mindspore_gs.quantization.SimulatedQuantizationAwareTraining
- class mindspore_gs.quantization.SimulatedQuantizationAwareTraining(config=None)[源代码]
Basic implementation of simulated quantization aware training, this algorithm adopts fake quantizer to simulate the loss of quantization calculation, and network parameters are updated through backpropagation, so that the network parameters can better adapt to the loss caused by quantization. See more details in A White Paper on Neural Network Quantization <https://arxiv.org/pdf/2106.08295.pdf>.
- Parameters
config (dict) –
store attributes for quantization aware training, keys are attribute names, values are attribute values. supported attribute are listed below:
quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized during train and eval. The first element represents data flow and the second element represents weights. Default: (0, 0).
quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first element represents data flow and the second element represents weights. It is necessary to consider the precision support of hardware devices in the practical quantization infer scenaries. Default: (QuantDtype.INT8, QuantDtype.INT8).
per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If True then base on per channel, otherwise base on per layer. The first element represents data flow and the second element represents weights, and the first element must be False now. Default: (False, False).
symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If True then base on symmetric, otherwise base on asymmetric. The first element represents data flow and the second element represents weights. Default: (False, False).
narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. The first element represents data flow and the second element represents weights. Default: (False, False).
enable_fusion (bool): Whether apply fusion before applying quantization. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters fixed to global mean and variance. Default: 10000000.
bn_fold (bool): Whether to use bn fold ops for simulation inference operation. Default: False.
one_conv_fold (bool): Whether to use one conv bn fold ops for simulation inference operation. Default: True.
- Raises
TypeError – If bn_fold, one_conv_fold or enable_fusion is not bool.
TypeError – If freeze_bn is not int.
TypeError – If quant_delay is not int, or every element of quant_delay is not int.
TypeError – If quant_dtype is not QuantDtype, or every element of quant_dtype is not QuantDtype.
TypeError – If per_channel is not bool, or every element of per_channel is not bool.
TypeError – If symmetric is not bool, or every element of symmetric is not bool.
TypeError – If narrow_range is not bool, or every element of narrow_range is not bool.
ValueError – If freeze_bn is less than 0.
ValueError – If the length of quant_delay, quant_dtype, per_channel, symmetric or narrow_range is not less than 2.
ValueError – If quant_delay is less than 0, or any element of quant_delay is less than 0.
TypeError – If quant_dtype is not QuantDtype.INT8, or any element of quant_dtype is not QuantDtype.INT8.
ValueError – If per_channel is True, or the first element of per_channel is True.
- Supported Platforms:
GPU
Examples
>>> from mindspore_gs.quantization.simulated_quantization import SimulatedQuantizationAwareTraining >>> from mindspore import nn ... class NetToQuant(nn.Cell): ... def __init__(self, num_channel=1): ... super(NetToQuant, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... >>> ## 1) Define network to be quantized >>> net = NetToQuant() >>> ## 2) Define SimQAT Algorithm >>> simulated_quantization = SimulatedQuantizationAwareTraining() >>> ## 3) Use set functions to change config >>> simulated_quantization.set_enable_fusion(True) >>> simulated_quantization.set_bn_fold(False) >>> simulated_quantization.set_act_quant_delay(900) >>> simulated_quantization.set_weight_quant_delay(900) >>> simulated_quantization.set_act_per_channel(False) >>> simulated_quantization.set_weight_per_channel(True) >>> simulated_quantization.set_act_narrow_range(False) >>> simulated_quantization.set_weight_narrow_range(False) >>> ## 4) Apply SimQAT algorithm to origin network >>> net_qat = simulated_quantization.apply(net) >>> ## 5) Print network and check the result. Conv2d and Dense should be transformed to QuantizeWrapperCells. >>> ## Since we set enable_fusion to be True, bn_fold to be False, the Conv2d and BatchNorm2d Cells are >>> ## fused and converted to Conv2dBnWithoutFoldQuant. >>> ## Since we set act_quant_delay to be 900, the quant_delay value of _input_quantizer and _output_quantizer >>> ## are set to be 900. >>> ## Since we set weight_quant_delay to be 900, the quant_delay value of fake_quant_weight are set to be 900. >>> ## Since we set act_per_channel to be False, the per_channel value of _input_quantizer and >>> ## _output_quantizer are set to be False. >>> ## Since we set weight_per_channel to be True, the per_channel value of fake_quant_weight are set to be >>> ## True. >>> ## Since we set act_narrow_range to be False, the narrow_range value of _input_quantizer and >>> ## _output_quantizer are set to be False. >>> ## Since we set weight_narrow_range to be False, the narrow_range value of fake_quant_weight are set to be >>> ## True. >>> print(net_qat) NetToQuantOpt< (_handler): NetToQuant< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=_handler.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=_handler.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=_handler.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=_handler.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (Conv2dBnWithoutFoldQuant): QuantizeWrapperCell< handler: in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, input quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900, output quantizer: bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900 (_handler): Conv2dBnWithoutFoldQuant< in_channels=1, out_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False (fake_quant_weight): SimulatedFakeQuantizerPerChannel<bit_num=8, symmetric=True, narrow_range=False, ema=False(0.999), per_channel=True(0, 6), quant_delay=900> (batchnorm): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.0030000000000000027, gamma=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=Conv2dBnWithoutFoldQuant._handler.batchnorm.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (_input_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900> (_output_quantizer): SimulatedFakeQuantizerPerLayer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900> > >
- apply(network: Cell)[源代码]
Apply SimQAT Algorithm on network, use the following steps to make network available for quantization aware training:
Fuse certain cells in network using pattern engine which is defined by net policy. Default fuse pattern: Conv2d + BatchNorm2d + ReLU, Conv2d + ReLU, Dense + BatchNorm2d + ReLU, Dense + BatchNorm2d, Dense + ReLU.
Propagate LayerPolicies defined in NetPolicy through network.
Reduce redundant fake quantizers which means two or more fake quantizers existing on one tensor.
Apply LayerPolicies to convert normal cell to QuantizeWrapperCell. We will insert real fake quantizer into network in this step.
- Parameters
network (Cell) – Network to be quantized.
- Returns
Quantized network.
- set_act_narrow_range(act_narrow_range)[源代码]
Set value of act_narrow_range of quantization aware training config
- set_act_per_channel(act_per_channel)[源代码]
Set value of act_per_channel of quantization aware training config
- Parameters
act_per_channel (bool) – Quantization granularity based on layer or on channel. If True then base on per channel, otherwise base on per layer. Only support False now.
- Raises
TypeError – If act_per_channel is not bool.
ValueError – Only supported if act_per_channel is False yet.
- set_act_quant_delay(act_quant_delay)[源代码]
Set value of act_quant_delay of quantization aware training config
- Parameters
act_quant_delay (int) – Number of steps after which activation is quantized during train and eval.
- Raises
TypeError – If act_quant_delay is not int.
ValueError – act_quant_delay is less than 0.
- set_act_quant_dtype(act_quant_dtype)[源代码]
Set value of act_quant_dtype of quantization aware training config
- set_act_symmetric(act_symmetric)[源代码]
Set value of act_symmetric of quantization aware training config
- set_enable_fusion(enable_fusion)[源代码]
Set value of enable_fusion of quantization aware training config
- set_freeze_bn(freeze_bn)[源代码]
Set value of freeze_bn of quantization aware training config
- Parameters
freeze_bn (int) – Number of steps after which BatchNorm OP parameters fixed to global mean and variance.
- Raises
TypeError – If freeze_bn is not int.
ValueError – freeze_bn is less than 0.
- set_one_conv_fold(one_conv_fold)[源代码]
Set value of one_conv_fold of quantization aware training config
- set_weight_narrow_range(weight_narrow_range)[源代码]
Set value of weight_narrow_range of quantization aware training config
- set_weight_per_channel(weight_per_channel)[源代码]
Set value of weight_per_channel of quantization aware training config
- set_weight_quant_delay(weight_quant_delay)[源代码]
Set value of weight_quant_delay of quantization aware training config
- Parameters
weight_quant_delay (int) – Number of steps after which weight is quantized during train and eval.
- Raises
TypeError – If weight_quant_delay is not int.
ValueError – weight_quant_delay is less than 0.
- set_weight_quant_dtype(weight_quant_dtype)[源代码]
Set value of weight_quant_dtype of quantization aware training config