mindspore.parallel.function.reshard
- mindspore.parallel.function.reshard(tensor, layout)[source]
Converting a tensor from one distributed arrangement to another distributed arrangement. The given layout must be type mindspore.parallel.Layout, can check
mindspore.parallel.Layout
for reference.Note
In the Graph mode, this function can set the sharding propagation strategy of a tensor. For those tensor do not manually be set, their strategies are decided by the sharding strategy propagation algorithm automatically.
In PyNative mode, you can use this method to arrange tensors in a cell (that is, cells that use Cell.shard/F.shard in PyNative mode) that is executed in parallel in graph mode.
- Parameters
- Returns
Tensor. The mathematically equivalent of the input tensor.
- Raises
- Supported Platforms:
Ascend
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start-up for more details.
This example should be run with 8 devices.
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import ops, nn, Tensor, context, Layout >>> from mindspore.parallel.function import reshard >>> from mindspore.nn.utils import no_init_parameters >>> from mindspore.parallel.auto_parallel import AutoParallel >>> from mindspore.communication import init >>> context.set_context(mode=ms.GRAPH_MODE) >>> init() >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.matmul = ops.MatMul() ... self.relu = ops.ReLU() ... def construct(self, x, layout): ... x = self.relu(x) ... x_reshard = reshard(x, layout) ... y = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32) ... x = self.matmul(x_reshard, y) ... return x >>> layout = Layout((4, 2), ("dp", "mp")) >>> input_layout = layout("dp", "mp") >>> with no_init_parameters(): ... net = Network() >>> parallel_net = AutoParallel(net, parallel_mode='sharding_propagation') >>> tensor = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32) >>> out = parallel_net(tensor, input_layout)