文档反馈

问题文档片段

问题文档片段包含公式时,显示为空格。

提交类型
issue

有点复杂...

找人问问吧。

PR

小问题,全程线上修改...

一键搞定!

请选择提交类型

问题类型
规范和低错类

- 规范和低错类:

- 错别字或拼写错误,标点符号使用错误、公式错误或显示异常。

- 链接错误、空单元格、格式错误。

- 英文中包含中文字符。

- 界面和描述不一致,但不影响操作。

- 表述不通顺,但不影响理解。

- 版本号不匹配:如软件包名称、界面版本号。

易用性

- 易用性:

- 关键步骤错误或缺失,无法指导用户完成任务。

- 缺少主要功能描述、关键词解释、必要前提条件、注意事项等。

- 描述内容存在歧义指代不明、上下文矛盾。

- 逻辑不清晰,该分类、分项、分步骤的没有给出。

正确性

- 正确性:

- 技术原理、功能、支持平台、参数类型、异常报错等描述和软件实现不一致。

- 原理图、架构图等存在错误。

- 命令、命令参数等错误。

- 代码片段错误。

- 命令无法完成对应功能。

- 界面错误,无法指导操作。

- 代码样例运行报错、运行结果不符。

风险提示

- 风险提示:

- 对重要数据或系统存在风险的操作,缺少安全提示。

内容合规

- 内容合规:

- 违反法律法规,涉及政治、领土主权等敏感词。

- 内容侵权。

请选择问题类型

问题描述

点击输入详细问题描述,以帮助我们快速定位问题。

mindspore.ops.AlltoAll

查看源文件
class mindspore.ops.AlltoAll(split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP)[源代码]

AlltoAll是一个集合通信函数。

AlltoAll将输入数据在特定的维度切分成特定的块数(blocks),并按顺序发送给其他rank。一般有两个阶段:

  • 分发阶段:在每个进程上,操作数沿着 split_dim 拆分为 split_count 个块(blocks),且分发到指定的rank上,例如,第i块被发送到第i个rank上。

  • 聚合阶段:每个rank沿着 concat_dimension 拼接接收到的数据。

说明

聚合阶段,所有进程中的Tensor必须具有相同的shape和格式。

要求全连接配网方式,每台设备具有相同的vlan id,ip和mask在同一子网,请查看 详细信息

参数:
  • split_count (int) - 在每个进程上,将块(blocks)拆分为 split_count 个。

  • split_dim (int) - 在每个进程上,沿着 split_dim 维度进行拆分。

  • concat_dim (int) - 在每个进程上,沿着 concat_dimension 拼接接收到的块(blocks)。

  • group (str) - AlltoAll的通信域。默认值: GlobalComm.WORLD_COMM_GROUP

输入:
  • input_x (Tensor) - shape为 (x1,x2,...,xR)

输出:

Tensor,设输入的shape是 (x1,x2,...,xR),则输出的shape为 (y1,y2,...,yR),其中:

  • ysplit_dim=xsplit_dim/split_count

  • yconcat_dim=xconcat_dimsplit_count

  • yother=xother

异常:
  • TypeError - 如果 group 不是字符串。

支持平台:

Ascend

样例:

说明

运行以下样例之前,需要配置好通信环境变量。

针对Ascend设备,用户需要准备rank表,设置rank_id和device_id,详见 rank table启动

针对GPU设备,用户需要准备host文件和mpi,详见 mpirun启动

针对CPU设备,用户需要编写动态组网启动脚本,详见 动态组网启动

该样例需要在8卡环境下运行。

>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
...
...     def construct(self, x):
...         out = self.alltoall(x)
...         return out
...
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> net = Net()
>>> rank_id = int(os.getenv("RANK_ID"))
>>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
教程样例: