mindspore.save_checkpoint

View Source On Gitee
mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, **kwargs)[source]

Save checkpoint to a specified file.

Parameters
  • save_obj (Union[Cell, list, dict]) – The object to be saved. The data type can be mindspore.nn.Cell, list, or dict. If a list, it can be the returned value of Cell.trainable_params(), or a list of dict elements(each element is a dictionary, like [{“name”: param_name, “data”: param_data},…], the type of param_name must be string, and the type of param_data must be parameter or Tensor); If dict, it can be the returned value of mindspore.load_checkpoint().

  • ckpt_file_name (str) – Checkpoint file name. If the file name already exists, it will be overwritten.

  • integrated_save (bool) – Whether to integrated save in automatic model parallel scene. Default: True .

  • async_save (bool) – Whether to open an independent thread to save the checkpoint file. Default: False .

  • append_dict (dict) – Additional information that needs to be saved. The key of dict must be str, the value of dict must be one of int, float, bool, string, Parameter or Tensor. Default: None .

  • enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is None , the encryption is not required. Default: None .

  • enc_mode (str) – This parameter is valid only when enc_key is not set to None . Specifies the encryption mode, currently supports "AES-GCM" and "AES-CBC" and "SM4-CBC" . Default: "AES-GCM" .

  • choice_func (function) – A function for saving custom selected parameters. The input value of choice_func is a parameter name in string type, and the returned value is a bool. If returns True , the Parameter that matching the custom condition will be saved. If returns False , the Parameter that not matching the custom condition will not be saved. Default: None .

  • kwargs (dict) – Configuration options dictionary.

Raises
  • TypeError – If the parameter save_obj is not mindspore.nn.Cell , list or dict type.

  • TypeError – If the parameter integrated_save or async_save is not bool type.

  • TypeError – If the parameter ckpt_file_name is not string type.

Examples

>>> import mindspore as ms
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> ms.save_checkpoint(net, "./lenet.ckpt",
...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
>>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
>>> print(param_dict1)
{'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
>>> params_list = net.trainable_params()
>>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
>>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
>>> print(param_dict2)
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
>>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
>>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
>>> print(param_dict3)
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
Tutorial Examples: