mindspore.dataset.Dataset.flat_map
- mindspore.dataset.Dataset.flat_map(func)[源代码]
对数据集对象中每一条数据执行给定的数据处理,并将结果展平。
- 参数:
func (function) - 数据处理函数,要求输入必须为一个 numpy.ndarray ,返回值是一个 Dataset 对象。
- 返回:
执行给定操作后的数据集对象。
- 异常:
TypeError - func 不是函数。
TypeError - func 的返回值不是 Dataset 对象。
样例:
>>> import mindspore.dataset as ds >>> # 1) flat_map on one column dataset >>> dataset = ds.NumpySlicesDataset([[0, 1], [2, 3]], shuffle=False) >>> >>> def repeat(array): ... # create a NumpySlicesDataset with the array ... data = ds.NumpySlicesDataset(array, shuffle=False) ... # repeat the dataset twice ... data = data.repeat(2) ... return data >>> >>> dataset = dataset.flat_map(repeat) >>> # [0, 1, 0, 1, 2, 3, 2, 3] >>> >>> # 2) flat_map on multi column dataset >>> dataset = ds.NumpySlicesDataset(([[0, 1], [2, 3]], [[0, -1], [-2, -3]]), shuffle=False) >>> >>> def plus_and_minus(col1, col2): ... # apply different methods on columns ... data = ds.NumpySlicesDataset((col1 + 1, col2 - 1), shuffle=False) ... return data >>> >>> dataset = dataset.flat_map(plus_and_minus) >>> # ([1, 2, 3, 4], [-1, -2, -3, -4])