mindspore.ops.AlltoAll

查看源文件
class mindspore.ops.AlltoAll(split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP)[源代码]

AlltoAll是一个集合通信函数。

AlltoAll将输入数据在特定的维度切分成特定的块数(blocks),并按顺序发送给其他rank。一般有两个阶段:

  • 分发阶段:在每个进程上,操作数沿着 split_dim 拆分为 split_count 个块(blocks),且分发到指定的rank上,例如,第i块被发送到第i个rank上。

  • 聚合阶段:每个rank沿着 concat_dimension 拼接接收到的数据。

说明

聚合阶段,所有进程中的Tensor必须具有相同的shape和格式。

要求全连接配网方式,每台设备具有相同的vlan id,ip和mask在同一子网,请查看 详细信息

参数:
  • split_count (int) - 在每个进程上,将块(blocks)拆分为 split_count 个。

  • split_dim (int) - 在每个进程上,沿着 split_dim 维度进行拆分。

  • concat_dim (int) - 在每个进程上,沿着 concat_dimension 拼接接收到的块(blocks)。

  • group (str) - AlltoAll的通信域。默认值: GlobalComm.WORLD_COMM_GROUP

输入:
  • input_x (Tensor) - shape为 \((x_1, x_2, ..., x_R)\)

输出:

Tensor,设输入的shape是 \((x_1, x_2, ..., x_R)\),则输出的shape为 \((y_1, y_2, ..., y_R)\),其中:

  • \(y_{split\_dim} = x_{split\_dim} / split\_count\)

  • \(y_{concat\_dim} = x_{concat\_dim} * split\_count\)

  • \(y_{other} = x_{other}\)

异常:
  • TypeError - 如果 group 不是字符串。

支持平台:

Ascend

样例:

说明

运行以下样例之前,需要配置好通信环境变量。

针对Ascend/GPU/CPU设备,推荐使用msrun启动方式,无第三方以及配置文件依赖。详见 msrun启动

该样例需要在8卡环境下运行。

>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> import numpy as np
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
...
...     def construct(self, x):
...         out = self.alltoall(x)
...         return out
...
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> net = Net()
>>> rank_id = int(os.getenv("RANK_ID"))
>>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
教程样例: