mindspore.ops.AlltoAll

View Source On Gitee
class mindspore.ops.AlltoAll(split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP)[source]

AlltoAll is a collective operation.

AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:

  • The scatter phase: On each process, the operand is split into split_count number of blocks along the split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.

  • The gather phase: Each process concatenates the received blocks along the concat_dimension.

Note

This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details .

Parameters
  • split_count (int) – On each process, divide blocks into split_count number.

  • split_dim (int) – On each process, split blocks along the split_dim.

  • concat_dim (int) – On each process, gather the received blocks along the concat_dimension.

  • group (str, optional) – The communication group to work on. Default: GlobalComm.WORLD_COMM_GROUP .

Inputs:
  • input_x (Tensor) - The shape of tensor is (x1,x2,...,xR).

Outputs:

Tensor. If the shape of input tensor is (x1,x2,...,xR), then the shape of output tensor is (y1,y2,...,yR), where:

  • ysplit_dim=xsplit_dim/split_count

  • yconcat_dim=xconcat_dimsplit_count

  • yother=xother.

Raises

TypeError – If group is not a string.

Supported Platforms:

Ascend

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.

This example should be run with 8 devices.

>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> import numpy as np
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
...
...     def construct(self, x):
...         out = self.alltoall(x)
...         return out
...
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> net = Net()
>>> rank_id = int(os.getenv("RANK_ID"))
>>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
Tutorial Examples: