mindspore.ops.CollectiveGather
- class mindspore.ops.CollectiveGather(dest_rank, group=GlobalComm.WORLD_COMM_GROUP)[source]
Gathers tensors from the specified communication group. The operation will gather the tensor from processes according to dimension 0.
Note
Only the tensor in process dest_rank (global rank) will keep the gathered tensor. The other process will keep a tensor with shape [1], which has no mathematical meaning.
- Parameters
- Inputs:
input_x (Tensor) - The tensor to be gathered. The shape of tensor is \((x_1, x_2, ..., x_R)\).
- Outputs:
Tensor, the shape of output is \((\sum x_1, x_2, ..., x_R)\). The dimension 0 of data is equal to sum of the dimension of input tensor, and the other dimension keep the same.
- Raises
TypeError – If group is not a str.
RuntimeError – If device target is invalid, or backend is invalid, or distributed initialization fails.
ValueError – If the local rank id of the calling process in the group is larger than the group's rank size.
- Supported Platforms:
Ascend
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.
This example should be run with 4 devices.
>>> import numpy as np >>> import mindspore as ms >>> import mindspore.nn as nn >>> from mindspore.communication import init >>> from mindspore import Tensor >>> from mindspore import ops >>> # Launch 2 processes. >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> init() >>> class CollectiveGatherNet(nn.Cell): ... def __init__(self): ... super(CollectiveGatherNet, self).__init__() ... self.collective_gather = ops.CollectiveGather(dest_rank=0) ... ... def construct(self, x): ... return self.collective_gather(x) ... >>> input = Tensor(np.arange(4).reshape([2, 2]).astype(np.float32)) >>> net = CollectiveGatherNet() >>> output = net(input) >>> print(output) Process with rank 0: [[0. 1.], [2. 3.], [0. 1.], [2. 3.]] Process with rank 1: [0.]
- Tutorial Examples: