mindspore_gs.pruner.PrunerFtCompressAlgo
- class mindspore_gs.pruner.PrunerFtCompressAlgo(config=None)[源代码]
PrunerFtCompressAlgo 是CompAlgo的子类,实现了删除冗余卷积核并对网络进行完整训练的能力。
- 参数:
config (dict) - 以字典的形式存放用于剪枝训练的配置,下面列出了受支持的属性:
prune_rate (float) - 值的取值范围是[0.0, 1.0)。
- 异常:
TypeError - prune_rate 的数据类型不是 float 。
ValueError - prune_rate 小于0或者大于等于1。
- 支持平台:
Ascend
GPU
样例:
>>> from mindspore_gs.pruner import PrunerKfCompressAlgo, PrunerFtCompressAlgo >>> from mindspore import nn >>> class Net(nn.Cell): ... def __init__(self, num_channel=1): ... super(Net, self).__init__() ... self.conv = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.bn = nn.BatchNorm2d(6) ... ... def construct(self, x): ... x = self.conv(x) ... x = self.bn(x) ... return x ... ... class NetToPrune(nn.Cell): ... def __init__(self): ... super(NetToPrune, self).__init__() ... self.layer = Net() ... ... def construct(self, x): ... x = self.layer(x) ... return x ... >>> net = NetToPrune() >>> kf_pruning = PrunerKfCompressAlgo({}) >>> net_pruning_kf = kf_pruning.apply(net) >>> ## 1) Define FineTune Algorithm >>> ft_pruning = PrunerFtCompressAlgo({'prune_rate': 0.5}) >>> ## 2) Apply FineTune-algorithm to origin network >>> net_pruning_ft = ft_pruning.apply(net_pruning_kf) >>> ## 3) Print network and check the result. Conv2d and bn should be transformed to KfConv2d. >>> print(net_pruning_ft) NetToPrune< (layer): Net< (conv): MaskedConv2dbn< (conv): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW> (bn): BatchNorm2d<num_features=6, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=conv.bn.bn.gamma, shape=(6,), dtype=Float32, requires_grad=True), beta=Parameter (name=conv.bn.bn.beta, shape=(6,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=conv.bn.bn.moving_mean, shape=(6,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=conv.bn.bn.moving_variance, shape=(6,), dtype=Float32, requires_grad=False)> > (bn): SequentialCell<> > >