采样器

image0image1image2

概述

MindSpore提供了多种用途的采样器(Sampler),帮助用户对数据集进行不同形式的采样,以满足训练需求,能够解决诸如数据集过大或样本类别分布不均等问题。只需在加载数据集时传入采样器对象,即可实现数据的采样。

MindSpore目前提供的部分采样器类别如下表所示。此外,用户也可以根据需要实现自定义的采样器类。更多采样器的使用方法参见API文档

采样器名称

采样器说明

RandomSampler

随机采样器,在数据集中随机地采样指定数目的数据。

WeightedRandomSampler

带权随机采样器,依照长度为N的概率列表,在前N个样本中随机采样指定数目的数据。

SubsetRandomSampler

子集随机采样器,在指定的索引范围内随机采样指定数目的数据。

PKSampler

PK采样器,在指定的数据集类别P中,每种类别各采样K条数据。

DistributedSampler

分布式采样器,在分布式训练中对数据集分片进行采样。

MindSpore采样器

下面以CIFAR-10数据集为例,介绍几种常用MindSpore采样器的使用方法。

下载CIFAR-10数据集并解压到指定路径,执行如下命令:

[1]:
!wget -N https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz
!mkdir -p datasets
!tar -xzf cifar-10-binary.tar.gz -C datasets
!mkdir -p datasets/cifar-10-batches-bin/train datasets/cifar-10-batches-bin/test
!mv -f datasets/cifar-10-batches-bin/test_batch.bin datasets/cifar-10-batches-bin/test
!mv -f datasets/cifar-10-batches-bin/data_batch*.bin datasets/cifar-10-batches-bin/batches.meta.txt datasets/cifar-10-batches-bin/train
!tree ./datasets/cifar-10-batches-bin
./datasets/cifar-10-batches-bin
├── readme.html
├── test
│   └── test_batch.bin
└── train
    ├── batches.meta.txt
    ├── data_batch_1.bin
    ├── data_batch_2.bin
    ├── data_batch_3.bin
    ├── data_batch_4.bin
    └── data_batch_5.bin

2 directories, 8 files

RandomSampler

从索引序列中随机采样指定数目的数据。

下面的样例使用随机采样器分别从CIFAR-10数据集中有放回和无放回地随机采样5个数据,并展示已加载数据的形状和标签。

[2]:
import mindspore.dataset as ds

ds.config.set_seed(0)

DATA_DIR = "./datasets/cifar-10-batches-bin/train/"

print("------ Without Replacement ------")

sampler = ds.RandomSampler(num_samples=5)
dataset1 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

for data in dataset1.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])

print("------ With Replacement ------")

sampler = ds.RandomSampler(replacement=True, num_samples=5)
dataset2 = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

for data in dataset2.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])
------ Without Replacement ------
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 4
------ With Replacement ------
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 3
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 6

WeightedRandomSampler

指定长度为N的采样概率列表,按照概率在前N个样本中随机采样指定数目的数据。

下面的样例使用带权随机采样器从CIFAR-10数据集的前10个样本中按概率获取6个样本,并展示已读取数据的形状和标签。

[3]:
import mindspore.dataset as ds

ds.config.set_seed(1)

DATA_DIR = "./datasets/cifar-10-batches-bin/train/"

weights = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
sampler = ds.WeightedRandomSampler(weights, num_samples=6)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

for data in dataset.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 6

SubsetRandomSampler

从指定索引子序列中随机采样指定数目的数据。

下面的样例使用子序列随机采样器从CIFAR-10数据集的指定子序列中抽样3个样本,并展示已读取数据的形状和标签。

[4]:
import mindspore.dataset as ds

ds.config.set_seed(2)

DATA_DIR = "./datasets/cifar-10-batches-bin/train/"

indices = [0, 1, 2, 3, 4, 5]
sampler = ds.SubsetRandomSampler(indices, num_samples=3)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

for data in dataset.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 4

PKSampler

在指定的数据集类别P中,每种类别各采样K条数据。

下面的样例使用PK采样器从CIFAR-10数据集中每种类别抽样2个样本,最多20个样本,并展示已读取数据的形状和标签。

[5]:
import mindspore.dataset as ds

ds.config.set_seed(3)

DATA_DIR = "./datasets/cifar-10-batches-bin/train/"

sampler = ds.PKSampler(num_val=2, class_column='label', num_samples=20)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

for data in dataset.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 0
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 3
Image shape: (32, 32, 3) , Label: 3
Image shape: (32, 32, 3) , Label: 4
Image shape: (32, 32, 3) , Label: 4
Image shape: (32, 32, 3) , Label: 5
Image shape: (32, 32, 3) , Label: 5
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 7
Image shape: (32, 32, 3) , Label: 7
Image shape: (32, 32, 3) , Label: 8
Image shape: (32, 32, 3) , Label: 8
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 9

DistributedSampler

在分布式训练中,对数据集分片进行采样。

下面的样例使用分布式采样器将构建的数据集分为3片,在每个分片中采样3个数据样本,并展示已读取的数据。

[6]:
import numpy as np
import mindspore.dataset as ds

data_source = [0, 1, 2, 3, 4, 5, 6, 7, 8]

sampler = ds.DistributedSampler(num_shards=3, shard_id=0, shuffle=False, num_samples=3)
dataset = ds.NumpySlicesDataset(data_source, column_names=["data"], sampler=sampler)

for data in dataset.create_dict_iterator():
    print(data)
{'data': Tensor(shape=[], dtype=Int64, value= 0)}
{'data': Tensor(shape=[], dtype=Int64, value= 3)}
{'data': Tensor(shape=[], dtype=Int64, value= 6)}

自定义采样器

用户可以继承Sampler基类,通过实现__iter__方法来自定义采样器的采样方式。

下面的样例定义了一个从下标0至下标9间隔为2采样的采样器,将其作用于CIFAR-10数据集,并展示已读取数据的形状和标签。

[7]:
import mindspore.dataset as ds

class MySampler(ds.Sampler):
    def __iter__(self):
        for i in range(0, 10, 2):
            yield i

DATA_DIR = "./datasets/cifar-10-batches-bin/train/"

dataset = ds.Cifar10Dataset(DATA_DIR, sampler=MySampler())

for data in dataset.create_dict_iterator():
    print("Image shape:", data['image'].shape, ", Label:", data['label'])
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 1
Image shape: (32, 32, 3) , Label: 2
Image shape: (32, 32, 3) , Label: 8