mindspore.save_checkpoint
- mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, crc_check=False, format='ckpt', **kwargs)[source]
Save checkpoint to a specified file.
Note
The enc_mode and crc_check parameters are mutually exclusive and cannot be configured simultaneously.
- Parameters
save_obj (Union[Cell, list, dict]) – The object to be saved. The data type can be
mindspore.nn.Cell
, list, or dict. If a list, it can be the returned value of Cell.trainable_params(), or a list of dict elements(each element is a dictionary, like [{"name": param_name, "data": param_data},…], the type of param_name must be string, and the type of param_data must be parameter or Tensor); If dict, it can be the returned value ofmindspore.load_checkpoint()
.ckpt_file_name (str) – Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool) – Whether to integrated save in automatic model parallel scene. Default:
True
.async_save (Union[bool, str]) – Whether to use asynchronous saving of the checkpoint file, if True, the asynchronous thread is used by default. If the type is string, the method of asynchronous saving, it can be "process" or "thread". Default:
False
.append_dict (dict) – Additional information that needs to be saved. The key of dict must be str, the value of dict must be one of int, float, bool, string, Parameter or Tensor. Default:
None
.enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is
None
, the encryption is not required. Default:None
.enc_mode (str) – This parameter is valid only when enc_key is not set to
None
. Specifies the encryption mode, currently supports"AES-GCM"
and"AES-CBC"
and"SM4-CBC"
. Default:"AES-GCM"
.choice_func (function) – A function for saving custom selected parameters. The input value of choice_func is a parameter name in string type, and the returned value is a bool. If returns
True
, the Parameter that matching the custom condition will be saved. If returnsFalse
, the Parameter that not matching the custom condition will not be saved. Default:None
.crc_check (bool) – Whether to perform crc32 calculation when saving checkpoint and save the calculation result to the file. Default:
False
.format (str) – Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".
kwargs (dict) – Configuration options dictionary.
- Raises
TypeError – If the parameter save_obj is not
mindspore.nn.Cell
, list or dict type.TypeError – If the parameter integrated_save is not bool type.
TypeError – If the parameter ckpt_file_name is not string type.
TypeError – If the parameter async_save is not bool or string type.
ValueError – If the parameter async_save is string type but not in ["process", "thread"].
Examples
>>> import mindspore as ms >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> ms.save_checkpoint(net, "./lenet.ckpt", ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1")) >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt") >>> print(param_dict1) {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)} >>> params_list = net.trainable_params() >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt", ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2")) >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt") >>> print(param_dict2) {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)} >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt") >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt") >>> print(param_dict3) {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
- Tutorial Examples: