Source code for mindspore_gs.pruner.scop.scop_pruner

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ScopPruner."""

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore.ops import constexpr
from mindspore import Parameter
from mindspore_gs.common.validator import Validator, Rel
from ...comp_algo import CompAlgo, CompAlgoConfig


@constexpr
def generate_int(shape):
    """Generate int."""
    return int(shape // 2)


class KfConv2d(nn.Cell):
    """KF Conv2d."""

    def __init__(self, conv_ori, bn_ori, prex):
        super(KfConv2d, self).__init__()
        self.conv = conv_ori
        self.bn = bn_ori
        self.out_channels = self.conv.out_channels
        self.kfscale = Parameter(ops.Ones()((1, self.out_channels, 1, 1), mindspore.float32), requires_grad=True,
                                 name=prex + '.kfscale')
        self.kfscale.data.asnumpy().fill(0.5)
        self.concat_op = ops.Concat(axis=0)

    def construct(self, x):
        """Calculate."""
        x = self.conv(x)
        if self.training:
            num_ori = generate_int(x.shape[0])
            x = self.concat_op((self.kfscale * x[:num_ori] + (1 - self.kfscale) * x[num_ori:], x[num_ori:]))
        x = self.bn(x)
        return x


@constexpr
def generate_tensor(shape, mask_list):
    """Generate tensor."""
    mask = ops.Ones()((shape), mstype.float16).asnumpy()
    for i in mask_list:
        mask[:, i, :, :] = 0.0
    new_mask = Tensor(mask)
    new_mask.set_dtype(mstype.bool_)
    return new_mask


class MaskedConv2dbn(nn.Cell):
    """Mask Conv2d and bn."""

    def __init__(self, kf_conv2d_ori, prex):
        super(MaskedConv2dbn, self).__init__()
        self.target = context.get_context("device_target").upper()
        self.conv = kf_conv2d_ori.conv
        self.bn = kf_conv2d_ori.bn
        self.zeros = ops.Zeros()
        self.one = ops.Ones()
        self.out_index = Parameter(kf_conv2d_ori.out_index, requires_grad=False, name=prex + '.out_index')
        self.cast = ops.Cast()
        self.mask = self.out_index.asnumpy().tolist()

    def construct(self, x):
        """Calculate."""
        x = self.conv(x)
        x = self.bn(x)
        if self.target == 'ASCEND':
            new_mask = generate_tensor(x.shape, self.mask)
            output = ops.MaskedFill()(x, new_mask, 0.0)
            return output
        mask = self.zeros((x.shape), mstype.float32)
        mask[:, self.mask, :, :] = 1.0
        x = x * mask
        return x


class PrunedConv2dbn1(nn.Cell):
    """Prune Conv2d and bn."""

    def __init__(self, masked_module):
        super(PrunedConv2dbn1, self).__init__()

        newconv = nn.Conv2d(in_channels=masked_module.conv.in_channels, out_channels=len(masked_module.out_index),
                            kernel_size=masked_module.conv.kernel_size, stride=masked_module.conv.stride,
                            has_bias=False, padding=masked_module.conv.padding, pad_mode='pad')
        self.conv = newconv
        weight_data = masked_module.conv.weight.data.clone()
        self.conv.weight = Parameter(ops.Gather()(weight_data, masked_module.out_index, 0), requires_grad=True,
                                     name=masked_module.conv.weight.name)

        newbn = nn.BatchNorm2d(len(masked_module.out_index))
        self.bn = newbn
        self.bn.gamma = Parameter(ops.Gather()(masked_module.bn.gamma.data.clone(), masked_module.out_index, 0),
                                  requires_grad=True, name=masked_module.bn.gamma.name)
        self.bn.beta = Parameter(ops.Gather()(masked_module.bn.beta.data.clone(), masked_module.out_index, 0),
                                 requires_grad=True, name=masked_module.bn.beta.name)
        self.bn.moving_mean = Parameter(
            ops.Gather()(masked_module.bn.moving_mean.data.clone(), masked_module.out_index, 0), requires_grad=False,
            name=masked_module.bn.moving_mean.name)
        self.bn.moving_variance = Parameter(
            ops.Gather()(masked_module.bn.moving_variance.data.clone(), masked_module.out_index, 0),
            requires_grad=False, name=masked_module.bn.moving_variance.name)

        self.oriout_channels = masked_module.conv.out_channels
        self.out_index = masked_module.out_index

    def construct(self, x):
        """Calculate."""
        x = self.conv(x)
        x = self.bn(x)
        return x


class PrunedConv2dbnmiddle(nn.Cell):
    """Prune Conv2d and bn."""

    def __init__(self, masked_module):
        super(PrunedConv2dbnmiddle, self).__init__()

        newconv = nn.Conv2d(in_channels=len(masked_module.in_index), out_channels=len(masked_module.out_index),
                            kernel_size=masked_module.conv.kernel_size, stride=masked_module.conv.stride,
                            has_bias=False, padding=masked_module.conv.padding, pad_mode=masked_module.conv.pad_mode)
        self.conv = newconv

        weight_data = masked_module.conv.weight.data.clone()
        weight_data = ops.Gather()(ops.Gather()(weight_data, masked_module.out_index, 0), masked_module.in_index, 1)
        self.conv.weight = Parameter(weight_data, requires_grad=True, name=masked_module.conv.weight.name)

        newbn = nn.BatchNorm2d(len(masked_module.out_index))
        self.bn = newbn
        self.bn.gamma = Parameter(ops.Gather()(masked_module.bn.gamma.data.clone(), masked_module.out_index, 0),
                                  requires_grad=True, name=masked_module.bn.gamma.name)
        self.bn.beta = Parameter(ops.Gather()(masked_module.bn.beta.data.clone(), masked_module.out_index, 0),
                                 requires_grad=True, name=masked_module.bn.beta.name)
        self.bn.moving_mean = Parameter(
            ops.Gather()(masked_module.bn.moving_mean.data.clone(), masked_module.out_index, 0), requires_grad=False,
            name=masked_module.bn.moving_mean.name)
        self.bn.moving_variance = Parameter(
            ops.Gather()(masked_module.bn.moving_variance.data.clone(), masked_module.out_index, 0),
            requires_grad=False, name=masked_module.bn.moving_variance.name)

        self.oriout_channels = masked_module.conv.out_channels
        self.out_index = masked_module.out_index

    def construct(self, x):
        """Calculate."""
        x = self.conv(x)
        x = self.bn(x)
        return x


class PrunedConv2dbn2(nn.Cell):
    """Prune Conv2d and bn."""

    def __init__(self, masked_module):
        super(PrunedConv2dbn2, self).__init__()

        newconv = nn.Conv2d(in_channels=len(masked_module.in_index), out_channels=len(masked_module.out_index),
                            kernel_size=masked_module.conv.kernel_size, stride=masked_module.conv.stride,
                            has_bias=False, padding=masked_module.conv.padding, pad_mode='pad')
        self.conv = newconv

        weight_data = masked_module.conv.weight.data.clone()
        weight_data = ops.Gather()(ops.Gather()(weight_data, masked_module.out_index, 0), masked_module.in_index, 1)
        self.conv.weight = Parameter(weight_data, requires_grad=True, name=masked_module.conv.weight.name)

        newbn = nn.BatchNorm2d(len(masked_module.out_index))
        self.bn = newbn
        self.bn.gamma = Parameter(ops.Gather()(masked_module.bn.gamma.data.clone(), masked_module.out_index, 0),
                                  requires_grad=True, name=masked_module.bn.gamma.name)
        self.bn.beta = Parameter(ops.Gather()(masked_module.bn.beta.data.clone(), masked_module.out_index, 0),
                                 requires_grad=True, name=masked_module.bn.beta.name)
        self.bn.moving_mean = Parameter(
            ops.Gather()(masked_module.bn.moving_mean.data.clone(), masked_module.out_index, 0), requires_grad=False,
            name=masked_module.bn.moving_mean.name)
        self.bn.moving_variance = Parameter(
            ops.Gather()(masked_module.bn.moving_variance.data.clone(), masked_module.out_index, 0),
            requires_grad=False, name=masked_module.bn.moving_variance.name)

        self.oriout_channels = masked_module.conv.out_channels
        self.out_index = masked_module.out_index
        self.zeros = ops.Zeros()

    def construct(self, x):
        """Calculate."""
        x = self.conv(x)
        x = self.bn(x)
        output = self.zeros((x.shape[0], self.oriout_channels, x.shape[2], x.shape[3]), mstype.float32)
        output[:, self.out_index, :, :] = x
        return output


class KfCallback(Callback):
    """
    Define konockoff data callback for scop algorithm.
    """

    def step_begin(self, run_context):
        """
        Step_begin.
        """
        cb_params = run_context.original_args()
        cur_data = cb_params.train_dataset_element
        kf = cur_data[0]
        kf_label = cur_data[1]
        idx = ops.Randperm(max_length=kf.shape[0])(mindspore.Tensor([kf.shape[0]], dtype=mstype.int32))
        kf_input = kf[idx, :].view(kf.shape)
        kf_input_label = kf_label[idx].view(kf_label.shape)
        cur_data[0] = ops.Concat(axis=0)((cur_data[0], kf_input))
        cur_data[1] = ops.Concat(axis=0)((cur_data[1], kf_input_label))
        cb_params.train_dataset_element = cur_data


[docs]class PrunerKfCompressAlgo(CompAlgo): """ `PrunerKfCompressAlgo` is a subclass of CompAlgo, which implements the use of high imitation data to learn and discover redundant convolution kernels in the SCOP algorithm. Note: For the input parameter `config`, there is currently no optional configuration item for `PrunerKfCompressAlgo`, but for compatibility, `config` is reserved and replaced with an empty dictionary during initialization. Such as `kf_pruning = PrunerKfCompressAlgo({})`. Args: config (dict): Configuration of `PrunerKfCompressAlgo`. There are no configurable options for `PrunerKfCompressAlgo` currently, but for compatibility, the config parameter in the constructor of class A is retained. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindspore_gs.pruner import PrunerKfCompressAlgo >>> from mindspore import nn >>> class Net(nn.Cell): ... def __init__(self, num_channel=1): ... super(Net, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... ... class NetToPrune(nn.Cell): ... def __init__(self): ... super(NetToPrune, self).__init__() ... self.layer = Net() ... ... def construct(self, x): ... x = self.layer(x) ... return x ... >>> ## 1) Define network to be quantized >>> net = NetToPrune() >>> ## 2) Define Knockoff Algorithm >>> kf_pruning = PrunerKfCompressAlgo({}) >>> ## 3) Apply Konckoff-algorithm to origin network >>> net_pruning = kf_pruning.apply(net) >>> ## 4) Print network and check the result. Conv2d and bn should be transformed to KfConv2d. >>> print(net_pruning) NetToPrune< (layer): Net< (conv): KfConv2d< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=conv.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=conv.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=conv.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=conv.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (bn): SequentialCell<> > > """
[docs] def callbacks(self, *args, **kwargs): """ Define the callbacks for SCOP algorithm,the callback that generates konockoff data. Returns: List of instance of SCOP Callbacks. """ cb = [] cb.append(KfCallback()) cb.extend(super(PrunerKfCompressAlgo, self).callbacks()) return cb
def _tranform(self, net): """Transform net.""" module = net._cells keys = list(module.keys()) for _, k in enumerate(keys): if 'layer' in k: module[k] = self._tranform_conv(module[k]) for param in net.get_parameters(): param.requires_grad = False for _, (_, module) in enumerate(net.cells_and_names()): if isinstance(module, KfConv2d): module.kfscale.requires_grad = True return net def _tranform_conv(self, net): """Transform conv.""" def _inject(modules): keys = list(modules.keys()) for ik, k in enumerate(keys): if isinstance(modules[k], nn.Conv2d): if k not in ('0', 'conv1_3x3', 'conv1_7x7'): for value, param in modules[k].parameters_and_names(): prex = param.name.strip(value) modules[k] = KfConv2d(modules[k], modules[keys[ik + 1]], prex) for params in modules[k].conv.get_parameters(): params.name = prex + params.name for params in modules[k].bn.get_parameters(): params.name = prex + params.name modules[keys[ik + 1]] = nn.SequentialCell() elif (not isinstance(modules[k], KfConv2d)) and modules[k]._cells: _inject(modules[k]._cells) _inject(net._cells) return net
[docs] def apply(self, network, **kwargs): """ Transform input `network` to a knockoff network. Args: network (Cell): Network to be pruned. Returns: Knockoff network. Raises: TypeError: If `network` is not Cell. """ network = Validator.check_isinstance('network', network, nn.Cell) return self._tranform(network)
[docs]class PrunerFtCompressAlgo(CompAlgo): """ `PrunerFtCompressAlgo` is a subclass of CompAlgo that implements the ability to remove redundant convolution kernels and fully train the network. Args: config (dict): Configuration of `PrunerFtCompressAlgo`, keys are attribute names, values are attribute values. Supported attribute are listed below: - prune_rate (float): number in [0.0, 1.0). Raises: TypeError: If `prune_rate` is not float. ValueError: If `epoch_size` is less than 0 or greater than or equal to 1. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindspore_gs.pruner import PrunerKfCompressAlgo, PrunerFtCompressAlgo >>> from mindspore import nn >>> class Net(nn.Cell): ... def __init__(self, num_channel=1): ... super(Net, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... ... class NetToPrune(nn.Cell): ... def __init__(self): ... super(NetToPrune, self).__init__() ... self.layer = Net() ... ... def construct(self, x): ... x = self.layer(x) ... return x ... >>> net = NetToPrune() >>> kf_pruning = PrunerKfCompressAlgo({}) >>> net_pruning_kf = kf_pruning.apply(net) >>> ## 1) Define FineTune Algorithm >>> ft_pruning = PrunerFtCompressAlgo({'prune_rate': 0.5}) >>> ## 2) Apply FineTune-algorithm to origin network >>> net_pruning_ft = ft_pruning.apply(net_pruning_kf) >>> ## 3) Print network and check the result. Conv2d and bn should be transformed to KfConv2d. >>> print(net_pruning_ft) NetToPrune< (layer): Net< (conv): MaskedConv2dbn< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=conv.bn.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=conv.bn.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=conv.bn.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=conv.bn.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (bn): SequentialCell<> > > """ def _create_config(self): """Create PrunerFtCompressConfig.""" self._config = PrunerFtCompressConfig() def _update_config_from_dict(self, config: dict): """Update prune `config` from a dict""" self.set_prune_rate(config.get("prune_rate", 0.0))
[docs] def set_prune_rate(self, prune_rate: float): """ Set value of prune_rate of `_config` Args: prune_rate (float): the size of network needs to be pruned. Raises: TypeError: If `prune_rate` is not float. ValueError: If `prune_rate` is less than 0. or greater than 1. """ prune_rate = Validator.check_float_range(prune_rate, 0.0, 1.0, Rel.INC_LEFT, "prune_rate", self.__class__.__name__) self._config.prune_rate = prune_rate
def _recover(self, net): """Recover.""" kfconv_list = [] for _, (_, module) in enumerate(net.cells_and_names()): if isinstance(module, KfConv2d): kfconv_list.append(module) for param in net.get_parameters(): param.requires_grad = True for _, (_, module) in enumerate(net.cells_and_names()): if isinstance(module, KfConv2d): module.score = module.bn.gamma.data.abs() * ops.Squeeze()( module.kfscale.data - (1 - module.kfscale.data)) for kfconv in kfconv_list: kfconv.prune_rate = self._config.prune_rate for _, (_, module) in enumerate(net.cells_and_names()): if isinstance(module, KfConv2d): _, index = ops.Sort()(module.score) num_pruned_channel = int(module.prune_rate * module.score.shape[0]) module.out_index = index[num_pruned_channel:] return self._recover_conv(net) def _recover_conv(self, net): """Recover conv.""" def _inject(modules): keys = list(modules.keys()) for _, k in enumerate(keys): if isinstance(modules[k], KfConv2d): for value, param in modules[k].parameters_and_names(): prex = param.name.strip(value.split('.')[-1]) modules[k] = MaskedConv2dbn(modules[k], prex) for params in modules[k].conv.get_parameters(): params.name = prex + params.name for params in modules[k].bn.get_parameters(): params.name = prex + params.name elif (not isinstance(modules[k], KfConv2d)) and modules[k]._cells: _inject(modules[k]._cells) _inject(net._cells) return net def _pruning_conv(self, net): """Prune conv.""" def _inject(modules): keys = list(modules.keys()) for _, k in enumerate(keys): if isinstance(modules[k], MaskedConv2dbn): if 'conv1' in k: modules[k] = PrunedConv2dbn1(modules[k]) elif 'conv2' in k: modules[k] = PrunedConv2dbnmiddle(modules[k]) elif 'conv3' in k: modules[k] = PrunedConv2dbn2(modules[k]) elif (not isinstance(modules[k], KfConv2d)) and modules[k]._cells: _inject(modules[k]._cells) _inject(net._cells) return net
[docs] def apply(self, network, **kwargs): """ Transform a knockoff `network` to a normal and pruned network. Args: network (Cell): Knockoff network. Returns: Pruned network. Raises: TypeError: If `network` is not Cell. """ network = Validator.check_isinstance('network', network, nn.Cell) return self._recover(network)
class PrunerFtCompressConfig(CompAlgoConfig): """Config for PrunerFtCompress.""" def __init__(self): super(PrunerFtCompressConfig, self).__init__() self.prune_rate = 0.0