# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PTQ algorithm."""
from typing import List, Union, Tuple
import time
import os
import copy
import tqdm
from mindspore import dtype, get_context, PYNATIVE_MODE
from mindspore.nn import Cell
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore_gs.comp_algo import CompAlgo
from mindspore_gs.common import logger
from mindspore_gs.common.utils import offload_network, value_check
from mindspore_gs.ptq.processor import transform_network_inplace
from mindspore_gs.ptq.ptq_config import PTQConfig, InnerPTQConfig, PTQApproach, PTQMode, OutliersSuppressionType
from mindspore_gs.ptq.network_helpers import NetworkHelper
from mindspore_gs.ptq.ptq.wrapper_cell import WrapperCell
from mindspore_gs.ptq.processor import Processor
from .algorithm import Algorithm
from .algorithms import LinearSmoother, Quantizer, Deployer
class InputCatcher(Cell):
"""input catcher"""
def __init__(self, handler):
super().__init__()
self.handler = handler
if hasattr(handler, "attention"):
self.attention = handler.attention
self.args = []
self.kwargs = []
def construct(self, *args, **kwargs):
self.args.append(list(args))
self.kwargs.append(kwargs)
raise GeneratorExit("already catch first layer inputs, do not need continue.")
[docs]class PTQ(CompAlgo):
"""
Implementation of PTQ algorithm which supports the combination quantization of activation,
weight, and kvcache.
Args:
config(:class:`mindspore_gs.ptq.PTQConfig`, optional): config for PTQ, default is ``None``.
Raises:
TypeError: If `config` type is not PTQConfig when it's not ``None``.
ValueError: If not PYNATIVE mode when mode in config is PTQMode.QUANTIZE.
ValueError: If act_quant_dtype is int8 and weight_quant_dtype is None.
Examples:
>>> import mindspore_gs
>>> from mindspore_gs.ptq import PTQ
>>> from mindspore_gs.ptq import PTQConfig
>>> from mindspore_gs.ptq.network_helpers.mf_net_helpers import MFLlama2Helper
>>> from mindformers.tools.register.config import MindFormerConfig
>>> from mindformers import LlamaForCausalLM, LlamaConfig
>>> mf_yaml_config_file = "/path/to/mf_yaml_config_file"
>>> mfconfig = MindFormerConfig(mf_yaml_config_file)
>>> helper = MFLlama2Helper(mfconfig)
>>> ptq_config = PTQConfig(mode=PTQMode.QUANTIZE, backend=backend, opname_blacklist=["w2", "lm_head"],
weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8,
outliers_suppression=OutliersSuppressionType.SMOOTH)
>>> ptq = PTQ(ptq_config)
>>> network = LlamaForCausalLM(LlamaConfig(**mfconfig.model.model_config))
>>> fake_quant_net = ptq.apply(network, helper)
>>> quant_net = ptq.convert(fake_quant_net)
"""
def __init__(self, config: Union[dict, PTQConfig] = None):
super().__init__()
if config is not None:
if not isinstance(config, PTQConfig):
raise TypeError(f'Shall init PTQ with PTQConfig, bug got {type(config)}')
self._config = config
else:
self._config = PTQConfig()
# convert PTQConfig to InnerConfig to add inner parameters
self._config = InnerPTQConfig().inner_config(self._config, approach=PTQApproach.PTQ)
logger.info(f"Config for PTQ: {self._config}")
PTQ._ptq_config_check(self._config)
self.pipeline: List[Algorithm] = []
self.decoder_layer_types: list = []
self._build_pipeline()
self._load_mindformers_plugin()
self.context_mode = get_context("mode")
def _build_pipeline(self):
"""build pipline"""
if self._config.mode == PTQMode.QUANTIZE:
if self._config.outliers_suppression == OutliersSuppressionType.SMOOTH:
logger.info("Adding LinearSmoother to pipeline.")
self.pipeline.append(LinearSmoother(self._config))
if self._config.act_quant_dtype == dtype.int8 or \
self._config.weight_quant_dtype == dtype.int8 or \
self._config.kvcache_quant_dtype == dtype.int8:
logger.info("Adding Quantizer to pipeline.")
self.pipeline.append(Quantizer(self._config))
elif self._config.mode == PTQMode.DEPLOY:
logger.info("Adding Deploy to pipeline.")
self.pipeline.append(Deployer(self._config))
def _load_mindformers_plugin(self):
for algorithm in self.pipeline:
algorithm.load_mindformers_plugin()
from mindformers.models.llama.llama_transformer import LLamaDecodeLayer
from mindformers.experimental.infer.core.transformer import ParallelTransformerLayer
self.decoder_layer_types.append(LLamaDecodeLayer)
self.decoder_layer_types.append(ParallelTransformerLayer)
def _get_decoder_layers(self, network: Cell):
"""
Get decoder layers from network.
Args:
network (nn.Cell): Network to get decoder layers.
Returns:
A list of tuples (cell_name, `Cell`) as decoder layers of network.
"""
value_check('network', network, Cell)
class NetworkWalker(Processor):
def __init__(self, decoder_layer_types_):
self.layers = []
self._decoder_layer_types = decoder_layer_types_
def process_cell(self, cell_name: str, cell: Cell) -> Tuple[Cell, bool]:
if isinstance(cell, self._decoder_layer_types):
self.layers.append((cell_name, cell))
return cell, True
return cell, False
walker = NetworkWalker(tuple(self.decoder_layer_types))
walker.process(network)
return walker.layers
@staticmethod
def _ptq_config_check(config):
"""_ptq_config_check"""
if config.outliers_suppression is None and \
config.weight_quant_dtype == dtype.int8 and \
config.act_quant_dtype == dtype.int8:
logger.warning("When outliers_suppression is None, A8W8 algorithm accuracy is expected to decline.")
if config.weight_quant_dtype is None and \
config.act_quant_dtype == dtype.int8:
raise ValueError("PTQ algorithm not support only quant activation.")
if config.weight_quant_dtype is None and config.act_quant_dtype is None \
and config.kvcache_quant_dtype is None and \
config.outliers_suppression == OutliersSuppressionType.NONE:
logger.warning("PTQ algorithm does not quantify any layers when"
"weight_quant_dtype=None,"
"act_quant_dtype=None,"
"kvcache_quant_dtype=None and"
"outliers_suppression=None")
# pylint: disable=arguments-differ
[docs] def apply(self, network: Cell, network_helper: NetworkHelper = None, datasets=None, **kwargs) -> Cell:
"""
Define how to add fake quantizer to `network`.
Args:
network (Cell): Network to be fake quantized.
network_helper (NetworkHelper): Utils for decoupling algorithm with network framework.
datasets (Dataset): Datasets for calibrating.
Returns:
fake quantized network.
Raises:
RuntimeError: If PTQ is not well inited.
TypeError: If input `network` is not a Cell.
ValueError: If input `network_helper` is None when mode is `PTQMode.DEPLOY`.
ValueError: If input datasets is None.
"""
if self._config.mode == PTQMode.DEPLOY:
layers = self._get_decoder_layers(network)
for i in tqdm.tqdm(range(len(layers)), desc="Running PTQ Deploy..."):
layer_name, layer = layers[i]
for processor in self.pipeline:
processor.replace(layer_name, layer)
processor.process(layer_name, layer)
processor.deploy(layer_name, layer)
network.update_parameters_name()
return network
if self._config.mode == PTQMode.QUANTIZE and get_context("mode") != PYNATIVE_MODE:
raise ValueError("Quantization phase only support PYNATIVE MODE.")
if not network_helper:
raise ValueError("Please provide network_helper when PTQ in apply phase.")
if not datasets:
raise ValueError("please provide dataset when use PTQ quant to quantize network.")
if self._config.kvcache_quant_dtype == dtype.int8 and not network_helper.get_spec("use_past"):
raise ValueError("use_past need be true when doing kvcache quantize.")
logger.info(f"Visible decoder layer types: {self.decoder_layer_types}. If decoder layer type of target network "
"not in list, please modify PTQ.decoder_layer_types before invoking apply method.")
start_time = time.time()
logger.info("Analysis network structure.")
network_helper.analysis_decoder_groups(network)
logger.info(f"analysis_decoder_groups time cost {time.time() - start_time}")
start_time = time.time()
logger.info(f"Catching inputs for first decoder layer with {datasets.get_dataset_size()} datasets samples.")
catcher, network = self._get_first_layer_input(network, network_helper, datasets)
all_args = catcher.args
all_kwargs = catcher.kwargs
logger.info(f"_get_first_layer_input time cost {time.time() - start_time}")
start_time = time.time()
layers = self._get_decoder_layers(network)
if not layers:
logger.warning(
f"No decoder layer found in network. Visible decoder layer types: {self.decoder_layer_types}, "
"please modify PTQ.decoder_layer_types before invoking apply method.")
else:
logger.info(f"get_decoder_layers time cost {time.time() - start_time}")
for i in tqdm.tqdm(range(len(layers)), desc="Running PTQ..."):
logger.info(f"Quantize {i}th decoder layer.")
layer_name, layer = layers[i]
cur_args, cur_kwargs = copy.deepcopy(all_args), copy.deepcopy(all_kwargs)
for processor in self.pipeline:
processor.replace(layer_name, layer, network_helper)
logger.info("Catching inputs of all Linear in decoder layer.")
start_time = time.time()
transform_network_inplace(layer, WrapperCell, lambda _, cell: cell.add_hook())
index = 0
for args, kwargs in zip(cur_args, cur_kwargs):
all_args[index][0] = layer(*args, **kwargs)
index += 1
transform_network_inplace(layer, WrapperCell, lambda _, cell: cell.remove_hook())
logger.info(f"{i}th layer output refresh time cost {time.time() - start_time}")
start_time = time.time()
processor.process(layer_name, layer, network_helper)
processor.deploy(layer_name, layer)
network.update_parameters_name()
logger.info(f"{i}th layer do {type(processor)} time cost {time.time() - start_time}")
start_time = time.time()
offload_network(layer)
logger.info(f"{i}th layer offload network time cost {time.time() - start_time}")
return network
def _get_first_layer_input(self, network: Cell, network_helper: NetworkHelper = None, ds=None):
"""get first layer input"""
layers = self._get_decoder_layers(network)
catcher = InputCatcher(layers[0][1])
def replace_first_decoder(root: Cell, src: Cell, dst: Cell):
if root is None:
return
for name, cell in root.name_cells().items():
if cell is src:
root.insert_child_to_cell(name, dst)
return
replace_first_decoder(cell, src, dst)
replace_first_decoder(network, layers[0][1], catcher)
if not ds:
raise ValueError("PTQ need dataset to calibrate, please provide dataset.")
total_count = ds.get_dataset_size()
data_count = 1
for _, ds_item in enumerate(ds.create_dict_iterator()):
logger.info(f"Calibrating: dataset count: {data_count}/{total_count}")
input_ids = ds_item['input_ids'].asnumpy()
try:
network_helper.generate(network, input_ids, max_new_tokens=1)
except GeneratorExit:
if hasattr(network, "block_mgr") and network.block_mgr:
network.block_mgr.clear_cache()
data_count += 1
replace_first_decoder(network, catcher, catcher.handler)
offload_network(network)
return catcher, network
[docs] def convert(self, net_opt: Cell, ckpt_path="") -> Cell:
"""
Define how to convert a compressed network to a standard network before exporting.
Args:
net_opt (Cell): Network to be converted which is transformed by `RoundToNearest.apply`.
ckpt_path (str): Path to checkpoint file for `net_opt`. Default is ``""``, which means not loading
checkpoint file to `net_opt`.
Returns:
An instance of Cell represents quantized network.
Raises:
TypeError: If `net_opt` is not Cell.
TypeError: If `ckpt_path` is not string.
ValueError: If `ckpt_path` is not empty and invalid.
"""
if not isinstance(net_opt, Cell):
raise TypeError(
f'The parameter `net_opt` must be isinstance of Cell, but got {type(net_opt)}.')
if not isinstance(ckpt_path, str):
raise TypeError(
f'The parameter `ckpt_path` must be isinstance of str, but got {type(ckpt_path)}.')
real_path = os.path.realpath(ckpt_path)
if ckpt_path != "":
if os.path.isfile(real_path):
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net_opt, param_dict)
else:
raise ValueError(
f'The parameter `ckpt_path` can only be empty or a valid file, but got {real_path}.')
return net_opt