mindspore.ops.MultitypeFuncGraph
- class mindspore.ops.MultitypeFuncGraph(name, read_value=False)[源代码]
Generates overloaded functions.
MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs. Initialize an MultitypeFuncGraph object with name, and use register with input types as the decorator for the function to be registered. And the object can be called with different types of inputs, and work with HyperMap and Map.
- Parameters
- Raises
ValueError – If failed to find a matching function for the given arguments.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> # `add` is a metagraph object which will add two objects according to >>> # input type using ".register" decorator. >>> from mindspore import Tensor >>> from mindspore import ops >>> from mindspore import dtype as mstype >>> from mindspore.ops.composite import MultitypeFuncGraph >>> >>> tensor_add = ops.Add() >>> add = MultitypeFuncGraph('add') >>> @add.register("Number", "Number") ... def add_scala(x, y): ... return x + y >>> @add.register("Tensor", "Tensor") ... def add_tensor(x, y): ... return tensor_add(x, y) >>> output = add(1, 2) >>> print(output) 3 >>> output = add(Tensor([0.1, 0.6, 1.2], dtype=mstype.float32), Tensor([0.1, 0.6, 1.2], dtype=mstype.float32)) >>> print(output) [0.2 1.2 2.4]
- register(*type_names)[源代码]
Register a function for the given type string.
- Parameters
type_names (Union[str,
mindspore.dtype
]) – Inputs type names or types list.- Returns
decorator, a decorator to register the function to run, when called under the types described in type_names.