mindspore.dataset

该模块提供了加载和处理各种通用数据集的API,如MNIST、CIFAR-10、CIFAR-100、VOC、COCO、ImageNet、CelebA、CLUE等, 也支持加载业界标准格式的数据集,包括MindRecord、TFRecord、Manifest等。此外,用户还可以使用此模块定义和加载自己的数据集。

该模块还提供了在加载时进行数据采样的API,如SequentialSample、RandomSampler、DistributedSampler等。

大多数数据集可以通过指定参数 cache 启用缓存服务,以提升整体数据处理效率。 请注意Windows平台上还不支持缓存服务,因此在Windows上加载和处理数据时,请勿使用。更多介绍和限制, 请参考 Single-Node Tensor Cache

在API示例中,常用的模块导入方法如下:

import mindspore.dataset as ds
import mindspore.dataset.transforms as transforms
import mindspore.dataset.vision as vision

常用数据集术语说明如下:

  • Dataset,所有数据集的基类,提供了数据处理方法来帮助预处理数据。

  • SourceDataset,一个抽象类,表示数据集管道的来源,从文件和数据库等数据源生成数据。

  • MappableDataset,一个抽象类,表示支持随机访问的源数据集。

  • Iterator,用于枚举元素的数据集迭代器的基类。

数据处理Pipeline介绍

../_images/dataset_pipeline.png

如上图所示,MindSpore Dataset模块使得用户很简便地定义数据预处理Pipeline,并以最高效(多进程/多线程)的方式处理 数据集中样本,具体的步骤参考如下:

  • 加载数据集(Dataset):用户可以方便地使用 *Dataset 类来加载已支持的数据集,或者通过 UDF Loader + GeneratorDataset 实现Python层自定义数据集的加载,同时加载类方法可以使用多种Sampler、数据分片、数据shuffle等功能;

  • 数据集操作(filter/ skip):用户通过数据集对象方法 .shuffle / .filter / .skip / .split / .take / … 来实现数据集的进一步混洗、过滤、跳过、最多获取条数等操作;

  • 数据集样本增强操作(map):用户可以将数据增强操作 (vision类nlp类audio类 ) 添加到map操作中执行,数据预处理过程中可以定义多个map操作,用于执行不同增强操作,数据增强操作也可以是 用户自定义增强的 PyFunc

  • 批(batch):用户在样本完成增强后,使用 .batch 操作将多个样本组织成batch,也可以通过batch的参数 per_batch_map 来自定义batch逻辑;

  • 迭代器(create_dict_iterator):最后用户通过数据集对象方法 create_dict_iterator 来创建迭代器, 可以将预处理完成的数据循环输出。

数据处理Pipeline示例如下,完整示例请参考 datasets_example.py

import numpy as np
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms

# 构造图像和标签
data1 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data2 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data3 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)
data4 = np.array(np.random.sample(size=(300, 300, 3)) * 255, dtype=np.uint8)

label = [1, 2, 3, 4]

# 加载数据集
dataset = ds.NumpySlicesDataset(([data1, data2, data3, data4], label), ["data", "label"])

# 对data数据增强
dataset = dataset.map(operations=vision.RandomCrop(size=(250, 250)), input_columns="data")
dataset = dataset.map(operations=vision.Resize(size=(224, 224)), input_columns="data")
dataset = dataset.map(operations=vision.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                                  std=[0.229 * 255, 0.224 * 255, 0.225 * 255]),
                      input_columns="data")
dataset = dataset.map(operations=vision.HWC2CHW(), input_columns="data")

# 对label变换类型
dataset = dataset.map(operations=transforms.TypeCast(ms.int32), input_columns="label")

# batch操作
dataset = dataset.batch(batch_size=2)

# 创建迭代器
epochs = 2
ds_iter = dataset.create_dict_iterator(output_numpy=True, num_epochs=epochs)
for _ in range(epochs):
    for item in ds_iter:
        print("item: {}".format(item), flush=True)

