文档反馈

问题文档片段

问题文档片段包含公式时,显示为空格。

提交类型
issue

有点复杂...

找人问问吧。

请选择提交类型

问题类型
规范和低错类

- 规范和低错类:

- 错别字或拼写错误,标点符号使用错误、公式错误或显示异常。

- 链接错误、空单元格、格式错误。

- 英文中包含中文字符。

- 界面和描述不一致,但不影响操作。

- 表述不通顺,但不影响理解。

- 版本号不匹配:如软件包名称、界面版本号。

易用性

- 易用性:

- 关键步骤错误或缺失,无法指导用户完成任务。

- 缺少主要功能描述、关键词解释、必要前提条件、注意事项等。

- 描述内容存在歧义指代不明、上下文矛盾。

- 逻辑不清晰,该分类、分项、分步骤的没有给出。

正确性

- 正确性:

- 技术原理、功能、支持平台、参数类型、异常报错等描述和软件实现不一致。

- 原理图、架构图等存在错误。

- 命令、命令参数等错误。

- 代码片段错误。

- 命令无法完成对应功能。

- 界面错误,无法指导操作。

- 代码样例运行报错、运行结果不符。

风险提示

- 风险提示:

- 对重要数据或系统存在风险的操作,缺少安全提示。

内容合规

- 内容合规:

- 违反法律法规,涉及政治、领土主权等敏感词。

- 内容侵权。

请选择问题类型

问题描述

点击输入详细问题描述,以帮助我们快速定位问题。

数据加载及处理

image0image1image2

MindSpore提供了部分常用数据集和标准格式数据集的加载接口,用户可以直接使用mindspore.dataset中对应的数据集加载类进行数据加载。数据集类为用户提供了常用的数据处理接口,使得用户能够快速进行数据处理操作。

加载数据集

下面的样例通过Cifar10Dataset接口加载CIFAR-10数据集,使用顺序采样器获取前5个样本。

[1]:
import mindspore.dataset as ds

DATA_DIR = "./datasets/cifar-10-batches-bin/train"
sampler = ds.SequentialSampler(num_samples=5)
dataset = ds.Cifar10Dataset(DATA_DIR, sampler=sampler)

迭代数据集

用户可以用create_dict_iterator创建数据迭代器,迭代访问数据,下面展示了对应图片的形状和标签。

[2]:
for data in dataset.create_dict_iterator():
    print("Image shape: {}".format(data['image'].shape), ", Label: {}".format(data['label']))
Image shape: (32, 32, 3) , Label: 6
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 9
Image shape: (32, 32, 3) , Label: 4
Image shape: (32, 32, 3) , Label: 1

自定义数据集

对于目前MindSpore不支持直接加载的数据集,可以构造自定义数据集类,然后通过GeneratorDataset接口实现自定义方式的数据加载。

[3]:
import numpy as np

np.random.seed(58)

class DatasetGenerator:
    def __init__(self):
        self.data = np.random.sample((5, 2))
        self.label = np.random.sample((5, 1))

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return len(self.data)

其中用户需要自定义的类函数如下:

  • __init__

    实例化数据集对象时,__init__函数被调用,用户可以在此进行数据初始化等操作。

    def __init__(self):
        self.data = np.random.sample((5, 2))
        self.label = np.random.sample((5, 1))
    
  • __getitem__

    定义数据集类的__getitem__函数,使其支持随机访问,能够根据给定的索引值index,获取数据集中的数据并返回。

    def __getitem__(self, index):
        return self.data[index], self.label[index]
    
  • __len__

    定义数据集类的__len__函数,返回数据集的样本数量。

    def __len__(self):
        return len(self.data)
    

定义数据集类之后,就可以通过GeneratorDataset接口按照用户定义的方式加载并访问数据集样本。

[4]:
dataset_generator = DatasetGenerator()
dataset = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=False)

for data in dataset.create_dict_iterator():
    print('{}'.format(data["data"]), '{}'.format(data["label"]))
[0.36510558 0.45120592] [0.78888122]
[0.49606035 0.07562207] [0.38068183]
[0.57176158 0.28963401] [0.16271622]
[0.30880446 0.37487617] [0.54738768]
[0.81585667 0.96883469] [0.77994068]

数据处理及增强

数据处理

MindSpore提供的数据集接口具备常用的数据处理方法,用户只需调用相应的函数接口即可快速进行数据处理。

下面的样例先将数据集随机打乱顺序,然后将样本两两组成一个批次。

[5]:
ds.config.set_seed(58)

# 随机打乱数据顺序
dataset = dataset.shuffle(buffer_size=10)
# 对数据集进行分批
dataset = dataset.batch(batch_size=2)

for data in dataset.create_dict_iterator():
    print("data: {}".format(data["data"]))
    print("label: {}".format(data["label"]))
data: [[0.36510558 0.45120592]
 [0.57176158 0.28963401]]
label: [[0.78888122]
 [0.16271622]]
data: [[0.30880446 0.37487617]
 [0.49606035 0.07562207]]
label: [[0.54738768]
 [0.38068183]]
data: [[0.81585667 0.96883469]]
label: [[0.77994068]]

其中,

buffer_size:数据集中进行shuffle操作的缓存区的大小。

batch_size:每组包含的数据个数,现设置每组包含2个数据。

数据增强

数据量过小或是样本场景单一等问题会影响模型的训练效果,用户可以通过数据增强操作扩充样本多样性,从而提升模型的泛化能力。

下面的样例使用mindspore.dataset.vision.c_transforms模块中的算子对MNIST数据集进行数据增强。

导入c_transforms模块,加载MNIST数据集。

[6]:
import matplotlib.pyplot as plt

from mindspore.dataset.vision import Inter
import mindspore.dataset.vision.c_transforms as c_vision

DATA_DIR = './datasets/MNIST_Data/train'

mnist_dataset = ds.MnistDataset(DATA_DIR, num_samples=6, shuffle=False)

# 查看数据原图
mnist_it = mnist_dataset.create_dict_iterator()
data = next(mnist_it)
plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)
plt.title(data['label'].asnumpy(), fontsize=20)
plt.show()
_images/dataset_13_0.png

定义数据增强算子,对数据集进行ResizeRandomCrop操作,然后通过map映射将其插入数据处理管道。

[7]:
resize_op = c_vision.Resize(size=(200,200), interpolation=Inter.LINEAR)
crop_op = c_vision.RandomCrop(150)
transforms_list = [resize_op, crop_op]
mnist_dataset = mnist_dataset.map(operations=transforms_list, input_columns=["image"])

查看数据增强效果。

[8]:
mnist_dataset = mnist_dataset.create_dict_iterator()
data = next(mnist_dataset)
plt.imshow(data['image'].asnumpy().squeeze(), cmap=plt.cm.gray)
plt.title(data['label'].asnumpy(), fontsize=20)
plt.show()
_images/dataset_17_0.png

想要了解更多可以参考编程指南中数据增强章节。