mindformers.models.llama.llama 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LLaMA models' APIs."""
import copy
from multiprocessing.managers import DictProxy
from multiprocessing.synchronize import Condition

import numpy as np

import mindspore.common.dtype as mstype
from mindspore import Tensor, nn, mint
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation

from mindformers.core.loss.loss import CrossEntropyLoss
from mindformers.mindformer_book import MindFormerBook
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import LayerSetting, check_fine_grain_interleave_valid
from mindformers.modules.layers import Linear, FreqsMgr
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.tools.utils import get_predict_run_mode

from .llama_config import LlamaConfig
from .llama_layer import LlamaEmbedding, LlamaRMSNorm
from .llama_transformer import LLamaDecodeLayer
from .llama_interleave import LLamaDecodeLayerInterleave
from ..utils import lazy_inline
from ...tools.logger import logger

__all__ = ['LlamaModel', 'LlamaForCausalLM']


class LlamaPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = LlamaConfig
    base_model_prefix = "llama"


class LlamaModel(LlamaPreTrainedModel):
    r"""
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
    Args:
        config(LlamaConfig): the config of network

    Returns:
            output: Tensor, the output of llama decoderlayer

    Examples:
        >>> from mindformers import LlamaModel
        >>> network = LlamaModel.from_pretrained('llama_7b')
        >>> type(network)
        <class 'mindformers.models.llama.llama.LlamaModel'>
    """
    _support_list = MindFormerBook.get_model_support_list()['llama']

    def __init__(self,
                 config: LlamaConfig = None):
        super().__init__(config, auto_prefix=True)
        _check_config(config.parallel_config)
        self.dtype = config.compute_dtype
        self.hidden_size = config.hidden_size
        self.num_layers = config.num_layers
        self.n_head = config.num_heads
        self.head_dim = self.hidden_size // self.n_head
        self.pad_token_id = config.pad_token_id
        self.is_first_iteration = True
        self.use_past = config.use_past
        self.use_flash_attention = config.use_flash_attention
        self.use_ring_attention = config.use_ring_attention
        self.parallel_decoding = config.parallel_decoding_params is not None
        self.concat = P.Concat(-1)
        self.cast = P.Cast()
        self.shape = P.Shape()
        self.reshape = P.Reshape()
        self.rmsnorm_compute_2d = config.rmsnorm_compute_2d

        if config.moe_config.expert_num > 1:
            logger.info("MoE config is provided, use MoE FFN")
        else:
            logger.info("MoE config is None, use normal FFN")
        if not self.use_flash_attention and self.use_ring_attention:
            raise ValueError(f"When the ring_attention = True, the flash_attention must be True ")
        self.freqs_mgr = FreqsMgr(head_dim=self.head_dim,
                                  seq_length=config.seq_length,
                                  max_position_embedding=config.max_position_embedding,
                                  rotary_dtype=config.rotary_dtype,
                                  theta=config.theta,
                                  scaling_factor=config.scaling_factor,
                                  extend_method=config.extend_method,
                                  parallel_config=config.parallel_config,
                                  is_dynamic=config.is_dynamic)
        self.residual_cast_flag = config.residual_dtype != self.dtype
        if self.residual_cast_flag:
            logger.info(f"residual in llama model cast flag: {self.residual_cast_flag}, "
                        f"residual dtype: {config.residual_dtype}")
        self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
                                                          compute_type=config.compute_dtype,
                                                          is_dynamic=config.is_dynamic,
                                                          pad_token_id=config.pad_token_id,
                                                          use_flash_attention=config.use_flash_attention,
                                                          use_attn_mask_compression=config.use_attn_mask_compression,
                                                          use_past=config.use_past)
        self.tok_embeddings = LlamaEmbedding(vocab_table_size=config.vocab_size,
                                             embedding_size=config.hidden_size,
                                             init_method_std=config.init_method_std,
                                             param_init_type=config.embedding_init_type,
                                             parallel_optimizer=config.parallel_optimizer,
                                             rmsnorm_compute_2d=config.rmsnorm_compute_2d)
        self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave,
                                                                       config.parallel_config)
        self.layers = nn.CellList()
        self.layer_setting = LayerSetting(config.num_layers,
                                          config.offset,
                                          config.parallel_config,
                                          config.pp_interleave_num)
        for layer_id in range(config.num_layers):
            if self.fine_grain_interleave:
                layer = LLamaDecodeLayerInterleave(config.batch_size,
                                                   config.seq_length,
                                                   layer_id,
                                                   dim=config.hidden_size,
                                                   n_heads=config.num_heads,
                                                   num_layers=config.num_layers,
                                                   multiple_of=config.multiple_of,
                                                   n_kv_heads=config.n_kv_heads,
                                                   intermediate_size=config.intermediate_size,
                                                   ffn_dim_multiplier=config.ffn_dim_multiplier,
                                                   norm_eps=config.rms_norm_eps,
                                                   qkv_has_bias=config.qkv_has_bias,
                                                   attn_proj_has_bias=config.attn_proj_has_bias,
                                                   qkv_concat=config.qkv_concat,
                                                   compute_dtype=config.compute_dtype,
                                                   layernorm_compute_dtype=config.layernorm_compute_type,
                                                   softmax_compute_dtype=config.softmax_compute_type,
                                                   rotary_dtype=config.rotary_dtype,
                                                   param_init_type=config.param_init_type,
                                                   residual_dtype=config.residual_dtype,
                                                   use_flash_attention=config.use_flash_attention,
                                                   use_ring_attention=config.use_ring_attention,
                                                   use_attn_mask_compression=config.use_attn_mask_compression,
                                                   fine_grain_interleave=config.fine_grain_interleave,
                                                   parallel_config=config.parallel_config,
                                                   init_method_std=config.init_method_std)
            else:
                layer = LLamaDecodeLayer(config.seq_length,
                                         layer_id,
                                         dim=config.hidden_size,
                                         n_heads=config.num_heads,
                                         n_kv_heads=config.n_kv_heads,
                                         intermediate_size=config.intermediate_size,
                                         multiple_of=config.multiple_of,
                                         ffn_dim_multiplier=config.ffn_dim_multiplier,
                                         norm_eps=config.rms_norm_eps,
                                         qkv_has_bias=config.qkv_has_bias,
                                         attn_proj_has_bias=config.attn_proj_has_bias,
                                         qkv_concat=config.qkv_concat,
                                         compute_dtype=config.compute_dtype,
                                         layernorm_compute_dtype=config.layernorm_compute_type,
                                         softmax_compute_dtype=config.softmax_compute_type,
                                         rotary_dtype=config.rotary_dtype,
                                         param_init_type=config.param_init_type,
                                         residual_dtype=config.residual_dtype,
                                         use_past=config.use_past,
                                         is_dynamic=config.is_dynamic,
                                         use_flash_attention=config.use_flash_attention,
                                         use_ring_attention=config.use_ring_attention,
                                         use_attn_mask_compression=config.use_attn_mask_compression,
                                         block_size=config.block_size,
                                         num_blocks=config.num_blocks,
                                         use_rope_slice=config.use_rope_slice,
                                         rmsnorm_compute_2d=config.rmsnorm_compute_2d,
                                         moe_config=config.moe_config,
                                         parallel_config=config.parallel_config,
                                         parallel_decoding=self.parallel_decoding,
                                         fused_kernel=config.fused_rms_norm,
                                         init_method_std=config.init_method_std
                                         )
            self.layer_setting(layer, layer_id)
            self.layers.append(layer)
        self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps,
                                     compute_type=config.layernorm_compute_type,
                                     fused_kernel=config.fused_rms_norm)
        dp = config.parallel_config.data_parallel
        cp = config.parallel_config.context_parallel
        self.tok_embeddings.pipeline_stage = 0
        if config.parallel_config.pipeline_stage > 1:
            self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1
            self.tok_embeddings.set_comm_fusion(2)
            self.norm_out.set_comm_fusion(2)
        else:
            self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
            self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group)

        if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
            self.tok_embeddings.shard(config.parallel_config)
            self.casual_mask.shard(config.parallel_config)
            if self.fine_grain_interleave or config.rmsnorm_compute_2d:
                self.norm_out.shard((dp * cp, 1))
            else:
                self.norm_out.shard((dp, cp, 1))
        else:
            self.tok_embeddings.shard(config.parallel_config)
            self.casual_mask.shard(config.parallel_config)
            self.concat.shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
            if self.fine_grain_interleave or config.rmsnorm_compute_2d:
                self.norm_out.shard((dp * cp, 1))
            else:
                self.norm_out.shard((dp, cp, 1))

    # pylint: disable=W0613
    def construct(self, tokens: Tensor, input_embeds=None, batch_valid_length=None, batch_index=None,
                  zactivate_len=None, block_tables=None, slot_mapping=None, prefix_keys_values=None,
                  attention_mask=None, position_ids=None, q_seq_lens=None):
        """
        Forward of llama model.

        Args:
            tokens: the tokenized inputs with datatype int32
            input_embeds: the embedding Tensor of tokens, Tensor of shape:math:`(batch_size, seq/_length, hidden_size)`.
                Default None.
            batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
                prediction. Tensor of shape :math:`(batch_size,)`. Default None.
            block_tables (Tensor[int64]): Store mapping tables for each sequence.
            slot_mapping (Tensor[int32]): Store token cache physical slot index.
        Returns:
            output: Tensor, the output of llama decoderlayer
        """
        # preprocess
        bs, seq_len = self.shape(tokens)
        rmsnorm_compute_2d = self.training and self.rmsnorm_compute_2d
        if self.parallel_decoding:
            # FA with TH layout, mask is 2D, FA with BSH layout, mask is 4D
            mask = attention_mask
            freqs_cis = self.freqs_mgr.increment_multi_ids(position_ids)
        else:
            mask = None
            if self.use_past:
                if self.is_first_iteration:
                    freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
                    mask = self.casual_mask.prefill()
                    if prefix_keys_values is not None:
                        if mask is None:
                            mask = self.casual_mask(tokens)
                        prefix_length = prefix_keys_values[0].shape[2]
                        prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
                        mask = self.concat((prefix_mask, mask))
                else:
                    freqs_cis = self.freqs_mgr.increment(batch_valid_length)
            else:
                if not self.use_ring_attention:
                    mask = self.casual_mask(tokens)
                freqs_cis = self.freqs_mgr(seq_len)
                if prefix_keys_values is not None:
                    prefix_length = prefix_keys_values[0].shape[2]
                    prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
                    mask = self.concat((prefix_mask, mask))

        # tokens: [bs, seq/1]
        if input_embeds is not None:
            h = self.cast(input_embeds, self.dtype)
        else:
            h = self.cast(self.tok_embeddings(tokens), self.dtype)
        if not rmsnorm_compute_2d:
            h = self.reshape(h, (bs, seq_len, self.hidden_size))    # h: [bs, seq/1, hidden_dim]
        for i in range(self.num_layers):
            prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None
            h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables,
                               slot_mapping=slot_mapping, prefix_keys_values=prefix_kv, q_seq_lens=q_seq_lens)
        if rmsnorm_compute_2d:
            h = self.reshape(h, (bs * seq_len, -1))
        output = self.norm_out(h)
        return output


