mindspore.vmap

mindspore.vmap(fn, in_axes=0, out_axes=0)[source]

Vectorizing map (vmap) is a kind of higher-order function to map fn along the parameter axes.

Vmap is pioneered by Jax and it removes the restriction of batch dimension on the operator, and provides a more convenient and unified operator expression. Moreover, it allows users to composite with other functional modules such as mindspore.grad(), to improve the development efficiency. In addition, the vectorizing map does not execute loops outside the function, but sinks loops into the primitive operations of the function for better performance. When combined with Graph Kernel Fusion, operational efficiency would be further improved.

Warning

This is an experimental API that is subject to change or deletion.

Note

1. The power of vmap comes from the implementation of VmapRules of primitives. Although we have designed a generalized rule for user custom operators, we can not guarantee that it works well for all operators, please be aware the risk of use. If you want to achieve a better performance, please refer to the tutorial to implement the specific VmapRule for the custom operator, which won’t take too much time. 2. When calling the random number generation methods within the scope of vmap, the same random number is generated among vector functions each time. If you expect each vector branch to use different random numbers, you need to generate batch random numbers externally in advance and then transfer them to vmap.

Parameters
  • fn (Union[Cell, Function, CellList]) – Function to be mapped along the parameter axes, which takes at least one argument and returns one or more Tensors or the type of data supported by the MindSpore Tensor. When it is a CellList, the model ensembling scenario, please make sure that the structure of each cell is the same and the number of cells is consistent with the sizes of the mapped axes (axis_size).

  • in_axes (Union[int, list, tuple]) – Specifies which dimensions (axes) of the inputs should be mapped over. If in_axes is an integer, all arguments of fn are mapped over according to this axis index. If in_axes is a tuple or list, which only composed of integers or Nones and the length should equal to the number of positional arguments to fn, indicates which axis to map for each corresponding positional argument. Note that, axis integers must be in range \([-ndim, ndim)\) for each argument, where ndim is the number of dimensions of the corresponding argument. None means not mapping along any axis. Also the mapping axis index of the in_axes must have at least one positional parameter not None. The sizes of the mapped axes (axis_size) for all arguments must be equal. Default: 0 .

  • out_axes (Union[int, list, tuple]) – Specifies where the mapped dimensions (axes) should appear in the outputs. If out_axes is an integer, all outputs of fn are specified according to this axis. If out_axes is a tuple or list, which only composed of integers or Nones. And its length also should be equal to the number of outputs of fn. Note that, axis integers must be in range \([-ndim, ndim)\) for each output, where ndim is the dimension of the output of the vmap-mapped function. All outputs with a non-None mapped axis must specify a non-None out_axes, and if outputs with None mapped axis specifies a non-None out_axes, the result broadcasts across the mapped axis. Default: 0 .

Returns

Function, returns the Vectorized/Batched version function of fn. The arguments and outputs of this function correspond to those of fn, but it adds an extra batch dimension at positions specified by in_axes and out_axes.

Raises

RuntimeError – If base elements in in_axes or out_axes are not a None or an integer. If the all base elements in in_axes or out_axes are None. If in_axes is not single integer, and the length of in_axes is not equal to the arguments sizes. If out_axes is not single integer, and the length of out_axes is not equal to the outputs sizes. If the axis_size of each arguments in the scope of vmap are not equal. If the axis in in_axes or out_axes is out of bounds.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import Tensor
>>> from mindspore import vmap
>>> def test_vmap(x, y, z):                                              # ([a],[a],[a]) -> [a]
...     return x + y + z
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32))    # [b, a]
>>> y = Tensor(np.array([[-3, -2, -1], [3, 2, 1]]).astype(np.float32))   # [a, b]
>>> z = Tensor(np.array([0, 3]).astype(np.float32))                      # [a]
>>> output = vmap(test_vmap, in_axes=(0, 1, None), out_axes=1)(x, y, z)  # ([b, a],[a, b],[a]) -> [a, b]
>>> print(output)
[[-2  1  4]
 [ 8  9 10]]