mindspore.load_checkpoint_async

查看源文件
mindspore.load_checkpoint_async(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode='AES-GCM', specify_prefix=None, choice_func=None)[源代码]

异步加载checkpoint文件。

警告

这是一个实验性API,后续可能修改或删除。

说明

  • specify_prefixfilter_prefix 的功能相互之间没有影响。

  • 如果发现没有参数被成功加载,将会报ValueError。

  • specify_prefixfilter_prefix 参数已被弃用,推荐使用 choice_func 代替。并且使用这两个参数中的任何一个都将覆盖 choice_func

参数:
  • ckpt_file_name (str) - checkpoint的文件名称。

  • net (Cell,可选) - 加载checkpoint参数的网络。默认值: None

  • strict_load (bool,可选) - 是否将严格加载参数到网络中。如果是 False ,它将根据相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行强制精度转换,比如将float32转换为float16。默认值: False

  • filter_prefix (Union[str, list[str], tuple[str]],可选) - 废弃(请参考参数 choice_func)。以 filter_prefix 开头的参数将不会被加载。默认值: None

  • dec_key (Union[None, bytes],可选) - 用于解密的字节类型密钥,如果值为 None ,则不需要解密。默认值: None

  • dec_mode (str,可选) - 该参数仅当 dec_key 不为 None 时有效。指定解密模式,目前支持 "AES-GCM""AES-CBC""SM4-CBC" 。默认值: "AES-GCM"

  • specify_prefix (Union[str, list[str], tuple[str]],可选) - 废弃(请参考参数 choice_func)。以 specify_prefix 开头的参数将会被加载。默认值: None

  • choice_func (Union[None, function],可选) - 函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被加载。 如果返回 False ,则匹配自定义条件的Parameter将被删除。默认值: None

返回:

自定义的内部类, 调用其 result 方法可以得到 mindspore.load_checkpoint() 返回的结果。

异常:
  • ValueError - checkpoint文件格式不正确。

  • ValueError - 没有一个参数被成功加载。

  • TypeError - specify_prefix 或者 filter_prefix 的数据类型不正确。

样例:

>>> import mindspore
>>> from mindspore import nn
>>> from mindspore.train import Model
>>> from mindspore.amp import FixedLossScaleManager
>>> from mindspore import context
>>> from mindspore import load_checkpoint_async
>>> from mindspore import load_param_into_net
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
>>> # Create the dataset taking MNIST as an example. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/mnist.py
>>> dataset = create_dataset()
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> ckpt_file = "./checkpoint/LeNet5-1_32.ckpt"
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
...               loss_scale_manager=loss_scale_manager)
>>> pd_future = load_checkpoint_async(ckpt_file)
>>> model.build(train_dataset=dataset, epoch=2)
>>> param_dict = pd_future.result()
>>> load_param_into_net(net, param_dict)
>>> model.train(2, dataset)
>>> print("param dict len: ", len(param_dict), flush=True)