mindspore.ops.AlltoAll
- class mindspore.ops.AlltoAll(split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP)[源代码]
AlltoAll是一个集合通信函数。
AlltoAll将输入数据在特定的维度切分成特定的块数(blocks),并按顺序发送给其他rank。一般有两个阶段:
分发阶段:在每个进程上,操作数沿着 split_dim 拆分为 split_count 个块(blocks),且分发到指定的rank上,例如,第i块被发送到第i个rank上。
聚合阶段:每个rank沿着 concat_dimension 拼接接收到的数据。
- 参数:
split_count (int) - 在每个进程上,将块(blocks)拆分为 split_count 个。
split_dim (int) - 在每个进程上,沿着 split_dim 维度进行拆分。
concat_dim (int) - 在每个进程上,沿着 concat_dimension 拼接接收到的块(blocks)。
group (str) - AlltoAll的通信域。默认值:
GlobalComm.WORLD_COMM_GROUP
。
- 输入:
input_x (Tensor) - shape为 \((x_1, x_2, ..., x_R)\)。
- 输出:
Tensor,设输入的shape是 \((x_1, x_2, ..., x_R)\),则输出的shape为 \((y_1, y_2, ..., y_R)\),其中:
\(y_{split\_dim} = x_{split\_dim} / split\_count\)
\(y_{concat\_dim} = x_{concat\_dim} * split\_count\)
\(y_{other} = x_{other}\) 。
- 异常:
TypeError - 如果 group 不是字符串。
- 支持平台:
Ascend
样例:
说明
运行以下样例之前,需要配置好通信环境变量。
针对Ascend设备,用户需要准备rank表,设置rank_id和device_id,详见 rank table启动 。
针对GPU设备,用户需要准备host文件和mpi,详见 mpirun启动 。
针对CPU设备,用户需要编写动态组网启动脚本,详见 动态组网启动 。
该样例需要在8卡环境下运行。
>>> import os >>> import mindspore as ms >>> from mindspore import Tensor >>> from mindspore.communication import init >>> import mindspore.nn as nn >>> import mindspore.ops as ops >>> import numpy as np >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1) ... ... def construct(self, x): ... out = self.alltoall(x) ... return out ... >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> net = Net() >>> rank_id = int(os.getenv("RANK_ID")) >>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32) >>> output = net(input_x) >>> print(output) [[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
- 教程样例: