mindspore.DatasetHelper
- class mindspore.DatasetHelper(dataset, dataset_sink_mode=True, sink_size=- 1, epoch_num=1)[源代码]
DatasetHelper是一个处理MindData数据集的类,提供数据集信息。
根据不同的上下文,改变数据集的迭代,在不同的上下文中使用相同的迭代。
说明
DatasetHelper的迭代将提供一个epoch的数据。
- 参数:
dataset (Dataset) - 训练数据集迭代器。数据集可以由数据集生成器API在 mindspore.dataset 中生成,例如
mindspore.dataset.ImageFolderDataset
。dataset_sink_mode (bool) - 如果值为True,使用
mindspore.ops.GetNext
在设备(Device)上通过数据通道中获取数据,否则在主机(Host)直接遍历数据集获取数据。默认值:True。sink_size (int) - 控制每个下沉中的数据量。如果 sink_size 为-1,则下沉每个epoch的完整数据集。如果 sink_size 大于0,则下沉每个epoch的 sink_size 数据。默认值:-1。
epoch_num (int) - 控制待发送的epoch数据量。默认值:1。
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn >>> from mindspore import dataset as ds >>> >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} >>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32) >>> set_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=False) >>> >>> net = nn.Dense(10, 5) >>> # Object of DatasetHelper is iterable >>> for next_element in set_helper: ... # `next_element` includes data and label, using data to run the net ... data = next_element[0] ... net(data)
- sink_size()[源代码]
获取每次迭代的 sink_size 。
样例:
>>> import mindspore as ms >>> import numpy as np >>> >>> # Define a dataset pipeline >>> def generator(): ... for i in range(5): ... yield (np.ones((32, 10)),) >>> >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"]) >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True, sink_size=-1) >>> >>> # if sink_size==-1, then will return the full size of source dataset. >>> sink_size = dataset_helper.sink_size()
- types_shapes()[源代码]
从当前配置中的数据集获取类型和形状(shape)。
样例:
>>> import mindspore as ms >>> import numpy as np >>> >>> # Define a dataset pipeline >>> def generator(): ... for i in range(5): ... yield (np.ones((32, 10)),) >>> >>> train_dataset = ms.dataset.GeneratorDataset(generator, ["data"]) >>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True) >>> >>> types, shapes = dataset_helper.types_shapes()