# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Llama Config API."""
from typing import Optional, Union
from mindspore._checkparam import args_type_check
from mindformers.modules.transformer.moe import MoEConfig
from mindformers.modules.transformer.transformer import default_transformer_config, \
TransformerOpParallelConfig, default_moe_config
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from mindformers.models.configuration_utils import PretrainedConfig
from mindformers.models.utils import convert_mstype
from mindformers.mindformer_book import MindFormerBook
__all__ = ['LlamaConfig']
[docs]@MindFormerRegister.register(MindFormerModuleType.CONFIG)
class LlamaConfig(PretrainedConfig):
"""
Llama config class which defines the model size.
Args:
batch_size (int, optional): batch size for input data, use in predict. Default: ``1``.
seq_length (int, optional): The sequence length of input_ids. Default: ``2048``.
vocab_size (int, optional): Default: ``32000``.
Vocabulary size of the BERT model.
hidden_size (int, optional):
Dimensionality of the encoder layers and the pooler layer. Default: ``4096``.
num_layers (int, optional):
Number of hidden layers in the Transformer encoder. Default: ``32``.
num_heads (int, optional):
Number of attention heads for each attention layer in the Transformer encoder. Default: ``32``.
multiple_of (int, optional): Define SwiGLU hidden layer size multiples. Default: ``256``.
n_kv_heads (int, optional): Define multi group head attention heads number. Default: ``None``.
ffn_dim_multiplier (int, optional): Define ffn layer dim multiples. Default: ``None``.
rms_norm_eps (float, optional): The epsilon value of the denominator. Default: ``1e-5``.
bos_token_id (int, optional): The id of the *beginning-of-sequence* token. Default: ``1``.
eos_token_id (int, optional): The id of the *end-of-sequence* token. Default: ``2``.
pad_token_id (int, optional): The id of the *padding* token. Default: ``0``.
ignore_token_id (int, optional): The id of the *ignoring* token. Default: ``-100``.
compute_dtype (str, optional):
Linear layer compute dtype. Default: ``float16``.
layernorm_compute_type (str, optional):
layernorm compute dtype. Default: ``float32``.
softmax_compute_type (str, optional):
softmax compute dtype. Default: ``float32``.
rotary_dtype (str, optional):
rope compute dtype. Default: ``float32``.
param_init_type (str, optional):
parameter initial dtype. Default: ``float16``.
qkv_has_bias (bool, optional):
Whether the Query, Key, and Value projection has bias. Default: ``False``.
use_past (bool, optional):
Whether the model should use the past last key/values attentions
(if applicable to the model) to speed up decoding. Default: ``False``.
parallel_config(TransformerOpParallelConfig):
The parallel configure. Default: ``default_transformer_config`` ,
an instance of `TransformerOpParallelConfig` with default args.
extend_method(str, optional): The extent method of seq length of inference. Default: ``None``.
use_flash_attention(bool, optional): Whether enable flash attention ops. Default: ``False``.
use_ring_attention(bool, optional): Whether enable ring attention ops. Default: ``False``.
offset(int, optional): Offset of transformer layer when set pipeline stage number. Default: ``0``.
checkpoint_name_or_path (str, optional):
checkpoint path or name used to load to the network. Default: ``None``.
repetition_penalty (float, optional):
The parameter for repetition penalty. 1.0 means no penalty.
See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`_ for more details. Default: ``1.0``.
max_decode_length (int, optional):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. Its effect is overridden by `max_new_tokens`, if also set. Default: ``1024``.
top_k (int, optional):
The number of highest probability vocabulary tokens to keep for top-k-filtering. Default: ``5``.
top_p (float, optional):
If set to float < 1, only the smallest set of most probable tokens with probabilities
that add up to `top_p` or higher are kept for generation. Default: ``1.0``.
do_sample (bool, optional):
Whether to use sampling; use greedy decoding otherwise. Default: ``True``.
block_size (int, optional):
The maximum number of tokens in one block can have when using paged attention. Default: ``16``.
num_blocks (int, optional):
The maximum number of blocks when using paged attention. Default: ``512``.
tie_word_embeddings (bool, optional):
Whether to tie input and output embeddings. Default: ``False``.
llm_backend (str, optional):
Llm boost backend. Default: ``None``.
fused_rms_norm (bool, optional):
Whether or not to use the RMS_NORM of the fusion operator. Default: ``True``.
Returns:
LlamaConfig, a LlamaConfig instance.
Examples:
>>> from mindformers.models import LlamaConfig
>>> config = LlamaConfig(num_layers=2, seq_length=1024)
>>> print(config)
LlamaConfig {
"batch_size": 1,
"block_size": 16,
"bos_token_id": 1,
"checkpoint_name_or_path": "",
"compute_dtype": "float16",
"do_sample": true,
"embedding_init_type": "float16",
"eos_token_id": 2,
"extend_method": "None",
"ffn_dim_multiplier": null,
"fine_grain_interleave": 1,
"hidden_size": 4096,
"ignore_token_id": -100,
"intermediate_size": null,
"is_dynamic": false,
"layernorm_compute_type": "float32",
"llm_backend": "",
"max_decode_length": 1024,
"max_position_embedding": 1024,
"mindformers_version": "dev",
"model_type": "llama",
"multiple_of": 256,
"n_kv_heads": null,
"num_blocks": 512,
"num_heads": 32,
"num_layers": 2,
"offset": 0,
"pad_token_id": 0,
"parallel_decoding_params": null,
"parallel_optimizer": false,
"param_init_type": "float16",
"pp_interleave_num": 1,
"qkv_concat": false,
"qkv_has_bias": false,
"quant_config": null,
"repetition_penalty": 1.0,
"rms_norm_eps": 1e-05,
"rotary_dtype": "float32",
"scaling_factor": 1.0,
"seq_length": 1024,
"softmax_compute_type": "float32",
"theta": 10000.0,
"tie_word_embeddings": false,
"top_k": 5,
"top_p": 1.0,
"use_attn_mask_compression": false,
"use_flash_attention": false,
"use_past": false,
"use_ring_attention": false,
"use_rope_slice": false,
"vocab_size": 32000
}
"""
model_type = "llama"
_support_list = MindFormerBook.get_config_support_list()['llama']
@args_type_check(parallel_config=(dict, TransformerOpParallelConfig))
def __init__(self,
batch_size: int = 1,
seq_length: int = 2048,
hidden_size: int = 4096,
num_layers: int = 32,
num_heads: int = 32,
n_kv_heads: Optional[int] = None,
max_position_embedding: Optional[int] = None,
intermediate_size: Optional[int] = None,
vocab_size: int = 32000,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[int] = None,
rms_norm_eps: float = 1e-5,
bos_token_id: int = 1,
eos_token_id: int = 2,
pad_token_id: int = 0,
ignore_token_id: int = -100,
theta: float = 10000.0,
compute_dtype: str = "float16",
layernorm_compute_type: str = "float32",
softmax_compute_type: str = "float32",
rotary_dtype: str = "float32",
param_init_type: str = "float16",
embedding_init_type=None,
qkv_has_bias: bool = False,
qkv_concat: bool = False,
parallel_config: Union[dict, TransformerOpParallelConfig] = default_transformer_config,
moe_config: Union[dict, MoEConfig] = default_moe_config,
use_past: bool = False,
extend_method: str = "None",
scaling_factor: float = 1.0,
is_dynamic: bool = False,
use_rope_slice: bool = False,
use_flash_attention: bool = False,
use_ring_attention: bool = False,
use_attn_mask_compression: bool = False,
parallel_optimizer: bool = False,
fine_grain_interleave: int = 1,
pp_interleave_num: int = 1,
offset: int = 0,
checkpoint_name_or_path: str = "",
repetition_penalty: float = 1.0,
max_decode_length: int = 1024,
block_size: int = 16,
num_blocks: int = 512,
top_k: int = 5,
top_p: float = 1.0,
do_sample: bool = True,
quant_config: dict = None,
tie_word_embeddings: bool = False,
llm_backend: str = "",
fused_rms_norm: bool = True,
**kwargs):
"""
Note:
vocab_size: int = 32000, # defined later by tokenizer
multiple_of: int = 256, # make SwiGLU hidden layer size multiple of large power of 2
"""
super(LlamaConfig, self).__init__(**kwargs)
if isinstance(parallel_config, dict):
parallel_config = TransformerOpParallelConfig(**parallel_config)
if isinstance(moe_config, dict):
moe_config = MoEConfig(**moe_config)
self.batch_size = batch_size
self.seq_length = seq_length
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.max_position_embedding = max_position_embedding if max_position_embedding else seq_length
self.intermediate_size = intermediate_size
self.multiple_of = multiple_of
self.n_kv_heads = n_kv_heads
self.ffn_dim_multiplier = ffn_dim_multiplier
self.rms_norm_eps = rms_norm_eps
self.qkv_concat = qkv_concat
self.param_init_type = convert_mstype(param_init_type)
if embedding_init_type is not None:
self.embedding_init_type = convert_mstype(embedding_init_type)
else:
self.embedding_init_type = self.param_init_type
self.qkv_has_bias = qkv_has_bias
self.layernorm_compute_type = convert_mstype(layernorm_compute_type)
self.softmax_compute_type = convert_mstype(softmax_compute_type)
self.rotary_dtype = convert_mstype(rotary_dtype)
self.compute_dtype = convert_mstype(compute_dtype)
self.parallel_config = parallel_config
self.moe_config = moe_config
self.checkpoint_name_or_path = checkpoint_name_or_path
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.ignore_token_id = ignore_token_id
self.use_past = use_past
self.extend_method = extend_method
self.scaling_factor = scaling_factor
self.is_dynamic = is_dynamic
self.use_rope_slice = use_rope_slice
self.use_flash_attention = use_flash_attention
self.use_ring_attention = use_ring_attention
self.use_attn_mask_compression = use_attn_mask_compression
self.parallel_optimizer = parallel_optimizer
self.fine_grain_interleave = fine_grain_interleave
self.pp_interleave_num = pp_interleave_num
self.offset = offset
self.repetition_penalty = repetition_penalty
self.max_decode_length = max_decode_length
self.top_k = top_k
self.top_p = top_p
self.do_sample = do_sample
self.theta = theta
self.block_size = block_size
self.num_blocks = num_blocks
self.quant_config = quant_config
self.tie_word_embeddings = tie_word_embeddings
self.llm_backend = llm_backend
self.parallel_decoding_params = kwargs.get('parallel_decoding_params')
self.fused_rms_norm = fused_rms_norm