视觉

mindspore.dataset.Caltech101Dataset

读取和解析Caltech101数据集的源文件构建数据集。

mindspore.dataset.Caltech256Dataset

读取和解析Caltech256数据集的源文件构建数据集。

mindspore.dataset.CelebADataset

读取和解析CelebA数据集的源文件构建数据集。

mindspore.dataset.Cifar10Dataset

读取和解析CIFAR-10数据集的源文件构建数据集。

mindspore.dataset.Cifar100Dataset

读取和解析CIFAR-100数据集的源文件构建数据集。

mindspore.dataset.CityscapesDataset

读取和解析Cityscapes数据集的源文件构建数据集。

mindspore.dataset.CocoDataset

读取和解析COCO数据集的源文件构建数据集。

mindspore.dataset.DIV2KDataset

读取和解析DIV2K数据集的源文件构建数据集。

mindspore.dataset.EMnistDataset

读取和解析EMNIST数据集的源文件构建数据集。

mindspore.dataset.FakeImageDataset

生成虚假图像构建数据集。

mindspore.dataset.FashionMnistDataset

读取和解析Fashion-MNIST数据集的源文件构建数据集。

mindspore.dataset.FlickrDataset

读取和解析Flickr8k和Flickr30k数据集的源文件构建数据集。

mindspore.dataset.Flowers102Dataset

读取和解析Flowers102数据集的源文件构建数据集。

mindspore.dataset.ImageFolderDataset

从树状结构的文件目录中读取图片构建源数据集。

mindspore.dataset.KMnistDataset

读取和解析KMNIST数据集的源文件构建数据集。

mindspore.dataset.ManifestDataset

读取和解析Manifest数据文件构建数据集。

mindspore.dataset.MnistDataset

读取和解析MNIST数据集的源文件构建数据集。

mindspore.dataset.PhotoTourDataset

读取和解析PhotoTour数据集的源数据集。

mindspore.dataset.Places365Dataset

读取和解析Places365数据集的源数据集。

mindspore.dataset.QMnistDataset

读取和解析QMNIST数据集的源文件构建数据集。

mindspore.dataset.SBDataset

读取和解析Semantic Boundaries数据集的源文件构建数据集。

mindspore.dataset.SBUDataset

读取和解析SBU数据集的源文件构建数据集。

mindspore.dataset.SemeionDataset

读取和解析Semeion数据集的源文件构建数据集。

mindspore.dataset.STL10Dataset

读取和解析STL10数据集的源文件构建数据集。

mindspore.dataset.SVHNDataset

读取和解析SVHN数据集的源文件构建数据集。

mindspore.dataset.USPSDataset

读取和解析USPS数据集的源数据集。

mindspore.dataset.VOCDataset

读取和解析VOC数据集的源文件构建数据集。

mindspore.dataset.WIDERFaceDataset

读取和解析WIDERFace数据集的源数据集。

文本

mindspore.dataset.AGNewsDataset

读取和解析AG News数据集的源文件构建数据集。

mindspore.dataset.AmazonReviewDataset

读取和解析Amazon Review Full和Amazon Review Polarity数据集的源数据集。

mindspore.dataset.CLUEDataset

读取和解析CLUE数据集的源文件构建数据集。

mindspore.dataset.CoNLL2000Dataset

读取和解析CoNLL2000分块数据集的源数据集。

mindspore.dataset.DBpediaDataset

读取和解析DBpedia数据集的源数据集。

mindspore.dataset.EnWik9Dataset

读取和解析EnWik9 Full和EnWik9 Polarity数据集。

mindspore.dataset.IMDBDataset

读取和解析互联网电影数据库(IMDb)的源数据集。

mindspore.dataset.IWSLT2016Dataset

读取和解析IWSLT2016数据集的源数据集。

mindspore.dataset.IWSLT2017Dataset

读取和解析IWSLT2017数据集的源数据集。

