mindspore.save_checkpoint

mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, **kwargs)[source]

Save checkpoint to a specified file.

Parameters
  • save_obj (Union[Cell, list]) – The cell object or data list(each element is a dictionary, like [{“name”: param_name, “data”: param_data},…], the type of param_name would be string, and the type of param_data would be parameter or Tensor).

  • ckpt_file_name (str) – Checkpoint file name. If the file name already exists, it will be overwritten.

  • integrated_save (bool) – Whether to integrated save in automatic model parallel scene. Default: True .

  • async_save (bool) – Whether to open an independent thread to save the checkpoint file. Default: False .

  • append_dict (dict) – Additional information that needs to be saved. The key of dict must be str, the value of dict must be one of int, float, bool, string, Parameter or Tensor. Default: None .

  • enc_key (Union[None, bytes]) – Byte type key used for encryption. If the value is None , the encryption is not required. Default: None .

  • enc_mode (str) – This parameter is valid only when enc_key is not set to None . Specifies the encryption mode, currently supports "AES-GCM" and "AES-CBC" and "SM4-CBC" . Default: "AES-GCM" .

  • choice_func (function) – A function for saving custom selected parameters. The input value of choice_func is a parameter name in string type, and the return value is a bool. If returns True , the Parameter that matching the custom condition will be saved. If returns False , the Parameter that not matching the custom condition will not be saved. Default: None .

  • kwargs (dict) –

    Configuration options dictionary.

    • incremental (bool): Whether export checkpoint for MapParameter incrementally.

Raises
  • TypeError – If the parameter save_obj is not nn.Cell or list type.

  • TypeError – If the parameter integrated_save or async_save is not bool type.

  • TypeError – If the parameter ckpt_file_name is not string type.

Examples

>>> import mindspore as ms
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> ms.save_checkpoint(net, "./lenet.ckpt",
>>>                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
>>> param_dict = ms.load_checkpoint("./lenet.ckpt")
>>> print(param_dict)
{'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
Tutorial Examples: