# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LLaMA models' APIs."""
import copy
from multiprocessing.managers import DictProxy
from multiprocessing.synchronize import Condition
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn, mint, Parameter
from mindspore.common.initializer import initializer
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindformers.core.loss.loss import CrossEntropyLoss
from mindformers.mindformer_book import MindFormerBook
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import LayerSetting, check_fine_grain_interleave_valid
from mindformers.modules.layers import Linear, FreqsMgr
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.tools.utils import get_predict_run_mode
from .llama_config import LlamaConfig
from .llama_layer import LlamaEmbedding, LlamaRMSNorm
from .llama_transformer import LLamaDecodeLayer
from .llama_interleave import LLamaDecodeLayerInterleave
from ..utils import lazy_inline
from ...tools.logger import logger
__all__ = ['LlamaModel', 'LlamaForCausalLM']
class LlamaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LlamaConfig
base_model_prefix = "llama"
class LlamaModel(LlamaPreTrainedModel):
r"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config(LlamaConfig): the config of network
Returns:
output: Tensor, the output of llama decoderlayer
Examples:
>>> from mindformers import LlamaModel
>>> network = LlamaModel.from_pretrained('llama_7b')
>>> type(network)
<class 'mindformers.models.llama.llama.LlamaModel'>
"""
_support_list = MindFormerBook.get_model_support_list()['llama']
def __init__(self,
config: LlamaConfig = None):
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.dtype = config.compute_dtype
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.n_head = config.num_heads
self.head_dim = self.hidden_size // self.n_head
self.pad_token_id = config.pad_token_id
self.is_first_iteration = True
self.chunk_prefill = config.chunk_prefill
self.use_past = config.use_past
self.use_flash_attention = config.use_flash_attention
self.use_ring_attention = config.use_ring_attention
self.parallel_decoding = config.parallel_decoding_params is not None
self.concat = P.Concat(-1)
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.rmsnorm_compute_2d = config.rmsnorm_compute_2d
if config.moe_config.expert_num > 1:
logger.info("MoE config is provided, use MoE FFN")
else:
logger.info("MoE config is None, use normal FFN")
if not self.use_flash_attention and self.use_ring_attention:
raise ValueError(f"When the ring_attention = True, the flash_attention must be True ")
self.seq_split_num = config.parallel_config.seq_split_num
self.seq_pipe = self.seq_split_num > 1
if self.seq_pipe:
dp = config.parallel_config.data_parallel
if self.use_ring_attention:
raise ValueError(f"When the seq_pipe = True, the use_ring_attention cannot be True ")
if config.use_attn_mask_compression:
raise ValueError(f"Currently, when the seq_pipe = True, the use_attn_mask_compression cannot be True ")
self.n_kv_head = self.n_head if config.n_kv_heads is None else config.n_kv_heads
kv_shape = (config.batch_size * dp, self.n_kv_head, config.seq_length, self.head_dim)
self.zeros = initializer('zeros', kv_shape, dtype=self.dtype)
self.seq_update = Tensor(1, dtype=mstype.int32)
self.seq_zero = Tensor(0, dtype=mstype.int32)
self.seq_seg_len = config.seq_length // self.seq_split_num
kv_mask = np.zeros((1, self.n_kv_head, config.seq_length, self.head_dim), np.int32)
for s in range(self.seq_split_num):
kv_mask[:, :, s * self.seq_seg_len: (s + 1) * self.seq_seg_len, :] = s
self.kv_mask = Tensor(kv_mask)
self.seq_chunk = Parameter(Tensor(0, dtype=mstype.int32), name="seq_chunk",
requires_grad=False, parallel_optimizer=False)
cp = config.parallel_config.context_parallel
mp = config.parallel_config.model_parallel
self.equal_kv = P.Equal().shard(((dp, mp, cp, 1), ()))
self.kv_mask_add = P.Add().shard(((dp, mp, cp, 1), (1, mp, cp, 1)))
self.assign_add_count = P.AssignAdd()
self.assign_count = P.Assign()
self.assign_mask = P.Assign().shard(((dp, 1), (dp, 1)))
self.mask_zeros = Tensor(np.zeros((config.batch_size * dp, config.seq_length)), mstype.float32)
self.freqs_mgr = FreqsMgr(head_dim=self.head_dim,
seq_length=config.seq_length,
max_position_embedding=config.max_position_embedding,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
scaling_factor=config.scaling_factor,
extend_method=config.extend_method,
parallel_config=config.parallel_config,
is_dynamic=config.is_dynamic)
self.residual_cast_flag = config.residual_dtype != self.dtype
if self.residual_cast_flag:
logger.info(f"residual in llama model cast flag: {self.residual_cast_flag}, "
f"residual dtype: {config.residual_dtype}")
self.freqs_mgr.shard(config.parallel_config)
total_batch_size_in_dp = config.batch_size * config.parallel_config.data_parallel
self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
batch_size=total_batch_size_in_dp,
compute_type=config.compute_dtype,
is_dynamic=config.is_dynamic,
pad_token_id=config.pad_token_id,
use_flash_attention=config.use_flash_attention,
use_attn_mask_compression=config.use_attn_mask_compression,
use_past=config.use_past,
seq_split_num=self.seq_split_num,
chunk_prefill=config.chunk_prefill)
self.tok_embeddings = LlamaEmbedding(vocab_table_size=config.vocab_size,
embedding_size=config.hidden_size,
init_method_std=config.init_method_std,
param_init_type=config.embedding_init_type,
parallel_optimizer=config.parallel_optimizer,
rmsnorm_compute_2d=config.rmsnorm_compute_2d)
self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave,
config.parallel_config)
self.layers = nn.CellList()
self.layer_setting = LayerSetting(config.num_layers,
config.offset,
config.parallel_config,
config.pp_interleave_num,
config.start_stage,
config.stage_num)
for layer_id in range(config.num_layers):
if self.fine_grain_interleave:
layer = LLamaDecodeLayerInterleave(config.batch_size,
config.seq_length,
layer_id,
dim=config.hidden_size,
n_heads=config.num_heads,
num_layers=config.num_layers,
multiple_of=config.multiple_of,
n_kv_heads=config.n_kv_heads,
intermediate_size=config.intermediate_size,
ffn_dim_multiplier=config.ffn_dim_multiplier,
norm_eps=config.rms_norm_eps,
qkv_has_bias=config.qkv_has_bias,
attn_proj_has_bias=config.attn_proj_has_bias,
qkv_concat=config.qkv_concat,
compute_dtype=config.compute_dtype,
layernorm_compute_dtype=config.layernorm_compute_type,
softmax_compute_dtype=config.softmax_compute_type,
rotary_dtype=config.rotary_dtype,
param_init_type=config.param_init_type,
residual_dtype=config.residual_dtype,
use_flash_attention=config.use_flash_attention,
use_ring_attention=config.use_ring_attention,
use_attn_mask_compression=config.use_attn_mask_compression,
fine_grain_interleave=config.fine_grain_interleave,
init_method_std=config.init_method_std,
parallel_config=config.parallel_config)
else:
layer = LLamaDecodeLayer(config.seq_length,
layer_id,
dim=config.hidden_size,
n_heads=config.num_heads,
n_kv_heads=config.n_kv_heads,
intermediate_size=config.intermediate_size,
multiple_of=config.multiple_of,
ffn_dim_multiplier=config.ffn_dim_multiplier,
norm_eps=config.rms_norm_eps,
qkv_has_bias=config.qkv_has_bias,
attn_proj_has_bias=config.attn_proj_has_bias,
qkv_concat=config.qkv_concat,
compute_dtype=config.compute_dtype,
layernorm_compute_dtype=config.layernorm_compute_type,
softmax_compute_dtype=config.softmax_compute_type,
rotary_dtype=config.rotary_dtype,
param_init_type=config.param_init_type,
residual_dtype=config.residual_dtype,
use_past=config.use_past,
is_dynamic=config.is_dynamic,
use_flash_attention=config.use_flash_attention,
use_ring_attention=config.use_ring_attention,
use_attn_mask_compression=config.use_attn_mask_compression,
block_size=config.block_size,
num_blocks=config.num_blocks,
use_rope_slice=config.use_rope_slice,
rmsnorm_compute_2d=config.rmsnorm_compute_2d,
batch_size=config.batch_size,
moe_config=config.moe_config,
parallel_config=config.parallel_config,
parallel_decoding=self.parallel_decoding,
fused_kernel=config.fused_rms_norm,
init_method_std=config.init_method_std,
chunk_prefill=config.chunk_prefill
)
self.layer_setting(layer, layer_id)
self.layers.append(layer)
self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type,
fused_kernel=config.fused_rms_norm)
dp = config.parallel_config.data_parallel
cp = config.parallel_config.context_parallel
self.tok_embeddings.pipeline_stage = config.start_stage
if config.parallel_config.pipeline_stage > 1:
if config.stage_num == 0:
self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1
else:
self.norm_out.pipeline_stage = config.start_stage + config.stage_num - 1
self.tok_embeddings.set_comm_fusion(2)
self.norm_out.set_comm_fusion(2)
else:
self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
self.tok_embeddings.shard(config.parallel_config)
self.casual_mask.shard(config.parallel_config)
if self.fine_grain_interleave or config.rmsnorm_compute_2d:
self.norm_out.shard((dp * cp, 1))
else:
self.norm_out.shard((dp, cp, 1))
else:
self.tok_embeddings.shard(config.parallel_config)
self.casual_mask.shard(config.parallel_config)
self.concat.shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
if self.fine_grain_interleave or config.rmsnorm_compute_2d:
self.norm_out.shard((dp * cp, 1))
else:
self.norm_out.shard((dp, cp, 1))
# pylint: disable=W0613
def construct(self, tokens: Tensor, input_embeds=None, batch_valid_length=None, batch_index=None,
zactivate_len=None, block_tables=None, slot_mapping=None, prefix_keys_values=None,
attention_mask=None, position_ids=None, q_seq_lens=None, seq_range=None):
"""
Forward of llama model.
Args:
tokens: the tokenized inputs with datatype int32
input_embeds: the embedding Tensor of tokens, Tensor of shape:math:`(batch_size, seq/_length, hidden_size)`.
Default None.
batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
block_tables (Tensor[int64]): Store mapping tables for each sequence.
slot_mapping (Tensor[int32]): Store token cache physical slot index.
Returns:
output: Tensor, the output of llama decoderlayer
"""
# preprocess
bs, seq_len = self.shape(tokens)
kv_mask = None
seq_chunk = None
rmsnorm_compute_2d = self.training and self.rmsnorm_compute_2d
if self.chunk_prefill and self.is_first_iteration:
# get chunk + decode masks
if attention_mask is not None:
mask = attention_mask
else:
mask = self.casual_mask.chunk_masks(seq_range)
# get chunk + decode pos
freqs_cis = self.freqs_mgr.chunk_with_decode(seq_range)
elif self.parallel_decoding:
# FA with TH layout, mask is 2D, FA with BSH layout, mask is 4D
if self.is_first_iteration:
mask = self.casual_mask.prefill()
else:
mask = attention_mask
freqs_cis = self.freqs_mgr.increment_multi_ids(position_ids)
elif attention_mask is not None:
mask = attention_mask
mask = self.cast(mask, mstype.uint8)
freqs_cis = self.freqs_mgr(seq_len, position_ids)
else:
mask = None
if self.use_past:
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
mask = self.casual_mask.prefill()
if prefix_keys_values is not None:
if mask is None:
mask = self.casual_mask(tokens)
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
else:
if self.seq_pipe:
mask = self.casual_mask(tokens, seq_chunk=self.seq_chunk)
seq_chunk = P.ReLU()(self.seq_chunk)
kv_mask = self.cast(self.equal_kv(self.kv_mask_add(self.zeros, self.kv_mask), seq_chunk),
self.dtype)
seq_update = F.depend(self.seq_update, mask)
seq_update = F.depend(seq_update, kv_mask)
mask = F.depend(mask, self.assign_add_count(self.seq_chunk, seq_update))
elif not self.use_ring_attention:
mask = self.casual_mask(tokens)
freqs_cis = self.freqs_mgr(seq_len, seq_chunk=seq_chunk)
if prefix_keys_values is not None:
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
# tokens: [bs, seq/1]
if input_embeds is not None:
h = self.cast(input_embeds, self.dtype)
else:
h = self.cast(self.tok_embeddings(tokens), self.dtype)
if not rmsnorm_compute_2d:
h = self.reshape(h, (bs, seq_len, self.hidden_size)) # h: [bs, seq/1, hidden_dim]
for i in range(self.num_layers):
prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None
h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables,
slot_mapping=slot_mapping, prefix_keys_values=prefix_kv, q_seq_lens=q_seq_lens,
kv_mask=kv_mask, seq_chunk=seq_chunk)
if rmsnorm_compute_2d:
h = self.reshape(h, (bs * seq_len, -1))
output = self.norm_out(h)
return output
def clear_kv_cache(self):
zeros = 0.0
return_tuple = ()
return_tuple += (self.assign_count(self.seq_chunk, self.seq_zero),)
return_tuple += (self.assign_mask(self.casual_mask.mask_cache, self.mask_zeros),)
return F.depend(zeros, return_tuple)
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS)
class LlamaForCausalLM(LlamaPreTrainedModel):
r"""
Provide llama training loss or logits through network.
Args:
config (LlamaConfig, optional): The config of llama model. Default: `None` .
Inputs:
- **input_ids** (Tensor) - the indices of input sequence tokens in the vocabulary with data type Int64/Int32,
Tensor of shape :math:`(batch, seq\_length)`.
- **labels** (Tensor, optional) - the labels of inputs with data type Int64/Int32, Tensor of
shape :math:`(batch, seq\_length)` . Default: ``None``.
- **input_position** (Tensor, optional) - the position ids of inputs (at incremental reasoning mode) which is
an increasing sequence with data type Int64/Int32, Tensor :math:`(batch, seq\_length)`.
Default: ``None``.
- **position_ids** (Tensor, optional) - the position ids of inputs which is
an increasing sequence with data type
Int64/Int32, Tensor :math:`(batch, seq\_length)`. Default: ``None``.
- **attention_mask** (Tensor, optional) - input sentences padding mask, where 0 indicates padding position with
data type Int64/Int32, Tensor of shape :math:`(batch, seq\_length)`. Default: ``None``.
- **input_embeds** (Tensor, optional) - the embedding of inputs with data type Float32/Float16, Tensor of
shape :math:`(batch, seq\_length, hidden\_size)`. Default: ``None``.
- **init_reset** (Tensor, optional) - A Bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True.
Tensor of shape :math:`(1)`. Default: ``Tensor([True])``.
- **batch_valid_length** (Tensor, optional) - Int32 tensor with shape [batch_size]
the past calculated the index.
Used for incremental prediction when the use_past is True. Default: ``None``.
- **batch_index** (Tensor, optional) - Discard argument. Will be deleted in the future. Default: ``None``.
- **zactivate_len** (Tensor, optional) - Discard argument. Will be deleted in the future. Default: ``None``.
- **block_tables** (Tensor, optional) - Int64 type Tensor, store mapping tables for each sequence.
Default: ``None``.
- **slot_mapping** (Tensor, optional) - Int32 type Tensor, token cache physical slot index. Default: ``None``.
- **prefix_keys_values** (Tensor, optional) - Discard argument. Will be deleted in the future.
Default: ``None``.
- **llm_boost_inputs** (Tensor, optional) - Discard argument. Will be deleted in the future. Default: ``None``.
- **q_seq_lens** (Tensor, optional) - In parallel decoding, the query may be flattened.
The Paged Attention operator need `q_seq_lens` to obtain the length information. Default: ``None`` .
- **loss_mask** (Tensor, optional) - Float32/Int32 type tensor, which is used to determine
whether the corresponding token position participates in the loss calculation. If the value is :math:`(1)`,
the loss of the position is calculated, and :math:`(0)` is not calculated. Default: ``None``.
- **gather_index** (Tensor, optional) - Int32 type Tensor, used to obtain the last latent vector of
each sequence. Default: ``None``.
- **seq_range** (Tensor, optional) - Int32 type Tensor, used to obtain Mask and positional encoding of
valid tokens for each sequence. Default: ``None``.
Outputs:
Tensor. If it is in training mode, the output Tensor contains loss;
If it is in prediction mode, the output Tensor contains logits;
If it is in evaluation mode, the output Tensor contains logits, tokens, and input masks.
Examples:
>>> from mindformers.models.llama import LlamaConfig, LlamaForCausalLM
>>> import mindspore as ms
>>> ms.set_context(mode=0)
>>> config = LlamaConfig(batch_size=2)
>>> network = LlamaForCausalLM(config=config)
>>> type(network)
<class 'mindformers.models.llama.llama.LlamaForCausalLM'>
>>> from mindformers import LlamaForCausalLM
>>> network = LlamaForCausalLM.from_pretrained('llama2_7b')
>>> type(network)
<class 'mindformers.models.llama.llama.LlamaForCausalLM'>
"""
_support_list = MindFormerBook.get_model_support_list()['llama']
@lazy_inline
def __init__(self, config: LlamaConfig = None):
super(LlamaForCausalLM, self).__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.config = config
self.ignore_token_id = config.ignore_token_id
self.pad_token_id = config.pad_token_id
self.use_past = config.use_past
self.vocab_size = config.vocab_size
self.is_first_iteration = True
self.chunk_prefill = config.chunk_prefill
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.slice = P.StridedSlice()
self.not_equal = P.NotEqual()
self.mul = P.Mul()
self.add = P.Add()
self.ones = P.Ones()
self.gather = P.Gather(1)
self.prefill_gather_flatten = P.Gather()
self.sub_batch_valid_len = P.Sub()
self.model = LlamaModel(config=config)
self.lm_head = Linear(in_channels=config.hidden_size,
out_channels=config.vocab_size,
has_bias=False,
compute_dtype=config.compute_dtype,
param_init_type=config.param_init_type,
weight_init="normal") # meta default: xavier_normal
if config.tie_word_embeddings:
self.lm_head.weight = self.model.tok_embeddings.embedding_weight
mp = config.parallel_config.model_parallel
vocab_size = config.vocab_size
loss_parallel_config = copy.deepcopy(config.parallel_config)
if vocab_size % mp != 0:
logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s",
vocab_size, mp)
logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1")
loss_parallel_config.model_parallel = 1
loss_parallel_config.data_parallel *= loss_parallel_config.context_parallel
check_for_nan_in_loss_and_grad = getattr(config, "check_for_nan_in_loss_and_grad", False)
calculate_per_token_loss = getattr(config, "calculate_per_token_loss", False)
self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config,
check_for_nan_in_loss_and_grad=check_for_nan_in_loss_and_grad,
calculate_per_token_loss=calculate_per_token_loss,
seq_split_num=config.parallel_config.seq_split_num)
dp = config.parallel_config.data_parallel
mp = config.parallel_config.model_parallel
cp = config.parallel_config.context_parallel
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
self.slice.shard(((dp, 1),))
self.not_equal.shard(((dp, 1), ()))
if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0):
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (1, 1)))
else:
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (mp, 1)))
else:
self.slice.shard(((dp, 1),))
self.not_equal.shard(((dp, 1), ()))
self.mul.shard(((dp, 1), (dp, 1)))
self.add.shard(((dp, 1), ()))
self.gather.shard(((dp, 1, 1), (dp,)))
self.prefill_gather_flatten.shard(((dp, 1, 1), (dp,)))
self.sub_batch_valid_len.shard(((1,), ()))
if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0):
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (1, 1)))
else:
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (mp, 1)))
if config.parallel_config.pipeline_stage > 1:
self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1
self.load_checkpoint(config)
self.predict_run_mode = get_predict_run_mode()
logger.info("Predict run mode:{}".format(self.predict_run_mode))
self.parallel_decoding = config.parallel_decoding_params is not None
self.input_sliced_sig = config.input_sliced_sig
def to_embeddings(self, tokens):
"""return embedding tokens"""
return self.model.tok_embeddings(tokens)
# pylint: disable=W0613
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
"""Get Llama model input tuple for transform ckpt."""
input_ids = Tensor(input_ids, mstype.int32)
labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None
bs, seq = input_ids.shape[0], input_ids.shape[1]
slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32)
prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None
position_ids = Tensor(np.zeros(shape=tuple([bs, seq])), mstype.int32) if self.parallel_decoding else None
mask = Tensor(np.zeros(shape=tuple([seq, seq])), mstype.float16) if self.parallel_decoding else None
q_seq_lens = Tensor(np.zeros(shape=tuple([bs])), mstype.int32) if self.parallel_decoding else None
outputs = (input_ids, labels, None, position_ids, mask, None, None, None, None, None, None, slot_mapping,
prefix_keys_values, None, q_seq_lens)
return outputs
def set_dynamic_inputs(self, **kwargs):
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False)
dynamic_position_ids = Tensor(shape=[None, None], dtype=mstype.int32) if self.parallel_decoding else None
dynamic_mask = Tensor(shape=[None, None], dtype=mstype.float16) if self.parallel_decoding else None
dynamic_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) if self.parallel_decoding else None
if have_prefix_keys_values:
dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16)
self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_mask, None, None,
dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping, dynamic_prefix_keys_values, None, dynamic_q_seq_lens, None,
None, None)
elif self.use_past:
self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_mask, None, None,
dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping, None, None, dynamic_q_seq_lens, None, None, None)
elif kwargs.get("pre_gather", False):
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
dynamic_batch_valid_length, None, None, None, None, None)
else:
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
None, None, None, None, None, None)
logger.info("Set dynamic input for llama.")
def add_flags_custom(self, is_first_iteration):
"""Add customized attributes for specific cells in the model."""
self.add_flags(is_first_iteration=is_first_iteration)
self.model.add_flags(is_first_iteration=is_first_iteration)
for layer in self.model.layers:
layer.add_flags(is_first_iteration=is_first_iteration)
layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration)
layer.attention.infer_attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
def pre_gather_func(self, pre_gather, output, batch_valid_length, gather_index=None):
"""Pre gather operation in infer mode."""
if not pre_gather:
return output
if pre_gather:
if self.chunk_prefill and self.is_first_iteration:
output = output.reshape(-1, output.shape[-1])
output = output[self.sub_batch_valid_len(gather_index, 1)]
elif self.config.is_dynamic:
batch_valid_length = mint.cumsum(batch_valid_length, 0)
output = self.prefill_gather_flatten(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
else:
output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
return output
# pylint: disable=W0613
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None,
zactivate_len=None, block_tables=None, slot_mapping=None, prefix_keys_values=None,
llm_boost_inputs=None, q_seq_lens=None, loss_mask=None, gather_index=None, seq_range=None):
r"""LlamaForCausalLM forward."""
has_loss_mask = loss_mask is not None
input_sliced_sig = self.input_sliced_sig
if self.training and input_sliced_sig and labels is None:
input_sliced_sig = False
bsz, seqlen = self.shape(input_ids)
if self.use_past:
if not isinstance(batch_valid_length, Tensor):
batch_valid_length = self.ones((bsz,), mstype.int32)
if not input_sliced_sig and self.training:
tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1))
if has_loss_mask:
loss_mask = self.slice(loss_mask, (0, 0), (bsz, seqlen - 1), (1, 1))
else:
tokens = input_ids
if batch_valid_length is not None:
batch_valid_length = self.reshape(batch_valid_length, (-1,))
output = self.model(tokens, input_embeds, batch_valid_length, batch_index, zactivate_len, block_tables, \
slot_mapping, prefix_keys_values, attention_mask, position_ids, q_seq_lens, \
seq_range)
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
output = self.pre_gather_func(pre_gather, output, batch_valid_length, gather_index)
logits = self.lm_head(output)
input_mask = loss_mask if has_loss_mask \
else self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32)
if labels is None:
labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1))
else:
if labels.ndim > 1:
if not input_sliced_sig and self.training:
labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1))
if not has_loss_mask:
label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32)
input_mask = self.mul(input_mask, label_mask)
if not self.training:
logits = self.cast(logits, mstype.float32)
if self.predict_run_mode:
logits = self.reshape(logits, (-1, logits.shape[-1]))
return logits
return logits, tokens, input_mask
if logits.ndim > 2:
logits = self.reshape(logits, (-1, logits.shape[-1]))
logits = self.cast(logits, mstype.float32)
labels = self.reshape(labels, (-1,))
input_mask = self.reshape(input_mask, (-1,))
loss = self.loss(logits, labels, input_mask)
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
return key_cache, value_cache
@classmethod
def convert_name(cls, weight_name):
"""convert HuggingFace weight name to MindFormers weight name"""
origin_name = weight_name
weight_name = weight_name.replace('embed_tokens.', 'tok_embeddings.')
weight_name = weight_name.replace('.self_attn.q_proj.', '.attention.wq.')
weight_name = weight_name.replace('.self_attn.k_proj.', '.attention.wk.')
weight_name = weight_name.replace('.self_attn.v_proj.', '.attention.wv.')
weight_name = weight_name.replace('.self_attn.o_proj.', '.attention.wo.')
weight_name = weight_name.replace('.mlp.gate_proj.', '.feed_forward.w1.')
weight_name = weight_name.replace('.mlp.down_proj.', '.feed_forward.w2.')
weight_name = weight_name.replace('.mlp.up_proj.', '.feed_forward.w3.')
weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.')
weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.')
weight_name = weight_name.replace('.norm.', '.norm_out.')
weight_name = weight_name.replace('output.', 'lm_head.')
weight_name = weight_name.replace('.tok_embeddings.weight', '.tok_embeddings.embedding_weight')
if weight_name == origin_name:
logger.warning(f"weight name '{weight_name}' does not change after conversion. "
f"Please check if it is as expected.")
return weight_name
@classmethod
def convert_weight_dict(cls, source_dict, **kwargs):
"""convert HuggingFace weight dict to MindFormers weight dict"""
model_config = kwargs.get("model_config")
qkv_concat = model_config.qkv_concat
target_dict = {}
wq_keys = []
wk_keys = []
wv_keys = []
w1_keys = []
w3_keys = []
for k, v in source_dict.items():
k = cls.convert_name(k)
target_dict.update({k: v})
if qkv_concat:
part = k.split('.')
if part[-2] == 'wq':
wq_keys.append(k)
if part[-2] == 'wk':
wk_keys.append(k)
if part[-2] == 'wv':
wv_keys.append(k)
if part[-2] == 'w1':
w1_keys.append(k)
if part[-2] == 'w3':
w3_keys.append(k)
if qkv_concat:
qkv_dict = kwargs.get('qkv_dict', None)
if not isinstance(qkv_dict, DictProxy):
raise ValueError(f'qkv_queue must be a queue, when qkv_concat is True, but got {qkv_dict}.')
condition = kwargs.get('condition', None)
if not isinstance(condition, Condition):
raise ValueError(f'condition must be a Condition, when qkv_concat is True, but got {condition}.')
_concat_qkv_weight(wq_keys, wk_keys, wv_keys, model_config, qkv_dict, condition, target_dict)
_concat_ffn_weight(w1_keys, w3_keys, model_config, qkv_dict, condition, target_dict)
return target_dict
@classmethod
def convert_map_dict(cls, source_dict, **kwargs):
"""convert HuggingFace map dict to MindFormers map dict"""
qkv_concat = kwargs.pop("qkv_concat", False)
target_dict = {}
wq_keys = []
w1_keys = []
for k, v in source_dict.items():
k = cls.convert_name(k)
target_dict.update({k: v})
if qkv_concat:
part = k.split('.')
if part[-2] == 'wq':
wq_keys.append(k)
if part[-2] == 'w1':
w1_keys.append(k)
if qkv_concat:
for wq_key in wq_keys:
wk_key = wq_key.replace('wq', 'wk')
wv_key = wq_key.replace('wq', 'wv')
wq_value = target_dict.pop(wq_key)
target_dict.pop(wk_key)
target_dict.pop(wv_key)
w_qkv_key = wq_key.replace('wq', 'w_qkv')
w_qkv_value = wq_value
target_dict.update({w_qkv_key: w_qkv_value})
for w1_key in w1_keys:
w3_key = w1_key.replace('w1', 'w3')
w1_value = target_dict.pop(w1_key)
target_dict.pop(w3_key)
w_gate_hidden_key = w1_key.replace('w1', 'w_gate_hidden')
w_gate_hidden_value = w1_value
target_dict.update({w_gate_hidden_key: w_gate_hidden_value})
return target_dict
@classmethod
def obtain_qkv_ffn_concat_keys(cls):
qkv_key = "w_qkv"
ffn_key = "w_gate_hidden"
concat_keys = [qkv_key, ffn_key]
logger.info(f"{cls.__name__} qkv/ffn concat keys are {concat_keys}")
return concat_keys
def clear_kv_cache(self):
return self.model.clear_kv_cache()
def _concat_qkv_weight(wq_keys, wk_keys, wv_keys, model_config, qkv_dict, condition, target_dict):
"""concat qkv weight from dicts"""
from mindformers.utils.convert_utils import qkv_concat_hf2mg
num_heads = model_config.num_heads
n_kv_heads = model_config.n_kv_heads or num_heads
hidden_size = model_config.hidden_size
# pop extra weight to shared dict if there is no corresponding weight for concat in the target dict
for wk_key in wk_keys:
wq_key = wk_key.replace('wk', 'wq')
if wq_key not in wq_keys:
with condition:
qkv_dict[wk_key] = target_dict.pop(wk_key) # add extra weight to shared dict
condition.notify_all()
for wv_key in wv_keys:
wq_key = wv_key.replace('wv', 'wq')
if wq_key not in wq_keys:
with condition:
qkv_dict[wv_key] = target_dict.pop(wv_key) # add extra weight to shared dict
condition.notify_all()
# concat qkv
for wq_key in wq_keys:
wk_key = wq_key.replace('wq', 'wk')
wv_key = wq_key.replace('wq', 'wv')
wq_value = target_dict.pop(wq_key)
wk_value = target_dict.pop(wk_key, None)
wv_value = target_dict.pop(wv_key, None)
# get missing weight from shared dict
if wk_value is None:
with condition:
condition.wait_for(lambda: wk_key in qkv_dict.keys())
wk_value = qkv_dict.pop(wk_key)
if wv_value is None:
with condition:
condition.wait_for(lambda: wv_key in qkv_dict.keys())
wv_value = qkv_dict.pop(wv_key)
w_qkv_key = wq_key.replace('wq', 'w_qkv')
w_qkv_value = np.concatenate((wq_value, wk_value, wv_value), 0)
# qkv weight format: hf -> mg
w_qkv_value_mg = qkv_concat_hf2mg(w_qkv_value, num_heads, n_kv_heads, hidden_size)
target_dict.update({w_qkv_key: w_qkv_value_mg})
def _concat_ffn_weight(w1_keys, w3_keys, model_config, qkv_dict, condition, target_dict):
"""concat ffn weight from dicts"""
from mindformers.utils.convert_utils import ffn_concat_hf2mg
intermediate_size = model_config.intermediate_size
ffn_dim_multiplier = model_config.ffn_dim_multiplier
multiple_of = model_config.multiple_of or 256
ffn_hidden_size = model_config.hidden_size * 4
if intermediate_size is not None:
ffn_hidden_size = intermediate_size
else:
if ffn_dim_multiplier is not None:
ffn_hidden_size = int((ffn_dim_multiplier + 0.01) * ffn_hidden_size)
ffn_hidden_size = int(2 * ffn_hidden_size / 3)
ffn_hidden_size = multiple_of * \
((ffn_hidden_size + multiple_of - 1) // multiple_of)
# pop extra weight to shared dict if there is no corresponding weight for concat in the target dict
for w3_key in w3_keys:
w1_key = w3_key.replace('w3', 'w1')
if w1_key not in w1_keys:
with condition:
qkv_dict[w3_key] = target_dict.pop(w3_key) # add extra weight to shared dict
condition.notify_all()
# concat ffn
for w1_key in w1_keys:
w3_key = w1_key.replace('w1', 'w3')
w1_value = target_dict.pop(w1_key)
w3_value = target_dict.pop(w3_key, None)
# get missing weight from shared dict
if w3_value is None:
with condition:
condition.wait_for(lambda: w3_key in qkv_dict.keys())
w3_value = qkv_dict.pop(w3_key)
w_gate_hidden_key = w1_key.replace('w1', 'w_gate_hidden')
w_gate_hidden_value = np.concatenate((w1_value, w3_value), 0)
# ffn weight format: hf -> mg
w_gate_hidden_value_mg = ffn_concat_hf2mg(w_gate_hidden_value, ffn_hidden_size)
target_dict.update({w_gate_hidden_key: w_gate_hidden_value_mg})