mindspore.ops.moe_token_unpermute

查看源文件
mindspore.ops.moe_token_unpermute(permuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None)[源代码]

根据排序的索引对已排列的标记进行反排列,并可选择将标记与其对应的概率合并。

警告

  • 仅支持 Atlas A2 训练系列产品。

  • 当前版本下,输入 permuted_tokensprobs 仅支持bfloat16类型。

  • 这是一个实验性API,后续可能修改或删除。

参数:
  • permuted_tokens (Tensor) - 要进行反排列的已排列标记的Tensor。 shape为 \([num\_tokens * topk, hidden\_size]\),其中 num_tokenstopkhidden_size 都是正整数。

  • sorted_indices (Tensor) - 用于反排列标记的排列索引Tensor。shape为 \([num\_tokens * topk,]\) , 其中 num_tokenstopk 都是正整数。

  • probs (Tensor,可选) - 与已排列标记对应的概率Tensor。如果提供,反排列的标记将与其相应的概率合并。 shape为 \([num\_tokens, topk]\) ,其中 num_tokenstopk 都是正整数。默认值: None

  • padded_mode (bool, 可选) - 如果为 True,表示索引被填充,以表示每个专家选择的标记。默认值: False

  • restore_shape (Union[tuple[int], list[int]],可选) - 排列之前的输入形状,仅在填充模式下使用。默认值: None

返回:

Tensor,类型与 permuted_tokens 一致。如果 padded_modeFalse, 则shape为[num_tokenshidden_size]。 如果 padded_modeTrue,则shape由 restore_shape 指定。

异常:
  • TypeError - input 不是Tensor。

  • ValueError - 仅支持 padded_modeFalse

支持平台:

Ascend

样例:

>>> import mindspore
>>> from mindspore import Tensor, mint
>>> permuted_token = Tensor([
...                          [1, 1, 1],
...                          [0, 0, 0],
...                          [0, 0, 0],
...                          [3, 3, 3],
...                          [2, 2, 2],
...                          [1, 1, 1],
...                          [2, 2, 2],
...                          [3, 3, 3]], dtype=mindspore.bfloat16)
>>> sorted_indices = Tensor([0, 6, 7, 5, 3, 1, 2, 4], dtype=mindspore.int32)
>>> out = mint.npu_moe_untoken_unpermute(permuted_token, sorted_indices)
>>> out.shape
(8, 3)