mindspore.save_checkpoint
- mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, **kwargs)[source]
Save checkpoint to a specified file.
- Parameters
save_obj (Union[Cell, list]) – The cell object or data list(each element is a dictionary, like [{“name”: param_name, “data”: param_data},…], the type of param_name would be string, and the type of param_data would be parameter or Tensor).
ckpt_file_name (str) – Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool) – Whether to integrated save in automatic model parallel scene. Default:
True
.async_save (bool) – Whether to open an independent thread to save the checkpoint file. Default:
False
.append_dict (dict) – Additional information that needs to be saved. The key of dict must be str, the value of dict must be one of int, float, bool, string, Parameter or Tensor. Default:
None
.enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is
None
, the encryption is not required. Default:None
.enc_mode (str) – This parameter is valid only when enc_key is not set to
None
. Specifies the encryption mode, currently supports"AES-GCM"
and"AES-CBC"
and"SM4-CBC"
. Default:"AES-GCM"
.choice_func (function) – A function for saving custom selected parameters. The input value of choice_func is a parameter name in string type, and the return value is a bool. If returns
True
, the Parameter that matching the custom condition will be saved. If returnsFalse
, the Parameter that not matching the custom condition will not be saved. Default:None
.kwargs (dict) –
Configuration options dictionary.
incremental (bool): Whether export checkpoint for MapParameter incrementally.
- Raises
Examples
>>> import mindspore as ms >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> ms.save_checkpoint(net, "./lenet.ckpt", >>> choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1")) >>> param_dict = ms.load_checkpoint("./lenet.ckpt") >>> print(param_dict) {'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
- Tutorial Examples: