mindspore.DatasetHelper

class mindspore.DatasetHelper(dataset, dataset_sink_mode=True, sink_size=- 1, epoch_num=1)[源代码]

DatasetHelper是一个处理MindData数据集的类,提供数据集信息。

根据不同的上下文,改变数据集的迭代,在不同的上下文中使用相同的迭代。

说明

DatasetHelper的迭代将提供一个epoch的数据。

参数:
  • dataset (Dataset) - 训练数据集迭代器。数据集可以由数据集生成器API在 mindspore.dataset 中生成,例如 mindspore.dataset.ImageFolderDataset

  • dataset_sink_mode (bool) - 如果值为True,使用 mindspore.ops.GetNext 在设备(Device)上通过数据通道中获取数据,否则在主机(Host)直接遍历数据集获取数据。默认值:True。

  • sink_size (int) - 控制每个下沉中的数据量。如果 sink_size 为-1,则下沉每个epoch的完整数据集。如果 sink_size 大于0,则下沉每个epoch的 sink_size 数据。默认值:-1。

  • epoch_num (int) - 控制待发送的epoch数据量。默认值:1。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import dataset as ds
>>>
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> set_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=False)
>>>
>>> net = nn.Dense(10, 5)
>>> # Object of DatasetHelper is iterable
>>> for next_element in set_helper:
...     # `next_element` includes data and label, using data to run the net
...     data = next_element[0]
...     result = net(data)
continue_send()[源代码]

在epoch开始时继续向设备发送数据。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import dataset as ds
>>>
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>> dataset_helper.continue_send()
release()[源代码]

释放数据下沉资源。

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import dataset as ds
>>>
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>> dataset_helper.release()
sink_size()[源代码]

获取每次迭代的 sink_size

样例:

>>> import mindspore as ms
>>> import numpy as np
>>>
>>> # Define a dataset pipeline
>>> def generator():
...    for i in range(5):
...        yield (np.ones((32, 10)),)
>>>
>>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"])
>>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True, sink_size=-1)
>>>
>>> # if sink_size==-1, then will return the full size of source dataset.
>>> sink_size = dataset_helper.sink_size()
stop_send()[源代码]

停止发送数据下沉数据。

样例:

>>> import mindspore as ms
>>> import numpy as np
>>> # Define a dataset pipeline
>>> def generator():
...    for i in range(5):
...        yield (np.ones((32, 10)),)
>>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"])
>>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True, sink_size=-1)
>>> dataset_helper.stop_send()
types_shapes()[源代码]

从当前配置中的数据集获取类型和形状(shape)。

样例:

>>> import mindspore as ms
>>> import numpy as np
>>>
>>> # Define a dataset pipeline
>>> def generator():
...    for i in range(5):
...        yield (np.ones((32, 10)),)
>>>
>>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"])
>>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>>
>>> types, shapes = dataset_helper.types_shapes()