mindspore_gs.pruner.PrunerKfCompressAlgo
- class mindspore_gs.pruner.PrunerKfCompressAlgo(config=None)[source]
PrunerKfCompressAlgo is a subclass of CompAlgo, which implements the use of high imitation data to learn and discover redundant convolution kernels in the SCOP algorithm.
Note
For the input parameter config, there is currently no optional configuration item for PrunerKfCompressAlgo, but for compatibility, config is reserved and replaced with an empty dictionary during initialization. Such as kf_pruning = PrunerKfCompressAlgo({}).
- Parameters
config (dict) – Configuration of PrunerKfCompressAlgo. There are no configurable options for PrunerKfCompressAlgo currently, but for compatibility, the config parameter in the constructor of class A is retained.
- Supported Platforms:
Ascend
GPU
Examples
>>> from mindspore_gs.pruner import PrunerKfCompressAlgo >>> from mindspore import nn >>> class Net(nn.Cell): ... def __init__(self, num_channel=1): ... super(Net, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... ... class NetToPrune(nn.Cell): ... def __init__(self): ... super(NetToPrune, self).__init__() ... self.layer = Net() ... ... def construct(self, x): ... x = self.layer(x) ... return x ... >>> ## 1) Define network to be quantized >>> net = NetToPrune() >>> ## 2) Define Knockoff Algorithm >>> kf_pruning = PrunerKfCompressAlgo({}) >>> ## 3) Apply Konckoff-algorithm to origin network >>> net_pruning = kf_pruning.apply(net) >>> ## 4) Print network and check the result. Conv2d and bn should be transformed to KfConv2d. >>> print(net_pruning) NetToPrune< (layer): Net< (conv): KfConv2d< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=conv.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=conv.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=conv.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=conv.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (bn): SequentialCell<> > >