mindspore_gs.quantization.SimulatedQuantizationAwareTraining
- class mindspore_gs.quantization.SimulatedQuantizationAwareTraining(config=None)[source]
Basic implementation of simulated quantization aware training, this algorithm adopts fake quantizer to simulate the loss of quantization calculation, and network parameters are updated through backpropagation, so that the network parameters can better adapt to the loss caused by quantization. See more details in A White Paper on Neural Network Quantization.
- Parameters
config (dict) –
store attributes for quantization aware training, keys are attribute names, values are attribute values. The Default value is None, supported attribute are listed below:
quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during train and eval. The first element represents activations and the second element represents weights. Default:
(0, 0)
.quant_dtype (Union[QuantDtype, list, tuple]): The target data type for quantization. It is necessary to consider the precision support of hardware devices when setting quant_dtype. The first element represents activations and the second element represents weights. Default:
(QuantDtype.INT8, QuantDtype.INT8)
.per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If
True
then base on per channel, otherwise base on per layer. The first element represents activations and the second element represents weights, and the first element must beFalse
now. Default:(False, False)
.symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If
True
then base on symmetric, otherwise base on asymmetric. The first element represents activations and the second element represents weights. Default:(False, False)
.narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. The first element represents activations and the second element represents weights. Default:
(False, False)
.enable_fusion (bool): Whether apply fusion before applying quantization. Default:
False
.freeze_bn (int): Number of steps after which BatchNorm OP parameters fixed to global mean and variance. Default:
10000000
.bn_fold (bool): Whether to use bn fold ops for simulation inference operation. Default:
False
.one_conv_fold (bool): Whether to use one conv bn fold ops for simulation inference operation. Default:
True
.
- Raises
TypeError – If bn_fold, one_conv_fold or enable_fusion is not bool.
TypeError – If freeze_bn is not int.
TypeError – If quant_delay is not int, or every element of quant_delay is not int.
TypeError – If quant_dtype is not QuantDtype, or every element of quant_dtype is not QuantDtype.
TypeError – If per_channel is not bool, or every element of per_channel is not bool.
TypeError – If symmetric is not bool, or every element of symmetric is not bool.
TypeError – If narrow_range is not bool, or every element of narrow_range is not bool.
ValueError – If freeze_bn is less than 0.
ValueError – If the length of quant_delay, quant_dtype, per_channel, symmetric or narrow_range is more than 2.
ValueError – If quant_delay is less than 0, or any element of quant_delay is less than 0.
ValueError – If quant_dtype is not QuantDtype.INT8, or any element of quant_dtype is not QuantDtype.INT8.
ValueError – If per_channel is
True
, or the first element of per_channel isTrue
.
- Supported Platforms:
GPU
Examples
>>> from mindspore_gs.quantization import SimulatedQuantizationAwareTraining >>> from mindspore import nn ... class NetToQuant(nn.Cell): ... def __init__(self, num_channel=1): ... super(NetToQuant, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... >>> ## 1) Define network to be quantized >>> net = NetToQuant() >>> ## 2) Define SimQAT Algorithm >>> simulated_quantization = SimulatedQuantizationAwareTraining() >>> ## 3) Use set functions to change config >>> simulated_quantization.set_enable_fusion(True) >>> simulated_quantization.set_bn_fold(False) >>> simulated_quantization.set_act_quant_delay(900) >>> simulated_quantization.set_weight_quant_delay(900) >>> simulated_quantization.set_act_per_channel(False) >>> simulated_quantization.set_weight_per_channel(True) >>> simulated_quantization.set_act_narrow_range(False) >>> simulated_quantization.set_weight_narrow_range(False) >>> ## 4) Apply SimQAT algorithm to origin network >>> net_qat = simulated_quantization.apply(net) >>> ## 5) Print network and check the result. Conv2d and Dense should be transformed to QuantizeWrapperCells. >>> ## Since we set enable_fusion to be True, bn_fold to be False, the Conv2d and BatchNorm2d Cells are >>> ## fused and converted to Conv2dBnWithoutFoldQuant. >>> ## Since we set act_quant_delay to be 900, the quant_delay value of _input_quantizer and _output_quantizer >>> ## are set to be 900. >>> ## Since we set weight_quant_delay to be 900, the quant_delay value of fake_quant_weight are set to be 900. >>> ## Since we set act_per_channel to be False, the per_channel value of _input_quantizer and >>> ## _output_quantizer are set to be False. >>> ## Since we set weight_per_channel to be True, the per_channel value of fake_quant_weight are set to be >>> ## True. >>> ## Since we set act_narrow_range to be False, the narrow_range value of _input_quantizer and >>> ## _output_quantizer are set to be False. >>> ## Since we set weight_narrow_range to be False, the narrow_range value of fake_quant_weight are set to be >>> ## True. >>> print(net_qat) NetToQuantOpt< (_handler): NetToQuant< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=_handler.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=_handler.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=_handler.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=_handler.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (Conv2dBnWithoutFoldQuant): QuantCell< handler: in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, input quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900, output quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900 (_handler): Conv2dBnWithoutFoldQuant< in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 6), quant_delay=900> (batchnorm): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.0030000000000000027, gamma=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900> (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900> > >
- apply(network: Cell, **kwargs)[source]
Apply SimQAT Algorithm on network, use the following steps to make network available for quantization aware training:
Fuse certain cells in network using pattern engine which is defined by net policy. Default fuse pattern: Conv2d + BatchNorm2d + ReLU, Conv2d + ReLU, Dense + BatchNorm2d + ReLU, Dense + BatchNorm2d, Dense + ReLU.
Propagate LayerPolicies defined in NetPolicy through network.
Reduce redundant fake quantizers which means two or more fake quantizers existing on one tensor.
Apply LayerPolicies to convert normal cell to QuantCell. We will insert real fake quantizer into network in this step.
- Parameters
network (Cell) – Network to be quantized.
- Returns
Quantized network.
- convert(net_opt: Cell, ckpt_path='')[source]
Define how to convert a compressed network to a standard network before exporting to MindIR.
- Parameters
net_opt (Cell) – Network to be converted which is transformed by SimulatedQuantizationAwareTraining.apply.
ckpt_path (str) – Path to checkpoint file for net_opt. Default is
""
, which means not loading checkpoint file to net_opt.
- Returns
An instance of Cell represents converted network.
- Raises
TypeError – If net_opt is not Cell.
TypeError – If ckpt_path is not string.
ValueError – If ckpt_path is not empty and invalid.
RuntimeError – If loading ckpt_path fails.
- set_act_narrow_range(act_narrow_range)[source]
Set value of act_narrow_range of quantization aware training config
- set_act_per_channel(act_per_channel)[source]
Set value of act_per_channel of quantization aware training config
- Parameters
act_per_channel (bool) – Quantization granularity based on layer or on channel. If
True
then base on per channel, otherwise base on per layer. Only supportFalse
now.- Raises
TypeError – If act_per_channel is not bool.
ValueError – Only supported if act_per_channel is
False
yet.
- set_act_quant_delay(act_quant_delay)[source]
Set value of act_quant_delay of quantization aware training config
- Parameters
act_quant_delay (int) – Number of steps after which activation is quantized during train and eval.
- Raises
TypeError – If act_quant_delay is not int.
ValueError – act_quant_delay is less than 0.
- set_act_quant_dtype(act_quant_dtype)[source]
Set value of act_quant_dtype of quantization aware training config
- Parameters
act_quant_dtype (QuantDtype) – Datatype used to quantize activations.
- Raises
TypeError – If act_quant_dtype is not QuantDtype.
ValueError – Only supported if act_quant_dtype is QuantDtype.INT8 yet.
- set_act_symmetric(act_symmetric)[source]
Set value of act_symmetric of quantization aware training config.
- set_enable_fusion(enable_fusion)[source]
Set value of enable_fusion of quantization aware training config
- set_freeze_bn(freeze_bn)[source]
Set value of freeze_bn of quantization aware training config
- Parameters
freeze_bn (int) – Number of steps after which BatchNorm OP parameters fixed to global mean and variance.
- Raises
TypeError – If freeze_bn is not int.
ValueError – freeze_bn is less than 0.
- set_one_conv_fold(one_conv_fold)[source]
Set value of one_conv_fold of quantization aware training config
- set_weight_narrow_range(weight_narrow_range)[source]
Set value of weight_narrow_range of quantization aware training config
- set_weight_per_channel(weight_per_channel)[source]
Set value of weight_per_channel of quantization aware training config
- set_weight_quant_delay(weight_quant_delay)[source]
Set value of weight_quant_delay of quantization aware training config
- Parameters
weight_quant_delay (int) – Number of steps after which weight is quantized during train and eval.
- Raises
TypeError – If weight_quant_delay is not int.
ValueError – weight_quant_delay is less than 0.
- set_weight_quant_dtype(weight_quant_dtype)[source]
Set value of weight_quant_dtype of quantization aware training config.
- Parameters
weight_quant_dtype (QuantDtype) – Datatype used to quantize weight.
- Raises
TypeError – If weight_quant_dtype is not QuantDtype.
ValueError – Only supported if weight_quant_dtype is QuantDtype.INT8 yet.