mindspore.train.BackupAndRestore

查看源文件
class mindspore.train.BackupAndRestore(backup_dir, save_freq='epoch', delete_checkpoint=True)[源代码]

在训练过程中备份和恢复训练参数的回调函数。

说明

只能在训练过程使用这个方法。

参数:
  • backup_dir (str) - 保存和恢复checkpoint文件的路径。

  • save_freq (Union[“epoch”, int]) - 当设置为 "epoch" 时,在每个epoch进行备份,当设置为整数时,将在每隔 save_freq 个epoch进行备份。默认值: "epoch"

  • delete_checkpoint (bool) - 如果 delete_checkpoint=True ,将在训练结束的时候删除备份文件,否则保留备份文件。默认值: True

异常:
  • ValueError - 如果 backup_dir 参数不是str类型。

  • ValueError - 如果 save_freq 参数不是 "epoch" 或int类型。

  • ValueError - 如果 delete_checkpoint 参数不是bool类型。

样例:

>>> from mindspore import nn
>>> from mindspore.train import Model, BackupAndRestore, RunContext
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> backup_ckpt = BackupAndRestore("backup")
>>> model.train(10, dataset, callbacks=backup_ckpt)
on_train_begin(run_context)[源代码]

在训练开始时,加载备份的checkpoint文件。

参数:
on_train_end(run_context)[源代码]

在训练结束时,判断是否删除备份的checkpoint文件。

参数:
on_train_epoch_end(run_context)[源代码]

在每个epoch结束时,判断是否需要备份checkpoint文件。

参数: