# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LLaMA models' APIs."""
import copy
from multiprocessing.managers import DictProxy
from multiprocessing.synchronize import Condition
import numpy as np
from safetensors import safe_open
import mindspore.common.dtype as mstype
from mindspore import Tensor, nn, mint, Parameter
from mindspore.common.initializer import initializer
from mindspore.context import ParallelMode
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindformers.core.loss.loss import CrossEntropyLoss
from mindformers.mindformer_book import MindFormerBook
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.models.utils import LayerSetting, check_fine_grain_interleave_valid, check_use_3d_tensor_parallel_valid
from mindformers.modules.layers import Linear, FreqsMgr
from mindformers.modules.transformer import LowerTriangularMaskWithDynamic
from mindformers.modules.transformer.op_parallel_config import _check_config
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.tools.utils import get_predict_run_mode
from mindformers.version_control import check_seqpp_fa_opt_support
from mindformers.tools.utils import is_pynative
from .llama_config import LlamaConfig
from .llama_layer import LlamaEmbedding, LlamaRMSNorm
from .llama_transformer import LLamaDecodeLayer
from .llama_interleave import LLamaDecodeLayerInterleave
from ..utils import lazy_inline
from ...tools.logger import logger
__all__ = ['LlamaModel', 'LlamaForCausalLM']
class LlamaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LlamaConfig
base_model_prefix = "llama"
class LlamaModel(LlamaPreTrainedModel):
r"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config(LlamaConfig): the config of network
Returns:
output: Tensor, the output of llama decoderlayer
Examples:
>>> from mindformers import LlamaModel
>>> network = LlamaModel.from_pretrained('llama_7b')
>>> type(network)
<class 'mindformers.models.llama.llama.LlamaModel'>
"""
_support_list = MindFormerBook.get_model_support_list()['llama']
def __init__(self,
config: LlamaConfig = None):
super().__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.dtype = config.compute_dtype
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.n_head = config.num_heads
self.head_dim = self.hidden_size // self.n_head
self.pad_token_id = config.pad_token_id
self.is_first_iteration = True
self.chunk_prefill = config.chunk_prefill
self.use_past = config.use_past
self.use_eod_attn_mask_compression = config.use_eod_attn_mask_compression
self.use_flash_attention = config.use_flash_attention
self.use_ring_attention = config.use_ring_attention
self.parallel_decoding = config.parallel_decoding_params is not None
self.concat = P.Concat(-1)
self.cast = P.Cast()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.rmsnorm_compute_2d = config.rmsnorm_compute_2d
self.rl_config = config.rl_config
self.is_pynative = is_pynative()
if config.moe_config.expert_num > 1:
logger.info("MoE config is provided, use MoE FFN")
else:
logger.info("MoE config is None, use normal FFN")
if not self.use_flash_attention and self.use_ring_attention:
raise ValueError(f"When the ring_attention = True, the flash_attention must be True.")
if not self.use_flash_attention and self.use_eod_attn_mask_compression:
raise ValueError(f"When the use_eod_attn_mask_compression = True, the flash_attention must be True.")
self.seq_split_num = config.parallel_config.seq_split_num
self.seq_pipe = self.seq_split_num > 1
if self.seq_pipe:
dp = config.parallel_config.data_parallel
if self.use_ring_attention:
raise ValueError(f"When the seq_pipe = True, the use_ring_attention cannot be True.")
if config.use_attn_mask_compression and not check_seqpp_fa_opt_support():
raise ValueError(f"Currently, when the seq_pipe = True, "
f"use_attn_mask_compress must be False with mindspore < 2.6.0. "
f"If you want to enable it, please upgrade mindspore to 2.6.0 or later.")
if config.use_eod_attn_mask_compression:
raise ValueError(f"Currently, when the seq_pipe = True, "
f"use_eod_attn_mask_compression cannot be True.")
self.n_kv_head = self.n_head if config.n_kv_heads is None else config.n_kv_heads
kv_shape = (config.batch_size * dp, self.n_kv_head, config.seq_length, self.head_dim)
self.zeros = initializer('zeros', kv_shape, dtype=self.dtype)
self.seq_update = Tensor(1, dtype=mstype.int32)
self.seq_zero = Tensor(0, dtype=mstype.int32)
self.seq_seg_len = config.seq_length // self.seq_split_num
kv_mask = np.zeros((1, self.n_kv_head, config.seq_length, self.head_dim), np.int32)
for s in range(self.seq_split_num):
kv_mask[:, :, s * self.seq_seg_len: (s + 1) * self.seq_seg_len, :] = s
self.kv_mask = Tensor(kv_mask)
self.seq_chunk = Parameter(Tensor(0, dtype=mstype.int32), name="seq_chunk",
requires_grad=False, parallel_optimizer=False)
cp = config.parallel_config.context_parallel
mp = config.parallel_config.model_parallel
self.equal_kv = P.Equal().shard(((dp, mp, cp, 1), ()))
self.kv_mask_add = P.Add().shard(((dp, mp, cp, 1), (1, mp, cp, 1)))
self.assign_add_count = P.AssignAdd()
self.assign_count = P.Assign()
self.assign_mask = P.Assign().shard(((dp, 1), (dp, 1)))
self.mask_zeros = Tensor(np.zeros((config.batch_size * dp, config.seq_length)), mstype.float32)
self.freqs_mgr = FreqsMgr(head_dim=self.head_dim,
seq_length=config.seq_length,
max_position_embedding=config.max_position_embedding,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
scaling_factor=config.scaling_factor,
extend_method=config.extend_method,
parallel_config=config.parallel_config,
is_dynamic=config.is_dynamic)
self.residual_cast_flag = config.residual_dtype != self.dtype
if self.residual_cast_flag:
logger.info(f"residual in llama model cast flag: {self.residual_cast_flag}, "
f"residual dtype: {config.residual_dtype}")
total_batch_size_in_dp = config.batch_size * config.parallel_config.data_parallel
use_attn_mask_compression = config.use_attn_mask_compression or config.use_eod_attn_mask_compression
self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length,
batch_size=total_batch_size_in_dp,
compute_type=config.compute_dtype,
is_dynamic=config.is_dynamic,
pad_token_id=config.pad_token_id,
use_flash_attention=config.use_flash_attention,
use_attn_mask_compression=use_attn_mask_compression,
use_past=config.use_past,
seq_split_num=self.seq_split_num,
chunk_prefill=config.chunk_prefill)
self.tok_embeddings = LlamaEmbedding(vocab_table_size=config.vocab_size,
embedding_size=config.hidden_size,
init_method_std=config.init_method_std,
param_init_type=config.embedding_init_type,
parallel_optimizer=config.parallel_optimizer,
rmsnorm_compute_2d=config.rmsnorm_compute_2d)
self.fine_grain_interleave = check_fine_grain_interleave_valid(config.fine_grain_interleave,
config.parallel_config)
self.use_3d_tensor_parallel = check_use_3d_tensor_parallel_valid(config)
self.tp_x = getattr(config, "tp_x", 1)
self.tp_y = getattr(config, "tp_y", 1)
self.tp_z = getattr(config, "tp_z", 1)
self.layers = nn.CellList()
self.layer_setting = LayerSetting(config.num_layers,
config.offset,
config.parallel_config,
config.pp_interleave_num,
config.start_stage,
config.stage_num)
for layer_id in range(config.num_layers):
if self.fine_grain_interleave:
layer = LLamaDecodeLayerInterleave(config.batch_size,
config.seq_length,
layer_id,
dim=config.hidden_size,
n_heads=config.num_heads,
num_layers=config.num_layers,
multiple_of=config.multiple_of,
n_kv_heads=config.n_kv_heads,
intermediate_size=config.intermediate_size,
ffn_dim_multiplier=config.ffn_dim_multiplier,
norm_eps=config.rms_norm_eps,
qkv_has_bias=config.qkv_has_bias,
attn_proj_has_bias=config.attn_proj_has_bias,
qkv_concat=config.qkv_concat,
compute_dtype=config.compute_dtype,
layernorm_compute_dtype=config.layernorm_compute_type,
softmax_compute_dtype=config.softmax_compute_type,
rotary_dtype=config.rotary_dtype,
param_init_type=config.param_init_type,
residual_dtype=config.residual_dtype,
use_flash_attention=config.use_flash_attention,
use_ring_attention=config.use_ring_attention,
use_attn_mask_compression=config.use_attn_mask_compression,
use_eod_attn_mask_compression=config.use_eod_attn_mask_compression,
fine_grain_interleave=config.fine_grain_interleave,
init_method_std=config.init_method_std,
parallel_config=config.parallel_config)
else:
layer = LLamaDecodeLayer(config.seq_length,
layer_id,
dim=config.hidden_size,
n_heads=config.num_heads,
n_kv_heads=config.n_kv_heads,
intermediate_size=config.intermediate_size,
multiple_of=config.multiple_of,
ffn_dim_multiplier=config.ffn_dim_multiplier,
norm_eps=config.rms_norm_eps,
qkv_has_bias=config.qkv_has_bias,
attn_proj_has_bias=config.attn_proj_has_bias,
qkv_concat=config.qkv_concat,
compute_dtype=config.compute_dtype,
layernorm_compute_dtype=config.layernorm_compute_type,
softmax_compute_dtype=config.softmax_compute_type,
rotary_dtype=config.rotary_dtype,
param_init_type=config.param_init_type,
residual_dtype=config.residual_dtype,
use_past=config.use_past,
is_dynamic=config.is_dynamic,
use_flash_attention=config.use_flash_attention,
use_ring_attention=config.use_ring_attention,
use_attn_mask_compression=config.use_attn_mask_compression,
use_eod_attn_mask_compression=config.use_eod_attn_mask_compression,
block_size=config.block_size,
num_blocks=config.num_blocks,
use_rope_slice=config.use_rope_slice,
rmsnorm_compute_2d=config.rmsnorm_compute_2d,
batch_size=config.batch_size,
moe_config=config.moe_config,
parallel_config=config.parallel_config,
parallel_decoding=self.parallel_decoding,
rl_config=self.rl_config,
fused_kernel=config.fused_rms_norm,
init_method_std=config.init_method_std,
chunk_prefill=config.chunk_prefill,
use_3d_tensor_parallel=self.use_3d_tensor_parallel,
tp_x=self.tp_x,
tp_y=self.tp_y,
tp_z=self.tp_z
)
self.layer_setting(layer, layer_id)
self.layers.append(layer)
self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps,
compute_type=config.layernorm_compute_type,
fused_kernel=config.fused_rms_norm)
dp = config.parallel_config.data_parallel
cp = config.parallel_config.context_parallel
mp = config.parallel_config.model_parallel
self.tok_embeddings.pipeline_stage = config.start_stage
if config.parallel_config.pipeline_stage > 1:
if config.stage_num == 0:
self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1
else:
self.norm_out.pipeline_stage = config.start_stage + config.stage_num - 1
self.tok_embeddings.set_comm_fusion(2)
self.norm_out.set_comm_fusion(2)
else:
self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group)
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
self.tok_embeddings.shard(config.parallel_config)
self.casual_mask.shard(config.parallel_config)
if self.fine_grain_interleave or config.rmsnorm_compute_2d:
self.norm_out.shard((dp * cp, 1))
else:
self.norm_out.shard((dp, cp, 1))
elif _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
self.tok_embeddings.shard(config.parallel_config)
self.casual_mask.shard(config.parallel_config)
self.concat.shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
if self.fine_grain_interleave or config.rmsnorm_compute_2d:
if self.rl_config is not None:
self.norm_out.shard((dp * cp * mp, 1))
else:
self.norm_out.shard((dp * cp, 1))
else:
self.norm_out.shard((dp, cp, 1))
# pylint: disable=W0613
def construct(self, tokens: Tensor, input_embeds=None, batch_valid_length=None, batch_index=None,
zactivate_len=None, block_tables=None, slot_mapping=None, prefix_keys_values=None,
attention_mask=None, position_ids=None, q_seq_lens=None, seq_range=None, actual_seq_len=None):
"""
Forward of llama model.
Args:
tokens: the tokenized inputs with datatype int32
input_embeds: the embedding Tensor of tokens, Tensor of shape:math:`(batch_size, seq/_length, hidden_size)`.
Default None.
batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental
prediction. Tensor of shape :math:`(batch_size,)`. Default None.
block_tables (Tensor[int64]): Store mapping tables for each sequence.
slot_mapping (Tensor[int32]): Store token cache physical slot index.
Returns:
output: Tensor, the output of llama decoderlayer
"""
# preprocess
bs, seq_len = self.shape(tokens)
if actual_seq_len is not None:
actual_seq_len = self.reshape(actual_seq_len, (-1,))
kv_mask = None
seq_chunk = None
rmsnorm_compute_2d = self.training and self.rmsnorm_compute_2d
if self.chunk_prefill and self.is_first_iteration:
# get chunk + decode masks
if attention_mask is not None:
mask = attention_mask
else:
mask = self.casual_mask.chunk_masks(seq_range)
# get chunk + decode pos
freqs_cis = self.freqs_mgr.chunk_with_decode(seq_range)
elif self.parallel_decoding:
# FA with TH layout, mask is 2D, FA with BSH layout, mask is 4D
if self.is_first_iteration:
mask = self.casual_mask.prefill()
else:
mask = attention_mask
freqs_cis = self.freqs_mgr.increment_multi_ids(position_ids)
elif self.use_eod_attn_mask_compression and not self.use_ring_attention:
mask = self.casual_mask()
freqs_cis = self.freqs_mgr(seq_len)
elif attention_mask is not None:
mask = attention_mask
mask = self.cast(mask, mstype.uint8)
freqs_cis = self.freqs_mgr(seq_len)
if self.seq_pipe:
raise ValueError("When the seq_pipe = True, the attention_mask must be None.")
else:
mask = None
if self.use_past:
if self.is_first_iteration:
freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
if self.use_flash_attention:
if self.is_pynative:
mask = self.casual_mask(tokens)
else:
mask = self.casual_mask.prefill()
else:
mask = self.casual_mask(tokens)
if prefix_keys_values is not None:
if mask is None:
mask = self.casual_mask(tokens)
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
else:
freqs_cis = self.freqs_mgr.increment(batch_valid_length)
else:
if self.seq_pipe:
mask = self.casual_mask(tokens, seq_chunk=self.seq_chunk)
seq_chunk = P.ReLU()(self.seq_chunk)
kv_mask = self.cast(self.equal_kv(self.kv_mask_add(self.zeros, self.kv_mask), seq_chunk),
self.dtype)
seq_update = F.depend(self.seq_update, mask)
seq_update = F.depend(seq_update, kv_mask)
mask = F.depend(mask, self.assign_add_count(self.seq_chunk, seq_update))
elif not self.use_ring_attention:
mask = self.casual_mask(tokens)
freqs_cis = self.freqs_mgr(seq_len, seq_chunk=seq_chunk)
if prefix_keys_values is not None:
prefix_length = prefix_keys_values[0].shape[2]
prefix_mask = Tensor(np.zeros((bs, 1, seq_len, prefix_length)), dtype=mask.dtype)
mask = self.concat((prefix_mask, mask))
# tokens: [bs, seq/1]
if input_embeds is not None:
h = self.cast(input_embeds, self.dtype)
else:
h = self.cast(self.tok_embeddings(tokens), self.dtype)
if not rmsnorm_compute_2d:
h = self.reshape(h, (bs, seq_len, self.hidden_size)) # h: [bs, seq/1, hidden_dim]
for i in range(self.num_layers):
prefix_kv = prefix_keys_values[i] if prefix_keys_values is not None else None
h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables,
slot_mapping=slot_mapping, prefix_keys_values=prefix_kv, q_seq_lens=q_seq_lens,
kv_mask=kv_mask, seq_chunk=seq_chunk, actual_seq_len=actual_seq_len)
if rmsnorm_compute_2d:
h = self.reshape(h, (bs * seq_len, -1))
output = self.norm_out(h)
return output
def clear_kv_cache(self):
zeros = 0.0
return_tuple = ()
return_tuple += (self.assign_count(self.seq_chunk, self.seq_zero),)
return_tuple += (self.assign_mask(self.casual_mask.mask_cache, self.mask_zeros),)
return F.depend(zeros, return_tuple)
[文档]@MindFormerRegister.register(MindFormerModuleType.MODELS)
class LlamaForCausalLM(LlamaPreTrainedModel):
r"""
Provide llama training loss or logits through network.
Args:
config (LlamaConfig, optional): The config of llama model. Default: `None` .
Inputs:
- **input_ids** (Tensor) - the indices of input sequence tokens in the vocabulary with data type Int64/Int32,
Tensor of shape :math:`(batch, seq\_length)`.
- **labels** (Tensor, optional) - the labels of inputs with data type Int64/Int32, Tensor of
shape :math:`(batch, seq\_length)` . Default: ``None``.
- **input_position** (Tensor, optional) - the position ids of inputs (at incremental reasoning mode) which is
an increasing sequence with data type Int64/Int32, Tensor :math:`(batch, seq\_length)`.
Default: ``None``.
- **position_ids** (Tensor, optional) - the position ids of inputs which is
an increasing sequence with data type
Int64/Int32, Tensor :math:`(batch, seq\_length)`. Default: ``None``.
- **attention_mask** (Tensor, optional) - input sentences padding mask, where 0 indicates padding position with
data type Int64/Int32, Tensor of shape :math:`(batch, seq\_length)`. Default: ``None``.
- **input_embeds** (Tensor, optional) - the embedding of inputs with data type Float32/Float16, Tensor of
shape :math:`(batch, seq\_length, hidden\_size)`. Default: ``None``.
- **init_reset** (Tensor, optional) - A Bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True.
Tensor of shape :math:`(1)`. Default: ``Tensor([True])``.
- **batch_valid_length** (Tensor, optional) - Int32 tensor with shape [batch_size]
the past calculated the index.
Used for incremental prediction when the use_past is True. Default: ``None``.
- **batch_index** (Tensor, optional) - Discard argument. Will be deleted in the future. Default: ``None``.
- **zactivate_len** (Tensor, optional) - Discard argument. Will be deleted in the future. Default: ``None``.
- **block_tables** (Tensor, optional) - Int64 type Tensor, store mapping tables for each sequence.
Default: ``None``.
- **slot_mapping** (Tensor, optional) - Int32 type Tensor, token cache physical slot index. Default: ``None``.
- **prefix_keys_values** (Tensor, optional) - Discard argument. Will be deleted in the future.
Default: ``None``.
- **llm_boost_inputs** (Tensor, optional) - Discard argument. Will be deleted in the future. Default: ``None``.
- **q_seq_lens** (Tensor, optional) - In parallel decoding, the query may be flattened.
The Paged Attention operator need `q_seq_lens` to obtain the length information. Default: ``None`` .
- **loss_mask** (Tensor, optional) - Float32/Int32 type tensor, which is used to determine
whether the corresponding token position participates in the loss calculation. If the value is :math:`(1)`,
the loss of the position is calculated, and :math:`(0)` is not calculated. Default: ``None``.
- **gather_index** (Tensor, optional) - Int32 type Tensor, used to obtain the last latent vector of
each sequence. Default: ``None``.
- **seq_range** (Tensor, optional) - Int32 type Tensor, used to obtain Mask and positional encoding of
valid tokens for each sequence. Default: ``None``.
- **actual_seq_len** (Tensor, optional) - Int32 type Tensor, used to automatically generate attention mask
within FlashAttention for eod text. Default: ``None``.
Outputs:
Tensor. If it is in training mode, the output Tensor contains loss;
If it is in prediction mode, the output Tensor contains logits;
If it is in evaluation mode, the output Tensor contains logits, tokens, and input masks.
Examples:
>>> from mindformers.models.llama import LlamaConfig, LlamaForCausalLM
>>> import mindspore as ms
>>> ms.set_context(mode=0)
>>> config = LlamaConfig(batch_size=2)
>>> network = LlamaForCausalLM(config=config)
>>> type(network)
<class 'mindformers.models.llama.llama.LlamaForCausalLM'>
>>> from mindformers import LlamaForCausalLM
>>> network = LlamaForCausalLM.from_pretrained('llama2_7b')
>>> type(network)
<class 'mindformers.models.llama.llama.LlamaForCausalLM'>
"""
_support_list = MindFormerBook.get_model_support_list()['llama']
@lazy_inline
def __init__(self, config: LlamaConfig = None):
super(LlamaForCausalLM, self).__init__(config, auto_prefix=True)
_check_config(config.parallel_config)
self.config = config
self.ignore_token_id = config.ignore_token_id
self.pad_token_id = config.pad_token_id
self.use_past = config.use_past
self.vocab_size = config.vocab_size
self.is_first_iteration = True
self.chunk_prefill = config.chunk_prefill
self.rl_config = config.rl_config
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
self.slice = P.StridedSlice()
self.not_equal = P.NotEqual()
self.mul = P.Mul()
self.add = P.Add()
self.ones = P.Ones()
self.gather = P.Gather(1)
self.prefill_gather_flatten = P.Gather()
self.sub_batch_valid_len = P.Sub()
self.predict_run_mode = get_predict_run_mode()
logger.info("Predict run mode: {}".format(self.predict_run_mode))
if self.predict_run_mode and self.config.is_dynamic:
logger.info("use_flash_attention is set to True when run_mode is predict and is_dynamic is True.")
self.config.use_flash_attention = True
self.model = LlamaModel(config=config)
self.lm_head = Linear(in_channels=config.hidden_size,
out_channels=config.vocab_size,
has_bias=False,
compute_dtype=config.compute_dtype,
param_init_type=config.param_init_type,
weight_init="normal") # meta default: xavier_normal
if config.tie_word_embeddings:
self.lm_head.weight = self.model.tok_embeddings.embedding_weight
mp = config.parallel_config.model_parallel
vocab_size = config.vocab_size
loss_parallel_config = copy.deepcopy(config.parallel_config)
if vocab_size % mp != 0:
logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s",
vocab_size, mp)
logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1")
loss_parallel_config.model_parallel = 1
loss_parallel_config.data_parallel *= loss_parallel_config.context_parallel
check_for_nan_in_loss_and_grad = getattr(config, "check_for_nan_in_loss_and_grad", False)
calculate_per_token_loss = getattr(config, "calculate_per_token_loss", False)
self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config,
check_for_nan_in_loss_and_grad=check_for_nan_in_loss_and_grad,
calculate_per_token_loss=calculate_per_token_loss,
seq_split_num=config.parallel_config.seq_split_num)
dp = config.parallel_config.data_parallel
mp = config.parallel_config.model_parallel
cp = config.parallel_config.context_parallel
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
self.slice.shard(((dp, 1),))
self.not_equal.shard(((dp, 1), ()))
if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0):
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (1, 1)))
else:
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (mp, 1)))
elif _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
self.slice.shard(((dp, 1),))
self.not_equal.shard(((dp, 1), ()))
self.mul.shard(((dp, 1), (dp, 1)))
self.add.shard(((dp, 1), ()))
self.gather.shard(((dp, 1, 1), (dp,)))
self.prefill_gather_flatten.shard(((dp, 1, 1), (dp,)))
self.sub_batch_valid_len.shard(((1,), ()))
if config.parallel_config.vocab_emb_dp or (vocab_size % mp != 0):
if self.rl_config is not None:
self.lm_head.shard(strategy_matmul=((dp * cp * mp, 1), (1, 1)))
else:
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (1, 1)))
else:
self.lm_head.shard(strategy_matmul=((dp * cp, 1), (mp, 1)))
if config.parallel_config.pipeline_stage > 1:
self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1
self.load_checkpoint(config)
self.parallel_decoding = config.parallel_decoding_params is not None
self.input_sliced_sig = config.input_sliced_sig
def to_embeddings(self, tokens):
"""return embedding tokens"""
return self.model.tok_embeddings(tokens)
# pylint: disable=W0613
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
"""Get Llama model input tuple for transform ckpt."""
input_ids = Tensor(input_ids, mstype.int32)
labels = Tensor(kwargs["labels"]) if "labels" in kwargs else None
bs, seq = input_ids.shape[0], input_ids.shape[1]
slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32)
prefix_keys_values = Tensor(kwargs["prefix_keys_values"]) if "prefix_keys_values" in kwargs else None
position_ids = Tensor(np.zeros(shape=tuple([bs, seq])), mstype.int32) if self.parallel_decoding else None
mask = Tensor(np.zeros(shape=tuple([seq, seq])), mstype.float16) if self.parallel_decoding else None
q_seq_lens = Tensor(np.zeros(shape=tuple([bs])), mstype.int32) if self.parallel_decoding else None
outputs = (input_ids, labels, None, position_ids, mask, None, None, None, None, None, None, slot_mapping,
prefix_keys_values, None, q_seq_lens)
return outputs
def set_dynamic_inputs(self, **kwargs):
dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
have_prefix_keys_values = getattr(kwargs, "have_prefix_keys_values", False)
dynamic_position_ids = Tensor(shape=[None, None], dtype=mstype.int32) if self.parallel_decoding else None
dynamic_mask = Tensor(shape=[None, None], dtype=mstype.float16) if self.parallel_decoding else None
dynamic_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) if self.parallel_decoding else None
if have_prefix_keys_values:
dynamic_prefix_keys_values = Tensor(shape=[2, None, None, None, None], dtype=mstype.float16)
self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_mask, None, None,
dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping, dynamic_prefix_keys_values, None, dynamic_q_seq_lens, None,
None, None, None)
elif self.use_past:
self.set_inputs(dynamic_input_ids, None, None, dynamic_position_ids, dynamic_mask, None, None,
dynamic_batch_valid_length, None, None, dynamic_block_tables,
dynamic_slot_mapping, None, None, dynamic_q_seq_lens, None, None, None, None)
elif kwargs.get("pre_gather", False):
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
dynamic_batch_valid_length, None, None, None, None, None)
else:
self.set_inputs(dynamic_input_ids, None, None, None, None, None, None,
None, None, None, None, None, None, None, None, None, None, None, None)
logger.info("Set dynamic input for llama.")
def add_flags_custom(self, is_first_iteration):
"""Add customized attributes for specific cells in the model."""
self.add_flags(is_first_iteration=is_first_iteration)
self.model.add_flags(is_first_iteration=is_first_iteration)
for layer in self.model.layers:
layer.add_flags(is_first_iteration=is_first_iteration)
layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration)
layer.attention.infer_attention.paged_attention_mgr.add_flags(is_first_iteration=is_first_iteration)
def pre_gather_func(self, pre_gather, output, batch_valid_length, gather_index=None):
"""Pre gather operation in infer mode."""
if not pre_gather:
return output
if pre_gather:
if self.chunk_prefill and self.is_first_iteration:
output = output.reshape(-1, output.shape[-1])
output = output[self.sub_batch_valid_len(gather_index, 1)]
elif self.config.is_dynamic:
batch_valid_length = mint.cumsum(batch_valid_length, 0)
output = self.prefill_gather_flatten(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
else:
output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1)
return output
# pylint: disable=W0613
def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None,
input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None,
block_tables=None, slot_mapping=None, prefix_keys_values=None, llm_boost_inputs=None, q_seq_lens=None,
loss_mask=None, gather_index=None, seq_range=None, actual_seq_len=None):
r"""LlamaForCausalLM forward."""
has_loss_mask = loss_mask is not None
input_sliced_sig = self.input_sliced_sig
if self.training and input_sliced_sig and labels is None:
input_sliced_sig = False
bsz, seqlen = self.shape(input_ids)
if self.use_past:
if not isinstance(batch_valid_length, Tensor):
batch_valid_length = self.ones((bsz,), mstype.int32)
if not input_sliced_sig and self.training:
tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1))
if has_loss_mask:
loss_mask = self.slice(loss_mask, (0, 0), (bsz, seqlen - 1), (1, 1))
else:
tokens = input_ids
if batch_valid_length is not None:
batch_valid_length = self.reshape(batch_valid_length, (-1,))
output = self.model(tokens, input_embeds, batch_valid_length, batch_index, zactivate_len, block_tables, \
slot_mapping, prefix_keys_values, attention_mask, position_ids, q_seq_lens, \
seq_range, actual_seq_len)
pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None
output = self.pre_gather_func(pre_gather, output, batch_valid_length, gather_index)
logits = self.lm_head(output)
input_mask = loss_mask if has_loss_mask \
else self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32)
if self.rl_config is not None:
return logits
if labels is None:
labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1))
else:
if labels.ndim > 1:
if not input_sliced_sig and self.training:
labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1))
if not has_loss_mask:
label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32)
input_mask = self.mul(input_mask, label_mask)
if not self.training:
logits = self.cast(logits, mstype.float32)
if self.predict_run_mode:
logits = self.reshape(logits, (-1, logits.shape[-1]))
return logits
return logits, tokens, input_mask
if logits.ndim > 2:
logits = self.reshape(logits, (-1, logits.shape[-1]))
logits = self.cast(logits, mstype.float32)
labels = self.reshape(labels, (-1,))
input_mask = self.reshape(input_mask, (-1,))
loss = self.loss(logits, labels, input_mask)
return loss
def kvcache(self, layer_idx):
key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache
value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache
return key_cache, value_cache
@classmethod
def convert_name(cls, weight_name):
"""convert HuggingFace weight name to MindFormers weight name"""
origin_name = weight_name
weight_name = weight_name.replace('embed_tokens.', 'tok_embeddings.')
weight_name = weight_name.replace('.self_attn.q_proj.', '.attention.wq.')
weight_name = weight_name.replace('.self_attn.k_proj.', '.attention.wk.')
weight_name = weight_name.replace('.self_attn.v_proj.', '.attention.wv.')
weight_name = weight_name.replace('.self_attn.o_proj.', '.attention.wo.')
weight_name = weight_name.replace('.mlp.gate_proj.', '.feed_forward.w1.')
weight_name = weight_name.replace('.mlp.down_proj.', '.feed_forward.w2.')
weight_name = weight_name.replace('.mlp.up_proj.', '.feed_forward.w3.')
weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.')
weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.')
weight_name = weight_name.replace('.norm.', '.norm_out.')
weight_name = weight_name.replace('output.', 'lm_head.')
weight_name = weight_name.replace('.tok_embeddings.weight', '.tok_embeddings.embedding_weight')
if weight_name == origin_name:
logger.warning(f"weight name '{weight_name}' does not change after conversion. "
f"Please check if it is as expected.")
return weight_name
@classmethod
def convert_weight_dict(cls, source_dict, **kwargs):
"""convert HuggingFace weight dict to MindFormers weight dict"""
model_config = kwargs.get("model_config")
qkv_concat = model_config.qkv_concat if 'qkv_concat' in dir(model_config) else False
target_dict = {}
wq_keys = []
wk_keys = []
wv_keys = []
w1_keys = []
w3_keys = []
for k, v in source_dict.items():
k = cls.convert_name(k)
target_dict.update({k: v})
if qkv_concat:
part = k.split('.')
if part[-2] == 'wq':
wq_keys.append(k)
if part[-2] == 'wk':
wk_keys.append(k)
if part[-2] == 'wv':
wv_keys.append(k)
if part[-2] == 'w1':
w1_keys.append(k)
if part[-2] == 'w3':
w3_keys.append(k)
if qkv_concat:
qkv_dict = kwargs.get('qkv_dict', None)
if not isinstance(qkv_dict, DictProxy):
raise ValueError(f'qkv_queue must be a queue, when qkv_concat is True, but got {qkv_dict}.')
condition = kwargs.get('condition', None)
if not isinstance(condition, Condition):
raise ValueError(f'condition must be a Condition, when qkv_concat is True, but got {condition}.')
_concat_qkv_weight(wq_keys, wk_keys, wv_keys, model_config, qkv_dict, condition, target_dict)
_concat_ffn_weight(w1_keys, w3_keys, model_config, qkv_dict, condition, target_dict)
return target_dict
@classmethod
def convert_map_dict(cls, source_dict, **kwargs):
"""convert HuggingFace map dict to MindFormers map dict"""
qkv_concat = kwargs.pop("qkv_concat", False)
target_dict = {}
wq_keys = []
w1_keys = []
for k, v in source_dict.items():
k = cls.convert_name(k)
target_dict.update({k: v})
if qkv_concat:
part = k.split('.')
if part[-2] == 'wq':
wq_keys.append(k)
if part[-2] == 'w1':
w1_keys.append(k)
if qkv_concat:
for wq_key in wq_keys:
wk_key = wq_key.replace('wq', 'wk')
wv_key = wq_key.replace('wq', 'wv')
wq_value = target_dict.pop(wq_key)
target_dict.pop(wk_key)
target_dict.pop(wv_key)
w_qkv_key = wq_key.replace('wq', 'w_qkv')
w_qkv_value = wq_value
target_dict.update({w_qkv_key: w_qkv_value})
for w1_key in w1_keys:
w3_key = w1_key.replace('w1', 'w3')
w1_value = target_dict.pop(w1_key)
target_dict.pop(w3_key)
w_gate_hidden_key = w1_key.replace('w1', 'w_gate_hidden')
w_gate_hidden_value = w1_value
target_dict.update({w_gate_hidden_key: w_gate_hidden_value})
return target_dict
@classmethod
def obtain_qkv_ffn_concat_keys(cls):
qkv_key = "w_qkv"
ffn_key = "w_gate_hidden"
concat_keys = [qkv_key, ffn_key]
logger.info(f"{cls.__name__} qkv/ffn concat keys are {concat_keys}")
return concat_keys
@classmethod
def obtain_name_map(cls, load_checkpoint_files):
name_map = dict()
for checkpoint_file in load_checkpoint_files:
with safe_open(checkpoint_file, framework="np") as f:
for k in f.keys():
name_map.update({cls.convert_name(k): k})
return name_map
def clear_kv_cache(self):
return self.model.clear_kv_cache()
def _concat_qkv_weight(wq_keys, wk_keys, wv_keys, model_config, qkv_dict, condition, target_dict):
"""concat qkv weight from dicts"""
from mindformers.utils.convert_utils import qkv_concat_hf2mg
num_heads = model_config.num_heads
n_kv_heads = model_config.n_kv_heads or num_heads
hidden_size = model_config.hidden_size
# pop extra weight to shared dict if there is no corresponding weight for concat in the target dict
for wk_key in wk_keys:
wq_key = wk_key.replace('wk', 'wq')
if wq_key not in wq_keys:
with condition:
qkv_dict[wk_key] = target_dict.pop(wk_key) # add extra weight to shared dict
condition.notify_all()
for wv_key in wv_keys:
wq_key = wv_key.replace('wv', 'wq')
if wq_key not in wq_keys:
with condition:
qkv_dict[wv_key] = target_dict.pop(wv_key) # add extra weight to shared dict
condition.notify_all()
# concat qkv
for wq_key in wq_keys:
wk_key = wq_key.replace('wq', 'wk')
wv_key = wq_key.replace('wq', 'wv')
wq_value = target_dict.pop(wq_key)
wk_value = target_dict.pop(wk_key, None)
wv_value = target_dict.pop(wv_key, None)
# get missing weight from shared dict
if wk_value is None:
with condition:
condition.wait_for(lambda: wk_key in qkv_dict.keys())
wk_value = qkv_dict.pop(wk_key)
if wv_value is None:
with condition:
condition.wait_for(lambda: wv_key in qkv_dict.keys())
wv_value = qkv_dict.pop(wv_key)
w_qkv_key = wq_key.replace('wq', 'w_qkv')
w_qkv_value = np.concatenate((wq_value, wk_value, wv_value), 0)
# qkv weight format: hf -> mg
w_qkv_value_mg = qkv_concat_hf2mg(w_qkv_value, num_heads, n_kv_heads, hidden_size)
target_dict.update({w_qkv_key: w_qkv_value_mg})
def _concat_ffn_weight(w1_keys, w3_keys, model_config, qkv_dict, condition, target_dict):
"""concat ffn weight from dicts"""
from mindformers.utils.convert_utils import ffn_concat_hf2mg
intermediate_size = model_config.intermediate_size
ffn_dim_multiplier = model_config.ffn_dim_multiplier
multiple_of = model_config.multiple_of or 256
ffn_hidden_size = model_config.hidden_size * 4
if intermediate_size is not None:
ffn_hidden_size = intermediate_size
else:
if ffn_dim_multiplier is not None:
ffn_hidden_size = int((ffn_dim_multiplier + 0.01) * ffn_hidden_size)
ffn_hidden_size = int(2 * ffn_hidden_size / 3)
ffn_hidden_size = multiple_of * \
((ffn_hidden_size + multiple_of - 1) // multiple_of)
# pop extra weight to shared dict if there is no corresponding weight for concat in the target dict
for w3_key in w3_keys:
w1_key = w3_key.replace('w3', 'w1')
if w1_key not in w1_keys:
with condition:
qkv_dict[w3_key] = target_dict.pop(w3_key) # add extra weight to shared dict
condition.notify_all()
# concat ffn
for w1_key in w1_keys:
w3_key = w1_key.replace('w1', 'w3')
w1_value = target_dict.pop(w1_key)
w3_value = target_dict.pop(w3_key, None)
# get missing weight from shared dict
if w3_value is None:
with condition:
condition.wait_for(lambda: w3_key in qkv_dict.keys())
w3_value = qkv_dict.pop(w3_key)
w_gate_hidden_key = w1_key.replace('w1', 'w_gate_hidden')
w_gate_hidden_value = np.concatenate((w1_value, w3_value), 0)
# ffn weight format: hf -> mg
w_gate_hidden_value_mg = ffn_concat_hf2mg(w_gate_hidden_value, ffn_hidden_size)
target_dict.update({w_gate_hidden_key: w_gate_hidden_value_mg})