mindspore.ops.NeighborExchangeV2

View Source On Gitee
class mindspore.ops.NeighborExchangeV2(send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group=GlobalComm.WORLD_COMM_GROUP)[source]

NeighborExchangeV2 is a collective communication operation.

NeighborExchangeV2 sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids. Please refer to the tutorial examples below to learn about how the data is exchanged between neighborhood devices.

Note

This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details .

Parameters
  • send_rank_ids (list(int)) – Ranks which the data is sent to. 8 rank_ids represents 8 directions, if one direction is not send to , set it -1.

  • recv_rank_ids (list(int)) – Ranks which the data is received from. 8 rank_ids represents 8 directions, if one direction is not recv from , set it -1.

  • send_lens (list(int)) – Data lens which send to the send_rank_ids, 4 numbers represent the lens of [send_top, send_bottom, send_left, send_right].

  • recv_lens (list(int)) – Data lens which received from recv_rank_ids, 4 numbers represent the lens of [recv_top, recv_bottom, recv_left, recv_right].

  • data_format (str) – Data format, only support NCHW now.

  • group (str, optional) – The communication group to work on. Default: GlobalComm.WORLD_COMM_GROUP , which means "hccl_world_group" in Ascend, and "nccl_world_group" in GPU.

Inputs:
  • input_x (Tensor) - The Tensor before being exchanged. It has a shape of \((N, C, H, W)\).

Outputs:

The Tensor after being exchanged. If input shape is \((N, C, H, W)\), output shape is \((N, C, H+recv\_top+recv\_bottom, W+recv\_left+recv\_right)\).

Raises
  • TypeError – If group is not a string or any one of send_rank_ids, recv_rank_ids, send_lens, recv_lens is not a list.

  • ValueError – If send_rank_ids or recv_rank_ids has value less than -1 or has repeated values.

  • ValueError – If send_lens, recv_lens has value less than 0.

  • ValueError – If data_format is not “NCHW”.

Supported Platforms:

Ascend

Examples

Note

Before running the following examples, you need to configure the communication environment variables.

For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. Please see the rank table Startup for more details.

For the GPU devices, users need to prepare the host file and mpi, please see the mpirun Startup .

For the CPU device, users need to write a dynamic cluster startup script, please see the Dynamic Cluster Startup .

This example should be run with 2 devices.

>>> import os
>>> import mindspore as ms
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>>
>>> class Net0(nn.Cell):
...     def __init__(self):
...         super(Net0, self).__init__()
...         self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
...                                                           send_lens=[0, 1, 0, 0],
...                                                           recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
...                                                           recv_lens=[0, 1, 0, 0], data_format="NCHW")
...
...     def construct(self, x):
...         out = self.neighbor_exchangev2(x)
...         return out
>>>
... class Net1(nn.Cell):
...     def __init__(self):
...         super(Net1, self).__init__()
...         self.neighbor_exchangev2 = ops.NeighborExchangeV2(send_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
...                                                           send_lens=[1, 0, 0, 0],
...                                                           recv_rank_ids=[0, -1, -1, -1, -1, -1, -1, -1],
...                                                           recv_lens=[1, 0, 0, 0], data_format="NCHW")
...
...     def construct(self, x):
...         out = self.neighbor_exchangev2(x)
...         return out
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> init()
>>> rank_id = int(os.getenv("RANK_ID"))
>>> if (rank_id % 2 == 0):
>>>     input_x = ms.Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
>>>     net = Net0()
>>>     output = net(input_x)
>>>     print(output)
>>> else:
>>>     input_x = ms.Tensor(np.ones([1, 1, 2, 2]) * 2, dtype = ms.float32)
>>>     net = Net1()
>>>     output = net(input_x)
>>>     print(output)
[[[[1. 1.], [1. 1.], [2. 2.]]]]
Tutorial Examples: