回调机制 Callback

下载Notebook下载样例代码查看源文件

在深度学习训练过程中,为及时掌握网络模型的训练状态、实时观察网络模型各参数的变化情况和实现训练过程中用户自定义的一些操作,MindSpore提供了回调机制(Callback)来实现上述功能。

Callback回调机制一般用在网络模型训练过程Model.train中,MindSpore的Model会按照Callback列表callbacks顺序执行回调函数,用户可以通过设置不同的回调类来实现在训练过程中或者训练后执行的功能。

更多内置回调类的信息及使用方式请参考API文档

Callback介绍

当聊到回调Callback的时候,大部分用户都会觉得很难理解,是不是需要堆栈或者特殊的调度方式,实际上我们简单的理解回调:

假设函数A有一个参数,这个参数是个函数B,当函数A执行完以后执行函数B,那么这个过程就叫回调。

Callback是回调的意思,MindSpore中的回调函数实际上不是一个函数而是一个类,用户可以使用回调机制来观察训练过程中网络内部的状态和相关信息,或在特定时期执行特定动作

例如监控损失函数Loss、保存模型参数ckpt、动态调整参数lr、提前终止训练任务等。下面我们继续以手写体识别模型为例,介绍常见的内置回调函数和自定义回调函数。

[1]:
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
from mindspore.train import Model

# Download data from open datasets
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)


def datapipe(path, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = MnistDataset(path)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

# Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)

file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:01<00:00, 10.0MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./
[2]:
train_dataset = datapipe('MNIST_Data/train', 64)
test_dataset = datapipe('MNIST_Data/test', 64)

trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})

常用的内置回调函数

MindSpore提供Callback能力,支持用户在训练/推理的特定阶段,插入自定义的操作。

ModelCheckpoint

用于保存训练后的网络模型和参数,方便进行再推理或再训练,MindSpore提供了ModelCheckpoint接口,一般与配置保存信息接口CheckpointConfig配合使用。

[3]:
from mindspore.train import CheckpointConfig, ModelCheckpoint

# 设置保存模型的配置信息
config = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# 实例化保存模型回调接口,定义保存路径和前缀名
ckpt_callback = ModelCheckpoint(prefix="mnist", directory="./checkpoint", config=config)

# 开始训练,加载保存模型和参数回调函数
trainer.train(1, train_dataset, callbacks=[ckpt_callback])

上面代码运行后,生成的Checkpoint文件目录结构如下:

./checkpoint/
├── mnist-1_938.ckpt # 保存参数文件
└── mnist-graph.meta # 编译后的计算图

LossMonitor

用于监控训练或测试过程中的损失函数值Loss变化情况,可设置per_print_times控制打印Loss值的间隔。

[4]:
from mindspore.train import LossMonitor

loss_monitor = LossMonitor(300)
# 开始训练,加载保存模型和参数回调函数,LossMonitor的入参0.01为学习率,300为步长
trainer.train(1, train_dataset, callbacks=[loss_monitor])
epoch: 1 step: 300, loss is 0.45305341482162476
epoch: 1 step: 600, loss is 0.2915695905685425
epoch: 1 step: 900, loss is 0.5174192190170288

训练场景下,LossMonitor监控训练的Loss值;边训练边推理场景下,监控训练的Loss值和推理的Metrics值。

[5]:
trainer.fit(1, train_dataset, test_dataset, callbacks=[loss_monitor])
epoch: 1 step: 300, loss is 0.3167177438735962
epoch: 1 step: 600, loss is 0.36215940117836
epoch: 1 step: 900, loss is 0.25714176893234253
Eval result: epoch 1, metrics: {'accuracy': 0.9202}

TimeMonitor

用于监控训练或测试过程的执行时间。可设置data_size控制打印执行时间的间隔。

[6]:
from mindspore.train import TimeMonitor

time_monitor = TimeMonitor()
trainer.train(1, train_dataset, callbacks=[time_monitor])
Train epoch time: 7388.254 ms, per step time: 7.877 ms

自定义回调机制

MindSpore不仅有功能强大的内置回调函数,当用户有自己的特殊需求时,还可以基于Callback基类自定义回调类。

用户可以基于Callback基类,根据自身的需求,实现自定义CallbackCallback基类定义如下所示:

[7]:
class Callback():
    """Callback base class"""
    def on_train_begin(self, run_context):
        """Called once before the network executing."""

    def on_train_epoch_begin(self, run_context):
        """Called before each epoch beginning."""

    def on_train_epoch_end(self, run_context):
        """Called after each epoch finished."""

    def on_train_step_begin(self, run_context):
        """Called before each step beginning."""

    def on_train_step_end(self, run_context):
        """Called after each step finished."""

    def on_train_end(self, run_context):
        """Called once after network training."""

