回调机制 Callback

下载Notebook下载样例代码查看源文件

在深度学习训练过程中,为及时掌握网络模型的训练状态、实时观察网络模型各参数的变化情况和实现训练过程中用户自定义的一些操作,MindSpore提供了回调机制(Callback)来实现上述功能。

Callback回调机制一般用在网络模型训练过程Model.train中,MindSpore的Model会按照Callback列表callbacks顺序执行回调函数,用户可以通过设置不同的回调类来实现在训练过程中或者训练后执行的功能。

更多内置回调类的信息及使用方式请参考API文档

Callback介绍和使用

当聊到回调Callback的时候,大部分用户都会觉得很难理解,是不是需要堆栈或者特殊的调度方式,实际上我们简单的理解回调:

假设函数A有一个参数,这个参数是个函数B,当函数A执行完以后执行函数B,那么这个过程就叫回调。

Callback是回调的意思,MindSpore中的回调函数实际上不是一个函数而是一个类,用户可以使用回调机制来观察训练过程中网络内部的状态和相关信息,或在特定时期执行特定动作

例如监控损失函数Loss、保存模型参数ckpt、动态调整参数lr、提前终止训练任务等。

下面以基于MNIST数据集训练LeNet-5网络模型为例,介绍几种常用的MindSpore内置回调类。

首先需要下载并处理MNIST数据,构建LeNet-5网络模型,示例代码如下:

[1]:
import mindspore.nn as nn
import mindspore as ms
from mindvision.classification.dataset import Mnist
from mindvision.classification.models import lenet

download_train = Mnist(path="./mnist", split="train", download=True)
dataset_train = download_train.run()

network = lenet(num_classes=10, pretrained=False)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)

# 定义网络模型
model = ms.Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": nn.Accuracy()})

回调机制的使用方法,在model.train方法中传入Callback对象,它可以是一个Callback列表,示例代码如下,其中ModelCheckpointLossMonitor是MindSpore提供的回调类。

[2]:
import mindspore as ms

# 定义回调类
ckpt_cb = ms.ModelCheckpoint()
loss_cb = ms.LossMonitor(1875)

model.train(5, dataset_train, callbacks=[ckpt_cb, loss_cb])
epoch: 1 step: 1875, loss is 0.257398396730423
epoch: 2 step: 1875, loss is 0.04801357910037041
epoch: 3 step: 1875, loss is 0.028765171766281128
epoch: 4 step: 1875, loss is 0.008372672833502293
epoch: 5 step: 1875, loss is 0.0016194271156564355

常用的内置回调函数

MindSpore提供Callback能力,支持用户在训练/推理的特定阶段,插入自定义的操作。

ModelCheckpoint

为了保存训练后的网络模型和参数,方便进行再推理或再训练,MindSpore提供了ModelCheckpoint接口,一般与配置保存信息接口CheckpointConfig配合使用。

下面我们通过一段示例代码来说明如何保存训练后的网络模型和参数:

[3]:
import mindspore as ms

# 设置保存模型的配置信息
config_ck = ms.CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# 实例化保存模型回调接口,定义保存路径和前缀名
ckpoint = ms.ModelCheckpoint(prefix="lenet", directory="./lenet", config=config_ck)

# 开始训练,加载保存模型和参数回调函数
model.train(1, dataset_train, callbacks=[ckpoint])

上面代码运行后,生成的Checkpoint文件目录结构如下:

./lenet/
├── lenet-1_1875.ckpt # 保存参数文件
└── lenet-graph.meta # 编译后的计算图

LossMonitor

为了监控训练过程中的损失函数值Loss变化情况,观察训练过程中每个epoch、每个step的运行时间,MindSpore Vision提供了LossMonitor接口(与MindSpore提供的LossMonitor接口有区别)。

下面我们通过示例代码说明:

[4]:
from mindvision.engine.callback import LossMonitor

