mindspore.communication.comm_func.all_to_all_single_with_output_shape

查看源文件
mindspore.communication.comm_func.all_to_all_single_with_output_shape(output_shape, tensor, output_split_sizes=None, input_split_sizes=None, group=None)[源代码]

根据用户输入的切分大小,把输入tensor切分后,发送到其他的设备上,并从其他设备接收切分块,然后合并到一个输出tensor中。

说明

各个rank之间发送和接收的切分块大小需要互相匹配。 仅支持PyNative模式,目前不支持Graph模式。

参数:
  • output_shape (Union(Tensor, Tuple(int))) - 表示接收的张量的形状。

  • tensor (Tensor) - 要发送到远端设备的张量。

  • output_split_sizes (Union(Tuple(int), List(int))) - 接收张量在0维的切分大小列表。默认值:None,表示均匀切分。

  • input_split_sizes (Union(Tuple(int), List(int))) - 发送张量在0维的切分大小列表。默认值:None,表示均匀切分。

  • group (str, 可选) - 通信执行所在的通信组。默认值:None。为None时,在Ascend上将使用为 hccl_world_group,在GPU上使用 nccl_world_group

返回:

从远端设备接收分块并合并的张量。如果从其他设备接收的张量为空,它将返回一个没有实际意义的值为0的张量。

异常:
  • TypeError - tensor 不是张量类型。

  • TypeError - output_shape 不是元组或者张量类型。

支持平台:

Ascend

样例:

说明

运行以下样例之前,需要配置好通信环境变量。

针对Ascend/GPU/CPU设备,推荐使用msrun启动方式,无第三方以及配置文件依赖。详见 msrun启动

该样例需要在2卡环境下运行。

>>> import numpy as np
>>> import mindspore
>>> from mindspore.communication import init, get_rank, get_group_size
>>> from mindspore.communication.comm_func import all_to_all_single_with_output_shape
>>> from mindspore import Tensor
>>> from mindspore.ops import zeros
>>>
>>> init()
>>> this_rank = get_rank()
>>> if this_rank == 0:
>>>     output_shape = (3, 3)
>>>     tensor = Tensor([[0, 1, 2.], [3, 4, 5], [6, 7, 8]])
>>>     result = all_to_all_single_with_output_shape(output_shape, tensor, [2, 1], [2, 1])
>>> if this_rank == 1:
>>>     output_shape = (2, 3)
>>>     tensor = Tensor([[9, 10., 11], [12, 13, 14]])
>>>     result = all_to_all_single_with_output_shape(output_shape, tensor)
>>> print(result)
rank 0:
[[ 0.  1.  2.]
[ 3.  4.  5.]
[ 9. 10. 11.]]
rank 1:
[[ 6.  7.  8.]
[12. 13. 14.]]