mindspore.dataset.Dataset.flat_map

Dataset.flat_map(func)[source]

Map func to each row in dataset and flatten the result.

Parameters

func (function) – A function that must take one numpy.ndarray as an argument and return a Dataset .

Returns

Dataset, dataset applied by the function.

Examples

>>> # 1) flat_map on one column dataset
>>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False)
>>>
>>> def repeat(array):
...     # create a NumpySlicesDataset with the array
...     data = ds.NumpySlicesDataset(array, shuffle=False)
...     # repeat the dataset twice
...     data = data.repeat(2)
...     return data
>>>
>>> dataset = dataset.flat_map(repeat)
>>> # [0, 1, 0, 1, 2, 3, 2, 3]
>>>
>>> # 2) flat_map on multi column dataset
>>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False)
>>>
>>> def plus_and_minus(col1, col2):
...     # apply different methods on columns
...     data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False)
...     return data
>>>
>>> dataset = dataset.flat_map(plus_and_minus)
>>> # ([1, 2, 3, 4], [-1, -2, -3, -4])
Raises
  • TypeError – If func is not a function.

  • TypeError – If func doesn’t return a Dataset.