mindspore.save_checkpoint
- mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, **kwargs)[source]
Save checkpoint to a specified file.
- Parameters
save_obj (Union[Cell, list, dict]) – The object to be saved. The data type can be
mindspore.nn.Cell
, list, or dict. If a list, it can be the returned value of Cell.trainable_params(), or a list of dict elements(each element is a dictionary, like [{“name”: param_name, “data”: param_data},…], the type of param_name must be string, and the type of param_data must be parameter or Tensor); If dict, it can be the returned value of mindspore.load_checkpoint().ckpt_file_name (str) – Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool) – Whether to integrated save in automatic model parallel scene. Default:
True
.async_save (bool) – Whether to open an independent thread to save the checkpoint file. Default:
False
.append_dict (dict) – Additional information that needs to be saved. The key of dict must be str, the value of dict must be one of int, float, bool, string, Parameter or Tensor. Default:
None
.enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is
None
, the encryption is not required. Default:None
.enc_mode (str) – This parameter is valid only when enc_key is not set to
None
. Specifies the encryption mode, currently supports"AES-GCM"
and"AES-CBC"
and"SM4-CBC"
. Default:"AES-GCM"
.choice_func (function) – A function for saving custom selected parameters. The input value of choice_func is a parameter name in string type, and the returned value is a bool. If returns
True
, the Parameter that matching the custom condition will be saved. If returnsFalse
, the Parameter that not matching the custom condition will not be saved. Default:None
.kwargs (dict) – Configuration options dictionary.
- Raises
TypeError – If the parameter save_obj is not
mindspore.nn.Cell
, list or dict type.TypeError – If the parameter integrated_save or async_save is not bool type.
TypeError – If the parameter ckpt_file_name is not string type.
Examples
>>> import mindspore as ms >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> ms.save_checkpoint(net, "./lenet.ckpt", ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1")) >>> param_dict1 = ms.load_checkpoint("./lenet.ckpt") >>> print(param_dict1) {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)} >>> params_list = net.trainable_params() >>> ms.save_checkpoint(params_list, "./lenet_list.ckpt", ... choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2")) >>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt") >>> print(param_dict2) {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)} >>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt") >>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt") >>> print(param_dict3) {'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
- Tutorial Examples: