mindspore.ops.moe_token_permute
- mindspore.ops.moe_token_permute(tokens, indices, num_out_tokens=None, padded_mode=False)[source]
Permute the tokens based on the indices. Token with the same index will be grouped together.
Warning
It is only supported on Atlas A2 Training Series Products.
The input tokens only supports the bfloat16 data type in the current version.
This is an experimental API that is subject to change or deletion.
When indices is 2-D, the size of the second dim must be less than or equal to 512.
- Parameters
tokens (Tensor) – The input token tensor to be permuted. The dtype is bfloat16. The shape is \((num\_tokens, hidden\_size)\) , where num_tokens and hidden_size are positive integers.
indices (Tensor) – The tensor specifies indices used to permute the tokens. The dtype is int32 or int64. The shape is \((num\_tokens, topk)\) or \((num\_tokens,)\), where num_tokens and topk are positive integers. If the shape is the latter case, topk is implied to be 1.
num_out_tokens (int, optional) – The effective output token count, when enabling the capacity factor, should equal the number of tokens not dropped. By default, set to None, meaning no tokens are dropped.
padded_mode (bool, optional) – If
True
, indicating the indices are padded to denote selected tokens per expert. It can only be False currently. Default:False
.
- Returns
tuple (Tensor), tuple of 2 tensors, containing the permuted tokens and sorted indices.
permuted_tokens (Tensor) - The permuted tensor of the same dtype as tokens.
sorted_indices (Tensor) - The indices Tensor of dtype int32, corresponds to permuted tensor.
- Raises
TypeError – If tokens or indices is not a Tensor.
TypeError – If dtype of tokens is not bfloat16.
TypeError – If dtype of indices is not int32 or int64.
TypeError – If specified num_out_tokens is not an integer.
TypeError – If specified padded_mode is not a bool.
ValueError – If second dim of indices is greater than 512 when exists.
ValueError – If padded_node is set to True.
ValueError – If tokens is not 2-D or indices is not 1-D or 2-D Tensor.
RuntimeError – If first dimensions of tokens and indices are not consistent.
- Supported Platforms:
Ascend
Examples
>>> import mindspore >>> from mindspore import Tensor, ops >>> tokens = Tensor([[1, 1, 1], ... [7, 7, 7], ... [2, 2, 2], ... [1, 1, 1], ... [2, 2, 2], ... [3, 3, 3]], dtype=mindspore.bfloat16) >>> indices = Tensor([5, 0, 3, 1, 2, 4], dtype=mindspore.int32) >>> out = ops.moe_untoken_permute(tokens, indices) >>> print(out) (Tensor(shape=[6, 3], dtype=BFloat16, value= [[7, 7, 7], [1, 1, 1], [2, 2, 2], [2, 2, 2], [3, 3, 3], [1, 1, 1]]), Tensor(shape=[6], dtype=Int32, value= [5, 0, 3, 1, 2, 4]))