加载模型用于推理或迁移学习

Linux Ascend GPU CPU 模型加载 初级 中级 高级

查看源文件    查看notebook    在线运行

概述

在模型训练过程中保存在本地的CheckPoint文件,可以帮助用户进行推理或迁移学习使用。

以下通过示例来介绍如何通过本地加载,用于推理验证和迁移学习。

本地加载模型

用于推理验证

针对仅推理场景可以使用load_checkpoint把参数直接加载到网络中,以便进行后续的推理验证。

示例代码如下:

resnet = ResNet50()
load_checkpoint("resnet50-2_32.ckpt", net=resnet)
dateset_eval = create_dataset(os.path.join(mnist_path, "test"), 32, 1) # define the test dataset
loss = CrossEntropyLoss()
model = Model(resnet, loss, metrics={"accuracy"})
acc = model.eval(dataset_eval)
  • load_checkpoint方法会把参数文件中的网络参数加载到模型中。加载后,网络中的参数就是CheckPoint保存的。

  • eval方法会验证训练后模型的精度。

用于迁移学习

针对任务中断再训练及微调(Fine Tune)场景,可以加载网络参数和优化器参数到模型中。

示例代码如下:

# return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt")
resnet = ResNet50()
opt = Momentum(resnet.trainable_params(), 0.01, 0.9)
# load the parameter into net
load_param_into_net(resnet, param_dict)
# load the parameter into optimizer
load_param_into_net(opt, param_dict)
loss = SoftmaxCrossEntropyWithLogits()
model = Model(resnet, loss, opt)
model.train(epoch, dataset)
  • load_checkpoint方法会返回一个参数字典。

  • load_param_into_net会把参数字典中相应的参数加载到网络或优化器中。