mindspore.ops.all_gather_matmul
- mindspore.ops.all_gather_matmul(input, x2, group, world_size, *, bias=None, gather_index=0, gather_output=True, comm_turn=0, trans_input=False, trans_x2=False) Tensor [source]
In the TP segmentation scenario, allgather and matmul are fused, and communication and computational pipelines are parallelized within the fusion operator.
\[ \begin{align}\begin{aligned}output = allgather(input)@x2\\gather\_out = allgather(input)\end{aligned}\end{align} \]Warning
This is an experimental API that is subject to change or deletion.
- Parameters
input (Tensor) – The left matrix of matmul, the dtype supports float16 and bfloat16, the shape supports 2 dimensions, and the data format supports ND.
x2 (Tensor) – The right matrix of matmul, the dtype needs to be consistent with
input
, the shape supports 2 dimensions, and the data format supports ND.group (str) – Communication group name, can be created by
create_group
method, or use the default groupmindspore.communication.GlobalComm.WORLD_COMM_GROUP
.world_size (int) – The total number of ranks in the communication group, should be consistent with the number of devices actually running, supporting
2
,4
, and8
.
- Keyword Arguments
bias (Tensor, optional) – Currently only
None
is supported. Default:None
.gather_index (int, optional) – Indicates the allgather operation object,
0
means gatherinput
,1
means gatherx2
. Currently only0
is supported. Default:0
.gather_output (bool, optional) – Indicates whether gather output is required. Default:
True
.comm_turn (int, optional) – Indicates the granularity of communication between ranks. Currently only
0
is supported. Default:0
.trans_input (bool, optional) – Indicates whether
input
is transposed. Currently onlyFalse
is supported. Default:False
.trans_x2 (bool, optional) – Indicates whether
x2
is transposed. Default:False
.
- Returns
output (Tensor) - The result of allgather and matmul fusion calculations.
gather_out (Tensor) - The result of allgather. If gather_output is
False
,gather_out
returns a tensor with shape 0.
Note
When using this interface, please ensure that the driver firmware package and CANN package are both the matching 8.0.RC2 version or a higher version, otherwise an error will be reported, such as BUS ERROR.
The shape of
input
is (m, k), the shape ofx2
is (k, n), k is required to be equal, and the value range of k is [256, 65535). The shape ofoutput
is (m * world_size, n), and the shape ofgather_out
is (m * world_size, k).The common fusion operators in a model only support the same communication group.
- Raises
TypeError – Any arg is of wrong type.
RuntimeError – The dtype of
input
orx2
is neither float16 nor bfloat16.RuntimeError – The dtypes of
input
andx2
are different.RuntimeError – The shape of
input
orx2
is not two-dimensional.RuntimeError – The k axis of
input
shape andx2
shape are not equal.RuntimeError – k is less than
256
or greater than or equal to65535
.RuntimeError –
bias
is not None.RuntimeError –
group
does not exist.RuntimeError –
world_size
is inconsistent with the actual number of running cards.RuntimeError –
world_size
is not equal to2
,4
, or8
.RuntimeError –
gather_index
is not0
.RuntimeError –
trans_input
isTrue
.
- Supported Platforms:
Ascend
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.
This example should be run with 2 devices.
>>> import mindspore as ms >>> import numpy as np >>> from mindspore import ops >>> ms.communication.init() >>> rank = ms.communication.get_rank() >>> np.random.seed(rank) >>> input = ms.Tensor(np.random.randn(128, 256).astype(np.float32), dtype=ms.float16) >>> x2 = ms.Tensor(np.random.randn(256, 512).astype(np.float32), dtype=ms.float16) >>> group = ms.communication.GlobalComm.WORLD_COMM_GROUP >>> world_size = ms.communication.get_group_size() >>> output, gather_out = ops.all_gather_matmul( ... input, ... x2, ... group, ... world_size, ... bias=None, ... gather_index=0, ... gather_output=True, ... comm_turn=0, ... trans_input=False, ... trans_x2=False, ... ) >>> print(output.shape) (256, 512) >>> print(gather_out.shape) (256, 256)