mindspore_gs.quantization.SlbQuantAwareTraining

查看源文件
class mindspore_gs.quantization.SlbQuantAwareTraining(config=None)[源代码]

SLB(Searching for Low-Bit Weights in Quantized Neural Networks)算法的实现,该算法将量化神经网络中的离散权重作为可搜索的变量,并实现了一种微分方法去精确的实现该搜索。具体来说,是将每个权重表示为在离散值集上的概率分布,通过训练来优化该概率分布,最终具有最高概率的离散值就是搜索的结果,也就是量化的结果。更多详细信息见 Searching for Low-Bit Weights in Quantized Neural Networks

说明

  • 本方法会调用其它接口来设置参数,所以报错时需要参考其他的接口,比如 quant_dtype 要参考 set_weight_quant_dtypeset_act_quant_dtype

参数:
  • config (dict) - 以字典的形式存放用于量化训练的属性,默认值为 None。下面列出了受支持的属性:

    • quant_dtype (Union[QuantDtype, list(QuantDtype), tuple(QuantDtype)]) - 用于量化权重和激活的数据类型。类型为 QuantDtype 或包含两个 QuantDtype 的list或者tuple。如果 quant_dtype 是一个 QuantDtype ,则会被复制成包含两个 QuantDtype 的list。第一个元素表示激活的量化数据类型,第二个元素表示权重的量化数据类型。在实际量化推理场景中需要考虑硬件器件的精度支持。当前权重量化支持1、2、4比特,激活量化支持8比特。默认值:(QuantDtype.INT8, QuantDtype.INT1)

    • enable_act_quant (bool) - 在训练中是否开启激活量化。默认值:False

    • enable_bn_calibration (bool) - 在训练中是否开启BN层矫正功能。默认值:False

    • epoch_size (int) - 训练的总epoch数。

    • has_trained_epoch (int) - 预训练的epoch数。

    • t_start_val (float) - 温度初始值。默认值:1.0

    • t_start_time (float) - 温度开始变化时间。默认值:0.2

    • t_end_time (float) - 温度停止变化时间。默认值:0.6

    • t_factor (float) - 温度变化因子。默认值:1.2

异常:
  • TypeError - quant_dtype 的数据类型不是 QuantDtype ,或者 quant_dtype 存在不是 QuantDtype 的元素。

  • TypeError - enable_act_quant 或者 enable_bn_calibration 的数据类型不是bool。

  • ValueError - quant_dtype 的长度大于2。

  • TypeError - epoch_sizehas_trained_epoch 的数据类型不是int。

  • TypeError - t_start_valt_start_timet_end_timet_factor 的数据类型不是float。

  • ValueError - epoch_size 小于等于0。

  • ValueError - has_trained_epoch 小于0。

  • ValueError - t_start_valt_factor 小于等于0.0。

  • ValueError - t_start_timet_end_time 小于0.0。

  • ValueError - t_start_timet_end_time 大于1.0。

支持平台:

GPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore_gs.quantization import SlbQuantAwareTraining
>>> from mindspore.common.dtype import QuantDtype
>>> class NetToQuant(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(NetToQuant, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToQuant()
>>> ## 2) Define SLB QAT-Algorithm
>>> slb_quantization = SlbQuantAwareTraining()
>>> ## 3) Use set functions to change config
>>> ## 3.1) set_weight_quant_dtype is used to set the weight quantization bit, and support QuantDtype.INT4, QuantDtype.INT2,
>>> ## QuantDtype.INT1 now.
>>> slb_quantization.set_weight_quant_dtype(QuantDtype.INT1)
>>> ## 3.2) set_act_quant_dtype is used to set the activation quantization bit, and support QuantDtype.INT8 now.
>>> slb_quantization.set_act_quant_dtype(QuantDtype.INT8)
>>> ## 3.3) set_enable_act_quant is used to set whether apply activation quantization.
>>> slb_quantization.set_enable_act_quant(True)
>>> ## 3.4) set_enable_bn_calibration is used to set whether apply batchnorm calibration.
>>> slb_quantization.set_enable_bn_calibration(True)
>>> ## 3.5) set_epoch_size is used to set the epoch size of training.
>>> slb_quantization.set_epoch_size(100)
>>> ## 3.6) set_has_trained_epoch is used to set the trained epoch size of training.
>>> slb_quantization.set_has_trained_epoch(0)
>>> ## 3.7) set_t_start_val is used to set the initial value of temperature hyperparameters.
>>> slb_quantization.set_t_start_val(1.0)
>>> ## 3.8) set_t_start_time is used to set the fraction of epochs after which temperature hyperparameters starting changing.
>>> slb_quantization.set_t_start_time(0.2)
>>> ## 3.9) set_t_end_time is used to set the fraction of epochs after which temperature hyperparameters stopping changing.
>>> slb_quantization.set_t_end_time(0.6)
>>> ## 3.10) set_t_factor is used to set the multiplicative factor of temperature hyperparameters changing.
>>> slb_quantization.set_t_factor(1.2)
>>> ## 4) Print SLB QAT-Algorithm object and check the config setting result
>>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the value of the attribute weight_quant_dtype is INT1
>>> ## Since we set act_quant_dtype to be QuantDtype.INT8, the value of the attribute weight_quant_dtype is INT8
>>> ## Since we set enable_act_quant to be True, the value of the attribute enable_act_quant is True
>>> ## Since we set enable_bn_calibration to be True, the value of the attribute enable_bn_calibration is True
>>> ## Since we set epoch_size to be 100, the value of the attribute epoch_size is 100
>>> ## Since we set has_trained_epoch to be 0, the value of the attribute has_trained_epoch is 0
>>> ## Since we set t_start_val to be 1.0, the value of the attribute t_start_val is 1.0
>>> ## Since we set t_start_time to be 0.2, the value of the attribute t_start_time is 0.2
>>> ## Since we set t_end_time to be 0.6, the value of the attribute t_end_time is 0.6
>>> ## Since we set t_factor to be 1.2, the value of the attribute t_factor is 1.2
>>> print(slb_quantization)
SlbQuantAwareTraining<weight_quant_dtype=INT1, act_quant_dtype=INT8, enable_act_quant=True, enable_bn_calibration=True, epoch_size=100, has_trained_epoch=0, t_start_val=1.0, t_start_time=0.2, t_end_time=0.6, t_factor=1.2>
>>> ## 5) Apply SLB QAT-algorithm to origin network
>>> net_qat = slb_quantization.apply(net)
>>> ## 6) Print network and check the result. Conv2d should be transformed to QuantizeWrapperCells.
>>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the bit_num value of fake_quant_weight
>>> ## should be 1, and the weight_bit_num value of Conv2dSlbQuant should be 1.
>>> print(net_qat)
NetToQuantOpt<
(_handler): NetToQuant<
(conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.9, gamma=Parameter(name=bn.gamma, requires_grad=True, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.]), beta=Parameter(name=bn.beta, requires_grad=True, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_mean=Parameter(name=bn.moving_mean, requires_grad=False, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_variance=Parameter(name=bn.moving_variance, requires_grad=False, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.])>
>
(bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.9, gamma=Parameter(name=bn.gamma, requires_grad=True, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.]), beta=Parameter(name=bn.beta, requires_grad=True, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_mean=Parameter(name=bn.moving_mean, requires_grad=False, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_variance=Parameter(name=bn.moving_variance, requires_grad=False, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.])>
(Conv2dSlbQuant): QuantCell<
(_handler): Conv2dSlbQuant<
in_channels=1, out_channels=6, kernel_size=(5, 5), weight_bit_num=1, stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False
(fake_quant_weight): SlbFakeQuantizerPerLayer<bit_num=1>
>
(_input_quantizer): SlbActQuantizer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
(_output_quantizer): SlbActQuantizer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900>
>
>
>>> ## 7) convert a compressed network to a standard network before exporting to MindIR.
>>> net_qat = slb_quantization.convert(net_qat)
>>> data_in = mindspore.Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
>>> file_name = "./conv.mindir"
>>> mindspore.export(net_qat, data_in, file_name=file_name, file_format="MINDIR")
>>> graph = mindspore.load(file_name)
>>> mindspore.nn.GraphCell(graph)
apply(network: Cell)[源代码]

按照下面4个步骤对给定网络应用量化算法,得到带有伪量化节点的网络。

  1. 使用网络策略中定义的模式引擎在给定网络中融合特定的单元。

  2. 传播通过单元定义的层策略。

  3. 当量化器冗余时,减少冗余的伪量化器。

  4. 应用层策略将正常 Cell 转换为 QuantizeWrapperCell

参数:
  • network (Cell) - 即将被量化的网络。

返回:

在原网络定义的基础上,修改需要量化的网络层后生成带有伪量化节点的网络。

callbacks(model: Model, dataset: Dataset)[源代码]

定义SLB量化算法特有的一些callbacks,其中包括用于调节温度因子的callback。

参数:
  • model (Model) - 经过算法修改后的网络构造的mindspore的Model对象。

  • dataset (Dataset) - 加载了特定数据集的Dataset对象。

异常:
  • RuntimeError - epoch_size 没有初始化。

  • RuntimeError - has_trained_epoch 没有初始化。

  • ValueError - epoch_size 小于等于 has_trained_epoch

  • ValueError - t_end_time 小于 t_start_time

  • TypeError - model 的数据类型不是 mindspore.train.Model

  • TypeError - dataset 的数据类型不是 mindspore.dataset.Dataset

返回:

SLB量化算法特有的一些callbacks的列表。

convert(net_opt: Cell, ckpt_path='')[源代码]

定义将SLB量化网络转换成适配MindIR的标准网络的具体实现。

参数:
  • net_opt (Cell) - 经过SLB量化算法量化后的网络。

  • ckpt_path (str) - checkpoint文件的存储路径,为空时不加载,默认值为 ""

异常:
  • TypeError - net_opt 的数据类型不是 mindspore.nn.Cell

  • TypeError - ckpt_path 的数据类型不是str。

  • ValueError - ckpt_path 不为空,但不是有效文件。

  • RuntimeError - ckpt_path 是有效文件,但加载失败。

返回:

能适配MindIR的标准网络。

set_act_quant_dtype(act_quant_dtype=QuantDtype.INT8)[源代码]

设置激活量化的数据类型。

参数:
  • act_quant_dtype (QuantDtype) - 激活量化的数据类型。默认值:QuantDtype.INT8

异常:
  • TypeError - act_quant_dtype 的数据类型不是QuantDtype。

  • ValueError - act_quant_dtype 不是 QuantDtype.INT8

set_enable_act_quant(enable_act_quant=False)[源代码]

设置是否开启激活量化。

参数:
  • enable_act_quant (bool) - 在训练中是否开启激活量化。默认值:False

异常:
  • TypeError - enable_act_quant 的数据类型不是bool。

set_enable_bn_calibration(enable_bn_calibration=False)[源代码]

设置是否开启BatchNorm层矫正功能。

参数:
  • enable_bn_calibration (bool) - 在训练中是否开启BatchNorm层矫正功能。默认值:False

异常:
  • TypeError - enable_bn_calibration 的数据类型不是bool。

set_epoch_size(epoch_size)[源代码]

设置训练的总epoch数。

参数:
  • epoch_size (int) - 训练的总epoch数。

异常:
  • TypeError - epoch_size 的数据类型不是int。

  • ValueError - epoch_size 小于等于0。

set_has_trained_epoch(has_trained_epoch)[源代码]

设置预训练的epoch数。

参数:
  • has_trained_epoch (int) - 预训练的epoch数。

异常:
  • TypeError - has_trained_epoch 的数据类型不是int。

  • ValueError - has_trained_epoch 小于0。

set_t_end_time(t_end_time=0.6)[源代码]

设置温度停止变化时间。

参数:
  • t_end_time (float) - 温度停止变化时间。默认值:0.6

异常:
  • TypeError - t_end_time 的数据类型不是float。

  • ValueError - t_end_time 小于0.0或大于1.0。

set_t_factor(t_factor=1.2)[源代码]

设置温度变化因子。

参数:
  • t_factor (float) - 温度变化因子。默认值:1.2

异常:
  • TypeError - t_factor 的数据类型不是float。

  • ValueError - t_factor 小于等于0.0。

set_t_start_time(t_start_time=0.2)[源代码]

设置温度开始变化时间。

参数:
  • t_start_time (float) - 温度开始变化时间。默认值:0.2

异常:
  • TypeError - t_start_time 的数据类型不是float。

  • ValueError - t_start_time 小于0.0或大于1.0。

set_t_start_val(t_start_val=1.0)[源代码]

设置温度初始值。

参数:
  • t_start_val (float) - 温度初始值。默认值:1.0

异常:
  • TypeError - t_start_val 的数据类型不是float。

  • ValueError - t_start_val 小于等于0.0。

set_weight_quant_dtype(weight_quant_dtype=QuantDtype.INT1)[源代码]

设置权重量化的数据类型。

参数:
  • weight_quant_dtype (QuantDtype) - 权重量化的数据类型。默认值:QuantDtype.INT1

异常:
  • TypeError - weight_quant_dtype 的数据类型不是QuantDtype。

  • ValueError - weight_quant_dtype 不是 QuantDtype.INT1QuantDtype.INT2QuantDtype.INT4 中的一种。