{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 数据采样\n",
    "\n",
    "`Ascend` `GPU` `CPU` `数据准备`\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL3Byb2dyYW1taW5nX2d1aWRlL21pbmRzcG9yZV9zYW1wbGVyLmlweW5i&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_sampler.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_sampler.py) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.5/docs/mindspore/programming_guide/source_zh_cn/sampler.ipynb)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 概述\n",
    "\n",
    "MindSpore提供了多种用途的采样器(Sampler),帮助用户对数据集进行不同形式的采样,以满足训练需求,能够解决诸如数据集过大或样本类别分布不均等问题。只需在加载数据集时传入采样器对象,即可实现数据的采样。\n",
    "\n",
    "MindSpore目前提供的部分采样器类别如下表所示。此外,用户也可以根据需要实现自定义的采样器类。更多采样器的使用方法参见[API文档](https://www.mindspore.cn/docs/api/zh-CN/r1.5/api_python/mindspore.dataset.html)。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "| 采样器名称  | 采样器说明 |\n",
    "| :----  | :----           |\n",
    "| RandomSampler | 随机采样器,在数据集中随机地采样指定数目的数据。 |\n",
    "| WeightedRandomSampler | 带权随机采样器,依照长度为N的概率列表,在前N个样本中随机采样指定数目的数据。 |\n",
    "| SubsetRandomSampler | 子集随机采样器,在指定的索引范围内随机采样指定数目的数据。 |\n",
    "| PKSampler | PK采样器,在指定的数据集类别P中,每种类别各采样K条数据。 |\n",
    "| DistributedSampler | 分布式采样器,在分布式训练中对数据集分片进行采样。 |"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## MindSpore采样器\n",
    "\n",
    "下面以CIFAR-10数据集为例,介绍几种常用MindSpore采样器的使用方法。\n",
    "\n",
    "下载CIFAR-10数据集并解压到指定路径,在Jupyter Notebook中执行如下命令:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz --no-check-certificate\n",
    "!mkdir -p datasets\n",
    "!tar -xzf cifar-10-binary.tar.gz -C datasets\n",
    "!mkdir -p datasets/cifar-10-batches-bin/train datasets/cifar-10-batches-bin/test\n",
    "!mv -f datasets/cifar-10-batches-bin/test_batch.bin datasets/cifar-10-batches-bin/test\n",
    "!mv -f datasets/cifar-10-batches-bin/data_batch*.bin datasets/cifar-10-batches-bin/batches.meta.txt datasets/cifar-10-batches-bin/train"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "解压后数据集文件的目录结构如下:\n",
    "\n",
    "```text\n",
    "./datasets/cifar-10-batches-bin\n",
    "├── readme.html\n",
    "├── test\n",
    "│   └── test_batch.bin\n",
    "└── train\n",
    "    ├── batches.meta.txt\n",
    "    ├── data_batch_1.bin\n",
    "    ├── data_batch_2.bin\n",
    "    ├── data_batch_3.bin\n",
    "    ├── data_batch_4.bin\n",
    "    └── data_batch_5.bin\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### RandomSampler\n",
    "\n",
    "从索引序列中随机采样指定数目的数据。\n",
    "\n",
    "下面的样例使用随机采样器分别从CIFAR-10数据集中有放回和无放回地随机采样5个数据,并展示已加载数据的形状和标签。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "ds.config.set_seed(0)\n",
    "\n",
    "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n",
    "\n",
    "print(\"------ Without Replacement ------\")\n",
    "\n",
    "sampler = ds.RandomSampler(num_samples=5)\n",
    "dataset1 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n",
    "\n",
    "for data in dataset1.create_dict_iterator():\n",
    "    print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])\n",
    "\n",
    "print(\"------ With Replacement ------\")\n",
    "\n",
    "sampler = ds.RandomSampler(replacement=True, num_samples=5)\n",
    "dataset2 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n",
    "\n",
    "for data in dataset2.create_dict_iterator():\n",
    "    print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "------ Without Replacement ------\n",
      "Image shape: (32, 32, 3) , Label: 1\n",
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 0\n",
      "Image shape: (32, 32, 3) , Label: 4\n",
      "------ With Replacement ------\n",
      "Image shape: (32, 32, 3) , Label: 0\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 3\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 6\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### WeightedRandomSampler\n",
    "\n",
    "指定长度为N的采样概率列表,按照概率在前N个样本中随机采样指定数目的数据。\n",
    "\n",
    "下面的样例使用带权随机采样器从CIFAR-10数据集的前10个样本中按概率获取6个样本,并展示已读取数据的形状和标签。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "ds.config.set_seed(1)\n",
    "\n",
    "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n",
    "\n",
    "weights = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n",
    "sampler = ds.WeightedRandomSampler(weights, num_samples=6)\n",
    "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 6\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### SubsetRandomSampler\n",
    "\n",
    "从指定索引子序列中随机采样指定数目的数据。\n",
    "\n",
    "下面的样例使用子序列随机采样器从CIFAR-10数据集的指定子序列中抽样3个样本,并展示已读取数据的形状和标签。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "ds.config.set_seed(2)\n",
    "\n",
    "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n",
    "\n",
    "indices = [0, 1, 2, 3, 4, 5]\n",
    "sampler = ds.SubsetRandomSampler(indices, num_samples=3)\n",
    "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Image shape: (32, 32, 3) , Label: 1\n",
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 4\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### PKSampler\n",
    "\n",
    "在指定的数据集类别P中,每种类别各采样K条数据。\n",
    "\n",
    "下面的样例使用PK采样器从CIFAR-10数据集中每种类别抽样2个样本,最多20个样本,并展示已读取数据的形状和标签。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "ds.config.set_seed(3)\n",
    "\n",
    "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n",
    "\n",
    "sampler = ds.PKSampler(num_val=2, class_column='label', num_samples=20)\n",
    "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Image shape: (32, 32, 3) , Label: 0\n",
      "Image shape: (32, 32, 3) , Label: 0\n",
      "Image shape: (32, 32, 3) , Label: 1\n",
      "Image shape: (32, 32, 3) , Label: 1\n",
      "Image shape: (32, 32, 3) , Label: 2\n",
      "Image shape: (32, 32, 3) , Label: 2\n",
      "Image shape: (32, 32, 3) , Label: 3\n",
      "Image shape: (32, 32, 3) , Label: 3\n",
      "Image shape: (32, 32, 3) , Label: 4\n",
      "Image shape: (32, 32, 3) , Label: 4\n",
      "Image shape: (32, 32, 3) , Label: 5\n",
      "Image shape: (32, 32, 3) , Label: 5\n",
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 7\n",
      "Image shape: (32, 32, 3) , Label: 7\n",
      "Image shape: (32, 32, 3) , Label: 8\n",
      "Image shape: (32, 32, 3) , Label: 8\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 9\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### DistributedSampler\n",
    "\n",
    "在分布式训练中,对数据集分片进行采样。\n",
    "\n",
    "下面的样例使用分布式采样器将构建的数据集分为3片,在每个分片中采样3个数据样本,并展示已读取的数据。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "import numpy as np\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "data_source = [0, 1, 2, 3, 4, 5, 6, 7, 8]\n",
    "\n",
    "sampler = ds.DistributedSampler(num_shards=3, shard_id=0, shuffle=False, num_samples=3)\n",
    "dataset = ds.NumpySlicesDataset(data_source, column_names=[\"data\"], sampler=sampler)\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(data)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{'data': Tensor(shape=[], dtype=Int64, value= 0)}\n",
      "{'data': Tensor(shape=[], dtype=Int64, value= 3)}\n",
      "{'data': Tensor(shape=[], dtype=Int64, value= 6)}\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 自定义采样器\n",
    "\n",
    "用户可以继承`Sampler`基类,通过实现`__iter__`方法来自定义采样器的采样方式。\n",
    "\n",
    "下面的样例定义了一个从下标0至下标9间隔为2采样的采样器,将其作用于CIFAR-10数据集,并展示已读取数据的形状和标签。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "import mindspore.dataset as ds\n",
    "\n",
    "class MySampler(ds.Sampler):\n",
    "    def __iter__(self):\n",
    "        for i in range(0, 10, 2):\n",
    "            yield i\n",
    "\n",
    "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n",
    "\n",
    "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=MySampler())\n",
    "\n",
    "for data in dataset.create_dict_iterator():\n",
    "    print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Image shape: (32, 32, 3) , Label: 6\n",
      "Image shape: (32, 32, 3) , Label: 9\n",
      "Image shape: (32, 32, 3) , Label: 1\n",
      "Image shape: (32, 32, 3) , Label: 2\n",
      "Image shape: (32, 32, 3) , Label: 8\n"
     ]
    }
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}