Source code for mindspore.parallel.nn.moe

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Note: Mixture of Expert (MoE) structure. This is an experimental interface that is subject to change and/or deletion.
"""
import math
import numpy as np
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.ops.primitive import constexpr
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from .op_parallel_config import default_dpmp_config

__all__ = [
    "MoEConfig"]


[docs]class MoEConfig: r""" The configuration of MoE (Mixture of Expert). Args: expert_num (int): The number of experts employed. Default: 1 capacity_factor (float): The factor is used to indicate how much to expand expert capacity, which is >=1.0. Default: 1.1. aux_loss_factor (float): The factor is used to indicate how much the load balance loss (produced by the router) to be added to the entire model loss, which is < 1.0. Default: 0.05. num_experts_chosen (int): The number of experts is chosen by each token. Default: 1. noisy_policy (string): The noisy policy is used in routing tokens to experts. Default: None. noisy_epsilon (float): The parameter is used in adding noises in routing tokens to experts. Default: 1e-2. """ def __init__(self, expert_num=1, capacity_factor=1.1, aux_loss_factor=0.05, num_experts_chosen=1, noisy_policy=None, noisy_epsilon=1e-2): self.expert_num = expert_num self.capacity_factor = capacity_factor self.aux_loss_factor = aux_loss_factor self.num_experts_chosen = num_experts_chosen self.noisy_policy = noisy_policy self.noisy_epsilon = noisy_epsilon
default_moe_config = MoEConfig() @constexpr def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim): return math.ceil(k * tokens_per_device * capacity_factor / expert_dim) class MoE(Cell): """ The mixture of experts (MoE) implementation. The implementation includes a router and a FeedForward layer. The router dispatches tokens to experts in FeedForward, then FeedForward does computation, and the final output is obtained by multiplying FeedForward's output and router's combine weight. Args: hidden_size (int): The dimension of the inputs. ffn_hidden_size (int): The intermediate hidden size. dropout_rate (float): The dropout rate for the second linear's output. hidden_act (str): The activation of the internal feedforward layer. Supports 'relu', 'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish', 'hsigmoid', 'logsigmoid' and so on. Default: gelu. param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`. Default `default_dpmp_config`, a instance of `OpParallelConfig` with default args. Inputs: - **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor. Outputs: Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`. """ def __init__(self, hidden_size, ffn_hidden_size, dropout_rate, hidden_act='gelu', param_init_type=mstype.float32, moe_config=default_moe_config, parallel_config=default_dpmp_config): super(MoE, self).__init__() self.hidden_size = hidden_size self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.aux_loss_factor = moe_config.aux_loss_factor self.num_experts_chosen = moe_config.num_experts_chosen self.expert_parallel = parallel_config.data_parallel self.dp = parallel_config.data_parallel from .transformer import FeedForward self.ffn = FeedForward(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, dropout_rate=dropout_rate, hidden_act=hidden_act, expert_num=self.expert_dim, param_init_type=param_init_type, parallel_config=parallel_config) self.reshape = P.Reshape() self.shape = P.Shape() self.transpose = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose2 = P.Transpose().shard(((self.dp, 1, 1, 1),)) self.transpose3 = P.Transpose().shard(((self.dp, 1, 1, 1),)) self.transpose4 = P.Transpose().shard(((self.dp, 1, 1),)) self.transpose5 = P.Transpose().shard(((self.dp, 1, 1),)) self.batch_mm = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.batch_mm2 = P.BatchMatMul().shard(((self.dp, 1, 1), (self.dp, 1, 1))) self.mul = P.Mul().shard(((), ())) self.router = Router(d_model=hidden_size, moe_config=moe_config, routing_policy=None, training=True, parallel_config=parallel_config) self.cast = P.Cast() def construct(self, input_tensor): input_shape = F.shape(input_tensor) input_tensor = self.reshape(input_tensor, (-1, self.hidden_size)) bs_and_dmodel = self.shape(input_tensor) tokens_per_device = bs_and_dmodel[0] / self.expert_parallel input_tensor = self.reshape(input_tensor, (self.expert_parallel, tokens_per_device, self.hidden_size)) expert_capacity = calculate_expert_capacity(self.num_experts_chosen, tokens_per_device, self.capacity_factor, self.expert_dim) # dispatch_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity) # combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, expert_capacity) dispatch_tensor, combine_tensor, aux_loss = self.router(input_tensor) # after transpose, input_tensor's shape: (self.expert_parallel, self.hidden_size, tokens_per_device) input_tensor = self.transpose(input_tensor, (0, 2, 1)) dispatch_tensor = self.reshape(dispatch_tensor, (self.expert_parallel, tokens_per_device, self.expert_dim * expert_capacity)) dispatch_tensor = self.cast(dispatch_tensor, F.dtype(input_tensor)) # expert_input's shape: (self.expert_parallel, self.hidden_size, self.expert_dim * expert_capacity) expert_input = self.batch_mm(input_tensor, dispatch_tensor) expert_input = self.reshape(expert_input, (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity)) # expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size) expert_input = self.transpose2(expert_input, (2, 0, 3, 1)) expert_input = self.reshape(expert_input, (self.expert_dim * self.expert_parallel * expert_capacity, self.hidden_size)) # expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size) expert_output = self.ffn(expert_input) expert_output = self.reshape(expert_output, (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)) # expert_output's shape: (self.expert_parallel, self.hidden_size, self.expert_dim, expert_capacity) expert_output = self.transpose3(expert_output, (1, 3, 0, 2)) expert_output = self.reshape(expert_output, (self.expert_parallel, self.hidden_size, self.expert_dim*expert_capacity)) combine_tensor = self.reshape(combine_tensor, (self.expert_parallel, tokens_per_device, self.expert_dim*expert_capacity)) # combine_tensor's shape: (self.expert_parallel, self.expert_dim*expert_capacity, tokens_per_device) combine_tensor = self.transpose4(combine_tensor, (0, 2, 1)) combine_tensor = self.cast(combine_tensor, F.dtype(expert_output)) # combined_output's shape: (self.expert_parallel, self.hidden_size, tokens_per_device) combined_output = self.batch_mm2(expert_output, combine_tensor) # combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size) combined_output = self.transpose5(combined_output, (0, 2, 1)) combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1])) combined_output = self.reshape(combined_output, input_shape) aux_loss = self.mul(self.aux_loss_factor, aux_loss) return combined_output, aux_loss class _CumSum(Cell): r""" A layer used to calculate cumulative summation of a tensor along a dimension. Inputs: - **expert_mask** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. """ def __init__(self, config): super(_CumSum, self).__init__() dp = config.data_parallel self.range = P.Range().shard(((1,),)) self.reshape = P.Reshape() self.matmul = P.MatMul().shard(((dp, 1), (1, 1))) self.shape = P.Shape() self.cast = P.Cast() self.transpose = P.Transpose().shard(((dp, 1, 1),)) self.transpose2 = P.Transpose().shard(((1, 1),)) self.transpose3 = P.Transpose().shard(((dp, 1, 1),)) self.expand = P.ExpandDims().shard(((1,),)) self.greater = P.Greater().shard(((1, 1), (1, 1))) self.start = Tensor(0, mstype.int32) self.limit = Tensor(0, mstype.int32) self.delta = Tensor(1, mstype.int32) self.add = P.Add().shard(((1,), ())) def construct(self, expert_mask): # origin_shape: (expert_parallel, tokens_per_device, self.expert_dim) origin_shape = self.shape(expert_mask) tokens_per_device = origin_shape[1] # expert_mask_trans's shape: (expert_parallel, self.expert_dim, tokens_per_device) expert_mask_trans = self.transpose(expert_mask, (0, 2, 1)) # expert_mask_reshaped's shape: (expert_parallel*self.expert_dim, tokens_per_device) expert_mask_reshaped = self.reshape(expert_mask_trans, (-1, tokens_per_device)) one_dim = self.expand(self.range(self.start, self.add(self.limit, tokens_per_device), self.delta), 0) other_dim = self.transpose2(one_dim, (1, 0)) # up_tri_matrix's shape: (tokens_per_device, tokens_per_device) up_tri_matrix = self.greater(one_dim, other_dim) up_tri_matrix = self.cast(up_tri_matrix, mstype.float32) # cum_sum's shape: (expert_parallel*self.expert_dim, tokens_per_device) cum_sum = self.matmul(expert_mask_reshaped, up_tri_matrix) # cum_sum's shape: (expert_parallel, self.expert_dim, tokens_per_device) cum_sum = self.reshape(cum_sum, (origin_shape[0], origin_shape[2], tokens_per_device)) # cum_sum's shape: (expert_parallel, tokens_per_device, self.expert_dim) cum_sum = self.transpose3(cum_sum, (0, 2, 1)) return cum_sum class Router(Cell): r""" A router backbone used to calculate logits of each token, which should be cascaded by router implementations mapping tokens to experts. Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). routing_policy: The policy of mapping tokens to experts. Default: SwitchRouter training (bool): The value indicating whether is in training phase. parallel_config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. """ def __init__(self, d_model, moe_config, routing_policy=None, training=True, parallel_config=None): super(Router, self).__init__() dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.training = training self.routing_policy = routing_policy self.noisy_policy = moe_config.noisy_policy # candidate: ["jitter", "rsample", "None"] self.noisy_epsilon = moe_config.noisy_epsilon self.noise = Tensor(np.random.uniform(1 - self.noisy_epsilon, 1 + self.noisy_epsilon, (d_model,))) self.dense = Dense(in_channels=self.d_model, out_channels=self.expert_dim, has_bias=False) self.dense.matmul.shard(((dp, 1), (1, 1))) self.mul = P.Mul().shard(((dp, 1, 1), (dp,))) self.cast = P.Cast() if self.routing_policy is None: self.router = SwitchRouter(d_model=d_model, moe_config=moe_config, training=training, parallel_config=parallel_config) else: self.router = routing_policy def construct(self, input_tensor): input_tensor = self.cast(input_tensor, mstype.float32) if self.noisy_policy == "jitter" and self.training is True: # Here, we temporarily implement the multiplicative jitter this way, # for the lack of UniforReal parallel operator. input_tensor = self.mul(input_tensor, self.noise) router_logits = self.dense(input_tensor) return self.router(router_logits) class SwitchRouter(Cell): r""" A router implementation which maps each tokens to the top1 expert. Reference: https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py Args: d_model (int): The hidden size of each token. moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). training (bool): The value indicating whether is in training phase. config: The parallel-related configuration. Inputs: - **input_tensor** (Tensor) - Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, hidden\_size)`. Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(1)`. """ def __init__(self, d_model, moe_config, training=True, parallel_config=None): super(SwitchRouter, self).__init__() dp = parallel_config.data_parallel self.d_model = d_model self.expert_dim = moe_config.expert_num self.capacity_factor = moe_config.capacity_factor self.training = training self.expert_parallel = dp self.noisy_policy = moe_config.noisy_policy self.cast = P.Cast() self.reshape = P.Reshape() self.shape = P.Shape() self.softmax = P.Softmax(axis=-1).shard(((dp, 1, 1,),)) self.argmax = P.ArgMaxWithValue(axis=-1, keep_dims=False).shard(((dp, 1, 1),)) self.onehot = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot2 = P.OneHot().shard(((dp, 1, 1), (), ())) self.onehot3 = P.OneHot().shard(((dp, 1, 1, 1), (), ())) self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),)) self.mul = P.Mul().shard(((dp, 1), (dp, 1))) self.mul2 = P.Mul().shard(((1,), ())) self.mul3 = P.Mul().shard(((1,), ())) self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul6 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul7 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.cumsum = _CumSum(config=parallel_config) self.less = P.Less().shard(((dp, 1, 1), ())) self.reduce_sum = P.ReduceSum(keep_dims=False).shard(((dp, 1, 1),)) self.expand = P.ExpandDims().shard(((dp, 1),)) self.expand2 = P.ExpandDims().shard(((dp, 1, 1),)) def _auxiliary_loss(self, expert_mask, router_prob): """ Computing the load balance loss. """ # density_1's shape: (expert_parallel, self.expert_dim) density_1 = self.reduce_mean(expert_mask, 1) # density_1_proxy's shape: (expert_parallel, self.expert_dim) density_1_proxy = self.reduce_mean2(router_prob, 1) loss = self.mul(density_1, density_1_proxy) loss = self.reduce_mean3(loss) loss = self.mul3(self.mul2(loss, self.expert_dim), self.expert_dim) return loss def _maskout_overflowed_tokens(self, expert_mask, expert_capacity, expert_gate): """ Keeping only the tokens that fit within expert_capacity. """ cumsum = self.cumsum(expert_mask) # position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim) position_in_expert = self.mul4(cumsum, expert_mask) less_result = self.less(position_in_expert, expert_capacity) # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) expert_mask = self.mul5(less_result, expert_mask) # expert_mask_flat's shape: (expert_parallel, tokens_per_device) expert_mask_flat = self.reduce_sum(expert_mask, -1) # Mask out the experts that have overflowed the expert_capacity. # expert_gate's shape: (expert_parallel, tokens_per_device) expert_gate = self.mul6(expert_gate, expert_mask_flat) return expert_gate, expert_mask_flat, position_in_expert def construct(self, router_logits): router_logits_shape = self.shape(router_logits) router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1])) logits_shape = self.shape(router_logits) tokens_per_device = logits_shape[0] / self.expert_parallel expert_capacity = calculate_expert_capacity(1, tokens_per_device, self.capacity_factor, self.expert_dim) router_logits = self.reshape(router_logits, (self.expert_parallel, tokens_per_device, self.expert_dim)) # Currently, lack of gumbel sampler for router_logits. # Probabilities for each token of what expert is should be sent to router_prob = self.softmax(router_logits) # shape is : (expert_parallel, tokens_per_device) expert_index, expert_gate = self.argmax(router_prob) # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value) # Computing the load balance loss: loss = self._auxiliary_loss(expert_mask, router_prob) expert_gate, expert_mask_flat, position_in_expert = \ self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate) # combine_tensor's shape: (expert_parallel, tokens_per_device) combine_tensor = self.mul7(expert_gate, expert_mask_flat) # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim) combine_tensor = self.mul8(self.expand(combine_tensor, -1), self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value)) # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity) combine_tensor = self.mul9(self.expand2(combine_tensor, -1), self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity, self.on_value, self.off_value)) dispatch_tensor = self.cast(combine_tensor, mstype.bool_) return dispatch_tensor, combine_tensor, loss