mindspore.ops.HyperMap

查看源文件
class mindspore.ops.HyperMap(ops=None, reverse=False)[源代码]

对输入序列做集合运算。

对序列的每个元素或嵌套序列进行运算。与 mindspore.ops.Map 不同,HyperMap 能够用于嵌套结构。

参数:
  • ops (Union[MultitypeFuncGraph, None],可选) - ops 是指定运算操作。如果 opsNone ,则运算应该作为 HyperMap 实例的第一个入参。默认值为 None

  • reverse (bool,可选) - 在某些场景下,需要逆向以提高计算的并行性能,一般情况下,用户可以忽略。 reverse 用于决定是否逆向执行运算,仅在图模式下支持。默认值为 False

输入:
  • args (Tuple[sequence]) -

    • 如果 ops 不是 None ,则所有入参都应该是具有相同长度的序列,并且序列的每一行都是运算的输入。

    • 如果 opsNone ,则第一个入参是运算,其余都是输入。

说明

输入数量等于 ops 的输入数量。

输出:

序列或嵌套序列,执行 operation(args[0][i], args[1][i]) 之后输出的序列,其中 operationops 指定的一个函数。

异常:
支持平台:

Ascend GPU CPU

样例:

>>> from mindspore import Tensor, ops
>>> from mindspore import dtype as mstype
>>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
...                     (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
>>> # square all the tensor in the nested list
>>>
>>> square = ops.MultitypeFuncGraph('square')
>>> @square.register("Tensor")
... def square_tensor(x):
...     return ops.square(x)
>>>
>>> common_map = ops.HyperMap()
>>> output = common_map(square, nest_tensor_list)
>>> print(output)
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
>>> square_map = ops.HyperMap(square, False)
>>> output = square_map(nest_tensor_list)
>>> print(output)
((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
(Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))