{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 保存与加载\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/beginner/mindspore_save_load.ipynb) \n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/beginner/mindspore_save_load.py) \n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/tutorials/source_zh_cn/beginner/save_load.ipynb)\n", "\n", "上一章节的内容里面主要是介绍了如何调整超参数,并进行网络模型训练。训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型部署和推理,本章节我们开始学习如何保存与加载模型。\n", "\n", "## 模型训练\n", "\n", "下面我们以MNIST数据集为例,介绍网络模型的保存与加载方式。首先,我们需要获取MNIST数据集并训练模型,示例代码如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch:[ 0/ 10], step:[ 1875/ 1875], loss:[0.148/1.210], time:2.021 ms, lr:0.01000\n", "Epoch time: 4251.808 ms, per step time: 2.268 ms, avg loss: 1.210\n", "Epoch:[ 1/ 10], step:[ 1875/ 1875], loss:[0.049/0.081], time:2.048 ms, lr:0.01000\n", "Epoch time: 4301.405 ms, per step time: 2.294 ms, avg loss: 0.081\n", "Epoch:[ 2/ 10], step:[ 1875/ 1875], loss:[0.014/0.050], time:1.992 ms, lr:0.01000\n", "Epoch time: 4278.799 ms, per step time: 2.282 ms, avg loss: 0.050\n", "Epoch:[ 3/ 10], step:[ 1875/ 1875], loss:[0.035/0.038], time:2.254 ms, lr:0.01000\n", "Epoch time: 4380.553 ms, per step time: 2.336 ms, avg loss: 0.038\n", "Epoch:[ 4/ 10], step:[ 1875/ 1875], loss:[0.130/0.031], time:1.932 ms, lr:0.01000\n", "Epoch time: 4287.547 ms, per step time: 2.287 ms, avg loss: 0.031\n", "Epoch:[ 5/ 10], step:[ 1875/ 1875], loss:[0.003/0.027], time:1.981 ms, lr:0.01000\n", "Epoch time: 4377.000 ms, per step time: 2.334 ms, avg loss: 0.027\n", "Epoch:[ 6/ 10], step:[ 1875/ 1875], loss:[0.004/0.023], time:2.167 ms, lr:0.01000\n", "Epoch time: 4687.250 ms, per step time: 2.500 ms, avg loss: 0.023\n", "Epoch:[ 7/ 10], step:[ 1875/ 1875], loss:[0.004/0.020], time:2.226 ms, lr:0.01000\n", "Epoch time: 4685.529 ms, per step time: 2.499 ms, avg loss: 0.020\n", "Epoch:[ 8/ 10], step:[ 1875/ 1875], loss:[0.000/0.016], time:2.275 ms, lr:0.01000\n", "Epoch time: 4651.129 ms, per step time: 2.481 ms, avg loss: 0.016\n", "Epoch:[ 9/ 10], step:[ 1875/ 1875], loss:[0.022/0.015], time:2.177 ms, lr:0.01000\n", "Epoch time: 4623.760 ms, per step time: 2.466 ms, avg loss: 0.015\n" ] } ], "source": [ "import mindspore.nn as nn\n", "import mindspore as ms\n", "\n", "from mindvision.classification.dataset import Mnist\n", "from mindvision.classification.models import lenet\n", "from mindvision.engine.callback import LossMonitor\n", "\n", "epochs = 10 # 训练轮次\n", "\n", "# 1. 构建数据集\n", "download_train = Mnist(path=\"./mnist\", split=\"train\", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)\n", "dataset_train = download_train.run()\n", "\n", "# 2. 定义神经网络\n", "network = lenet(num_classes=10, pretrained=False)\n", "# 3.1 定义损失函数\n", "net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n", "# 3.2 定义优化器函数\n", "net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n", "# 3.3 初始化模型参数\n", "model = ms.Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})\n", "\n", "# 4. 对神经网络执行训练\n", "model.train(epochs, dataset_train, callbacks=[LossMonitor(0.01, 1875)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,随着训练轮次的增加,损失值趋于收敛。\n", "\n", "## 保存模型\n", "\n", "在训练完网络完成后,下面我们将网络模型以文件的形式保存下来。保存模型的接口有主要2种:\n", "\n", "1. 简单的对网络模型进行保存,可以在训练前后进行保存。这种方式的优点是接口简单易用,但是只保留执行命令时候的网络模型状态;\n", "2. 在网络模型训练中进行保存,MindSpore在网络模型训练的过程中,自动保存训练时候设定好的epoch数和step数的参数,也就是把模型训练过程中产生的中间权重参数也保存下来,方便进行网络微调和停止训练;\n", "\n", "### 直接保存模型\n", "\n", "使用MindSpore提供的save_checkpoint保存模型,传入网络和保存路径:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "\n", "# 定义的网络模型为net,一般在训练前或者训练后使用\n", "ms.save_checkpoint(network, \"./MyNet.ckpt\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中,`network`为训练网络,`\"./MyNet.ckpt\"`为网络模型的保存路径。\n", "\n", "### 训练过程中保存模型\n", "\n", "在模型训练的过程中,使用`model.train`里面的`callbacks`参数传入保存模型的对象 [ModelCheckpoint](https://mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore/mindspore.ModelCheckpoint.html#mindspore.ModelCheckpoint)(一般与[CheckpointConfig](https://mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore/mindspore.CheckpointConfig.html#mindspore.CheckpointConfig)配合使用),可以保存模型参数,生成CheckPoint(简称ckpt)文件。\n", "\n", "用户可以根据具体需求通过设置`CheckpointConfig`来对CheckPoint策略进行配置。具体用法如下:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "\n", "# 设置epoch_num数量\n", "epoch_num = 5\n", "\n", "# 设置模型保存参数\n", "config_ck = ms.CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)\n", "\n", "# 应用模型保存参数\n", "ckpoint = ms.ModelCheckpoint(prefix=\"lenet\", directory=\"./lenet\", config=config_ck)\n", "model.train(epoch_num, dataset_train, callbacks=[ckpoint])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上述代码中,首先需要初始化一个`CheckpointConfig`类对象,用来设置保存策略。\n", "\n", "- `save_checkpoint_steps`表示每隔多少个step保存一次。\n", "- `keep_checkpoint_max`表示最多保留CheckPoint文件的数量。\n", "- `prefix`表示生成CheckPoint文件的前缀名。\n", "- `directory`表示存放文件的目录。\n", "\n", "创建一个`ModelCheckpoint`对象把它传递给`model.train`方法,就可以在训练过程中使用CheckPoint功能了。\n", "\n", "生成的CheckPoint文件如下:\n", "\n", "```text\n", "lenet-graph.meta # 编译后的计算图\n", "lenet-1_1875.ckpt # CheckPoint文件后缀名为'.ckpt'\n", "lenet-2_1875.ckpt # 文件的命名方式表示保存参数所在的epoch和step数,这里为第2个epoch的第1875个step的模型参数\n", "lenet-3_1875.ckpt # 表示保存的是第3个epoch的第1875个step的模型参数\n", "...\n", "```\n", "\n", "如果用户使用相同的前缀名,运行多次训练脚本,可能会生成同名CheckPoint文件。MindSpore为方便用户区分每次生成的文件,会在用户定义的前缀后添加\"_\"和数字加以区分。如果想要删除`.ckpt`文件时,请同步删除`.meta` 文件。\n", "\n", "例:`lenet_3-2_1875.ckpt` 表示运行第3次脚本生成的第2个epoch的第1875个step的CheckPoint文件。\n", "\n", "## 加载模型\n", "\n", "要加载模型权重,需要先创建相同模型的实例,然后使用`load_checkpoint`和`load_param_into_net`方法加载参数。\n", "\n", "示例代码如下:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "\n", "from mindvision.classification.dataset import Mnist\n", "from mindvision.classification.models import lenet\n", "\n", "# 将模型参数存入parameter的字典中,这里加载的是上面训练过程中保存的模型参数\n", "param_dict = ms.load_checkpoint(\"./lenet/lenet-5_1875.ckpt\")\n", "\n", "# 重新定义一个LeNet神经网络\n", "net = lenet(num_classes=10, pretrained=False)\n", "\n", "# 将参数加载到网络中\n", "ms.load_param_into_net(net, param_dict)\n", "\n", "# 重新定义优化器函数\n", "net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)\n", "\n", "model = ms.Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={\"accuracy\"})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- `load_checkpoint`方法会把参数文件中的网络参数加载到字典`param_dict`中。\n", "- `load_param_into_net`方法会把字典`param_dict`中的参数加载到网络或者优化器中,加载后,网络中的参数就是CheckPoint保存的。\n", "\n", "### 模型验证\n", "\n", "在上述模块把参数加载到网络中之后,针对推理场景,可以调用`eval`函数进行推理验证。示例代码如下:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'accuracy': 0.9866786858974359}\n" ] } ], "source": [ "# 调用eval()进行推理\n", "download_eval = Mnist(path=\"./mnist\", split=\"test\", batch_size=32, resize=32, download=True)\n", "dataset_eval = download_eval.run()\n", "acc = model.eval(dataset_eval)\n", "\n", "print(\"{}\".format(acc))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 用于迁移学习\n", "\n", "针对任务中断再训练及微调(Fine-tuning)场景,可以调用`train`函数进行迁移学习。示例代码如下:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch:[ 0/ 5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.193 ms, lr:0.01000\n", "Epoch time: 4106.620 ms, per step time: 2.190 ms, avg loss: 0.010\n", "Epoch:[ 1/ 5], step:[ 1875/ 1875], loss:[0.000/0.009], time:2.036 ms, lr:0.01000\n", "Epoch time: 4233.697 ms, per step time: 2.258 ms, avg loss: 0.009\n", "Epoch:[ 2/ 5], step:[ 1875/ 1875], loss:[0.000/0.010], time:2.045 ms, lr:0.01000\n", "Epoch time: 4246.248 ms, per step time: 2.265 ms, avg loss: 0.010\n", "Epoch:[ 3/ 5], step:[ 1875/ 1875], loss:[0.000/0.008], time:2.001 ms, lr:0.01000\n", "Epoch time: 4235.036 ms, per step time: 2.259 ms, avg loss: 0.008\n", "Epoch:[ 4/ 5], step:[ 1875/ 1875], loss:[0.002/0.008], time:2.039 ms, lr:0.01000\n", "Epoch time: 4354.482 ms, per step time: 2.322 ms, avg loss: 0.008\n" ] } ], "source": [ "# 定义训练数据集\n", "download_train = Mnist(path=\"./mnist\", split=\"train\", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)\n", "dataset_train = download_train.run()\n", "\n", "# 网络模型调用train()继续进行训练\n", "model.train(epoch_num, dataset_train, callbacks=[LossMonitor(0.01, 1875)])" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }