{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model基本使用\n", "\n", "[](https://openi.pcl.ac.cn/MindSpore/docs/src/branch/r1.9/tutorials/source_zh_cn/advanced/model/model.ipynb?card=2&image=MindSpore1.8.1) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.9/tutorials/zh_cn/advanced/model/mindspore_model.ipynb) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.9/tutorials/zh_cn/advanced/model/mindspore_model.py) [](https://gitee.com/mindspore/docs/blob/r1.9/tutorials/source_zh_cn/advanced/model/model.ipynb)\n", "\n", "通常情况下,定义训练和评估网络并直接运行,已经可以满足基本需求。\n", "\n", "一方面,`Model`可以在一定程度上简化代码。例如:无需手动遍历数据集;在不需要自定义`nn.TrainOneStepCell`的场景下,可以借助`Model`自动构建训练网络;可以使用`Model`的`eval`接口进行模型评估,直接输出评估结果,无需手动调用评价指标的`clear`、`update`、`eval`函数等。\n", "\n", "另一方面,`Model`提供了很多高阶功能,如数据下沉、混合精度等,在不借助`Model`的情况下,使用这些功能需要花费较多的时间仿照`Model`进行自定义。\n", "\n", "本文档首先对MindSpore的Model进行基本介绍,然后重点讲解如何使用`Model`进行模型训练、评估和推理。\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import mindspore\n", "from mindspore import nn\n", "from mindspore.dataset import MnistDataset, vision, transforms\n", "from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model基本介绍\n", "\n", "[Model](https://www.mindspore.cn/docs/en/r1.9/api_python/mindspore/mindspore.Model.html#mindspore.Model)是MindSpore提供的高阶API,可以进行模型训练、评估和推理。其接口的常用参数如下:\n", "\n", "- `network`:用于训练或推理的神经网络。\n", "- `loss_fn`:所使用的损失函数。\n", "- `optimizer`:所使用的优化器。\n", "- `metrics`:用于模型评估的评价函数。\n", "- `eval_network`:模型评估所使用的网络,未定义情况下,`Model`会使用`network`和`loss_fn`进行封装。\n", "\n", "`Model`提供了以下接口用于模型训练、评估和推理:\n", "\n", "- `fit`:边训练边评估模型。\n", "- `train`:用于在训练集上进行模型训练。\n", "- `eval`:用于在验证集上进行模型评估。\n", "- `predict`:用于对输入的一组数据进行推理,输出预测结果。\n", "\n", "### 使用Model接口\n", "\n", "对于简单场景的神经网络,可以在定义`Model`时指定前向网络`network`、损失函数`loss_fn`、优化器`optimizer`和评价函数`metrics`。\n", "\n", "## 下载并处理数据集" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Download data from open datasets\n", "from download import download\n", "\n", "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/\" \\\n", " \"notebook/datasets/MNIST_Data.zip\"\n", "path = download(url, \"./\", kind=\"zip\")\n", "\n", "\n", "def datapipe(path, batch_size):\n", " image_transforms = [\n", " vision.Rescale(1.0 / 255.0, 0),\n", " vision.Normalize(mean=(0.1307,), std=(0.3081,)),\n", " vision.HWC2CHW()\n", " ]\n", " label_transform = transforms.TypeCast(mindspore.int32)\n", "\n", " dataset = MnistDataset(path)\n", " dataset = dataset.map(image_transforms, 'image')\n", " dataset = dataset.map(label_transform, 'label')\n", " dataset = dataset.batch(batch_size, drop_remainder=True)\n", " return dataset\n", "\n", "train_dataset = datapipe('MNIST_Data/train', 64)\n", "test_dataset = datapipe('MNIST_Data/test', 64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 创建模型" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:30.392367Z", "start_time": "2022-01-04T06:43:28.436687Z" } }, "outputs": [], "source": [ "# Define model\n", "class Network(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.flatten = nn.Flatten()\n", " self.dense_relu_sequential = nn.SequentialCell(\n", " nn.Dense(28*28, 512),\n", " nn.ReLU(),\n", " nn.Dense(512, 512),\n", " nn.ReLU(),\n", " nn.Dense(512, 10)\n", " )\n", "\n", " def construct(self, x):\n", " x = self.flatten(x)\n", " logits = self.dense_relu_sequential(x)\n", " return logits\n", "\n", "model = Network()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义损失函数和优化器\n", "\n", "要训练神经网络模型,需要定义损失函数和优化器函数。\n", "\n", "- 损失函数这里使用交叉熵损失函数`CrossEntropyLoss`。\n", "- 优化器这里使用`SGD`。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Instantiate loss function and optimizer\n", "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = nn.SGD(model.trainable_params(), 1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 训练及保存模型\n", "\n", "在开始训练之前,MindSpore需要提前声明网络模型在训练过程中是否需要保存中间过程和结果,因此使用`ModelCheckpoint`接口用于保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "steps_per_epoch = train_dataset.get_dataset_size()\n", "config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch)\n", "\n", "ckpt_callback = ModelCheckpoint(prefix=\"mnist\", directory=\"./checkpoint\", config=config)\n", "loss_callback = LossMonitor(steps_per_epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过MindSpore提供的`model.fit`接口可以方便地进行网络的训练,`LossMonitor`可以监控训练过程中`loss`值的变化。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 1 step: 937, loss is 0.5313149094581604\n", "Eval result: epoch 1, metrics: {'accuracy': 0.8557692307692307}\n", "epoch: 2 step: 937, loss is 0.2875961363315582\n", "Eval result: epoch 2, metrics: {'accuracy': 0.9007411858974359}\n", "epoch: 3 step: 937, loss is 0.19009104371070862\n", "Eval result: epoch 3, metrics: {'accuracy': 0.9191706730769231}\n", "epoch: 4 step: 937, loss is 0.24231787025928497\n", "Eval result: epoch 4, metrics: {'accuracy': 0.9296875}\n", "epoch: 5 step: 937, loss is 0.16016671061515808\n", "Eval result: epoch 5, metrics: {'accuracy': 0.9386017628205128}\n", "epoch: 6 step: 937, loss is 0.4830142855644226\n", "Eval result: epoch 6, metrics: {'accuracy': 0.9444110576923077}\n", "epoch: 7 step: 937, loss is 0.20778779685497284\n", "Eval result: epoch 7, metrics: {'accuracy': 0.9508213141025641}\n", "epoch: 8 step: 937, loss is 0.22020074725151062\n", "Eval result: epoch 8, metrics: {'accuracy': 0.9540264423076923}\n", "epoch: 9 step: 937, loss is 0.15951070189476013\n", "Eval result: epoch 9, metrics: {'accuracy': 0.9575320512820513}\n", "epoch: 10 step: 937, loss is 0.11161471903324127\n", "Eval result: epoch 10, metrics: {'accuracy': 0.9608373397435898}\n" ] } ], "source": [ "trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})\n", "\n", "trainer.fit(10, train_dataset, test_dataset, callbacks=[ckpt_callback, loss_callback])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。\n", "\n", "通过模型运行测试数据集得到的结果,验证模型的泛化能力:\n", "\n", "1. 使用`model.eval`接口读入测试数据集。\n", "2. 使用保存后的模型参数进行推理。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T06:43:30.952190Z", "start_time": "2022-01-04T06:43:30.525149Z" }, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 0.9607371794871795}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "acc = trainer.eval(test_dataset)\n", "acc" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以在打印信息中看出模型精度数据,示例中精度数据达到95%以上,模型质量良好。随着网络迭代次数增加,模型精度会进一步提高。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13 (default, Mar 28 2022, 01:56:06) [MSC v.1916 32 bit (Intel)]" }, "vscode": { "interpreter": { "hash": "d735f6cc625ca8b095620eeb46a50be5e34ded063dbe74c6d5dc8e1ec88bb29c" } } }, "nbformat": 4, "nbformat_minor": 4 }