mindspore.dataset.BatchInfo
- class mindspore.dataset.BatchInfo[源代码]
此类提供了两种方法获取数据集的批处理数量(batch size)和迭代数(epoch)属性,这些属性可以用于 batch 操作中的输入参数 batch_size 和 per_batch_map。
- get_batch_num()[源代码]
返回数据集的批处理数量(batch size)。
样例:
>>> # Create a dataset where its batch size is dynamic >>> # Define a callable batch size function and let batch size increase 1 each time. >>> import mindspore.dataset as ds >>> from mindspore.dataset import BatchInfo >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") >>> def add_one(BatchInfo): ... return BatchInfo.get_batch_num() + 1 >>> dataset = dataset.batch(batch_size=add_one)
- get_epoch_num()[源代码]
返回数据集的迭代数(epoch)。
样例:
>>> # Create a dataset where its batch size is dynamic >>> # Define a callable batch size function and let batch size increase 1 each epoch. >>> import mindspore.dataset as ds >>> from mindspore.dataset import BatchInfo >>> dataset = ds.GeneratorDataset([i for i in range(10)], "column1") >>> def add_one_by_epoch(BatchInfo): ... return BatchInfo.get_epoch_num() + 1 >>> dataset = dataset.batch(batch_size=add_one_by_epoch)