回调机制 Callback
在深度学习训练过程中,为及时掌握网络模型的训练状态、实时观察网络模型各参数的变化情况和实现训练过程中用户自定义的一些操作,MindSpore提供了回调机制(Callback)来实现上述功能。
Callback回调机制一般用在网络模型训练过程Model.train
中,MindSpore的Model
会按照Callback列表callbacks
顺序执行回调函数,用户可以通过设置不同的回调类来实现在训练过程中或者训练后执行的功能。
更多内置回调类的信息及使用方式请参考API文档。
Callback介绍
当聊到回调Callback的时候,大部分用户都会觉得很难理解,是不是需要堆栈或者特殊的调度方式,实际上我们简单的理解回调:
假设函数A有一个参数,这个参数是个函数B,当函数A执行完以后执行函数B,那么这个过程就叫回调。
Callback
是回调的意思,MindSpore中的回调函数实际上不是一个函数而是一个类,用户可以使用回调机制来观察训练过程中网络内部的状态和相关信息,或在特定时期执行特定动作。
例如监控损失函数Loss、保存模型参数ckpt、动态调整参数lr、提前终止训练任务等。下面我们继续以手写体识别模型为例,介绍常见的内置回调函数和自定义回调函数。
[1]:
from download import download
from mindspore import nn, Model
from mindspore.dataset import vision, transforms, MnistDataset
from mindspore.common.initializer import Normal
from mindspore import dtype as mstype
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"
# 下载MNIST数据集
download(url, "./", kind="zip", replace=True)
# 数据处理
def proc_dataset(data_path, batch_size=32):
mnist_ds = MnistDataset(data_path, shuffle=True)
# define map operations
image_transforms = [
vision.Resize(32),
vision.Rescale(1.0 / 255.0, 0),
vision.Normalize(mean=(0.1307,), std=(0.3081,)),
vision.HWC2CHW()
]
label_transform = transforms.TypeCast(mstype.int32)
mnist_ds = mnist_ds.map(operations=label_transform, input_columns="label")
mnist_ds = mnist_ds.map(operations=image_transforms, input_columns="image")
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds
train_dataset = proc_dataset('MNIST_Data/train')
# 定义LeNet-5网络模型
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def create_model():
model = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(model.trainable_params(), learning_rate=0.01, momentum=0.9)
trainer = Model(model, loss_fn=net_loss, optimizer=net_opt, metrics={"Accuracy": nn.Accuracy()})
return trainer
trainer = create_model()
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)
file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:00<00:00, 23.8MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./
常用的内置回调函数
MindSpore提供Callback
能力,支持用户在训练/推理的特定阶段,插入自定义的操作。
ModelCheckpoint
用于保存训练后的网络模型和参数,方便进行再推理或再训练,MindSpore提供了ModelCheckpoint接口,一般与配置保存信息接口CheckpointConfig配合使用。
[2]:
from mindspore import CheckpointConfig, ModelCheckpoint
# 设置保存模型的配置信息
config = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# 实例化保存模型回调接口,定义保存路径和前缀名
ckpt_callback = ModelCheckpoint(prefix="mnist", directory="./checkpoint", config=config)
# 开始训练,加载保存模型和参数回调函数
trainer.train(1, train_dataset, callbacks=[ckpt_callback])
上面代码运行后,生成的Checkpoint文件目录结构如下:
./checkpoint/
├── mnist-1_1875.ckpt # 保存参数文件
└── mnist-graph.meta # 编译后的计算图
LossMonitor
用于监控训练或测试过程中的损失函数值Loss变化情况,可设置per_print_times
控制打印Loss值的间隔。
[3]:
from mindspore import LossMonitor
loss_monitor = LossMonitor(1875)
trainer.train(3, train_dataset, callbacks=[loss_monitor])
epoch: 1 step: 1875, loss is 0.008795851841568947
epoch: 2 step: 1875, loss is 0.007240554317831993
epoch: 3 step: 1875, loss is 0.0036914246156811714
训练场景下,LossMonitor监控训练的Loss值;边训练边推理场景下,监控训练的Loss值和推理的Metrics值。
[4]:
test_dataset = proc_dataset('MNIST_Data/test')
trainer.fit(2, train_dataset, test_dataset, callbacks=[loss_monitor])
epoch: 1 step: 1875, loss is 0.0026960039976984262
Eval result: epoch 1, metrics: {'Accuracy': 0.9888822115384616}
epoch: 2 step: 1875, loss is 0.00038617433165200055
Eval result: epoch 2, metrics: {'Accuracy': 0.9877804487179487}
TimeMonitor
用于监控训练或测试过程的执行时间。可设置data_size
控制打印执行时间的间隔。
[5]:
from mindspore import TimeMonitor
time_monitor = TimeMonitor(1875)
trainer.train(1, train_dataset, callbacks=[time_monitor])
Train epoch time: 3876.302 ms, per step time: 2.067 ms
自定义回调机制
MindSpore不仅有功能强大的内置回调函数,当用户有自己的特殊需求时,还可以基于Callback
基类自定义回调类。
用户可以基于Callback
基类,根据自身的需求,实现自定义Callback
。Callback
基类定义如下所示:
[6]:
class Callback():
"""Callback base class"""
def on_train_begin(self, run_context):
"""Called once before the network executing."""
def on_train_epoch_begin(self, run_context):
"""Called before each epoch beginning."""
def on_train_epoch_end(self, run_context):
"""Called after each epoch finished."""
def on_train_step_begin(self, run_context):
"""Called before each step beginning."""
def on_train_step_end(self, run_context):
"""Called after each step finished."""
def on_train_end(self, run_context):
"""Called once after network training."""
回调机制可以把训练过程中的重要信息记录下来,通过把一个字典类型变量RunContext.original_args()
,传递给Callback对象,使得用户可以在各个自定义的Callback中获取到相关属性,执行自定义操作,也可以自定义其他变量传递给RunContext.original_args()
对象。
RunContext.original_args()
中的常用属性有:
epoch_num:训练的epoch的数量
batch_num:一个epoch中step的数量
cur_epoch_num:当前的epoch数
cur_step_num:当前的step数
loss_fn:损失函数
optimizer:优化器
train_network:训练的网络
train_dataset:训练的数据集
net_outputs:网络的输出结果
parallel_mode:并行模式
list_callback:所有的Callback函数
通过下面两个场景,我们可以增加对自定义Callback回调机制功能的了解。
自定义终止训练
实现在规定时间内终止训练功能。用户可以设定时间阈值,当训练时间达到这个阈值后就终止训练过程。
下面代码中,通过run_context.original_args
方法可以获取到cb_params
字典,字典里会包含前文描述的主要属性信息。
同时可以对字典内的值进行修改和添加,在begin
函数中定义一个init_time
对象传递给cb_params
字典。每个数据迭代结束step_end
之后会进行判断,当训练时间大于设置的时间阈值时,会向run_context传递终止训练的信号,提前终止训练,并打印当前的epoch、step、loss的值。
[7]:
import time
from mindspore import Callback
class StopTimeMonitor(Callback):
def __init__(self, run_time):
"""定义初始化过程"""
super(StopTimeMonitor, self).__init__()
self.run_time = run_time # 定义执行时间
def on_train_begin(self, run_context):
"""开始训练时的操作"""
cb_params = run_context.original_args()
cb_params.init_time = time.time() # 获取当前时间戳作为开始训练时间
print(f"Begin training, time is: {cb_params.init_time}")
def on_train_step_end(self, run_context):
"""每个step结束后执行的操作"""
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num # 获取epoch值
step_num = cb_params.cur_step_num # 获取step值
loss = cb_params.net_outputs # 获取损失值loss
cur_time = time.time() # 获取当前时间戳
if (cur_time - cb_params.init_time) > self.run_time:
print(f"End training, time: {cur_time}, epoch: {epoch_num}, step: {step_num}, loss:{loss}")
run_context.request_stop() # 停止训练
train_dataset = proc_dataset('MNIST_Data/train')
trainer.train(10, train_dataset, callbacks=[LossMonitor(), StopTimeMonitor(4)])
Begin training, time is: 1673515004.6783535
epoch: 1 step: 1875, loss is 0.0006050781812518835
End training, time: 1673515009.1824663, epoch: 1, step: 1875, loss:0.0006050782
从上面的打印结果可以看出,运行时间到达了阈值后就结束了训练。
自定义阈值保存模型
该回调机制实现当loss小于设定的阈值时,保存网络模型权重ckpt文件。
示例代码如下:
[8]:
from mindspore import save_checkpoint
# 定义保存ckpt文件的回调接口
class SaveCkptMonitor(Callback):
"""定义初始化过程"""
def __init__(self, loss):
super(SaveCkptMonitor, self).__init__()
self.loss = loss # 定义损失值阈值
def on_train_step_end(self, run_context):
"""定义step结束时的执行操作"""
cb_params = run_context.original_args()
cur_loss = cb_params.net_outputs.asnumpy() # 获取当前损失值
# 如果当前损失值小于设定的阈值就停止训练
if cur_loss < self.loss:
# 自定义保存文件名
file_name = f"./checkpoint/{cb_params.cur_epoch_num}_{cb_params.cur_step_num}.ckpt"
# 保存网络模型
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
print("Saved checkpoint, loss:{:8.7f}, current step num:{:4}.".format(cur_loss, cb_params.cur_step_num))
trainer = create_model()
train_dataset = proc_dataset('MNIST_Data/train')
trainer.train(5, train_dataset, callbacks=[LossMonitor(), SaveCkptMonitor(0.01)])
epoch: 1 step: 1875, loss is 0.15191984176635742
epoch: 2 step: 1875, loss is 0.14701086282730103
epoch: 3 step: 1875, loss is 0.0020134493242949247
Saved checkpoint, loss:0.0020134, current step num:5625.
epoch: 4 step: 1875, loss is 0.018305214121937752
epoch: 5 step: 1875, loss is 0.00019801077723968774
Saved checkpoint, loss:0.0001980, current step num:9375.
最终,损失值小于设定阈值的网络权重被保存在./checkpoint/
目录下。