mindspore.train.BackupAndRestore
- class mindspore.train.BackupAndRestore(backup_dir, save_freq='epoch', delete_checkpoint=True)[源代码]
在训练过程中备份和恢复训练参数的回调函数。
说明
只能在训练过程使用这个方法。
- 参数:
backup_dir (str) - 保存和恢复checkpoint文件的路径。
save_freq (Union[‘epoch’, int]) - 当设置为’epoch’时,在每个epoch进行备份,当设置为整数时,将在每隔 save_freq 个epoch进行备份。默认值:
"epoch"
。delete_checkpoint (bool) - 如果 delete_checkpoint=True ,将在训练结束的时候删除备份文件,否则保留备份文件。默认值:
True
。
- 异常:
ValueError - 如果 backup_dir 参数不是str类型。
ValueError - 如果 save_freq 参数不是’epoch’或str类型。
ValueError - 如果 delete_checkpoint 参数不是bool类型。
样例:
>>> from mindspore import nn >>> from mindspore.train import Model, BackupAndRestore, RunContext >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> # Create the dataset taking MNIST as an example. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py >>> dataset = create_dataset() >>> backup_ckpt = BackupAndRestore("backup") >>> cb_params = {} >>> cb_params["cur_epoch_num"] = 4 >>> cb_params["epoch_num"] = 4 >>> cb_params["cur_step_num"] = 2 >>> cb_params["batch_num"] = 2 >>> cb_params["net_outputs"] = Tensor(2.0) >>> run_context = RunContext(cb_params) >>> backup_ckpt.on_train_begin(run_context) >>> backup_ckpt.on_train_epoch_end(run_context) >>> backup_ckpt.on_train_end(run_context) >>> model.train(10, dataset, callbacks=backup_ckpt)
- on_train_begin(run_context)[源代码]
在训练开始时,加载备份的checkpoint文件。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。
- on_train_end(run_context)[源代码]
在训练结束时,判断是否删除备份的checkpoint文件。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。
- on_train_epoch_end(run_context)[源代码]
在每个epoch结束时,判断是否需要备份checkpoint文件。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。