Operation Overloading

View Source On Gitee


mindspore.ops.composite provide some operator combinations related to graph transformation such as MultitypeFuncGraph and HyperMap.


MultitypeFuncGraph is used to generate overloaded functions to support different types of input. Users can use MultitypeFuncGraph to define a group of overloaded functions. The implementation varies according to the function type. First initialize a MultitypeFuncGraph object, and use register with input type as the decorator of the function to be registered, so that the object can be called with different types of inputs. For more instructions, see MultitypeFuncGraph.

A code example is as follows:

import numpy as np
from mindspore.ops import MultitypeFuncGraph
from mindspore import Tensor
import mindspore.ops as ops

add = MultitypeFuncGraph('add')
@add.register("Number", "Number")
def add_scalar(x, y):
    return ops.scalar_add(x, y)

@add.register("Tensor", "Tensor")
def add_tensor(x, y):
    return ops.add(x, y)

tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
print('tensor', add(tensor1, tensor2))
print('scalar', add(1, 2))

The following information is displayed:

tensor [[2.4 4.2]
 [4.4 6.4]]
scalar 3


HyperMap can apply an specified operation to one or more input sequences, which can be used with MultitypeFuncGraph. For example, after defining a group of overloaded add functions, we can apply add operation to multiple input groups of different types. Unlike Map, HyperMap can be used in nested structures to perform specified operations on the input in a sequence or nested sequence. For more instructions, see HyperMap.

A code example is as follows:

from mindspore import dtype as mstype
from mindspore import Tensor
from mindspore.ops import MultitypeFuncGraph, HyperMap
import mindspore.ops as ops

add = MultitypeFuncGraph('add')
@add.register("Number", "Number")
def add_scalar(x, y):
    return ops.scalar_add(x, y)

@add.register("Tensor", "Tensor")
def add_tensor(x, y):
    return ops.tensor_add(x, y)

add_map = HyperMap(add)
output = add_map((Tensor(1, mstype.float32), Tensor(2, mstype.float32), 1), (Tensor(3, mstype.float32), Tensor(4, mstype.float32), 2))
print("output =", output)

The following information is displayed:

output = (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 6), 3)

In this example, the input of add_map contains two sequences. HyperMap will get the corresponding elements from the two sequences as x and y for the inputs of add in the form of operation(args[0][i], args[1][i]). For example, add(Tensor(1, mstype.float32), Tensor(3, mstype.float32)).