mindspore.ops.MultitypeFuncGraph

class mindspore.ops.MultitypeFuncGraph(name, read_value=False)[源代码]

MultitypeFuncGraph是一个用于生成重载函数的类,使用不同类型作为输入。使用 name 去初始化一个MultitypeFuncGraph对象,然后用带有输入类型的 register 注册器进行装饰注册类型。这样使该函数可以使用不同的类型作为输入调用,一般与 HyperMapMap 结合使用。

参数:
  • name (str) - 操作名。

  • read_value (bool, 可选) - 如果注册函数不需要对输入的值进行更改,即所有输入都为按值传递,则将 read_value 设置为 True 。默认值: False

异常:
  • ValueError - 找不到给定参数类型所匹配的函数。

支持平台:

Ascend GPU CPU

样例:

>>> # `add` is a metagraph object which will add two objects according to
>>> # input type using ".register" decorator.
>>> from mindspore import Tensor
>>> from mindspore import ops
>>> from mindspore import dtype as mstype
>>> import mindspore.ops as ops
>>>
>>> tensor_add = ops.Add()
>>> add = ops.MultitypeFuncGraph('add')
>>> @add.register("Number", "Number")
... def add_scala(x, y):
...     return x + y
>>> @add.register("Tensor", "Tensor")
... def add_tensor(x, y):
...     return tensor_add(x, y)
>>> output = add(1, 2)
>>> print(output)
3
>>> output = add(Tensor([0.1, 0.6, 1.2], dtype=mstype.float32), Tensor([0.1, 0.6, 1.2], dtype=mstype.float32))
>>> print(output)
[0.2 1.2 2.4]
register(*type_names)[源代码]

根据给出的字符串内容注册不同输入类型的函数。

参数:
  • type_names (Union[str, mindspore.dtype]) - 输入类型的名或者一个类型列表。

返回:

装饰器, 一个根据 type_names 指定输入类型的注册函数的装饰器。