# 开始训练,加载保存模型和参数回调函数,LossMonitor的入参0.01为学习率,375为步长
model.train(5, dataset_train, callbacks=[LossMonitor(0.01, 375)])
Epoch:[  0/  5], step:[  375/ 1875], loss:[0.041/0.023], time:0.670 ms, lr:0.01000
Epoch:[  0/  5], step:[  750/ 1875], loss:[0.002/0.023], time:0.723 ms, lr:0.01000
Epoch:[  0/  5], step:[ 1125/ 1875], loss:[0.006/0.023], time:0.662 ms, lr:0.01000
Epoch:[  0/  5], step:[ 1500/ 1875], loss:[0.000/0.024], time:0.664 ms, lr:0.01000
Epoch:[  0/  5], step:[ 1875/ 1875], loss:[0.009/0.024], time:0.661 ms, lr:0.01000
Epoch time: 1759.622 ms, per step time: 0.938 ms, avg loss: 0.024
Epoch:[  1/  5], step:[  375/ 1875], loss:[0.001/0.020], time:0.658 ms, lr:0.01000
Epoch:[  1/  5], step:[  750/ 1875], loss:[0.002/0.021], time:0.661 ms, lr:0.01000
Epoch:[  1/  5], step:[ 1125/ 1875], loss:[0.000/0.021], time:0.663 ms, lr:0.01000
Epoch:[  1/  5], step:[ 1500/ 1875], loss:[0.048/0.022], time:0.655 ms, lr:0.01000
Epoch:[  1/  5], step:[ 1875/ 1875], loss:[0.018/0.022], time:0.646 ms, lr:0.01000
Epoch time: 1551.506 ms, per step time: 0.827 ms, avg loss: 0.022
Epoch:[  2/  5], step:[  375/ 1875], loss:[0.001/0.017], time:0.674 ms, lr:0.01000
Epoch:[  2/  5], step:[  750/ 1875], loss:[0.001/0.018], time:0.669 ms, lr:0.01000
Epoch:[  2/  5], step:[ 1125/ 1875], loss:[0.004/0.019], time:0.683 ms, lr:0.01000
Epoch:[  2/  5], step:[ 1500/ 1875], loss:[0.003/0.020], time:0.657 ms, lr:0.01000
Epoch:[  2/  5], step:[ 1875/ 1875], loss:[0.041/0.019], time:1.447 ms, lr:0.01000
Epoch time: 1616.589 ms, per step time: 0.862 ms, avg loss: 0.019
Epoch:[  3/  5], step:[  375/ 1875], loss:[0.000/0.011], time:0.672 ms, lr:0.01000
Epoch:[  3/  5], step:[  750/ 1875], loss:[0.001/0.013], time:0.687 ms, lr:0.01000
Epoch:[  3/  5], step:[ 1125/ 1875], loss:[0.016/0.014], time:0.665 ms, lr:0.01000
Epoch:[  3/  5], step:[ 1500/ 1875], loss:[0.001/0.015], time:0.674 ms, lr:0.01000
Epoch:[  3/  5], step:[ 1875/ 1875], loss:[0.001/0.015], time:0.666 ms, lr:0.01000
Epoch time: 1586.809 ms, per step time: 0.846 ms, avg loss: 0.015
Epoch:[  4/  5], step:[  375/ 1875], loss:[0.000/0.008], time:0.671 ms, lr:0.01000
Epoch:[  4/  5], step:[  750/ 1875], loss:[0.000/0.013], time:0.701 ms, lr:0.01000
Epoch:[  4/  5], step:[ 1125/ 1875], loss:[0.009/0.015], time:0.666 ms, lr:0.01000
Epoch:[  4/  5], step:[ 1500/ 1875], loss:[0.008/0.015], time:0.941 ms, lr:0.01000
Epoch:[  4/  5], step:[ 1875/ 1875], loss:[0.008/0.015], time:0.661 ms, lr:0.01000
Epoch time: 1584.785 ms, per step time: 0.845 ms, avg loss: 0.015

