数据迭代

Ascend GPU CPU 数据准备

下载样例代码下载Notebook查看源文件

概述

原始数据集通过数据集加载接口读取到内存,再通过数据增强操作进行数据变换,得到的数据集对象有两种常规的数据迭代方法:

  • 创建迭代器进行数据迭代。

  • 传入Model接口(如model.trainmodel.eval等)进行迭代训练或推理。

创建迭代器进行数据迭代

数据集对象通常可以创建两种不同的迭代器来遍历数据,分别为元组迭代器和字典迭代器。

创建元组迭代器的接口为create_tuple_iterator,创建字典迭代器的接口为create_dict_iterator,具体使用方法如下。

首先,任意创建一个数据集对象作为演示说明。

[1]:
import mindspore.dataset as ds

np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], shuffle=False)

则可使用以下方法创建数据迭代器。

[2]:
# 创建元组迭代器
print("\n create tuple iterator")
for item in dataset.create_tuple_iterator():
    print("item:\n", item[0])

# 创建字典迭代器
print("\n create dict iterator")
for item in dataset.create_dict_iterator():
    print("item:\n", item["data"])

# 直接遍历数据集对象(等同于创建元组迭代器)
print("\n iterate dataset object directly")
for item in dataset:
    print("item:\n", item[0])

# 使用enumerate方式遍历(等同于创建元组迭代器)
print("\n iterate dataset using enumerate")
for index, item in enumerate(dataset):
    print("index: {}, item:\n {}".format(index, item[0]))

 create tuple iterator
item:
 [[1 2]
 [3 4]]
item:
 [[5 6]
 [7 8]]

 create dict iterator
item:
 [[1 2]
 [3 4]]
item:
 [[5 6]
 [7 8]]

 iterate dataset object directly
item:
 [[1 2]
 [3 4]]
item:
 [[5 6]
 [7 8]]

 iterate dataset using enumerate
index: 0, item:
 [[1 2]
 [3 4]]
index: 1, item:
 [[5 6]
 [7 8]]

此外,如果需要产生多个Epoch的数据,可以相应地调整入参num_epochs的取值。相比于多次调用迭代器接口,直接设置Epoch数可以提高数据迭代的性能。

[3]:
# 创建元组迭代器产生2个Epoch的数据
epoch = 2
iterator = dataset.create_tuple_iterator(num_epochs=epoch)
for i in range(epoch):
    print("epoch: ", i)
    for item in iterator:
        print("item:\n", item[0])
epoch:  0
item:
 [[1 2]
 [3 4]]
item:
 [[5 6]
 [7 8]]
epoch:  1
item:
 [[1 2]
 [3 4]]
item:
 [[5 6]
 [7 8]]

迭代器默认输出的数据类型为mindspore.Tensor,如果希望得到numpy.ndarray类型的数据,可以设置入参output_numpy=True

[4]:
# 默认输出类型为mindspore.Tensor
for item in dataset.create_tuple_iterator():
    print("dtype: ", type(item[0]), "\nitem:", item[0])

# 设置输出类型为numpy.ndarray
for item in dataset.create_tuple_iterator(output_numpy=True):
    print("dtype: ", type(item[0]), "\nitem:", item[0])
dtype:  <class 'mindspore.common.tensor.Tensor'>
item: [[1 2]
 [3 4]]
dtype:  <class 'mindspore.common.tensor.Tensor'>
item: [[5 6]
 [7 8]]
dtype:  <class 'numpy.ndarray'>
item: [[1 2]
 [3 4]]
dtype:  <class 'numpy.ndarray'>
item: [[5 6]
 [7 8]]

更详细的说明,请参考create_tuple_iteratorcreate_dict_iterator的API文档。

传入Model接口进行迭代训练或推理

数据集对象创建后,可通过传入Model接口,由接口内部进行数据迭代,并送入网络执行训练或推理。

[3]:
import numpy as np
from mindspore import ms_function
from mindspore import context, nn, Model
import mindspore.dataset as ds
import mindspore.ops as ops


def create_dataset():
    np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
    np_data = np.array(np_data, dtype=np.float16)
    dataset = ds.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
    return dataset


class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.relu = ops.ReLU()
        self.print = ops.Print()

    @ms_function
    def construct(self, x):
        self.print(x)
        return self.relu(x)


if __name__ == "__main__":
    # it is supported to run in CPU, GPU or Ascend
    context.set_context(mode=context.GRAPH_MODE)
    dataset = create_dataset()
    network = Net()
    model = Model(network)

    # do training, sink to device defaultly
    model.train(epoch=1, train_dataset=dataset, dataset_sink_mode=True)

Model接口中的dataset_sink_mode参数用于设置是否将数据下沉到Device。若设置为不下沉,则内部会创建上述迭代器,逐条遍历数据并送入网络;若设置为下沉,则内部会将数据直接发送给Device,并送入网络进行迭代训练或推理。

更加详细的使用方法,可参见Model基本使用