mindspore.ops.ReduceScatter

class mindspore.ops.ReduceScatter(op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP)[源代码]

Reduces and scatters tensors from the specified communication group.

Note

The back propagation of the op is not supported yet. Stay tuned for more. The tensors must have the same shape and format in all processes of the collection. The user needs to preset communication environment variables before running the following example, please check the details on the official website of Communication Operator API.

Parameters
  • op (str) – Specifies an operation used for element-wise reductions, like SUM, MAX, AVG. Default: ReduceOp.SUM.

  • group (str) – The communication group to work on. Default: “GlobalComm.WORLD_COMM_GROUP”.

Raises
  • TypeError – If any of operation and group is not a string.

  • ValueError – If the first dimension of the input cannot be divided by the rank size.

Supported Platforms:

Ascend GPU

Examples

>>> # This example should be run with two devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> from mindspore.ops import ReduceOp
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.reducescatter = ops.ReduceScatter(ReduceOp.SUM)
...
...     def construct(self, x):
...         return self.reducescatter(x)
...
>>> input_ = Tensor(np.ones([8, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
>>> print(output)
[[2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2.]]