mindspore.dataset.BatchInfo
- class mindspore.dataset.BatchInfo[源代码]
当 batch 操作中参数 batch_size 或 per_batch_map 的传入对象是回调函数时,可以通过此类提供的方法获取数据集信息。
- get_batch_num()[源代码]
返回当前epoch已经处理的batch数,数值从0开始。
样例:
>>> # Create a dataset where its batch size is dynamic >>> # Define a callable batch size function and let batch size increase 1 each time. >>> import mindspore.dataset as ds >>> from mindspore.dataset import BatchInfo >>> >>> dataset = ds.GeneratorDataset([i for i in range(3)], "column1", shuffle=False) >>> def add_one(BatchInfo): ... return BatchInfo.get_batch_num() + 1 >>> dataset = dataset.batch(batch_size=add_one) >>> print(list(dataset)) [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[2], dtype=Int64, value= [1, 2])]]
- get_epoch_num()[源代码]
返回当前的epoch数,数值从0开始。
样例:
>>> # Create a dataset where its batch size is dynamic >>> # Define a callable batch size function and let batch size increase 1 each epoch. >>> import mindspore.dataset as ds >>> from mindspore.dataset import BatchInfo >>> >>> dataset = ds.GeneratorDataset([i for i in range(4)], "column1", shuffle=False) >>> def add_one_by_epoch(BatchInfo): ... return BatchInfo.get_epoch_num() + 1 >>> dataset = dataset.batch(batch_size=add_one_by_epoch) >>> >>> result = [] >>> epoch = 2 >>> iterator = dataset.create_tuple_iterator(num_epochs=epoch) >>> for i in range(epoch): ... result.extend(list(iterator)) >>> # result: >>> # [[Tensor(shape=[1], dtype=Int64, value= [0])], [Tensor(shape=[1], dtype=Int64, value= [1])], >>> # [Tensor(shape=[1], dtype=Int64, value= [2])], [Tensor(shape=[1], dtype=Int64, value= [3])], >>> # [Tensor(shape=[2], dtype=Int64, value= [0, 1])], [Tensor(shape=[2], dtype=Int64, value= [2, 3])]]