mindspore.dataset.Dataset.filter

查看源文件
mindspore.dataset.Dataset.filter(predicate, input_columns=None, num_parallel_workers=None)[源代码]

通过自定义判断条件对数据集对象中的数据进行过滤。

参数:
  • predicate (callable) - Python可调用对象。要求该对象接收n个入参,用于指代每个数据列的数据,最后返回值一个bool值。 如果返回值为False,则表示过滤掉该条数据。注意n的值与参数 input_columns 表示的输入列数量一致。

  • input_columns (Union[str, list[str]], 可选) - filter 操作的输入数据列。默认值: Nonepredicate 将应用于数据集中的所有列。

  • num_parallel_workers (int, 可选) - 指定 filter 操作的并发线程数。默认值: None ,使用全局默认线程数(8),也可以通过 mindspore.dataset.config.set_num_parallel_workers() 配置全局线程数。

返回:

Dataset,应用了上述操作的新数据集对象。

样例:

>>> # generator data(0 ~ 19)
>>> # filter the data that greater than or equal to 11
>>> import mindspore.dataset as ds
>>> dataset = ds.GeneratorDataset([i for i in range(20)], "data")
>>> dataset = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])