{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 快速入门:手写数字识别\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/beginner/mindspore_quick_start.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.8/tutorials/zh_cn/beginner/mindspore_quick_start.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/tutorials/source_zh_cn/beginner/quick_start.ipynb)\n", "\n", "本章节贯穿MindSpore深度学习的基本流程,以LeNet5网络模型为例子,实现深度学习中的常见任务。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 下载并处理数据集\n", "\n", "数据集对于模型训练非常重要,好的数据集可以有效提高训练精度和效率。示例中用到的MNIST数据集是由10类28∗28的灰度图片组成,训练数据集包含60000张图片,测试数据集包含10000张图片。\n", "\n", "![mnist](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/tutorials/source_zh_cn/beginner/images/mnist.png)\n", "\n", "> 你可以从[MNIST数据集下载页面](http://yann.lecun.com/exdb/mnist/)下载,解压后按下方目录结构放置。\n", "\n", "[MindSpore Vision套件](https://mindspore.cn/vision/docs/zh-CN/r0.1/index.html)提供了用于下载并处理MNIST数据集的Mnist模块,以下示例代码将数据集下载、解压到指定位置并进行数据处理:\n", "\n", "> 本章节中的示例代码依赖`mindvision`,可使用命令`pip install mindvision`安装。如本文档以Notebook运行时,完成安装后需要重启kernel才能执行后续代码。\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from mindvision.dataset import Mnist\n", "\n", "# 下载并处理MNIST数据集\n", "download_train = Mnist(path=\"./mnist\", split=\"train\", batch_size=32, repeat_num=1, shuffle=True, resize=32, download=True)\n", "\n", "download_eval = Mnist(path=\"./mnist\", split=\"test\", batch_size=32, resize=32, download=True)\n", "\n", "dataset_train = download_train.run()\n", "dataset_eval = download_eval.run()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "参数说明:\n", "\n", "- path:数据集路径。\n", "- split:数据集类型,支持train、 test、infer,默认为train。\n", "- batch_size:每个训练批次设定的数据大小,默认为32。\n", "- repeat_num:训练时遍历数据集的次数,默认为1。\n", "- shuffle:是否需要将数据集随机打乱(可选参数)。\n", "- resize:输出图像的图像大小,默认为32*32。\n", "- download:是否需要下载数据集,默认为False。\n", "\n", "下载的数据集文件的目录结构如下:\n", "\n", "```text\n", "./mnist/\n", "├── test\n", "│ ├── t10k-images-idx3-ubyte\n", "│ └── t10k-labels-idx1-ubyte\n", "└── train\n", " ├── train-images-idx3-ubyte\n", " └── train-labels-idx1-ubyte\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 创建模型\n", "\n", "按照LeNet的网络结构,LeNet除去输入层共有7层,其中有2个卷积层,2个子采样层,3个全连接层。\n", "\n", "![](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/tutorials/source_zh_cn/beginner/images/lenet.png)\n", "\n", "下面是LeNet5网络模型的具体代码实现:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import mindspore.nn as nn\n", "\n", "class LeNet5(nn.Cell):\n", " \"\"\"\n", " LeNet-5网络结构\n", " \"\"\"\n", " def __init__(self, num_class=10, num_channel=1):\n", " super(LeNet5, self).__init__()\n", " # 卷积层,输入的通道数为num_channel,输出的通道数为6,卷积核大小为5*5\n", " self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n", " # 卷积层,输入的通道数为6,输出的通道数为16,卷积核大小为5*5\n", " self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n", " # 全连接层,输入个数为16*5*5,输出个数为120\n", " self.fc1 = nn.Dense(16 * 5 * 5, 120)\n", " # 全连接层,输入个数为120,输出个数为84\n", " self.fc2 = nn.Dense(120, 84)\n", " # 全连接层,输入个数为84,分类的个数为num_class\n", " self.fc3 = nn.Dense(84, num_class)\n", " # ReLU激活函数\n", " self.relu = nn.ReLU()\n", " # 池化层\n", " self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n", " # 多维数组展平为一维数组\n", " self.flatten = nn.Flatten()\n", "\n", " def construct(self, x):\n", " # 使用定义好的运算构建前向网络\n", " x = self.conv1(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.conv2(x)\n", " x = self.relu(x)\n", " x = self.max_pool2d(x)\n", " x = self.flatten(x)\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", " x = self.relu(x)\n", " x = self.fc3(x)\n", " return x\n", "\n", "network = LeNet5(num_class=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "同时,MindSpore Vision套件提供了LeNet网络模型接口`lenet`, 定义网络模型如下:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from mindvision.classification.models import lenet\n", "\n", "network = lenet(num_classes=10, pretrained=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义损失函数和优化器\n", "\n", "要训练神经网络模型,需要定义损失函数和优化器函数。\n", "\n", "- 损失函数这里使用交叉熵损失函数`SoftmaxCrossEntropyWithLogits`。\n", "- 优化器这里使用`Momentum`。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import mindspore.nn as nn\n", "\n", "# 定义损失函数\n", "net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n", "\n", "# 定义优化器函数\n", "net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 训练及保存模型\n", "\n", "在开始训练之前,MindSpore需要提前声明网络模型在训练过程中是否需要保存中间过程和结果,因此使用`ModelCheckpoint`接口用于保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "\n", "# 设置模型保存参数,模型训练保存参数的step为1875。\n", "config_ck = ms.CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)\n", "\n", "# 应用模型保存参数\n", "ckpoint = ms.ModelCheckpoint(prefix=\"lenet\", directory=\"./lenet\", config=config_ck)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "通过MindSpore提供的`model.train`接口可以方便地进行网络的训练,`LossMonitor`可以监控训练过程中`loss`值的变化。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch:[ 0/ 10], step:[ 1875/ 1875], loss:[0.314/0.314], time:2237.313 ms, lr:0.01000\n", "Epoch time: 3577.754 ms, per step time: 1.908 ms, avg loss: 0.314\n", "Epoch:[ 1/ 10], step:[ 1875/ 1875], loss:[0.031/0.031], time:1306.982 ms, lr:0.01000\n", "Epoch time: 1307.792 ms, per step time: 0.697 ms, avg loss: 0.031\n", "Epoch:[ 2/ 10], step:[ 1875/ 1875], loss:[0.007/0.007], time:1324.625 ms, lr:0.01000\n", "Epoch time: 1325.340 ms, per step time: 0.707 ms, avg loss: 0.007\n", "Epoch:[ 3/ 10], step:[ 1875/ 1875], loss:[0.021/0.021], time:1396.733 ms, lr:0.01000\n", "Epoch time: 1397.495 ms, per step time: 0.745 ms, avg loss: 0.021\n", "Epoch:[ 4/ 10], step:[ 1875/ 1875], loss:[0.028/0.028], time:1594.762 ms, lr:0.01000\n", "Epoch time: 1595.549 ms, per step time: 0.851 ms, avg loss: 0.028\n", "Epoch:[ 5/ 10], step:[ 1875/ 1875], loss:[0.007/0.007], time:1242.175 ms, lr:0.01000\n", "Epoch time: 1242.928 ms, per step time: 0.663 ms, avg loss: 0.007\n", "Epoch:[ 6/ 10], step:[ 1875/ 1875], loss:[0.033/0.033], time:1199.938 ms, lr:0.01000\n", "Epoch time: 1200.627 ms, per step time: 0.640 ms, avg loss: 0.033\n", "Epoch:[ 7/ 10], step:[ 1875/ 1875], loss:[0.175/0.175], time:1228.845 ms, lr:0.01000\n", "Epoch time: 1229.548 ms, per step time: 0.656 ms, avg loss: 0.175\n", "Epoch:[ 8/ 10], step:[ 1875/ 1875], loss:[0.009/0.009], time:1237.200 ms, lr:0.01000\n", "Epoch time: 1237.969 ms, per step time: 0.660 ms, avg loss: 0.009\n", "Epoch:[ 9/ 10], step:[ 1875/ 1875], loss:[0.000/0.000], time:1287.693 ms, lr:0.01000\n", "Epoch time: 1288.413 ms, per step time: 0.687 ms, avg loss: 0.000\n" ] } ], "source": [ "from mindvision.engine.callback import LossMonitor\n", "import mindspore as ms\n", "\n", "# 初始化模型参数\n", "model = ms.Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={'accuracy'})\n", "\n", "# 训练网络模型,并保存为lenet-1_1875.ckpt文件\n", "model.train(10, dataset_train, callbacks=[ckpoint, LossMonitor(0.01, 1875)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。\n", "\n", "通过模型运行测试数据集得到的结果,验证模型的泛化能力:\n", "\n", "1. 使用`model.eval`接口读入测试数据集。\n", "2. 使用保存后的模型参数进行推理。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'accuracy': 0.9903846153846154}\n" ] } ], "source": [ "acc = model.eval(dataset_eval)\n", "\n", "print(\"{}\".format(acc))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "可以在打印信息中看出模型精度数据,示例中精度数据达到95%以上,模型质量良好。随着网络迭代次数增加,模型精度会进一步提高。\n", "\n", "## 加载模型" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import mindspore as ms\n", "\n", "# 加载已经保存的用于测试的模型\n", "param_dict = ms.load_checkpoint(\"./lenet/lenet-1_1875.ckpt\")\n", "# 加载参数到网络中\n", "ms.load_param_into_net(network, param_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> 阅读更多有关[加载模型](https://www.mindspore.cn/tutorials/zh-CN/r1.8/beginner/save_load.html#加载模型)的信息。\n", "\n", "## 验证模型\n", "\n", "我们使用生成的模型进行单个图片数据的分类预测,具体步骤如下:\n", "\n", "> - 被预测的图片会随机生成,每次运行结果可能会不一样。\n", "> - 代码使用了Tensor模块,阅读更多有关[张量Tensor](https://www.mindspore.cn/tutorials/zh-CN/r1.8/beginner/save_load.html)的信息。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAlq0lEQVR4nO2de7AU1bXGvyWCIKCAICIvQRDEB6CAKD5QyqgYA1FjQGNhQoqUyU1pclORpCpaJpUqzB+JicaKJBpJYhnxqkgiRkBRERU58pCXAj5Q3iIiiE9w3z/OsPhmMnPOcM7MnNnd369qim/6Mb37rO5N79VrrW0hBAghhIiPQ5q6AUIIIRqGOnAhhIgUdeBCCBEp6sCFECJS1IELIUSkqAMXQohIaVQHbmYXm9nrZrbOzCaXqlGiaZFdk4tsmyysoXHgZtYMwBoAFwLYAGARgPEhhFWla56oNLJrcpFtk8ehjdh3GIB1IYQ3AcDM/glgDICCF4OZKWuoSgghWIFVsmvE1GFX4CBtK7tWFdtDCJ1yFzbGhdIVwLv0fUNmWRZmNsnMasysphHHEpVDdk0u9dpWdq1a1udb2Jgn8KIIIUwFMBXQ/+hJIgl2NTvwsNquXbusdb/85S9dP/zww64XLFjg+osvvihf45qIJNg1TTTmCXwjgO70vVtmmYgb2TW5yLYJozEd+CIAfc2sl5m1ADAOwMzSNEs0IbJrcpFtE0aDXSghhL1m9j8AngTQDMC9IYSVJWuZaBLSZNdmzZq5Hjp0aNa64447Lu92+/btK3u7ykWabJsWGuUDDyHMAjCrRG0RVYLsmlxk22ShTEwhhIiUskehCFGtHHrogcv/rLPOylr36aefut6zZ4/rL7/8svwNE6JI9AQuhBCRog5cCCEiRS4UkSrYbXLssce6Pu+887K2mz9/vuv33nuv/A0TogHoCVwIISJFHbgQQkSKXChFwIkcgwcPzlrXtm1b12+++abrzZs3u/7888/L2DpxMLC9zj//fNfdunXL2u65555zvWnTpvI3TDQpXBenRYsWrlu3bu26VatWrj/66CPXfE1t27bNdSXuez2BCyFEpKgDF0KISFEHLoQQkSIfeBG0bNnS9Y9//OOsdaeddprr3//+9665hjT7xUTl4dDB7t0PVFMdP36866VLl2bts3btWteffPJJ+RonSgLbmH3V7du3z9qOfd1MmzZtXB999NGue/bs6bpz586u168/ML9Cr169XD/00EN5twHKUz9eT+BCCBEp6sCFECJS5EIpAi5gtH379qx1nTodmGeUQ9EOP/zw8jdMFEXHjh1dc8bloEGDXN99991Z+3AxK1GdsNuE3R5cmOzyyy/P2qd58+Z5f2vAgAGu+/Tpk3f7Qu4XhvuKf/zjH1nrNm4s/eRHegIXQohIUQcuhBCRIhdKEfCw6N13381ax9lWI0aMcM2ZfG+//Xb5GifqhadHGzVqlOunn37a9V133ZW1T66rTFQf7N7o16+f69/97neu2cUJAIcckv+Zle/xzz77zDVnXPLxOPuSYZcNF0QD5EIRQghBqAMXQohISYQLpdDs4jxcfvzxx10vX77cdTGzjO/du9d17rBo0qRJrk8//XTXJ5xwgmseqpcjmF/8Nx06dHDN0Sb9+/d3fdttt7l+//33s/aPefb5tMDuEE6246Scd955J2ufxYsXu169erXrt956K+8+H3zwgeu+ffu6/s1vfuOak31CCMWfQAnQE7gQQkRKvR24md1rZtvMbAUt62Bmc8xsbebf9nX9hqg+ZNfkItumh2JcKPcBuBPA32jZZABPhRCmmNnkzPebSt+84mAXysCBA11fccUVro8//njXnLSxcOHCen+f31Dn1jfg6bZ69OjhmodbvPyNN96o93gV4j5UuV0bA7uzLr74Ytf892fXVsJqtt+HBNt2P5xsxbVsrr/+etc7d+7M2ofr9LN7hKNN9uzZ45r7li5durjmGiuc4PPMM8+43rJlS32n0GjqfQIPITwHYEfO4jEApmX0NABjS9ssUW5k1+Qi26aHhr7E7BxC2P9f2RYAnQttaGaTAEwqtF5UFbJrcinKtrJrXDQ6CiWEEMys4KvXEMJUAFMBoK7tSgVHD3AQ/0UXXeSak2yKcaHwm+UdO7IfbHiYxAkARxxxhOtCQf/VTLXZtRh4lnmuh8F1Mv7+97+73rBhg2t2kyWdumxbjXYtBN/rfB/OmDHDNUeQ5fueD55SjV2h3Ifw/c0uFI5mqkQ9nYZGoWw1sy4AkPlXBa+TgeyaXGTbBNLQDnwmgAkZPQHAY6VpjmhiZNfkItsmkHpdKGb2AICRADqa2QYAtwCYAmC6mU0EsB7AVeVsZH3wUGrVqlWud+3a5ZqHQr1793bNpUaLqX/BbhIA2L17t2senh122GF5dbUQg10PFi4VO2TIENcrVng0HWbPnu26lMk6HK3ApYR5VnMeXpczoSuJtq0PdnM21nXRrl0718OGDXP91a9+1TW7TTiahWslVWImp3o78BDC+AKrRhVYLiJAdk0usm16UCamEEJESiJqoXAEwcqVK12/9tprrnm2HE72OeWUU1zPmzevZG065phj8mpRWo466ijXXPvmyCOPdP2Xv/zFdWMTqbj+BidzcKIY11vhCJi5c+e65mtTs/80PTxhMbtNzj33XNc8Uw+7S1966SXXL7/8smt2rZQLPYELIUSkqAMXQohISYQLhd9Af/jhh655yHrSSSe5Pu2001xzzQyuY1BsWUiOZGBXDpeTPfHEE13PmjXLtUrLNgyezHbkyJGu2U22ZMkS1wsWLGjU8Xgmlu7du7vmqJexY8e65mF3q1atXLPL5Q9/+INruVCaBk6wGz16tOtx48a55lLEXC+Hy89OmzbNdaVnctITuBBCRIo6cCGEiJREuFAK8cILL7jmYREPcTl6gJMueFhbV50MLlfJQyyuy8HuFJ4pZuvWrXW2X+SHI0/Gjz8Q8syzsrz44ouuGzKs5UQNTvz6yU9+4vqqqw7kwnDyDl8vHK3A2xSaXDfNcMIbu8nYFgy7OfnvzPdhriuUXVpf+cpXXH/72992fc455+T9XY4cuvXWW13PnDkz77Erga4iIYSIFHXgQggRKYl2ofDkpFwGlodFl156qWse1j7yyCOuub5BLuwG4Vk9mK5du7rmiBS5UBoGD4M5YefJJ590zcPahsCJHZMnT3Y9ZswY1+wq4QQOHmqffPLJrjVR8n/DNWS+//3vu+YIH47eYThR5vnnn3f9t78dmIgod7Lq6667zjXP3MP3JcOJgVOmTHH92GMHaoE1ZTSZnsCFECJS1IELIUSkqAMXQohISbQPnLMyf/3rX7vmmeQvueQS19/4xjdcX3jhha7rCg3iqZU4RJBh3zr7/ETD4OJgHJrH9Zc//vjjRh2Di5xxcSr2uz744IOu+Z0J157nkNWNGze6TlsWLr+3GDx4sOuf/vSnrjlDmv3ehe4ZfqdwxhlnuP7Wt77lOnfKxPPPP991z549XfO1wyGo7E+vxixqPYELIUSkqAMXQohISbQLhcO8li9f7vquu+5yzdOucVZfjx49XDfW7cEuFM4wEw1j6NChrjt16tTg32G7sNsDACZOnOiaCxdxbXGeqo33v/rqq12zu27+/PmueSq+JJJrF57R/Xvf+55rdlW9+uqrrpcuXeqa72MOyWXXFrvVOASUM6KBbNcM34vLli1zze4wDk3lvqJa0BO4EEJEijpwIYSIlNSM53k2eR76/vWvf3W9Zs0a11zAiGepzoXdKyNGjHDNRbJ4ONmrV6+DaLXIB/9tuaZzsTXc98O2GzBgQNY6dolwJANHPlxwwQWuObqBizLdcccdrjlDs9JFjyoBR2GdffbZWeuuueYa1+yevP/++13PmTPHNWdOd+nSxTXPDJ/r9toPF78q1sXGUTIcncJuGnbDcoRbU6IncCGEiJR6O3Az625m88xslZmtNLMbMss7mNkcM1ub+Td/wQJRlciuyUR2TRfFuFD2AvjfEMJiM2sL4BUzmwPgOgBPhRCmmNlkAJMB3FS+ppYOLmbFxWrWrVvnmmtOc7JOLjwMZ82zkfMQkIvm8FRdTZAYEK1duVY725Lrrg8ZMsQ1D8c58qRNmzauuagZkG3/gQMH5j0Gu0o42mT69OmuOaKB3XhlpKJ2bdGihWt2Q3GCHJB9D9x7772uORmK7cp/Z3ZjHHfcca65/jtfB1yjn7cBsu3P9mO3CU+Pxy46TuKKxoUSQtgcQlic0bsBrAbQFcAYAPsng5sGYGyZ2ijKgOyaTGTXdHFQLzHN7DgAgwEsBNA5hLA5s2oLgM4F9pkEYFIj2ijKjOyaTGTX5FN0B25mbQA8DODGEMIuftMbQghmljcEIIQwFcDUzG8cXJhAheEh7qZNm/Lquli7dq1rHmLxbOkc3cLDNnbfVJIY7cr12bnmyahRo1yz24u35yE01xLnYXPuOna18G9xYs6jjz7qmqNWKuQ2+S8qZdeOHTu6Hj58uGuuTQIA27Ztc821tDnCiyNX2AXDtVM4EYcjedavX+/6mWeecZ07dR27OdmdyYlAbG92/XCkSrVQVBSKmTVH7cVwfwhhv1Nvq5l1yazvAmBbof1FdSK7JhPZNT0UE4ViAO4BsDqE8FtaNRPAhIyeAOCx3H1F9SK7JhPZNV0U40IZAeBaAMvNbGlm2c8BTAEw3cwmAlgP4Kr8u6cHdrXwG+tCbhOu6dEELpRo7crDcdYcedKvX7+8+3JdDXa/5CbWcFIXT5fGU7W98MILedtxsAlFJaaidu3evbvr008/3TVf80B2hMlll13mmhOg2FXCLiy2GU+Rxi7LuXPnur7nnntc87SKQHaSD7eDS9nyNqtWrXJd19SKTUW9HXgI4XkAVmD1qALLRZUjuyYT2TVdKBNTCCEiJTW1UCrBU0895ZrrZPDQkodnnMgjiof/zpwwxeVBOfmD4ZlXuCYOD8eB7ASTJUuWuE56GdiDhaNQuBZKbsTGsGHDXLOri91N7Crh6B12T3G0D5f2ZZdXsXVmZsyYkVfHhJ7AhRAiUtSBCyFEpMiFUkK2bNnimod9PKTj2gpc10EUD0f48DCaJ6DlhI1CUSFcGpY1kF2bhutsiGyeffZZ11w/hhNjgGyX1oYNG1zzfcL3zyuvvOJ69uzZrrl2EUe25NovLegJXAghIkUduBBCRIpcKCWEh9qvv/66a35D3qdPH9c80eusWbNcc9RDWoeGdcEuEY5WaKq6I2lmz549rh9//HHXixcvztru8MMPd80JVOyqYvvx73LkD7tNhJ7AhRAiWtSBCyFEpMiFUia4hsLq1atdn3zyya65zCzXgXj11Vddy4Uiqhl2Z/HMRxwpVNc+xSwXhdETuBBCRIo6cCGEiBR14EIIESnygZcJDiPksMBzzjnH9ebNm11v3brVtXyBIkb4utU1XBn0BC6EEJGiDlwIISJFLpQywVM/zZkzxzXXPN6+fbvrmpoa1yqeJIQoBj2BCyFEpKgDF0KISLFKvi02M72arhJCCIUmvj1oZNfqQXZNLK+EEIbkLtQTuBBCREq9HbiZtTSzl81smZmtNLNbM8t7mdlCM1tnZg+aWYvyN1eUCtk1mciuKSOEUOcHgAFok9HNASwEMBzAdADjMsv/BOD6In4r6FM1H9k1mR/ZNZmfmnw2qvcJPNTyUeZr88wnALgAwP9llk8DMLa+3xLVg+yaTGTXdFGUD9zMmpnZUgDbAMwB8AaAnSGE/QHLGwB0LbDvJDOrMbOafOtF0yG7JhPZNT0U1YGHEPaFEAYB6AZgGID+de+Rte/UEMKQfG9QRdMiuyYT2TU9HFQUSghhJ4B5AM4E0M7M9mdydgOwsbRNE5VCdk0msmvyKSYKpZOZtcvoVgAuBLAatRfGlZnNJgB4rExtFGVAdk0msmu6qDeRx8xORe1Lj2ao7fCnhxB+aWa9AfwTQAcASwB8K4RQ57TgZvYegD0Atte1XULpiOo5754ARqG0dl2P6jrHSlFN5yy7lo5qO+eeIYROuQsrmokJAGZWk0b/WhrOOw3nmEsazjkN55hLLOesTEwhhIgUdeBCCBEpTdGBT22CY1YDaTjvNJxjLmk45zScYy5RnHPFfeBCCCFKg1woQggRKerAhRAiUiragZvZxWb2eqak5eRKHrtSmFl3M5tnZqsy5TxvyCzvYGZzzGxt5t/2Td3WUpEGuwLps63sWv12rZgP3MyaAViD2sywDQAWARgfQlhVkQZUCDPrAqBLCGGxmbUF8ApqK79dB2BHCGFK5mZoH0K4qelaWhrSYlcgXbaVXeOwayWfwIcBWBdCeDOE8Dlqs8LGVPD4FSGEsDmEsDijd6M2jbkras91WmazJJXzTIVdgdTZVnaNwK6V7MC7AniXvhcsaZkUzOw4AINRW1S/cwhhc2bVFgCdm6pdJSZ1dgVSYVvZNQK76iVmmTCzNgAeBnBjCGEXrwu1fivFb0aKbJtMYrRrJTvwjQC60/fElrQ0s+aovRDuDyE8klm8NeNr2+9z29ZU7SsxqbErkCrbyq4R2LWSHfgiAH0zk6u2ADAOwMwKHr8imJkBuAfA6hDCb2nVTNSW8QSSVc4zFXYFUmdb2TUCu1Y0E9PMRgO4HbWlLu8NIfy6YgevEGZ2NoD5AJYD+DKz+Oeo9alNB9ADtSU6rwoh7GiSRpaYNNgVSJ9tZdfqt6tS6YUQIlL0ElMIISJFHbgQQkRKozrwtKTapg3ZNbnItgkjhNCgD2pfbLwBoDeAFgCWARhQzz5Bn+r4yK7J/JTynm3qc9En6/NePhs15gk8Nam2KUN2TS6ybbysz7ewMR14Uam2ZjbJzGrMrKYRxxKVQ3ZNLvXaVnaNi0PLfYAQwlRkpicys1Du44nKILsmE9k1LhrzBJ6qVNsUIbsmF9k2YTSmA09Nqm3KkF2Ti2ybMBrsQgkh7DWz/wHwJA6k2q4sWctEkyC7JhfZNnlUuhaKfGpVQgjBSvVbsmv1ILsmlldCCENyFyoTUwghIkUduBBCRIo6cCGEiBR14EIIESllT+QRQoim4tBDs7u4c845x3Xfvn1df/DBB66XLFniet26dWVsXePRE7gQQkSKOnAhhIiUxLlQWrdu7bpnz56uu3Xr5rpVq1YH9ZscK79jR+Ep8Xbu3Ol69+7drnl4tmvXroM6tigttfPX/reui5YtW7o+6qijXHfo0MF18+bNXX/66aeuP//8c9fr1x8oKPfZZ58V2WJxsDRr1sz1oEGDstbdcMMNrkeOHOn6nXfecf3HP/7R9YYNG1yzXasFPYELIUSkqAMXQohIUQcuhBCRkggfOIcKDRs2zPXVV1/tevTo0a67dOlyUL+/b98+1ytWrMhat3fvXtevvfaa6zfeeMP1woULXT///POu2U8uGgb7sQ855MDzCL8Lad++vesjjzzSdYsWLYo6RqdOnVyfdtpprgcOHOi6bdu2rvmdx/vvv+/6nnvucc3XSjX6VmPmsMMOcz1x4sSsdRxGyNdCnz59XLONH3/8cdfsD68W9AQuhBCRog5cCCEiJREuFA7zuvnmm11zCBG7WTiU72DL6fbq1SvrOw/bTznllLzLFyxYkLet//73v11/8cUXB9UOUQu7Ljp27OiaXWljx451PWLECNccWtpYCl1HHEZ4xBFHuP7FL37hmkPYRONhW+TeV4XsxPclX1O8vBrRE7gQQkSKOnAhhIiURLhQOBKEozw48+2FF15wvXTpUteNzYw86aSTXF9++eWuOctr+PDhrtmVw26Whx9+uFHtSBMcYXLLLbe45r8/u1M42iS3uFExsBuEbfbll1/Wu83HH3/smqOfDj/88Lzb5/6uOHi4P3juueey1l122WWuOauWI8K2bdvmutozp/UELoQQkaIOXAghIiURLhQuDHTnnXe65qHzRx995JqHtZyk0xA2btzoevny5a65sNX48eNdH3/88a7ZzSIXSvGceeaZrjnahKNKCrlK+FrhZCtO2ACyoxc4SoRdH3xNvfXWW67btGnjeu3ata45YYeTQuQyKS18319xxRVZ6zipi+F+gN1h7I6pRvQELoQQkVJvB25m95rZNjNbQcs6mNkcM1ub+Tf/f2uiapFdk4tsmx6KcaHcB+BOAH+jZZMBPBVCmGJmkzPfbyp984qDg/O3bt1a0WPz22t+q81DMm4fu1Zy66pUmPtQ5XYtRP/+/V0fffTRrtltwm6JTZs2uZ49e7ZrTqRi9xeQPXTes2dP3mPwNnwdcG3wJqp3cx8itW1D4Zo47ObKTdZi2zBcm4anVKv2ekX1PoGHEJ4DkDuLwRgA0zJ6GoCxpW2WKDeya3KRbdNDQ19idg4hbM7oLQA6F9rQzCYBmNTA44jKIrsml6JsK7vGRaOjUEIIwcwKFhQJIUwFMBUA6tqumuHhWdeuXbPWnX322a7POuss1+eee65rdqFw5MOyZctK2s5SUm125cgCrjnDUQUcPcB/55dfftk1zzLOU2+xmwTILgnLkSvF1M6p9vKwddk21vuV3SaDBw923b1796zt2IXCLrDVq1e7XrVqletqr1HU0CiUrWbWBQAy/26rZ3sRB7JrcpFtE0hDO/CZACZk9AQAj5WmOaKJkV2Ti2ybQOp1oZjZAwBGAuhoZhsA3AJgCoDpZjYRwHoAV5WzkZWCh9Rc+vOEE05w/fWvfz1rH07G6devn2sezr/44ouuOWGHkz8qTWx25bKe/HdmOzHs9uLZ43lGlvPOO8/1kCFDsvZ//fXXXXNUAtuMIxSqKRknNtuWAi4Be8EFF7jmKCUgO4qoUBJeTOV96+3AQwjjC6waVeK2iAoiuyYX2TY9KBNTCCEiJRG1UBoDD8F79+7tmt9kn3/++a5zayvwUJ2HZOw2mT59uuu5c+e6rvZohaaG/7bs3uKEKZ7AluFkH9aFuOiii7K+cw2TZ555xjXbr6amxjXbnqNWRPng64Pv46FDh7quq3zwokWLXLMtOQKp2tETuBBCRIo6cCGEiJRUulB4BpQBAwa4/s53vuP6kksucc0zqbz77rtZv/X222+7njVrlusnnnjC9Zo1a1xzdIooHo7y4CEuz5jCduUEDE7SKfbvz1Ev11xzjWuOVpkxY4ZrLke7cuVK13KTlQ9O3mH3Z9++fV2z6w3Idrs8/fTTrjl5Jyb0BC6EEJGiDlwIISIlNS4UHjq1a9fONc/uwm6TY4891jXPvDJt2jQwd911l2ueDFU0Hq47wrMoPfroo665fPCRRx7pmhNuXnrpJdc8E05dsP3ZtcYRDjfddKAaK7tWbr75Ztc8gbZoPHwf9+rVy/U3v/lN17klZBm+ptgdumNHbvHGONATuBBCRIo6cCGEiJTUuFC47Oill17q+mtf+5prjjZht8m8efNc33HHHVm/G+vQKzY4qoTdVlOnTnXNw2uOWuGyocXWLOGIFr4WbrzxRtcjRoxwzSVu+ZqSC6W0cMIOl2/micPThJ7AhRAiUtSBCyFEpKgDF0KISEmND5xDB7kmNId/sd+zdevWrjl0jH2dALBw4ULXn3zySUnaKuqGi0VVonDUggULXPN1xAwaNMj18OHDXfO7l507d7ouZmq2NMO+bq65P2bMGNd8H3NRszT9bfUELoQQkaIOXAghIiU1LpT33nvP9QMPPOCai91w+BfXnO7UqZPrX/3qV1m/u2XLFtd33nmna87+U33ouNm+fbtrLlTFU29xSBvPhM7ZgsuWLXO9b9++krczSVx88cWur732WteF3FMMh5PmwrXd+d6NFT2BCyFEpKgDF0KISEmNC4VnEOfpzng51/bmGct5lmsufgVkD4XZHcNZnfPnz3fNU2+JOOBMzvXr17vmade45niHDh1cc7Ymu1/kQqmbnj17uu7Xr59rdm0WcpXUFYXCBazYZrFS7xO4mXU3s3lmtsrMVprZDZnlHcxsjpmtzfyb3yElqhLZNZnIrumiGBfKXgD/G0IYAGA4gB+Y2QAAkwE8FULoC+CpzHcRD7JrMpFdU0S9LpQQwmYAmzN6t5mtBtAVwBgAIzObTQPwDICb8vxE1cHTXHGxoU2bNrlesWKFa3Z7cJ1oADjhhBNcc8IBD6N52Pef//zHNdemrnSkSmx2bdGihWt2aa1bt841R4WUa+o6jkhh+7ErjmdCL2bIX0pis2sxsEuE7xNOjOJ7mt0vdf1WEhJ+DsoHbmbHARgMYCGAzpmLBQC2AOhcYJ9JACY1oo2izMiuyUR2TT5FR6GYWRsADwO4MYSwi9eF2v/K8v53FkKYGkIYEkIYkm+9aFpk12Qiu6aDop7Azaw5ai+G+0MIj2QWbzWzLiGEzWbWBUCU84lxhAG7UHiW8ZqaGtf9+/fP2p+H85dffrnrYcOGuW7btq3rjh07umZ3CtdUqRTVbld2ObRq1cr1D37wA9cc4fPII4+4Lpc7hWtu8Kzo3Nb333/f9eLFi11XKvKk2u1aDOye4uidN998M+82nGxXlwslaRQThWIA7gGwOoTwW1o1E8CEjJ4A4LHSN0+UC9k1mciu6aKYJ/ARAK4FsNzMlmaW/RzAFADTzWwigPUAripLC0W5kF2TieyaIoqJQnkeQKHX56NK25zqgV0rHHnAtRQAYNGiRa55VvQrr7zSNZca5QQfruVQaRdKzHZl18Xo0aNd87Rr//rXv1xzpEpD4AQtjjoaOHCga3atrF692jVfL9y+chGzXRmexpBdm5x8w66x7373u0X9Ll87HC0UK0qlF0KISFEHLoQQkRL/GKKE8Iw8nIjDs/Bs3bo1ax+un8JRJX379nV97rnnuuahNid58LGLnTk96XCiBSfK/PnPf3Z9000HclF+9KMfuW7ZsqXr22+/3TUnfBSbyMF24hmcBgwY4JqH9qtWrXK9a1dWBJ8oEi71WqjsK0cmcdTYD3/4w4K/y/cy3+OxoidwIYSIFHXgQggRKXKhEDyTyoQJE1zz0JzLiQLAz372M9cnnXSSax6esUuEh/Z8vN69e7vmZAW5U2rhv8OTTz7petSoA4EVY8eOdX399de7PuaYY1zfdtttrjm6oS64/s2FF17omt0pmzdvhqgO6nKNHX/88a579OjhmqNTPv744/I0rAzoCVwIISJFHbgQQkRK6l0offr0cc21TMaMGeP6iCOOcH3iiSdm7c8TGbN7pF27dq45EYSTSh566CHXPFOI3CZ18+GHH7qePXu2a44K4Vo048aNc81urjlz5rjOdYGwDdhNc8YZZ7jmRBBuU2MTh0RxcG0ZjgLatu1AmReukQJk34unn366a07I4yiiakdP4EIIESnqwIUQIlJS70LhCIPhw4e75poXXCqUZ4YBgF69euX9XR6C8/Cca2PMmDHDdaVn5IkZ/ts+++yzrtltxYlR7PY466yzXLPt6oo8OProo11z/RouWTt37lzXPGm2KB9cW4ZLyz7xxBOux48fn7UP37/s6spN0IsFPYELIUSkqAMXQohIUQcuhBCRknofOBcb4tnnecbr1q1bF/VbXDecZ7vn2sYcusbHEw2DQ8ZmzZrlmm0xcuRI14MHD3bNYYe57zL4vQdfIzy9F/u9H330Ude52bqiPHDGJfuw7777btc8HSKQHfrJxec++OCDcjSx7OgJXAghIkUduBBCRIoVWxO5JAczq9zBioSLTp166qmuhw4d6ppDDXloDWQP43jYvmzZMtc8M3mxBZTKTQih0LRbB0012pXrrvM0duxCKVR8DMgOQ+QsSy40tmTJEtdr1qxxzTXHK03S7ZpiXgkhDMldqCdwIYSIFHXgQggRKal3oaQVDbWTieyaWBrmQjGzlmb2spktM7OVZnZrZnkvM1toZuvM7EEza1Hfb4nqQXZNJrJryggh1PkBYADaZHRzAAsBDAcwHcC4zPI/Abi+iN8K+lTNR3ZN5kd2TeanJp+N6n0CD7V8lPnaPPMJAC4A8H+Z5dMAjK3vt0T1ILsmE9k1XRT1EtPMmpnZUgDbAMwB8AaAnSGEvZlNNgDoWmDfSWZWY2Y1+daLpkN2TSaya3ooqgMPIewLIQwC0A3AMAD9iz1ACGFqCGFIPge8aFpk12Qiu6aHgwojDCHsBDAPwJkA2pnZ/sIC3QCosEekyK7JRHZNPsVEoXQys3YZ3QrAhQBWo/bCuDKz2QQAj5WpjaIMyK7JRHZNGUW8iT4VwBIArwJYAeDmzPLeAF4GsA7AQwAO01vtqD6yazI/smsyP3mjUCqdyPMegD0Atte3bQLpiOo5754hhE71b1YcGbuuR3WdY6WopnOWXUtHtZ1zXttWtAMHADOrSeMLkjScdxrOMZc0nHMazjGXWM5ZtVCEECJS1IELIUSkNEUHPrUJjlkNpOG803COuaThnNNwjrlEcc4V94ELIYQoDXKhCCFEpKgDF0KISKloB25mF5vZ65maxJMreexKYWbdzWyema3K1GO+IbO8g5nNMbO1mX/bN3VbS0Ua7Aqkz7aya/XbtWI+cDNrBmANalN7NwBYBGB8CGFVRRpQIcysC4AuIYTFZtYWwCuoLd15HYAdIYQpmZuhfQjhpqZraWlIi12BdNlWdo3DrpV8Ah8GYF0I4c0QwucA/glgTAWPXxFCCJtDCIszejdq61B0Re25TstslqR6zKmwK5A628quEdi1kh14VwDv0veCNYmTgpkdB2AwamdF6RxC2JxZtQVA56ZqV4lJnV2BVNhWdo3ArnqJWSbMrA2AhwHcGELYxetCrd9K8ZuRItsmkxjtWskOfCOA7vQ9sTWJzaw5ai+E+0MIj2QWb8342vb73LY1VftKTGrsCqTKtrJrBHatZAe+CEDfzOzYLQCMAzCzgsevCGZmAO4BsDqE8FtaNRO1dZiBZNVjToVdgdTZVnaNwK6VLic7GsDtAJoBuDeE8OuKHbxCmNnZAOYDWA7gy8zin6PWpzYdQA/Ului8KoSwo0kaWWLSYFcgfbaVXavfrkqlF0KISNFLTCGEiBR14EIIESnqwIUQIlLUgQshRKSoAxdCiEhRBy6EEJGiDlwIISLl/wFg7lq93NT4EgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"[4 6 2 3 5 1]\", Actual: \"[4 6 2 3 5 1]\"\n" ] } ], "source": [ "import numpy as np\n", "import mindspore as ms\n", "import matplotlib.pyplot as plt\n", "\n", "mnist = Mnist(\"./mnist\", split=\"train\", batch_size=6, resize=32)\n", "dataset_infer = mnist.run()\n", "ds_test = dataset_infer.create_dict_iterator()\n", "data = next(ds_test)\n", "images = data[\"image\"].asnumpy()\n", "labels = data[\"label\"].asnumpy()\n", "\n", "plt.figure()\n", "for i in range(1, 7):\n", " plt.subplot(2, 3, i)\n", " plt.imshow(images[i-1][0], interpolation=\"None\", cmap=\"gray\")\n", "plt.show()\n", "\n", "# 使用函数model.predict预测image对应分类\n", "output = model.predict(ms.Tensor(data['image']))\n", "predicted = np.argmax(output.asnumpy(), axis=1)\n", "\n", "# 输出预测分类与实际分类\n", "print(f'Predicted: \"{predicted}\", Actual: \"{labels}\"')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从上面的打印结果可以看出,预测值与目标值完全一致。" ] } ], "metadata": { "interpreter": { "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" }, "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }