文档反馈

问题文档片段

问题文档片段包含公式时,显示为空格。

提交类型
issue

有点复杂...

找人问问吧。

PR

小问题,全程线上修改...

一键搞定!

请选择提交类型

问题类型
规范和低错类

- 规范和低错类:

- 错别字或拼写错误,标点符号使用错误、公式错误或显示异常。

- 链接错误、空单元格、格式错误。

- 英文中包含中文字符。

- 界面和描述不一致,但不影响操作。

- 表述不通顺,但不影响理解。

- 版本号不匹配:如软件包名称、界面版本号。

易用性

- 易用性:

- 关键步骤错误或缺失,无法指导用户完成任务。

- 缺少主要功能描述、关键词解释、必要前提条件、注意事项等。

- 描述内容存在歧义指代不明、上下文矛盾。

- 逻辑不清晰,该分类、分项、分步骤的没有给出。

正确性

- 正确性:

- 技术原理、功能、支持平台、参数类型、异常报错等描述和软件实现不一致。

- 原理图、架构图等存在错误。

- 命令、命令参数等错误。

- 代码片段错误。

- 命令无法完成对应功能。

- 界面错误,无法指导操作。

- 代码样例运行报错、运行结果不符。

风险提示

- 风险提示:

- 对重要数据或系统存在风险的操作,缺少安全提示。

内容合规

- 内容合规:

- 违反法律法规,涉及政治、领土主权等敏感词。

- 内容侵权。

请选择问题类型

问题描述

点击输入详细问题描述,以帮助我们快速定位问题。

mindspore.ops.matmul_reduce_scatter

查看源文件
mindspore.ops.matmul_reduce_scatter(input, x2, group, world_size, *, reduce_op='sum', bias=None, comm_turn=0, trans_input=False, trans_x2=False)[源代码]

TP 切分场景下,实现 matmul 和 reducescatter 的融合,融合算子内部实现通信和计算流水并行。

output=reducescatter(input@x2)

警告

这是一个实验性 API,后续可能修改或删除。

参数:
  • input (Tensor) - matmul 的左矩阵,dtype 支持 float16、bfloat16,shape 支持二维,数据格式支持 ND。

  • x2 (Tensor) - matmul 的右矩阵,dtype 需要和 input 一致,shape 支持二维,数据格式支持 ND。

  • group (str) - 通信组名称,可以由 create_group 方法创建,或者使用默认组 mindspore.communication.GlobalComm.WORLD_COMM_GROUP

  • world_size (int) - 通信组的总进程数,要求与实际运行的卡数一致,支持 248

关键字参数:
  • reduce_op (str, 可选) - reduce 操作类型。当前仅支持 'sum' 。默认 'sum'

  • bias (Tensor, 可选) - 当前仅支持 None 。默认 None

  • comm_turn (int, 可选) - 表示进程间通信切分粒度。当前仅支持 0 。默认 0

  • trans_input (bool, 可选) - 表示 input 是否转置。当前仅支持 False 。默认 False

  • trans_x2 (bool, 可选) - 表示 x2 是否转置。默认 False

返回:
  • output (Tensor) - matmul 和 reducescatter 融合计算的结果。

说明

  • 使用该接口时,请确保驱动固件包和 CANN 包都为配套的 8.0.RC2 版本或者配套的更高版本,否则将会引发报错,比如 BUS ERROR 等。

  • input 的 shape 为 (m, k), x2 的 shape 为 (k, n),要求 k 相等,且 k 的取值范围为 [256, 65535),要求 m 是 world_size 的整数倍。 output 的 shape 为 (m * world_size, n)。

  • 一个模型中的通算融合算子仅支持相同通信组。

异常:
  • TypeError - 参数的类型不对。

  • RuntimeError - inputx2 的 dtype 不是 float16 或 bfloat16。

  • RuntimeError - inputx2 的 dtype 不一致。

  • RuntimeError - inputx2 的 shape 不是二维。

  • RuntimeError - input shape 和 x2 shape 的 k 不相等。

  • RuntimeError - k 小于 256 或大于等于 65535

  • RuntimeError - bias 不是 None

  • RuntimeError - group 不存在。

  • RuntimeError - world_size 与实际运行的卡数不一致。

  • RuntimeError - world_size 不等于 248

  • RuntimeError - reduce_op 不是 'sum'

  • RuntimeError - trans_inputTrue

支持平台:

Ascend

样例:

说明

运行以下样例之前,需要配置好通信环境变量。

针对Ascend/GPU/CPU设备,推荐使用msrun启动方式,无第三方以及配置文件依赖。详见 msrun启动

该样例需要在 2 卡环境下运行。

>>> import mindspore as ms
>>> from mindspore import ops
>>> import numpy as np
>>> ms.communication.init()
>>> rank = ms.communication.get_rank()
>>> np.random.seed(rank)
>>> input = ms.Tensor(np.random.randn(1024, 256).astype(np.float32), dtype=ms.float16)
>>> x2 = ms.Tensor(np.random.randn(256, 512).astype(np.float32), dtype=ms.float16)
>>> group = ms.communication.GlobalComm.WORLD_COMM_GROUP
>>> world_size = ms.communication.get_group_size()
>>> reduce_op = ops.ReduceOp.SUM
>>> output = ops.matmul_reduce_scatter(
...    input,
...    x2,
...    group,
...    world_size,
...    reduce_op=reduce_op,
...    bias=None,
...    comm_turn=0,
...    trans_input=False,
...    trans_x2=False,
... )
>>> print(output.shape)
(512, 512)