mindspore.dataset.PennTreebankDataset

读取和解析PennTreebank数据集的源数据集。

mindspore.dataset.SogouNewsDataset

读取和解析SogouNew数据集的源数据集。

mindspore.dataset.TextFileDataset

读取和解析文本文件构建数据集。

mindspore.dataset.UDPOSDataset

读取和解析UDPOS数据集的源数据集。

mindspore.dataset.WikiTextDataset

读取和解析WikiText2和WikiText103数据集。

mindspore.dataset.YahooAnswersDataset

读取和解析YahooAnswers数据集的源数据集。

mindspore.dataset.YelpReviewDataset

读取和解析Yelp Review Full和Yelp Review Polarity数据集的源数据集。

音频

mindspore.dataset.LJSpeechDataset

读取和解析LJSpeech数据集的源文件构建数据集。

mindspore.dataset.SpeechCommandsDataset

读取和解析SpeechCommands数据集的源数据集。

mindspore.dataset.TedliumDataset

读取和解析Tedlium数据集的源数据集。

mindspore.dataset.YesNoDataset

读取和解析YesNo数据集的源数据集。

标准格式

mindspore.dataset.CSVDataset

读取和解析CSV数据文件构建数据集。

mindspore.dataset.MindDataset

读取和解析MindRecord数据文件构建数据集。

mindspore.dataset.OBSMindDataset

读取和解析存放在华为云OBS、Minio以及AWS S3等云存储上的MindRecord格式数据集。

mindspore.dataset.TFRecordDataset

读取和解析TFData格式的数据文件构建数据集。

用户自定义

mindspore.dataset.GeneratorDataset

自定义Python数据源,通过迭代该数据源构造数据集。

mindspore.dataset.NumpySlicesDataset

由Python数据构建数据集。

mindspore.dataset.PaddedDataset

由用户提供的填充数据构建数据集。

mindspore.dataset.RandomDataset

生成随机数据的源数据集。

mindspore.dataset.ArgoverseDataset

加载argoverse数据集并进行图(Graph)初始化。

mindspore.dataset.Graph

主要用于存储图的结构信息和图特征属性,并提供图采样等能力。

mindspore.dataset.GraphData

从共享文件或数据库中读取用于GNN训练的图数据集。

mindspore.dataset.InMemoryGraphDataset

用于将图数据加载到内存中的Dataset基类。

采样器

mindspore.dataset.DistributedSampler

分布式采样器,将数据集进行分片用于分布式训练。

mindspore.dataset.PKSampler

为数据集中每P个类别各采样K个样本。

mindspore.dataset.RandomSampler

随机采样器。

mindspore.dataset.SequentialSampler

按数据集的读取顺序采样数据集样本,相当于不使用采样器。

mindspore.dataset.SubsetRandomSampler

给定样本的索引序列,从序列中随机获取索引对数据集进行采样。

mindspore.dataset.SubsetSampler

给定样本的索引序列,对数据集采样指定索引的样本。

mindspore.dataset.WeightedRandomSampler

给定样本的权重列表,根据权重决定样本的采样概率,随机采样[0,len(weights) - 1]中的样本。

配置

config模块能够设置或获取数据处理的全局配置参数。

mindspore.dataset.config.set_sending_batches

在昇腾设备中使用sink_mode=True进行训练时,设置默认的发送批次。

mindspore.dataset.config.load

根据文件内容加载项目配置文件。

mindspore.dataset.config.set_seed

设置随机种子,产生固定的随机数来达到确定的结果。

mindspore.dataset.config.get_seed

获取随机数的种子。

mindspore.dataset.config.set_prefetch_size

设置管道中线程的队列容量。

mindspore.dataset.config.get_prefetch_size

获取数据处理管道的输出缓存队列长度。

mindspore.dataset.config.set_num_parallel_workers

为并行工作线程数量设置新的全局配置默认值。

mindspore.dataset.config.get_num_parallel_workers

