mindspore.reshard
- mindspore.reshard(tensor, layout)[source]
Specify the tensor by the given layout. The given layout must be type mindspore.Layout, can check
mindspore.Layout
for reference.In the Graph mode, this function can set the sharding propagation strategy of a tensor. For those tensor do not manually be set, their strategies are decided by the sharding strategy propagation algorithm automatically.
In the PyNative mode, this function can set a tensor sharding strategy in a Cell that runs in the Graph mode (i.e. inside the Cell processed by Cell.shard/F.shard).
Note
In the auto parallel mode, an exception will throw if the search mode is not "sharding_propagation".
In the semi-auto parallel mode, the parallel mode will automatically switch to auto parallel mode with the search mode be set to "sharding_propagation".
- Parameters
- Returns
Tensor. The mathematically equivalent of the input tensor.
- Raises
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import ops, nn, Tensor, context, Layout >>> context.set_context(mode=ms.GRAPH_MODE) >>> context.set_auto_parallel_context(parallel_mode=ms.ParallelMode.AUTO_PARALLEL, ... search_mode="sharding_propagation") >>> class Network(nn.Cell): ... def __init__(self): ... super().__init__() ... self.matmul = ops.MatMul() ... self.relu = ops.ReLU() ... def construct(self, x, layout): ... x = self.relu(x) ... x_reshard = ops.reshard(x, self.layout) ... y = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32) ... x = self.matmul(x_reshard, y) ... return x >>> >>> layout = Layout((4, 2), ("dp", "mp")) >>> input_layout = layout("dp", "mp") >>> net = Network() >>> tensor = Tensor(np.ones(shape=(128, 128)), dtype=ms.float32) >>> out = net(tensor, input_layout)