mindspore.ops.matmul_reduce_scatter =================================== .. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg :target: https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/ops/mindspore.ops.func_matmul_reduce_scatter.rst :alt: 查看源文件 .. py:function:: mindspore.ops.matmul_reduce_scatter(input, x2, group, world_size, *, reduce_op='sum', bias=None,\ comm_turn=0, trans_input=False, trans_x2=False) TP 切分场景下,实现 matmul 和 reducescatter 的融合,融合算子内部实现通信和计算流水并行。 .. math:: output = reducescatter(input@x2) .. warning:: 这是一个实验性 API,后续可能修改或删除。 参数: - **input** (Tensor) - matmul 的左矩阵,dtype 支持 float16、bfloat16,shape 支持二维,数据格式支持 ND。 - **x2** (Tensor) - matmul 的右矩阵,dtype 需要和 ``input`` 一致,shape 支持二维,数据格式支持 ND。 - **group** (str) - 通信组名称,可以由 ``create_group`` 方法创建,或者使用默认组 ``mindspore.communication.GlobalComm.WORLD_COMM_GROUP`` 。 - **world_size** (int) - 通信组的总进程数,要求与实际运行的卡数一致,支持 ``2`` 、 ``4`` 、 ``8`` 。 关键字参数: - **reduce_op** (str, 可选) - reduce 操作类型。当前仅支持 ``'sum'`` 。默认 ``'sum'`` 。 - **bias** (Tensor, 可选) - 当前仅支持 ``None`` 。默认 ``None`` 。 - **comm_turn** (int, 可选) - 表示进程间通信切分粒度。当前仅支持 ``0`` 。默认 ``0`` 。 - **trans_input** (bool, 可选) - 表示 ``input`` 是否转置。当前仅支持 ``False`` 。默认 ``False`` 。 - **trans_x2** (bool, 可选) - 表示 ``x2`` 是否转置。默认 ``False`` 。 返回: - **output** (Tensor) - matmul 和 reducescatter 融合计算的结果。 .. note:: - 使用该接口时,请确保驱动固件包和 CANN 包都为配套的 8.0.RC2 版本或者配套的更高版本,否则将会引发报错,比如 BUS ERROR 等。 - ``input`` 的 shape 为 (m, k), ``x2`` 的 shape 为 (k, n),要求 k 相等,且 k 的取值范围为 [256, 65535),要求 m 是 ``world_size`` 的\ 整数倍。 ``output`` 的 shape 为 (m * world_size, n)。 - 一个模型中的通算融合算子仅支持相同通信组。 异常: - **TypeError** - 参数的类型不对。 - **RuntimeError** - ``input`` 或 ``x2`` 的 dtype 不是 float16 或 bfloat16。 - **RuntimeError** - ``input`` 和 ``x2`` 的 dtype 不一致。 - **RuntimeError** - ``input`` 或 ``x2`` 的 shape 不是二维。 - **RuntimeError** - ``input`` shape 和 ``x2`` shape 的 k 不相等。 - **RuntimeError** - k 小于 ``256`` 或大于等于 ``65535`` 。 - **RuntimeError** - ``bias`` 不是 ``None`` 。 - **RuntimeError** - ``group`` 不存在。 - **RuntimeError** - ``world_size`` 与实际运行的卡数不一致。 - **RuntimeError** - ``world_size`` 不等于 ``2`` 、 ``4`` 、 ``8`` 。 - **RuntimeError** - ``reduce_op`` 不是 ``'sum'`` 。 - **RuntimeError** - ``trans_input`` 为 ``True`` 。 样例: .. note:: .. include:: mindspore.ops.comm_note.txt 该样例需要在 2 卡环境下运行。