获取并行工作线程数量的全局配置。

mindspore.dataset.config.set_numa_enable

设置NUMA的默认状态为启动状态。

mindspore.dataset.config.get_numa_enable

获取NUMA的启动/禁用状态。

mindspore.dataset.config.set_monitor_sampling_interval

设置监测采样的默认间隔时间(毫秒)。

mindspore.dataset.config.get_monitor_sampling_interval

获取性能监控采样时间间隔的全局配置。

mindspore.dataset.config.set_callback_timeout

mindspore.dataset.WaitedDSCallback 设置的默认超时时间(秒)。

mindspore.dataset.config.get_callback_timeout

获取 mindspore.dataset.WaitedDSCallback 的默认超时时间。

mindspore.dataset.config.set_auto_num_workers

自动为每个数据集操作设置并行线程数量(默认情况下,此功能关闭)。

mindspore.dataset.config.get_auto_num_workers

获取当前是否开启自动线程调整。

mindspore.dataset.config.set_enable_shared_mem

设置共享内存标志的是否启用。

mindspore.dataset.config.get_enable_shared_mem

获取当前是否开启共享内存。

mindspore.dataset.config.set_enable_autotune

设置是否开启自动数据加速。

mindspore.dataset.config.get_enable_autotune

获取当前是否开启自动数据加速。

mindspore.dataset.config.set_autotune_interval

设置自动数据加速的配置调整step间隔。

mindspore.dataset.config.get_autotune_interval

获取当前自动数据加速的配置调整step间隔。

mindspore.dataset.config.set_auto_offload

设置是否开启数据异构加速。

mindspore.dataset.config.get_auto_offload

获取当前是否开启数据异构加速。

mindspore.dataset.config.set_enable_watchdog

设置watchdog Python线程是否启用。

mindspore.dataset.config.get_enable_watchdog

获取当前是否开启watchdog Python线程。

mindspore.dataset.config.set_fast_recovery

在数据集管道故障恢复时,是否开启快速恢复模式(快速恢复模式下,无法保证随机性的数据增强操作得到与故障之前相同的结果)。

mindspore.dataset.config.get_fast_recovery

获取当前数据管道是否开启快速恢复模式。

mindspore.dataset.config.set_multiprocessing_timeout_interval

设置在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的默认时间间隔(秒)。

mindspore.dataset.config.get_multiprocessing_timeout_interval

获取在多进程/多线程下,主进程/主线程获取数据超时时,告警日志打印的时间间隔的全局配置。

其他

mindspore.dataset.BatchInfo

此类提供了两种方法获取数据集的批处理数量(batch size)和迭代数(epoch)属性,。

mindspore.dataset.DatasetCache

创建数据缓存客户端实例。

mindspore.dataset.DSCallback

数据处理回调类的抽象基类,用户可以基于此类实现自己的回调操作。

mindspore.dataset.SamplingStrategy

指定图数据采样策略的枚举类。

mindspore.dataset.Schema

用于解析和存储数据列属性的类。

mindspore.dataset.Shuffle

指定混洗模式的枚举类。

mindspore.dataset.WaitedDSCallback

阻塞式数据处理回调类的抽象基类,用于与训练回调类 mindspore.train.Callback 的同步。

mindspore.dataset.OutputFormat

通过API get_all_neighbors 获取所有相邻节点时,指定节点的存储格式。

mindspore.dataset.compare

比较两个数据处理管道是否相同。

mindspore.dataset.deserialize

数据处理管道反序列化,支持输入Python字典或使用 mindspore.dataset.serialize() 接口生成的JSON文件。

mindspore.dataset.serialize

将数据处理管道序列化成JSON文件。

mindspore.dataset.show

将数据处理管道图写入logger.info文件。

mindspore.dataset.sync_wait_for_dataset

等待所有的卡需要的数据集文件下载完成。

mindspore.dataset.utils.imshow_det_bbox

使用给定的边界框和类别置信度绘制图像。