mindspore.parallel.function.reshard

View Source On Gitee
mindspore.parallel.function.reshard(tensor, layout)[source]

Specify the tensor by the given layout. The given layout must be type mindspore.parallel.Layout, can check mindspore.parallel.Layout for reference.

  • In the Graph mode, this function can set the sharding propagation strategy of a tensor. For those tensor do not manually be set, their strategies are decided by the sharding strategy propagation algorithm automatically.

  • In the PyNative mode, this function can set a tensor sharding strategy in a Cell that runs in the Graph mode (i.e. inside the Cell processed by Cell.shard/F.shard).

Note

  • In the auto parallel mode, an exception will throw if the search mode is not "sharding_propagation".

  • In the semi-auto parallel mode, the parallel mode will automatically switch to auto parallel mode with the search mode be set to "sharding_propagation".

  • Currently, configuring multi-dimension and multi-copy reshard strategy in mindspore.parallel.Layout is not supported.

Parameters
  • tensor (Tensor) – The tensor to be set the sharding strategy.

  • layout (Layout) – The layout to shard the tensor precisely, including the device arrangement (device_matrix) and the alias for the device matrix (alias_name).

Returns

Tensor. The mathematically equivalent of the input tensor.

Raises
  • TypeError – Reshard takes in Tensor type as the first input param, but got: type(tensor).

  • TypeError – Reshard only support type mindspore.parallel.Layout but got: type(layout).

Supported Platforms:

Ascend

Examples

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import ops, nn, Tensor, context, Layout
>>> from mindspore.parallel.function import reshard
>>> from mindspore.nn.utils import no_init_parameters
>>> from mindspore.parallel.auto_parallel import AutoParallel
>>> context.set_context(mode=ms.GRAPH_MODE)
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.matmul = ops.MatMul()
...         self.relu = ops.ReLU()
...     def construct(self, x, layout):
...         x = self.relu(x)
...         x_reshard = reshard(x, layout)
...         y = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
...         x = self.matmul(x_reshard, y)
...         return x
>>> layout = Layout((4, 2), ("dp", "mp"))
>>> input_layout = layout("dp", "mp")
>>> with no_init_parameters():
>>>     net = Network()
>>> parallel_net = AutoParallel(net, parallel_mode='sharding_propagation')
>>> tensor = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
>>> out = parallel_net(tensor, input_layout)