mindspore.train.CheckpointConfig

class mindspore.train.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False)[source]

The configuration of model checkpoint.

Note

During the training process, if dataset is transmitted through the data channel, it is suggested to set ‘save_checkpoint_steps’ to an integer multiple of loop_size. Otherwise, the time to save the checkpoint may be biased. It is recommended to set only one save strategy and one keep strategy at the same time. If both save_checkpoint_steps and save_checkpoint_seconds are set, save_checkpoint_seconds will be invalid. If both keep_checkpoint_max and keep_checkpoint_per_n_minutes are set, keep_checkpoint_per_n_minutes will be invalid.

Parameters
  • save_checkpoint_steps (int) – Steps to save checkpoint. Default: 1.

  • save_checkpoint_seconds (int) – Seconds to save checkpoint. Can’t be used with save_checkpoint_steps at the same time. Default: 0.

  • keep_checkpoint_max (int) – Maximum number of checkpoint files can be saved. Default: 5.

  • keep_checkpoint_per_n_minutes (int) – Save the checkpoint file every keep_checkpoint_per_n_minutes minutes. Can’t be used with keep_checkpoint_max at the same time. Default: 0.

  • integrated_save (bool) – Whether to merge and save the split Tensor in the automatic parallel scenario. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Default: True.

  • async_save (bool) – Whether asynchronous execution saves the checkpoint to a file. Default: False.

  • saved_network (Cell) – Network to be saved in checkpoint file. If the saved_network has no relation with the network in training, the initial value of saved_network will be saved. Default: None.

  • append_info (list) – The information save to checkpoint file. Support “epoch_num”, “step_num” and dict. The key of dict must be str, the value of dict must be one of int, float, bool, Parameter or Tensor. Default: None

  • enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is None, the encryption is not required. Default: None.

  • enc_mode (str) – This parameter is valid only when enc_key is not set to None. Specifies the encryption mode, currently supports ‘AES-GCM’, ‘AES-CBC’ and ‘SM4-CBC’. Default: ‘AES-GCM’.

  • exception_save (bool) – Whether to save the current checkpoint when an exception occurs. Default: False.

Raises

ValueError – If input parameter is not the correct type.

Examples

Note

Before running the following example, you need to customize the network LeNet5 and dataset preparation function create_dataset. Refer to Building a Network and Dataset .

>>> from mindspore import nn
>>> from mindspore.common.initializer import Normal
>>> from mindspore.train import Model, CheckpointConfig, ModelCheckpoint
>>>
>>> class LeNet5(nn.Cell):
...     def __init__(self, num_class=10, num_channel=1):
...         super(LeNet5, self).__init__()
...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
...         self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
...         self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
...         self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
...         self.relu = nn.ReLU()
...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
...         self.flatten = nn.Flatten()
...
...     def construct(self, x):
...         x = self.max_pool2d(self.relu(self.conv1(x)))
...         x = self.max_pool2d(self.relu(self.conv2(x)))
...         x = self.flatten(x)
...         x = self.relu(self.fc1(x))
...         x = self.relu(self.fc2(x))
...         x = self.fc3(x)
...         return x
>>>
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> data_path = './MNIST_Data'
>>> dataset = create_dataset(data_path)
>>> config = CheckpointConfig(saved_network=net)
>>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config)
>>> model.train(10, dataset, callbacks=ckpoint_cb)
property append_dict

Get the value of information dict saved to checkpoint file.

Returns

Dict, the information saved to checkpoint file.

property async_save

Get the value of whether asynchronous execution saves the checkpoint to a file.

Returns

Bool, whether asynchronous execution saves the checkpoint to a file.

property enc_key

Get the value of byte type key used for encryption.

Returns

(None, bytes), byte type key used for encryption.

property enc_mode

Get the value of the encryption mode.

Returns

str, encryption mode.

get_checkpoint_policy()[source]

Get the policy of checkpoint.

Returns

Dict, the information of checkpoint policy.

property integrated_save

Get the value of whether to merge and save the split Tensor in the automatic parallel scenario.

Returns

Bool, whether to merge and save the split Tensor in the automatic parallel scenario.

property keep_checkpoint_max

Get the value of maximum number of checkpoint files can be saved.

Returns

Int, Maximum number of checkpoint files can be saved.

property keep_checkpoint_per_n_minutes

Get the value of save the checkpoint file every n minutes.

Returns

Int, save the checkpoint file every n minutes.

property save_checkpoint_seconds

Get the value of _save_checkpoint_seconds.

Returns

Int, seconds to save the checkpoint file.

property save_checkpoint_steps

Get the value of steps to save checkpoint.

Returns

Int, steps to save checkpoint.

property saved_network

Get the value of network to be saved in checkpoint file.

Returns

Cell, network to be saved in checkpoint file.