从上面的打印结果可以看出,MindSpore Vision套件提供的LossMonitor接口打印信息更加详细。由于步长设置的是375,所以每375个step会打印一条,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。

ValAccMonitor

为了在训练过程中保存精度最优的网络模型和参数,需要边训练边验证,MindSpore Vision提供了ValAccMonitor接口。

下面我们通过一段示例来介绍:

[5]:
from mindvision.engine.callback import ValAccMonitor

download_eval = Mnist(path="./mnist", split="test", download=True)
dataset_eval = download_eval.run()

# 开始训练,加载保存模型和参数回调函数
model.train(1, dataset_train, callbacks=[ValAccMonitor(model, dataset_eval, num_epochs=1)])
--------------------
Epoch: [  1 /   1], Train Loss: [0.000], Accuracy:  0.988
================================================================================
End of validation the best Accuracy is:  0.988, save the best ckpt file in ./best.ckpt

上面代码执行后,精度最优的网络模型和参数会被保存在当前目录下,文件名为“best.ckpt”。

自定义回调机制

MindSpore不仅有功能强大的内置回调函数,当用户有自己的特殊需求时,还可以基于Callback基类自定义回调类。

用户可以基于Callback基类,根据自身的需求,实现自定义CallbackCallback基类定义如下所示:

[6]:
class Callback():
    """Callback base class"""
    def begin(self, run_context):
        """Called once before the network executing."""
        pass # pylint: disable=W0107

    def epoch_begin(self, run_context):
        """Called before each epoch beginning."""
        pass # pylint: disable=W0107

    def epoch_end(self, run_context):
        """Called after each epoch finished."""
        pass # pylint: disable=W0107

    def step_begin(self, run_context):
        """Called before each step beginning."""
        pass # pylint: disable=W0107

    def step_end(self, run_context):
        """Called after each step finished."""
        pass # pylint: disable=W0107

    def end(self, run_context):
        """Called once after network training."""
        pass # pylint: disable=W0107

回调机制可以把训练过程中的重要信息记录下来,通过把一个字典类型变量RunContext.original_args(),传递给Callback对象,使得用户可以在各个自定义的Callback中获取到相关属性,执行自定义操作,也可以自定义其他变量传递给RunContext.original_args()对象。

RunContext.original_args()中的常用属性有:

  • epoch_num:训练的epoch的数量

  • batch_num:一个epoch中step的数量

  • cur_epoch_num:当前的epoch数

  • cur_step_num:当前的step数

  • loss_fn:损失函数

  • optimizer:优化器

  • train_network:训练的网络

  • train_dataset:训练的数据集

  • net_outputs:网络的输出结果

  • parallel_mode:并行模式

  • list_callback:所有的Callback函数

通过下面两个场景,我们可以增加对自定义Callback回调机制功能的了解。

自定义终止训练

实现在规定时间内终止训练功能。用户可以设定时间阈值,当训练时间达到这个阈值后就终止训练过程。

下面代码中,通过run_context.original_args方法可以获取到cb_params字典,字典里会包含前文描述的主要属性信息。

同时可以对字典内的值进行修改和添加,在begin函数中定义一个init_time对象传递给cb_params字典。每个数据迭代结束step_end之后会进行判断,当训练时间大于设置的时间阈值时,会向run_context传递终止训练的信号,提前终止训练,并打印当前的epoch、step、loss的值。

[7]:
import time
import mindspore as ms

