mindspore.ops.NeighborExchangeV2
- class mindspore.ops.NeighborExchangeV2(send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group=GlobalComm.WORLD_COMM_GROUP)[source]
NeighborExchangeV2 is a collective communication operation.
NeighborExchangeV2 sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids. Please refer to Distributed Set Communication Primitives - NeighborExchangeV2 to learn about how the data is exchanged between neighborhood devices.
Note
This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details .
- Parameters
send_rank_ids (list(int)) – Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one direction is not send to , set it -1.
recv_rank_ids (list(int)) – Ranks which the data is received from. 8 rank_ids represents 8 directions, if one direction is not recv from , set it -1.
send_lens (list(int)) – Data lens which send to the send_rank_ids, 4 numbers represent the lens of [send_top, send_bottom, send_left, send_right].
recv_lens (list(int)) – Data lens which received from recv_rank_ids, 4 numbers represent the lens of [recv_top, recv_bottom, recv_left, recv_right].
data_format (str) – Data format, only support NCHW now.
group (str, optional) – The communication group to work on. Default: “GlobalComm.WORLD_COMM_GROUP”, which means “hccl_world_group” in Ascend, and “nccl_world_group” in GPU.
- Inputs:
input_x (Tensor) - The Tensor before being exchanged. It has a shape of \((N, C, H, W)\).
- Outputs:
The Tensor after being exchanged. If input shape is \((N, C, H, W)\), output shape is \((N, C, H+recv\_top+recv\_bottom, W+recv\_left+recv\_right)\).
- Raises
TypeError – If group is not a string or any one of send_rank_ids, recv_rank_ids, send_lens, recv_lens is not a list.
ValueError – If send_rank_ids or recv_rank_ids has value less than -1 or has repeated values.
ValueError – If send_lens, recv_lens has value less than 0.
ValueError – If data_format is not “NCHW”.
- Supported Platforms:
Ascend
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. Please see the Ascend tutorial for more details.
For the GPU devices, users need to prepare the host file and mpi, please see the GPU tutorial .
This example should be run with 2 devices.
>>> import os >>> import mindspore as ms >>> from mindspore import Tensor >>> from mindspore.communication import init >>> import mindspore.nn as nn >>> import mindspore.ops as ops >>> import numpy as np >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.neighborexchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], ... send_lens=[0, 1, 0, 0], ... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], ... recv_lens=[0, 1, 0, 0], ... data_format="NCHW") ... ... def construct(self, x): ... out = self.neighborexchangev2(x) ... return out ... >>> ms.set_context(mode=ms.GRAPH_MODE, device_target='Ascend') >>> init() >>> input_x = Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32) >>> net = Net() >>> output = net(input_x) >>> print(output) [[[[1. 1.], [1. 1.], [2. 2.]]]]