mindspore_gs.ptq.ptq_config 源代码

# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""algorithm related configs"""

from dataclasses import dataclass, field, is_dataclass, asdict
from enum import Enum
from typing import List

from mindspore import QuantDtype

from mindspore_gs.common.config import GSBaseConfig
from mindspore_gs.common.utils import value_check
from mindspore_gs.common.register import RegisterMachine
from mindspore_gs.common.gs_enum import QuantCellType, BackendTarget

algo_cfg_register = RegisterMachine()


class PTQApproach(Enum):
    """
    PTQ approach enums
    """
    SMOOTH_QUANT = 'smooth_quant'
    RTN = 'rtn'
    GPTQ = 'gptq'


[文档]class PTQMode(Enum): """ Mode for ptq quantizer. - ``QUANTIZE``: indicate ptq quantizer in quantize mode. - ``DEPLOY``: indicate ptq quantizer in deploy mode. """ QUANTIZE = 'quantize' DEPLOY = 'deploy'
@algo_cfg_register.register(PTQApproach.SMOOTH_QUANT) @dataclass class SmoothQuantConfig: """config for smooth quant algorithm""" alpha: float = 0.5 is_deploy: bool = False def __post_init__(self): value_check('alpha', self.alpha, float) value_check('is_deploy', self.is_deploy, bool) @algo_cfg_register.register(PTQApproach.RTN) @dataclass class RTNConfig: """ Config for round to nearest algorithms. """ @dataclass class QuantizerConfig(GSBaseConfig): """ quantize related config """ bit_num: int = 8 optypes_exclude_output_quant: List[str] = field(default_factory=lambda: []) algo_args: dict = field(default_factory=lambda: {}) def __post_init__(self): value_check('bit_num', self.bit_num, int) value_check('optypes_exclude_output_quant', self.optypes_exclude_output_quant, str) value_check('algo_args', self.algo_args, dict)
[文档]@dataclass class PTQConfig: """ Config for post trainning quantization. Args: mode (:class:`mindspore_gs.ptq.PTQMode`): Flag for ptq mode, ``QUANTIZATION`` for quantization mode, ``DEPLOY`` for deploy mode. backend (:class:`mindspore_gs.ptq.BackendTarget`): Flag for backend target, ``NONE`` for no specific backend, ``ASCEND`` for ascend backend. Raises: ValueError: If `mode` is not in PTQMode's members. ValueError: If `backend` is not in BackendTarget's members. Example: >>> import mindspore_gs >>> from mindspore_gs import ptq >>> from mindspore_gs import common >>> from mindspore_gs.ptq import PTQConfig, PTQMode >>> from mindspore_gs.common import BackendTarget >>> ascend_config = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND) >>> print(ascend_config) >>> PTQConfig(mode=<PTQMode.DEPLOY: 'deploy'>, backend=<BackendTarget.ASCEND: 'ascend'>) """ mode: PTQMode = field(default=PTQMode.QUANTIZE, metadata={'valid_values': PTQMode.__members__.values()} ) backend: BackendTarget = field(default=BackendTarget.NONE, metadata={'valid_values': BackendTarget.__members__.values()} ) def __post_init__(self): if self.mode not in PTQMode.__members__.values(): raise ValueError(f'{self.mode} is not supported, mode shall be in {PTQMode.__members__.values()}') if self.backend not in {item for item in BackendTarget.__members__.values()}: raise ValueError(f'{self.backend} is not supported, backend shall be in ' f'{BackendTarget.__members__.values()}')
@dataclass class InnerPTQConfig(QuantizerConfig, PTQConfig): """ config for post-trainning-quantizer """ approach: PTQApproach = field(default=PTQApproach.RTN, metadata={'valid_values': PTQApproach.__members__.values()} ) calibration_sampling_size: int = 0 act_quant_dtype: QuantDtype = QuantDtype.INT8 weight_quant_dtype: QuantDtype = QuantDtype.INT8 weight_only: bool = True enable_kvcache_int8: bool = False act_per_channel: bool = False weight_per_channel: bool = True act_symmetric: bool = False weight_symmetric: bool = True act_narrow_range: bool = False weight_narrow_range: bool = False op_types: List[str] = field(default_factory=lambda: [QuantCellType.MF_LINEAR.value], metadata={'choices': [ item.value for item in QuantCellType.__members__.values() ]}) def __post_init__(self): value_check('calibration_sampling_size', self.calibration_sampling_size, int) value_check('act_quant_dtype', self.act_quant_dtype, QuantDtype) value_check('weight_quant_dtype', self.weight_quant_dtype, QuantDtype) value_check('weight_only', self.weight_only, bool) value_check('enable_kvcache_int8', self.enable_kvcache_int8, bool) value_check('act_per_channel', self.act_per_channel, bool) value_check('weight_per_channel', self.weight_per_channel, bool) value_check('act_symmetric', self.weight_symmetric, bool) value_check('act_narrow_range', self.act_narrow_range, bool) value_check('weight_narrow_range', self.weight_narrow_range, bool) if self.approach not in PTQApproach.__members__.values(): raise ValueError(f'Invalid approach: {self.approach}') support_op_types = { item.value for item in QuantCellType.__members__.values() } for op_type in self.op_types: if op_type not in support_op_types: raise ValueError(f'{op_type} is not supported, all support type is {support_op_types}') if not self.algo_args: args_config = algo_cfg_register[self.approach] if args_config is not None and is_dataclass(args_config): self.algo_args.update(asdict(args_config())) def value_check(self): """value check""" self.__post_init__() def _parse_dict(self): """ parse data class to readable dicts""" parsed_dict = self.__dict__ parsed_dict['act_quant_dtype'] = self.act_quant_dtype.name parsed_dict['weight_quant_dtype'] = self.weight_quant_dtype.name parsed_dict['backend'] = self.backend.name parsed_dict['mode'] = self.mode.name parsed_dict['approach'] = self.approach.name return parsed_dict def _unparse_dict(self, data_dict): """ convert readable dicts to data config""" def update_dict(key, enum_name): nonlocal data_dict if key not in data_dict: raise ValueError(f'{key} shall in yaml, but not found') data_dict[key] = enum_name[data_dict[key]] unparse_list = [ ('act_quant_dtype', QuantDtype), ('weight_quant_dtype', QuantDtype), ('mode', PTQMode), ('backend', BackendTarget), ('approach', PTQApproach) ] for item in unparse_list: update_dict(*item) self.__dict__.update(data_dict) @staticmethod def inner_config(cfg: PTQConfig): """convert PTQConfig to InnerConfig""" if not isinstance(cfg, PTQConfig): raise TypeError(f'input config shall be PTQConfig, but got {type(cfg)}') inner_cfg = InnerPTQConfig() for key, val in asdict(cfg).items(): setattr(inner_cfg, key, val) return inner_cfg