{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据加载及处理\n", "\n", "[![](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.3/tutorials/source_zh_cn/dataset.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.3/tutorials/zh_cn/mindspore_dataset.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.3/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9yMS4zL3R1dG9yaWFscy96aF9jbi9taW5kc3BvcmVfZGF0YXNldC5pcHluYg==&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MindSpore提供了部分常用数据集和标准格式数据集的加载接口,用户可以直接使用`mindspore.dataset`中对应的数据集加载类进行数据加载。数据集类为用户提供了常用的数据处理接口,使得用户能够快速进行数据处理操作。\n", "\n", "## 加载数据集\n", "\n", "下面的样例通过`Cifar10Dataset`接口加载CIFAR-10数据集,使用顺序采样器获取前5个样本。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import mindspore.dataset as ds\n", "\n", "DATA_DIR = \"./datasets/cifar-10-batches-bin/train\"\n", "sampler = ds.SequentialSampler(num_samples=5)\n", "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 迭代数据集\n", "\n", "用户可以用`create_dict_iterator`创建数据迭代器,迭代访问数据,下面展示了对应图片的形状和标签。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 4\n", "Image shape: (32, 32, 3) , Label: 1\n" ] } ], "source": [ "for data in dataset.create_dict_iterator():\n", " print(\"Image shape: {}\".format(data['image'].shape), \", Label: {}\".format(data['label']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 自定义数据集\n", "\n", "对于目前MindSpore不支持直接加载的数据集,可以构造自定义数据集类,然后通过`GeneratorDataset`接口实现自定义方式的数据加载。\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "np.random.seed(58)\n", "\n", "class DatasetGenerator:\n", " def __init__(self):\n", " self.data = np.random.sample((5, 2))\n", " self.label = np.random.sample((5, 1))\n", "\n", " def __getitem__(self, index):\n", " return self.data[index], self.label[index]\n", "\n", " def __len__(self):\n", " return len(self.data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中用户需要自定义的类函数如下:\n", "\n", "- **\\_\\_init\\_\\_**\n", "\n", " 实例化数据集对象时,`__init__`函数被调用,用户可以在此进行数据初始化等操作。\n", "\n", " ```python\n", " def __init__(self):\n", " self.data = np.random.sample((5, 2))\n", " self.label = np.random.sample((5, 1))\n", " ```\n", "\n", "- **\\_\\_getitem\\_\\_**\n", "\n", " 定义数据集类的`__getitem__`函数,使其支持随机访问,能够根据给定的索引值`index`,获取数据集中的数据并返回。\n", "\n", " ```python\n", " def __getitem__(self, index):\n", " return self.data[index], self.label[index]\n", " ```\n", "\n", "- **\\_\\_len\\_\\_**\n", "\n", " 定义数据集类的`__len__`函数,返回数据集的样本数量。\n", "\n", " ```python\n", " def __len__(self):\n", " return len(self.data)\n", " ```\n", " \n", "定义数据集类之后,就可以通过`GeneratorDataset`接口按照用户定义的方式加载并访问数据集样本。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.36510558 0.45120592] [0.78888122]\n", "[0.49606035 0.07562207] [0.38068183]\n", "[0.57176158 0.28963401] [0.16271622]\n", "[0.30880446 0.37487617] [0.54738768]\n", "[0.81585667 0.96883469] [0.77994068]\n" ] } ], "source": [ "dataset_generator = DatasetGenerator()\n", "dataset = ds.GeneratorDataset(dataset_generator, [\"data\", \"label\"], shuffle=False)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print('{}'.format(data[\"data\"]), '{}'.format(data[\"label\"]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 数据处理及增强\n", "\n", "### 数据处理\n", "\n", "MindSpore提供的数据集接口具备常用的数据处理方法,用户只需调用相应的函数接口即可快速进行数据处理。\n", "\n", "下面的样例先将数据集随机打乱顺序,然后将样本两两组成一个批次。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data: [[0.36510558 0.45120592]\n", " [0.57176158 0.28963401]]\n", "label: [[0.78888122]\n", " [0.16271622]]\n", "data: [[0.30880446 0.37487617]\n", " [0.49606035 0.07562207]]\n", "label: [[0.54738768]\n", " [0.38068183]]\n", "data: [[0.81585667 0.96883469]]\n", "label: [[0.77994068]]\n" ] } ], "source": [ "ds.config.set_seed(58)\n", "\n", "# 随机打乱数据顺序\n", "dataset = dataset.shuffle(buffer_size=10)\n", "# 对数据集进行分批\n", "dataset = dataset.batch(batch_size=2)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print(\"data: {}\".format(data[\"data\"]))\n", " print(\"label: {}\".format(data[\"label\"]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "其中,\n", "\n", "`buffer_size`:数据集中进行shuffle操作的缓存区的大小。\n", "\n", "`batch_size`:每组包含的数据个数,现设置每组包含2个数据。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 数据增强\n", "\n", "数据量过小或是样本场景单一等问题会影响模型的训练效果,用户可以通过数据增强操作扩充样本多样性,从而提升模型的泛化能力。\n", "\n", "下面的样例使用`mindspore.dataset.vision.c_transforms`模块中的算子对MNIST数据集进行数据增强。\n", "\n", "导入`c_transforms`模块,加载MNIST数据集。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAENCAYAAADJzhMWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAMaklEQVR4nO3dX6ik9X3H8fenJmnBeLFGul2MZtNUQiGlWkQKlWIpCdZeqDc2QsE0pZuLWhLIRcReRAiFUGzaQqF0Q2w2tjUEjFHE1lix2eQmuIrVVTFauxKX1Y0sbbQ3afTbi/OsPbuec+bs/Htmz/f9gmFmnjP7zHef3c/+/s3sL1WFpJ3vZ8YuQNJyGHapCcMuNWHYpSYMu9SEYZeaMOxSE4Zdp0hyJEltcntl7Po0vXeNXYBW0n8Df7XB8TeWXIfmKH6CTuslOQJQVXvHrUTzZjdeasJuvDbys0l+H7gY+B/gSeBgVb05blmahd14nWLoxn9ggx/9J/AHVfWd5VakebEbr9P9PfDbwC8A5wK/AvwdsBf45yS/Ol5pmoUtu7Ylye3AZ4FvVdX1Y9ejM2fYtS1Jfgl4HjhRVe8bux6dObvx2q4fDffnjlqFpmbYtV2/Pty/OGoVmpph19uS/HKSd7TcSfYCfzM8/YelFqW5cZ1d6/0e8NkkB4GXgNeBDwG/C/wc8ABw+3jlaRaGXes9AnwYuAz4DdbG5/8FfA+4E7iznNE9azkbLzXhmF1qwrBLTRh2qQnDLjWx1Nn4JM4GSgtWVdno+Ewte5KrkzyX5IUkt8xyLkmLNfXSW5JzgB8AHwVeBh4FbqyqZ7b4Nbbs0oItomW/Anihql6sqp8AXweuneF8khZolrBfCPxw3fOXh2OnSLIvyaEkh2Z4L0kzWvgEXVXtB/aD3XhpTLO07EeBi9Y9f/9wTNIKmiXsjwKXJPlgkvcAHwfum09ZkuZt6m58Vf00yc3Ag8A5wB1V9fTcKpM0V0v91ptjdmnxFvKhGklnD8MuNWHYpSYMu9SEYZeaMOxSE4ZdasKwS00YdqkJwy41YdilJgy71IRhl5ow7FIThl1qwrBLTRh2qQnDLjVh2KUmDLvUhGGXmjDsUhOGXWrCsEtNGHapCcMuNWHYpSYMu9SEYZeamHrLZp0dlrlLr/5fsuFGqqOaKexJjgCvA28CP62qy+dRlKT5m0fL/ltV9docziNpgRyzS03MGvYCvp3ksST7NnpBkn1JDiU5NON7SZpBZpnASXJhVR1N8vPAQ8CfVNXBLV7vbNGSOUE3jjEn6KpqwzefqWWvqqPD/XHgHuCKWc4naXGmDnuSc5Ocd/Ix8DHg8LwKkzRfs8zG7wbuGbor7wL+qar+ZS5V7TB2pbUKZhqzn/GbNR2zG/Z+dtyYXdLZw7BLTRh2qQnDLjVh2KUm/IrrHDjb3s8qfoV1Elt2qQnDLjVh2KUmDLvUhGGXmjDsUhOGXWrCdXatrLNxLXuV2bJLTRh2qQnDLjVh2KUmDLvUhGGXmjDsUhOus8/BpPXgnfx9986/97ONLbvUhGGXmjDsUhOGXWrCsEtNGHapCcMuNeE6+xKczWvRs36n3O+kr46JLXuSO5IcT3J43bHzkzyU5Pnhftdiy5Q0q+10478KXH3asVuAh6vqEuDh4bmkFTYx7FV1EDhx2uFrgQPD4wPAdfMtS9K8TTtm311Vx4bHrwC7N3thkn3AvinfR9KczDxBV1WVZNMZpqraD+wH2Op1khZr2qW3V5PsARjuj8+vJEmLMG3Y7wNuGh7fBNw7n3IkLUomrfEmuQu4CrgAeBX4PPAt4BvAxcBLwA1Vdfok3kbnshu/AGOu07uOvnqqasM/lIlhnyfDvhiGXettFnY/Lis1YdilJgy71IRhl5ow7FIThl1qwrBLTRh2qQnDLjVh2KUmDLvUhGGXmjDsUhP+V9I7wFbfPFv0N+IWeX6/UTdftuxSE4ZdasKwS00YdqkJwy41YdilJgy71ITr7Dvc2bxd9Db+m/MlVbIz2LJLTRh2qQnDLjVh2KUmDLvUhGGXmjDsUhOuszfnOnwfE1v2JHckOZ7k8LpjtyU5muSJ4XbNYsuUNKvtdOO/Cly9wfG/rKpLh9sD8y1L0rxNDHtVHQROLKEWSQs0ywTdzUmeHLr5uzZ7UZJ9SQ4lOTTDe0maUbYzAZNkL3B/VX1keL4beA0o4AvAnqr65DbOs7qzPdrQKk/QTdJ1gq6qNvyNT9WyV9WrVfVmVb0FfBm4YpbiJC3eVGFPsmfd0+uBw5u9VtJqmLjOnuQu4CrggiQvA58HrkpyKWvd+CPApxZXosZ0Nq/D61TbGrPP7c0cs+84qxx2x+yn8uOyUhOGXWrCsEtNGHapCcMuNeFXXDWTWWa8x9xOuuNMvS271IRhl5ow7FIThl1qwrBLTRh2qQnDLjXhOru2tMrfatOZsWWXmjDsUhOGXWrCsEtNGHapCcMuNWHYpSZcZ9/hOq+Td/zO+lZs2aUmDLvUhGGXmjDsUhOGXWrCsEtNGHapiYlhT3JRkkeSPJPk6SSfHo6fn+ShJM8P97sWX25PVTX1bSdLsuVNp5q4ZXOSPcCeqno8yXnAY8B1wCeAE1X1xSS3ALuq6nMTzrWz//YtyE4P7bQM9Mam3rK5qo5V1ePD49eBZ4ELgWuBA8PLDrD2D4CkFXVGY/Yke4HLgO8Du6vq2PCjV4Dd8y1N0jxt+7PxSd4L3A18pqp+vL4LVVW1WRc9yT5g36yFSprNxDE7QJJ3A/cDD1bVl4ZjzwFXVdWxYVz/b1X14QnncfA5BcfsG3PMvrGpx+xZu6JfAZ49GfTBfcBNw+ObgHtnLVLS4mxnNv5K4LvAU8Bbw+FbWRu3fwO4GHgJuKGqTkw4V8smypZ5Orbc09msZd9WN35eDLvOhGGfztTdeEk7g2GXmjDsUhOGXWrCsEtNGHapCf8r6W1y+Ww6Lp+tDlt2qQnDLjVh2KUmDLvUhGGXmjDsUhOGXWqizTq76+TTcZ1857Bll5ow7FIThl1qwrBLTRh2qQnDLjVh2KUm2qyzd+U6uU6yZZeaMOxSE4ZdasKwS00YdqkJwy41YdilJiausye5CPgasBsoYH9V/XWS24A/An40vPTWqnpgUYXOyvVmdTdxf/Yke4A9VfV4kvOAx4DrgBuAN6rq9m2/WdP92aVl2mx/9okte1UdA44Nj19P8ixw4XzLk7RoZzRmT7IXuAz4/nDo5iRPJrkjya5Nfs2+JIeSHJqtVEmzmNiNf/uFyXuB7wB/VlXfTLIbeI21cfwXWOvqf3LCOezGSwu2WTd+W2FP8m7gfuDBqvrSBj/fC9xfVR+ZcB7DLi3YZmGf2I3P2jT2V4Bn1wd9mLg76Xrg8KxFSlqc7czGXwl8F3gKeGs4fCtwI3Apa934I8Cnhsm8rc5lyy4t2Ezd+Hkx7NLiTd2Nl7QzGHapCcMuNWHYpSYMu9SEYZeaMOxSE4ZdasKwS00YdqkJwy41YdilJgy71IRhl5pY9pbNrwEvrXt+wXBsFa1qbataF1jbtOZZ2wc2+8FSv8/+jjdPDlXV5aMVsIVVrW1V6wJrm9ayarMbLzVh2KUmxg77/pHffyurWtuq1gXWNq2l1DbqmF3S8ozdsktaEsMuNTFK2JNcneS5JC8kuWWMGjaT5EiSp5I8Mfb+dMMeeseTHF537PwkDyV5frjfcI+9kWq7LcnR4do9keSakWq7KMkjSZ5J8nSSTw/HR712W9S1lOu29DF7knOAHwAfBV4GHgVurKpnllrIJpIcAS6vqtE/gJHkN4E3gK+d3ForyZ8DJ6rqi8M/lLuq6nMrUtttnOE23guqbbNtxj/BiNduntufT2OMlv0K4IWqerGqfgJ8Hbh2hDpWXlUdBE6cdvha4MDw+ABrf1mWbpPaVkJVHauqx4fHrwMntxkf9dptUddSjBH2C4Efrnv+Mqu133sB307yWJJ9Yxezgd3rttl6Bdg9ZjEbmLiN9zKdts34yly7abY/n5UTdO90ZVX9GvA7wB8P3dWVVGtjsFVaO/1b4EOs7QF4DPiLMYsZthm/G/hMVf14/c/GvHYb1LWU6zZG2I8CF617/v7h2EqoqqPD/XHgHtaGHavk1ZM76A73x0eu521V9WpVvVlVbwFfZsRrN2wzfjfwj1X1zeHw6Nduo7qWdd3GCPujwCVJPpjkPcDHgftGqOMdkpw7TJyQ5FzgY6zeVtT3ATcNj28C7h2xllOsyjbem20zzsjXbvTtz6tq6TfgGtZm5P8D+NMxatikrl8E/n24PT12bcBdrHXr/pe1uY0/BN4HPAw8D/wrcP4K1XYna1t7P8lasPaMVNuVrHXRnwSeGG7XjH3ttqhrKdfNj8tKTThBJzVh2KUmDLvUhGGXmjDsUhOGXWrCsEtN/B/M3kbdmYwBvQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "from mindspore.dataset.vision import Inter\n", "import mindspore.dataset.vision.c_transforms as c_vision\n", "\n", "DATA_DIR = './datasets/MNIST_Data/train'\n", "\n", "mnist_dataset = ds.MnistDataset(DATA_DIR, num_samples=6, shuffle=False)\n", "\n", "# 查看数据原图\n", "mnist_it = mnist_dataset.create_dict_iterator()\n", "data = next(mnist_it)\n", "plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)\n", "plt.title(data['label'].asnumpy(), fontsize=20)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "定义数据增强算子,对数据集进行`Resize`和`RandomCrop`操作,然后通过`map`映射将其插入数据处理管道。\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "resize_op = c_vision.Resize(size=(200,200), interpolation=Inter.LINEAR)\n", "crop_op = c_vision.RandomCrop(150)\n", "transforms_list = [resize_op, crop_op]\n", "mnist_dataset = mnist_dataset.map(operations=transforms_list, input_columns=[\"image\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看数据增强效果。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "mnist_dataset = mnist_dataset.create_dict_iterator()\n", "data = next(mnist_dataset)\n", "plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)\n", "plt.title(data['label'].asnumpy(), fontsize=20)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "想要了解更多可以参考编程指南中[数据增强](https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.3/augmentation.html)章节。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }