mindspore.shard

mindspore.shard(fn, in_strategy, out_strategy=None, parameter_plan=None, device='Ascend', level=0)[源代码]

指定输入/输出Tensor的分布策略,其余算子的策略推导得到。在PyNative模式下,可以利用此方法指定某个Cell以图模式进行分布式执行。 在图模式下, 可以利用此方法设置某个模块的分布式切分策略,未设置的会自动通过策略传播方式配置。in_strategy/out_strategy需要为元组类型, 其中的每一个元素指定对应的输入/输出的Tensor分布策略,可参考: mindspore.ops.Primitive.shard() 的描述。也可以设置为None,会默认以数据并行执行。 其余算子的并行策略由输入输出指定的策略推导得到。

说明

调用该方法后,并行模式(parallel_mode)会自动设置为"auto_parallel"且搜索模式(search_mode)自动设置为"sharding_propagation"。 如果输入含有Parameter,其对应的策略应该在 in_strategy 里设置。

参数:
  • fn (Union[Cell, Function]) - 待通过分布式并行执行的函数,它的参数和返回值类型应该均为Tensor或Parameter。 如果 fn 是Cell类型且含有参数,则 fn 必须是一个实例化的对象,否则无法访问到其内部参数。

  • in_strategy (tuple) - 指定各输入的切分策略,输入元组的每个元素可以为整数元组或mindspore.Layout的元组。元组即具体指定输入每一维的切分策略。

  • out_strategy (Union[tuple, None]) - 指定各输出的切分策略,用法同 in_strategy,目前未使能。默认值: None

  • parameter_plan (Union[dict, None]) - 指定各参数的切分策略,传入字典时,键是str类型的参数名,值是一维整数tuple或一维mindspore.Layout的tuple表示相应的切分策略。 如果参数名错误或对应参数已经设置了切分策略,该参数的设置会被跳过。默认值: None

  • device (string) - 指定执行设备,可以为["CPU", "GPU", "Ascend"]中任意一个,目前未使能。默认值: "Ascend"

  • level (int) - 指定搜索切分策略的目标函数,即是最大化计算通信比、最小化内存消耗、最大化执行速度等。可以为[0, 1, 2]中任意一个,默认值: 0 。目前仅支持最大化计算通信比,其余模式未使能。

返回:

Function,返回一个在自动并行流程下执行的函数。

异常:
  • AssertionError - 如果并行模式不是"auto_parallel"或"semi_auto_parallel"。

  • AssertionError - 如果后端不是"Ascend"或"GPU"。

  • TypeError - 如果 in_strategy 不是tuple。

  • TypeError - 如果 out_strategy 不是tuple。

  • TypeError - 如果 in_strategy 里的任何一个元素不是tuple(int)或者tuple(mindspore.Layout)。

  • TypeError - 如果 out_strategy 里的任何一个元素不是tuple(int)或者tuple(mindspore.Layout)。

  • TypeError - 如果 parameter_plan 不是dict或None。

  • TypeError - 如果 parameter_plan 里的任何一个键值类型不是str。

  • TypeError - 如果 parameter_plan 里的任何一个值类型不是tuple(int)或者tuple(mindspore.Layout)。

  • TypeError - 如果 device 不是str。

  • TypeError - 如果 level 不是int。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Tensor, nn
>>> from mindspore.communication import init
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> ms.set_auto_parallel_context(parallel_mode="auto_parallel", search_mode="sharding_propagation",
...                              device_num=8)
>>>
>>> # Case 1: cell uses functional
>>> class BasicBlock(nn.Cell):
>>>     def __init__(self):
>>>         super(BasicBlock, self).__init__()
>>>         self.dense1 = nn.Dense(64, 64)
>>>         self.gelu = nn.GELU()
>>>         def my_add(x, y):
>>>             x = ops.abs(x)
>>>             return x + y
>>>         # shard a function with tuple(int) strategies
>>>         self.shard_my_add = ms.shard(my_add, in_strategy=((2, 2), (1, 4)), out_strategy=((4, 1),))
>>>
>>>     def construct(self, x, u):
>>>         x = self.gelu(x)
>>>         y = self.gelu(u)
>>>         y = x * y
>>>         x = self.dense1(x)
>>>         x = self.shard_my_add(x, y)
>>>         return x
>>>
>>> class NetForward(nn.Cell):
>>>     def __init__(self):
>>>         super(NetForward, self).__init__()
>>>         self.block1 = BasicBlock()
>>>         self.block2 = BasicBlock()
>>>         self.matmul = ops.MatMul()
>>>
>>>     def construct(self, x, y):
>>>         x = self.matmul(x, y)
>>>         x = self.block1(x, x)
>>>         x = self.block2(x, x)
>>>         return x
>>>
>>> class Net(nn.Cell):
>>>     def __init__(self):
>>>         super(Net, self).__init__()
>>>         # setting cell sharding strategy and parameter_plan by tuple(int)
>>>         self.layer_net1 = NetForward()
>>>         self.layer_net1_shard = ms.shard(self.layer_net1, in_strategy=((4, 2), (2, 1)),
...                                          parameter_plan={"self.layer_net1.block1.weight": (4, 1)})
>>>
>>>         # setting cell sharding strategy and parameter_plan by tuple(ms.Layout)
>>>         self.layer_net2 = NetForward()
>>>         layout = Layout((4, 2, 1), ("dp", "mp", "sp"))
>>>         in_layout = (layout("dp", "mp"), layout("mp", "sp"))
>>>         param_layout = layout("dp", "sp")
>>>         self.layer_net2_shard = ms.shard(self.layer_net2, in_strategy=in_layout,
...                                          parameter_plan={"self.layer_net2.block2.weight": param_layout})
>>>         self.flatten = nn.Flatten()
>>>         self.layer1 = nn.Dense(64, 64)
>>>         self.layer2 = nn.Dense(64, 32)
>>>         self.add = ops.Add()
>>>         self.matmul = ops.MatMul()
>>>
>>>     def construct(self, x, y):
>>>         x = self.flatten(x)
>>>         y = self.flatten(y)
>>>         x = self.layer1(x)
>>>         x = self.layer_net1_shard(x, y)
>>>         x = self.layer_net2_shard(x, y)
>>>         x = self.layer2(x)
>>>         x = self.matmul(x, Tensor(np.ones(shape=(32, 32)), dtype=ms.float32))
>>>         return x
>>>
>>> net = Net()
>>> x = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
>>> y = Tensor(np.ones(shape=(64, 1, 8, 8)), dtype=ms.float32)
>>> net(x, y)
>>>
>>> # Case 2: function uses functional sharding
>>> def test_shard(x, y):
...     return x + y
>>> x = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
>>> y = Tensor(np.ones(shape=(32, 10)), dtype=ms.float32)
>>> output = ms.shard(test_shard, in_strategy=((4, 2), (4, 2)))(x, y)
>>> print(output.shape)
(32, 10)
教程样例: