mindspore.save_checkpoint

查看源文件
mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode='AES-GCM', choice_func=None, **kwargs)[源代码]

将网络权重保存到checkpoint文件中。

参数:
  • save_obj (Union[Cell, list, dict]) - 待保存的对象。数据类型可为 mindspore.nn.Cell 、list或dict。若为list,可以是 Cell.trainable_params() 的返回值,或元素为dict的列表(如[{“name”: param_name, “data”: param_data},…],param_name 的类型必须是str,param_data 的类型必须是Parameter或者Tensor);若为dict,可以是 mindspore.load_checkpoint() 的返回值。

  • ckpt_file_name (str) - checkpoint文件名称。如果文件已存在,将会覆盖原有文件。

  • integrated_save (bool) - 在并行场景下是否合并保存拆分的Tensor。默认值: True

  • async_save (bool) - 是否异步执行保存checkpoint文件。默认值: False

  • append_dict (dict) - 需要保存的其他信息。dict的键必须为str类型,dict的值类型必须是int、float、bool、string、Parameter或Tensor类型。默认值: None

  • enc_key (Union[None, bytes]) - 用于加密的字节类型密钥。如果值为 None ,那么不需要加密。默认值: None

  • enc_mode (str) - 该参数在 enc_key 不为 None 时有效,指定加密模式,目前仅支持 "AES-GCM""AES-CBC""SM4-CBC" 。默认值: "AES-GCM"

  • choice_func (function) - 一个用于自定义控制保存参数的函数。函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被保存。 如果返回 False ,则未匹配自定义条件的Parameter不会被保存。默认值: None

  • kwargs (dict) - 配置选项字典。

异常:
  • TypeError - 如果参数 save_obj 类型不为 mindspore.nn.Cell 、list或者dict。

  • TypeError - 如果参数 integrated_saveasync_save 不是bool类型。

  • TypeError - 如果参数 ckpt_file_name 不是字符串类型。

样例:

>>> import mindspore as ms
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> ms.save_checkpoint(net, "./lenet.ckpt",
...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv1"))
>>> param_dict1 = ms.load_checkpoint("./lenet.ckpt")
>>> print(param_dict1)
{'conv2.weight': Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)}
>>> params_list = net.trainable_params()
>>> ms.save_checkpoint(params_list, "./lenet_list.ckpt",
...                    choice_func=lambda x: x.startswith("conv") and not x.startswith("conv2"))
>>> param_dict2 = ms.load_checkpoint("./lenet_list.ckpt")
>>> print(param_dict2)
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
>>> ms.save_checkpoint(param_dict2, "./lenet_dict.ckpt")
>>> param_dict3 = ms.load_checkpoint("./lenet_dict.ckpt")
>>> print(param_dict3)
{'conv1.weight': Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True)}
教程样例: