保存与加载
上一章节的内容里面主要是介绍了如何调整超参数,并进行网络模型训练。训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型部署和推理,本章节我们开始学习如何保存与加载模型。
模型训练
下面我们以MNIST数据集为例,介绍网络模型的保存与加载方式。首先,我们需要获取MNIST数据集并训练模型,示例代码如下:
[2]:
import mindspore.nn as nn
from mindspore.train import Model
from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet
from mindvision.engine.callback import LossMonitor
epochs = 10 # 训练轮次
# 1. 构建数据集
download_train = Mnist(path="./mnist", split="train", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)
dataset_train = download_train.run()
# 2. 定义神经网络
network = lenet(num_classes=10, pretrained=False)
# 3.1 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 3.2 定义优化器函数
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
# 3.3 初始化模型参数
model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})
# 4. 对神经网络执行训练
model.train(epochs, dataset_train, callbacks=[LossMonitor(0.01, 1875)])
Epoch:[ 0/ 10], step:[ 1875/ 1875], loss:[0.148/1.210], time:2.021 ms, lr:0.01000
Epoch time: 4251.808 ms, per step time: 2.268 ms, avg loss: 1.210
Epoch:[ 1/ 10], step:[ 1875/ 1875], loss:[0.049/0.081], time:2.048 ms, lr:0.01000
Epoch time: 4301.405 ms, per step time: 2.294 ms, avg loss: 0.081
Epoch:[ 2/ 10], step:[ 1875/ 1875], loss:[0.014/0.050], time:1.992 ms, lr:0.01000
Epoch time: 4278.799 ms, per step time: 2.282 ms, avg loss: 0.050
Epoch:[ 3/ 10], step:[ 1875/ 1875], loss:[0.035/0.038], time:2.254 ms, lr:0.01000
Epoch time: 4380.553 ms, per step time: 2.336 ms, avg loss: 0.038
Epoch:[ 4/ 10], step:[ 1875/ 1875], loss:[0.130/0.031], time:1.932 ms, lr:0.01000
Epoch time: 4287.547 ms, per step time: 2.287 ms, avg loss: 0.031
Epoch:[ 5/ 10], step:[ 1875/ 1875], loss:[0.003/0.027], time:1.981 ms, lr:0.01000
Epoch time: 4377.000 ms, per step time: 2.334 ms, avg loss: 0.027
Epoch:[ 6/ 10], step:[ 1875/ 1875], loss:[0.004/0.023], time:2.167 ms, lr:0.01000
Epoch time: 4687.250 ms, per step time: 2.500 ms, avg loss: 0.023
Epoch:[ 7/ 10], step:[ 1875/ 1875], loss:[0.004/0.020], time:2.226 ms, lr:0.01000
Epoch time: 4685.529 ms, per step time: 2.499 ms, avg loss: 0.020
Epoch:[ 8/ 10], step:[ 1875/ 1875], loss:[0.000/0.016], time:2.275 ms, lr:0.01000
Epoch time: 4651.129 ms, per step time: 2.481 ms, avg loss: 0.016
Epoch:[ 9/ 10], step:[ 1875/ 1875], loss:[0.022/0.015], time:2.177 ms, lr:0.01000
Epoch time: 4623.760 ms, per step time: 2.466 ms, avg loss: 0.015
从上面的打印结果可以看出,随着训练轮次的增加,损失值趋于收敛。
保存模型
在训练完网络完成后,下面我们将网络模型以文件的形式保存下来。保存模型的接口有主要2种:
简单的对网络模型进行保存,可以在训练前后进行保存。这种方式的优点是接口简单易用,但是只保留执行命令时候的网络模型状态;
在网络模型训练中进行保存,MindSpore在网络模型训练的过程中,自动保存训练时候设定好的epoch数和step数的参数,也就是把模型训练过程中产生的中间权重参数也保存下来,方便进行网络微调和停止训练;
直接保存模型
使用MindSpore提供的save_checkpoint保存模型,传入网络和保存路径:
[3]:
import mindspore as ms
# 定义的网络模型为net,一般在训练前或者训练后使用
ms.save_checkpoint(network, "./MyNet.ckpt")
其中,network
为训练网络,"./MyNet.ckpt"
为网络模型的保存路径。
训练过程中保存模型
在模型训练的过程中,使用model.train
里面的callbacks
参数传入保存模型的对象 ModelCheckpoint(一般与CheckpointConfig配合使用),可以保存模型参数,生成CheckPoint(简称ckpt)文件。
用户可以根据具体需求通过设置CheckpointConfig
来对CheckPoint策略进行配置。具体用法如下:
[4]:
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
# 设置epoch_num数量
epoch_num = 5
# 设置模型保存参数
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="lenet", directory="./lenet", config=config_ck)
model.train(epoch_num, dataset_train, callbacks=[ckpoint])
上述代码中,首先需要初始化一个CheckpointConfig
类对象,用来设置保存策略。
save_checkpoint_steps
表示每隔多少个step保存一次。keep_checkpoint_max
表示最多保留CheckPoint文件的数量。prefix
表示生成CheckPoint文件的前缀名。directory
表示存放文件的目录。
创建一个ModelCheckpoint
对象把它传递给model.train
方法,就可以在训练过程中使用CheckPoint功能了。
生成的CheckPoint文件如下:
lenet-graph.meta # 编译后的计算图
lenet-1_1875.ckpt # CheckPoint文件后缀名为'.ckpt'
lenet-2_1875.ckpt # 文件的命名方式表示保存参数所在的epoch和step数,这里为第2个epoch的第1875个step的模型参数
lenet-3_1875.ckpt # 表示保存的是第3个epoch的第1875个step的模型参数
...
如果用户使用相同的前缀名,运行多次训练脚本,可能会生成同名CheckPoint文件。MindSpore为方便用户区分每次生成的文件,会在用户定义的前缀后添加”_”和数字加以区分。如果想要删除.ckpt
文件时,请同步删除.meta
文件。
例:lenet_3-2_1875.ckpt
表示运行第3次脚本生成的第2个epoch的第1875个step的CheckPoint文件。
加载模型
要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpoint
和load_param_into_net
方法加载参数。
示例代码如下:
[5]:
from mindspore import load_checkpoint, load_param_into_net
from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet
# 将模型参数存入parameter的字典中,这里加载的是上面训练过程中保存的模型参数
param_dict = load_checkpoint("./lenet/lenet-5_1875.ckpt")
# 重新定义一个LeNet神经网络
net = lenet(num_classes=10, pretrained=False)
# 将参数加载到网络中
load_param_into_net(net, param_dict)
# 重新定义优化器函数
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
model = Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={"accuracy"})
load_checkpoint
方法会把参数文件中的网络参数加载到字典param_dict
中。load_param_into_net
方法会把字典param_dict
中的参数加载到网络或者优化器中,加载后,网络中的参数就是CheckPoint保存的。
模型验证
在上述模块把参数加载到网络中之后,针对推理场景,可以调用eval
函数进行推理验证。示例代码如下:
[8]:
# 调用eval()进行推理
download_eval = Mnist(path="./mnist", split="test", batch_size=32, resize=32, download=True)
dataset_eval = download_eval.run()
acc = model.eval(dataset_eval)
print("{}".format(acc))
{'accuracy': 0.9866786858974359}
用于迁移学习
针对任务中断再训练及微调(Fine-tuning)场景,可以调用train
函数进行迁移学习。示例代码如下:
[9]:
# 定义训练数据集
download_train = Mnist(path="./mnist", split="train", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)
dataset_train = download_train.run()
# 网络模型调用train()继续进行训练
model.train(epoch_num, dataset_train, callbacks=[LossMonitor(0.01, 1875)])
Epoch:[ 0/ 5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.193 ms, lr:0.01000
Epoch time: 4106.620 ms, per step time: 2.190 ms, avg loss: 0.010
Epoch:[ 1/ 5], step:[ 1875/ 1875], loss:[0.000/0.009], time:2.036 ms, lr:0.01000
Epoch time: 4233.697 ms, per step time: 2.258 ms, avg loss: 0.009
Epoch:[ 2/ 5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.045 ms, lr:0.01000
Epoch time: 4246.248 ms, per step time: 2.265 ms, avg loss: 0.010
Epoch:[ 3/ 5], step:[ 1875/ 1875], loss:[0.000/0.008], time:2.001 ms, lr:0.01000
Epoch time: 4235.036 ms, per step time: 2.259 ms, avg loss: 0.008
Epoch:[ 4/ 5], step:[ 1875/ 1875], loss:[0.002/0.008], time:2.039 ms, lr:0.01000
Epoch time: 4354.482 ms, per step time: 2.322 ms, avg loss: 0.008