mindspore.train.CheckpointConfig
- class mindspore.train.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False)[source]
The configuration of model checkpoint.
Note
During the training process, if dataset is transmitted through the data channel, it is suggested to set ‘save_checkpoint_steps’ to an integer multiple of loop_size. Otherwise, the time to save the checkpoint may be biased. It is recommended to set only one save strategy and one keep strategy at the same time. If both save_checkpoint_steps and save_checkpoint_seconds are set, save_checkpoint_seconds will be invalid. If both keep_checkpoint_max and keep_checkpoint_per_n_minutes are set, keep_checkpoint_per_n_minutes will be invalid.
- Parameters
save_checkpoint_steps (int) – Steps to save checkpoint. Default: 1.
save_checkpoint_seconds (int) – Seconds to save checkpoint. Can’t be used with save_checkpoint_steps at the same time. Default: 0.
keep_checkpoint_max (int) – Maximum number of checkpoint files can be saved. Default: 5.
keep_checkpoint_per_n_minutes (int) – Save the checkpoint file every keep_checkpoint_per_n_minutes minutes. Can’t be used with keep_checkpoint_max at the same time. Default: 0.
integrated_save (bool) – Whether to merge and save the split Tensor in the automatic parallel scenario. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Default: True.
async_save (bool) – Whether asynchronous execution saves the checkpoint to a file. Default: False.
saved_network (Cell) – Network to be saved in checkpoint file. If the saved_network has no relation with the network in training, the initial value of saved_network will be saved. Default: None.
append_info (list) – The information save to checkpoint file. Support “epoch_num”, “step_num” and dict. The key of dict must be str, the value of dict must be one of int, float, bool, Parameter or Tensor. Default: None
enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is None, the encryption is not required. Default: None.
enc_mode (str) – This parameter is valid only when enc_key is not set to None. Specifies the encryption mode, currently supports ‘AES-GCM’, ‘AES-CBC’ and ‘SM4-CBC’. Default: ‘AES-GCM’.
exception_save (bool) – Whether to save the current checkpoint when an exception occurs. Default: False.
- Raises
ValueError – If input parameter is not the correct type.
Examples
Note
Before running the following example, you need to customize the network LeNet5 and dataset preparation function create_dataset. Refer to Building a Network and Dataset .
>>> from mindspore import nn >>> from mindspore.common.initializer import Normal >>> from mindspore.train import Model, CheckpointConfig, ModelCheckpoint >>> >>> class LeNet5(nn.Cell): ... def __init__(self, num_class=10, num_channel=1): ... super(LeNet5, self).__init__() ... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') ... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') ... self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) ... self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) ... self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) ... self.relu = nn.ReLU() ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) ... self.flatten = nn.Flatten() ... ... def construct(self, x): ... x = self.max_pool2d(self.relu(self.conv1(x))) ... x = self.max_pool2d(self.relu(self.conv2(x))) ... x = self.flatten(x) ... x = self.relu(self.fc1(x)) ... x = self.relu(self.fc2(x)) ... x = self.fc3(x) ... return x >>> >>> net = LeNet5() >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) >>> model = Model(net, loss_fn=loss, optimizer=optim) >>> data_path = './MNIST_Data' >>> dataset = create_dataset(data_path) >>> config = CheckpointConfig(saved_network=net) >>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config) >>> model.train(10, dataset, callbacks=ckpoint_cb)
- property append_dict
Get the value of information dict saved to checkpoint file.
- Returns
Dict, the information saved to checkpoint file.
- property async_save
Get the value of whether asynchronous execution saves the checkpoint to a file.
- Returns
Bool, whether asynchronous execution saves the checkpoint to a file.
- property enc_key
Get the value of byte type key used for encryption.
- Returns
(None, bytes), byte type key used for encryption.
- property enc_mode
Get the value of the encryption mode.
- Returns
str, encryption mode.
- get_checkpoint_policy()[source]
Get the policy of checkpoint.
- Returns
Dict, the information of checkpoint policy.
- property integrated_save
Get the value of whether to merge and save the split Tensor in the automatic parallel scenario.
- Returns
Bool, whether to merge and save the split Tensor in the automatic parallel scenario.
- property keep_checkpoint_max
Get the value of maximum number of checkpoint files can be saved.
- Returns
Int, Maximum number of checkpoint files can be saved.
- property keep_checkpoint_per_n_minutes
Get the value of save the checkpoint file every n minutes.
- Returns
Int, save the checkpoint file every n minutes.
- property save_checkpoint_seconds
Get the value of _save_checkpoint_seconds.
- Returns
Int, seconds to save the checkpoint file.
- property save_checkpoint_steps
Get the value of steps to save checkpoint.
- Returns
Int, steps to save checkpoint.
- property saved_network
Get the value of network to be saved in checkpoint file.
- Returns
Cell, network to be saved in checkpoint file.