比较与tf.data.Dataset.from_generator的功能差异
tf.data.Dataset.from_generator
@staticmethod
tf.data.Dataset.from_generator(
generator,
output_types,
output_shapes=None,
args=None
)
mindspore.dataset.GeneratorDataset
class mindspore.dataset.GeneratorDataset(
source,
column_names=None,
column_types=None,
schema=None,
num_samples=None,
num_parallel_workers=1,
shuffle=None,
sampler=None,
num_shards=None,
shard_id=None,
python_multiprocessing=True,
max_rowsize=6
)
使用方式
TensorFlow:一个静态方法,支持从可调用对象中创建数据集,并指定数据的类型和形状。
MindSpore:一个数据集类,支持从可调用对象、可迭代对象或可随机访问对象中创建数据集,通过 schema
指定数据的类型和形状。
代码示例
# The following implements GeneratorDataset with MindSpore.
import numpy as np
import mindspore.dataset as ds
def gen():
for i in range(1, 3):
yield np.array([i]), np.array([1] * i)
dataset = ds.GeneratorDataset(source=gen, column_names=["col1", "col2"])
for item in dataset.create_dict_iterator():
print(item["col1"], item["col2"])
# [1] [1]
# [2] [1 1]
# The following implements from_generator with TensorFlow.
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
def gen():
for i in range(1, 3):
yield i, [1] * i
dataset = tf.data.Dataset.from_generator(
gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
for value in dataset:
print(value)
# (1, array([1]))
# (2, array([1, 1]))