[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS) class LlamaForCausalLM(LlamaPreTrainedModel): r""" Provide llama training loss or logits through network. Args: config (LlamaConfig): The config of llama model. Default: `None` . Inputs: - **input_ids** (Tensor) - the indices of input sequence tokens in the vocabulary with data type Int64/Int32, Tensor of shape :math:`(batch, seq\_length)`. - **labels** (Tensor, optional) - the labels of inputs with data type Int64/Int32, Tensor of shape :math:`(batch, seq\_length)` . Default: ``None`` . - **input_position** (Tensor, optional) - the position ids of inputs (at incremental reasoning mode) which is an increasing sequence with data type Int64/Int32, Tensor :math:`(batch, seq\_length)`. Default: ``None`` . - **position_ids** (Tensor, optional) - the position ids of inputs which is an increasing sequence with data type Int64/Int32, Tensor :math:`(batch, seq\_length)`. Default: ``None`` . - **attention_mask** (Tensor, optional) - input sentences padding mask, where 0 indicates padding position with data type Int64/Int32, Tensor of shape :math:`(batch, seq\_length)`. Default: ``None`` . - **input_embeds** (Tensor, optional) - the embedding of inputs with data type Float32/Float16, Tensor of shape :math:`(batch, seq\_length, hidden\_size)`. Default: ``None`` . - **init_reset** (Tensor, optional) - A Bool tensor with shape [1], used to clear the past key parameter and past value parameter used in the incremental prediction. Only valid when use_past is True. Tensor of shape :math:`(1)`. Default: ``Tensor([True])`` . - **batch_valid_length** (Tensor, optional) - Int32 tensor with shape [batch_size] the past calculated the index. Used for incremental prediction when the use_past is True. Default: ``None`` . - **block_tables** (Tensor, optional) - Int64 type Tensor, Store mapping tables for each sequence. Default: ``None`` . - **slot_mapping** (Tensor, optional) - Int32 type Tensor, token cache physical slot index. Default:``None`` . Outputs: Tensor. If it is in training mode, the output Tensor contains loss; If it is in prediction mode, the output Tensor contains logits; If it is in evaluation mode, the output Tensor contains logits, tokens, and input masks. Examples: >>> from mindformers.models.llama import LlamaConfig, LlamaForCausalLM >>> import mindspore as ms >>> ms.set_context(mode=0) >>> config = LlamaConfig(batch_size=2) >>> network = LlamaForCausalLM(config=config) >>> type(network) <class 'mindformers.models.llama.llama.LlamaForCausalLM'> >>> from mindformers import LlamaForCausalLM >>> network = LlamaForCausalLM.from_pretrained('llama_7b') >>> type(network) <class 'mindformers.models.llama.llama.LlamaForCausalLM'> """ _support_list = MindFormerBook.get_model_support_list()['llama'] @lazy_inline def __init__(self, config: LlamaConfig = None): super(LlamaForCausalLM, self).__init__(config, auto_prefix=True) _check_config(config.parallel_config) self.config = config self.ignore_token_id = config.ignore_token_id self.pad_token_id = config.pad_token_id self.use_past = config.use_past self.vocab_size = config.vocab_size self.is_first_iteration = True self.shape = P.Shape() self.reshape = P.Reshape() self.cast = P.Cast() self.slice = P.StridedSlice() self.not_equal = P.NotEqual() self.mul = P.Mul() self.add = P.Add() self.ones = P.Ones() self.gather = P.Gather(1) self.prefill_gather_flatten = P.Gather() self.sub_batch_valid_len = P.Sub() self.model = LlamaModel(config=config) self.lm_head = Linear(in_channels=config.hidden_size, out_channels=config.vocab_size, has_bias=False, compute_dtype=config.compute_dtype, param_init_type=config.param_init_type, weight_init="normal") # meta default: xavier_normal if config.tie_word_embeddings: self.lm_head.weight = self.model.tok_embeddings.embedding_weight mp = config.parallel_config.model_parallel vocab_size = config.vocab_size loss_parallel_config = copy.deepcopy(config.parallel_config) if vocab_size % mp != 0: logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s", vocab_size, mp) logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") loss_parallel_config.model_parallel = 1 loss_parallel_config.data_parallel *= loss_parallel_config.context_parallel check_for_nan_in_loss_and_grad = getattr(config, "check_for_nan_in_loss_and_grad", False) calculate_per_token_loss = getattr(config, "calculate_per_token_loss", False) self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config, check_for_nan_in_loss_and_grad=check_for_nan_in_loss_and_grad, calculate_per_token_loss=calculate_per_token_loss) dp = config.parallel_config.data_parallel mp = config.parallel_config.model_parallel cp = config.parallel_config.context_parallel if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation(): self.slice.shard(((dp, 1),)) self.not_equal.shard(((dp, 1), ())) if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0): self.lm_head.shard(strategy_matmul=((dp * cp, 1), (1, 1))) else: self.lm_head.shard(strategy_matmul=((dp * cp, 1), (mp, 1))) else: self.slice.shard(((dp, 1),)) self.not_equal.shard(((dp, 1), ())) self.mul.shard(((dp, 1), (dp, 1))) self.add.shard(((dp, 1), ())) self.gather.shard(((dp, 1, 1), (dp,))) self.prefill_gather_flatten.shard(((dp, 1, 1), (dp,))) self.sub_batch_valid_len.shard(((1,), ())) if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0): self.lm_head.shard(strategy_matmul=((dp * cp, 1), (1, 1))) else: self.lm_head.shard(strategy_matmul=((dp * cp, 1), (mp, 1))) if config.parallel_config.pipeline_stage > 1: self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 self.load_checkpoint(config) self.predict_run_mode = get_predict_run_mode() logger.info("Predict run mode:{}".format(self.predict_run_mode)) self.parallel_decoding = config.parallel_decoding_params is not None self.input_sliced_sig = config.input_sliced_sig def to_embeddings(self, tokens): """return embedding tokens""" return self.model.tok_embeddings(tokens) # pylint: disable=W0613 def prepare_inputs_for_predict_layout(self, input_ids, **kwargs): """Get Llama model input tuple for transform ckpt.""" input_ids = Tensor(input_ids, mstype.int32) labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None bs, seq = input_ids.shape[0], input_ids.shape[1] slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32) prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None return input_ids, labels, None, None, None, None, None, None, None, None, None, slot_mapping, prefix_keys_values def set_dynamic_inputs(self, **kwargs): dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32) dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32) dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False) dynamic_position_ids = Tensor(shape=[None, None], dtype=mstype.int32) if self.parallel_decoding else None dynamic_mask = Tensor(shape=[None, None], dtype=mstype.float16) if self.parallel_decoding else None dynamic_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) if self.parallel_decoding else None if have_prefix_keys_values: dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16) self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_mask, None, None, dynamic_batch_valid_length, None, None, dynamic_block_tables, dynamic_slot_mapping, dynamic_prefix_keys_values, None, dynamic_q_seq_lens) elif self.use_past: self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_mask, None, None, dynamic_batch_valid_length, None, None, dynamic_block_tables, dynamic_slot_mapping, None, None, dynamic_q_seq_lens) elif kwargs.get("pre_gather", False): self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, dynamic_batch_valid_length, None, None, None, None, None) else: self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, None, None, None, None, None, None) logger.info("Set dynamic input for llama.") def add_flags_custom(self, is_first_iteration): """Add customized attributes for specific cells in the model.""" self.add_flags(is_first_iteration=is_first_iteration) self.model.add_flags(is_first_iteration=is_first_iteration) for layer in self.model.layers: layer.add_flags(is_first_iteration=is_first_iteration) layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration) layer.attention.infer_attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration) def pre_gather_func(self, pre_gather, output, batch_valid_length): """Pre gather operation in infer mode.""" if not pre_gather: return output if self.parallel_decoding and self.is_first_iteration: output = output.reshape(-1, output.shape[-1]) output = output[self.sub_batch_valid_len(batch_valid_length, 1)] elif pre_gather: if self.config.is_dynamic: batch_valid_length = mint.cumsum(batch_valid_length, 0) output = self.prefill_gather_flatten(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) else: output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) return output # pylint: disable=W0613 def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None, input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None, block_tables=None, slot_mapping=None, prefix_keys_values=None, llm_boost_inputs=None, q_seq_lens=None): r""" LlamaForCausalLM forward. Args: input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. labels(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. input_position(Tensor): current position, used by model.predict. position_ids(Tensor): Reserved param, not used. attention_mask(Tensor): Reserved param, not used. input_embeds(Tensor): the input embedding Tensor of shape :math:`(batch, seq\_length, hidden_size)`. Default None. init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and past value parameter used in the incremental prediction. Default True. batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental prediction. Tensor of shape :math:`(batch_size,)`. Default None. block_tables (Tensor[int64]): Store mapping tables for each sequence. slot_mapping (Tensor[int32]): Store token cache physical slot index. q_seq_lens (Tensor[int32]): In parallel decoding, the query may be flattened. The Paged Attention operator need `q_seq_lens` to obtain the length information. Returns: Tensor, The loss or (logits, tokens, input_mask) of the network. """ input_sliced_sig = self.input_sliced_sig if self.training and input_sliced_sig and labels is None: input_sliced_sig = False bsz, seqlen = self.shape(input_ids) if self.use_past: if not isinstance(batch_valid_length, Tensor): batch_valid_length = self.ones((bsz,), mstype.int32) if not input_sliced_sig and self.training: tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1)) else: tokens = input_ids if batch_valid_length is not None: batch_valid_length = self.reshape(batch_valid_length, (-1,)) output = self.model(tokens, input_embeds, batch_valid_length, batch_index, zactivate_len, block_tables, \ slot_mapping, prefix_keys_values, attention_mask, position_ids, q_seq_lens) pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None output = self.pre_gather_func(pre_gather, output, batch_valid_length) logits = self.lm_head(output) input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) if labels is None: labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) else: if labels.ndim > 1: if not input_sliced_sig and self.training: labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1)) label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32) input_mask = self.mul(input_mask, label_mask) if not self.training: logits = self.cast(logits, mstype.float32) if self.predict_run_mode: logits = self.reshape(logits, (-1, logits.shape[-1])) return logits return logits, tokens, input_mask if logits.ndim > 2: logits = self.reshape(logits, (-1, logits.shape[-1])) logits = self.cast(logits, mstype.float32) labels = self.reshape(labels, (-1,)) input_mask = self.reshape(input_mask, (-1,)) loss = self.loss(logits, labels, input_mask) return loss def kvcache(self, layer_idx): key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache return key_cache, value_cache @classmethod def convert_name(cls, weight_name): """convert HuggingFace weight name to MindFormers weight name""" origin_name = weight_name weight_name = weight_name.replace('embed_tokens.', 'tok_embeddings.') weight_name = weight_name.replace('.self_attn.q_proj.', '.attention.wq.') weight_name = weight_name.replace('.self_attn.k_proj.', '.attention.wk.') weight_name = weight_name.replace('.self_attn.v_proj.', '.attention.wv.') weight_name = weight_name.replace('.self_attn.o_proj.', '.attention.wo.') weight_name = weight_name.replace('.mlp.gate_proj.', '.feed_forward.w1.') weight_name = weight_name.replace('.mlp.down_proj.', '.feed_forward.w2.') weight_name = weight_name.replace('.mlp.up_proj.', '.feed_forward.w3.') weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.') weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.') weight_name = weight_name.replace('.norm.', '.norm_out.') weight_name = weight_name.replace('output.', 'lm_head.') weight_name = weight_name.replace('.tok_embeddings.weight', '.tok_embeddings.embedding_weight') if weight_name == origin_name: logger.warning(f"weight name '{weight_name}' does not change after conversion. " f"Please check if it is as expected.") return weight_name @classmethod def convert_weight_dict(cls, source_dict, **kwargs): """convert HuggingFace weight dict to MindFormers weight dict""" qkv_concat = kwargs.get("qkv_concat", False) target_dict = {} wq_keys = [] wk_keys = [] wv_keys = [] w1_keys = [] w3_keys = [] for k, v in source_dict.items(): k = cls.convert_name(k) target_dict.update({k: v}) if qkv_concat: part = k.split('.') if part[-2] == 'wq': wq_keys.append(k) if part[-2] == 'wk': wk_keys.append(k) if part[-2] == 'wv': wv_keys.append(k) if part[-2] == 'w1': w1_keys.append(k) if part[-2] == 'w3': w3_keys.append(k) if qkv_concat: qkv_dict = kwargs.get('qkv_dict', None) if not isinstance(qkv_dict, DictProxy): raise ValueError(f'qkv_queue must be a queue, when qkv_concat is True, but got {qkv_dict}.') condition = kwargs.get('condition', None) if not isinstance(condition, Condition): raise ValueError(f'condition must be a Condition, when qkv_concat is True, but got {condition}.') _concat_qkv_weight(wq_keys, wk_keys, wv_keys, w1_keys, w3_keys, qkv_dict, condition, target_dict) return target_dict @classmethod def convert_map_dict(cls, source_dict, **kwargs): """convert HuggingFace map dict to MindFormers map dict""" qkv_concat = kwargs.pop("qkv_concat", False) target_dict = {} wq_keys = [] w1_keys = [] for k, v in source_dict.items(): k = cls.convert_name(k) target_dict.update({k: v}) if qkv_concat: part = k.split('.') if part[-2] == 'wq': wq_keys.append(k) if part[-2] == 'w1': w1_keys.append(k) if qkv_concat: for wq_key in wq_keys: wk_key = wq_key.replace('wq', 'wk') wv_key = wq_key.replace('wq', 'wv') wq_value = target_dict.pop(wq_key) target_dict.pop(wk_key) target_dict.pop(wv_key) w_qkv_key = wq_key.replace('wq', 'w_qkv') w_qkv_value = wq_value target_dict.update({w_qkv_key: w_qkv_value}) for w1_key in w1_keys: w3_key = w1_key.replace('w1', 'w3') w1_value = target_dict.pop(w1_key) target_dict.pop(w3_key) w_gate_hidden_key = w1_key.replace('w1', 'w_gate_hidden') w_gate_hidden_value = w1_value target_dict.update({w_gate_hidden_key: w_gate_hidden_value}) return target_dict @classmethod def obtain_qkv_ffn_concat_keys(cls): qkv_key = "w_qkv" ffn_key = "w_gate_hidden" concat_keys = [qkv_key, ffn_key] logger.info(f"{cls.__name__} qkv/ffn concat keys are {concat_keys}") return concat_keys
def _concat_qkv_weight(wq_keys, wk_keys, wv_keys, w1_keys, w3_keys, qkv_dict, condition, target_dict): """concat qkv weight and ffn weight from dicts""" # pop extra weight to shared dict if there is no corresponding weight for concat in the target dict for wk_key in wk_keys: wq_key = wk_key.replace('wk', 'wq') if wq_key not in wq_keys: with condition: qkv_dict[wk_key] = target_dict.pop(wk_key) # add extra weight to shared dict condition.notify_all() for wv_key in wv_keys: wq_key = wv_key.replace('wv', 'wq') if wq_key not in wq_keys: with condition: qkv_dict[wv_key] = target_dict.pop(wv_key) # add extra weight to shared dict condition.notify_all() for w3_key in w3_keys: w1_key = w3_key.replace('w3', 'w1') if w1_key not in w1_keys: with condition: qkv_dict[w3_key] = target_dict.pop(w3_key) # add extra weight to shared dict condition.notify_all() # concat qkv and ffn for wq_key in wq_keys: wk_key = wq_key.replace('wq', 'wk') wv_key = wq_key.replace('wq', 'wv') wq_value = target_dict.pop(wq_key) wk_value = target_dict.pop(wk_key, None) wv_value = target_dict.pop(wv_key, None) # get missing weight from shared dict if wk_value is None: with condition: condition.wait_for(lambda: wk_key in qkv_dict.keys()) wk_value = qkv_dict.pop(wk_key) if wv_value is None: with condition: condition.wait_for(lambda: wv_key in qkv_dict.keys()) wv_value = qkv_dict.pop(wv_key) w_qkv_key = wq_key.replace('wq', 'w_qkv') w_qkv_value = np.concatenate((wq_value, wk_value, wv_value), 0) target_dict.update({w_qkv_key: w_qkv_value}) for w1_key in w1_keys: w3_key = w1_key.replace('w1', 'w3') w1_value = target_dict.pop(w1_key) w3_value = target_dict.pop(w3_key, None) # get missing weight from shared dict if w3_value is None: with condition: condition.wait_for(lambda: w3_key in qkv_dict.keys()) w3_value = qkv_dict.pop(w3_key) w_gate_hidden_key = w1_key.replace('w1', 'w_gate_hidden') w_gate_hidden_value = np.concatenate((w1_value, w3_value), 0) target_dict.update({w_gate_hidden_key: w_gate_hidden_value})