mindspore.reshard

mindspore.reshard(tensor, layout)[source]

Specify the tensor by the given layout. The given layout must be type mindspore.Layout, can check mindspore.Layout for reference.

  • In the Graph mode, this function can set the sharding propagation strategy of a tensor. For those tensor do not manually be set, their strategies are decided by the sharding strategy propagation algorithm automatically.

  • In the PyNative mode, this function can set a tensor sharding strategy in a Cell that runs in the Graph mode (i.e. inside the Cell processed by Cell.shard/F.shard).

Note

  • In the auto parallel mode, an exception will throw if the search mode is not "sharding_propagation".

  • In the semi-auto parallel mode, the parallel mode will automatically switch to auto parallel mode with the search mode be set to "sharding_propagation".

Parameters
  • tensor (Tensor) – The tensor to be set the sharding strategy.

  • layout (Layout) – The layout to shard the tensor precisely, including the device arrangement (device_matrix) and the alias for the device matrix (alias_name).

Returns

Tensor. The mathematically equivalent of the input tensor.

Raises
  • TypeError – Reshard takes in Tensor type as the first input param, but got: type(tensor).

  • TypeError – Reshard only support type mindspore.Layout but got: type(layout).

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import ops, nn, Tensor, context, Layout
>>> context.set_context(mode=ms.GRAPH_MODE)
>>> context.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL,
>>>                                   search_mode="sharding_propagation")
>>> class Network(nn.Cell):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.matmul = ops.MatMul()
>>>         self.relu = ops.ReLU()
>>>     def construct(self, x, layout):
>>>         x = self.relu(x)
>>>         x_reshard = ops.reshard(x, self.layout)
>>>         y = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
>>>         x = self.matmul(x_reshard, y)
>>>         return x
>>>
>>> layout = Layout((4, 2), ("dp", "mp"))
>>> input_layout = layout("dp", "mp")
>>> net = Network()
>>> tensor = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
>>> out = net(tensor, input_layout)