mindspore.ops.AlltoAll
- class mindspore.ops.AlltoAll(split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP)[source]
AlltoAll is a collective operation.
AlltoAll sends data from the all processes to the all processes in the specified group. It has two phases:
The scatter phase: On each process, the operand is split into split_count number of blocks along the split_dimensions, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
The gather phase: Each process concatenates the received blocks along the concat_dimension.
Note
This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details .
- Parameters
split_count (int) – On each process, divide blocks into split_count number.
split_dim (int) – On each process, split blocks along the split_dim.
concat_dim (int) – On each process, gather the received blocks along the concat_dimension.
group (str) – The communication group to work on. Default:
GlobalComm.WORLD_COMM_GROUP
.
- Inputs:
input_x (Tensor) - The shape of tensor is \((x_1, x_2, ..., x_R)\).
- Outputs:
Tensor. If the shape of input tensor is \((x_1, x_2, ..., x_R)\), then the shape of output tensor is \((y_1, y_2, ..., y_R)\), where:
\(y_{split\_dim} = x_{split\_dim} / split\_count\)
\(y_{concat\_dim} = x_{concat\_dim} * split\_count\)
\(y_{other} = x_{other}\).
- Raises
TypeError – If group is not a string.
- Supported Platforms:
Ascend
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.
This example should be run with 8 devices.
>>> import os >>> import mindspore as ms >>> from mindspore import Tensor >>> from mindspore.communication import init >>> import mindspore.nn as nn >>> from mindspore import ops >>> import numpy as np >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1) ... ... def construct(self, x): ... out = self.alltoall(x) ... return out ... >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> net = Net() >>> rank_id = int(os.getenv("RANK_ID")) >>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32) >>> output = net(input_x) >>> print(output) [[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
- Tutorial Examples: