mindspore.parallel.function.reshard

View Source On Gitee
mindspore.parallel.function.reshard(tensor, layout)[source]

Converting a tensor from one distributed arrangement to another distributed arrangement. The given layout must be type mindspore.parallel.Layout, can check mindspore.parallel.Layout for reference.

Note

  • In the Graph mode, this function can set the sharding propagation strategy of a tensor. For those tensor do not manually be set, their strategies are decided by the sharding strategy propagation algorithm automatically.

  • In PyNative mode, you can use this method to arrange tensors in a cell (that is, cells that use Cell.shard/F.shard in PyNative mode) that is executed in parallel in graph mode.

Parameters
  • tensor (Tensor) – The tensor to be set the sharding strategy.

  • layout (Layout) – The layout to shard the tensor precisely, including the device arrangement (device_matrix) and the alias for the device matrix (alias_name).

Returns

Tensor. The mathematically equivalent of the input tensor.

Raises
  • TypeError – If the type of input param tensor is not mindspore.Tensor.

  • TypeError – If the type of input param layout is not mindspore.parallel.Layout.

Supported Platforms:

Ascend

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start-up for more details.

This example should be run with 8 devices.

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import ops, nn, Tensor, context, Layout
>>> from mindspore.parallel.function import reshard
>>> from mindspore.nn.utils import no_init_parameters
>>> from mindspore.parallel.auto_parallel import AutoParallel
>>> from mindspore.communication import init
>>> context.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> class Network(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.matmul = ops.MatMul()
...         self.relu = ops.ReLU()
...     def construct(self, x, layout):
...         x = self.relu(x)
...         x_reshard = reshard(x, layout)
...         y = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
...         x = self.matmul(x_reshard, y)
...         return x
>>> layout = Layout((4, 2), ("dp", "mp"))
>>> input_layout = layout("dp", "mp")
>>> with no_init_parameters():
...     net = Network()
>>> parallel_net = AutoParallel(net, parallel_mode='sharding_propagation')
>>> tensor = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32)
>>> out = parallel_net(tensor, input_layout)