mindspore.train.BackupAndRestore
- class mindspore.train.BackupAndRestore(backup_dir, save_freq='epoch', delete_checkpoint=True)[source]
Callback to back up and restore the parameters during training.
Note
This function can only use in training.
- Parameters
backup_dir (str) – Path to store and load the checkpoint file.
save_freq (Union['epoch', int]) – When set to ‘epoch’ the callback saves the checkpoint at the end of each epoch. When set to an integer, the callback saves the checkpoint every save_freq epoch. Default:
"epoch"
.delete_checkpoint (bool) – If delete_checkpoint=True, the checkpoint will be deleted after training is finished. Default:
True
.
- Raises
ValueError – If backup_dir is not str.
ValueError – If save_freq is not ‘epoch’ or int.
ValueError – If delete_checkpoint is not bool.
Examples
>>> from mindspore import nn >>> from mindspore.train import Model, BackupAndRestore, RunContext >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> backup_ckpt = BackupAndRestore("backup") >>> cb_params = {} >>> cb_params["cur_epoch_num"] = 4 >>> cb_params["epoch_num"] = 4 >>> cb_params["cur_step_num"] = 2 >>> cb_params["batch_num"] = 2 >>> cb_params["net_outputs"] = Tensor(2.0) >>> run_context = RunContext(cb_params) >>> backup_ckpt.on_train_begin(run_context) >>> backup_ckpt.on_train_epoch_end(run_context) >>> backup_ckpt.on_train_end(run_context) >>> model.train(10, dataset, callbacks=backup_ckpt)
- on_train_begin(run_context)[source]
Load the backup checkpoint file at the beginning of epoch.
- Parameters
run_context (RunContext) – Context of the process running. For more details, please refer to
mindspore.train.RunContext
.
- on_train_end(run_context)[source]
Deleted checkpoint file at the end of train.
- Parameters
run_context (RunContext) – Context of the process running. For more details, please refer to
mindspore.train.RunContext
.
- on_train_epoch_end(run_context)[source]
Backup checkpoint file at the end of train epoch.
- Parameters
run_context (RunContext) – Context of the process running. For more details, please refer to
mindspore.train.RunContext
.