mindspore.ops.NeighborExchange
- class mindspore.ops.NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, group=GlobalComm.WORLD_COMM_GROUP)[源代码]
NeighborExchange is a collective operation.
NeighborExchange sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids.
Note
The user needs to preset communication environment variables before running the following example, please check the details on the official website of MindSpore.
This operator requires a full-mesh network topology, each device has the same vlan id, and the ip & mask are in the same subnet, please check the details.
- Parameters
send_rank_ids (list(int)) – Ranks which the data is sent to.
recv_rank_ids (list(int)) – Ranks which the data is received from.
recv_shapes (tuple(list(int))) – Data shape which received from recv_rank_ids.
send_shapes (tuple(list(int))) – Data shape which send to the send_rank_ids.
recv_type (type) – Data type which received from recv_rank_ids
group (str) – The communication group to work on. Default: “GlobalComm.WORLD_COMM_GROUP”.
- Inputs:
input_x (tuple[Tensor]) - Shapes are same as args of send_shapes.
- Outputs:
Tuple tensor, shapes are same as args of recv_shapes.
- Supported Platforms:
Ascend
Examples
>>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn >>> import os >>> import mindspore as ms >>> from mindspore import Tensor >>> from mindspore.communication import init >>> import mindspore.nn as nn >>> import mindspore.ops as ops >>> import numpy as np >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1], ... recv_shapes=([2, 2],), send_shapes=([3, 3],), ... recv_type=ms.float32) ... ... ... def construct(self, x): ... out = self.neighborexchange((x,)) ... >>> ms.set_context(mode=ms.GRAPH_MODE, device_target='Ascend') >>> init() >>> net = Net() >>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32) >>> output = net(input_x) >>> print(output) [[2. 2.], [2. 2.]]