数据迭代
Ascend
GPU
CPU
数据准备
概述
原始数据集通过数据集加载接口读取到内存,再通过数据增强操作进行数据变换,得到的数据集对象有两种常规的数据迭代方法:
创建迭代器进行数据迭代。
传入Model接口(如
model.train
、model.eval
等)进行迭代训练或推理。
创建迭代器进行数据迭代
数据集对象通常可以创建两种不同的迭代器来遍历数据,分别为元组迭代器和字典迭代器。
创建元组迭代器的接口为create_tuple_iterator
,创建字典迭代器的接口为create_dict_iterator
,具体使用方法如下。
首先,任意创建一个数据集对象作为演示说明。
[1]:
import mindspore.dataset as ds
np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
dataset = ds.NumpySlicesDataset(np_data, column_names=["data"], shuffle=False)
则可使用以下方法创建数据迭代器。
[2]:
# 创建元组迭代器
print("\n create tuple iterator")
for item in dataset.create_tuple_iterator():
print("item:\n", item[0])
# 创建字典迭代器
print("\n create dict iterator")
for item in dataset.create_dict_iterator():
print("item:\n", item["data"])
# 直接遍历数据集对象(等同于创建元组迭代器)
print("\n iterate dataset object directly")
for item in dataset:
print("item:\n", item[0])
# 使用enumerate方式遍历(等同于创建元组迭代器)
print("\n iterate dataset using enumerate")
for index, item in enumerate(dataset):
print("index: {}, item:\n {}".format(index, item[0]))
create tuple iterator
item:
[[1 2]
[3 4]]
item:
[[5 6]
[7 8]]
create dict iterator
item:
[[1 2]
[3 4]]
item:
[[5 6]
[7 8]]
iterate dataset object directly
item:
[[1 2]
[3 4]]
item:
[[5 6]
[7 8]]
iterate dataset using enumerate
index: 0, item:
[[1 2]
[3 4]]
index: 1, item:
[[5 6]
[7 8]]
此外,如果需要产生多个Epoch的数据,可以相应地调整入参num_epochs
的取值。相比于多次调用迭代器接口,直接设置Epoch数可以提高数据迭代的性能。
[3]:
# 创建元组迭代器产生2个Epoch的数据
epoch = 2
iterator = dataset.create_tuple_iterator(num_epochs=epoch)
for i in range(epoch):
print("epoch: ", i)
for item in iterator:
print("item:\n", item[0])
epoch: 0
item:
[[1 2]
[3 4]]
item:
[[5 6]
[7 8]]
epoch: 1
item:
[[1 2]
[3 4]]
item:
[[5 6]
[7 8]]
迭代器默认输出的数据类型为mindspore.Tensor
,如果希望得到numpy.ndarray
类型的数据,可以设置入参output_numpy=True
。
[4]:
# 默认输出类型为mindspore.Tensor
for item in dataset.create_tuple_iterator():
print("dtype: ", type(item[0]), "\nitem:", item[0])
# 设置输出类型为numpy.ndarray
for item in dataset.create_tuple_iterator(output_numpy=True):
print("dtype: ", type(item[0]), "\nitem:", item[0])
dtype: <class 'mindspore.common.tensor.Tensor'>
item: [[1 2]
[3 4]]
dtype: <class 'mindspore.common.tensor.Tensor'>
item: [[5 6]
[7 8]]
dtype: <class 'numpy.ndarray'>
item: [[1 2]
[3 4]]
dtype: <class 'numpy.ndarray'>
item: [[5 6]
[7 8]]
更详细的说明,请参考create_tuple_iterator 和create_dict_iterator的API文档。
传入Model接口进行迭代训练或推理
数据集对象创建后,可通过传入Model
接口,由接口内部进行数据迭代,并送入网络执行训练或推理。
[3]:
import numpy as np
from mindspore import ms_function
from mindspore import context, nn, Model
import mindspore.dataset as ds
import mindspore.ops as ops
def create_dataset():
np_data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
np_data = np.array(np_data, dtype=np.float16)
dataset = ds.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False)
return dataset
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.relu = ops.ReLU()
self.print = ops.Print()
@ms_function
def construct(self, x):
self.print(x)
return self.relu(x)
if __name__ == "__main__":
# it is supported to run in CPU, GPU or Ascend
context.set_context(mode=context.GRAPH_MODE)
dataset = create_dataset()
network = Net()
model = Model(network)
# do training, sink to device defaultly
model.train(epoch=1, train_dataset=dataset, dataset_sink_mode=True)
Model接口中的dataset_sink_mode
参数用于设置是否将数据下沉到Device。若设置为不下沉,则内部会创建上述迭代器,逐条遍历数据并送入网络;若设置为下沉,则内部会将数据直接发送给Device,并送入网络进行迭代训练或推理。
更加详细的使用方法,可参见Model基本使用。