mindspore_gs.quantization.SlbQuantAwareTraining
- class mindspore_gs.quantization.SlbQuantAwareTraining(config=None)[源代码]
SLB(Searching for Low-Bit Weights in Quantized Neural Networks)算法的实现,该算法将量化神经网络中的离散权重作为可搜索的变量,并实现了一种微分方法去精确的实现该搜索。具体来说,是将每个权重表示为在离散值集上的概率分布,通过训练来优化该概率分布,最终具有最高概率的离散值就是搜索的结果,也就是量化的结果。更多详细信息见 Searching for Low-Bit Weights in Quantized Neural Networks。
- 参数:
config (dict) - 以字典的形式存放用于量化训练的属性,默认值为None。下面列出了受支持的属性:
quant_dtype (Union[QuantDtype, list(QuantDtype), tuple(QuantDtype)]) - 用于量化权重和激活的数据类型。类型为 QuantDtype 或包含两个 QuantDtype 的list或者tuple。如果 quant_dtype 是一个 QuantDtype ,则会被复制成包含两个 QuantDtype 的list。第一个元素表示激活的量化数据类型,第二个元素表示权重的量化数据类型。在实际量化推理场景中需要考虑硬件器件的精度支持。当前权重量化支持1、2、4比特,激活量化支持8比特。默认值:(QuantDtype.INT8, QuantDtype.INT1)。
enable_act_quant (bool) - 在训练中是否开启激活量化。默认值:False。
enable_bn_calibration (bool) - 在训练中是否开启BN层矫正功能。默认值:False。
epoch_size (int) - 训练的总epoch数。
has_trained_epoch (int) - 预训练的epoch数。
t_start_val (float) - 温度初始值。默认值:1.0。
t_start_time (float) - 温度开始变化时间。默认值:0.2。
t_end_time (float) - 温度停止变化时间。默认值:0.6。
t_factor (float) - 温度变化因子。默认值:1.2。
说明
本方法会调用其它接口来设置参数,所以报错时需要参考其他的接口,比如 quant_dtype 要参考 set_weight_quant_dtype 和 set_act_quant_dtype。
- 异常:
TypeError - quant_dtype 的数据类型不是 QuantDtype ,或者 quant_dtype 存在不是 QuantDtype 的元素。
TypeError - enable_act_quant 或者 enable_bn_calibration 的数据类型不是bool。
ValueError - quant_dtype 的长度大于2。
TypeError - epoch_size 或 has_trained_epoch 的数据类型不是int。
TypeError - t_start_val 、 t_start_time、 t_end_time 或 t_factor 的数据类型不是float。
ValueError - epoch_size 小于等于0。
ValueError - has_trained_epoch 小于0。
ValueError - t_start_val 或 t_factor 小于等于0.0。
ValueError - t_start_time 或 t_end_time 小于0.0。
ValueError - t_start_time 或 t_end_time 大于1.0。
- 支持平台:
GPU
样例:
>>> import mindspore >>> import numpy as np >>> from mindspore import nn >>> from mindspore_gs.quantization import SlbQuantAwareTraining >>> from mindspore.common.dtype import QuantDtype >>> class NetToQuant(nn.Cell): ... def __init__(self, num_channel=1): ... super(NetToQuant, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... >>> ## 1) Define network to be quantized >>> net = NetToQuant() >>> ## 2) Define SLB QAT-Algorithm >>> slb_quantization = SlbQuantAwareTraining() >>> ## 3) Use set functions to change config >>> ## 3.1) set_weight_quant_dtype is used to set the weight quantization bit, and support QuantDtype.INT4, QuantDtype.INT2, >>> ## QuantDtype.INT1 now. >>> slb_quantization.set_weight_quant_dtype(QuantDtype.INT1) >>> ## 3.2) set_act_quant_dtype is used to set the activation quantization bit, and support QuantDtype.INT8 now. >>> slb_quantization.set_act_quant_dtype(QuantDtype.INT8) >>> ## 3.3) set_enable_act_quant is used to set whether apply activation quantization. >>> slb_quantization.set_enable_act_quant(True) >>> ## 3.4) set_enable_bn_calibration is used to set whether apply batchnorm calibration. >>> slb_quantization.set_enable_bn_calibration(True) >>> ## 3.5) set_epoch_size is used to set the epoch size of training. >>> slb_quantization.set_epoch_size(100) >>> ## 3.6) set_has_trained_epoch is used to set the trained epoch size of training. >>> slb_quantization.set_has_trained_epoch(0) >>> ## 3.7) set_t_start_val is used to set the initial value of temperature hyperparameters. >>> slb_quantization.set_t_start_val(1.0) >>> ## 3.8) set_t_start_time is used to set the fraction of epochs after which temperature hyperparameters starting changing. >>> slb_quantization.set_t_start_time(0.2) >>> ## 3.9) set_t_end_time is used to set the fraction of epochs after which temperature hyperparameters stopping changing. >>> slb_quantization.set_t_end_time(0.6) >>> ## 3.10) set_t_factor is used to set the multiplicative factor of temperature hyperparameters changing. >>> slb_quantization.set_t_factor(1.2) >>> ## 4) Print SLB QAT-Algorithm object and check the config setting result >>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the value of the attribute weight_quant_dtype is INT1 >>> ## Since we set act_quant_dtype to be QuantDtype.INT8, the value of the attribute weight_quant_dtype is INT8 >>> ## Since we set enable_act_quant to be True, the value of the attribute enable_act_quant is True >>> ## Since we set enable_bn_calibration to be True, the value of the attribute enable_bn_calibration is True >>> ## Since we set epoch_size to be 100, the value of the attribute epoch_size is 100 >>> ## Since we set has_trained_epoch to be 0, the value of the attribute has_trained_epoch is 0 >>> ## Since we set t_start_val to be 1.0, the value of the attribute t_start_val is 1.0 >>> ## Since we set t_start_time to be 0.2, the value of the attribute t_start_time is 0.2 >>> ## Since we set t_end_time to be 0.6, the value of the attribute t_end_time is 0.6 >>> ## Since we set t_factor to be 1.2, the value of the attribute t_factor is 1.2 >>> print(slb_quantization) SlbQuantAwareTraining<weight_quant_dtype=INT1, act_quant_dtype=INT8, enable_act_quant=True, enable_bn_calibration=True, epoch_size=100, has_trained_epoch=0, t_start_val=1.0, t_start_time=0.2, t_end_time=0.6, t_factor=1.2> >>> ## 5) Apply SLB QAT-algorithm to origin network >>> net_qat = slb_quantization.apply(net) >>> ## 6) Print network and check the result. Conv2d should be transformed to QuantizeWrapperCells. >>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the bit_num value of fake_quant_weight >>> ## should be 1, and the weight_bit_num value of Conv2dSlbQuant should be 1. >>> print(net_qat) NetToQuantOpt< (_handler): NetToQuant< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.9, gamma=Parameter(name=bn.gamma, requires_grad=True, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.]), beta=Parameter(name=bn.beta, requires_grad=True, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_mean=Parameter(name=bn.moving_mean, requires_grad=False, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_variance=Parameter(name=bn.moving_variance, requires_grad=False, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.])> > (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.9, gamma=Parameter(name=bn.gamma, requires_grad=True, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.]), beta=Parameter(name=bn.beta, requires_grad=True, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_mean=Parameter(name=bn.moving_mean, requires_grad=False, shape=[6], dtype=Float32, value= [0., 0., 0., 0., 0., 0.]), moving_variance=Parameter(name=bn.moving_variance, requires_grad=False, shape=[6], dtype=Float32, value= [1., 1., 1., 1., 1., 1.])> (Conv2dSlbQuant): QuantizeWrapperCell< (_handler): Conv2dSlbQuant< in_channels=1, out_channels=6, kernel_size=(5, 5), weight_bit_num=1, stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False (fake_quant_weight): SlbFakeQuantizerPerLayer<bit_num=1> > (_input_quantizer): SlbActQuantizer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900> (_output_quantizer): SlbActQuantizer<bit_num=8, symmetric=False, narrow_range=False, ema=False(0.999), per_channel=False, quant_delay=900> > > >>> ## 7) convert a compressed network to a standard network before exporting to MindIR. >>> net_qat = slb_quantization.convert(net_qat) >>> data_in = mindspore.Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) >>> file_name = "./conv.mindir" >>> mindspore.export(net_qat, data_in, file_name=file_name, file_format="MINDIR") >>> graph = mindspore.load(file_name) >>> mindspore.nn.GraphCell(graph)
- apply(network: Cell)[源代码]
按照下面4个步骤对给定网络应用量化算法,得到带有伪量化节点的网络。
使用网络策略中定义的模式引擎在给定网络中融合特定的单元。
传播通过单元定义的层策略。
当量化器冗余时,减少冗余的伪量化器。
应用层策略将正常 Cell 转换为 QuantizeWrapperCell 。
- 参数:
network (Cell) - 即将被量化的网络。
- 返回:
在原网络定义的基础上,修改需要量化的网络层后生成带有伪量化节点的网络。
- callbacks(model: Model, dataset: Dataset)[源代码]
定义SLB量化算法特有的一些callbacks,其中包括用于调节温度因子的callback。
- 参数:
model (Model) - 经过算法修改后的网络构造的mindspore的Model对象。
dataset (Dataset) - 加载了特定数据集的Dataset对象。
- 异常:
RuntimeError - epoch_size 没有初始化。
RuntimeError - has_trained_epoch 没有初始化。
ValueError - epoch_size 小于等于 has_trained_epoch 。
ValueError - t_end_time 小于 t_start_time 。
TypeError - model 的数据类型不是mindspore.Model。
TypeError - dataset 的数据类型不是mindspore.dataset.Dataset。
- 返回:
SLB量化算法特有的一些callbacks的列表。
- convert(net_opt: Cell, ckpt_path='')[源代码]
定义将SLB量化网络转换成适配MindIR的标准网络的具体实现。
- 参数:
net_opt (Cell) - 经过SLB量化算法量化后的网络。
ckpt_path (str) - checkpoint文件的存储路径,为空时不加载,默认值为空。
- 异常:
TypeError - net_opt 的数据类型不是mindspore.nn.Cell。
TypeError - ckpt_path 的数据类型不是str。
ValueError - ckpt_path 不为空,但不是有效文件。
RuntimeError - ckpt_path 是有效文件,但加载失败。
- 返回:
能适配MindIR的标准网络。
- set_act_quant_dtype(act_quant_dtype=QuantDtype.INT8)[源代码]
设置激活量化的数据类型。
- 参数:
act_quant_dtype (QuantDtype) - 激活量化的数据类型。默认值:QuantDtype.INT8。
- 异常:
TypeError - act_quant_dtype 的数据类型不是QuantDtype。
ValueError - act_quant_dtype 不是 QuantDtype.INT8 。
- set_enable_act_quant(enable_act_quant=False)[源代码]
设置是否开启激活量化。
- 参数:
enable_act_quant (bool) - 在训练中是否开启激活量化。默认值:False。
- 异常:
TypeError - enable_act_quant 的数据类型不是bool。
- set_enable_bn_calibration(enable_bn_calibration=False)[源代码]
设置是否开启BatchNorm层矫正功能。
- 参数:
enable_bn_calibration (bool) - 在训练中是否开启BatchNorm层矫正功能。默认值:False。
- 异常:
TypeError - enable_bn_calibration 的数据类型不是bool。
- set_epoch_size(epoch_size)[源代码]
设置训练的总epoch数。
- 参数:
epoch_size (int) - 训练的总epoch数。
- 异常:
TypeError - epoch_size 的数据类型不是int。
ValueError - epoch_size 小于等于0。
- set_has_trained_epoch(has_trained_epoch)[源代码]
设置预训练的epoch数。
- 参数:
has_trained_epoch (int) - 预训练的epoch数。
- 异常:
TypeError - has_trained_epoch 的数据类型不是int。
ValueError - has_trained_epoch 小于0。
- set_t_end_time(t_end_time=0.6)[源代码]
设置温度停止变化时间。
- 参数:
t_end_time (float) - 温度停止变化时间。默认值:0.6。
- 异常:
TypeError - t_end_time 的数据类型不是float。
ValueError - t_end_time 小于0.0或大于1.0。
- set_t_factor(t_factor=1.2)[源代码]
设置温度变化因子。
- 参数:
t_factor (float) - 温度变化因子。默认值:1.2。
- 异常:
TypeError - t_factor 的数据类型不是float。
ValueError - t_factor 小于等于0.0。
- set_t_start_time(t_start_time=0.2)[源代码]
设置温度开始变化时间。
- 参数:
t_start_time (float) - 温度开始变化时间。默认值:0.2。
- 异常:
TypeError - t_start_time 的数据类型不是float。
ValueError - t_start_time 小于0.0或大于1.0。