{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model基本使用\n",
    "\n",
    "[![在OpenI运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/resource/_static/logo_openi.png)](https://openi.pcl.ac.cn/MindSpore/docs/src/branch/r1.9/tutorials/source_zh_cn/advanced/model/model.ipynb?card=2&image=MindSpore1.8.1) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.9/tutorials/zh_cn/advanced/model/mindspore_model.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.9/tutorials/zh_cn/advanced/model/mindspore_model.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.9/tutorials/source_zh_cn/advanced/model/model.ipynb)\n",
    "\n",
    "通常情况下,定义训练和评估网络并直接运行,已经可以满足基本需求。\n",
    "\n",
    "一方面,`Model`可以在一定程度上简化代码。例如:无需手动遍历数据集;在不需要自定义`nn.TrainOneStepCell`的场景下,可以借助`Model`自动构建训练网络;可以使用`Model`的`eval`接口进行模型评估,直接输出评估结果,无需手动调用评价指标的`clear`、`update`、`eval`函数等。\n",
    "\n",
    "另一方面,`Model`提供了很多高阶功能,如数据下沉、混合精度等,在不借助`Model`的情况下,使用这些功能需要花费较多的时间仿照`Model`进行自定义。\n",
    "\n",
    "本文档首先对MindSpore的Model进行基本介绍,然后重点讲解如何使用`Model`进行模型训练、评估和推理。\n",
    "\n",
    "![model](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.9/tutorials/source_zh_cn/advanced/model/images/model.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "from mindspore import nn\n",
    "from mindspore.dataset import MnistDataset, vision, transforms\n",
    "from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model基本介绍\n",
    "\n",
    "[Model](https://www.mindspore.cn/docs/en/r1.9/api_python/mindspore/mindspore.Model.html#mindspore.Model)是MindSpore提供的高阶API,可以进行模型训练、评估和推理。其接口的常用参数如下:\n",
    "\n",
    "- `network`:用于训练或推理的神经网络。\n",
    "- `loss_fn`:所使用的损失函数。\n",
    "- `optimizer`:所使用的优化器。\n",
    "- `metrics`:用于模型评估的评价函数。\n",
    "- `eval_network`:模型评估所使用的网络,未定义情况下,`Model`会使用`network`和`loss_fn`进行封装。\n",
    "\n",
    "`Model`提供了以下接口用于模型训练、评估和推理:\n",
    "\n",
    "- `fit`:边训练边评估模型。\n",
    "- `train`:用于在训练集上进行模型训练。\n",
    "- `eval`:用于在验证集上进行模型评估。\n",
    "- `predict`:用于对输入的一组数据进行推理,输出预测结果。\n",
    "\n",
    "### 使用Model接口\n",
    "\n",
    "对于简单场景的神经网络,可以在定义`Model`时指定前向网络`network`、损失函数`loss_fn`、优化器`optimizer`和评价函数`metrics`。\n",
    "\n",
    "## 下载并处理数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download data from open datasets\n",
    "from download import download\n",
    "\n",
    "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/\" \\\n",
    "      \"notebook/datasets/MNIST_Data.zip\"\n",
    "path = download(url, \"./\", kind=\"zip\")\n",
    "\n",
    "\n",
    "def datapipe(path, batch_size):\n",
    "    image_transforms = [\n",
    "        vision.Rescale(1.0 / 255.0, 0),\n",
    "        vision.Normalize(mean=(0.1307,), std=(0.3081,)),\n",
    "        vision.HWC2CHW()\n",
    "    ]\n",
    "    label_transform = transforms.TypeCast(mindspore.int32)\n",
    "\n",
    "    dataset = MnistDataset(path)\n",
    "    dataset = dataset.map(image_transforms, 'image')\n",
    "    dataset = dataset.map(label_transform, 'label')\n",
    "    dataset = dataset.batch(batch_size, drop_remainder=True)\n",
    "    return dataset\n",
    "\n",
    "train_dataset = datapipe('MNIST_Data/train', 64)\n",
    "test_dataset = datapipe('MNIST_Data/test', 64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T06:43:30.392367Z",
     "start_time": "2022-01-04T06:43:28.436687Z"
    }
   },
   "outputs": [],
   "source": [
    "# Define model\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits\n",
    "\n",
    "model = Network()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义损失函数和优化器\n",
    "\n",
    "要训练神经网络模型,需要定义损失函数和优化器函数。\n",
    "\n",
    "- 损失函数这里使用交叉熵损失函数`CrossEntropyLoss`。\n",
    "- 优化器这里使用`SGD`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate loss function and optimizer\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = nn.SGD(model.trainable_params(), 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练及保存模型\n",
    "\n",
    "在开始训练之前,MindSpore需要提前声明网络模型在训练过程中是否需要保存中间过程和结果,因此使用`ModelCheckpoint`接口用于保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "steps_per_epoch = train_dataset.get_dataset_size()\n",
    "config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch)\n",
    "\n",
    "ckpt_callback = ModelCheckpoint(prefix=\"mnist\", directory=\"./checkpoint\", config=config)\n",
    "loss_callback = LossMonitor(steps_per_epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过MindSpore提供的`model.fit`接口可以方便地进行网络的训练,`LossMonitor`可以监控训练过程中`loss`值的变化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1 step: 937, loss is 0.5313149094581604\n",
      "Eval result: epoch 1, metrics: {'accuracy': 0.8557692307692307}\n",
      "epoch: 2 step: 937, loss is 0.2875961363315582\n",
      "Eval result: epoch 2, metrics: {'accuracy': 0.9007411858974359}\n",
      "epoch: 3 step: 937, loss is 0.19009104371070862\n",
      "Eval result: epoch 3, metrics: {'accuracy': 0.9191706730769231}\n",
      "epoch: 4 step: 937, loss is 0.24231787025928497\n",
      "Eval result: epoch 4, metrics: {'accuracy': 0.9296875}\n",
      "epoch: 5 step: 937, loss is 0.16016671061515808\n",
      "Eval result: epoch 5, metrics: {'accuracy': 0.9386017628205128}\n",
      "epoch: 6 step: 937, loss is 0.4830142855644226\n",
      "Eval result: epoch 6, metrics: {'accuracy': 0.9444110576923077}\n",
      "epoch: 7 step: 937, loss is 0.20778779685497284\n",
      "Eval result: epoch 7, metrics: {'accuracy': 0.9508213141025641}\n",
      "epoch: 8 step: 937, loss is 0.22020074725151062\n",
      "Eval result: epoch 8, metrics: {'accuracy': 0.9540264423076923}\n",
      "epoch: 9 step: 937, loss is 0.15951070189476013\n",
      "Eval result: epoch 9, metrics: {'accuracy': 0.9575320512820513}\n",
      "epoch: 10 step: 937, loss is 0.11161471903324127\n",
      "Eval result: epoch 10, metrics: {'accuracy': 0.9608373397435898}\n"
     ]
    }
   ],
   "source": [
    "trainer = Model(model, loss_fn=loss_fn, optimizer=optimizer, metrics={'accuracy'})\n",
    "\n",
    "trainer.fit(10, train_dataset, test_dataset, callbacks=[ckpt_callback, loss_callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。\n",
    "\n",
    "通过模型运行测试数据集得到的结果,验证模型的泛化能力:\n",
    "\n",
    "1. 使用`model.eval`接口读入测试数据集。\n",
    "2. 使用保存后的模型参数进行推理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T06:43:30.952190Z",
     "start_time": "2022-01-04T06:43:30.525149Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.9607371794871795}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acc = trainer.eval(test_dataset)\n",
    "acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以在打印信息中看出模型精度数据,示例中精度数据达到95%以上,模型质量良好。随着网络迭代次数增加,模型精度会进一步提高。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13 (default, Mar 28 2022, 01:56:06) [MSC v.1916 32 bit (Intel)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "d735f6cc625ca8b095620eeb46a50be5e34ded063dbe74c6d5dc8e1ec88bb29c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}