mindspore.ops.AlltoAll

View Source On Gitee
class mindspore.ops.AlltoAll(split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP)[source]

AlltoAll is a collective operation.

AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:

  • The scatter phase: On each process, the operand is split into split_count number of blocks along the split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.

  • The gather phase: Each process concatenates the received blocks along the concat_dimension.

Note

This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details .

Parameters
  • split_count (int) – On each process, divide blocks into split_count number.

  • split_dim (int) – On each process, split blocks along the split_dim.

  • concat_dim (int) – On each process, gather the received blocks along the concat_dimension.

  • group (str) – The communication group to work on. Default: GlobalComm.WORLD_COMM_GROUP .

Inputs:
  • input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).

Outputs:

Tensor. If the shape of input tensor is \((x_1, x_2, ..., x_R)\), then the shape of output tensor is \((y_1, y_2, ..., y_R)\), where:

  • \(y_{split\_dim} = x_{split\_dim} / split\_count\)

  • \(y_{concat\_dim} = x_{concat\_dim} * split\_count\)

  • \(y_{other} = x_{other}\).

Raises

TypeError – If group is not a string.

Supported Platforms:

Ascend

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.

This example should be run with 8 devices.

>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> import numpy as np
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
...
...     def construct(self, x):
...         out = self.alltoall(x)
...         return out
...
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> net = Net()
>>> rank_id = int(os.getenv("RANK_ID"))
>>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
Tutorial Examples: