
View Source On Gitee
mindspore.ops.all_gather_matmul(input, x2, group, world_size, *, bias=None, gather_index=0, gather_output=True, comm_turn=0, trans_input=False, trans_x2=False) Tensor[source]

In the TP segmentation scenario, allgather and matmul are fused, and communication and computational pipelines are parallelized within the fusion operator.

\[ \begin{align}\begin{aligned}output = allgather(input)@x2\\gather\_out = allgather(input)\end{aligned}\end{align} \]


This is an experimental API that is subject to change or deletion.

  • input (Tensor) – The left matrix of matmul, the dtype supports float16 and bfloat16, the shape supports 2 dimensions, and the data format supports ND.

  • x2 (Tensor) – The right matrix of matmul, the dtype needs to be consistent with input , the shape supports 2 dimensions, and the data format supports ND.

  • group (str) – Communication group name, can be created by create_group method, or use the default group mindspore.communication.GlobalComm.WORLD_COMM_GROUP.

  • world_size (int) – The total number of ranks in the communication group, should be consistent with the number of devices actually running, supporting 2 , 4 , and 8 .

Keyword Arguments
  • bias (Tensor, optional) – Currently only None is supported. Default: None .

  • gather_index (int, optional) – Indicates the allgather operation object, 0 means gather input , 1 means gather x2 . Currently only 0 is supported. Default: 0 .

  • gather_output (bool, optional) – Indicates whether gather output is required. Default: True .

  • comm_turn (int, optional) – Indicates the granularity of communication between ranks. Currently only 0 is supported. Default: 0 .

  • trans_input (bool, optional) – Indicates whether input is transposed. Currently only False is supported. Default: False .

  • trans_x2 (bool, optional) – Indicates whether x2 is transposed. Default: False .


  • output (Tensor) - The result of allgather and matmul fusion calculations.

  • gather_out (Tensor) - The result of allgather. If gather_output is False , gather_out returns a tensor with shape 0.


  • When using this interface, please ensure that the driver firmware package and CANN package are both the matching 8.0.RC2 version or a higher version, otherwise an error will be reported, such as BUS ERROR.

  • The shape of input is (m, k), the shape of x2 is (k, n), k is required to be equal, and the value range of k is [256, 65535). The shape of output is (m * world_size, n), and the shape of gather_out is (m * world_size, k).

  • The common fusion operators in a model only support the same communication group.

Supported Platforms:




Before running the following examples, you need to configure the communication environment variables.

For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.

This example should be run with 2 devices.

>>> import mindspore as ms
>>> import numpy as np
>>> from mindspore import ops
>>> ms.communication.init()
>>> rank = ms.communication.get_rank()
>>> np.random.seed(rank)
>>> input = ms.Tensor(np.random.randn(128, 256).astype(np.float32), dtype=ms.float16)
>>> x2 = ms.Tensor(np.random.randn(256, 512).astype(np.float32), dtype=ms.float16)
>>> group = ms.communication.GlobalComm.WORLD_COMM_GROUP
>>> world_size = ms.communication.get_group_size()
>>> output, gather_out = ops.all_gather_matmul(
...    input,
...    x2,
...    group,
...    world_size,
...    bias=None,
...    gather_index=0,
...    gather_output=True,
...    comm_turn=0,
...    trans_input=False,
...    trans_x2=False,
... )
>>> print(output.shape)
(256, 512)
>>> print(gather_out.shape)
(256, 256)