Source code for mindspore.nn.transformer.moe

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Note: Mixture of Expert (MoE) structure. This is an experimental interface that is subject to change or deletion.
"""
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.communication.management as D
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from .op_parallel_config import default_moeparallel_config

__all__ = [
    "MoEConfig"]


[docs]class MoEConfig: r""" The configuration of MoE (Mixture of Expert). Args: expert_num (int): The number of experts employed. Default: 1 capacity_factor (float): The factor is used to indicate how much to expand expert capacity, which is >=1.0. Default: 1.1. aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the router) to be added to the entire model loss, which is < 1.0. Default: 0.05. num_experts_chosen (int): The number of experts is chosen by each token and it should not be larger than expert_num. Default: 1. Supported Platforms: ``Ascend`` Examples: >>> from mindspore.nn.transformer import MoEConfig >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1) """ def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1): Validator.check_positive_int(expert_num, "expert_num") Validator.check_positive_float(capacity_factor, "capacity_factor") Validator.check_positive_float(aux_loss_factor, "aux_loss_factor") Validator.check_positive_int(num_experts_chosen, "num_experts_chosen") if capacity_factor < 1.0: raise ValueError(f"'capacity_factor' should be equal to or greater than 1.0, " f"but got {capacity_factor}.") if aux_loss_factor >= 1.0: raise ValueError(f"'aux_loss_factor' should be less than 1.0, " f"but got {aux_loss_factor}.") if num_experts_chosen > expert_num: raise ValueError(f"'num_experts_chosen' should not be larger than 'expert_num', " f"but got {num_experts_chosen}.") self.expert_num = expert_num self.capacity_factor = capacity_factor self.aux_loss_factor = aux_loss_factor self.num_experts_chosen = num_experts_chosen
default_moe_config = MoEConfig() def _check_moe_config(moe_config=None, parallel_config=None): """ check if MoE with right configuration. """ if not isinstance(moe_config, MoEConfig): raise TypeError(f"'moe_config' should be an instance of MoEConfig, but got {type(moe_config).__name__}.") use_moe = (moe_config.expert_num > 1) if use_moe is False: return if moe_config.expert_num % parallel_config.expert_parallel != 0: raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple " f"of 'expert_parallel' value in {type(parallel_config).__name__}, but got " f"{moe_config.expert_num} for 'expert_num' and {parallel_config.expert_parallel} for " f"'expert_parallel'.") device_num = D.get_group_size() if device_num % parallel_config.expert_parallel != 0: raise ValueError(f"device_num: {device_num} should be a multiple of expert_parallel: " f"{parallel_config.expert_parallel}.") if parallel_config.data_parallel % parallel_config.expert_parallel != 0: raise ValueError(f"data parallel: {parallel_config.data_parallel} should be a multiple of " f"expert_parallel: {parallel_config.expert_parallel} when using MoE.") if parallel_config.data_parallel * parallel_config.model_parallel > device_num: raise ValueError(f"The product of the data parallel: {parallel_config.data_parallel} and " f"model parallel: {parallel_config.model_parallel} " f"should be less than device_num: {device_num}.") @constexpr def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim): return math.ceil(k * tokens_per_group * capacity_factor / expert_dim) class MoE(Cell): """ The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer. The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is obtained by multiplying FeedForward's output and router's combine weight. Args: hidden_size (int): The dimension of the inputs. ffn_hidden_size (int): The intermediate hidden size. dropout_rate (float): The dropout rate for the second linear's output. hidden_act (str): The activation of the internal feedforward layer. Supports 'relu', 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid', 'logsigmoid' and so on. Default: gelu. param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with default values. Please see `MoEConfig`. parallel_config(MoEParallelConfig): The parallel config for MoE, see `MoEParallelConfig`. Default `default_moeparallel_config`, an instance of `MoEParallelConfig` with default args. Inputs: - **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor. Outputs: Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`. """ def __init__(self, hidden_size, ffn_hidden_size, dropout_rate, hidden_act='gelu', param_init_type=mstype.float32, moe_config=default_moe_config, parallel_config=default_moeparallel_config): super(MoE, self).__init__() if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.hidden_size = hidden_size self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.aux_loss_factor = moe_config.aux_loss_factor self.num_experts_chosen = moe_config.num_experts_chosen self.dp_group = parallel_config.data_parallel self.dp = parallel_config.data_parallel self.ep = parallel_config.expert_parallel from .transformer import FeedForward self.ffn = FeedForward(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, dropout_rate=dropout_rate, hidden_act=hidden_act, expert_num=self.expert_dim, param_init_type=param_init_type, parallel_config=parallel_config) self.reshape = P.Reshape() self.shape = P.Shape() self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),)) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.mul = P.Mul() self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, training=True, parallel_config=parallel_config) self.cast = P.Cast() else: self.hidden_size = hidden_size self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.aux_loss_factor = moe_config.aux_loss_factor self.num_experts_chosen = moe_config.num_experts_chosen self.dp_group = parallel_config.data_parallel self.dp = parallel_config.data_parallel self.ep = parallel_config.expert_parallel from .transformer import FeedForward self.ffn = FeedForward(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, dropout_rate=dropout_rate, hidden_act=hidden_act, expert_num=self.expert_dim, param_init_type=param_init_type, parallel_config=parallel_config) self.reshape = P.Reshape() self.shape = P.Shape() self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) self.transpose_2dim_ep = P.Transpose().shard(((self.ep, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose_4dim_ep = P.Transpose().shard(((self.ep, 1, 1, 1),)) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.mul = P.Mul().shard(((), ())) self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, training=True, parallel_config=parallel_config) self.cast = P.Cast() def construct(self, input_tensor): input_shape = F.shape(input_tensor) input_tensor = self.reshape(input_tensor, (-1, self.hidden_size)) bs_and_dmodel = self.shape(input_tensor) tokens_per_group = bs_and_dmodel[0] // self.dp_group input_tensor = self.reshape(input_tensor, (self.dp_group, tokens_per_group, self.hidden_size)) expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group, self.capacity_factor, self.expert_dim) # dispatch_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity) # combine_tensor's shape: (self.dp_group, tokens_per_group, self.expert_dim, expert_capacity) dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor) # after transpose, input_tensor's shape: (self.dp_group, self.hidden_size, tokens_per_group) input_tensor = self.transpose_3dim(input_tensor, (0, 2, 1)) dispatch_tensor = self.reshape(dispatch_tensor, (self.dp_group, tokens_per_group, self.expert_dim * expert_capacity)) dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor)) # expert_input's shape: (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity) expert_input = self.batch_mm(input_tensor, dispatch_tensor) expert_input = self.reshape(expert_input, (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)) # The following four ops are to implement transpose(expert_input, (2, 0, 3, 1)), for that a single transpose # has bad performance expert_input = self.reshape(expert_input, (self.dp_group * self.hidden_size, self.expert_dim * expert_capacity)) expert_input = self.transpose_2dim(expert_input, (1, 0)) expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.dp_group, self.hidden_size)) # expert_input's shape: (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size) expert_input = self.transpose_4dim_ep(expert_input, (0, 2, 1, 3)) expert_input = self.reshape(expert_input, (self.expert_dim * self.dp_group * expert_capacity, self.hidden_size)) # expert_output's shape: (self.expert_dim, self.dp_group*expert_capacity, self.hidden_size) expert_output = self.ffn(expert_input) expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group, expert_capacity, self.hidden_size)) # The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose # has bad performance expert_output = self.reshape(expert_output, (self.expert_dim, self.dp_group * expert_capacity * self.hidden_size)) expert_output = self.transpose_2dim_ep(expert_output, (1, 0)) expert_output = self.reshape(expert_output, (self.dp_group, expert_capacity, self.hidden_size * self.expert_dim)) expert_output = self.transpose_3dim(expert_output, (0, 2, 1)) # expert_output's shape: (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity) expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim, expert_capacity)) expert_output = self.reshape(expert_output, (self.dp_group, self.hidden_size, self.expert_dim * expert_capacity)) combine_tensor = self.reshape(combine_tensor, (self.dp_group, tokens_per_group, self.expert_dim * expert_capacity)) # combine_tensor's shape: (self.dp_group, self.expert_dim*expert_capacity, tokens_per_group) combine_tensor = self.transpose_3dim(combine_tensor, (0, 2, 1)) combine_tensor = self.cast(combine_tensor, F.dtype(expert_output)) # combined_output's shape: (self.dp_group, self.hidden_size, tokens_per_group) combined_output = self.batch_mm2(expert_output, combine_tensor) # combined_output's shape: (self.dp_group, tokens_per_group, self.hidden_size) combined_output = self.transpose_3dim(combined_output, (0, 2, 1)) combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1])) combined_output = self.reshape(combined_output, input_shape) aux_loss = self.mul(self.aux_loss_factor, aux_loss) return combined_output, aux_loss class Router(Cell): r""" A router backbone used to calculate logits of each token, which should be cascaded by router implementations mapping tokens to experts. when moe_config.num_experts_chosen = 1, use top1 routing; when moe_config.num_experts_chosen > 1, use topk routing Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). routing_policy: The policy of mapping tokens to experts. Default: topkRouter training (bool): The value indicating whether is in training phase. parallel_config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. """ def __init__(self, d_model, moe_config, routing_policy=None, training=True, parallel_config=None): super(Router, self).__init__() if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.num_experts_chosen = moe_config.num_experts_chosen self.training = training self.routing_policy = routing_policy self.noisy_policy = None # candidate: ["jitter", "rsample", "None"] self.noisy_epsilon = 1e-2 self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,))) self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False) self.mul = P.Mul() self.cast = P.Cast() if self.routing_policy is None: self.router = TopkRouter(d_model=d_model, moe_config=moe_config, training=training, parallel_config=parallel_config) else: self.router = routing_policy else: dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.num_experts_chosen = moe_config.num_experts_chosen self.training = training self.routing_policy = routing_policy self.noisy_policy = None # candidate: ["jitter", "rsample", "None"] self.noisy_epsilon = 1e-2 self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,))) self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False) self.dense.matmul.shard(((dp, 1), (1, 1))) self.mul = P.Mul().shard(((dp, 1, 1), (dp,))) self.cast = P.Cast() if self.routing_policy is None: self.router = TopkRouter(d_model=d_model, moe_config=moe_config, training=training, parallel_config=parallel_config) else: self.router = routing_policy def construct(self, input_tensor): input_tensor = self.cast(input_tensor, mstype.float32) if self.noisy_policy == "jitter" and self.training: # Here, we temporarily implement the multiplicative jitter this way, # for the lack of UniforReal parallel operator. input_tensor = self.mul(input_tensor, self.noise) router_logits = self.dense(input_tensor) return self.router(router_logits) class TopkRouter(Cell): r""" A router implementation which maps each tokens to the topk expert. Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). training (bool): The value indicating whether is in training phase. config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(1)`. """ def __init__(self, d_model, moe_config, training=True, parallel_config=None): super(TopkRouter, self).__init__() if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.training = training self.dp_group = dp self.noisy_policy = None self.cast = P.Cast() self.reshape = P.Reshape() self.shape = P.Shape() self.softmax = P.Softmax(axis=-1) self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False) self.num_experts_chosen = moe_config.num_experts_chosen self.onehot = P.OneHot() self.onehot2 = P.OneHot() self.onehot3 = P.OneHot() self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.reduce_mean = P.ReduceMean(keep_dims=False) self.reduce_mean2 = P.ReduceMean(keep_dims=False) self.reduce_mean3 = P.ReduceMean(keep_dims=False) self.mul = P.Mul() self.mul2 = P.Mul() self.mul3 = P.Mul() self.mul4 = P.Mul() self.mul5 = P.Mul() self.mul6 = P.Mul() self.mul7 = P.Mul() self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.not_equal = P.NotEqual() self.div1 = P.RealDiv() self.div2 = P.RealDiv() self.add = P.Add() self.add1 = P.Add() self.add2 = P.Add() self.add3 = P.Add() self.add4 = P.Add() self.sub = P.Sub() self.cumsum = P.CumSum(exclusive=True) self.less = P.Less() self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum_keep = P.ReduceSum(keep_dims=True) self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True) self.expand = P.ExpandDims() self.expand2 = P.ExpandDims() else: dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.training = training self.dp_group = dp self.noisy_policy = None self.cast = P.Cast() self.reshape = P.Reshape() self.shape = P.Shape() self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),)) self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),)) self.num_experts_chosen = moe_config.num_experts_chosen self.onehot = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ())) self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),)) self.mul = P.Mul().shard(((dp, 1), (dp, 1))) self.mul2 = P.Mul().shard(((), ())) self.mul3 = P.Mul().shard(((), ())) self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul6 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul7 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ())) self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1))) self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1))) self.add1 = P.Add().shard(((dp, 1, 1), ())) self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.add3 = P.Add().shard(((dp, 1), (dp, 1))) self.add4 = P.Add().shard(((dp, 1, 1, 1), ())) self.sub = P.Sub().shard(((), (dp, 1, 1))) self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),)) self.less = P.Less().shard(((dp, 1, 1), ())) self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_sum_keep = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1),)) self.reduce_sum_keep2 = P.ReduceSum(keep_dims=True).shard(((dp, 1, 1, 1),)) self.expand = P.ExpandDims().shard(((dp, 1),)) self.expand2 = P.ExpandDims().shard(((dp, 1, 1),)) def _auxiliary_loss(self, expert_mask, router_prob): """ Computing the load balance loss. """ # density_1's shape: (dp_group, self.expert_dim) density_1 = self.reduce_mean(expert_mask, 1) # density_1_proxy's shape: (dp_group, self.expert_dim) density_1_proxy = self.reduce_mean2(router_prob, 1) loss = self.mul(density_1, density_1_proxy) loss = self.reduce_mean3(loss) loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim) return loss def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate, last_num, expert_chosen_index): """ Keeping only the tokens that fit within expert_capacity. """ cumsum = self.cumsum(expert_mask, 1) if expert_chosen_index > 0: cumsum = self.add(cumsum, last_num) # position_in_expert's shape: (dp_group, tokens_per_group, self.expert_dim) position_in_expert = self.mul4(cumsum, expert_mask) less_result = self.less(position_in_expert, expert_capacity) # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim) expert_mask = self.mul5(less_result, expert_mask) # expert_mask_flat's shape: (dp_group, tokens_per_group) expert_mask_flat = self.reduce_sum(expert_mask, -1) # Mask out the experts that have overflowed the expert_capacity. # expert_gate's shape: (dp_group, tokens_per_group) expert_gate = self.mul6(expert_gate, expert_mask_flat) output = (expert_mask, expert_gate, expert_mask_flat, position_in_expert) return output def construct(self, router_logits): router_logits_shape = self.shape(router_logits) router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1])) logits_shape = self.shape(router_logits) tokens_per_group = logits_shape[0] // self.dp_group expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_group, self.capacity_factor, self.expert_dim) router_logits = self.reshape(router_logits, (self.dp_group, tokens_per_group, self.expert_dim)) accum_expert_mask = 0 accum_expert_gate = 0 loss = 0 mask_count = 0 accum_combine_tensor = 0 # Probabilities for each token of what expert is should be sent to router_prob = self.softmax(router_logits) for expert_chosen_index in range(self.num_experts_chosen): # for each token, set the router_prob of the selected experts to zero router_prob = self.mul4(router_prob, self.sub(self.on_value, accum_expert_mask)) # shape is : (dp_group, tokens_per_group) expert_index, expert_gate = self.argmax(router_prob) # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim) expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value) # renormalize the rest prob to be of sum 1 router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9)) # the balance loss is computed at each routing step loss += self._auxiliary_loss(expert_mask, router_prob_normal) output = self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate, mask_count, expert_chosen_index) expert_mask, expert_gate, expert_mask_flat, position_in_expert = output[0], output[1], output[2], output[3] accum_expert_mask = self.add(accum_expert_mask, expert_mask) accum_expert_gate = self.add3(accum_expert_gate, expert_gate) mask_count = self.add(mask_count, self.reduce_sum_keep(expert_mask, 1)) # combine_tensor's shape: (dp_group, tokens_per_group) combine_tensor = self.mul7(expert_gate, expert_mask_flat) # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim) combine_tensor = self.mul8(self.expand(combine_tensor, -1), self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value)) # combine_tensor's shape: (dp_group, tokens_per_group, self.expert_dim, self.expert_capacity) combine_tensor = self.mul9(self.expand2(combine_tensor, -1), self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity, self.on_value, self.off_value)) accum_combine_tensor = self.add2(accum_combine_tensor, combine_tensor) # expert weights normalization combine_tensor_sum = self.reduce_sum_keep2(self.reduce_sum_keep2(accum_combine_tensor, -1), -2) accum_combine_tensor = self.div2(accum_combine_tensor, self.add4(combine_tensor_sum, 1e-9)) # dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has # bad performance dispatch_tensor = self.not_equal(accum_combine_tensor, 0.0) return dispatch_tensor, accum_combine_tensor, loss