class StopTimeMonitor(ms.Callback):

    def __init__(self, run_time):
        """定义初始化过程"""
        super(StopTimeMonitor, self).__init__()
        self.run_time = run_time            # 定义执行时间

    def begin(self, run_context):
        """开始训练时的操作"""
        cb_params = run_context.original_args()
        cb_params.init_time = time.time()   # 获取当前时间戳作为开始训练时间
        print("Begin training, time is:", cb_params.init_time)

    def step_end(self, run_context):
        """每个step结束后执行的操作"""
        cb_params = run_context.original_args()
        epoch_num = cb_params.cur_epoch_num  # 获取epoch值
        step_num = cb_params.cur_step_num    # 获取step值
        loss = cb_params.net_outputs         # 获取损失值loss
        cur_time = time.time()               # 获取当前时间戳

        if (cur_time - cb_params.init_time) > self.run_time:
            print("End training, time:", cur_time, ",epoch:", epoch_num, ",step:", step_num, ",loss:", loss)
            run_context.request_stop()       # 停止训练

download_train = Mnist(path="./mnist", split="train", download=True)
dataset = download_train.run()
model.train(5, dataset, callbacks=[LossMonitor(0.01, 1875), StopTimeMonitor(4)])
Begin training, time is: 1648452437.2004516
Epoch:[  0/  5], step:[ 1875/ 1875], loss:[0.011/0.012], time:0.678 ms, lr:0.01000
Epoch time: 1603.104 ms, per step time: 0.855 ms, avg loss: 0.012
Epoch:[  1/  5], step:[ 1875/ 1875], loss:[0.000/0.011], time:0.688 ms, lr:0.01000
Epoch time: 1602.716 ms, per step time: 0.855 ms, avg loss: 0.011
End training, time: 1648452441.20081 ,epoch: 3 ,step: 4673 ,loss: 0.014888153
Epoch time: 792.901 ms, per step time: 0.423 ms, avg loss: 0.010

从上面的打印结果可以看出,当第3个epoch的第4673个step执行完时,运行时间到达了阈值并结束了训练。

自定义阈值保存模型

该回调机制实现当loss小于设定的阈值时,保存网络模型权重ckpt文件。

示例代码如下:

[8]:
import mindspore as ms

# 定义保存ckpt文件的回调接口
class SaveCkptMonitor(ms.Callback):
    """定义初始化过程"""

    def __init__(self, loss):
        super(SaveCkptMonitor, self).__init__()
        self.loss = loss  # 定义损失值阈值

    def step_end(self, run_context):
        """定义step结束时的执行操作"""
        cb_params = run_context.original_args()
        cur_loss = cb_params.net_outputs.asnumpy() # 获取当前损失值

        # 如果当前损失值小于设定的阈值就停止训练
        if cur_loss < self.loss:
            # 自定义保存文件名
            file_name = str(cb_params.cur_epoch_num) + "_" + str(cb_params.cur_step_num) + ".ckpt"
            # 保存网络模型
            ms.save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
            print("Saved checkpoint, loss:{:8.7f}, current step num:{:4}.".format(cur_loss, cb_params.cur_step_num))

model.train(1, dataset_train, callbacks=[SaveCkptMonitor(5e-7)])
Saved checkpoint, loss:0.0000001, current step num: 253.
Saved checkpoint, loss:0.0000005, current step num: 258.
Saved checkpoint, loss:0.0000001, current step num: 265.
Saved checkpoint, loss:0.0000000, current step num: 332.
Saved checkpoint, loss:0.0000003, current step num: 358.
Saved checkpoint, loss:0.0000003, current step num: 380.
Saved checkpoint, loss:0.0000003, current step num: 395.
Saved checkpoint, loss:0.0000005, current step num:1151.
Saved checkpoint, loss:0.0000005, current step num:1358.
Saved checkpoint, loss:0.0000002, current step num:1524.

保存目录结构如下:

./
├── 1_253.ckpt
├── 1_258.ckpt
├── 1_265.ckpt
├── 1_332.ckpt
├── 1_358.ckpt
├── 1_380.ckpt
├── 1_395.ckpt
├── 1_1151.ckpt
├── 1_1358.ckpt
├── 1_1524.ckpt