mindspore.ops.CollectiveScatter
- class mindspore.ops.CollectiveScatter(src_rank=0, group=GlobalComm.WORLD_COMM_GROUP)[source]
Scatter tensor evently across the processes in the specified communication group.
Note
The interface behavior only support Tensor input and scatter evenly. Only the tensor in process src_rank (global rank) will do scatter.
- Parameters
- Inputs:
input_x (Tensor) - The input tensor to be scattered. The shape of tensor is \((x_1, x_2, ..., x_R)\).
- Outputs:
Tensor, the shape of output is \((x_1/src\_rank, x_2, ..., x_R)\). The dimension 0 of data is equal to the dimension of input tensor divided by src, and the other dimension keep the same.
- Raises
TypeError – If group is not a str.
RuntimeError – If device target is invalid, or backend is invalid, or distributed initialization fails.
ValueError – If the local rank id of the calling process in the group is larger than the group's rank size.
- Supported Platforms:
Ascend
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For Ascend/GPU/CPU devices, it is recommended to use the msrun startup method without any third-party or configuration file dependencies. Please see the msrun start up for more details.
This example should be run with 2 devices.
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import Tensor >>> from mindspore.communication.management import init, get_rank >>> from mindspore import ops >>> # Launch 2 processes. >>> init() >>> class CollectiveScatterNet(nn.Cell): >>> def __init__(self): >>> super(CollectiveScatter, self).__init__() >>> self.collective_scatter = ops.CollectiveScatter(src_rank=0) >>> >>> def construct(self, x): >>> return self.collective_scatter(x) >>> >>> input = Tensor(np.arange(8).reshape([4, 2]).astype(np.float32)) >>> net = CollectiveScatterNet() >>> output = net(input) >>> print(output) Process with rank 0: [[0. 1.], [2. 3.]] Process with rank 1: [[4. 5.], [6. 7.]]
- Tutorial Examples: