Source code for mindspore.nn.transformer.moe

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Note: Mixture of Expert (MoE) structure. This is an experimental interface that is subject to change or deletion.
"""
from __future__ import absolute_import
from __future__ import division
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.nn.transformer.op_parallel_config import default_moeparallel_config

__all__ = [
    "MoEConfig"]


[docs]class MoEConfig: r""" The configuration of MoE (Mixture of Expert). Args: expert_num (int): The number of experts employed. Default: 1 capacity_factor (float): The factor is used to indicate how much to expand expert capacity, which is >=1.0. Default: 1.1. aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the router) to be added to the entire model loss, which is < 1.0. Default: 0.05. num_experts_chosen (int): The number of experts is chosen by each token and it should not be larger than expert_num. Default: 1. expert_group_size (int): The number of tokens in each data parallel group. Default: None. This parameter is effective only when in AUTO_PARALLEL mode, and NOT SHARDING_PROPAGATION. Supported Platforms: ``Ascend`` Examples: >>> from mindspore.nn.transformer import MoEConfig >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1, ... expert_group_size=64) """ def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, expert_group_size=None): Validator.check_positive_int(expert_num, "expert_num") Validator.check_positive_float(capacity_factor, "capacity_factor") Validator.check_positive_float(aux_loss_factor, "aux_loss_factor") Validator.check_positive_int(num_experts_chosen, "num_experts_chosen") if expert_group_size is not None: Validator.check_positive_int(expert_group_size, "expert_group_size") if capacity_factor < 1.0: raise ValueError(f"'capacity_factor' must be equal to or greater than 1.0, " f"but got {capacity_factor}.") if aux_loss_factor >= 1.0: raise ValueError(f"'aux_loss_factor' must be less than 1.0, " f"but got {aux_loss_factor}.") if num_experts_chosen > expert_num: raise ValueError(f"'num_experts_chosen' must not be larger than 'expert_num', " f"but got {num_experts_chosen}.") self.expert_num = expert_num self.capacity_factor = capacity_factor self.aux_loss_factor = aux_loss_factor self.num_experts_chosen = num_experts_chosen self.expert_group_size = expert_group_size
default_moe_config = MoEConfig() def _check_moe_config(moe_config=None, parallel_config=None): """ check if MoE with right configuration. """ if not isinstance(moe_config, MoEConfig): raise TypeError(f"'moe_config' must be an instance of MoEConfig, but got {type(moe_config).__name__}.") use_moe = (moe_config.expert_num > 1) if use_moe is False: return if moe_config.expert_num % parallel_config.expert_parallel != 0: raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple " f"of 'expert_parallel' value in {type(parallel_config).__name__}, but got " f"{moe_config.expert_num} for 'expert_num' and {parallel_config.expert_parallel} for " f"'expert_parallel'.") device_num = D.get_group_size() if device_num % parallel_config.expert_parallel != 0: raise ValueError(f"device_num: {device_num} must be a multiple of expert_parallel: " f"{parallel_config.expert_parallel}.") if parallel_config.data_parallel % parallel_config.expert_parallel != 0: raise ValueError(f"data parallel: {parallel_config.data_parallel} must be a multiple of " f"expert_parallel: {parallel_config.expert_parallel} when using MoE.") if parallel_config.data_parallel * parallel_config.model_parallel > device_num: raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel} and " f"model parallel: {parallel_config.model_parallel} " f"should be less than device_num: {device_num}.") @constexpr def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim): return math.ceil(k * tokens_per_group * capacity_factor / expert_dim) class MoE(Cell): """ The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer. The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is obtained by multiplying FeedForward's output and router's combine weight. Args: hidden_size (int): The dimension of the inputs. ffn_hidden_size (int): The intermediate hidden size. dropout_rate (float): The dropout rate for the second linear's output. hidden_act (str): The activation of the internal feedforward layer. Supports 'relu', 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid', 'logsigmoid' and so on. Default: gelu. param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with default values. Please see `MoEConfig`. parallel_config(MoEParallelConfig): The parallel config for MoE, see `MoEParallelConfig`. Default `default_moeparallel_config`, an instance of `MoEParallelConfig` with default args. Inputs: - **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor. Outputs: Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`. """ def __init__(self, hidden_size, ffn_hidden_size, dropout_rate, hidden_act='gelu', param_init_type=mstype.float32, moe_config=default_moe_config, parallel_config=default_moeparallel_config): super(MoE, self).__init__() self.dp = parallel_config.data_parallel self.ep = parallel_config.expert_parallel self.hidden_size = hidden_size self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.aux_loss_factor = moe_config.aux_loss_factor self.num_experts_chosen = moe_config.num_experts_chosen self.dp_group = parallel_config.data_parallel from mindspore.nn.transformer import FeedForward self.reshape = P.Reshape() self.shape = P.Shape() self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),)) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, training=True, parallel_config=parallel_config) self.cast = P.Cast() if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.expert_group_size = moe_config.expert_group_size self.ffn = FeedForward(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, dropout_rate=dropout_rate, hidden_act=hidden_act, expert_group_size=self.expert_group_size, expert_num=self.expert_dim, param_init_type=param_init_type, parallel_config=parallel_config) self.mul = P.Mul() else: self.ffn = FeedForward(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, dropout_rate=dropout_rate, hidden_act=hidden_act, expert_num=self.expert_dim, param_init_type=param_init_type, parallel_config=parallel_config) self.mul = P.Mul().shard(((), ())) def construct(self, input_tensor): input_shape = F.shape(input_tensor) input_tensor = self.reshape(input_tensor, (-1, self.hidden_size)) bs_and_dmodel = self.shape(input_tensor) tokens_per_group = bs_and_dmodel[0] // self.dp_group input_tensor = self.reshape(input_tensor, (self.dp_group, tokens_per_group, self.hidden_size)) expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group, self.capacity_factor, self.expert_dim) # dispatch_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity) # combine_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity) dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor) # after transpose, input_tensor's shape: (self.dp_group, self.hidden_size, tokens_per_group) input_tensor = self.transpose_3dim(input_tensor, (0, 2, 1)) dispatch_tensor = self.reshape(dispatch_tensor, (self.dp_group, tokens_per_group, self.expert_dim * expert_capacity)) dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor)) # expert_input's shape: (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity) expert_input = self.batch_mm(input_tensor, dispatch_tensor) expert_input = self.reshape(expert_input, (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)) # The following four ops are to implement transpose(expert_input, (2, 0, 3, 1)), for that a single transpose # has bad performance expert_input = self.reshape(expert_input, (self.dp_group * self.hidden_size, self.expert_dim * expert_capacity)) expert_input = self.transpose_2dim(expert_input, (1, 0)) expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group, self.hidden_size)) # expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size) expert_input = self.transpose_4dim_ep(expert_input, (0, 2, 1, 3)) expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * expert_capacity, self.hidden_size)) # expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size) expert_output = self.ffn(expert_input) expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)) # The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose # has bad performance expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group * expert_capacity * self.hidden_size)) expert_output = self.transpose_2dim_ep(expert_output, (1, 0)) expert_output = self.reshape(expert_output, (self.dp_group, expert_capacity, self.hidden_size * self.expert_dim)) expert_output = self.transpose_3dim(expert_output, (0, 2, 1)) # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity) expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)) expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity)) combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group, self.expert_dim * expert_capacity)) # combine_tensor's shape: (self.dp_group, self.expert_dim*expert_capacity, tokens_per_group) combine_tensor = self.transpose_3dim(combine_tensor, (0, 2, 1)) combine_tensor = self.cast(combine_tensor, F.dtype(expert_output)) # combined_output's shape: (self.dp_group, self.hidden_size, tokens_per_group) combined_output = self.batch_mm2(expert_output, combine_tensor) # combined_output's shape: (self.dp_group, tokens_per_group, self.hidden_size) combined_output = self.transpose_3dim(combined_output, (0, 2, 1)) combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1])) combined_output = self.reshape(combined_output, input_shape) aux_loss = self.mul(self.aux_loss_factor, aux_loss) return combined_output, aux_loss class Router(Cell): r""" A router backbone used to calculate logits of each token, which should be cascaded by router implementations mapping tokens to experts. when moe_config.num_experts_chosen = 1, use top1 routing; when moe_config.num_experts_chosen > 1, use topk routing Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). routing_policy: The policy of mapping tokens to experts. Default: topkRouter training (bool): The value indicating whether is in training phase. parallel_config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. """ def __init__(self, d_model, moe_config, routing_policy=None, training=True, parallel_config=None): super(Router, self).__init__() dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.num_experts_chosen = moe_config.num_experts_chosen self.training = training self.routing_policy = routing_policy self.noisy_policy = None # candidate: ["jitter", "rsample", "None"] self.noisy_epsilon = 1e-2 self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,))) self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False) self.router = routing_policy if self.routing_policy is None: self.router = TopkRouter(d_model=d_model, moe_config=moe_config, training=training, parallel_config=parallel_config) if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.dense.matmul.shard(((dp, 1), (1, 1))) self.mul = P.Mul() self.cast = P.Cast() else: self.dense.matmul.shard(((dp, 1), (1, 1))) self.mul = P.Mul().shard(((dp, 1, 1), (dp,))) self.cast = P.Cast() def construct(self, input_tensor): input_tensor = self.cast(input_tensor, mstype.float32) if self.noisy_policy == "jitter" and self.training: # Here, we temporarily implement the multiplicative jitter this way, # for the lack of UniforReal parallel operator. input_tensor = self.mul(input_tensor, self.noise) router_logits = self.dense(input_tensor) return self.router(router_logits) class TopkRouter(Cell): r""" A router implementation which maps each tokens to the topk expert. Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). training (bool): The value indicating whether is in training phase. config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(1)`. """ def __init__(self, d_model, moe_config, training=True, parallel_config=None): super(TopkRouter, self).__init__() dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.training = training self.dp_group = dp self.noisy_policy = None self.cast = P.Cast() self.reshape = P.Reshape() self.shape = P.Shape() self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.num_experts_chosen = moe_config.num_experts_chosen if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.softmax = P.Softmax(axis=-1) self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False) self.onehot = P.OneHot() self.onehot2 = P.OneHot() self.onehot3 = P.OneHot() self.reduce_mean = P.ReduceMean(keep_dims=False) self.reduce_mean2 = P.ReduceMean(keep_dims=False) self.reduce_mean3 = P.ReduceMean(keep_dims=False) self.mul = P.Mul() self.mul2 = P.Mul() self.mul3 = P.Mul() self.mul4 = P.Mul() self.mul5 = P.Mul() self.mul6 = P.Mul() self.mul7 = P.Mul() self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.not_equal = P.NotEqual() self.div1 = P.RealDiv() self.div2 = P.RealDiv() self.add = P.Add() self.add1 = P.Add() self.add2 = P.Add() self.add3 = P.Add() self.add4 = P.Add() self.sub = P.Sub() self.cumsum = P.CumSum(exclusive=True) self.less = P.Less() self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum_keep = P.ReduceSum(keep_dims=True) self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True) self.expand = P.ExpandDims() self.expand2 = P.ExpandDims() self.add_scala = P.Add() self.init_loss = Tensor(0.0, mstype.float32) else: self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),)) self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),)) self.onehot = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ())) self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),)) self.mul = P.Mul().shard(((dp, 1), (dp, 1))) self.mul2 = P.Mul().shard(((), ())) self.mul3 = P.Mul().shard(((), ())) self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul6 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul7 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ())) self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1))) self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1))) self.add1 = P.Add().shard(((dp, 1, 1), ())) self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.add3 = P.Add().shard(((dp, 1), (dp, 1))) self.add4 = P.Add().shard(((dp, 1, 1, 1), ())) self.sub = P.Sub().shard(((), (dp, 1, 1))) self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),)) self.less = P.Less().shard(((dp, 1, 1), ())) self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_sum_keep = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1),)) self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1, 1),)) self.expand = P.ExpandDims().shard(((dp, 1),)) self.expand2 = P.ExpandDims().shard(((dp, 1, 1),)) self.add_scala = P.Add().shard(((), ())) self.init_loss = Tensor(0.0, mstype.float32) def construct(self, router_logits): router_logits_shape = self.shape(router_logits) router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1])) logits_shape = self.shape(router_logits) tokens_per_group = logits_shape[0] // self.dp_group expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group, self.capacity_factor, self.expert_dim) router_logits = self.reshape(router_logits, (self.dp_group, tokens_per_group, self.expert_dim)) accum_expert_mask = 0 accum_expert_gate = 0 loss = self.init_loss mask_count = 0 accum_combine_tensor = 0 # Probabilities for each token of what expert is should be sent to router_prob = self.softmax(router_logits) for expert_chosen_index in range(self.num_experts_chosen): # for each token, set the router_prob of the selected experts to zero router_prob = self.mul4(router_prob, self.sub(self.on_value, accum_expert_mask)) # shape is : (dp_group, tokens_per_group) expert_index, expert_gate = self.argmax(router_prob) # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim) expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value) # renormalize the rest prob to be of sum 1 router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9)) # the balance loss is computed at each routing step loss = self.add_scala(loss, self._auxiliary_loss(expert_mask, router_prob_normal)) output = self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate, mask_count, expert_chosen_index) expert_mask, expert_gate, expert_mask_flat, position_in_expert = output[0], output[1], output[2], output[3] accum_expert_mask = self.add(accum_expert_mask, expert_mask) accum_expert_gate = self.add3(accum_expert_gate, expert_gate) mask_count = self.add(mask_count, self.reduce_sum_keep(expert_mask, 1)) # combine_tensor's shape: (dp_group, tokens_per_group) combine_tensor = self.mul7(expert_gate, expert_mask_flat) # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim) combine_tensor = self.mul8(self.expand(combine_tensor, -1), self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value)) # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim, self.expert_capacity) combine_tensor = self.mul9(self.expand2(combine_tensor, -1), self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity, self.on_value, self.off_value)) accum_combine_tensor = self.add2(accum_combine_tensor, combine_tensor) # expert weights normalization combine_tensor_sum = self.reduce_sum_keep2(self.reduce_sum_keep2(accum_combine_tensor, -1), -2) accum_combine_tensor = self.div2(accum_combine_tensor, self.add4(combine_tensor_sum, 1e-9)) # dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has # bad performance dispatch_tensor = self.not_equal(accum_combine_tensor, 0.0) return dispatch_tensor, accum_combine_tensor, loss def _auxiliary_loss(self, expert_mask, router_prob): """ Computing the load balance loss. """ # density_1's shape: (dp_group, self.expert_dim) density_1 = self.reduce_mean(expert_mask, 1) # density_1_proxy's shape: (dp_group, self.expert_dim) density_1_proxy = self.reduce_mean2(router_prob, 1) loss = self.mul(density_1, density_1_proxy) loss = self.reduce_mean3(loss) loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim) return loss def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate, last_num, expert_chosen_index): """ Keeping only the tokens that fit within expert_capacity. """ cumsum = self.cumsum(expert_mask, 1) if expert_chosen_index > 0: cumsum = self.add(cumsum, last_num) # position_in_expert's shape: (dp_group, tokens_per_group, self.expert_dim) position_in_expert = self.mul4(cumsum, expert_mask) less_result = self.less(position_in_expert, expert_capacity) # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim) expert_mask = self.mul5(less_result, expert_mask) # expert_mask_flat's shape: (dp_group, tokens_per_group) expert_mask_flat = self.reduce_sum(expert_mask, -1) # Mask out the experts that have overflowed the expert_capacity. # expert_gate's shape: (dp_group, tokens_per_group) expert_gate = self.mul6(expert_gate, expert_mask_flat) output = (expert_mask, expert_gate, expert_mask_flat, position_in_expert) return output