mindspore.dataset.Dataset.padded_batch

查看源文件
mindspore.dataset.Dataset.padded_batch(batch_size, drop_remainder=False, num_parallel_workers=None, pad_info=None)[源代码]

将数据集中连续 batch_size 条数据组合为一个批数据,并可通过可选参数 pad_info 预先将样本补齐。

batch 操作要求每列中的数据具有相同的shape。

执行流程参考下图:

../../../../_images/padded_batch_cn.png

说明

执行 repeatpadded_batch 操作的先后顺序,会影响批处理数据的数量。建议在 padded_batch 操作完成后执行 repeat 操作。

参数:
  • batch_size (Union[int, Callable]) - 指定每个批处理数据包含的数据条目。 如果 batch_size 为整型,则直接表示每个批处理数据大小; 如果为可调用对象,则可以通过自定义行为动态指定每个批处理数据大小,要求该可调用对象接收一个参数BatchInfo,返回一个整形代表批处理大小,用法请参考样例(3)。

  • drop_remainder (bool, 可选) - 当最后一个批处理数据包含的数据条目小于 batch_size 时,是否将该批处理丢弃,不传递给下一个操作。默认值: False ,不丢弃。

  • num_parallel_workers (int, 可选) - 指定 padded_batch 操作的并发线程数。 默认值: None ,使用全局默认线程数(8),也可以通过 mindspore.dataset.config.set_num_parallel_workers() 配置全局线程数。

  • pad_info (dict, 可选) - 对给定数据列进行填充。通过传入dict来指定列信息与填充信息,例如 pad_info={“col1”:([224,224],0)} , 则将列名为”col1”的数据列扩充到shape为(224, 224)的Tensor,缺失的值使用0填充。如果 pad_info={} ,则每个 batch 中的所有样本会补齐至当前 batch 中样本最大的shape。默认值: None ,不填充。

返回:

Dataset,应用了上述操作的新数据集对象。

样例:

>>> # 1) Pad every sample to the largest sample's shape and batch the samples
>>> import mindspore.dataset as ds
>>> dataset = ds.NumpySlicesDataset([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]], "column1")
>>> dataset = dataset.padded_batch(2, True, pad_info={})
>>>
>>> # 2) Create a dataset where every 3 rows are combined into a batch
>>> # and drops the last incomplete batch if there is one.
>>> dataset = ds.NumpySlicesDataset([i for i in range(10)], "column1")
>>> dataset = dataset.padded_batch(3, True)
>>>
>>> # 3) Create a dataset where its batch size is dynamic
>>> # Define a callable batch size function and let batch size increase 1 each time.
>>> def add_one(BatchInfo):
...     return BatchInfo.get_batch_num() + 1
>>> dataset = dataset.padded_batch(batch_size=add_one, drop_remainder=True)