mindspore.train.BackupAndRestore

class mindspore.train.BackupAndRestore(backup_dir, save_freq='epoch', delete_checkpoint=True)[source]

Callback to back up and restore the parameters during training.

Note

This function can only use in training.

Parameters
  • backup_dir (str) – Path to store and load the checkpoint file.

  • save_freq (Union['epoch', int]) – When set to ‘epoch’ the callback saves the checkpoint at the end of each epoch. When set to an integer, the callback saves the checkpoint every save_freq epoch. Default: "epoch" .

  • delete_checkpoint (bool) – If delete_checkpoint=True, the checkpoint will be deleted after training is finished. Default: True .

Raises

Examples

>>> from mindspore import nn
>>> from mindspore.train import Model, BackupAndRestore, RunContext
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> backup_ckpt = BackupAndRestore("backup")
>>> cb_params = {}
>>> cb_params["cur_epoch_num"] = 4
>>> cb_params["epoch_num"] = 4
>>> cb_params["cur_step_num"] = 2
>>> cb_params["batch_num"] = 2
>>> cb_params["net_outputs"] = Tensor(2.0)
>>> run_context = RunContext(cb_params)
>>> backup_ckpt.on_train_begin(run_context)
>>> backup_ckpt.on_train_epoch_end(run_context)
>>> backup_ckpt.on_train_end(run_context)
>>> model.train(10, dataset, callbacks=backup_ckpt)
on_train_begin(run_context)[source]

Load the backup checkpoint file at the beginning of epoch.

Parameters

run_context (RunContext) – Context of the process running. For more details, please refer to mindspore.train.RunContext.

on_train_end(run_context)[source]

Deleted checkpoint file at the end of train.

Parameters

run_context (RunContext) – Context of the process running. For more details, please refer to mindspore.train.RunContext.

on_train_epoch_end(run_context)[source]

Backup checkpoint file at the end of train epoch.

Parameters

run_context (RunContext) – Context of the process running. For more details, please refer to mindspore.train.RunContext.