mindspore.ops.HyperMap
- class mindspore.ops.HyperMap(ops=None, reverse=False)[source]
Hypermap will apply the set operation to input sequences.
Apply the operations to every elements of the sequence or nested sequence. Different from Map, the HyperMap supports to apply on nested structure.
- Parameters
ops (Union[MultitypeFuncGraph, None]) – ops is the operation to apply. If ops is None, the operations should be put in the first input of the instance. Default is None.
reverse (bool) – The optimizer needs to be inverted in some scenarios to improve parallel performance, general users please ignore. reverse is the flag to decide if apply the operation reversely. Only supported in graph mode. Default is False.
- Inputs:
args (Tuple[sequence]) - If ops is not None, all the inputs should be sequences with the same length. And each row of the sequences will be the inputs of the operation.
If ops is None, the first input is the operation, and the others are inputs.
- Outputs:
Sequence or nested sequence, the sequence of output after applying the function. e.g. operation(args[0][i], args[1][i]).
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> from mindspore import dtype as mstype >>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)), ... (Tensor(3, mstype.float32), Tensor(4, mstype.float32))) >>> # square all the tensor in the nested list >>> >>> square = MultitypeFuncGraph('square') >>> @square.register("Tensor") ... def square_tensor(x): ... return ops.square(x) >>> >>> common_map = HyperMap() >>> output = common_map(square, nest_tensor_list) >>> print(output) ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)), (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16))) >>> square_map = HyperMap(square, False) >>> output = square_map(nest_tensor_list) >>> print(output) ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)), (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))