mindspore.ops.Receive

查看源文件
class mindspore.ops.Receive(sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP, group_back=GlobalComm.WORLD_COMM_GROUP)[源代码]

接收来自 src_rank 线程的张量。

说明

Send 和 Receive 算子需组合使用,且有同一个 sr_tag

参数:
  • sr_tag (int) - 用于区分发送、接收消息的标签。该消息将被接收来自相同 sr_tag 的Send发送的张量。

  • src_rank (int) - 表示发送源的进程编号。只会接收来自源进程的张量。

  • shape (list[int]) - 表示发送源的张量形状。

  • dtype (Type) - 表示发送源的张量类型。所支持的类型有:int8、int16、int32、float16、float32。

  • group (str,可选) - 表示通信域。默认值: GlobalComm.WORLD_COMM_GROUP

  • group_back (str,可选) - 表示计算反向传播时的通信域。默认值: GlobalComm.WORLD_COMM_GROUP

输出:
  • Tensor - Tensor的shape与Send算子所发送Tensor的shape相同。

异常:
  • TypeError - src_rank不是int或group不是str。

  • RuntimeError - 如果目标设备无效,或者后端无效,或者分布式初始化失败。

  • ValueError - 如果该线程的rank id 大于通信组的rank size。

支持平台:

Ascend GPU

样例:

说明

运行以下样例之前,需要配置好通信环境变量。

针对Ascend/GPU/CPU设备,推荐使用msrun启动方式,无第三方以及配置文件依赖。详见 msrun启动

该样例需要在2卡环境下运行。

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> from mindspore import ops
>>>
>>> init()
>>> class ReceiveNet(nn.Cell):
>>>     def __init__(self):
>>>         super(ReceiveNet, self).__init__()
>>>         self.recv = ops.Receive(sr_tag=0, src_rank=0, shape=[2, 8], dtype=ms.float32,
>>>                               group="hccl_world_group")
>>>
>>>     def construct(self):
>>>         out = self.recv()
>>>         return out
>>>
>>> net = Net()
>>> output = net()
教程样例: