Saving and Loading Model Parameters
Overview
During model training, you can add CheckPoints to save model parameters for inference and retraining after interruption.
The application scenarios are as follows:
Inference after training.
After model training, save model parameters for inference or prediction.
During training, through real-time accuracy validation, save the model parameters with the highest accuracy for prediction.
Retraining.
During a long-time training, save the generated CheckPoint files to prevent the training from starting from the beginning after it exits abnormally.
Train a model, save parameters, and perform fine-tuning for different tasks.
A CheckPoint file of MindSpore is a binary file that stores the values of all training parameters. The Google Protocol Buffers mechanism with good scalability is adopted, which is independent of the development language and platform.
The protocol format of CheckPoints is defined in mindspore/ccsrc/utils/checkpoint.proto
.
The following uses an example to describe the saving and loading functions of MindSpore. The ResNet-50 network and the MNIST dataset are selected.
Saving Model Parameters
During model training, use the callback mechanism to transfer the object of the callback function ModelCheckpoint
to save model parameters and generate CheckPoint files.
You can use the CheckpointConfig
object to set the CheckPoint saving policies.
The saved parameters are classified into network parameters and optimizer parameters.
ModelCheckpoint()
provides default configuration policies for users to quickly get started.
The following describes the usage:
from mindspore.train.callback import ModelCheckpoint
ckpoint_cb = ModelCheckpoint()
model.train(epoch_num, dataset, callbacks=ckpoint_cb)
You can configure the CheckPoint policies as required. The following describes the usage:
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='resnet50', directory=None, config=config_ck)
model.train(epoch_num, dataset, callbacks=ckpoint_cb)
In the preceding code, initialize a TrainConfig
class object to set the saving policies.
save_checkpoint_steps
indicates the saving frequency. That is, parameters are saved every specified number of steps. keep_checkpoint_max
indicates the maximum number of CheckPoint files that can be saved.
prefix
indicates the prefix name of the generated CheckPoint file. directory
indicates the directory for storing the file.
Create a ModelCheckpoint
object and transfer it to the model.train method. Then you can use the CheckPoint function during training.
Generated CheckPoint files are as follows:
resnet50-graph.meta # Generate compiled computation graph.
resnet50-1_32.ckpt # The file name extension is .ckpt.
resnet50-2_32.ckpt # The file name format contains the epoch and step correspond to the saved parameters.
resnet50-3_32.ckpt # The file name indicates that the model parameters generated during the 32th step of the third epoch are saved.
…
If you use the same prefix and run the training script for multiple times, CheckPoint files with the same name may be generated. MindSpore adds underscores (_) and digits at the end of the user-defined prefix to distinguish CheckPoints with the same name.
For example, resnet50_3-2_32.ckpt
indicates the CheckPoint file generated during the 32th step of the second epoch after the script is executed for the third time.
CheckPoint Configuration Policies
MindSpore provides two types of CheckPoint saving policies: iteration policy and time policy. You can create the CheckpointConfig
object to set the corresponding policies.
CheckpointConfig contains the following four parameters:
save_checkpoint_steps: indicates the step interval for saving a CheckPoint file. That is, parameters are saved every specified number of steps. The default value is 1.
save_checkpoint_seconds: indicates the interval for saving a CheckPoint file. That is, parameters are saved every specified number of seconds. The default value is 0.
keep_checkpoint_max: indicates the maximum number of CheckPoint files that can be saved. The default value is 5.
keep_checkpoint_per_n_minutes: indicates the interval for saving a CheckPoint file. That is, parameters are saved every specified number of minutes. The default value is 0.
save_checkpoint_steps
and keep_checkpoint_max
are iteration policies, which can be configured based on the number of training iterations.
save_checkpoint_seconds
and keep_checkpoint_per_n_minutes
are time policies, which can be configured during training.
The two types of policies cannot be used together. Iteration policies have a higher priority than time policies. When the two types of policies are configured at the same time, only iteration policies take effect. If a parameter is set to None, the related policy is canceled. After the training script is normally executed, the CheckPoint file generated during the last step is saved by default.
Loading Model Parameters
After saving CheckPoint files, you can load parameters.
For Inference Validation
In inference-only scenarios, use load_checkpoint
to directly load parameters to the network for subsequent inference validation.
The sample code is as follows:
resnet = ResNet50()
load_checkpoint("resnet50-2_32.ckpt", net=resnet)
dateset_eval = create_dataset(os.path.join(mnist_path, "test"), 32, 1) # define the test dataset
loss = CrossEntropyLoss()
model = Model(resnet, loss)
acc = model.eval(dataset_eval)
The load_checkpoint
method loads network parameters in the parameter file to the model. After the loading, parameters in the network are those saved in CheckPoints.
The eval
method validates the accuracy of the trained model.
For Retraining
In the retraining after task interruption and fine-tuning scenarios, you can load network parameters and optimizer parameters to the model.
The sample code is as follows:
# return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt")
resnet = ResNet50()
opt = Momentum()
# load the parameter into net
load_param_into_net(resnet, param_dict)
# load the parameter into operator
load_param_into_net(opt, param_dict)
loss = SoftmaxCrossEntropyWithLogits()
model = Model(resnet, loss, opt)
model.train(epoch, dataset)
The load_checkpoint
method returns a parameter dictionary and then the load_param_into_net
method loads parameters in the parameter dictionary to the network or optimizer.