# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""quantization aware."""
import copy
import re
import numpy as np
import mindspore.context as context
from ... import log as logger
from ... import nn, ops
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel
from ...common import Tensor
from ...common import dtype as mstype
from ...common.api import _executor
from ...nn.layer import quant
from ...ops import functional as F
from ...ops import operations as P
from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
nn.ReLU6: quant.ActQuant,
nn.Sigmoid: quant.ActQuant,
nn.LeakyReLU: quant.LeakyReLUQuant,
nn.HSigmoid: quant.HSigmoidQuant,
nn.HSwish: quant.HSwishQuant}
class _AddFakeQuantInput(nn.Cell):
"""
Add FakeQuant OP at input of the network. Only support one input case.
"""
def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input.')
self.network = network
def construct(self, data):
data = self.fake_quant_input(data)
output = self.network(data)
return output
class _AddFakeQuantAfterSubCell(nn.Cell):
"""
Add FakeQuant OP after of the sub Cell.
"""
def __init__(self, subcell, **kwargs):
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
self.subcell = subcell
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
max_init=6,
ema=True,
num_bits=kwargs["num_bits"],
quant_delay=kwargs["quant_delay"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
narrow_range=kwargs["narrow_range"])
def construct(self, *data):
output = self.subcell(*data)
output = self.fake_quant_act(output)
return output
class ConvertToQuantNetwork:
"""
Convert network to quantization aware network
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, **kwargs):
self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"])
self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0])
self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1])
self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0])
self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1])
self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0])
self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1])
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
def _convert_op_name(self, name):
pattern = re.compile(r'([A-Z]{1})')
name_new = re.sub(pattern, r'_\1', name).lower()
if name_new[0] == '_':
name_new = name_new[1:]
return name_new
def run(self):
self.network.update_cell_prefix()
network = self._convert_subcells2quant(self.network)
self.network.update_cell_type("quant")
return network
def _convert_subcells2quant(self, network):
"""
convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
"""
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
prefix = subcell.param_prefix
new_subcell = self._convert_method_map[type(subcell)](subcell)
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
else:
self._convert_subcells2quant(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
# add FakeQuant OP after OP in while list
add_list = []
for name in network.__dict__:
if name[0] == '_':
continue
attr = network.__dict__[name]
if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
add_list.append((name, attr))
for name, prim_op in add_list:
prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
prefix = self._convert_op_name(prim_op.name)
if network.param_prefix:
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
add_quant.update_parameters_name(prefix + '.')
del network.__dict__[name]
network.insert_child_to_cell(name, add_quant)
return network
def _convert_conv(self, subcell):
"""
convert Conv2d cell to quant cell
"""
conv_inner = subcell.conv
if subcell.has_bn:
if self.bn_fold:
bn_inner = subcell.batchnorm
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
stride=conv_inner.stride,
pad_mode=conv_inner.pad_mode,
padding=conv_inner.padding,
dilation=conv_inner.dilation,
group=conv_inner.group,
eps=bn_inner.eps,
momentum=bn_inner.momentum,
quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
fake=True,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta
conv_inner.moving_mean = subcell.batchnorm.moving_mean
conv_inner.moving_variance = subcell.batchnorm.moving_variance
del subcell.batchnorm
subcell.batchnorm = None
subcell.has_bn = False
else:
bn_inner = subcell.batchnorm
conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
stride=conv_inner.stride,
pad_mode=conv_inner.pad_mode,
padding=conv_inner.padding,
dilation=conv_inner.dilation,
group=conv_inner.group,
eps=bn_inner.eps,
momentum=bn_inner.momentum,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network BatchNormal OP parameters to quant network
conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
conv_inner.batchnorm.beta = subcell.batchnorm.beta
conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
del subcell.batchnorm
subcell.batchnorm = None
subcell.has_bn = False
else:
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
stride=conv_inner.stride,
pad_mode=conv_inner.pad_mode,
padding=conv_inner.padding,
dilation=conv_inner.dilation,
group=conv_inner.group,
has_bias=conv_inner.has_bias,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network Conv2D OP parameters to quant network
conv_inner.weight = subcell.conv.weight
if subcell.conv.has_bias:
conv_inner.bias = subcell.conv.bias
subcell.conv = conv_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
return subcell
def _convert_dense(self, subcell):
"""
convert dense cell to combine dense cell
"""
dense_inner = subcell.dense
dense_inner = quant.DenseQuant(dense_inner.in_channels,
dense_inner.out_channels,
has_bias=dense_inner.has_bias,
num_bits=self.weight_bits,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
# change original network Dense OP parameters to quant network
dense_inner.weight = subcell.dense.weight
if subcell.dense.has_bias:
dense_inner.bias = subcell.dense.bias
subcell.dense = dense_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
return subcell
def _convert_activation(self, activation):
act_class = activation.__class__
if act_class not in _ACTIVATION_MAP:
raise ValueError("Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](activation=activation,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
class ExportToQuantInferNetwork:
"""
Convert quantization aware network to infer network.
Args:
network (Cell): MindSpore network API `convert_quant_network`.
inputs (Tensor): Input tensors of the `quantization aware training network`.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
Returns:
Cell, Infer network.
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs):
network = validator.check_isinstance('network', network, (nn.Cell,))
# quantize for inputs: q = f / scale + zero_point
# dequantize for outputs: f = (q - zero_point) * scale
self.input_scale = round(mean)
self.input_zero_point = 1 / std_dev
self.data_type = mstype.int8
self.network = copy.deepcopy(network)
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
self.get_inputs_table(inputs)
def get_inputs_table(self, inputs):
"""Get the support info for quant export."""
phase_name = 'export_quant'
graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
def run(self):
"""Start to convert."""
self.network.update_cell_prefix()
network = self.network
if isinstance(network, _AddFakeQuantInput):
network = network.network
network = self._convert_quant2deploy(network)
return network
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
"""convet network's quant subcell to deploy subcell"""
# Calculate the scale and zero point
w_minq_name = cell_core.fake_quant_weight.minq.name
np_type = mstype.dtype_to_nptype(self.data_type)
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
info = self.quant_info_table.get(w_minq_name, None)
if info:
fack_quant_a_in_op, minq_name = info
if minq_name == 'input':
scale_a_in, zp_a_in = self.input_scale, self.input_zero_point
else:
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
minq = self.all_parameters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
else:
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
return None
# Build the `Quant` `Dequant` op.
# Quant only support perlayer version. Need check here.
quant_op = inner.Quant(float(scale_a_in), float(zp_a_in))
sqrt_mode = False
scale_deq = scale_a_out * scale_w
if (scale_deq < 2 ** -14).all():
scale_deq = np.sqrt(scale_deq)
sqrt_mode = True
dequant_op = inner.Dequant(sqrt_mode)
if isinstance(activation, _AddFakeQuantAfterSubCell):
activation = activation.subcell
elif hasattr(activation, "get_origin"):
activation = activation.get_origin()
# get the `weight` and `bias`
weight = cell_core.weight.data.asnumpy()
bias = None
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
if cell_core.has_bias:
bias = cell_core.bias.data.asnumpy()
elif isinstance(cell_core, quant.Conv2dBnFoldQuant):
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
# apply the quant
weight = quant_utils.weight2int(weight, scale_w, zp_w)
if bias is not None:
bias = Tensor(scale_a_in * scale_w * bias, mstype.int32)
scale_deq = Tensor(scale_deq, mstype.float16)
# get op
if isinstance(cell_core, quant.DenseQuant):
op_core = P.MatMul()
weight = np.transpose(weight)
else:
op_core = cell_core.conv
weight = Tensor(weight, self.data_type)
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
def _convert_quant2deploy(self, network):
"""Convert network's all quant subcell to deploy subcell."""
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
cell_core = None
fake_quant_act = None
activation = None
if isinstance(subcell, quant.Conv2dBnAct):
cell_core = subcell.conv
activation = subcell.activation
fake_quant_act = activation.fake_quant_act
elif isinstance(subcell, quant.DenseBnAct):
cell_core = subcell.dense
activation = subcell.activation
fake_quant_act = activation.fake_quant_act
if cell_core is not None:
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
if new_subcell:
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
op = subcell.subcell
if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive):
network.__delattr__(name)
network.__setattr__(name, op)
change = True
else:
self._convert_quant2deploy(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
return network
def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='AIR'):
"""
Exports MindSpore quantization predict model to deploy with AIR.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `quantization aware training network`.
file_name (str): File name of model to export.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported
quantization aware model. Default: 'AIR'.
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models.
Recommended suffix for output file is '.mindir'.
"""
supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR']
mean = validator.check_type("mean", mean, (int, float))
std_dev = validator.check_type("std_dev", std_dev, (int, float))
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
if file_format not in supported_formats:
raise ValueError('Illegal file format {}.'.format(file_format))
network.set_train(False)
exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
deploy_net = exporter.run()
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
[docs]def convert_quant_network(network,
bn_fold=True,
freeze_bn=10000000,
quant_delay=(0, 0),
num_bits=(8, 8),
per_channel=(False, False),
symmetric=(False, False),
narrow_range=(False, False)
):
r"""
Create quantization aware training network.
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False)
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: (False, False)
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
The first element represents weights and the second element represents data flow. Default: (False, False)
Returns:
Cell, Network which has change to quantization aware training network cell.
"""
support_device = ["Ascend", "GPU"]
def convert2list(name, value):
if not isinstance(value, list) and not isinstance(value, tuple):
value = [value]
elif len(value) > 2:
raise ValueError("input `{}` len should less then 2".format(name))
return value
quant_delay = convert2list("quant delay", quant_delay)
num_bits = convert2list("num bits", num_bits)
per_channel = convert2list("per channel", per_channel)
symmetric = convert2list("symmetric", symmetric)
narrow_range = convert2list("narrow range", narrow_range)
if context.get_context('device_target') not in support_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
net = ConvertToQuantNetwork(network=network,
quant_delay=quant_delay,
bn_fold=bn_fold,
freeze_bn=freeze_bn,
num_bits=num_bits,
per_channel=per_channel,
symmetric=symmetric,
narrow_range=narrow_range)
return net.run()