{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dae25842",
   "metadata": {},
   "source": [
    "# 数据集构建\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.1/zh_cn/migration_guide/model_development/mindspore_dataset.ipynb)\n",
    "[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.1/zh_cn/migration_guide/model_development/mindspore_dataset.py)\n",
    "[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/source_zh_cn/migration_guide/model_development/dataset.ipynb)\n",
    "\n",
    "本章节主要对网络迁移中数据处理相关的注意事项加以说明,基本的数据处理请参考:\n",
    "\n",
    "[数据处理](https://www.mindspore.cn/tutorials/zh-CN/r2.1/beginner/dataset.html)\n",
    "\n",
    "[自动数据增强](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.1/dataset/augment.html)\n",
    "\n",
    "[轻量化数据处理](https://mindspore.cn/tutorials/zh-CN/r2.1/advanced/dataset/eager.html)\n",
    "\n",
    "[数据处理性能优化](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.1/dataset/optimize.html)\n",
    "\n",
    "## 数据构建基本流程\n",
    "\n",
    "整个数据构建基本流程主要包括:数据集加载和数据增强这两方面。\n",
    "\n",
    "### 数据集加载\n",
    "\n",
    "MindSpore提供了很多常见数据集的加载接口,用的比较多的接口有:\n",
    "\n",
    "| 数据接口 | 介绍 |\n",
    "| -------| ---- |\n",
    "| [Cifar10Dataset](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/mindspore.dataset.Cifar10Dataset.html#mindspore.dataset.Cifar10Dataset) | Cifar10 数据集读入接口(需要自行下载Cifar10的原始bin文件) |\n",
    "| [MNIST](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/mindspore.dataset.MnistDataset.html#mindspore.dataset.MnistDataset) | Minist 手写数字识别数据集(需要自行下载原始文件) |\n",
    "| [ImageFolderDataset](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/mindspore.dataset.ImageFolderDataset.html) | 以文件目录作为分类的数据组织格式的数据集读取方式(ImageNet常用) |\n",
    "| [MindDataset](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/mindspore.dataset.MindDataset.html#mindspore.dataset.MindDataset) | MindRecord数据读入接口 |\n",
    "| [GeneratorDataset](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/mindspore.dataset.GeneratorDataset.html) | 用户自定义数据接口 |\n",
    "| [FakeImageDataset](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/mindspore.dataset.FakeImageDataset.html) | 构造一个假的图像数据集 |\n",
    "\n",
    "还有很多常用的不同领域的数据集接口,详情参考[常见数据集的加载接口](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/mindspore.dataset.html)。\n",
    "\n",
    "#### 自定义数据集 GeneratorDataset\n",
    "\n",
    "构造自定义 Dataset 对象的基本流程如下:\n",
    "\n",
    "首先创建一个迭代器类,在该类中定义 `__init__` 、 `__getitem__` 、 `__len__` 三个方法:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d0857d92",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T06:24:16.510160Z",
     "start_time": "2022-09-08T06:24:16.326480Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "class MyDataset():\n",
    "    \"\"\"Self Defined dataset.\"\"\"\n",
    "    def __init__(self, n):\n",
    "        self.data = []\n",
    "        self.label = []\n",
    "        for _ in range(n):\n",
    "            self.data.append(np.zeros((3, 4, 5)))\n",
    "            self.label.append(np.ones((1)))\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.data[idx]\n",
    "        label = self.label[idx]\n",
    "        return data, label"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aadecd65",
   "metadata": {},
   "source": [
    "> 在迭代器类中最好不要使用MindSpore的算子,因为一般在数据处理阶段会加多线程,这样会有问题。\n",
    ">\n",
    "> 迭代器的输出需要是numpy的array。\n",
    ">\n",
    "> 迭代器必须要设置`__len__`方法,返回的结果一定要是真实的数据集大小,设置大了在getitem取值时会有问题。\n",
    "\n",
    "然后使用 GeneratorDataset 封装迭代器类:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "23cb730a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T06:24:19.388920Z",
     "start_time": "2022-09-08T06:24:16.512171Z"
    }
   },
   "outputs": [],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "my_dataset = MyDataset(10)\n",
    "# corresponding to torch.utils.data.DataLoader(my_dataset)\n",
    "dataset = ds.GeneratorDataset(my_dataset, column_names=[\"data\", \"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af5a401",
   "metadata": {},
   "source": [
    "在自定义数据集时需要给每一个输出设置一个名字,如上面的`column_names=[\"data\", \"label\"]`,表示迭代器的第一个输出列叫`data`,第二个叫`label`。在后续的数据增强以及数据迭代获取阶段,可以通过名字来分别对不同列进行处理。\n",
    "\n",
    "**所有MindSpore的数据接口**,有一些通用的属性控制,这里介绍一些常用的:\n",
    "\n",
    "| 属性 | 介绍 |\n",
    "| ---- | ---- |\n",
    "| num_samples(int) | 规定数据总的sample数 |\n",
    "| shuffle(bool)  | 是否对数据做随机打乱 |\n",
    "| sampler(Sampler) | 数据取样器,可以自定义数据打乱、分配,`sampler`设置和`num_shards`、`shard_id`互斥 |\n",
    "| num_shards(int) | 用于分布式场景,将数据分为多少份,与`shard_id`配合使用 |\n",
    "| shard_id(int) | 用于分布式场景,取第几份数据(0~n-1,n为设置的`num_shards`),与`num_shards`配合使用 |\n",
    "| num_parallel_workers(int) | 并行配置的线程数 |\n",
    "\n",
    "举个例子:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7f539b71",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T06:24:19.398835Z",
     "start_time": "2022-09-08T06:24:19.390999Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n",
      "3\n",
      "125\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n",
    "print(dataset.get_dataset_size())\n",
    "# 1000\n",
    "\n",
    "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0, num_samples=3)\n",
    "print(dataset.get_dataset_size())\n",
    "# 3\n",
    "\n",
    "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0,\n",
    "                              num_shards=8, shard_id=0)\n",
    "print(dataset.get_dataset_size())\n",
    "# 1000 / 8 = 125"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2191c57",
   "metadata": {},
   "source": [
    "### 数据处理及增强\n",
    "\n",
    "MindSpore的dataset对象使用map接口进行数据增强。我们一起来看下[map接口](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/dataset_method/operation/mindspore.dataset.Dataset.map.html#mindspore.dataset.Dataset.map):\n",
    "\n",
    "```text\n",
    "map(operations, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, **kwargs)\n",
    "```\n",
    "\n",
    "给定一组数据增强列表,按顺序将数据增强作用在数据集对象上。\n",
    "\n",
    "每个数据增强操作将数据集对象中的一个或多个数据列作为输入,将数据增强的结果输出为一个或多个数据列。第一个数据增强操作将 input_columns 中指定的列作为输入。如果数据增强列表中存在多个数据增强操作,则上一个数据增强的输出列将作为下一个数据增强的输入列。\n",
    "\n",
    "最后一个数据增强的输出列的列名由 `output_columns` 指定,如果没有指定 `output_columns` ,输出列名与 `input_columns` 一致。\n",
    "\n",
    "上面的介绍可能比较繁琐,简单来说 `map` 就是在数据集的某些列上做 `operations` 里规定的操作。这里的`operations`可以是MindSpore提供的数据增强操作:\n",
    "[audio](https://gitee.com/mindspore/mindspore/blob/r2.1/docs/api/api_python/mindspore.dataset.transforms.rst#音频)、[text](https://gitee.com/mindspore/mindspore/blob/r2.1/docs/api/api_python/mindspore.dataset.transforms.rst#文本)、[vision](https://gitee.com/mindspore/mindspore/blob/r2.1/docs/api/api_python/mindspore.dataset.transforms.rst#视觉)、[通用](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/mindspore.dataset.transforms.html)。详情请参考[数据变换 Transforms](https://www.mindspore.cn/tutorials/zh-CN/r2.1/beginner/transforms.html)\n",
    "也可以是python的方法,里面可以用 opencv,PIL,pandas 等一些三方的方法,和数据集加载一样,**不要使用MindSpore的算子**。\n",
    "\n",
    "MindSpore也提供了一些常用的随机增强方法:[自动数据增强](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.1/dataset/augment.html),在具体使用数据增强时最好先阅读[数据处理性能优化](https://www.mindspore.cn/tutorials/experts/zh-CN/r2.1/dataset/optimize.html),按照推荐的顺序进行处理。\n",
    "\n",
    "在数据增强结束后,可以使用batch操作将数据集中连续 `batch_size` 条数据合并为一个批处理数据。详情请参考[batch](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/dataset_method/batch/mindspore.dataset.Dataset.batch.html#mindspore.dataset.Dataset.batch)。其中需要注意 `drop_remainder` 这个参数,当训练的时候需要设置成True,推理时设置成False。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "aee6e813",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T06:24:19.412419Z",
     "start_time": "2022-09-08T06:24:19.401495Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31\n",
      "32\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\\\n",
    "    .batch(32, drop_remainder=True)\n",
    "print(dataset.get_dataset_size())\n",
    "# 1000 // 32 = 31\n",
    "\n",
    "dataset = ds.FakeImageDataset(num_images=1000, image_size=(32, 32, 3), num_classes=10, base_seed=0)\\\n",
    "    .batch(32, drop_remainder=False)\n",
    "print(dataset.get_dataset_size())\n",
    "# ceil(1000 / 32) = 32"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e84cb613",
   "metadata": {},
   "source": [
    "batch操作也可以使用一些batch内的增强操作,详情可参考[YOLOv3](https://gitee.com/mindspore/models/blob/r2.1/official/cv/YOLOv3/src/yolo_dataset.py#L177)。\n",
    "\n",
    "## 数据迭代\n",
    "\n",
    "MindSpore的数据对象有以下几种方式迭代获取。\n",
    "\n",
    "### [create_dict_iterator](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_dict_iterator.html#mindspore.dataset.Dataset.create_dict_iterator)\n",
    "\n",
    "基于数据集对象创建迭代器,输出的数据为字典类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "eee2d49c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T06:24:19.435680Z",
     "start_time": "2022-09-08T06:24:19.413994Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image (10, 32, 32, 3)\n",
      "label (10,)\n",
      "====================\n",
      "image (10, 32, 32, 3)\n",
      "label (10,)\n",
      "====================\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n",
    "dataset = dataset.batch(10, drop_remainder=True)\n",
    "iterator = dataset.create_dict_iterator()\n",
    "for data_dict in iterator:\n",
    "    for name in data_dict.keys():\n",
    "        print(name, data_dict[name].shape)\n",
    "    print(\"=\"*20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c79ed4d9",
   "metadata": {},
   "source": [
    "### [create_tuple_iterator](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_tuple_iterator.html#mindspore.dataset.Dataset.create_tuple_iterator)\n",
    "\n",
    "基于数据集对象创建迭代器,输出数据为 `numpy.ndarray` 组成的列表。\n",
    "\n",
    "可以通过参数 `columns` 指定输出的所有列名及列的顺序。如果columns未指定,列的顺序将保持不变。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e674ccf0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T06:24:19.458382Z",
     "start_time": "2022-09-08T06:24:19.437707Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 32, 32, 3)\n",
      "(10,)\n",
      "====================\n",
      "(10, 32, 32, 3)\n",
      "(10,)\n",
      "====================\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n",
    "dataset = dataset.batch(10, drop_remainder=True)\n",
    "iterator = dataset.create_tuple_iterator()\n",
    "for data_tuple in iterator:\n",
    "    for data in data_tuple:\n",
    "        print(data.shape)\n",
    "    print(\"=\"*20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a60d5ae1",
   "metadata": {},
   "source": [
    "### 直接遍历dataset对象\n",
    "\n",
    "> 注意这种写法在遍历完一次epoch后不会`shuffle`,在训练时这样使用可能会影响精度,训练时需要直接数据迭代时建议使用上面的两种方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cca2305f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-09-08T06:24:19.475797Z",
     "start_time": "2022-09-08T06:24:19.459963Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 32, 32, 3)\n",
      "(10,)\n",
      "====================\n",
      "(10, 32, 32, 3)\n",
      "(10,)\n",
      "====================\n"
     ]
    }
   ],
   "source": [
    "import mindspore.dataset as ds\n",
    "dataset = ds.FakeImageDataset(num_images=20, image_size=(32, 32, 3), num_classes=10, base_seed=0)\n",
    "dataset = dataset.batch(10, drop_remainder=True)\n",
    "\n",
    "for data in dataset:\n",
    "    for data in data:\n",
    "        print(data.shape)\n",
    "    print(\"=\"*20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "152092b4",
   "metadata": {},
   "source": [
    "其中后两种在数据读入顺序和网络需要的顺序一致时,可以直接使用:\n",
    "\n",
    "```python\n",
    "for data in dataset:\n",
    "    loss = net(*data)\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}