mindspore.train.CheckpointConfig

View Source On Gitee
class mindspore.train.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False, crc_check=False, remove_redundancy=False, format='ckpt', **kwargs)[source]

The configuration of model checkpoint.

Note

  • During the training process, if dataset is transmitted through the data channel, it is suggested to set 'save_checkpoint_steps' to an integer multiple of loop_size. Otherwise, the time to save the checkpoint may be biased. It is recommended to set only one save strategy and one keep strategy at the same time. If both save_checkpoint_steps and save_checkpoint_seconds are set, save_checkpoint_seconds will be invalid. If both keep_checkpoint_max and keep_checkpoint_per_n_minutes are set, keep_checkpoint_per_n_minutes will be invalid.

  • The enc_mode and crc_check parameters are mutually exclusive and cannot be configured simultaneously.

Parameters
  • save_checkpoint_steps (int) – Steps to save checkpoint. Default: 1 .

  • save_checkpoint_seconds (int) – Seconds to save checkpoint. Can't be used with save_checkpoint_steps at the same time. Default: 0 .

  • keep_checkpoint_max (int) – Maximum number of checkpoint files can be saved. Default: 5 .

  • keep_checkpoint_per_n_minutes (int) – Save the checkpoint file every keep_checkpoint_per_n_minutes minutes. Can't be used with keep_checkpoint_max at the same time. Default: 0 .

  • integrated_save (bool) – Whether to merge and save the split Tensor in the automatic parallel scenario. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Default: True .

  • async_save (Union[bool, str]) – Whether to use asynchronous saving of the checkpoint file, if True, the asynchronous thread is used by default. If the type is string, the method of asynchronous saving, it can be "process" or "thread". Default: False .

  • saved_network (Cell) – Network to be saved in checkpoint file. If the saved_network has no relation with the network in training, the initial value of saved_network will be saved. Default: None .

  • append_info (list) – The information save to checkpoint file. Support "epoch_num", "step_num" and dict. The key of dict must be str, the value of dict must be one of int, float, bool, Parameter or Tensor. Default: None .

  • enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is None, the encryption is not required. Default: None .

  • enc_mode (str) – This parameter is valid only when enc_key is not set to None. Specifies the encryption mode, currently supports 'AES-GCM', 'AES-CBC' and 'SM4-CBC'. Default: 'AES-GCM' .

  • exception_save (bool) – Whether to save the current checkpoint when an exception occurs. Default: False .

  • crc_check (bool) – Whether to perform crc32 calculation when saving checkpoint and save the calculation result to the end of ckpt. Default: False .

  • remove_redundancy (bool) – Whether to enable saving the checkpoint with redundancy removal. Redundancy removal refers to eliminating redundant data in data parallelism mode. Default: False , means redundant-free saving is not enabled.

  • format (str) – Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".

  • kwargs (dict) – Configuration options dictionary.

Raises

ValueError – If input parameter is not the correct type.

Examples

>>> from mindspore import nn
>>> from mindspore.train import Model, CheckpointConfig, ModelCheckpoint
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> config = CheckpointConfig(save_checkpoint_seconds=100, keep_checkpoint_per_n_minutes=5, saved_network=net)
>>> config.save_checkpoint_steps
1
>>> config.save_checkpoint_seconds
>>> config.keep_checkpoint_max
5
>>> config.keep_checkpoint_per_n_minutes
>>> config.integrated_save
True
>>> config.async_save
False
>>> config.saved_network
>>> config.enc_key
>>> config.enc_mode
'AES-GCM'
>>> config.append_dict
>>> config.get_checkpoint_policy
>>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config)
>>> model.train(10, dataset, callbacks=ckpoint_cb)
property append_dict

Get the value of information dict saved to checkpoint file.

Returns

dict, the information saved to checkpoint file.

property async_save

Get the value of whether or how asynchronous execution saves the checkpoint to a file.

Returns

(bool, str), whether or how asynchronous execution saves the checkpoint to a file.

property crc_check

Get the value of the whether to enable crc check.

Returns

bool, whether to enable crc check.

property enc_key

Get the value of byte type key used for encryption.

Returns

(None, bytes), byte type key used for encryption.

property enc_mode

Get the value of the encryption mode.

Returns

str, encryption mode.

get_checkpoint_policy()[source]

Get the policy of checkpoint.

Returns

dict, the information of checkpoint policy.

property integrated_save

Get the value of whether to merge and save the split Tensor in the automatic parallel scenario.

Returns

bool, whether to merge and save the split Tensor in the automatic parallel scenario.

property keep_checkpoint_max

Get the value of maximum number of checkpoint files can be saved.

Returns

int, Maximum number of checkpoint files can be saved.

property keep_checkpoint_per_n_minutes

Get the value of save the checkpoint file every n minutes.

Returns

Int, save the checkpoint file every n minutes.

property map_param_inc

Get the value of whether to save map Parameter incrementally.

Returns

bool, whether to save map Parameter incrementally.

property save_checkpoint_seconds

Get the value of _save_checkpoint_seconds.

Returns

int, seconds to save the checkpoint file.

property save_checkpoint_steps

Get the value of steps to save checkpoint.

Returns

int, steps to save checkpoint.

property saved_network

Get the value of network to be saved in checkpoint file.

Returns

Cell, network to be saved in checkpoint file.