mindspore.dataset.Dataset.flat_map

mindspore.dataset.Dataset.flat_map(func)[源代码]

对数据集对象中每一条数据执行给定的数据处理,并将结果展平。

参数:
  • func (function) - 数据处理函数,要求输入必须为一个 numpy.ndarray ,返回值是一个 Dataset 对象。

返回:

执行给定操作后的数据集对象。

异常:
  • TypeError - func 不是函数。

  • TypeError - func 的返回值不是 Dataset 对象。

样例:

>>> import mindspore.dataset as ds
>>> # 1) flat_map on one column dataset
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
>>>
>>> def repeat(array):
...     # create a NumpySlicesDataset with the array
...     data = ds.NumpySlicesDataset(array, shuffle=False)
...     # repeat the dataset twice
...     data = data.repeat(2)
...     return data
>>>
>>> dataset = dataset.flat_map(repeat)
>>> # [0, 1, 0, 1, 2, 3, 2, 3]
>>>
>>> # 2) flat_map on multi column dataset
>>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False)
>>>
>>> def plus_and_minus(col1, col2):
...     # apply different methods on columns
...     data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
...     return data
>>>
>>> dataset = dataset.flat_map(plus_and_minus)
>>> # ([1, 2, 3, 4], [-1, -2, -3, -4])