mindspore.ops.GetNext
- class mindspore.ops.GetNext(types, shapes, output_num, shared_name)[源代码]
返回数据集队列中的下一个元素。
Note
GetNext操作需要联网,且依赖’dataset’接口,例如:
mindspore.dataset.MnistDataset
。不能单独操作。详见mindspore.connect_network_with_dataset
的源码。参数:
types (list[
mindspore.dtype
]) - 输出的数据类型。shapes (list[tuple[int]]) - 输出数据的shape大小。
output_num (int) - 输出编号、 types 和 shapes 的长度。
shared_name (str) - 待获取数据的队列名称。
输入:
没有输入。
输出:
tuple[Tensor],Dataset的输出。Shape和类型参见 shapes 、 types 。
- 支持平台:
Ascend
GPU
样例:
>>> import mindspore >>> from mindspore import ops >>> from mindspore import dataset as ds >>> from mindspore.common import dtype as mstype >>> data_path = "/path/to/MNIST_Data/train/" >>> train_dataset = ds.MnistDataset(data_path, num_samples=10) >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True) >>> dataset = dataset_helper.iter.dataset >>> dataset_types, dataset_shapes = dataset_helper.types_shapes() >>> queue_name = dataset.__transfer_dataset__.queue_name >>> get_next = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) >>> data, label = get_next() >>> relu = ops.ReLU() >>> result = relu(data.astype(mstype.float32)) >>> print(result.shape) (28, 28, 1)