mindspore.train.ModelCheckpoint
- class mindspore.train.ModelCheckpoint(prefix='CKP', directory=None, config=None)[源代码]
checkpoint的回调函数。
在训练过程中调用该方法可以保存网络参数。
说明
在分布式训练场景下,请为每个训练进程指定不同的目录来保存checkpoint文件。否则,可能会训练失败。 如何在 model 方法中使用此回调函数,默认将会把优化器中的参数保存到checkpoint文件中。
- 参数:
prefix (Union[str, callable object]) - checkpoint文件的前缀名称,或者用来生成名称的可调用对象。默认值:
'CKP'
。directory (Union[str, callable object]) - 保存checkpoint文件的文件夹路径,或者用来生成路径的可调用对象。默认情况下,文件保存在当前目录下。默认值:
None
。config (CheckpointConfig) - checkpoint策略配置。默认值:
None
。
- 异常:
ValueError - 如果prefix参数不是str类型或包含’/’字符,且不是可调用对象。
ValueError - 如果directory参数不是str类型,且不是可调用对象。
TypeError - config不是CheckpointConfig类型。
样例:
>>> import numpy as np >>> import mindspore.dataset as ds >>> from mindspore import nn >>> from mindspore.train import Model, ModelCheckpoint >>> >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) >>> net = nn.Dense(10, 5) >>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> ckpt_callback = ModelCheckpoint(prefix="myckpt") >>> model = Model(network=net, optimizer=opt, loss_fn=crit) >>> model.train(2, train_dataset, callbacks=[ckpt_callback])
- end(run_context)[源代码]
在训练结束后,会保存最后一个step的checkpoint。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。
- property latest_ckpt_file_name
返回最新的checkpoint路径和文件名。
- step_end(run_context)[源代码]
在step结束时保存checkpoint。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。