mindformers.pet.pet_config.LoraConfig

View Source On Gitee
class mindformers.pet.pet_config.LoraConfig(lora_rank: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.01, lora_a_init: str = 'normal', lora_b_init: str = 'zero', param_init_type: str = 'float16', compute_dtype: str = 'float16', target_modules: str = None, exclude_layers: str = None, freeze_include: List[str] = None, freeze_exclude: List[str] = None, **kwargs)[source]

LoRA algorithm config. Used to set parameters for LoRA model runtime.

Parameters
  • lora_rank (int, optional) – The number of rows(columns) in LoRA matrices. Default: 8.

  • lora_alpha (int, optional) – A constant in lora_rank. Default: 16.

  • lora_dropout (float, optional) – The dropout rate, greater equal than 0 and less than 1. Default: 0.01.

  • lora_a_init (str, optional) – The initialization strategy of LoRA A matrix. Refers to (https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html). Default: normal.

  • lora_b_init (str, optional) – The initialization strategy of LoRA B matrix. Refers to (https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html). Default: zero.

  • param_init_type (str, optional) – The type of data in initialized tensor. Default: float16.

  • compute_dtype (str, optional) – The compute type of data. Default: float16.

  • target_modules (str, optional) – The layers that require replacement with LoRA algorithm. Default: None.

  • exclude_layers (str, optional) – The layers that do not require replacement with the LoRA algorithm. Default: None.

  • freeze_include (List[str], optional) – List of modules to be frozen. Default: None.

  • freeze_exclude (List[str], optional) – List of modules that do not need to be frozen. When an item in the freeze_include and freeze_exclude list conflicts, the module that matches this item is not processed. Default: None.

Returns

An instance of LoraConfig.

Examples

>>> from mindformers.pet.pet_config import LoraConfig
>>> config = LoraConfig(target_modules='.*wq|.*wk|.*wv|.*wo')
>>> print(config)
{'pet_type': 'lora', 'lora_rank': 8, 'lora_alpha': 16,
'lora_dropout': 0.01, 'lora_a_init': 'normal', 'lora_b_init'
: 'zero', 'param_init_type': mindspore.float16, 'compute_dtype':
mindspore.float16, 'target_modules': '.*wq|.*wk|.*wv|.*wo', 'exclude_layers': None
, 'freeze_include': None, 'freeze_exclude': None}