mindspore.ops.GetNext

class mindspore.ops.GetNext(types, shapes, output_num, shared_name)[源代码]

返回数据集队列中的下一个元素。

Note

GetNext操作需要联网,且依赖’dataset’接口,例如: mindspore.dataset.MnistDataset 。不能单独操作。详见 mindspore.connect_network_with_dataset 的源码。

参数:

  • types (list[mindspore.dtype]) - 输出的数据类型。

  • shapes (list[tuple[int]]) - 输出数据的shape大小。

  • output_num (int) - 输出编号、 typesshapes 的长度。

  • shared_name (str) - 待获取数据的队列名称。

输入:

没有输入。

输出:

tuple[Tensor],Dataset的输出。Shape和类型参见 shapestypes

支持平台:

Ascend GPU

样例:

>>> import mindspore
>>> from mindspore import ops
>>> from mindspore import dataset as ds
>>> from mindspore.common import dtype as mstype
>>> data_path = "/path/to/MNIST_Data/train/"
>>> train_dataset = ds.MnistDataset(data_path, num_samples=10)
>>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>> dataset = dataset_helper.iter.dataset
>>> dataset_types, dataset_shapes = dataset_helper.types_shapes()
>>> queue_name = dataset.__transfer_dataset__.queue_name
>>> get_next = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
>>> data, label = get_next()
>>> relu = ops.ReLU()
>>> result = relu(data.astype(mstype.float32))
>>> print(result.shape)
(28, 28, 1)