{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 回调机制 Callback\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/train/mindspore_callback.ipynb) \n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/advanced/train/mindspore_callback.py) \n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/tutorials/source_zh_cn/advanced/train/callback.ipynb)\n", "\n", "在深度学习训练过程中,为及时掌握网络模型的训练状态、实时观察网络模型各参数的变化情况和实现训练过程中用户自定义的一些操作,MindSpore提供了回调机制(Callback)来实现上述功能。\n", "\n", "Callback回调机制一般用在网络模型训练过程`Model.train`中,MindSpore的`Model`会按照Callback列表`callbacks`顺序执行回调函数,用户可以通过设置不同的回调类来实现在训练过程中或者训练后执行的功能。\n", "\n", "> 更多内置回调类的信息及使用方式请参考[API文档](https://www.mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore/mindspore.Callback.html#mindspore.Callback)。\n", "\n", "## Callback介绍和使用\n", "\n", "当聊到回调Callback的时候,大部分用户都会觉得很难理解,是不是需要堆栈或者特殊的调度方式,实际上我们简单的理解回调:\n", "\n", "假设函数A有一个参数,这个参数是个函数B,当函数A执行完以后执行函数B,那么这个过程就叫回调。\n", "\n", "`Callback`是回调的意思,MindSpore中的回调函数实际上不是一个函数而是一个类,用户可以使用回调机制来**观察训练过程中网络内部的状态和相关信息,或在特定时期执行特定动作**。\n", "\n", "例如监控损失函数Loss、保存模型参数ckpt、动态调整参数lr、提前终止训练任务等。\n", "\n", "下面以基于MNIST数据集训练LeNet-5网络模型为例,介绍几种常用的MindSpore内置回调类。\n", "\n", "首先需要下载并处理MNIST数据,构建LeNet-5网络模型,示例代码如下:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import mindspore.nn as nn\n", "import mindspore as ms\n", "from mindvision.classification.dataset import Mnist\n", "from mindvision.classification.models import lenet\n", "\n", "download_train = Mnist(path=\"./mnist\", split=\"train\", download=True)\n", "dataset_train = download_train.run()\n", "\n", "network = lenet(num_classes=10, pretrained=False)\n", "net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n", "net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)\n", "\n", "# 定义网络模型\n", "model = ms.Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={\"Accuracy\": nn.Accuracy()})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "回调机制的使用方法,在`model.train`方法中传入`Callback`对象,它可以是一个`Callback`列表,示例代码如下,其中[ModelCheckpoint](https://mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore/mindspore.ModelCheckpoint.html#mindspore.ModelCheckpoint)和[LossMonitor](https://mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore/mindspore.LossMonitor.html#mindspore.LossMonitor)是MindSpore提供的回调类。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 1875, loss is 0.257398396730423\n", "epoch: 2 step: 1875, loss is 0.04801357910037041\n", "epoch: 3 step: 1875, loss is 0.028765171766281128\n", "epoch: 4 step: 1875, loss is 0.008372672833502293\n", "epoch: 5 step: 1875, loss is 0.0016194271156564355\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "# 定义回调类\n", "ckpt_cb = ms.ModelCheckpoint()\n", "loss_cb = ms.LossMonitor(1875)\n", "\n", "model.train(5, dataset_train, callbacks=[ckpt_cb, loss_cb])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 常用的内置回调函数\n", "\n", "MindSpore提供`Callback`能力,支持用户在训练/推理的特定阶段,插入自定义的操作。\n", "\n", "### ModelCheckpoint\n", "\n", "为了保存训练后的网络模型和参数,方便进行再推理或再训练,MindSpore提供了[ModelCheckpoint](https://mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore/mindspore.ModelCheckpoint.html#mindspore.ModelCheckpoint)接口,一般与配置保存信息接口[CheckpointConfig](https://mindspore.cn/docs/zh-CN/r1.8/api_python/mindspore/mindspore.CheckpointConfig.html#mindspore.CheckpointConfig)配合使用。\n", "\n", "下面我们通过一段示例代码来说明如何保存训练后的网络模型和参数:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "\n", "# 设置保存模型的配置信息\n", "config_ck = ms.CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)\n", "# 实例化保存模型回调接口,定义保存路径和前缀名\n", "ckpoint = ms.ModelCheckpoint(prefix=\"lenet\", directory=\"./lenet\", config=config_ck)\n", "\n", "# 开始训练,加载保存模型和参数回调函数\n", "model.train(1, dataset_train, callbacks=[ckpoint])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面代码运行后,生成的Checkpoint文件目录结构如下:\n", "\n", "```text\n", "./lenet/\n", "├── lenet-1_1875.ckpt # 保存参数文件\n", "└── lenet-graph.meta # 编译后的计算图\n", "```\n", "\n", "### LossMonitor\n", "\n", "为了监控训练过程中的损失函数值Loss变化情况,观察训练过程中每个epoch、每个step的运行时间,[MindSpore Vision](https://mindspore.cn/vision/docs/zh-CN/r0.1/index.html)提供了`LossMonitor`接口(与MindSpore提供的`LossMonitor`接口有区别)。\n", "\n", "下面我们通过示例代码说明:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch:[ 0/ 5], step:[ 375/ 1875], loss:[0.041/0.023], time:0.670 ms, lr:0.01000\n", "Epoch:[ 0/ 5], step:[ 750/ 1875], loss:[0.002/0.023], time:0.723 ms, lr:0.01000\n", "Epoch:[ 0/ 5], step:[ 1125/ 1875], loss:[0.006/0.023], time:0.662 ms, lr:0.01000\n", "Epoch:[ 0/ 5], step:[ 1500/ 1875], loss:[0.000/0.024], time:0.664 ms, lr:0.01000\n", "Epoch:[ 0/ 5], step:[ 1875/ 1875], loss:[0.009/0.024], time:0.661 ms, lr:0.01000\n", "Epoch time: 1759.622 ms, per step time: 0.938 ms, avg loss: 0.024\n", "Epoch:[ 1/ 5], step:[ 375/ 1875], loss:[0.001/0.020], time:0.658 ms, lr:0.01000\n", "Epoch:[ 1/ 5], step:[ 750/ 1875], loss:[0.002/0.021], time:0.661 ms, lr:0.01000\n", "Epoch:[ 1/ 5], step:[ 1125/ 1875], loss:[0.000/0.021], time:0.663 ms, lr:0.01000\n", "Epoch:[ 1/ 5], step:[ 1500/ 1875], loss:[0.048/0.022], time:0.655 ms, lr:0.01000\n", "Epoch:[ 1/ 5], step:[ 1875/ 1875], loss:[0.018/0.022], time:0.646 ms, lr:0.01000\n", "Epoch time: 1551.506 ms, per step time: 0.827 ms, avg loss: 0.022\n", "Epoch:[ 2/ 5], step:[ 375/ 1875], loss:[0.001/0.017], time:0.674 ms, lr:0.01000\n", "Epoch:[ 2/ 5], step:[ 750/ 1875], loss:[0.001/0.018], time:0.669 ms, lr:0.01000\n", "Epoch:[ 2/ 5], step:[ 1125/ 1875], loss:[0.004/0.019], time:0.683 ms, lr:0.01000\n", "Epoch:[ 2/ 5], step:[ 1500/ 1875], loss:[0.003/0.020], time:0.657 ms, lr:0.01000\n", "Epoch:[ 2/ 5], step:[ 1875/ 1875], loss:[0.041/0.019], time:1.447 ms, lr:0.01000\n", "Epoch time: 1616.589 ms, per step time: 0.862 ms, avg loss: 0.019\n", "Epoch:[ 3/ 5], step:[ 375/ 1875], loss:[0.000/0.011], time:0.672 ms, lr:0.01000\n", "Epoch:[ 3/ 5], step:[ 750/ 1875], loss:[0.001/0.013], time:0.687 ms, lr:0.01000\n", "Epoch:[ 3/ 5], step:[ 1125/ 1875], loss:[0.016/0.014], time:0.665 ms, lr:0.01000\n", "Epoch:[ 3/ 5], step:[ 1500/ 1875], loss:[0.001/0.015], time:0.674 ms, lr:0.01000\n", "Epoch:[ 3/ 5], step:[ 1875/ 1875], loss:[0.001/0.015], time:0.666 ms, lr:0.01000\n", "Epoch time: 1586.809 ms, per step time: 0.846 ms, avg loss: 0.015\n", "Epoch:[ 4/ 5], step:[ 375/ 1875], loss:[0.000/0.008], time:0.671 ms, lr:0.01000\n", "Epoch:[ 4/ 5], step:[ 750/ 1875], loss:[0.000/0.013], time:0.701 ms, lr:0.01000\n", "Epoch:[ 4/ 5], step:[ 1125/ 1875], loss:[0.009/0.015], time:0.666 ms, lr:0.01000\n", "Epoch:[ 4/ 5], step:[ 1500/ 1875], loss:[0.008/0.015], time:0.941 ms, lr:0.01000\n", "Epoch:[ 4/ 5], step:[ 1875/ 1875], loss:[0.008/0.015], time:0.661 ms, lr:0.01000\n", "Epoch time: 1584.785 ms, per step time: 0.845 ms, avg loss: 0.015\n" ] } ], "source": [ "from mindvision.engine.callback import LossMonitor\n", "\n", "# 开始训练,加载保存模型和参数回调函数,LossMonitor的入参0.01为学习率,375为步长\n", "model.train(5, dataset_train, callbacks=[LossMonitor(0.01, 375)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,[MindSpore Vision套件](https://mindspore.cn/vision/docs/zh-CN/r0.1/index.html)提供的`LossMonitor`接口打印信息更加详细。由于步长设置的是375,所以每375个step会打印一条,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。\n", "\n", "### ValAccMonitor\n", "\n", "为了在训练过程中保存精度最优的网络模型和参数,需要边训练边验证,MindSpore Vision提供了`ValAccMonitor`接口。\n", "\n", "下面我们通过一段示例来介绍:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------\n", "Epoch: [ 1 / 1], Train Loss: [0.000], Accuracy: 0.988\n", "================================================================================\n", "End of validation the best Accuracy is: 0.988, save the best ckpt file in ./best.ckpt\n" ] } ], "source": [ "from mindvision.engine.callback import ValAccMonitor\n", "\n", "download_eval = Mnist(path=\"./mnist\", split=\"test\", download=True)\n", "dataset_eval = download_eval.run()\n", "\n", "# 开始训练,加载保存模型和参数回调函数\n", "model.train(1, dataset_train, callbacks=[ValAccMonitor(model, dataset_eval, num_epochs=1)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "上面代码执行后,精度最优的网络模型和参数会被保存在当前目录下,文件名为\"best.ckpt\"。\n", "\n", "## 自定义回调机制\n", "\n", "MindSpore不仅有功能强大的内置回调函数,当用户有自己的特殊需求时,还可以基于`Callback`基类自定义回调类。\n", "\n", "用户可以基于`Callback`基类,根据自身的需求,实现自定义`Callback`。`Callback`基类定义如下所示:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class Callback():\n", " \"\"\"Callback base class\"\"\"\n", " def begin(self, run_context):\n", " \"\"\"Called once before the network executing.\"\"\"\n", " pass # pylint: disable=W0107\n", "\n", " def epoch_begin(self, run_context):\n", " \"\"\"Called before each epoch beginning.\"\"\"\n", " pass # pylint: disable=W0107\n", "\n", " def epoch_end(self, run_context):\n", " \"\"\"Called after each epoch finished.\"\"\"\n", " pass # pylint: disable=W0107\n", "\n", " def step_begin(self, run_context):\n", " \"\"\"Called before each step beginning.\"\"\"\n", " pass # pylint: disable=W0107\n", "\n", " def step_end(self, run_context):\n", " \"\"\"Called after each step finished.\"\"\"\n", " pass # pylint: disable=W0107\n", "\n", " def end(self, run_context):\n", " \"\"\"Called once after network training.\"\"\"\n", " pass # pylint: disable=W0107" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "回调机制可以把训练过程中的重要信息记录下来,通过把一个字典类型变量`RunContext.original_args()`,传递给Callback对象,使得用户可以在各个自定义的Callback中获取到相关属性,执行自定义操作,也可以自定义其他变量传递给`RunContext.original_args()`对象。\n", "\n", "`RunContext.original_args()`中的常用属性有:\n", "\n", "- epoch_num:训练的epoch的数量\n", "- batch_num:一个epoch中step的数量\n", "- cur_epoch_num:当前的epoch数\n", "- cur_step_num:当前的step数\n", "\n", "- loss_fn:损失函数\n", "- optimizer:优化器\n", "- train_network:训练的网络\n", "- train_dataset:训练的数据集\n", "- net_outputs:网络的输出结果\n", "\n", "- parallel_mode:并行模式\n", "- list_callback:所有的Callback函数\n", "\n", "通过下面两个场景,我们可以增加对自定义Callback回调机制功能的了解。\n", "\n", "### 自定义终止训练\n", "\n", "实现在规定时间内终止训练功能。用户可以设定时间阈值,当训练时间达到这个阈值后就终止训练过程。\n", "\n", "下面代码中,通过`run_context.original_args`方法可以获取到`cb_params`字典,字典里会包含前文描述的主要属性信息。\n", "\n", "同时可以对字典内的值进行修改和添加,在`begin`函数中定义一个`init_time`对象传递给`cb_params`字典。每个数据迭代结束`step_end`之后会进行判断,当训练时间大于设置的时间阈值时,会向run_context传递终止训练的信号,提前终止训练,并打印当前的epoch、step、loss的值。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Begin training, time is: 1648452437.2004516\n", "Epoch:[ 0/ 5], step:[ 1875/ 1875], loss:[0.011/0.012], time:0.678 ms, lr:0.01000\n", "Epoch time: 1603.104 ms, per step time: 0.855 ms, avg loss: 0.012\n", "Epoch:[ 1/ 5], step:[ 1875/ 1875], loss:[0.000/0.011], time:0.688 ms, lr:0.01000\n", "Epoch time: 1602.716 ms, per step time: 0.855 ms, avg loss: 0.011\n", "End training, time: 1648452441.20081 ,epoch: 3 ,step: 4673 ,loss: 0.014888153\n", "Epoch time: 792.901 ms, per step time: 0.423 ms, avg loss: 0.010\n" ] } ], "source": [ "import time\n", "import mindspore as ms\n", "\n", "class StopTimeMonitor(ms.Callback):\n", "\n", " def __init__(self, run_time):\n", " \"\"\"定义初始化过程\"\"\"\n", " super(StopTimeMonitor, self).__init__()\n", " self.run_time = run_time # 定义执行时间\n", "\n", " def begin(self, run_context):\n", " \"\"\"开始训练时的操作\"\"\"\n", " cb_params = run_context.original_args()\n", " cb_params.init_time = time.time() # 获取当前时间戳作为开始训练时间\n", " print(\"Begin training, time is:\", cb_params.init_time)\n", "\n", " def step_end(self, run_context):\n", " \"\"\"每个step结束后执行的操作\"\"\"\n", " cb_params = run_context.original_args()\n", " epoch_num = cb_params.cur_epoch_num # 获取epoch值\n", " step_num = cb_params.cur_step_num # 获取step值\n", " loss = cb_params.net_outputs # 获取损失值loss\n", " cur_time = time.time() # 获取当前时间戳\n", "\n", " if (cur_time - cb_params.init_time) > self.run_time:\n", " print(\"End training, time:\", cur_time, \",epoch:\", epoch_num, \",step:\", step_num, \",loss:\", loss)\n", " run_context.request_stop() # 停止训练\n", "\n", "download_train = Mnist(path=\"./mnist\", split=\"train\", download=True)\n", "dataset = download_train.run()\n", "model.train(5, dataset, callbacks=[LossMonitor(0.01, 1875), StopTimeMonitor(4)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,当第3个epoch的第4673个step执行完时,运行时间到达了阈值并结束了训练。\n", "\n", "### 自定义阈值保存模型\n", "\n", "该回调机制实现当loss小于设定的阈值时,保存网络模型权重ckpt文件。\n", "\n", "示例代码如下:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved checkpoint, loss:0.0000001, current step num: 253.\n", "Saved checkpoint, loss:0.0000005, current step num: 258.\n", "Saved checkpoint, loss:0.0000001, current step num: 265.\n", "Saved checkpoint, loss:0.0000000, current step num: 332.\n", "Saved checkpoint, loss:0.0000003, current step num: 358.\n", "Saved checkpoint, loss:0.0000003, current step num: 380.\n", "Saved checkpoint, loss:0.0000003, current step num: 395.\n", "Saved checkpoint, loss:0.0000005, current step num:1151.\n", "Saved checkpoint, loss:0.0000005, current step num:1358.\n", "Saved checkpoint, loss:0.0000002, current step num:1524.\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "# 定义保存ckpt文件的回调接口\n", "class SaveCkptMonitor(ms.Callback):\n", " \"\"\"定义初始化过程\"\"\"\n", "\n", " def __init__(self, loss):\n", " super(SaveCkptMonitor, self).__init__()\n", " self.loss = loss # 定义损失值阈值\n", "\n", " def step_end(self, run_context):\n", " \"\"\"定义step结束时的执行操作\"\"\"\n", " cb_params = run_context.original_args()\n", " cur_loss = cb_params.net_outputs.asnumpy() # 获取当前损失值\n", "\n", " # 如果当前损失值小于设定的阈值就停止训练\n", " if cur_loss < self.loss:\n", " # 自定义保存文件名\n", " file_name = str(cb_params.cur_epoch_num) + \"_\" + str(cb_params.cur_step_num) + \".ckpt\"\n", " # 保存网络模型\n", " ms.save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)\n", " print(\"Saved checkpoint, loss:{:8.7f}, current step num:{:4}.\".format(cur_loss, cb_params.cur_step_num))\n", "\n", "model.train(1, dataset_train, callbacks=[SaveCkptMonitor(5e-7)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "保存目录结构如下:\n", "\n", "```text\n", "./\n", "├── 1_253.ckpt\n", "├── 1_258.ckpt\n", "├── 1_265.ckpt\n", "├── 1_332.ckpt\n", "├── 1_358.ckpt\n", "├── 1_380.ckpt\n", "├── 1_395.ckpt\n", "├── 1_1151.ckpt\n", "├── 1_1358.ckpt\n", "├── 1_1524.ckpt\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 5 }