mindspore.ops.moe_token_unpermute
- mindspore.ops.moe_token_unpermute(permuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None)[源代码]
根据排序的索引对已排列的标记进行反排列,并可选择将标记与其对应的概率合并。
警告
仅支持 Atlas A2 训练系列产品。
当前版本下,输入 permuted_tokens 和 probs 仅支持bfloat16类型。
这是一个实验性API,后续可能修改或删除。
- 参数:
permuted_tokens (Tensor) - 要进行反排列的已排列标记的Tensor。 shape为 \([num\_tokens * topk, hidden\_size]\),其中 num_tokens、 topk 和 hidden_size 都是正整数。
sorted_indices (Tensor) - 用于反排列标记的排列索引Tensor。shape为 \([num\_tokens * topk,]\) , 其中 num_tokens 和 topk 都是正整数。
probs (Tensor,可选) - 与已排列标记对应的概率Tensor。如果提供,反排列的标记将与其相应的概率合并。 shape为 \([num\_tokens, topk]\) ,其中 num_tokens 和 topk 都是正整数。默认值:
None
。padded_mode (bool, 可选) - 如果为
True
,表示索引被填充,以表示每个专家选择的标记。默认值:False
。restore_shape (Union[tuple[int], list[int]],可选) - 排列之前的输入形状,仅在填充模式下使用。默认值:
None
。
- 返回:
Tensor,类型与 permuted_tokens 一致。如果 padded_mode 为
False
, 则shape为[num_tokens, hidden_size]。 如果 padded_mode 为True
,则shape由 restore_shape 指定。- 异常:
TypeError - input 不是Tensor。
ValueError - 仅支持 padded_mode 为
False
。
- 支持平台:
Ascend
样例:
>>> import mindspore >>> from mindspore import Tensor, mint >>> permuted_token = Tensor([ ... [1, 1, 1], ... [0, 0, 0], ... [0, 0, 0], ... [3, 3, 3], ... [2, 2, 2], ... [1, 1, 1], ... [2, 2, 2], ... [3, 3, 3]], dtype=mindspore.bfloat16) >>> sorted_indices = Tensor([0, 6, 7, 5, 3, 1, 2, 4], dtype=mindspore.int32) >>> out = mint.npu_moe_untoken_unpermute(permuted_token, sorted_indices) >>> out.shape (8, 3)