mindspore.ops.matmul_reduce_scatter
- mindspore.ops.matmul_reduce_scatter(input, x2, group, world_size, *, reduce_op='sum', bias=None, comm_turn=0, trans_input=False, trans_x2=False) Tensor [source]
In the TP segmentation scenario, matmul and reducescatter are fused, and communication and computational pipelines are parallelized within the fusion operator.
Warning
This is an experimental API that is subject to change or deletion.
- Parameters
input (Tensor) – The left matrix of matmul, the dtype supports float16 and bfloat16, the shape supports 2 dimensions, and the data format supports ND.
x2 (Tensor) – The right matrix of matmul, the dtype needs to be consistent with
input
, the shape supports 2 dimensions, and the data format supports ND.group (str) – Communication group name, can be created by
create_group
method, or use the default groupmindspore.communication.GlobalComm.WORLD_COMM_GROUP
.world_size (int) – The total number of ranks in the communication group, should be consistent with the number of devices actually running, supporting
2
,4
, and8
.
- Keyword Arguments
reduce_op (str, optional) –
'sum'
.bias (Tensor, optional) – Currently only
None
is supported. Default:None
.comm_turn (int, optional) – Indicates the granularity of communication between ranks. Currently only
0
is supported. Default:0
.trans_input (bool, optional) – Indicates whether
input
is transposed. Currently onlyFalse
is supported. Default:False
.trans_x2 (bool, optional) – Indicates whether
x2
is transposed. Default:False
.
- Returns
output (Tensor) - The result of allgather and matmul fusion calculations.
Note
When using this interface, please ensure that the driver firmware package and CANN package are both the matching 8.0.RC2 version or a higher version, otherwise an error will be reported, such as BUS ERROR.
The shape of
input
is (m, k), the shape ofx2
is (k, n), k is required to be equal, and the value range of k is [256, 65535), and m is required to be an integer multiple ofworld_size
. The shape ofoutput
is (m * world_size, n).The common fusion operators in a model only support the same communication group.
- Raises
TypeError – Any arg is of wrong type.
RuntimeError – The dtype of
input
orx2
is neither float16 nor bfloat16.RuntimeError – The dtypes of
input
andx2
are different.RuntimeError – The shape of
input
orx2
is not two-dimensional.RuntimeError – The k axis of
input
shape andx2
shape are not equal.RuntimeError – k is less than
256
or greater than or equal to65535
.RuntimeError –
bias
is not None.RuntimeError –
group
does not exist.RuntimeError –
world_size
is inconsistent with the actual number of running cards.RuntimeError –
world_size
is not equal to2
,4
, or8
.RuntimeError –
reduce_op
is not'sum'
.RuntimeError –
trans_input
isTrue
.
- Supported Platforms:
Ascend
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.
This example should be run with 2 devices.
>>> import mindspore as ms >>> from mindspore import ops >>> import numpy as np >>> ms.communication.init() >>> rank = ms.communication.get_rank() >>> np.random.seed(rank) >>> input = ms.Tensor(np.random.randn(1024, 256).astype(np.float32), dtype=ms.float16) >>> x2 = ms.Tensor(np.random.randn(256, 512).astype(np.float32), dtype=ms.float16) >>> group = ms.communication.GlobalComm.WORLD_COMM_GROUP >>> world_size = ms.communication.get_group_size() >>> reduce_op = ops.ReduceOp.SUM >>> output = ops.matmul_reduce_scatter( ... input, ... x2, ... group, ... world_size, ... reduce_op=reduce_op, ... bias=None, ... comm_turn=0, ... trans_input=False, ... trans_x2=False, ... ) >>> print(output.shape) (512, 512)