mindspore.ops.moe_token_unpermute
- mindspore.ops.moe_token_unpermute(permuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None)[source]
Unpermute a tensor of permuted tokens based on sorted indices, and optionally merge the tokens with their corresponding probabilities.
Warning
It is only supported on Atlas A2 Training Series Products.
The inputs permuted_tokens and probs only support the bfloat16 data type in the current version.
This is an experimental API that is subject to change or deletion.
- Parameters
permuted_tokens (Tensor) – The tensor of permuted tokens to be unpermuted. The shape is \([num\_tokens * topk, hidden\_size]\) , where num_tokens, topk and hidden_size are positive integers.
sorted_indices (Tensor) – The tensor of sorted indices used to unpermute the tokens. The shape is \([num\_tokens * topk,]\), where num_tokens and topk are positive integers.
probs (Tensor, optional) – The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. The shape is \([num\_tokens, topk]\), where num_tokens and topk are positive integers. Default:
None
.padded_mode (bool, optional) – If
True
, indicating the indices are padded to denote selected tokens per expert. Default:False
.restore_shape (Union[tuple[int], list[int]], optional) – The input shape before permutation, only used in padding mode. Default:
None
.
- Returns
Tensor, with the same dtype as permuted_tokens. If padded_mode is
False
, the shape will be [num_tokens, hidden_size]. If padded_mode isTrue
, the shape will be specified by restore_shape.- Raises
TypeError – If input is not a Tensor.
ValueError – Only supported when padded_mode is
False
.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> from mindspore import Tensor, mint >>> permuted_token = Tensor([ ... [1, 1, 1], ... [0, 0, 0], ... [0, 0, 0], ... [3, 3, 3], ... [2, 2, 2], ... [1, 1, 1], ... [2, 2, 2], ... [3, 3, 3]], dtype=mindspore.bfloat16) >>> sorted_indices = Tensor([0, 6, 7, 5, 3, 1, 2, 4], dtype=mindspore.int32) >>> out = mint.npu_moe_untoken_unpermute(permuted_token, sorted_indices) >>> out.shape (8, 3)