mindspore.connect_network_with_dataset

mindspore.connect_network_with_dataset(network, dataset_helper)[源代码]

networkdataset_helper 中的数据集连接,只支持 下沉模式,(dataset_sink_mode=True)。

参数:
  • network (Cell) - 数据集的训练网络。

  • dataset_helper (DatasetHelper) - 一个处理MindData数据集的类,提供了数据集的类型、形状(shape)和队列名称。

返回:

Cell,一个新网络,包含数据集的类型、形状(shape)和队列名称信息。

异常:
  • RuntimeError - 如果该接口在非数据下沉模式调用。

支持平台:

Ascend GPU

样例:

>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import nn
>>> from mindspore import dataset as ds
>>>
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> dataset_helper = ms.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>> net = nn.Dense(10, 5)
>>> net_with_dataset = ms.connect_network_with_dataset(net, dataset_helper)