mindspore_gs.ptq.PTQ

class mindspore_gs.ptq.PTQ(config: Union[dict, PTQConfig] = None)[源代码]

量化算法PTQ的基本实现,支持激活、权重和kvcache的组合量化。

参数:
异常:
  • TypeError - config 在输入不为 None 时,元素类型不为 PTQConfig。

  • ValueError - config 中的 mode 是PTQMode.QUANTIZE时非PYNATIVE模式。

  • ValueError - 当act_quant_dtype是int8类型,weight_quant_dtype为None时。

样例:

>>> import mindspore_gs
>>> from mindspore_gs.ptq import PTQ
>>> from mindspore_gs.ptq import PTQConfig
>>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper
>>> from mindformers.tools.register.config import MindFormerConfig
>>> from mindformers import LlamaForCausalLM, LlamaConfig
>>> mf_yaml_config_file = "/path/to/mf_yaml_config_file"
>>> mfconfig = MindFormerConfig(mf_yaml_config_file)
>>> helper = MFLlama2Helper(mfconfig)
>>> ptq_config = PTQConfig(mode=PTQMode.QUANTIZE, backend=backend, opname_blacklist=["w2", "lm_head"],
weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8,
outliers_suppression=OutliersSuppressionType.SMOOTH)
>>> ptq = PTQ(ptq_config)
>>> network = LlamaForCausalLM(LlamaConfig(**mfconfig.model.model_config))
>>> fake_quant_net = ptq.apply(network, helper)
>>> quant_net = ptq.convert(fake_quant_net)
apply(network: Cell, network_helper: NetworkHelper = None, datasets=None, **kwargs)[源代码]

network 中添加伪量化节点,转换成一个伪量化网络。

参数:
  • network (Cell) - 待伪量化的网络。

  • network_helper (NetworkHelper) - 网络量化工具,用于解耦算法层和网络框架层。

  • datasets (Dataset) - 校准用的数据集。

返回:

伪量化后的网络。

异常:
  • RuntimeError - 如果当前算法没有有效的初始化。

  • TypeError - network 不是一个 Cell 对象。

  • ValueError - PTQMode.DEPLOY 模式时,network_helper 为空。

  • ValueError - 当datasets为空。

convert(net_opt: Cell, ckpt_path='')[源代码]

将量化网络 net_opt 转换为真实量化网络,后续导出用于部署。

参数:
  • net_opt (Cell) - 经过量化算法apply之后的网络。

  • ckpt_path (str) - 网络的checkpoint file文件路径,默认值为 "",表示不加载。注意,该参数会在后续版本中被遗弃。

返回:

转换后的网络。

异常:
  • TypeError - net_opt 数据类型不是Cell。

  • TypeError - ckpt_path 数据类型不是str。

  • ValueError - ckpt_path 非空但不是有效路径。