Source code for mindspore.nn.transformer.moe

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Note: Mixture of Expert (MoE) structure. This is an experimental interface that is subject to change or deletion.
"""
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from .op_parallel_config import default_dpmp_config

__all__ = [
    "MoEConfig"]


[docs]class MoEConfig: r""" The configuration of MoE (Mixture of Expert). Args: expert_num (int): The number of experts employed. Default: 1 capacity_factor (float): The factor is used to indicate how much to expand expert capacity, which is >=1.0. Default: 1.1. aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the router) to be added to the entire model loss, which is < 1.0. Default: 0.05. num_experts_chosen (int): The number of experts is chosen by each token. Since only 'Top1' routing policy is supported currently, the value should be 1. Default: 1. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindspore.nn.transformer import MoEConfig >>> moe_config = MoEConfig(expert_num=4, capacity_factor=5.0, aux_loss_factor=0.05, num_experts_chosen=1) """ def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1): Validator.check_positive_int(expert_num, "expert_num") Validator.check_positive_float(capacity_factor, "capacity_factor") Validator.check_positive_float(aux_loss_factor, "aux_loss_factor") Validator.check_positive_int(num_experts_chosen, "num_experts_chosen") if capacity_factor < 1.0: raise ValueError(f"'capacity_factor' should be equal to or greater than 1.0, " f"but got {capacity_factor}.") if aux_loss_factor >= 1.0: raise ValueError(f"'aux_loss_factor' should be less than 1.0, " f"but got {aux_loss_factor}.") if num_experts_chosen != 1: raise ValueError(f"'num_experts_chosen' should be 1. Since only 'Top1' routing policy supported currently, " f"the value should be 1.") self.expert_num = expert_num self.capacity_factor = capacity_factor self.aux_loss_factor = aux_loss_factor self.num_experts_chosen = num_experts_chosen
default_moe_config = MoEConfig() def _check_moe_config(moe_config=None, parallel_config=None): if not isinstance(moe_config, MoEConfig): raise TypeError(f"'moe_config' should be an instance of MoEConfig, but got {type(moe_config).__name__}.") use_moe = (moe_config.expert_num > 1) if use_moe and moe_config.expert_num % parallel_config.data_parallel != 0: raise ValueError(f"When using MoE, the 'expert_num' in {type(moe_config).__name__} must be a multiple " f"of 'data_parallel' value in {type(parallel_config).__name__}, but got " f"{moe_config.expert_num} for 'expert_num' and {parallel_config.data_parallel} for " f"'data_parallel'.") @constexpr def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim): return math.ceil(k * tokens_per_device * capacity_factor / expert_dim) class MoE(Cell): """ The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer. The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is obtained by multiplying FeedForward's output and router's combine weight. Args: hidden_size (int): The dimension of the inputs. ffn_hidden_size (int): The intermediate hidden size. dropout_rate (float): The dropout rate for the second linear's output. hidden_act (str): The activation of the internal feedforward layer. Supports 'relu', 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid', 'logsigmoid' and so on. Default: gelu. param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig with default values. Please see `MoEConfig`. parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`. Default `default_dpmp_config`, an instance of `OpParallelConfig` with default args. Inputs: - **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor. Outputs: Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`. """ def __init__(self, hidden_size, ffn_hidden_size, dropout_rate, hidden_act='gelu', param_init_type=mstype.float32, moe_config=default_moe_config, parallel_config=default_dpmp_config): super(MoE, self).__init__() self.hidden_size = hidden_size self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.aux_loss_factor = moe_config.aux_loss_factor self.num_experts_chosen = moe_config.num_experts_chosen self.expert_parallel = parallel_config.data_parallel self.dp = parallel_config.data_parallel from .transformer import FeedForward self.ffn = FeedForward(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, dropout_rate=dropout_rate, hidden_act=hidden_act, expert_num=self.expert_dim, param_init_type=param_init_type, parallel_config=parallel_config) self.reshape = P.Reshape() self.shape = P.Shape() self.transpose_2dim = P.Transpose().shard(((self.dp, 1),)) self.transpose_3dim = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose_4dim = P.Transpose().shard(((self.dp, 1, 1, 1),)) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.mul = P.Mul().shard(((), ())) self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, training=True, parallel_config=parallel_config) self.cast = P.Cast() def construct(self, input_tensor): input_shape = F.shape(input_tensor) input_tensor = self.reshape(input_tensor, (-1, self.hidden_size)) bs_and_dmodel = self.shape(input_tensor) tokens_per_device = bs_and_dmodel[0] // self.expert_parallel input_tensor = self.reshape(input_tensor, (self.expert_parallel, tokens_per_device, self.hidden_size)) expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_device, self.capacity_factor, self.expert_dim) # dispatch_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity) # combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity) dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor) # after transpose, input_tensor's shape: (self.expert_parallel, self.hidden_size, tokens_per_device) input_tensor = self.transpose_3dim(input_tensor, (0, 2, 1)) dispatch_tensor = self.reshape(dispatch_tensor, (self.expert_parallel, tokens_per_device, self.expert_dim * expert_capacity)) dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor)) # expert_input's shape: (self.expert_parallel, self.hidden_size, self.expert_dim * expert_capacity) expert_input = self.batch_mm(input_tensor, dispatch_tensor) expert_input = self.reshape(expert_input, (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity)) # The following four ops are to implement transpose(expert_input, (2, 0, 3, 1)), for that a single transpose # has bad performance expert_input = self.reshape(expert_input, (self.expert_parallel*self.hidden_size, self.expert_dim*expert_capacity)) expert_input = self.transpose_2dim(expert_input, (1, 0)) expert_input = self.reshape(expert_input, (self.expert_dim, expert_capacity, self.expert_parallel, self.hidden_size)) # expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size) expert_input = self.transpose_4dim(expert_input, (0, 2, 1, 3)) expert_input = self.reshape(expert_input, (self.expert_dim * self.expert_parallel * expert_capacity, self.hidden_size)) # expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size) expert_output = self.ffn(expert_input) expert_output = self.reshape(expert_output, (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)) # The following five ops are to implement transpose(expert_output, (1, 3, 0, 2)), for that a single transpose # has bad performance expert_output = self.reshape(expert_output, (self.expert_dim, self.expert_parallel*expert_capacity*self.hidden_size)) expert_output = self.transpose_2dim(expert_output, (1, 0)) expert_output = self.reshape(expert_output, (self.expert_parallel, expert_capacity, self.hidden_size*self.expert_dim)) expert_output = self.transpose_3dim(expert_output, (0, 2, 1)) # expert_output's shape: (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity) expert_output = self.reshape(expert_output, (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity)) expert_output = self.reshape(expert_output, (self.expert_parallel, self.hidden_size, self.expert_dim*expert_capacity)) combine_tensor = self.reshape(combine_tensor, (self.expert_parallel, tokens_per_device, self.expert_dim*expert_capacity)) # combine_tensor's shape: (self.expert_parallel, self.expert_dim*expert_capacity, tokens_per_device) combine_tensor = self.transpose_3dim(combine_tensor, (0, 2, 1)) combine_tensor = self.cast(combine_tensor, F.dtype(expert_output)) # combined_output's shape: (self.expert_parallel, self.hidden_size, tokens_per_device) combined_output = self.batch_mm2(expert_output, combine_tensor) # combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size) combined_output = self.transpose_3dim(combined_output, (0, 2, 1)) combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1])) combined_output = self.reshape(combined_output, input_shape) aux_loss = self.mul(self.aux_loss_factor, aux_loss) return combined_output, aux_loss class Router(Cell): r""" A router backbone used to calculate logits of each token, which should be cascaded by router implementations mapping tokens to experts. Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). routing_policy: The policy of mapping tokens to experts. Default: SwitchRouter training (bool): The value indicating whether is in training phase. parallel_config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. """ def __init__(self, d_model, moe_config, routing_policy=None, training=True, parallel_config=None): super(Router, self).__init__() dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.training = training self.routing_policy = routing_policy self.noisy_policy = None # candidate: ["jitter", "rsample", "None"] self.noisy_epsilon = 1e-2 self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,))) self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False) self.dense.matmul.shard(((dp, 1), (1, 1))) self.mul = P.Mul().shard(((dp, 1, 1), (dp,))) self.cast = P.Cast() if self.routing_policy is None: self.router = SwitchRouter(d_model=d_model, moe_config=moe_config, training=training, parallel_config=parallel_config) else: self.router = routing_policy def construct(self, input_tensor): input_tensor = self.cast(input_tensor, mstype.float32) if self.noisy_policy == "jitter" and self.training is True: # Here, we temporarily implement the multiplicative jitter this way, # for the lack of UniforReal parallel operator. input_tensor = self.mul(input_tensor, self.noise) router_logits = self.dense(input_tensor) return self.router(router_logits) class SwitchRouter(Cell): r""" A router implementation which maps each tokens to the top1 expert. Reference: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). training (bool): The value indicating whether is in training phase. config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(1)`. """ def __init__(self, d_model, moe_config, training=True, parallel_config=None): super(SwitchRouter, self).__init__() dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.training = training self.expert_parallel = dp self.noisy_policy = None self.cast = P.Cast() self.reshape = P.Reshape() self.shape = P.Shape() self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),)) self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),)) self.onehot = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ())) self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),)) self.mul = P.Mul().shard(((dp, 1), (dp, 1))) self.mul2 = P.Mul().shard(((1,), ())) self.mul3 = P.Mul().shard(((1,), ())) self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul6 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul7 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ())) self.cumsum = P.CumSum(exclusive=True).shard(((dp, 1, 1),)) self.less = P.Less().shard(((dp, 1, 1), ())) self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),)) self.expand = P.ExpandDims().shard(((dp, 1),)) self.expand2 = P.ExpandDims().shard(((dp, 1, 1),)) def _auxiliary_loss(self, expert_mask, router_prob): """ Computing the load balance loss. """ # density_1's shape: (expert_parallel, self.expert_dim) density_1 = self.reduce_mean(expert_mask, 1) # density_1_proxy's shape: (expert_parallel, self.expert_dim) density_1_proxy = self.reduce_mean2(router_prob, 1) loss = self.mul(density_1, density_1_proxy) loss = self.reduce_mean3(loss) loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim) return loss def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate): """ Keeping only the tokens that fit within expert_capacity. """ cumsum = self.cumsum(expert_mask, 1) # position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim) position_in_expert = self.mul4(cumsum, expert_mask) less_result = self.less(position_in_expert, expert_capacity) # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) expert_mask = self.mul5(less_result, expert_mask) # expert_mask_flat's shape: (expert_parallel, tokens_per_device) expert_mask_flat = self.reduce_sum(expert_mask, -1) # Mask out the experts that have overflowed the expert_capacity. # expert_gate's shape: (expert_parallel, tokens_per_device) expert_gate = self.mul6(expert_gate, expert_mask_flat) return expert_gate, expert_mask_flat, position_in_expert def construct(self, router_logits): router_logits_shape = self.shape(router_logits) router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1])) logits_shape = self.shape(router_logits) tokens_per_device = logits_shape[0] // self.expert_parallel expert_capacity = calculate_expert_capacity(1, tokens_per_device, self.capacity_factor, self.expert_dim) router_logits = self.reshape(router_logits, (self.expert_parallel, tokens_per_device, self.expert_dim)) # Currently, lack of gumbel sampler for router_logits. # Probabilities for each token of what expert is should be sent to router_prob = self.softmax(router_logits) # shape is : (expert_parallel, tokens_per_device) expert_index, expert_gate = self.argmax(router_prob) # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value) # Computing the load balance loss: loss = self._auxiliary_loss(expert_mask, router_prob) expert_gate, expert_mask_flat, position_in_expert = \ self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate) # combine_tensor's shape: (expert_parallel, tokens_per_device) combine_tensor = self.mul7(expert_gate, expert_mask_flat) # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim) combine_tensor = self.mul8(self.expand(combine_tensor, -1), self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value)) # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity) combine_tensor = self.mul9(self.expand2(combine_tensor, -1), self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity, self.on_value, self.off_value)) # dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has # bad performance dispatch_tensor = self.not_equal(combine_tensor, 0.0) return dispatch_tensor, combine_tensor, loss