mindspore.ops.HyperMap

class mindspore.ops.HyperMap(ops=None, reverse=False)[源代码]

对输入序列做集合运算。

对序列的每个元素或嵌套序列进行运算。与 mindspore.ops.Map 不同,HyperMap 能够用于嵌套结构。

参数:

  • ops (Union[MultitypeFuncGraph, None]) – ops 是指定运算操作。如果 ops 为None,则运算应该作为 HyperMap 实例的第一个入参。默认值为None。

  • reverse (bool) - 在某些场景下,需要逆向以提高计算的并行性能,一般情况下,用户可以忽略。reverse 用于决定是否逆向执行运算,仅在图模式下支持。默认值为False。

输入:

  • args (Tuple[sequence]) - 如果 ops 不是None,则所有入参都应该是具有相同长度的序列,并且序列的每一行都是运算的输入。如果 ops 是None,则第一个入参是运算,其余都是输入。

Note

输入数量等于 ops 的输入数量。

输出:

序列或嵌套序列,执行函数如 operation(args[0][i], args[1][i]) 之后输出的序列。

异常:

  • TypeError - 如果 ops 既不是 MultitypeFuncGraph 也不是None。

  • TypeError - 如果 args 不是一个tuple。

支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import Tensor, ops
>>> from mindspore.ops.composite.base import MultitypeFuncGraph, HyperMap
>>> from mindspore import dtype as mstype
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
...                     (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
>>> # square all the tensor in the nested list
>>>
>>> square = MultitypeFuncGraph('square')
>>> @square.register("Tensor")
... def square_tensor(x):
...     return ops.square(x)
>>>
>>> common_map = HyperMap()
>>> output = common_map(square, nest_tensor_list)
>>> print(output)
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
>>> square_map = HyperMap(square, False)
>>> output = square_map(nest_tensor_list)
>>> print(output)
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))