Model Quantization
Overview
MindSpore is an all-scenario AI framework. When a model is deployed on the device or other lightweight devices, the memory, power consumption, and latency are limited. Therefore, the model needs to be compressed before deployment.
MindSpore Golden Stick provides the model compression capability. MindSpore Golden Stick is a model compression algorithm set jointly designed and developed by Huawei Noah team and Huawei MindSpore team. It provides a series of model compression algorithms for MindSpore and supports quantization modes such as A16W8, A16W4, A8W8, and KVCache. For details, see MindSpore Golden Stick.
Basic Model Quantization Process
To help you understand the basic quantization process of the MindSpore Golden Stick model, the following provides examples of quantization algorithms along with the basic usage methods.
Procedure
MindSpore Golden Stick quantization algorithms can be divided into two phases: quantization and deployment. The quantization phase is completed before deployment. The main tasks are as follows: collecting weight distribution, calculating quantization parameters, quantizing weight data, and inserting dequantization nodes. The deployment phase refers to the process of using the MindSpore framework to perform inference on the quantized model in the production environment.
MindSpore Golden Stick uses PTQConfig
to define quantization and deployment, and uses the apply
and convert
APIs to implement quantization and deployment. In PTQConfig
, you can configure the data calibration policy, whether to quantize the weight, activation, and KVCache, and the quantization bits. For details, see PTQConfig Configuration Description.
The quantization procedure of MindSpore Golden Stick is as follows:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindformers.modules import Linear
from mindspore_gs.common import BackendTarget
from mindspore_gs.ptq import PTQMode, PTQConfig
from mindspore_gs.ptq import RoundToNearest as RTN
from mindspore_gs.ptq.network_helpers import NetworkHelper
class SimpleNet(nn.Cell):
def __init__(self):
super().__init__()
self.linear = Linear(in_channels=5, out_channels=6, transpose_b=True, bias_init="normal", weight_init="normal")
def construct(self, x):
return self.linear(x)
class SimpleNetworkHelper(NetworkHelper):
def __init__(self, **kwargs):
self.attrs = kwargs
def get_spec(self, name: str):
return self.attrs.get(name, None)
def generate(self, network: nn.Cell, input_ids: np.ndarray, max_new_tokens=1, **kwargs):
input_ids = np.pad(ininput_ids, ((0, 0), (0, self.get_spec("seq_length") - inputs_ids.shape[1])), 'constant', constant_values=0)
network(Tensor(input_ids, dtype=ms.dtype.float16))
net = SimpleNet() # The float model that needs to be quantized
cfg = PTQConfig(mode=PTQMode.QUANTIZE, backend=BackendTarget.ASCEND, weight_quant_dtype=ms.dtype.int8)
net_helper = SimpleNetworkHelper(batch_size=1, seq_length=5)
rtn = RTN(cfg)
rtn.apply(net, net_helper)
rtn.convert(net)
ms.save_checkpoint(net.parameters_dict(), './simplenet_rtn.ckpt')
Use nn.Cell to define the network. After training the model, obtain the floating-point weight of the model, and then load the floating-point weight during inference. The above example simplifies the process by directly creating a network and using the initial floating-point weight for quantization.
Use PTQConfig to set mode to quantization mode, set the backend to Ascend, and perform 8-bit quantization on the weight. For details, see PTQConfig Configuration Description.
Use the apply API to convert the network into a pseudo-quantized network and collect statistics on the quantized object based on the configuration in
PTQConfig
.Use the convert API to quantize the pseudo-quantized network in the previous step to obtain the quantized network.
After the quantization is complete, you can use the quantized model for inference as follows:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindformers.modules import Linear
from mindspore_gs.common import BackendTarget
from mindspore_gs.ptq import PTQMode, PTQConfig
from mindspore_gs.ptq import RoundToNearest as RTN
class SimpleNet(nn.Cell):
def __init__(self):
super().__init__()
self.linear = Linear(in_channels=5, out_channels=6, transpose_b=True, bias_init="normal", weight_init="normal")
def construct(self, x):
return self.linear(x)
net = SimpleNet()
cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, weight_quant_dtype=ms.dtype.int8)
rtn = RTN(cfg)
rtn.apply(net)
rtn.convert(net)
ms.load_checkpoint('./simplenet_rtn.ckpt', net)
input = Tensor(np.ones((5, 5), dtype=np.float32), dtype=ms.dtype.float32)
output = net(input)
print(output)
Use PTQConfig to set mode to deployment mode, set the backend to Ascend, and perform 8-bit quantization on the weight. For details, see PTQConfig Configuration Description.
Use the apply and convert APIs to convert a network into a quantized network. In the deployment phase, information statistics and quantitative calculation are not performed, and only the network structure is converted into a quantized network.
Load the quantized weights to the quantized network for inference.
PTQConfig Configuration Description
You can customize PTQConfig to enable different quantization capabilities. For details about PTQConfig, see the API document. The following shows the configuration examples of some algorithms.
A indicates activation, W indicates weight, C indicates KVCache, and the number indicates bits. For example, A16W8 indicates a quantization where activations are represented as float16 and weights as int8.
A16W8 quantization
from mindspore import dtype as msdtype from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType ptq_config = PTQConfig(weight_quant_dtype=msdtype.int8, act_quant_dtype=None, kvcache_quant_dtype=None, outliers_suppression=OutliersSuppressionType.NONE)
A8W8 quantization
A8W8 quantization is based on the SmoothQuant algorithm. PTQConfig provides the outliers_suppression field to specify whether to perform the smooth operation.
from mindspore import dtype as msdtype from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType ptq_config = PTQConfig(weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, kvcache_quant_dtype=None, outliers_suppression=OutliersSuppressionType.SMOOTH)
KVCache int8 quantization
from mindspore import dtype as msdtype from mindspore_gs.ptq import PTQConfig, OutliersSuppressionType ptq_config = PTQConfig(weight_quant_dtype=None, act_quant_dtype=None, kvcache_quant_dtype=msdtype.int8, outliers_suppression=OutliersSuppressionType.NONE)
Case Analysis
Post-training Quantization
The following provides a complete process of quantization and deployment of the PTQ algorithm and the RoundToNearest algorithm on the Llama2 network.
PTQ algorithm: supports 8-bit weight quantization, 8-bit full quantization, and KVCacheInt8 quantization. SmoothQuant can be used to improve the quantization accuracy. Combining different quantization algorithms can improve the quantization inference performance.
RoundToNearest algorithm: the simplest 8-bit PTQ algorithm, which supports linear weight quantization and KVCacheInt8 quantization. This algorithm will be discarded in the future. You are advised to use the PTQ algorithm.
Perceptual Quantization Training
SimQAT algorithm: a basic quantization aware algorithm based on the pseudo-quantization technology.
SLB quantization algorithm: a non-linear low-bit quantization aware algorithm.
Pruning
SCOP pruning algorithm example: a structured weight pruning algorithm.