{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 采样器\n", "\n", "[![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.2/docs/programming_guide/source_zh_cn/sampler.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.2/programming_guide/mindspore_sampler.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.2/docs/programming_guide/source_zh_cn/_static/logo_modelarts.png)](https://console.huaweicloud.com/modelarts/?region=cn-north-4#/notebook/loading?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svbW9kZWxhcnRzL3Byb2dyYW1taW5nX2d1aWRlL21pbmRzcG9yZV90b2tlbml6ZXIuaXB5bmI=&image_id=65f636a0-56cf-49df-b941-7d2a07ba8c8c)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 概述\n", "\n", "MindSpore提供了多种用途的采样器(Sampler),帮助用户对数据集进行不同形式的采样,以满足训练需求,能够解决诸如数据集过大或样本类别分布不均等问题。只需在加载数据集时传入采样器对象,即可实现数据的采样。\n", "\n", "MindSpore目前提供的部分采样器类别如下表所示。此外,用户也可以根据需要实现自定义的采样器类。更多采样器的使用方法参见[API文档](https://www.mindspore.cn/doc/api_python/zh-CN/r1.2/mindspore/mindspore.dataset.html)。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| 采样器名称 | 采样器说明 |\n", "| :---- | :---- |\n", "| RandomSampler | 随机采样器,在数据集中随机地采样指定数目的数据。 |\n", "| WeightedRandomSampler | 带权随机采样器,依照长度为N的概率列表,在前N个样本中随机采样指定数目的数据。 |\n", "| SubsetRandomSampler | 子集随机采样器,在指定的索引范围内随机采样指定数目的数据。 |\n", "| PKSampler | PK采样器,在指定的数据集类别P中,每种类别各采样K条数据。 |\n", "| DistributedSampler | 分布式采样器,在分布式训练中对数据集分片进行采样。 |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MindSpore采样器\n", "\n", "下面以CIFAR-10数据集为例,介绍几种常用MindSpore采样器的使用方法。\n", "\n", "下载CIFAR-10数据集并解压到指定路径,执行如下命令:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "./datasets/cifar-10-batches-bin\n", "├── readme.html\n", "├── test\n", "│   └── test_batch.bin\n", "└── train\n", " ├── batches.meta.txt\n", " ├── data_batch_1.bin\n", " ├── data_batch_2.bin\n", " ├── data_batch_3.bin\n", " ├── data_batch_4.bin\n", " └── data_batch_5.bin\n", "\n", "2 directories, 8 files\n" ] } ], "source": [ "!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz\n", "!mkdir -p datasets\n", "!tar -xzf cifar-10-binary.tar.gz -C datasets\n", "!mkdir -p datasets/cifar-10-batches-bin/train datasets/cifar-10-batches-bin/test\n", "!mv -f datasets/cifar-10-batches-bin/test_batch.bin datasets/cifar-10-batches-bin/test\n", "!mv -f datasets/cifar-10-batches-bin/data_batch*.bin datasets/cifar-10-batches-bin/batches.meta.txt datasets/cifar-10-batches-bin/train\n", "!tree ./datasets/cifar-10-batches-bin" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RandomSampler\n", "\n", "从索引序列中随机采样指定数目的数据。\n", "\n", "下面的样例使用随机采样器分别从CIFAR-10数据集中有放回和无放回地随机采样5个数据,并展示已加载数据的形状和标签。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "------ Without Replacement ------\n", "Image shape: (32, 32, 3) , Label: 1\n", "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 0\n", "Image shape: (32, 32, 3) , Label: 4\n", "------ With Replacement ------\n", "Image shape: (32, 32, 3) , Label: 0\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 3\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 6\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "\n", "ds.config.set_seed(0)\n", "\n", "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n", "\n", "print(\"------ Without Replacement ------\")\n", "\n", "sampler = ds.RandomSampler(num_samples=5)\n", "dataset1 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n", "\n", "for data in dataset1.create_dict_iterator():\n", " print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])\n", "\n", "print(\"------ With Replacement ------\")\n", "\n", "sampler = ds.RandomSampler(replacement=True, num_samples=5)\n", "dataset2 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n", "\n", "for data in dataset2.create_dict_iterator():\n", " print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### WeightedRandomSampler\n", "\n", "指定长度为N的采样概率列表,按照概率在前N个样本中随机采样指定数目的数据。\n", "\n", "下面的样例使用带权随机采样器从CIFAR-10数据集的前10个样本中按概率获取6个样本,并展示已读取数据的形状和标签。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 6\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "\n", "ds.config.set_seed(1)\n", "\n", "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n", "\n", "weights = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]\n", "sampler = ds.WeightedRandomSampler(weights, num_samples=6)\n", "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SubsetRandomSampler\n", "\n", "从指定索引子序列中随机采样指定数目的数据。\n", "\n", "下面的样例使用子序列随机采样器从CIFAR-10数据集的指定子序列中抽样3个样本,并展示已读取数据的形状和标签。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image shape: (32, 32, 3) , Label: 1\n", "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 4\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "\n", "ds.config.set_seed(2)\n", "\n", "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n", "\n", "indices = [0, 1, 2, 3, 4, 5]\n", "sampler = ds.SubsetRandomSampler(indices, num_samples=3)\n", "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PKSampler\n", "\n", "在指定的数据集类别P中,每种类别各采样K条数据。\n", "\n", "下面的样例使用PK采样器从CIFAR-10数据集中每种类别抽样2个样本,最多20个样本,并展示已读取数据的形状和标签。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image shape: (32, 32, 3) , Label: 0\n", "Image shape: (32, 32, 3) , Label: 0\n", "Image shape: (32, 32, 3) , Label: 1\n", "Image shape: (32, 32, 3) , Label: 1\n", "Image shape: (32, 32, 3) , Label: 2\n", "Image shape: (32, 32, 3) , Label: 2\n", "Image shape: (32, 32, 3) , Label: 3\n", "Image shape: (32, 32, 3) , Label: 3\n", "Image shape: (32, 32, 3) , Label: 4\n", "Image shape: (32, 32, 3) , Label: 4\n", "Image shape: (32, 32, 3) , Label: 5\n", "Image shape: (32, 32, 3) , Label: 5\n", "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 7\n", "Image shape: (32, 32, 3) , Label: 7\n", "Image shape: (32, 32, 3) , Label: 8\n", "Image shape: (32, 32, 3) , Label: 8\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 9\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "\n", "ds.config.set_seed(3)\n", "\n", "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n", "\n", "sampler = ds.PKSampler(num_val=2, class_column='label', num_samples=20)\n", "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DistributedSampler\n", "\n", "在分布式训练中,对数据集分片进行采样。\n", "\n", "下面的样例使用分布式采样器将构建的数据集分为3片,在每个分片中采样3个数据样本,并展示已读取的数据。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'data': Tensor(shape=[], dtype=Int64, value= 0)}\n", "{'data': Tensor(shape=[], dtype=Int64, value= 3)}\n", "{'data': Tensor(shape=[], dtype=Int64, value= 6)}\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.dataset as ds\n", "\n", "data_source = [0, 1, 2, 3, 4, 5, 6, 7, 8]\n", "\n", "sampler = ds.DistributedSampler(num_shards=3, shard_id=0, shuffle=False, num_samples=3)\n", "dataset = ds.NumpySlicesDataset(data_source, column_names=[\"data\"], sampler=sampler)\n", "\n", "for data in dataset.create_dict_iterator():\n", " print(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 自定义采样器\n", "\n", "用户可以继承`Sampler`基类,通过实现`__iter__`方法来自定义采样器的采样方式。\n", "\n", "下面的样例定义了一个从下标0至下标9间隔为2采样的采样器,将其作用于CIFAR-10数据集,并展示已读取数据的形状和标签。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image shape: (32, 32, 3) , Label: 6\n", "Image shape: (32, 32, 3) , Label: 9\n", "Image shape: (32, 32, 3) , Label: 1\n", "Image shape: (32, 32, 3) , Label: 2\n", "Image shape: (32, 32, 3) , Label: 8\n" ] } ], "source": [ "import mindspore.dataset as ds\n", "\n", "class MySampler(ds.Sampler):\n", " def __iter__(self):\n", " for i in range(0, 10, 2):\n", " yield i\n", "\n", "DATA_DIR = \"./datasets/cifar-10-batches-bin/train/\"\n", "\n", "dataset = ds.Cifar10Dataset(DATA_DIR, sampler=MySampler())\n", "\n", "for data in dataset.create_dict_iterator():\n", " print(\"Image shape:\", data['image'].shape, \", Label:\", data['label'])" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore-1.1.1", "language": "python", "name": "mindspore-1.1.1" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }