mindformers.pet.models.LoraModel

查看源文件
class mindformers.pet.models.LoraModel(config: LoraConfig, base_model: PreTrainedModel)[源代码]

LLM的LoRA模型。提供了一种灵活且高效的方式来调整和优化预训练模型,为基础预训练模型添加LoRA结构。

参数:
  • config (LoraConfig) - 低参微调(Pet)算法的配置基类。

  • base_model (PreTrainedModel) - 用于调优的预训练模型。

返回:

LoraModel实例。

样例:

>>> import mindspore as ms
>>> from mindformers.pet import LoraModel, LoraConfig
>>> from mindformers.models import LlamaConfig, LlamaForCausalLM
>>> ms.set_context(mode=0)
>>> config = LlamaConfig(num_layers=2)
>>> lora_config = LoraConfig(target_modules='.*wq|.*wk|.*wv|.*wo')
>>> model = LlamaForCausalLM(config)
>>> lora_model = LoraModel(lora_config,model)
>>> print(lora_model.lora_model)
LlamaForCausalLM<
(model): LlamaModel<
(freqs_mgr): FreqsMgr<>
(casual_mask): LowerTriangularMaskWithDynamic<>
(tok_embeddings): LlamaEmbedding<>
(layers): CellList<
(0): LLamaDecodeLayer<
(ffn_norm): LlamaRMSNorm<>
(attention_norm): LlamaRMSNorm<>
(attention): LLamaAttention<
(wq): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(wk): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(wv): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(wo): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(apply_rotary_emb): RotaryEmbedding<>
>
(feed_forward): LlamaFeedForward<
(w1): Linear<
(activation): LlamaSiLU<>
>
(w2): Linear<>
(w3): Linear<>
>
>
(1): LLamaDecodeLayer<
(ffn_norm): LlamaRMSNorm<>
(attention_norm): LlamaRMSNorm<>
(attention): LLamaAttention<
(wq): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(wk): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(wv): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(wo): LoRADense<
input_channels=4096, output_channels=4096
(lora_dropout): Dropout<p=0.01>
>
(apply_rotary_emb): RotaryEmbedding<>
>
(feed_forward): LlamaFeedForward<
(w1): Linear<
(activation): LlamaSiLU<>
>
(w2): Linear<>
(w3): Linear<>
>
>
>
(norm_out): LlamaRMSNorm<>
>
(lm_head): Linear<>
(loss): CrossEntropyLoss<
(_log_softmax): _LogSoftmax<>
(_nllloss): _NLLLoss<>
>
>