mindspore_gs.pruner.PrunerKfCompressAlgo

查看源文件
class mindspore_gs.pruner.PrunerKfCompressAlgo(config=None)[源代码]

PrunerKfCompressAlgo 是CompAlgo的子类,实现了SCOP算法中利用高仿数据来学习发现冗余卷积核。

说明

  • 针对入参 config ,目前 PrunerKfCompressAlgo 是没有可选的配置项,但为了兼容性, config 被保留,在初始化时以空字典代替。如 kf_pruning = PrunerKfCompressAlgo({})

参数:
  • config (dict) - 算法配置参数。

支持平台:

Ascend GPU

样例:

>>> from mindspore_gs.pruner import PrunerKfCompressAlgo
>>> from mindspore import nn
>>> class Net(nn.Cell):
...     def __init__(self, num_channel=1):
...         super(Net, self).__init__()
...         self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.bn = nn.BatchNorm2d(6)
...
...     def construct(self, x):
...         x = self.conv(x)
...         x = self.bn(x)
...         return x
...
... class NetToPrune(nn.Cell):
...     def __init__(self):
...        super(NetToPrune, self).__init__()
...        self.layer = Net()
...
...     def construct(self, x):
...         x = self.layer(x)
...         return x
...
>>> ## 1) Define network to be quantized
>>> net = NetToPrune()
>>> ## 2) Define Knockoff Algorithm
>>> kf_pruning = PrunerKfCompressAlgo({})
>>> ## 3) Apply Konckoff-algorithm to origin network
>>> net_pruning = kf_pruning.apply(net)
>>> ## 4) Print network and check the result. Conv2d and bn should be transformed to KfConv2d.
>>> print(net_pruning)
NetToPrune<
(layer): Net<
(conv): KfConv2d<
(conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid,
padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter
(name=conv.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter
(name=conv.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter
(name=conv.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter
(name=conv.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)>
>
(bn): SequentialCell<>
>
>
apply(network)[源代码]

将网络变成Konckoff网络。

参数:
  • network (Cell) - 原始待剪枝网络。

返回:

返回变换后的Konckoff网络。

异常:
  • TypeError - network 不是Cell。

callbacks(*args, **kwargs)[源代码]

定义SCOP剪枝算法特有的callbacks即生成高仿数据的callback。

返回:

SCOP剪枝算法的callbacks列表。