mindspore.ops.AlltoAll

class mindspore.ops.AlltoAll(split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP)[源代码]

AlltoAll is a collective operation.

AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:

  • The scatter phase: On each process, the operand is split into split_count number of blocks along the split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.

  • The gather phase: Each process concatenates the received blocks along the concat_dimension.

Note

The tensors must have the same shape and format in all processes of the collection. The user needs to preset communication environment variables before running the following example, please check the details on the official website of MindSpore.

This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details.

Parameters
  • split_count (int) – On each process, divide blocks into split_count number.

  • split_dim (int) – On each process, split blocks along the split_dim.

  • concat_dim (int) – On each process, gather the received blocks along the concat_dimension.

  • group (str) – The communication group to work on. Default: “GlobalComm.WORLD_COMM_GROUP”.

Raises

TypeError – If group is not a string.

Supported Platforms:

Ascend

Example

>>> # This example should be run with 8 devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore import context
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
...
...     def construct(self, x):
...         out = self.alltoall(x)
...         return out
...
>>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
>>> init()
>>> net = Net()
>>> rank_id = int(os.getenv("RANK_ID"))
>>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]