mindspore.save_checkpoint

View Source On Gitee
mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, crc_check=False, format='ckpt', **kwargs)[source]

Save checkpoint to a specified file.

Note

The enc_mode and crc_check parameters are mutually exclusive and cannot be configured simultaneously.

Parameters
  • save_obj (Union[Cell, list, dict]) – The object to be saved. The data type can be mindspore.nn.Cell, list, or dict. If a list, it can be the returned value of Cell.trainable_params(), or a list of dict elements(each element is a dictionary, like [{"name": param_name, "data": param_data},…], the type of param_name must be string, and the type of param_data must be parameter or Tensor); If dict, it can be the returned value of mindspore.load_checkpoint().

  • ckpt_file_name (str) – Checkpoint file name. If the file name already exists, it will be overwritten.

  • integrated_save (bool) – Whether to integrated save in automatic model parallel scene. Default: True .

  • async_save (Union[bool, str]) – Whether to use asynchronous saving of the checkpoint file, if True, the asynchronous thread is used by default. If the type is string, the method of asynchronous saving, it can be "process" or "thread". Default: False .

  • append_dict (dict) – Additional information that needs to be saved. The key of dict must be str, the value of dict must be one of int, float, bool, string, Parameter or Tensor. Default: None .

  • enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is None , the encryption is not required. Default: None .

  • enc_mode (str) – This parameter is valid only when enc_key is not set to None . Specifies the encryption mode, currently supports "AES-GCM" and "AES-CBC" and "SM4-CBC" . Default: "AES-GCM" .

  • choice_func (function) – A function for saving custom selected parameters. The input value of choice_func is a parameter name in string type, and the returned value is a bool. If returns True , the Parameter that matching the custom condition will be saved. If returns False , the Parameter that not matching the custom condition will not be saved. Default: None .

  • crc_check (bool) – Whether to perform crc32 calculation when saving checkpoint and save the calculation result to the file. Default: False .

  • format (str) – Format of the output file, can be "ckpt" or "safetensors". Default: "ckpt".

  • kwargs (dict) – Configuration options dictionary.

Raises
  • TypeError – If the parameter save_obj is not mindspore.nn.Cell , list or dict type.

  • TypeError – If the parameter integrated_save is not bool type.

  • TypeError – If the parameter ckpt_file_name is not string type.

  • TypeError – If the parameter async_save is not bool or string type.

  • ValueError – If the parameter async_save is string type but not in ["process", "thread"].

Examples

>>> import mindspore as ms
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> ms.save_checkpoint(net, "./lenet.ckpt",
...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
>>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
>>> print(param_dict1)
{'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
>>> params_list = net.trainable_params()
>>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
>>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
>>> print(param_dict2)
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
>>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
>>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
>>> print(param_dict3)
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
Tutorial Examples: