mindspore.ops.NeighborExchange

class mindspore.ops.NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, group=GlobalComm.WORLD_COMM_GROUP)[源代码]

NeighborExchange is a collective operation.

NeighborExchange sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids.

Note

The user needs to preset communication environment variables before running the following example, please check the details on the official website of MindSpore.

This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details.

Parameters
  • send_rank_ids (list(int)) – Ranks which the data is sent to.

  • recv_rank_ids (list(int)) – Ranks which the data is received from.

  • recv_shapes (tuple(list(int))) – Data shape which received from recv_rank_ids.

  • send_shapes (tuple(list(int))) – Data shape which send to the send_rank_ids.

  • recv_type (type) – Data type which received from recv_rank_ids

  • group (str) – The communication group to work on. Default: “GlobalComm.WORLD_COMM_GROUP”.

Supported Platforms:

Ascend

Example

>>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>> class Net(nn.Cell):
...     def __init__(self):
...         super(Net, self).__init__()
...         self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1],
...                                                      recv_shapes=([2, 2],), send_shapes=([3, 3],),
...                                                      recv_type=ms.float32)
...
...
...     def construct(self, x):
...         out = self.neighborexchange((x,))
...
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target='Ascend')
>>> init()
>>> net = Net()
>>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[2. 2.], [2. 2.]]