mindformers.pet.models.LoraModel
- class mindformers.pet.models.LoraModel(config: LoraConfig, base_model: PreTrainedModel)[源代码]
LLM的LoRA模型。提供了一种灵活且高效的方式来调整和优化预训练模型,为基础预训练模型添加LoRA结构。
- 参数:
config (LoraConfig) - 低参微调(Pet)算法的配置基类。
base_model (PreTrainedModel) - 用于调优的预训练模型。
- 返回:
LoraModel实例。
样例:
>>> import mindspore as ms >>> from mindformers.pet import LoraModel, LoraConfig >>> from mindformers.models import LlamaConfig, LlamaForCausalLM >>> ms.set_context(mode=0) >>> config = LlamaConfig(num_layers=2) >>> lora_config = LoraConfig(target_modules='.*wq|.*wk|.*wv|.*wo') >>> model = LlamaForCausalLM(config) >>> lora_model = LoraModel(lora_config,model) >>> print(lora_model.lora_model) LlamaForCausalLM< (model): LlamaModel< (freqs_mgr): FreqsMgr<> (casual_mask): LowerTriangularMaskWithDynamic<> (tok_embeddings): LlamaEmbedding<> (layers): CellList< (0): LLamaDecodeLayer< (ffn_norm): LlamaRMSNorm<> (attention_norm): LlamaRMSNorm<> (attention): LLamaAttention< (wq): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (wk): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (wv): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (wo): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (apply_rotary_emb): RotaryEmbedding<> > (feed_forward): LlamaFeedForward< (w1): Linear< (activation): LlamaSiLU<> > (w2): Linear<> (w3): Linear<> > > (1): LLamaDecodeLayer< (ffn_norm): LlamaRMSNorm<> (attention_norm): LlamaRMSNorm<> (attention): LLamaAttention< (wq): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (wk): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (wv): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (wo): LoRADense< input_channels=4096, output_channels=4096 (lora_dropout): Dropout<p=0.01> > (apply_rotary_emb): RotaryEmbedding<> > (feed_forward): LlamaFeedForward< (w1): Linear< (activation): LlamaSiLU<> > (w2): Linear<> (w3): Linear<> > > > (norm_out): LlamaRMSNorm<> > (lm_head): Linear<> (loss): CrossEntropyLoss< (_log_softmax): _LogSoftmax<> (_nllloss): _NLLLoss<> > >