mindspore.ops.AlltoAllV
- class mindspore.ops.AlltoAllV(group=GlobalComm.WORLD_COMM_GROUP)[源代码]
相对AlltoAll来说,AlltoAllV算子支持不等分的切分和聚合。
说明
只支持一维的输入,使用该接口前需要将输入数据展开成一维。
- 参数:
group (str) - AlltoAll的通信域。默认值:
GlobalComm.WORLD_COMM_GROUP
,Ascend平台表示为"hccl_world_group"
。
- 输入:
input_x (Tensor) - 一维待分发的张量, shape为 \((x_1)\)。
send_numel_list (Union[tuple[int], list[int], Tensor]) - 分发给每张卡的数据量。
recv_numel_list (Union[tuple[int], list[int], Tensor]) - 从每张卡聚合的数据量。
- 输出:
Tensor,从每张卡上聚合的一维数据结果。如果结果为空,则范围无意义的数值0。
- 支持平台:
Ascend
样例:
>>> from mindspore import ops >>> import mindspore.nn as nn >>> from mindspore.communication import init, get_rank >>> from mindspore import Tensor >>> >>> init() >>> rank = get_rank() >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.all_to_all = ops.AlltoAllV() ... ... def construct(self, x, send_numel_list, recv_numel_list): ... return self.all_to_all(x, send_numel_list, recv_numel_list) >>> send_numel_list = [] >>> recv_numel_list = [] >>> if rank == 0: >>> send_tensor = Tensor([0, 1, 2.]) >>> send_numel_list = [1, 2] >>> recv_numel_list = [1, 2] >>> elif rank == 1: >>> send_tensor = Tensor([3, 4, 5.]) >>> send_numel_list = [2, 1] >>> recv_numel_list = [2, 1] >>> net = Net() >>> output = net(send_tensor, send_numel_list, recv_numel_list) >>> print(output) rank 0: [0. 3. 4] rank 1: [1. 2. 5]
- 教程样例: