{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svcjIuMC4wLWFscGhhL3R1dG9yaWFscy96aF9jbi9iZWdpbm5lci9taW5kc3BvcmVfZGF0YXNldC5pcHluYg==&imageid=77ef960a-bd26-4de4-9695-5b85a786fb90) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/tutorials/zh_cn/beginner/mindspore_dataset.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/tutorials/zh_cn/beginner/mindspore_dataset.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r2.0.0-alpha/tutorials/source_zh_cn/beginner/dataset.ipynb)\n", "\n", "[基本介绍](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/introduction.html) || [快速入门](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/quick_start.html) || [张量 Tensor](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/tensor.html) || **数据集 Dataset** || [数据变换 Transforms](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/transforms.html) || [网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/model.html) || [函数式自动微分](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/autograd.html) || [模型训练](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/train.html) || [保存与加载](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/save_load.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据集 Dataset\n", "\n", "数据是深度学习的基础,高质量的数据输入将在整个深度神经网络中起到积极作用。MindSpore提供基于Pipeline的[数据引擎](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/design/data_engine.html),通过[数据集(Dataset)](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/dataset.html)和[数据变换(Transforms)](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/transforms.html)实现高效的数据预处理。其中Dataset是Pipeline的起始,用于加载原始数据。`mindspore.dataset`提供了内置的文本、图像、音频等数据集加载接口,并提供了自定义数据集加载接口。\n", "\n", "此外MindSpore的领域开发库也提供了大量的预加载数据集,可以使用API一键下载使用。本教程将分别对不同的数据集加载方式、数据集常见操作和自定义数据集方法进行详细阐述。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from mindspore.dataset import vision\n", "from mindspore.dataset import MnistDataset, GeneratorDataset\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据集加载\n", "\n", "我们使用**Mnist**数据集作为样例,介绍使用`mindspore.dataset`进行加载的方法。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`mindspore.dataset`提供的接口**仅支持解压后的数据文件**,因此我们使用`download`库下载数据集并解压。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)\n", "\n", "file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:02<00:00, 3.96MB/s]\n", "Extracting zip file...\n", "Successfully downloaded / unzipped to ./\n" ] } ], "source": [ "# Download data from open datasets\n", "from download import download\n", "\n", "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/\" \\\n", " \"notebook/datasets/MNIST_Data.zip\"\n", "path = download(url, \"./\", kind=\"zip\", replace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "压缩文件删除后,直接加载,可以看到其数据类型为MnistDataset。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "train_dataset = MnistDataset(\"MNIST_Data/train\", shuffle=False)\n", "print(type(train_dataset))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据集迭代\n", "\n", "数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。我们可以用`create_tuple_iterator`或`create_dict_iterator`接口创建数据迭代器,迭代访问数据。\n", "\n", "访问的数据类型默认为`Tensor`;若设置`output_numpy=True`,访问的数据类型为`Numpy`。\n", "\n", "下面定义一个可视化函数,迭代9张图片进行展示。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def visualize(dataset):\n", " figure = plt.figure(figsize=(4, 4))\n", " cols, rows = 3, 3\n", "\n", " for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):\n", " figure.add_subplot(rows, cols, idx + 1)\n", " plt.title(int(label))\n", " plt.axis(\"off\")\n", " plt.imshow(image.asnumpy().squeeze(), cmap=\"gray\")\n", " if idx == cols * rows - 1:\n", " break\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFeCAYAAAAIWe2LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhoUlEQVR4nO3de5xV8/7H8c/UUU1qpoukOExKugmRNKczDVK605VTKidyxFSOO0kuRcJDhORSSueE6HqEYlJUYzrknGRMwqimOzWjmhlN/f44Px/ftZr9bc/sy9qX1/PxmMfjvWfttfcnuz7Wd+21vt+Eo0ePHhUAQJkqeV0AAEQymiQAWNAkAcCCJgkAFjRJALCgSQKABU0SACxokgBgQZMEAAuaJABYxFSTXLFihSQkJJT5s3btWq/Lg0VxcbHcdddd0rBhQ0lMTJR27drJsmXLvC4L5TBhwgRJSEiQVq1aeV1KUP3B6wJCYdSoUdK2bVvH75o0aeJRNfDHsGHDZN68eTJmzBg566yzZObMmdKtWzfJzMyUDh06eF0ejmPr1q0yceJEOfHEE70uJegSYmmCixUrVsgll1wib731lvTr18/rcuCnzz77TNq1ayeTJ0+W22+/XUREioqKpFWrVnLyySfL6tWrPa4Qx3P11VfL7t27pbS0VPbs2SMbNmzwuqSgianhtqmwsFAOHz7sdRnww7x586Ry5coyYsQI/V21atVk+PDhsmbNGtmyZYuH1eF4Vq5cKfPmzZOnn37a61JCIiab5HXXXSdJSUlSrVo1ueSSS2TdunVelwSLL774Qpo2bSpJSUmO31900UUiIrJ+/XoPqoI/SktLJSMjQ66//no555xzvC4nJGLqnGSVKlWkb9++0q1bNznppJNk48aN8sQTT8if//xnWb16tZx//vlel4gybN++XRo0aHDM73/7XX5+frhLgp+mTZsmeXl5snz5cq9LCZmYapKpqamSmpqqj3v16iX9+vWT1q1byz333CPvvfeeh9XBl0OHDknVqlWP+X21atV0OyLP3r17Zdy4cXL//fdLvXr1vC4nZGJyuG1q0qSJ9O7dWzIzM6W0tNTrclCGxMREKS4uPub3RUVFuh2RZ+zYsVKnTh3JyMjwupSQiqkjSV/++Mc/SklJiRw4cOCY817wXoMGDWTbtm3H/H779u0iItKwYcNwl4Tj2LRpk0yfPl2efvppx+mQoqIi+fXXX+WHH36QpKQkqVOnjodVBkfMH0mKiHz33XdSrVo1qVGjhteloAznnXee5ObmSkFBgeP3WVlZuh2RZdu2bXLkyBEZNWqUNGrUSH+ysrIkNzdXGjVqJA899JDXZQZFTF0nuXv37mPOjXz55ZfStm1b6dq1qyxcuNCjymCTlZUlF198seM6yeLiYmnVqpXUrVuXu6Ui0J49e+STTz455vdjx46VwsJCmTJlijRu3DgmvvGOqSZ56aWXSmJioqSmpsrJJ58sGzdulOnTp8sJJ5wga9askebNm3tdInwYMGCAzJ8/X2699VZp0qSJvPbaa/LZZ5/Jhx9+KGlpaV6XBz+lp6fH3MXkMXVO8sorr5Q5c+bIU089JQUFBVKvXj3p06ePPPDAA9yWGOFmzZol999/v8yePVt+/vlnad26tSxZsoQGCc/F1JEkAARbXHxxAwAVRZMEAAuaJABY0CQBwIImCQAWNEkAsKBJAoCF3xeTJyQkhLIOGIJ16SqfWfgE83JjPrfw8edz40gSACxokgBgQZMEAAuaJABY0CQBwIImCQAWNEkAsKBJAoAFTRIALGiSAGBBkwQAC5okAFjE1GqJiD0XXHCB5ltuuUXzkCFDNM+aNcuxz7PPPqv5888/D2F1iAccSQKABU0SACz8Xnc7kue4q1y5subk5GS/9jGHbtWrV9d89tlnO5538803a37iiSc0X3PNNZqLiooc+zz22GOaH3zwQb/qMcXzfJLnnXee4/FHH32kOSkpya/X2L9/v+a6desGpa7jYT7J4Lrssss0z5kzR3PHjh0dz/vmm28Ceh/mkwSAANEkAcAi4r7dPv300zVXqVJFc2pqquN5HTp00FyrVi3Nffv2Dej9t27d6nj8zDPPaL7qqqs0FxYWav7yyy8d+3z88ccB1RBvLrroIs1vv/22Y5t5+sQcGpn//UtKShz7mEPsiy++WLP5Tbd7n2iXlpam2fzzz58/34tyAta2bVvN2dnZHlbCkSQAWNEkAcCCJgkAFhFxTtK87MO85MPfy3kCdeTIEc1jx451bPvll180m5cibN++XfPPP//s2CfQyxJilXmpVZs2bTS//vrrmhs0aODXa23atEnz448/7tg2d+5czZ9++qlm87N99NFH/XqfaJGenq75rLPO0hxN5yQrVfr9mK1Ro0aazzjjDM1eXB7FkSQAWNAkAcAiIobbP/74o+a9e/dqDnS4nZWV5Xi8b98+zZdccolm83KQ2bNnB/Se8O3FF1/UbN6xVBHmcL1GjRqObeYlWOYwtHXr1gG9ZyQzJ/xYs2aNh5VUnHmq5YYbbtBsno7JyckJa00iHEkCgBVNEgAsImK4/dNPP2m+4447NPfo0UPzF1984djHvBPGtH79es2XX365Y9uBAwc0t2zZUvPo0aPLVzD8Ys4FKSLSvXt3zb6+pXTfrbR48WLN5gQj+fn5mt1/N8yrDS699NLjvmcsML8ZjlYvv/xymb83r2TwQvT/lwWAEKJJAoBFRAy3TQsWLNBsXlhuTmggInLuuedqHj58uGZzSGYOr92++uorzSNGjKhQrTiWeWPAsmXLHNvM+SDNySqWLl2q2f2ttzl/oHkxuDk02717t2Mfc8IR80YBc7hvfjsuEp3LPJjf1tevX9/DSoLD19Us7r9H4caRJABY0CQBwIImCQAWEXdO0lRQUOBzm7mOicm8Uv+NN95wbDPPTyF4mjZtqtm8hMt9jmnPnj2azQlCXnvtNc3mhCIiIv/617/KzBWRmJio+bbbbnNsGzRoUECv7YVu3bppNv9s0cJ9HtWc1MK0bdu2cJTjE0eSAGBBkwQAi4gebtuMHz9es3lnh3nJSKdOnRz7fPDBByGvK15UrVpVs3nZlTkEdF+2ZU7CsG7dOs1eDBXNtZSilXv549+Yl7dFMvPvjYhz+J2bm6vZ/fco3DiSBAALmiQAWETtcNu8m8b8Rtu8c+Kll15y7JOZmanZHO4999xzms07QeDb+eefr9kcYpt69+7teMxSu+Hh9RKsIs67q6644grNgwcP1ty5c2ef+z/88MOazXlgvcCRJABY0CQBwCJqh9umzZs3ax42bJjmGTNmOJ537bXXlplPPPFEzbNmzXLsY170jN899dRTms15Gs0hdSQMr815FuPlZoI6deqUex9zwhj3vJvmVSKnnXaa5ipVqmh2X4xv/nc/dOiQZnNJleLiYsc+f/jD7+3o3//+t9+1hxpHkgBgQZMEAAuaJABYxMQ5SdP8+fM1u9fGMM+jXXbZZZonTpyo+YwzznDsM2HCBM1e32jvJXO9IRHn5LrmZVOLFi0KV0l+Mc9DmnWaayFFK/Ncn/lnmzZtmuZ7773Xr9cyJ/B1n5M8fPiw5oMHD2reuHGj5ldffdWxj3mJnXlueufOnZq3bt3q2Me888qLpWN94UgSACxokgBgEXPDbdOGDRscjwcMGKC5Z8+ems1LhW688UbHPmeddZZm9xK18cQ9CYV5+ceuXbs0u+fwDAdzsg0R5+QnJnPNpHvuuSeUJYXFyJEjNefl5WlOTU0t92v9+OOPms11pkREvv76a81r164t92ubzPWk6tWr59j23XffBfTaocKRJABY0CQBwCKmh9tu5o3ys2fP1mwuT2pe9S8ikpaWpjk9PV3zihUrgl5ftDLvnAjXHUrmENtcalbEuYSE+Q3qk08+qdm9TES0mzRpktcl+MW8qsTt7bffDmMl/uNIEgAsaJIAYBHTw23zAlkRkX79+mlu27atZvcQ22ReMLty5cogVhc7wnUBuXkBuzmkHjhwoON5Cxcu1Ny3b9+Q14XgMG8EiSQcSQKABU0SACxokgBgERPnJM2lNW+55RbNffr0cTzvlFNOOe5rlZaWOh6bl7TEy6StZXFPemA+vvLKKzWPHj06qO976623ar7//vs1Jycna54zZ45jH3PpWiBQHEkCgAVNEgAsoma47R4qX3PNNZrNIXZKSkq5X9uc+86cP1Ik8uZH9Ip7qV3zsfnZPPPMM5rdcwzu3btX88UXX6zZXG/IXGtFxLmmijkJw/vvv6/5+eefP/4fABHHfQqnadOmmgOdSCOYOJIEAAuaJABYRNxwu379+ppbtGiheerUqY7nNWvWrNyvbS5nOXnyZM3mHRrx/A12RVWuXFmzOceh+26XgoICzeY8nTarV6/WnJmZqXncuHHlrhORxX0Kx1yGNpJEZlUAECFokgBg4clwu06dOo7HL774omZzEoMzzzyz3K9tDs/M+QNFnN+ImivN4fjWrFnjeJydna3ZnCzE5L4iwTyVYjK/9Z47d65jW7AvTkfkat++veaZM2d6V4gLR5IAYEGTBAALmiQAWIT0nGS7du00m5OkXnTRRY7nnXrqqeV+7YMHD2o27/KYOHGi5gMHDpT7dVE2c60YEefkIeYyvO71ZnyZMmWK5hdeeEHzt99+W9ESEWXcd9xEKo4kAcCCJgkAFiEdbl911VVlZhtzTZklS5ZoPnz4sON55uU95lKxCA9zns3x48eXmQG3pUuXau7fv7+HlfiPI0kAsKBJAoBFwlH3Xea+nhgl30TFAj8/kuPiMwufYH1mInxu4eTP58aRJABY0CQBwIImCQAWNEkAsKBJAoAFTRIALGiSAGBBkwQAC5okAFj43SSPHj0adT+/rcMyY8YMz2spz0+weP3n8OenZs2a0r9//2N+3717d6lSpYoUFhZ6XmM4P7No+dx8/bRp00batGnjeR3B/Nw4koSniouLJTEx8ZjfV69eXUpKSmTDhg0eVIWKOHr0qOzcuVNOOukkr0sJKpokPHX22WfL2rVrpbS0VH9XUlIiWVlZIiKybds2r0pDOc2ZM0e2bdsmAwcO9LqUoKJJwlMjR46U3NxcGT58uGzcuFE2bNggQ4YM0fkqWfo3OuTk5MjNN98s7du3l6FDh3pdTlDRJOGpv/3tb3LvvffKP/7xD2nZsqWcc845snnzZrnzzjtFRKRGjRoeV4jj2bFjh3Tv3l2Sk5Nl3rx5UrlyZa9LCiqaJDw3YcIE2blzp6xatUr+85//SHZ2thw5ckRERJo2bepxdbDZv3+/dO3aVfbt2yfvvfeeNGzY0OuSgi6kyzcA/qpdu7Z06NBBHy9fvlxOO+00adasmYdVwaaoqEh69uwpubm5snz5cmnRooXXJYUETRIR54033pDs7Gx54oknpFIlBjuRqLS0VAYOHChr1qyRhQsXSvv27b0uKWRisklOnTpV9u3bJ/n5+SIisnjxYl03OiMjQ5KTk70sD4aVK1fKQw89JJ07d5a6devK2rVrZcaMGXLFFVfI6NGjvS4PPtx2222yaNEi6dmzp/z000/y+uuvO7YPHjzYo8qCz+/lG6JJSkqK5OXllbnt+++/l5SUlPAWBJ82b94sI0eOlM8//1wKCwulUaNGMnToUPn73/8uVapU8bo8+JCeni4ff/yxz+2x1FZiskkCQLBwwgcALGiSAGBBkwQAC5okAFjQJAHAgiYJABY0SQCw8PuOm4SEhFDWAUOwLl3lMwufYF5uzOcWPv58bhxJAoAFTRIALGiSAGBBkwQAC5okAFjQJAHAgiYJABY0SQCwoEkCgAVNEgAsaJIAYEGTBAALmiQAWNAkAcDC76nS4sXYsWMdjx988EHNlSr9/v+U9PR0zbb1h4F4VLNmTc01atTQ3L17d8316tVz7PPUU09pLi4uDmF15cORJABY0CQBwILhtogMGzZM81133eXYduTIkTL3CeZM1EA0SklJ0ez+d9O+fXvNrVq18uv1GjRooHnUqFGBFRdEHEkCgAVNEgAsaJIAYME5SRE544wzNFerVs3DSuJHu3btNA8ePFhzx44dHc9r2bJlmfvffvvtmvPz8x3bOnTooPn111/XnJWVVbFi41yzZs00jxkzRvOgQYM0JyYmOvYxV3zcsmWL5sLCQs3Nmzd37DNgwADNzz//vOacnJwKVB08HEkCgAVNEgAs4na43alTJ80ZGRk+n2ce6vfo0UPzzp07Q1NYDBs4cKDmKVOmaD7ppJM0m8M0EZEVK1ZoNu/QmDx5ss/3MV/D3Ofqq68uX8FxJDk5WfOkSZMc28zPzbyTxmbTpk2au3TpovmEE07Q7B5Gm38PzOw1jiQBwIImCQAWcTXcNr/1nDFjhmZzqOFmDuvy8vJCU1gM+cMffv8rdeGFFzq2vfTSS5qrV6+ueeXKlZoffvhhxz6ffPKJ5qpVq2p+8803NXfu3NlnPevWrfOn7Lh31VVXab7++uvLvf/mzZsdjy+//HLN5rfbTZo0qUB13uJIEgAsaJIAYBFXw+2hQ4dqbtiwYZnPMb9NFRGZNWtWKEuKOeaF4S+//LLP5y1btkyz+e1pQUGBz33M59mG2Fu3btX82muv+S4Wqn///n4974cfftCcnZ2t2T3BhTnENrkvII8GHEkCgAVNEgAsaJIAYBHT5yTdV+3/9a9/1WxOprtv3z7NjzzySMjrijXmZTv33nuvZvfExOakBeZaQrbzkKb77rvPr+eZE7bu3r3br33i3Q033KB5xIgRjm0ffPCB5m+//Vbzrl27yv0+9evXr0B13uJIEgAsaJIAYBFzw21z3Y23337br32effZZzZmZmcEuKSaNGzdOsznELikp0fz+++879jEvEzl06FCZr+uez9O81Of000/XbE5i4T5FsnDhQmvtOJY5J+f48eND9j7m2jfRgiNJALCgSQKARcwNt6+44grNrVu39vm8Dz/8ULM5tyHKVqtWLcfjkSNHaja/xTaH2FdeeaVfr21OejBnzhzHtgsuuKDMfebNm6f58ccf9+t9EFzmVQQnnniiX/ucc845PretXr1a85o1aypeWJBxJAkAFjRJALBIOOq+4tfXE13T6kcSc1g3c+ZMze4hgHk4b67MFmlLMfj5kRxXMD+zk08+2fHYvULhb84880zNRUVFjm3XXXed5l69emlu1aqV5ho1ajj2Mf9bmLlPnz6aFy9ebK09HIL1mYl4/2/NnOtTRKRFixaaH3jgAc3dunXz+RqVKv1+/GXeuGFy/x1KT0/X7J6fMlT8+dw4kgQAC5okAFjQJAHAImovAarInTXfffed5kg7DxnpzDtpRJwTR5jLtn7//fea/T1PZ56bck920aBBA8179uzRHAnnIaOdubzr+eefr9n978n8DMw7pczPzX3Jjnkpnvsc52/M9ZBEnOeZzcvy3H/3wo0jSQCwoEkCgEXUDrfNyRJ8XWLg9thjj4WqnJhnzrkp4rzsasmSJZrr1Kmj2X0ZhznxhHmp1k8//aR57ty5jn3MoZ57G8qvSpUqms0h8TvvvONznwcffFDzRx99pPnTTz/VbH7u7ueZl3iZzNM0IiKPPvqo5h9//FHzggULNBcXF/usM1Q4kgQAC5okAFhEzXD7vPPOczy2LSn6G/e8gt98800wS4prWVlZmt3DpvJKS0vT3LFjR8c281SKeXUC/GN+gy3iHDrfcccdZe6zdOlSx2NzvlXztIv5ub/77ruOfcyJLMxvp83JSNzD8N69e2s2JzpZvny55kmTJjn2+fnnn8v8M6xfv77M31cER5IAYEGTBACLqJngwr0yW+3atct83tq1azV37drVse2XX34JfmEhEIkTXIRSly5dNLuHbeZ/C/Ob7khbBTGSJrioXLmy5gkTJji23X777ZoPHDig+e6779bsvorAHNJeeOGFmqdOnVrm70WcqyredNNNms3lUZKSkhz7pKamah40aJBmczIU27yVW7Zs0dyoUSOfzzMxwQUABIgmCQAWNEkAsIiac5KlpaWOx77ushkyZIjmf/7znyGtKVTi7Zykyf05c06y/MxzgOblOyIiBw8e1DxixAjNH3zwgeZ27do59jEnSzbP8ycmJmp+6KGHHPvMmDFDs3musCKuueYazX/5y198Pu/WW2/VbJ4TteGcJAAEiCYJABYRPdw2D9mHDRvm2OZruG2usZKXlxeSukIt3obbXALkFOjntn37ds3uu6HMCSJycnI0m5fWmEv82owfP16zOTmFyLGnTSIVw20ACBBNEgAsIm6CC3Mii06dOml2D6/Nm+afe+45zSzLEH3MUyQI3I4dOzS7h9tVq1bVfO6555a5v/uUx8qVKzWbczv+8MMPmqNleF0RHEkCgAVNEgAsIm64XatWLc2nnHKKz+dt27ZNs3nTPqLPqlWrNFeq5Pz/tr9Lc+B35vyc5jIbIiJt2rTRbE4a8+qrr2p2z9Ho9WqFXuNIEgAsaJIAYEGTBACLiDsnifizYcMGzZs2bXJsMy8Paty4seZIu+MmkhQWFmqePXu2Y5v7MY6PI0kAsKBJAoBFxA23zZvuV69erblDhw5elIMwmzhxouPxyy+/rNlcryUjI0Pzxo0bQ18Y4hZHkgBgQZMEAIuInk8yXsXbfJIm9zKjb775pmZzwpN33nlHs7m8gIhzqdRwiaT5JOE/5pMEgADRJAHAguF2BIrn4babOfw2v902VwRs3bq1Yx8vvu1muB2dGG4DQIBokgBgQZMEAAvOSUYgzklGH85JRifOSQJAgGiSAGDhd5M8evRo1P1kZ2eLiMiMGTM8r6U8P8Hi9Z8j0J8ePXpI/fr1Pa8jnJ9ZtH5usfxvLeJmAUL8OnDggBw6dEj2798vixYtkqVLl8rAgQO9LgtxjiaJiHHbbbfJiy++KCL/WzWxT58+MnXqVI+rQryjSSJijBkzRvr16yf5+fny5ptvSmlpadwvZwrv8cUNIkazZs2kU6dOMmTIEFmyZIn88ssv0rNnz6Cf8wPKgyaJiNWvXz/Jzs6W3Nxcr0tBHKNJImIdOnRIRET279/vcSWIZzRJeG7Xrl3H/O7XX3+VWbNmSWJiorRo0cKDqoD/ickvbqZOnSr79u2T/Px8ERFZvHixbN26VUT+t4BUcnKyl+XB5cYbb5SCggJJS0uTU089VXbs2CFz5syRnJwcefLJJ6VGjRpelwgf4uHfmt/3bkeTlJQUycvLK3Pb999/LykpKeEtCFZz586VV155Rf773//K3r17pWbNmnLBBRdIRkaG9OrVy+vyYBEP/9ZiskkCQLBwThIALGiSAGBBkwQAC5okAFjQJAHAgiYJABY0SQCw8PuOGxYnCp9gXbrKZxY+wbzcmM8tfPz53DiSBAALmiQAWNAkAcCCJgkAFjRJALCgSQKABU0SACxokgBgQZMEAAuaJABY0CQBwIImCQAWMbGk7JQpUzSPGjVK84YNGxzP69Gjh2ZfK7wBgIkjSQCwoEkCgEXUDrfNRc8HDx6s+ciRI5qbN2/u2KdZs2aaGW6HX9OmTTWfcMIJmtPS0jQ///zzjn3Mz7MiFi5cqPnqq6/WXFJSEtDrxivzc0tNTdU8ceJEx/P+9Kc/ha2mUONIEgAsaJIAYBG1w+3du3drXrlypeZevXp5UQ7+X8uWLTUPGzbMsa1///6aK1X6/f/PDRs21OweXge6LIL592HatGmax4wZ43heQUFBQO8TL5KTkzVnZmZq3rFjh+N5p5xyis9t0YYjSQCwoEkCgAVNEgAsovac5IEDBzRzOU/kePTRRzV369bNw0qONWTIEM2vvPKKY9unn34a7nJiinkO0v2Yc5IAEMNokgBgEbXD7Vq1amk+99xzvSsEDsuWLdNsG27v2rVLszn0NS8NEvF9x415t0fHjh3LXSeCKyEhwesSQoYjSQCwoEkCgEXUDrerV6+u+fTTT/drn7Zt22rOycnRzLfjwfPCCy9oXrBggc/n/frrr5or8u1nUlKSZve8oeYdPCaznnXr1pX7PeGb+86oatWqeVRJ8HEkCQAWNEkAsIja4XZ+fr7mmTNnah4/frzPfcxt+/bt0zx16tQgVhbfDh8+rHnLli0he58uXbporl27tl/7bN26VXNxcXHQa8LvLrzwQs1r1671sJLAcSQJABY0SQCwoEkCgEXUnpM0Pfzww5pt5yQR3cw1am644QbNiYmJfu0/bty4oNcUb8xzzvv379dsTsYrItK4ceOw1RRqHEkCgAVNEgAsYmK4bTInSAh0OVKE36BBgxyP7777bs1NmjTRbC5tarN+/XrN5l0+qBjz0rlVq1Zp7tGjhwfVhAdHkgBgQZMEAIuYG26bQ+xAlyNF+aWkpGi+9tprHds6dep03P07dOjgeOzPZ+heDtYcor/77ruaDx06dNzXAtw4kgQAC5okAFjE3HAb4deqVSvNixYt0uzvPJ+BMr9lFRGZPn16WN4XvtWtW9frEoKGI0kAsKBJAoAFTRIALDgniaAylxatyDKj/i4pa3Lf7dG1a1fNS5cuLXcNCFyvXr28LiFoOJIEAAuaJABYxNxw298JLtLS0jSzxk1gzCVd09PTNQ8ePNjxvPfff19zUVFRud9n+PDhmjMyMsq9P4IrMzNTMxNcAECcokkCgEXCUT9ngajIN5VeKC0t1ezvBBetW7fWvHHjxqDXVF7BmpgjWj4zf5lLBOzdu9fn83r27Kk5XN9uB3MylWj53Pr27av5rbfecmwzJxNp0aKF5ry8vNAXVg7+fG4cSQKABU0SACxi7tvtadOmab7xxhv92mfEiBGax4wZE+ySECRdunTxugQYzJUT3cxTBlWrVg1HOSHDkSQAWNAkAcCCJgkAFjF3TjInJ8frEmKSewnXzp07a/7oo480B3sdmeuuu07zlClTgvraCMzChQs1u//dNWvWTLN5nn/kyJEhryvYOJIEAAuaJABYxNwdN6bc3FzH48aNG5f5PHNSjCZNmji2bd68OfiFHUek3HFjLu963333ObZdfvnlmhs1aqR5y5Yt5X6fOnXqaO7WrZtj27PPPqu5Zs2aZe7vHuKbcxmakzCEUjzecWN6+umnHY/N0yT169fXXJGJTUKJO24AIEA0SQCwiLlvt01fffWV4/GZZ55Z5vP8WSIgHpnzbJrLxrrdeeedmgsLC8v9PubQvU2bNo5tvoZDK1as0PzCCy84toVriA3fzM+tpKTEw0oCx5EkAFjQJAHAIqaH29OnT3c8NucZRPDcdNNNIXvtXbt2aV68eLHm0aNHa460b0whkpSUpLl3796a58+f70U5AeFIEgAsaJIAYEGTBACLmD4n6V6v5uuvv9bcvHnzcJcTdYYNG6bZvYTr0KFDA3pt806mgwcPal61apXjeeZ5ZXPpWkSWAQMGOB4XFxdrNv/dRSOOJAHAgiYJABYxPcFFtIqUCS5M7nVKzKH4I488orl27dqaFyxY4Nhn2bJlms25CHfs2BGkKr0T7xNczJ071/HYPJ1lTjjCkrIAEGNokgBgwXA7AkXicBt28T7cjlYMtwEgQDRJALCgSQKABU0SACxokgBgQZMEAAuaJABY0CQBwIImCQAWft9xAwDxiCNJALCgSQKABU0SACxokgBgQZMEAAuaJABY0CQBwIImCQAWNEkAsPg/JUsafnMQnogAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "visualize(train_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据集常用操作" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pipeline的设计理念使得数据集的常用操作采用`dataset = dataset.operation()`的异步执行方式,执行操作返回新的Dataset,此时不执行具体操作,而是在Pipeline中加入节点,最终进行迭代时,并行执行整个Pipeline。\n", "\n", "下面分别介绍几种常见的数据集操作。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### shuffle\n", "\n", "数据集随机`shuffle`可以消除数据排列造成的分布不均问题。\n", "\n", "![op-shuffle](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/tutorials/source_zh_cn/advanced/dataset/images/op_shuffle.png)\n", "\n", "`mindspore.dataset`提供的数据集在加载时可配置`shuffle=True`,或使用如下操作:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFeCAYAAAAIWe2LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAls0lEQVR4nO3dd3RU1drH8QdCS+i9CIKAgBRFqvrS9CJFig0EAQHhCogKiAgXuAiC4r26ECmKoBBBBQEbRRArRWlBQKR4c6Wo9A6hBRLn/eMuH/cZMptJZjJnZvL9rJW1fmfOnJkdxjyeveecvbN5PB6PAADSlN3tBgBAOKNIAoAFRRIALCiSAGBBkQQAC4okAFhQJAHAgiIJABYUSQCwoEgCgEXUFcnk5GQZNmyYlClTRmJjY6Vhw4by5Zdfut0sWOzYsUM6duwoFStWlLi4OClWrJg0adJElixZ4nbTYHHu3DkZPXq0tGrVSooUKSLZsmWTd955x+1mBV3UFcmePXvKq6++Kl27dpVJkyZJTEyM3HPPPfLdd9+53TT48Ouvv0pSUpL06NFDJk2aJKNGjRIRkfbt28uMGTNcbh18OX78uIwdO1Z27dolt9xyi9vNyTTZommCi40bN0rDhg3llVdekSFDhoiIyKVLl6RmzZpSokQJWbt2rcsthL9SU1Olbt26cunSJfn555/dbg7SkJycLKdOnZJSpUrJpk2bpH79+hIfHy89e/Z0u2lBFVVnkh9++KHExMRInz599LE8efJI7969Zd26dfL777+72DqkR0xMjJQrV05Onz7tdlPgQ+7cuaVUqVJuNyPT5XC7AcG0ZcsWqVKlihQoUMDxeIMGDUREZOvWrVKuXDk3mgY/nD9/Xi5evChnzpyRxYsXy/Lly6VTp05uNwtZXFQVyUOHDknp0qWvevzPxw4ePBjqJiEdnnnmGZk+fbqIiGTPnl0eeOABmTp1qsutQlYXVUXy4sWLkjt37qsez5Mnj+5H+Bo0aJB06NBBDh48KAsWLJDU1FS5fPmy281CFhdVY5KxsbGSnJx81eOXLl3S/Qhf1apVk+bNm0v37t1l6dKlcu7cOWnXrp1E0XeLiEBRVSRLly4thw4duurxPx8rU6ZMqJuEAHTo0EESEhIkMTHR7aYgC4uqIlm7dm1JTEyUs2fPOh7fsGGD7kfk+HN45MyZMy63BFlZVBXJDh06SGpqquMC5OTkZImPj5eGDRvyzXaYOnr06FWPXblyRebMmSOxsbFSvXp1F1oF/E9UfXHTsGFD6dixowwfPlyOHj0qlStXltmzZ8u+fftk5syZbjcPPvTt21fOnj0rTZo0keuuu04OHz4s77//vvz8888yYcIEyZcvn9tNhA9Tp06V06dP65UjS5Yskf3794uIyFNPPSUFCxZ0s3nB4YkyFy9e9AwZMsRTqlQpT+7cuT3169f3fP755243Cxbz5s3zNG/e3FOyZElPjhw5PIULF/Y0b97cs2jRIrebhmsoX768R0TS/Nm7d6/bzQuKqLotEQCCLarGJAEg2CiSAGBBkQQAC4okAFhQJAHAgiIJABYUSQCw8PuOm2zZsmVmO2AI1qWrfGahE8zLjfncQsefz40zSQCwoEgCgAVFEgAsKJIAYEGRBAALiiQAWFAkAcCCIgkAFhRJALCIqjVuAISHefPmab7ttts0d+7cWfOfq5iGO84kAcCCIgkAFlmqu926dWvNS5cu1fzncpgiIn369HEcs2nTJs3Hjh3LxNZFH++lYGvWrKm5Q4cOms+ePav51ltvdRxTunRpzW+++abmOXPmaP7jjz8CbyyCqnz58porVKig+b333tPsvZ76lStXMr1dGcGZJABYUCQBwMLvdbejYY47s7u9ePFiv45ZsmSJ5gceeCDobUpLpM0nWalSJc3jxo3T3KpVK8fzChUqpPnSpUuaU1JSNOfNm9dxTHJysuY8efJovvvuuzV//fXXGWh1cGX1+STLlSvn2N69e7fmnDlzpnlMXFycY/vixYvBb9g1MJ8kAASIIgkAFlH37bbZpXvjjTcc+xo3bpzu10tISAi0SVHv9ddf12x+02x+kykicuLECc3r1q3T/PPPP2suUKCA4xizW24OfTz55JOaw6G7ndUVLFjQse2ri/3pp59qNodSwhlnkgBgQZEEAAuKJABYRN2Y5M0336y5QYMGjn1lypTR7O9dGmPHjtW8c+dOzYsWLcpoE6OOeZfSb7/9lmnvY45PtmzZUrP3OKZ5Bw8yT44cf5WP4cOH+3XM3LlzNUfKnVKcSQKABUUSACyirrtduHBhzd53bwRq+vTpms2ugnlpSlaUmV3shg0bajYnyPj44481JyUlZdr7w7eJEydq7tKli4styVycSQKABUUSACwitrs9adIkzebdFzbZswf2/4SSJUtqNufLQ/B4D5G88847ms07dp544gnNwZxcAnaPPfaY5t69e7vYktDhTBIALCiSAGARsd1ts4uVkYtSFy5cqHnNmjWamzRp4nierzkk77//fs0ffPCBY9/x48fT3Z6srFixYpoXLFjg2GfOVXnXXXdp5t84dB599FHNU6dO1ZwrVy7NmzdvdhxTp06dzG9YiHAmCQAWFEkAsKBIAoBFWI9JXnfddZq9xwq7deuW5jGnT5/W7D1uZS4Pa142ZK6tUaJECb/aZrbHe8JRxsvSVqpUKc3du3fX3LlzZ821a9d2HHP58mXN5viw+Txz0gQRkZMnTwba1KjhvazvLbfcorlKlSqazTubHnroIccx5l1spgEDBmhetmyZY98vv/yS/saGKc4kAcCCIgkAFmHX3Ta7A+YaKdWrV3c8z9dlP++++67mwYMH+/WelStX1jxixAi/jsG1NW3a1LFt3j3j7x1L5mUmgwYNSvM55l0gIs7/hrK6smXLOrZnzZql2exum86cOePYfuuttzS//PLLmvft2+fzfaIJZ5IAYEGRBACLsOtum98aV6tWLSTveejQIc1vvvmmY1+/fv2uefyYMWMc24888khQ2hXpzp8/79jesmWL5jlz5mjes2ePZn+XxXj44Yc1v/baa459o0aN0jxu3Di/Xi9amcv1ijiXN7nxxhvTPMZ7+Ytgzhca7DleQ4EzSQCwoEgCgEU2j5+T8WXLli3TGmGegr/yyiua+/btq9l7Lsjt27drbtGihWaz6xzo+9va4Ov9M9oGU7DmR8zMzyyceC+f0ahRI82+LoQOtmDOaRmJn1vRokUd2+bfh3kTwaeffqrZnCTGLf58bpxJAoAFRRIALCiSAGARFpcAPf/885rNuydsk+mal+oEOgZYunTpNN/fuw07d+7U3LVr16C9PwIzc+ZMx7Y5JonQMNcfEhHZu3evZnNM8ttvvw1Zm4KFM0kAsKBIAoBFWHS3M7IeRvHixTXnzJlT85UrV4LSprT07NlT87Zt2zLtfRCYHDn++s/aXD+HeT7dF4lDU5xJAoAFRRIALMKiu/3ZZ59pbty4sV/HmN9gmssn2LpUFSpU0NyuXTvNZpfM2+LFizXXq1dP8w8//OBXO5H5vD+/lJQUzXSx3Wfe1XL06FEXW5IxnEkCgAVFEgAswmKCi//85z+aK1asmOZzNm/e7Nju2LGjZn/nu2vdurVmsxtt07ZtW80rVqzw65hAMcFF+nh34cyrHZjgIv3M5UxERIoUKZLm8y5cuKDZe4XKO+64Q7O55MPQoUM1r1q1ynFMXFyc5hdeeEHzwoULNfv7d+svJrgAgABRJAHAgiIJABZhcQmQeTmN93jInxo0aODYHjZsmOalS5dqNi/tMSfMFXFOmutr8ozJkyc7tkM1DhlpzLta/vWvf2kePny45mDf/RQTE6N5ypQpmr0vAcrq69rYmEv0muP/ffr00ez9d2OOFZouX76s+dy5c459vsYxzfHFY8eO+WybeVnf4cOHNQd7TNIfnEkCgAVFEgAswqK7ffDgQc22OSRNZvfAzCbba5n7zNP+jz76yK/3z+rMO54GDx6s+aabbtL89NNPO45JTExM9/uYXcIZM2ZovuuuuzT/9NNPjmPMrnhWV7JkScf2pEmTNHfq1Cndr2dOUGFePrNjxw7H83788cd0v7Yvs2fPDtprZQRnkgBgQZEEAIuw6G5nZD7JjDh//rzmAwcOaDbnidywYUNI2hLpNm3apHn//v2a77nnHs3Vq1d3HGN+821OPGF23cuVK+c4xryzKl++fJrN7lyrVq0cxzCpxV+6dOni2Pani21eLSIiMmHCBM3ff/+95sycuzWccCYJABYUSQCwCIsJLsylGD744APNTZo00WxeCC7i37fg3heefvHFF5qnT5+e7naGSqRNcFGzZk3Nc+fOTfNxf3m32fy3+PrrrzU/++yzmrdu3Zru9wm2cJ3gwpxDVcT5N2FeVTJ//nzN8fHxQXv/cMcEFwAQIIokAFhQJAHAIizGJE3ly5fXbI6T1K9f3/E8c0zyscce02zeEeC9Dk2kXBoSaWOSpmrVqml++OGHHfsGDBig2Zyw1ZxQ2RyTFnGuf5SUlKQ5NTU18MYGUbiOScKOMUkACBBFEgAswq67jcjubmdVdLcjE91tAAgQRRIALCiSAGBBkQQAC4okAFhQJAHAgiIJABYUSQCwoEgCgAVFEgAsKJIAYEGRBAALv4ukx+MJ+5/8+fNLx44dr3q8TZs2kitXLklKSnK9jf78BIvbv4e/P9OmTRMRkZ07d4rH45Fz585Jamqq6+1y4zOLpM8trZ86depInTp1XG9HMD+3qDqTTE5OltjY2Ksej4uLk8uXL8v27dtdaBWu5auvvpICBQrIgQMHpGrVqpIvXz4pUKCAPP7443Lp0iW3mwc/eTweOXLkiBQrVsztpgRVVBXJqlWryvr16x2zVl++fFk2bNggIiIHDhxwq2mw+O9//yspKSly7733SsuWLeWjjz6SXr16yZtvvimPPvqo282Dn95//305cOCAdOrUye2mBJcnikybNs0jIp4ePXp4duzY4fnpp588nTp18uTMmdMjIp53333X7SYiDRUrVvSIiKdfv36Ox/v27esREU9iYqJLLYO/du3a5SlQoIDn9ttv96SkpLjdnKCKqjPJfv36yYgRI2Tu3LlSo0YNqVWrluzevVuGDh0qIiL58uVzuYVIy59DJN5r4nTp0kVERNatWxfyNsF/hw8fljZt2kjBggXlww8/lJiYGLebFFRRVSRFRF588UU5cuSIrFmzRrZt2yYJCQm6aFiVKlVcbh3SUqZMGRERKVmypOPxEiVKiIjIqVOnQt4m+OfMmTPSunVrOX36tHz++ef6WUaTqCuSIiKFCxeWRo0aSa1atUTkf18MlC1b1rGSH8JH3bp1ReTqMeODBw+KiEjx4sVD3iZc26VLl6Rdu3aSmJgoS5culerVq7vdpEwRlUXSNH/+fElISJBBgwZJ9uxR/+tGpIceekhERGbOnOl4/O2335YcOXJIs2bNXGgVbFJTU6VTp06ybt06Wbhwodx+++1uNynT5HC7AcG0evVqGTt2rLRo0UKKFi0q69evl/j4eGnVqpUMHDjQ7ebBh1tvvVV69eols2bNkpSUFGnatKmsXLlSFi5cKMOHD4/KLlyke+aZZ2Tx4sXSrl07OXnypLz33nuO/d26dXOpZcHn92qJkWD37t3Sv39/2bx5syQlJckNN9wgPXr0kMGDB0uuXLncbh4srly5IuPHj5f4+Hg5ePCglC9fXp544gkZNGiQ201DGpo1ayarVq3yuT+Kykp0FUkACDYG6QDAgiIJABYUSQCwoEgCgAVFEgAsKJIAYEGRBAALv++4yZYtW2a2A4ZgXbrKZxY6wbzcmM8tdPz53DiTBAALiiQAWFAkAcCCIgkAFhRJALCgSAKABUUSACwokgBgEVXLNyDytWzZ0rH9j3/8Q/MXX3yh+YcffkjzcSDYOJMEAAuKJABY0N1GWGnbtq1ju0mTJpqbNm2q2VyE6vvvv3ccc/78+UxqHR588EHNefLk0VyvXj3H88wF3L799lvN5rLBu3btchyzefPmYDUzqDiTBAALiiQAWPi9pCzTN4VOVpsqrXXr1prnz5/v2Jc3b17N5u9j/huVK1fOcczBgweD3cRrivSp0mJjYzVXrVrVsW/cuHGa//a3v2nOnTt3QO+5d+9ex/Y333yjediwYZrPnj2rOTU1NaD39MZUaQAQIIokAFjQ3Q5DWaG7XalSJc2bNm3SXKBAAZ/HmL/Pe++9p7lXr16O56WkpASjiekSKd3tm2++WXPjxo01mxfxt2nTJtPePyOef/55zR9//LFj3/bt2wN6bbrbABAgiiQAWFAkAcAiKu64adiwoeZu3bppNu/QEBGpUaNGmscPGTJEs/flI40aNdJsjoNt2LAhY42FiIgMHDhQs20c0vT1119rHjt2rGY3xiAjlTkOOXny5HQf/9tvv2nOyOU4pUuX1mzesWMzevRozceOHXPsC3RM0h+cSQKABUUSACwitrvdqVMnzZMmTdJcrFgxzd6XUqxcuVJz8eLFNb/yyis+38d8DfOYzp07p6/BcPw7m5+fv+6+++5gNifL+/TTTzXfd999mg8fPux43ttvv63Z/AzPnTuX7vccMGCA5okTJ6b7eDdwJgkAFhRJALAI6+52jhx/Nc97vrq33npLc1xcnObVq1drNm/MFxH57rvvNJs35y9YsEBzixYtfLbHvDME6Ve7dm3N5rCI6Y8//nBsT5kyJTOblOXMnTtX87vvvqt55MiRmi9duuQ4Zt++fUF7/40bN6b7GHN+0OPHjwetLf7iTBIALCiSAGAR1t1t88Jw8xs2b19++aVm81tTcx46b+bzbF3s/fv3a549e7bvxuIq3hfvV69e/ZrHvPPOO47twYMHB7NJWd6pU6fSfNz2t5IROXPm1Dx+/HjNHTt2TPdrmXNLLly4MLCGZQBnkgBgQZEEAAuKJABYhN2YpHnZzogRIzR7T475xhtvaP7nP/+p2d+xFfOSBxvzDgHvm+th16dPH8d2qVKlrnnM1q1bM6k1yEx33nmnY/vpp5/WnJFJfPfs2aP5k08+yXjDgoAzSQCwoEgCgEVYdLefe+45zWYX+/Lly5pXrFjhOMa8LODixYtpvq73fHXmpT7XX3+9ZnMSixdeeMFxzKJFi6xth2+33Xab201AJnr00Uc1T58+3bEvJiYm3a9nzhFqTr7hPeFGqHEmCQAWFEkAsHClu12oUCHHdv/+/TWb32KbXWxzvjubypUra37//fcd++rWrZvmMR9++KHml19+2a/3QdqqVq2quUyZMo59/iyV6j1npK+lHSZMmKDZHJZB8JnL0N57772aR40apdnf7rU5ecayZcsc+8w72oI5qUagOJMEAAuKJABYZPN4X6Xt64l+dJX8VaJECce29wqFf6pYsaJm7znuzG/W2rdvr7lmzZqa8+XL5zjG/FXN/MADD2hesmSJte2h4OdHck3B/Mz8df/992s2hzGCwfx9zKUDfvjhB8fzzM/T14QOwRasz0zEnc/NnJCiUqVKjn3mFR7mcJbJe+XEK1eupPk880oWc8jELf58bpxJAoAFRRIALCiSAGDhyiVA3pdsmBNHmMu27t27V7O/Yz7m+Kb3ZBelS5fWbK6VEQ7jkEifvHnzam7SpIljnzmxxsyZMzW7sT5KpDDvYHv++ef9OmbNmjWa58+f79g3bdq04DQsDHAmCQAWFEkAsHClu3369GnHtnk3zdKlSzUXKVJE8+7dux3HmJclmOuinDx5UvMHH3zgOMbsbnvvQ3CYwx3en7P3nVZp+fHHHx3b5pCLeXmRjbmmSvPmzTV36dJFc1adG9QcpjAv5+nZs6dfx3/77beaH3nkEc2HDh0KvHFhijNJALCgSAKAhSt33GQm85vOVatWOfb98ccfmgcNGqR5ypQpmd6u9IjkO25MGzZscGzXq1fvmsdMnTrVsT1x4kTNxYoV02wuMVyrVi2/2vP4449rnjFjhl/H+CtS7rh54oknNE+ePNmvY1auXKnZHBpLSkoKVrNcwx03ABAgiiQAWITF8g3BFBsbq9nsXos4T635djtzNGrUSLM5t6S/nnzyScd269atNXfr1k1zRrqk5qQowe5uh6tq1ao5tocOHXrNY77++mvHtvnvHmgXu3z58prNb9pFnEunmM+zMSc6GT58uOa1a9dmtIlX4UwSACwokgBgQZEEAIuoG5P0XnoWobVnzx7NP/30k2PfHXfcke7XMyeANS8p8h5v9uXChQuas8o4ZO3atTUvWLDAsa9s2bLXPP6XX35xbN94442ajx49muYxY8aMcWz7WvOma9eumv0dd7Qxx5mDOQ5p4kwSACwokgBgEXXd7ZYtW7rdhCzNnODioYcecuwz17y57bbbQtIec/glPj4+JO/pNrO7/M033zj2ea9fk5a+ffs6ts3P0XuO1j9df/31ju1Q3e113XXXZfp7cCYJABYUSQCwiLrutrkMLdzlPcegudSrOdnE008/rdl7GeCM2Lhxo+b+/fsH/HqRxrwLZeDAgY59BQsW1Ow9HOJL4cKF08yhMnr0aM0nTpxw7Js1a1amvz9nkgBgQZEEAIuom0+yZs2amr0vZjYvQC5VqpTmcJvKP1rmk/RX/vz5NZvzHYqItGjRQnNycrJms1v+2WefOY556623NHt3zzJLpMwn2bZtW83mnKp33nlnpr3n77//rrlz586ad+3a5dfx5qQa/t5E4C/mkwSAAFEkAcCCIgkAFlE3JmlKTEx0bJuXB5mTw65fvz5kbfJHVhuTjAaRMiZpMsd127Vrp7lChQqO55mT4ZrMCUNWr17t833MSU+81z1yG2OSABAgiiQAWER1d7tnz56ObXMZUnO52aeeekrzzp07M71d10J3O/JEYncbdLcBIGAUSQCwiOrudoECBRzb5lT2zZs31/zxxx9rNqeDFxE5f/58JrXON7rbkYfudmSiuw0AAaJIAoBFVHe3vZnd7xdffFGzObfhzTff7DjGjW+76W5HHrrbkYnuNgAEiCIJABYUSQCwyFJjkpGCMcnIw5hkZGJMEgACRJEEAAu/i6TH4wn7n/z580vHjh2verxNmzaSK1cuSUpKcr2N/vwEi9u/R0Z+EhISREQkPj7e9ba48ZlFyuc2bNgwEfnf+lDej2fPnl3OnTvnehuD9blF1ZlkcnKyxMbGXvV4XFycXL58WbZv3+5Cq4Do06xZMxER6d27t2zdulV+//13mT9/vkybNk0GDBggefPmdbeBQZTD7QYEU9WqVWX9+vWSmpoqMTExIiJy+fJlnQ35wIEDbjYPiBqtWrWScePGyfjx42Xx4sX6+MiRI33OZB6poupMsn///pKYmCi9e/eWnTt3yvbt26V79+5y6NAhERG5ePGiyy0EokeFChWkSZMmMmPGDPnoo4+kV69eMn78eJk6darbTQsuT5QZMWKEJ2fOnB4R8YiIp169ep6RI0d6RMTzySefuN08XENCQoJHRDzx8fFuNwUW8+bN88TGxnp+//13x+M9e/b0xMXFeY4fP+5Sy4Ivqs4kRf53T/aRI0dkzZo1sm3bNklISNAFzatUqeJy64Do8MYbb8itt94qZcuWdTzevn17uXDhgmzZssWllgVfVI1J/qlw4cKO1RC/+uorKVu2rFSrVs3FVgHR48iRI1K4cOGrHr9y5YqIiKSkpIS6SZkm6s4kvc2fP18SEhJk0KBBkj171P+6QEhUqVJFtmzZctWyzfPmzZPs2bNfNZtWJIuqM8nVq1fL2LFjpUWLFlK0aFFZv369xMfHS6tWrWTgwIFuNw8WU6dOldOnT8vBgwdFRGTJkiWyf/9+EfnfQm0FCxZ0s3nw8uyzz8ry5culcePG8uSTT0rRokVl6dKlsnz5cvn73/8uZcqUcbuJQeP3vduRYPfu3dK/f3/ZvHmzJCUlyQ033CA9evSQwYMHS65cudxuHiwqVKggv/76a5r79u7dKxUqVAhtg3BNGzdulDFjxsiWLVvkxIkT+vc2dOhQyZEjes6/oqpIAkCwMUgHABYUSQCwoEgCgAVFEgAsKJIAYEGRBAALiiQAWPh9xSeLE4VOsC5d5TMLnWBebsznFjr+fG6cSQKABUUSACwokgBgQZEEAAuKJABYUCQBwIIiCQAWFEkAsKBIAoAFRRIALCiSAGBBkQQAi+hZ0gwRZdKkSZoHDBigefv27Y7ntW3bVrOv1RQRHYYMGaJ59OjRmp977jnNEydODGmbRDiTBAAriiQAWNDdRshUqFBBc7du3TT/8ccfmm+66SbHMdWqVdNMdzt81ahRw7GdI0fapSUpKUnznj17fL5eXFyc5kceeUQz3W0ACDMUSQCwoLuNkDl27Jjm1atXa27fvr0bzUEGFC1aVPNrr72m+cEHH3Q8L3fu3JrNJRJOnjypuUGDBn695759+9LZyuDiTBIALCiSAGBBkQQAi7Aek6xdu7bmdu3aOfaZd2kUK1ZMszn+MXLkSMcxL730UkDtyZ8/v+bhw4drrlWrluN5L774oub169cH9J7R5Pz585q5nCdyNGrUSPOYMWM0N2vWLN2vVaRIEc358uVz7OvevbvmEydOaO7du3e63yeYOJMEAAuKJABYhF13+91339XcqVMnzTExMT6PMe/YMI0bN86xvXbtWs2rVq3yqz2FChXSvHz5cs22yxfMy1vobv/F/Le85ZZb3GsIrFq3bu3YXrBggebY2Ng0j5k2bZrP12vatKnm5s2bax46dKjjeebdVcePH9d86tSpa7Q4c3EmCQAWFEkAsHClu12vXj3HtjmPXIcOHTRny5ZN888//+w4pk2bNprNU/PKlStrbty4seOY7777Lt1tnTBhgmZfXewvvvjCsW3OlYi/mJMWXH/99X4dU79+fc3mfwN8O555li1b5tg2h7O2bt2quVWrVprNu6lEnN98Jycna54yZUqazxFxDql9/vnn6WpzZuJMEgAsKJIAYJHNY159bXui0fUN1MKFCx3bDzzwgOb+/ftr/vDDDzWbp+wiIufOnQtae0xdu3Z1bL/99tuac+XKpdm8Ub9cuXKOYy5duhRQG/z8SK4pmJ9ZsI0aNUqzeYGy7XcfNGiQ5qlTp2ZGszIsWJ+ZiPufW2pqqmPb/N1mzZqleeDAgZovXrzo8/XMb7Rt3Wizmx+qSU/8+dw4kwQAC4okAFhQJAHAwpUxSe8xu5w5c2ouUaKEZvMm92ArXry45g8++EDzbbfd5nhenjx5NF+4cEFzly5dNC9ZsiSobcsKY5ImcwyMMUn3P7d///vfju1nnnkmzeetW7dO81NPPeXz9cy/j9KlS2s+dOiQ43nmJDbmpUaZiTFJAAgQRRIALFy54+bHH390bJt34HTs2FHz7NmzNdsuMfClZcuWju06depoNi81KlOmjF+v99xzz2kOdhc7K8ue/a//V/uarASh889//tOxbU48Yd7pdvvtt2v2vrTHHDIw18Uxu9j33nuv45hQdbHTizNJALCgSAKAhSvfbptTuIuIfPbZZ5rNSST27NmjOSUlJd3v4z2JgvlNtb9WrFih+eGHH9Z85syZdL+Wv/h2O218u+2OggULah4xYoRmX996izh/h6NHj2o256oMh+41324DQIAokgBg4cq32+bkECLOi0jNuRjNb9XMlRNtzKUTvOegNLv55gpwpn379jm2zYvGM7OLDYQr8797czkS27CAecXC4cOHNYdDFzu9OJMEAAuKJABYUCQBwCIslpQ116gxJ70tVqyY5vLly/v1WuY45Pnz5x37/u///k+zOXZp8l6f5vTp0369LxCt6tatq3ny5MmazctnvNe4MccrzUmpa9asqXn79u1BbWdm4UwSACwokgBgERbdbV/MbriZM6pixYppPr57927N77//fsDvg/Txd4KLJk2aaA63O26iiffldp9++qlmX/NBmpfxiTgvnRs8eLDmJ598UnO/fv0CbWpIcCYJABYUSQCwCOvudqDuu+8+x/Zrr72W5vNef/11zZm5ZATSZnaxbRMOmEsPV69eXfPOnTszp2FZlDmpi4h/80F630ljdrdN5vIohQsXduw7depUutsaCpxJAoAFRRIALKK6u+09312hQoU0JyYmap43b16omoQ0vPnmm5r79u3r1zF9+vTRbM4ziYyZNWuWZvMmDhHnEIj57fTmzZvT/T758+fXnJH5Xd3AmSQAWFAkAcCCIgkAFlE3JmmOeeTNm9ex78KFC5pffvllzeYaHAg978mRERqPPPKI5h49emg274AScU76smjRIr9e2/zbMye7MC8VMi8nCmecSQKABUUSACxcWVI2M5ldiHfeecex7/vvv9dsTpYQbrLakrIm89IsEZFKlSql+TyzS1i5cmXHPnPCklCJxCVl586dq/mhhx7SbC7xK+K808lc/tlUo0YNx/a2bds0//bbb5rvuOMOzeHQ3WZJWQAIEEUSACyi4tttc3r5V1991efzPvroo1A0BwHYsWOHY9vXHKC2eSfhn5tuuinNx8+dO+fYvnLliuYWLVpoNicZGT9+vM/32bJli+Zw6GKnF2eSAGBBkQQAi6jobpctW1ZzkSJFNCcnJzuel5CQELI2IWNmzJjh2PZeFgDBY87DWatWLc3mRDAiIsuWLUv3a0+fPl3zyJEj09+4MMKZJABYUCQBwIIiCQAWUTEm2bZt2zQf/+qrrxzba9euDUVzEADv9Wp27dql2dclK8gY83K5/fv3ax4yZIjPYzZt2qQ5Li5Oc/v27R3P27dvXxBaGB44kwQAC4okAFhExQQXL730kuahQ4dqbtmypeN5K1eu1JySkpLp7cqorDzBRaSKxAkuwAQXABAwiiQAWETFt9vr169P8/EVK1Y4ts2b8EeNGpWpbQIQHTiTBAALiiQAWERFd9tcbW/OnDmavSe4mDlzZsjaBCA6cCYJABYUSQCwoEgCgEVU3HETbbjjJvJwx01k4o4bAAgQRRIALPzubgNAVsSZJABYUCQBwIIiCQAWFEkAsKBIAoAFRRIALCiSAGBBkQQAC4okAFj8P5KLNhLoIQY2AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_dataset = train_dataset.shuffle(buffer_size=64)\n", "\n", "visualize(train_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### map" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`map`操作是数据预处理的关键操作,可以针对数据集指定列(column)添加数据变换(Transforms),将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集。这里我们对Mnist数据集做数据缩放处理,将图像统一除255,数据类型由uint8转为float32。\n", "\n", "> Dataset支持的不同变换类型详见[数据变换Transforms](https://www.mindspore.cn/tutorials/zh-CN/r2.0.0-alpha/beginner/transforms.html)。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(28, 28, 1) UInt8\n" ] } ], "source": [ "image, label = next(train_dataset.create_tuple_iterator())\n", "print(image.shape, image.dtype)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_dataset = train_dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "对比map前后的数据,可以看到数据类型变化。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(28, 28, 1) Float32\n" ] } ], "source": [ "image, label = next(train_dataset.create_tuple_iterator())\n", "print(image.shape, image.dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### batch\n", "\n", "将数据集打包为固定大小的`batch`是在有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量。\n", "\n", "![op-batch](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/tutorials/source_zh_cn/advanced/dataset/images/op_batch.png)\n", "\n", "一般我们会设置一个固定的batch size,将连续的数据分为若干批(batch)。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "train_dataset = train_dataset.batch(batch_size=32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "batch后的数据增加一维,大小为`batch_size`。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(32, 28, 28, 1) Float32\n" ] } ], "source": [ "image, label = next(train_dataset.create_tuple_iterator())\n", "print(image.shape, image.dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 自定义数据集" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`mindspore.dataset`提供了部分常用数据集和标准格式数据集的加载接口。对于MindSpore暂不支持直接加载的数据集,可以通过构造自定义数据集类或自定义数据集生成函数的方式来生成数据集,然后通过`GeneratorDataset`接口实现自定义方式的数据集加载。\n", "\n", "`GeneratorDataset`支持通过可迭代对象、迭代器和生成函数构造自定义数据集,下面分别对其进行详解。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 可迭代对象\n", "\n", "Python中可以使用for循环遍历出所有元素的,都可以称为可迭代对象(Iterable),我们可以通过实现`__getitem__`方法来构造可迭代对象,并将其加载至`GeneratorDataset`。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Iterable object as input source\n", "class Iterable:\n", " def __init__(self):\n", " self._data = np.random.sample((5, 2))\n", " self._label = np.random.sample((5, 1))\n", "\n", " def __getitem__(self, index):\n", " return self._data[index], self._label[index]\n", "\n", " def __len__(self):\n", " return len(self._data)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "data = Iterable()\n", "dataset = GeneratorDataset(source=data, column_names=[\"data\", \"label\"])" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# list, dict, tuple are also iterable object.\n", "dataset = GeneratorDataset(source=[(np.array(0),), (np.array(1),), (np.array(2),)], column_names=[\"col\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 迭代器\n", "\n", "Python中内置有`__iter__`和`__next__`方法的对象,称为迭代器(Iterator)。下面构造一个简单迭代器,并将其加载至`GeneratorDataset`。" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Iterator as input source\n", "class Iterator:\n", " def __init__(self):\n", " self._index = 0\n", " self._data = np.random.sample((5, 2))\n", " self._label = np.random.sample((5, 1))\n", "\n", " def __next__(self):\n", " if self._index >= len(self._data):\n", " raise StopIteration\n", " else:\n", " item = (self._data[self._index], self._label[self._index])\n", " self._index += 1\n", " return item\n", "\n", " def __iter__(self):\n", " self._index = 0\n", " return self\n", "\n", " def __len__(self):\n", " return len(self._data)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "data = Iterator()\n", "dataset = GeneratorDataset(source=data, column_names=[\"data\", \"label\"])" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" }, "vscode": { "interpreter": { "hash": "8c9da313289c39257cb28b126d2dadd33153d4da4d524f730c81a4aaccbd2ca7" } } }, "nbformat": 4, "nbformat_minor": 4 }