mindspore_gs
- class mindspore_gs.CompAlgo(config=None)[source]
Base class of algorithms in GoldenStick.
- Parameters
config (dict) –
User config for network compression, default is
None
. Algorithm config specification is default by derived class, base attributes are listed below:save_mindir (bool): If
True
, export MindIR automatically after training, else not. Default:False
.save_mindir_path (str): The path to export MindIR, the path includes the directory and file name, which can be a relative path or an absolute path, the user needs to ensure write permission. Default:
'./network'
.
- abstract apply(network: Cell, **kwargs)[source]
Define how to compress input network. This method must be overridden by all subclasses.
- Parameters
network (Cell) – Network to be compressed.
kwargs (Dict) – Extensible parameter for subclasses.
- Returns
Compressed network.
- callbacks(*args, **kwargs)[source]
Define what task need to be done when training. Must be called at the end of child callbacks.
- convert(net_opt: Cell, ckpt_path='')[source]
Define how to convert a compressed network to a standard network before exporting to MindIR.
- Parameters
net_opt (Cell) – Network to be converted which is transformed by CompAlgo.apply.
ckpt_path (str) – Path to checkpoint file for net_opt. Default is
""
, which means not loading checkpoint file to net_opt. this parameter would be deprecated in future version.
- Returns
An instance of Cell represents converted network.
Examples
>>> from mindspore_gs.quantization import SimulatedQuantizationAwareTraining as SimQAT >>> ## 1) Define network to be trained >>> network = LeNet(10) >>> ## 2) Define MindSpore Golden Stick Algorithm, here we use base algorithm. >>> algo = SimQAT() >>> ## 3) Apply MindSpore Golden Stick algorithm to origin network. >>> network = algo.apply(network) >>> ## 4) Then you can start training, after which you can convert a compressed network to a standard >>> ## network, there are two ways to do that. >>> ## 4.1) Convert without checkpoint. >>> net_deploy = algo.convert(network) >>> ## 4.2) Convert with checkpoint. >>> net_deploy = algo.convert(network, ckpt_path)
- loss(loss_fn: callable)[source]
Define how to adjust loss-function for algorithm. Subclass is not need to overridden this method if current algorithm not care loss-function.
- Parameters
loss_fn (callable) – Original loss function.
- Returns
Adjusted loss function.
- set_save_mindir(save_mindir: bool)[source]
Set whether to automatically export MindIR after training.
- Parameters
save_mindir (bool) – If
True
, export MindIR automatically after training, else not.- Raises
TypeError – If need_save is not bool.
Examples
>>> import mindspore as ms >>> from mindspore_gs.quantization import SimulatedQuantizationAwareTraining as SimQAT >>> import numpy as np >>> ## 1) Define network to be trained >>> network = LeNet(10) >>> ## 2) Define MindSpore Golden Stick Algorithm, here we use base algorithm. >>> algo = SimQAT() >>> ## 3) Enable automatically export MindIR after training. >>> algo.set_save_mindir(save_mindir=True) >>> ## 4) Set MindIR output path. >>> algo.set_save_mindir_path(save_mindir_path="./lenet") >>> ## 5) Apply MindSpore Golden Stick algorithm to origin network. >>> network = algo.apply(network) >>> ## 6) Set up Model. >>> train_dataset = create_custom_dataset() >>> net_loss = ms.nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") >>> net_opt = ms.nn.Momentum(network.trainable_params(), 0.01, 0.9) >>> model = ms.Model(network, net_loss, net_opt, metrics={"Accuracy": ms.train.Accuracy()}) >>> ## 7) Config callback in model.train, start training, then MindIR will be exported. >>> model.train(1, train_dataset, callbacks=algo.callbacks())
- set_save_mindir_path(save_mindir_path: str)[source]
Set the path to export MindIR, only takes effect if save_mindir is
True
.- Parameters
save_mindir_path (str) – The path to export MindIR, the path includes the directory and file name, which can be a relative path or an absolute path, the user needs to ensure write permission.
- Raises
ValueError – if save_mindir_path is not Non-empty str.