mindspore.shard
- mindspore.shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device='Ascend', level=0)[source]
Defining the input and output layouts of this cell and the parallel strategies of remaining ops will be generated by sharding propagation. In PyNative mode, use this method to specify a Cell for distributed execution in graph mode. In Graph mode, use this method to specify distribution strategy for a Cell, strategy for others will be set by sharding propagation. in_strategy and out_strategy define the input and output layout respectively. in_strategy/out_strategy should be a tuple, each element of which corresponds to the desired layout of this input/output, and None represents data_parallel, which can refer to the description of
mindspore.ops.Primitive.shard()
. The parallel strategies of remaining operators are derived from the strategy specified by the input and output.Note
If ms.shard is called, the parallel mode in set_auto_parallel_context (parallel_mode) will be set to "auto_parallel" and the search mode (search_mode) to "sharding_propagation". If the input contain Parameter, its strategy should be set in in_strategy.
- Parameters
fn (Union[Cell, Function]) – Function to be executed in parallel. Its arguments and return value must be Tensor or Parameter. If fn is a Cell with parameters, fn needs to be an instantiated object, otherwise its arguments cannot be accessed.
in_strategy (tuple) – Define the layout of inputs, each element of the tuple should be a tuple(int) or tuple(mindspore.Layout). Tuple defines the layout of the corresponding input.
out_strategy (Union[tuple, None]) – Define the layout of outputs similar with in_strategy. It is not in use right now. Default:
None
.parameter_plan (Union[dict, None]) – Define the layout for the specified parameters. Each element in dict defines the layout of the parameter like "param_name: layout". The key is a parameter name of type 'str'. The value is a 1-D integer tuple or a 1-D mindspore.Layout tuple, indicating the corresponding layout. If the parameter name is incorrect or the corresponding parameter has been set, the parameter setting will be ignored. Default:
None
.device (string) – Select a certain device target. It is not in use right now. Support ["CPU", "GPU", "Ascend"]. Default:
"Ascend"
.level (int) – Option for parallel strategy infer algorithm, namely the object function, maximize computation over communication ratio, maximize speed performance, minimize memory usage etc. It is not in use right now. Support [0, 1, 2]. Default:
0
.
- Returns
Function, return the function that will be executed under auto parallel process.
- Raises
AssertionError – If parallel mode is not "auto_parallel" nor "semi_auto_parallel".
AssertionError – If device_target it not "Ascend" or "GPU".
TypeError – If in_strategy is not a tuple.
TypeError – If out_strategy is not a tuple or None.
TypeError – If any element in in_strategy is not a tuple(int) or tuple(mindspore.Layout).
TypeError – If any element in out_strategy is not a tuple(int) or tuple(mindspore.Layout).
TypeError – If parameter_plan is not a dict or None.
TypeError – If any key in parameter_plan is not a str.
TypeError – If any value in parameter_plan is not a tuple(int) or a tuple(mindspore.Layout).
TypeError – If device is not a str.
TypeError – If level is not an integer.
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import Tensor, nn >>> from mindspore.communication import init >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation", ... device_num=8) >>> >>> # Case 1: cell uses functional >>> class BasicBlock(nn.Cell): >>> def __init__(self): >>> super(BasicBlock, self).__init__() >>> self.dense1 = nn.Dense(64, 64) >>> self.gelu = nn.GELU() >>> def my_add(x, y): >>> x = ops.abs(x) >>> return x + y >>> # shard a function with tuple(int) strategies >>> self.shard_my_add = ms.shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),)) >>> >>> def construct(self, x, u): >>> x = self.gelu(x) >>> y = self.gelu(u) >>> y = x * y >>> x = self.dense1(x) >>> x = self.shard_my_add(x, y) >>> return x >>> >>> class NetForward(nn.Cell): >>> def __init__(self): >>> super(NetForward, self).__init__() >>> self.block1 = BasicBlock() >>> self.block2 = BasicBlock() >>> self.matmul = ops.MatMul() >>> >>> def construct(self, x, y): >>> x = self.matmul(x, y) >>> x = self.block1(x, x) >>> x = self.block2(x, x) >>> return x >>> >>> class Net(nn.Cell): >>> def __init__(self): >>> super(Net, self).__init__() >>> # setting cell sharding strategy and parameter_plan by tuple(int) >>> self.layer_net1 = NetForward() >>> self.layer_net1_shard = ms.shard(self.layer_net1, in_strategy=((4, 2), (2, 1)), ... parameter_plan={"self.layer_net1.block1.weight": (4, 1)}) >>> >>> # setting cell sharding strategy and parameter_plan by tuple(ms.Layout) >>> self.layer_net2 = NetForward() >>> layout = Layout((4, 2, 1), ("dp", "mp", "sp")) >>> in_layout = (layout("dp", "mp"), layout("mp", "sp")) >>> param_layout = layout("dp", "sp") >>> self.layer_net2_shard = ms.shard(self.layer_net2, in_strategy=in_layout, ... parameter_plan={"self.layer_net2.block2.weight": param_layout}) >>> self.flatten = nn.Flatten() >>> self.layer1 = nn.Dense(64, 64) >>> self.layer2 = nn.Dense(64, 32) >>> self.add = ops.Add() >>> self.matmul = ops.MatMul() >>> >>> def construct(self, x, y): >>> x = self.flatten(x) >>> y = self.flatten(y) >>> x = self.layer1(x) >>> x = self.layer_net1_shard(x, y) >>> x = self.layer_net2_shard(x, y) >>> x = self.layer2(x) >>> x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32)) >>> return x >>> >>> net = Net() >>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32) >>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32) >>> net(x, y) >>> >>> # Case 2: function uses functional sharding >>> def test_shard(x, y): ... return x + y >>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32) >>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32) >>> output = ms.shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y) >>> print(output.shape) (32, 10)
- Tutorial Examples: