{ "cells": [ { "cell_type": "markdown", "id": "dae25842", "metadata": {}, "source": [ "# 数据集构建\n", "\n", "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0/zh_cn/migration_guide/model_development/mindspore_dataset.ipynb)\n", "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0/zh_cn/migration_guide/model_development/mindspore_dataset.py)\n", "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/source_zh_cn/migration_guide/model_development/dataset.ipynb)\n", "\n", "本章节主要对网络迁移中数据处理相关的注意事项加以说明,基本的数据处理请参考:\n", "\n", "[数据处理](https://www.mindspore.cn/tutorials/zh-CN/r2.0/beginner/dataset.html)\n", "\n", "[自动数据增强](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/dataset/augment.html)\n", "\n", "[轻量化数据处理](https://mindspore.cn/tutorials/zh-CN/r2.0/advanced/dataset/eager.html)\n", "\n", "[数据处理性能优化](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/dataset/optimize.html)\n", "\n", "## 数据构建基本流程\n", "\n", "整个数据构建基本流程主要包括:数据集加载和数据增强这两方面。\n", "\n", "### 数据集加载\n", "\n", "MindSpore提供了很多常见数据集的加载接口,用的比较多的接口有:\n", "\n", "| 数据接口 | 介绍 |\n", "| -------| ---- |\n", "| [Cifar10Dataset](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/mindspore.dataset.Cifar10Dataset.html#mindspore.dataset.Cifar10Dataset) | Cifar10 数据集读入接口(需要自行下载Cifar10的原始bin文件) |\n", "| [MNIST](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/mindspore.dataset.MnistDataset.html#mindspore.dataset.MnistDataset) | Minist 手写数字识别数据集(需要自行下载原始文件) |\n", "| [ImageFolderDataset](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/mindspore.dataset.ImageFolderDataset.html) | 以文件目录作为分类的数据组织格式的数据集读取方式(ImageNet常用) |\n", "| [MindDataset](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/mindspore.dataset.MindDataset.html#mindspore.dataset.MindDataset) | MindRecord数据读入接口 |\n", "| [GeneratorDataset](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/mindspore.dataset.GeneratorDataset.html) | 用户自定义数据接口 |\n", "| [FakeImageDataset](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/mindspore.dataset.FakeImageDataset.html) | 构造一个假的图像数据集 |\n", "\n", "还有很多常用的不同领域的数据集接口,详情参考[常见数据集的加载接口](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore.dataset.html)。\n", "\n", "#### 自定义数据集 GeneratorDataset\n", "\n", "构造自定义 Dataset 对象的基本流程如下:\n", "\n", "首先创建一个迭代器类,在该类中定义 `__init__` 、 `__getitem__` 、 `__len__` 三个方法:" ] }, { "cell_type": "code", "execution_count": 1, "id": "d0857d92", "metadata": { "ExecuteTime": { "end_time": "2022-09-08T06:24:16.510160Z", "start_time": "2022-09-08T06:24:16.326480Z" } }, "outputs": [], "source": [ "import numpy as np\n", "class MyDataset():\n", " \"\"\"Self Defined dataset.\"\"\"\n", " def __init__(self, n):\n", " self.data = []\n", " self.label = []\n", " for _ in range(n):\n", " self.data.append(np.zeros((3, 4, 5)))\n", " self.label.append(np.ones((1)))\n", " def __len__(self):\n", " return len(self.data)\n", " def __getitem__(self, idx):\n", " data = self.data[idx]\n", " label = self.label[idx]\n", " return data, label" ] }, { "cell_type": "markdown", "id": "aadecd65", "metadata": {}, "source": [ "> 在迭代器类中最好不要使用MindSpore的算子,因为一般在数据处理阶段会加多线程,这样会有问题。\n", ">\n", "> 迭代器的输出需要是numpy的array。\n", ">\n", "> 迭代器必须要设置`__len__`方法,返回的结果一定要是真实的数据集大小,设置大了在getitem取值时会有问题。\n", "\n", "然后使用 GeneratorDataset 封装迭代器类:" ] }, { "cell_type": "code", "execution_count": 2, "id": "23cb730a", "metadata": { "ExecuteTime": { "end_time": "2022-09-08T06:24:19.388920Z", "start_time": "2022-09-08T06:24:16.512171Z" } }, "outputs": [], "source": [ "import mindspore.dataset as ds\n", "\n", "my_dataset = MyDataset(10)\n", "# corresponding to torch.utils.data.DataLoader(my_dataset)\n", "dataset = ds.GeneratorDataset(my_dataset, column_names=[\"data\", \"label\"])" ] }, { "cell_type": "markdown", "id": "9af5a401", "metadata": {}, "source": [ "在自定义数据集时需要给每一个输出设置一个名字,如上面的`column_names=[\"data\", \"label\"]`,表示迭代器的第一个输出列叫`data`,第二个叫`label`。在后续的数据增强以及数据迭代获取阶段,可以通过名字来分别对不同列进行处理。\n", "\n", "**所有MindSpore的数据接口**,有一些通用的属性控制,这里介绍一些常用的:\n", "\n", "| 属性 | 介绍 |\n", "| ---- | ---- |\n", "| num_samples(int) | 规定数据总的sample数 |\n", "| shuffle(bool) | 是否对数据做随机打乱 |\n", "| sampler(Sampler) | 数据取样器,可以自定义数据打乱、分配,`sampler`设置和`num_shards`、`shard_id`互斥 |\n", "| num_shards(int) | 用于分布式场景,将数据分为多少份,与`shard_id`配合使用 |\n", "| shard_id(int) | 用于分布式场景,取第几份数据(0~n-1,n为设置的`num_shards`),与`num_shards`配合使用 |\n", "| num_parallel_workers(int) | 并行配置的线程数 |\n", "\n", "举个例子:" ] }, { "cell_type": "code", "execution_count": 3, "id": "7f539b71", "metadata": { "ExecuteTime": { "end_time": "2022-09-08T06:24:19.398835Z", "start_time": "2022-09-08T06:24:19.390999Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1000\n", "3\n", "125\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "\n", "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n", "print(dataset.get_dataset_size())\n", "# 1000\n", "\n", "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0, num_samples=3)\n", "print(dataset.get_dataset_size())\n", "# 3\n", "\n", "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0,\n", " num_shards=8, shard_id=0)\n", "print(dataset.get_dataset_size())\n", "# 1000 / 8 = 125" ] }, { "cell_type": "markdown", "id": "d2191c57", "metadata": {}, "source": [ "### 数据处理及增强\n", "\n", "MindSpore的dataset对象使用map接口进行数据增强。我们一起来看下[map接口](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/dataset_method/operation/mindspore.dataset.Dataset.map.html#mindspore.dataset.Dataset.map):\n", "\n", "```text\n", "map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, **kwargs)\n", "```\n", "\n", "给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。\n", "\n", "每个数据增强操作将数据集对象中的一个或多个数据列作为输入,将数据增强的结果输出为一个或多个数据列。第一个数据增强操作将 input_columns 中指定的列作为输入。如果数据增强列表中存在多个数据增强操作,则上一个数据增强的输出列将作为下一个数据增强的输入列。\n", "\n", "最后一个数据增强的输出列的列名由 `output_columns` 指定,如果没有指定 `output_columns` ,输出列名与 `input_columns` 一致。\n", "\n", "上面的介绍可能比较繁琐,简单来说 `map` 就是在数据集的某些列上做 `operations` 里规定的操作。这里的`operations`可以是MindSpore提供的数据增强操作:\n", "[audio](https://gitee.com/mindspore/mindspore/blob/r2.0/docs/api/api_python/mindspore.dataset.transforms.rst#音频)、[text](https://gitee.com/mindspore/mindspore/blob/r2.0/docs/api/api_python/mindspore.dataset.transforms.rst#文本)、[vision](https://gitee.com/mindspore/mindspore/blob/r2.0/docs/api/api_python/mindspore.dataset.transforms.rst#视觉)、[通用](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/mindspore.dataset.transforms.html)。详情请参考[数据变换 Transforms](https://www.mindspore.cn/tutorials/zh-CN/r2.0/beginner/transforms.html)\n", "也可以是python的方法,里面可以用 opencv,PIL,pandas 等一些三方的方法,和数据集加载一样,**不要使用MindSpore的算子**。\n", "\n", "MindSpore也提供了一些常用的随机增强方法:[自动数据增强](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/dataset/augment.html),在具体使用数据增强时最好先阅读[数据处理性能优化](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.0/dataset/optimize.html),按照推荐的顺序进行处理。\n", "\n", "在数据增强结束后,可以使用batch操作将数据集中连续 `batch_size` 条数据合并为一个批处理数据。详情请参考[batch](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/dataset_method/batch/mindspore.dataset.Dataset.batch.html#mindspore.dataset.Dataset.batch)。其中需要注意 `drop_remainder` 这个参数,当训练的时候需要设置成True,推理时设置成False。" ] }, { "cell_type": "code", "execution_count": 4, "id": "aee6e813", "metadata": { "ExecuteTime": { "end_time": "2022-09-08T06:24:19.412419Z", "start_time": "2022-09-08T06:24:19.401495Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "31\n", "32\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "\n", "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\\\n", " .batch(32, drop_remainder=True)\n", "print(dataset.get_dataset_size())\n", "# 1000 // 32 = 31\n", "\n", "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\\\n", " .batch(32, drop_remainder=False)\n", "print(dataset.get_dataset_size())\n", "# ceil(1000 / 32) = 32" ] }, { "cell_type": "markdown", "id": "e84cb613", "metadata": {}, "source": [ "batch操作也可以使用一些batch内的增强操作,详情可参考[YOLOv3](https://gitee.com/mindspore/models/blob/r2.0/official/cv/YOLOv3/src/yolo_dataset.py#L177)。\n", "\n", "## 数据迭代\n", "\n", "MindSpore的数据对象有以下几种方式迭代获取。\n", "\n", "### [create_dict_iterator](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_dict_iterator.html#mindspore.dataset.Dataset.create_dict_iterator)\n", "\n", "基于数据集对象创建迭代器,输出的数据为字典类型。" ] }, { "cell_type": "code", "execution_count": 5, "id": "eee2d49c", "metadata": { "ExecuteTime": { "end_time": "2022-09-08T06:24:19.435680Z", "start_time": "2022-09-08T06:24:19.413994Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "image (10, 32, 32, 3)\n", "label (10,)\n", "====================\n", "image (10, 32, 32, 3)\n", "label (10,)\n", "====================\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n", "dataset = dataset.batch(10, drop_remainder=True)\n", "iterator = dataset.create_dict_iterator()\n", "for data_dict in iterator:\n", " for name in data_dict.keys():\n", " print(name, data_dict[name].shape)\n", " print(\"=\"*20)" ] }, { "cell_type": "markdown", "id": "c79ed4d9", "metadata": {}, "source": [ "### [create_tuple_iterator](https://www.mindspore.cn/docs/zh-CN/r2.0/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_tuple_iterator.html#mindspore.dataset.Dataset.create_tuple_iterator)\n", "\n", "基于数据集对象创建迭代器,输出数据为 `numpy.ndarray` 组成的列表。\n", "\n", "可以通过参数 `columns` 指定输出的所有列名及列的顺序。如果columns未指定,列的顺序将保持不变。" ] }, { "cell_type": "code", "execution_count": 6, "id": "e674ccf0", "metadata": { "ExecuteTime": { "end_time": "2022-09-08T06:24:19.458382Z", "start_time": "2022-09-08T06:24:19.437707Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10, 32, 32, 3)\n", "(10,)\n", "====================\n", "(10, 32, 32, 3)\n", "(10,)\n", "====================\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n", "dataset = dataset.batch(10, drop_remainder=True)\n", "iterator = dataset.create_tuple_iterator()\n", "for data_tuple in iterator:\n", " for data in data_tuple:\n", " print(data.shape)\n", " print(\"=\"*20)" ] }, { "cell_type": "markdown", "id": "a60d5ae1", "metadata": {}, "source": [ "### 直接遍历dataset对象\n", "\n", "> 注意这种写法在遍历完一次epoch后不会`shuffle`,在训练时这样使用可能会影响精度,训练时需要直接数据迭代时建议使用上面的两种方法。" ] }, { "cell_type": "code", "execution_count": 7, "id": "cca2305f", "metadata": { "ExecuteTime": { "end_time": "2022-09-08T06:24:19.475797Z", "start_time": "2022-09-08T06:24:19.459963Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(10, 32, 32, 3)\n", "(10,)\n", "====================\n", "(10, 32, 32, 3)\n", "(10,)\n", "====================\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n", "dataset = dataset.batch(10, drop_remainder=True)\n", "\n", "for data in dataset:\n", " for data in data:\n", " print(data.shape)\n", " print(\"=\"*20)" ] }, { "cell_type": "markdown", "id": "152092b4", "metadata": {}, "source": [ "其中后两种在数据读入顺序和网络需要的顺序一致时,可以直接使用:\n", "\n", "```python\n", "for data in dataset:\n", " loss = net(*data)\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 5 }