mindspore_gs.quantization.SlbQuantAwareTraining
- class mindspore_gs.quantization.SlbQuantAwareTraining(config=None)[source]
Derived class of GoldenStick. SLB(Searching for Low-Bit Weights) QAT-algorithm.
- Parameters
config (dict) –
store attributes for quantization aware training, keys are attribute names, values are attribute values. Supported attribute are listed below:
quant_dtype (QuantDtype): Datatype used to quantize weights, weights quantization support int4|int2|int1 now. Default: QuantDtype.INT1.
epoch_size (int): Total training epochs.
has_trained_epoch (int): The trained epochs.
t_start_val (float): Initial value of temperature hyperparameters. Default: 1.
t_start_time (float): Fraction of epochs after which temperature hyperparameters starting changing. Default: 0.2.
t_end_time (float): Fraction of epochs after which temperature hyperparameters stopping changing. Default: 0.6.
t_factor (float): Multiplicative factor of temperature hyperparameters changing. Default: 1.2.
- Raises
TypeError – If quant_dtype is not QuantDtype.
TypeError – If epoch_size or has_trained_epoch is not an int.
TypeError – If t_start_val, t_start_time, t_end_time or t_factor is not float.
ValueError – If epoch_size is not greater than 0.
ValueError – If has_trained_epoch is less than 0.
ValueError – If t_start_val or t_factor is not greater than 0.
ValueError – If t_start_time or t_end_time is less than 0.
ValueError – If t_start_time or t_end_time is greater than 1.
- Supported Platforms:
GPU
Examples
>>> from mindspore_gs.quantization.slb import SlbQuantAwareTraining >>> from mindspore_gs.quantization.constant import QuantDtype >>> from mindspore import nn >>> class NetToQuant(nn.Cell): ... def __init__(self, num_channel=1): ... super(NetToQuant, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... >>> ## 1) Define network to be quantized >>> net = NetToQuant() >>> ## 2) Define SLB QAT-Algorithm >>> slb_quantization = SlbQuantAwareTraining() >>> ## 3) Use set functions to change config >>> ## 3.1) set_weight_quant_dtype is used to set the weight quantization bit, and support QuantDtype.INT4, QuantDtype.INT2, >>> ## QuantDtype.INT1 now. >>> slb_quantization.set_weight_quant_dtype(QuantDtype.INT1) >>> ## 3.2) set_epoch_size is used to set the epoch size of training. >>> slb_quantization.set_epoch_size(100) >>> ## 3.3) set_has_trained_epoch is used to set the trained epoch size of training. >>> slb_quantization.set_has_trained_epoch(0) >>> ## 3.4) set_t_start_val is used to set the initial value of temperature hyperparameters. >>> slb_quantization.set_t_start_val(1.0) >>> ## 3.5) set_t_start_time is used to set the fraction of epochs after which temperature hyperparameters starting changing. >>> slb_quantization.set_t_start_time(0.2) >>> ## 3.6) set_t_end_time is used to set the fraction of epochs after which temperature hyperparameters stopping changing. >>> slb_quantization.set_t_end_time(0.6) >>> ## 3.7) set_t_factor is used to set the multiplicative factor of temperature hyperparameters changing. >>> slb_quantization.set_t_factor(1.2) >>> ## 4) Print SLB QAT-Algorithm object and check the config setting result >>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the value of the attribute weight_quant_dtype is INT1 >>> ## Since we set epoch_size to be 100, the value of the attribute epoch_size is 100 >>> ## Since we set has_trained_epoch to be 0, the value of the attribute has_trained_epoch is 0 >>> ## Since we set t_start_val to be 1.0, the value of the attribute t_start_val is 1.0 >>> ## Since we set t_start_time to be 0.2, the value of the attribute t_start_time is 0.2 >>> ## Since we set t_end_time to be 0.6, the value of the attribute t_end_time is 0.6 >>> ## Since we set t_factor to be 1.2, the value of the attribute t_factor is 1.2 >>> print(slb_quantization) SlbQuantAwareTraining<weight_quant_dtype=INT1, epoch_size=100, has_trained_epoch=0, t_start_val=1.0, t_start_time=0.2, t_end_time=0.6, t_factor=1.2> >>> ## 5) Apply SLB QAT-algorithm to origin network >>> net_qat = slb_quantization.apply(net) >>> ## 6) Print network and check the result. Conv2d should be transformed to QuantizeWrapperCells. >>> ## Since we set weight_quant_dtype to be QuantDtype.INT1, the bit_num value of fake_quant_weight >>> ## should be 1, and the weight_bit_num value of Conv2dSlbQuant should be 1. >>> print(net_qat) NetToQuantOpt< (_handler): NetToQuant< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> (Conv2dSlbQuant): QuantizeWrapperCell< (_handler): Conv2dSlbQuant< in_channels=1, out_channels=6, kernel_size=(5, 5), weight_bit_num=1, stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False (fake_quant_weight): SlbFakeQuantizerPerLayer<bit_num=1> > > >
- apply(network: Cell)[source]
Apply SLB quantization Algorithm on network, use the following steps to make network available for quantization aware training:
Fuse certain cells in network using pattern engine which is defined by net policy.
Propagate layer policies defined through cells.
Reduce redundant fake quantizers when they are redundant.
Apply layer policies to convert normal cell to QuantizeWrapperCell.
- Parameters
network (Cell) – Network to be quantized.
- Returns
Quantized network.
- callbacks(model: Model)[source]
Define TemperatureScheduler callback for SLB QAT-algorithm.
- Parameters
model (Model) – Model to be used.
- Raises
RuntimeError – If epoch_size is not initialized!
RuntimeError – If has_trained_epoch is not initialized!
ValueError – If epoch_size is not greater than has_trained_epoch.
ValueError – If t_end_time is less than t_start_time.
TypeError – If model is not Model.
- Returns
List of instance of Callbacks.
- set_epoch_size(epoch_size)[source]
Set value of epoch_size of _config
- Parameters
epoch_size (int) – the epoch size of training.
- Raises
TypeError – If epoch_size is not int.
ValueError – If epoch_size is not greater than 0.
- set_has_trained_epoch(has_trained_epoch)[source]
Set value of has_trained_epoch of _config
- Parameters
has_trained_epoch (int) – the trained epochs of training.
- Raises
TypeError – If has_trained_epoch is not int.
ValueError – If has_trained_epoch is less than 0.
- set_t_end_time(t_end_time)[source]
Set value of t_end_time of _config
- Parameters
t_end_time (float) – Fraction of epochs after which temperature hyperparameters stopping changing, default: 0.6.
- Raises
TypeError – If t_end_time is not float.
ValueError – If t_end_time is less than 0. or greater than 1.
- set_t_factor(t_factor)[source]
Set value of t_factor of _config
- Parameters
t_factor (float) – Multiplicative factor of temperature hyperparameters changing, default: 1.2.
- Raises
TypeError – If t_factor is not float.
ValueError – If t_factor is not greater than 0.
- set_t_start_time(t_start_time)[source]
Set value of t_start_time of _config
- Parameters
t_start_time (float) – Fraction of epochs after which temperature hyperparameters starting changing, default: 0.2.
- Raises
TypeError – If t_start_time is not float.
ValueError – If t_start_time is less than 0. or greater than 1.
- set_t_start_val(t_start_val)[source]
Set value of t_start_val of _config
- Parameters
t_start_val (float) – Initial value of temperature hyperparameters, default: 1.0.
- Raises
TypeError – If t_start_val is not float.
ValueError – If t_start_val is not greater than 0.