高阶封装:Model
通常情况下,定义训练和评估网络并直接运行,已经可以满足基本需求。
一方面,Model
可以在一定程度上简化代码。例如:无需手动遍历数据集;在不需要自定义nn.TrainOneStepCell
的场景下,可以借助Model
自动构建训练网络;可以使用Model
的eval
接口进行模型评估,直接输出评估结果,无需手动调用评价指标的clear
、update
、eval
函数等。
另一方面,Model
提供了很多高阶功能,如数据下沉、混合精度等,在不借助Model
的情况下,使用这些功能需要花费较多的时间仿照Model
进行自定义。
本文档首先对MindSpore的Model进行基本介绍,然后重点讲解如何使用Model
进行模型训练、评估和推理。
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor
Model基本介绍
Model是MindSpore提供的高阶API,可以进行模型训练、评估和推理。其接口的常用参数如下:
network
:用于训练或推理的神经网络。loss_fn
:所使用的损失函数。optimizer
:所使用的优化器。metrics
:用于模型评估的评价函数。eval_network
:模型评估所使用的网络,未定义情况下,Model
会使用network
和loss_fn
进行封装。
Model
提供了以下接口用于模型训练、评估和推理:
fit
:边训练边评估模型。train
:用于在训练集上进行模型训练。eval
:用于在验证集上进行模型评估。predict
:用于对输入的一组数据进行推理,输出预测结果。
使用Model接口
对于简单场景的神经网络,可以在定义Model
时指定前向网络network
、损失函数loss_fn
、优化器optimizer
和评价函数metrics
。
下载并处理数据集
使用download库下载数据集,通过 vison.Rescale 接口对图片进行缩放, vision.Normalize 接口对输入图片进行归一化处理, vision.HWC2CHW 接口对数据格式进行转换。
# Download data from open datasets
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
def datapipe(path, batch_size):
image_transforms = [
vision.Rescale(1.0 / 255.0, 0),
vision.Normalize(mean=(0.1307,), std=(0.3081,)),
vision.HWC2CHW()
]
label_transform = transforms.TypeCast(mindspore.int32)
dataset = MnistDataset(path)
dataset = dataset.map(image_transforms, 'image')
dataset = dataset.map(label_transform, 'label')
dataset = dataset.batch(batch_size)
return dataset
train_dataset = datapipe('MNIST_Data/train', 64)
test_dataset = datapipe('MNIST_Data/test', 64)
创建模型
关于模型创建的讲解可以参考 网络构建。
# Define model
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.dense_relu_sequential = nn.SequentialCell(
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10)
)
def construct(self, x):
x = self.flatten(x)
logits = self.dense_relu_sequential(x)
return logits
model = Network()
定义损失函数和优化器
要训练神经网络模型,需要定义损失函数和优化器函数。
损失函数这里使用交叉熵损失函数
CrossEntropyLoss
。优化器这里使用
SGD
。
# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)
训练及保存模型
在开始训练之前,MindSpore需要提前声明网络模型在训练过程中是否需要保存中间过程和结果,因此使用ModelCheckpoint
接口用于保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。
steps_per_epoch = train_dataset.get_dataset_size()
config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch)
ckpt_callback = ModelCheckpoint(prefix="mnist", directory="./checkpoint", config=config)
loss_callback = LossMonitor(steps_per_epoch)
通过MindSpore提供的model.fit
接口可以方便地进行网络的训练与评估,LossMonitor
可以监控训练过程中loss
值的变化。
trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})
trainer.fit(10, train_dataset, test_dataset, callbacks=[ckpt_callback, loss_callback])
epoch: 1 step: 938, loss is 0.602992594242096
Eval result: epoch 1, metrics: {'accuracy': 0.8435}
epoch: 2 step: 938, loss is 0.2797124981880188
Eval result: epoch 2, metrics: {'accuracy': 0.9003}
epoch: 3 step: 938, loss is 0.32015785574913025
Eval result: epoch 3, metrics: {'accuracy': 0.9179}
epoch: 4 step: 938, loss is 0.17153620719909668
Eval result: epoch 4, metrics: {'accuracy': 0.9308}
epoch: 5 step: 938, loss is 0.18772485852241516
Eval result: epoch 5, metrics: {'accuracy': 0.9382}
epoch: 6 step: 938, loss is 0.45641791820526123
Eval result: epoch 6, metrics: {'accuracy': 0.946}
epoch: 7 step: 938, loss is 0.11519066989421844
Eval result: epoch 7, metrics: {'accuracy': 0.9506}
epoch: 8 step: 938, loss is 0.43486487865448
Eval result: epoch 8, metrics: {'accuracy': 0.9555}
epoch: 9 step: 938, loss is 0.1941455900669098
Eval result: epoch 9, metrics: {'accuracy': 0.9588}
epoch: 10 step: 938, loss is 0.13441434502601624
Eval result: epoch 10, metrics: {'accuracy': 0.9632}
训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。
通过模型运行测试数据集得到的结果,验证模型的泛化能力:
使用
model.eval
接口读入测试数据集。使用保存后的模型参数进行推理。
acc = trainer.eval(test_dataset)
acc
{'accuracy': 0.9632}
可以在打印信息中看出模型精度数据,示例中精度数据达到95%以上,模型质量良好。随着网络迭代次数增加,模型精度会进一步提高。