mindspore.ops.moe_token_unpermute

View Source On Gitee
mindspore.ops.moe_token_unpermute(permuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None)[source]

Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.

Warning

  • It is only supported on Atlas A2 Training Series Products.

  • The inputs permuted_tokens and probs only support the bfloat16 data type in the current version.

  • This is an experimental API that is subject to change or deletion.

Parameters
  • permuted_tokens (Tensor) – The tensor of permuted tokens to be unpermuted. The shape is \([num\_tokens * topk, hidden\_size]\) , where num_tokens, topk and hidden_size are positive integers.

  • sorted_indices (Tensor) – The tensor of sorted indices used to unpermute the tokens. The shape is \([num\_tokens * topk,]\), where num_tokens and topk are positive integers.

  • probs (Tensor, optional) – The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. The shape is \([num\_tokens, topk]\), where num_tokens and topk are positive integers. Default: None .

  • padded_mode (bool, optional) – If True, indicating the indices are padded to denote selected tokens per expert. Default: False .

  • restore_shape (Union[tuple[int], list[int]], optional) – The input shape before permutation, only used in padding mode. Default: None .

Returns

Tensor, with the same dtype as permuted_tokens. If padded_mode is False, the shape will be [num_tokens, hidden_size]. If padded_mode is True, the shape will be specified by restore_shape.

Raises
  • TypeError – If input is not a Tensor.

  • ValueError – Only supported when padded_mode is False.

Supported Platforms:

Ascend

Examples

>>> import mindspore
>>> from mindspore import Tensor, mint
>>> permuted_token = Tensor([
...                          [1, 1, 1],
...                          [0, 0, 0],
...                          [0, 0, 0],
...                          [3, 3, 3],
...                          [2, 2, 2],
...                          [1, 1, 1],
...                          [2, 2, 2],
...                          [3, 3, 3]], dtype=mindspore.bfloat16)
>>> sorted_indices = Tensor([0, 6, 7, 5, 3, 1, 2, 4], dtype=mindspore.int32)
>>> out = mint.npu_moe_untoken_unpermute(permuted_token, sorted_indices)
>>> out.shape
(8, 3)