# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScopPruner."""
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore.ops import constexpr
from mindspore import Parameter
from mindspore_gs.common.validator import Validator, Rel
from ...comp_algo import CompAlgo, CompAlgoConfig
@constexpr
def generate_int(shape):
"""Generate int."""
return int(shape // 2)
class KfConv2d(nn.Cell):
"""KF Conv2d."""
def __init__(self, conv_ori, bn_ori, prex):
super(KfConv2d, self).__init__()
self.conv = conv_ori
self.bn = bn_ori
self.out_channels = self.conv.out_channels
self.kfscale = Parameter(ops.Ones()((1, self.out_channels, 1, 1), mindspore.float32), requires_grad=True,
name=prex + '.kfscale')
self.kfscale.data.asnumpy().fill(0.5)
self.concat_op = ops.Concat(axis=0)
def construct(self, x):
"""Calculate."""
x = self.conv(x)
if self.training:
num_ori = generate_int(x.shape[0])
x = self.concat_op((self.kfscale * x[:num_ori] + (1 - self.kfscale) * x[num_ori:], x[num_ori:]))
x = self.bn(x)
return x
@constexpr
def generate_tensor(shape, mask_list):
"""Generate tensor."""
mask = ops.Ones()((shape), mstype.float16).asnumpy()
for i in mask_list:
mask[:, i, :, :] = 0.0
new_mask = Tensor(mask)
new_mask.set_dtype(mstype.bool_)
return new_mask
class MaskedConv2dbn(nn.Cell):
"""Mask Conv2d and bn."""
def __init__(self, kf_conv2d_ori, prex):
super(MaskedConv2dbn, self).__init__()
self.target = context.get_context("device_target").upper()
self.conv = kf_conv2d_ori.conv
self.bn = kf_conv2d_ori.bn
self.zeros = ops.Zeros()
self.one = ops.Ones()
self.out_index = Parameter(kf_conv2d_ori.out_index, requires_grad=False, name=prex + '.out_index')
self.cast = ops.Cast()
self.mask = self.out_index.asnumpy().tolist()
def construct(self, x):
"""Calculate."""
x = self.conv(x)
x = self.bn(x)
if self.target == 'ASCEND':
new_mask = generate_tensor(x.shape, self.mask)
output = ops.MaskedFill()(x, new_mask, 0.0)
return output
mask = self.zeros((x.shape), mstype.float32)
mask[:, self.mask, :, :] = 1.0
x = x * mask
return x
class PrunedConv2dbn1(nn.Cell):
"""Prune Conv2d and bn."""
def __init__(self, masked_module):
super(PrunedConv2dbn1, self).__init__()
newconv = nn.Conv2d(in_channels=masked_module.conv.in_channels, out_channels=len(masked_module.out_index),
kernel_size=masked_module.conv.kernel_size, stride=masked_module.conv.stride,
has_bias=False, padding=masked_module.conv.padding, pad_mode='pad')
self.conv = newconv
weight_data = masked_module.conv.weight.data.clone()
self.conv.weight = Parameter(ops.Gather()(weight_data, masked_module.out_index, 0), requires_grad=True,
name=masked_module.conv.weight.name)
newbn = nn.BatchNorm2d(len(masked_module.out_index))
self.bn = newbn
self.bn.gamma = Parameter(ops.Gather()(masked_module.bn.gamma.data.clone(), masked_module.out_index, 0),
requires_grad=True, name=masked_module.bn.gamma.name)
self.bn.beta = Parameter(ops.Gather()(masked_module.bn.beta.data.clone(), masked_module.out_index, 0),
requires_grad=True, name=masked_module.bn.beta.name)
self.bn.moving_mean = Parameter(
ops.Gather()(masked_module.bn.moving_mean.data.clone(), masked_module.out_index, 0), requires_grad=False,
name=masked_module.bn.moving_mean.name)
self.bn.moving_variance = Parameter(
ops.Gather()(masked_module.bn.moving_variance.data.clone(), masked_module.out_index, 0),
requires_grad=False, name=masked_module.bn.moving_variance.name)
self.oriout_channels = masked_module.conv.out_channels
self.out_index = masked_module.out_index
def construct(self, x):
"""Calculate."""
x = self.conv(x)
x = self.bn(x)
return x
class PrunedConv2dbnmiddle(nn.Cell):
"""Prune Conv2d and bn."""
def __init__(self, masked_module):
super(PrunedConv2dbnmiddle, self).__init__()
newconv = nn.Conv2d(in_channels=len(masked_module.in_index), out_channels=len(masked_module.out_index),
kernel_size=masked_module.conv.kernel_size, stride=masked_module.conv.stride,
has_bias=False, padding=masked_module.conv.padding, pad_mode=masked_module.conv.pad_mode)
self.conv = newconv
weight_data = masked_module.conv.weight.data.clone()
weight_data = ops.Gather()(ops.Gather()(weight_data, masked_module.out_index, 0), masked_module.in_index, 1)
self.conv.weight = Parameter(weight_data, requires_grad=True, name=masked_module.conv.weight.name)
newbn = nn.BatchNorm2d(len(masked_module.out_index))
self.bn = newbn
self.bn.gamma = Parameter(ops.Gather()(masked_module.bn.gamma.data.clone(), masked_module.out_index, 0),
requires_grad=True, name=masked_module.bn.gamma.name)
self.bn.beta = Parameter(ops.Gather()(masked_module.bn.beta.data.clone(), masked_module.out_index, 0),
requires_grad=True, name=masked_module.bn.beta.name)
self.bn.moving_mean = Parameter(
ops.Gather()(masked_module.bn.moving_mean.data.clone(), masked_module.out_index, 0), requires_grad=False,
name=masked_module.bn.moving_mean.name)
self.bn.moving_variance = Parameter(
ops.Gather()(masked_module.bn.moving_variance.data.clone(), masked_module.out_index, 0),
requires_grad=False, name=masked_module.bn.moving_variance.name)
self.oriout_channels = masked_module.conv.out_channels
self.out_index = masked_module.out_index
def construct(self, x):
"""Calculate."""
x = self.conv(x)
x = self.bn(x)
return x
class PrunedConv2dbn2(nn.Cell):
"""Prune Conv2d and bn."""
def __init__(self, masked_module):
super(PrunedConv2dbn2, self).__init__()
newconv = nn.Conv2d(in_channels=len(masked_module.in_index), out_channels=len(masked_module.out_index),
kernel_size=masked_module.conv.kernel_size, stride=masked_module.conv.stride,
has_bias=False, padding=masked_module.conv.padding, pad_mode='pad')
self.conv = newconv
weight_data = masked_module.conv.weight.data.clone()
weight_data = ops.Gather()(ops.Gather()(weight_data, masked_module.out_index, 0), masked_module.in_index, 1)
self.conv.weight = Parameter(weight_data, requires_grad=True, name=masked_module.conv.weight.name)
newbn = nn.BatchNorm2d(len(masked_module.out_index))
self.bn = newbn
self.bn.gamma = Parameter(ops.Gather()(masked_module.bn.gamma.data.clone(), masked_module.out_index, 0),
requires_grad=True, name=masked_module.bn.gamma.name)
self.bn.beta = Parameter(ops.Gather()(masked_module.bn.beta.data.clone(), masked_module.out_index, 0),
requires_grad=True, name=masked_module.bn.beta.name)
self.bn.moving_mean = Parameter(
ops.Gather()(masked_module.bn.moving_mean.data.clone(), masked_module.out_index, 0), requires_grad=False,
name=masked_module.bn.moving_mean.name)
self.bn.moving_variance = Parameter(
ops.Gather()(masked_module.bn.moving_variance.data.clone(), masked_module.out_index, 0),
requires_grad=False, name=masked_module.bn.moving_variance.name)
self.oriout_channels = masked_module.conv.out_channels
self.out_index = masked_module.out_index
self.zeros = ops.Zeros()
def construct(self, x):
"""Calculate."""
x = self.conv(x)
x = self.bn(x)
output = self.zeros((x.shape[0], self.oriout_channels, x.shape[2], x.shape[3]), mstype.float32)
output[:, self.out_index, :, :] = x
return output
class KfCallback(Callback):
"""
Define konockoff data callback for scop algorithm.
"""
def step_begin(self, run_context):
"""
Step_begin.
"""
cb_params = run_context.original_args()
cur_data = cb_params.train_dataset_element
kf = cur_data[0]
kf_label = cur_data[1]
idx = ops.Randperm(max_length=kf.shape[0])(mindspore.Tensor([kf.shape[0]], dtype=mstype.int32))
kf_input = kf[idx, :].view(kf.shape)
kf_input_label = kf_label[idx].view(kf_label.shape)
cur_data[0] = ops.Concat(axis=0)((cur_data[0], kf_input))
cur_data[1] = ops.Concat(axis=0)((cur_data[1], kf_input_label))
cb_params.train_dataset_element = cur_data
[docs]class PrunerKfCompressAlgo(CompAlgo):
"""
`PrunerKfCompressAlgo` is a subclass of CompAlgo, which implements the use of high imitation data to learn and
discover redundant convolution kernels in the SCOP algorithm.
Note:
For the input parameter `config`, there is currently no optional configuration item for `PrunerKfCompressAlgo`,
but for compatibility, `config` is reserved and replaced with an empty dictionary during initialization.
Such as `kf_pruning = PrunerKfCompressAlgo({})`.
Args:
config (dict): Configuration of `PrunerKfCompressAlgo`. There are no configurable options for
`PrunerKfCompressAlgo` currently, but for compatibility, the config parameter in the constructor of class A
is retained.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> from mindspore_gs.pruner import PrunerKfCompressAlgo
>>> from mindspore import nn
>>> class Net(nn.Cell):
... def __init__(self, num_channel=1):
... super(Net, self).__init__()
... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
... self.bn = nn.BatchNorm2d(6)
...
... def construct(self, x):
... x = self.conv(x)
... x = self.bn(x)
... return x
...
... class NetToPrune(nn.Cell):
... def __init__(self):
... super(NetToPrune, self).__init__()
... self.layer = Net()
...
... def construct(self, x):
... x = self.layer(x)
... return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToPrune()
>>> ## 2) Define Knockoff Algorithm
>>> kf_pruning = PrunerKfCompressAlgo({})
>>> ## 3) Apply Konckoff-algorithm to origin network
>>> net_pruning = kf_pruning.apply(net)
>>> ## 4) Print network and check the result. Conv2d and bn should be transformed to KfConv2d.
>>> print(net_pruning)
NetToPrune<
(layer): Net<
(conv): KfConv2d<
(conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid,
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter
(name=conv.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter
(name=conv.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter
(name=conv.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter
(name=conv.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
>
(bn): SequentialCell<>
>
>
"""
[docs] def callbacks(self, *args, **kwargs):
"""
Define the callbacks for SCOP algorithm,the callback that generates konockoff data.
Returns:
List of instance of SCOP Callbacks.
"""
cb = []
cb.append(KfCallback())
cb.extend(super(PrunerKfCompressAlgo, self).callbacks())
return cb
def _tranform(self, net):
"""Transform net."""
module = net._cells
keys = list(module.keys())
for _, k in enumerate(keys):
if 'layer' in k:
module[k] = self._tranform_conv(module[k])
for param in net.get_parameters():
param.requires_grad = False
for _, (_, module) in enumerate(net.cells_and_names()):
if isinstance(module, KfConv2d):
module.kfscale.requires_grad = True
return net
def _tranform_conv(self, net):
"""Transform conv."""
def _inject(modules):
keys = list(modules.keys())
for ik, k in enumerate(keys):
if isinstance(modules[k], nn.Conv2d):
if k not in ('0', 'conv1_3x3', 'conv1_7x7'):
for value, param in modules[k].parameters_and_names():
prex = param.name.strip(value)
modules[k] = KfConv2d(modules[k], modules[keys[ik + 1]], prex)
for params in modules[k].conv.get_parameters():
params.name = prex + params.name
for params in modules[k].bn.get_parameters():
params.name = prex + params.name
modules[keys[ik + 1]] = nn.SequentialCell()
elif (not isinstance(modules[k], KfConv2d)) and modules[k]._cells:
_inject(modules[k]._cells)
_inject(net._cells)
return net
[docs] def apply(self, network, **kwargs):
"""
Transform input `network` to a knockoff network.
Args:
network (Cell): Network to be pruned.
Returns:
Knockoff network.
Raises:
TypeError: If `network` is not Cell.
"""
network = Validator.check_isinstance('network', network, nn.Cell)
return self._tranform(network)
[docs]class PrunerFtCompressAlgo(CompAlgo):
"""
`PrunerFtCompressAlgo` is a subclass of CompAlgo that implements the ability to remove redundant convolution kernels
and fully train the network.
Args:
config (dict): Configuration of `PrunerFtCompressAlgo`, keys are attribute names,
values are attribute values. Supported attribute are listed below:
- prune_rate (float): number in [0.0, 1.0).
Raises:
TypeError: If `prune_rate` is not float.
ValueError: If `epoch_size` is less than 0 or greater than or equal to 1.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> from mindspore_gs.pruner import PrunerKfCompressAlgo, PrunerFtCompressAlgo
>>> from mindspore import nn
>>> class Net(nn.Cell):
... def __init__(self, num_channel=1):
... super(Net, self).__init__()
... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
... self.bn = nn.BatchNorm2d(6)
...
... def construct(self, x):
... x = self.conv(x)
... x = self.bn(x)
... return x
...
... class NetToPrune(nn.Cell):
... def __init__(self):
... super(NetToPrune, self).__init__()
... self.layer = Net()
...
... def construct(self, x):
... x = self.layer(x)
... return x
...
>>> net = NetToPrune()
>>> kf_pruning = PrunerKfCompressAlgo({})
>>> net_pruning_kf = kf_pruning.apply(net)
>>> ## 1) Define FineTune Algorithm
>>> ft_pruning = PrunerFtCompressAlgo({'prune_rate': 0.5})
>>> ## 2) Apply FineTune-algorithm to origin network
>>> net_pruning_ft = ft_pruning.apply(net_pruning_kf)
>>> ## 3) Print network and check the result. Conv2d and bn should be transformed to KfConv2d.
>>> print(net_pruning_ft)
NetToPrune<
(layer): Net<
(conv): MaskedConv2dbn<
(conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid,
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter
(name=conv.bn.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter
(name=conv.bn.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter
(name=conv.bn.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter
(name=conv.bn.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
>
(bn): SequentialCell<>
>
>
"""
def _create_config(self):
"""Create PrunerFtCompressConfig."""
self._config = PrunerFtCompressConfig()
def _update_config_from_dict(self, config: dict):
"""Update prune `config` from a dict"""
self.set_prune_rate(config.get("prune_rate", 0.0))
[docs] def set_prune_rate(self, prune_rate: float):
"""
Set value of prune_rate of `_config`
Args:
prune_rate (float): the size of network needs to be pruned.
Raises:
TypeError: If `prune_rate` is not float.
ValueError: If `prune_rate` is less than 0. or greater than 1.
"""
prune_rate = Validator.check_float_range(prune_rate, 0.0, 1.0, Rel.INC_LEFT,
"prune_rate", self.__class__.__name__)
self._config.prune_rate = prune_rate
def _recover(self, net):
"""Recover."""
kfconv_list = []
for _, (_, module) in enumerate(net.cells_and_names()):
if isinstance(module, KfConv2d):
kfconv_list.append(module)
for param in net.get_parameters():
param.requires_grad = True
for _, (_, module) in enumerate(net.cells_and_names()):
if isinstance(module, KfConv2d):
module.score = module.bn.gamma.data.abs() * ops.Squeeze()(
module.kfscale.data - (1 - module.kfscale.data))
for kfconv in kfconv_list:
kfconv.prune_rate = self._config.prune_rate
for _, (_, module) in enumerate(net.cells_and_names()):
if isinstance(module, KfConv2d):
_, index = ops.Sort()(module.score)
num_pruned_channel = int(module.prune_rate * module.score.shape[0])
module.out_index = index[num_pruned_channel:]
return self._recover_conv(net)
def _recover_conv(self, net):
"""Recover conv."""
def _inject(modules):
keys = list(modules.keys())
for _, k in enumerate(keys):
if isinstance(modules[k], KfConv2d):
for value, param in modules[k].parameters_and_names():
prex = param.name.strip(value.split('.')[-1])
modules[k] = MaskedConv2dbn(modules[k], prex)
for params in modules[k].conv.get_parameters():
params.name = prex + params.name
for params in modules[k].bn.get_parameters():
params.name = prex + params.name
elif (not isinstance(modules[k], KfConv2d)) and modules[k]._cells:
_inject(modules[k]._cells)
_inject(net._cells)
return net
def _pruning_conv(self, net):
"""Prune conv."""
def _inject(modules):
keys = list(modules.keys())
for _, k in enumerate(keys):
if isinstance(modules[k], MaskedConv2dbn):
if 'conv1' in k:
modules[k] = PrunedConv2dbn1(modules[k])
elif 'conv2' in k:
modules[k] = PrunedConv2dbnmiddle(modules[k])
elif 'conv3' in k:
modules[k] = PrunedConv2dbn2(modules[k])
elif (not isinstance(modules[k], KfConv2d)) and modules[k]._cells:
_inject(modules[k]._cells)
_inject(net._cells)
return net
[docs] def apply(self, network, **kwargs):
"""
Transform a knockoff `network` to a normal and pruned network.
Args:
network (Cell): Knockoff network.
Returns:
Pruned network.
Raises:
TypeError: If `network` is not Cell.
"""
network = Validator.check_isinstance('network', network, nn.Cell)
return self._recover(network)
class PrunerFtCompressConfig(CompAlgoConfig):
"""Config for PrunerFtCompress."""
def __init__(self):
super(PrunerFtCompressConfig, self).__init__()
self.prune_rate = 0.0