回调机制可以把训练过程中的重要信息记录下来,通过把一个字典类型变量RunContext.original_args(),传递给Callback对象,使得用户可以在各个自定义的Callback中获取到相关属性,执行自定义操作,也可以自定义其他变量传递给RunContext.original_args()对象。

RunContext.original_args()中的常用属性有:

  • epoch_num:训练的epoch的数量

  • batch_num:一个epoch中step的数量

  • cur_epoch_num:当前的epoch数

  • cur_step_num:当前的step数

  • loss_fn:损失函数

  • optimizer:优化器

  • train_network:训练的网络

  • train_dataset:训练的数据集

  • net_outputs:网络的输出结果

  • parallel_mode:并行模式

  • list_callback:所有的Callback函数

通过下面两个场景,我们可以增加对自定义Callback回调机制功能的了解。

自定义终止训练

实现在规定时间内终止训练功能。用户可以设定时间阈值,当训练时间达到这个阈值后就终止训练过程。

下面代码中,通过run_context.original_args方法可以获取到cb_params字典,字典里会包含前文描述的主要属性信息。

同时可以对字典内的值进行修改和添加,在begin函数中定义一个init_time对象传递给cb_params字典。每个数据迭代结束step_end之后会进行判断,当训练时间大于设置的时间阈值时,会向run_context传递终止训练的信号,提前终止训练,并打印当前的epoch、step、loss的值。

[8]:
import time
import mindspore as ms

class StopTimeMonitor(ms.train.Callback):

    def __init__(self, run_time):
        """定义初始化过程"""
        super(StopTimeMonitor, self).__init__()
        self.run_time = run_time            # 定义执行时间

    def on_train_begin(self, run_context):
        """开始训练时的操作"""
        cb_params = run_context.original_args()
        cb_params.init_time = time.time()   # 获取当前时间戳作为开始训练时间
        print(f"Begin training, time is: {cb_params.init_time}")

    def on_train_step_end(self, run_context):
        """每个step结束后执行的操作"""
        cb_params = run_context.original_args()
        epoch_num = cb_params.cur_epoch_num  # 获取epoch值
        step_num = cb_params.cur_step_num    # 获取step值
        loss = cb_params.net_outputs         # 获取损失值loss
        cur_time = time.time()               # 获取当前时间戳

        if (cur_time - cb_params.init_time) > self.run_time:
            print(f"End training, time: {cur_time}, epoch: {epoch_num}, step: {step_num}, loss:{loss}")
            run_context.request_stop()       # 停止训练

datasize = train_dataset.get_dataset_size()
trainer.train(5, train_dataset, callbacks=[LossMonitor(datasize), StopTimeMonitor(4)])
Begin training, time is: 1665892816.363511
End training, time: 1665892820.3696215, epoch: 1, step: 575, loss:Tensor(shape=[], dtype=Float32, value= 0.35758)

从上面的打印结果可以看出,当第3个epoch的第4673个step执行完时,运行时间到达了阈值并结束了训练。

自定义阈值保存模型

该回调机制实现当loss小于设定的阈值时,保存网络模型权重ckpt文件。

示例代码如下:

[9]:
import mindspore as ms

# 定义保存ckpt文件的回调接口
class SaveCkptMonitor(ms.train.Callback):
    """定义初始化过程"""

    def __init__(self, loss):
        super(SaveCkptMonitor, self).__init__()
        self.loss = loss  # 定义损失值阈值

    def on_train_step_end(self, run_context):
        """定义step结束时的执行操作"""
        cb_params = run_context.original_args()
        cur_loss = cb_params.net_outputs.asnumpy() # 获取当前损失值

        # 如果当前损失值小于设定的阈值就保存模型
        if cur_loss < self.loss:
            # 自定义保存文件名
            file_name = f"./checkpoint/{cb_params.cur_epoch_num}_{cb_params.cur_step_num}.ckpt"
            # 保存网络模型
            ms.save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
            print("Saved checkpoint, loss:{:8.7f}, current step num:{:4}.".format(cur_loss, cb_params.cur_step_num))

trainer.train(1, train_dataset, callbacks=[SaveCkptMonitor(0.05)])
Saved checkpoint, loss:0.0390485, current step num: 154.
Saved checkpoint, loss:0.0481475, current step num: 234.
Saved checkpoint, loss:0.0477566, current step num: 361.
Saved checkpoint, loss:0.0314977, current step num: 444.
Saved checkpoint, loss:0.0463577, current step num: 513.
Saved checkpoint, loss:0.0408403, current step num: 764.
Saved checkpoint, loss:0.0308827, current step num: 899.

保存目录结构如下:

./checkpoint/
├── 1_154.ckpt
├── 1_234.ckpt
├── 1_361.ckpt
├── 1_444.ckpt
├── 1_513.ckpt
├── 1_764.ckpt
├── 1_899.ckpt