mindspore.dataset.DSCallback
- class mindspore.dataset.DSCallback(step_size=1)[源代码]
数据处理回调类的抽象基类,用户可以基于此类实现自己的回调操作。
用户可通过 ds_run_context 获取数据处理管道相关信息,包括 cur_epoch_num (当前epoch数)、 cur_step_num_in_epoch (当前epoch的step数)、 cur_step_num (当前step数)。
- 参数:
step_size (int, 可选) - 定义相邻的 ds_step_begin/ds_step_end 调用之间相隔的step数。默认值:
1
,表示每个step都会调用。
样例:
>>> import mindspore.dataset as ds >>> from mindspore.dataset import DSCallback >>> >>> class PrintInfo(DSCallback): ... def ds_begin(self, ds_run_context): ... print("callback: start dataset pipeline", flush=True) ... ... def ds_epoch_begin(self, ds_run_context): ... print("callback: epoch begin, we are in epoch", ds_run_context.cur_epoch_num, flush=True) ... ... def ds_epoch_end(self, ds_run_context): ... print("callback: epoch end, we are in epoch", ds_run_context.cur_epoch_num, flush=True) ... ... def ds_step_begin(self, ds_run_context): ... print("callback: step begin, step", ds_run_context.cur_step_num_in_epoch, flush=True) ... ... def ds_step_end(self, ds_run_context): ... print("callback: step end, step", ds_run_context.cur_step_num_in_epoch, flush=True) >>> >>> dataset = ds.GeneratorDataset([1, 2], "col1", shuffle=False, num_parallel_workers=1) >>> dataset = dataset.map(operations=lambda x: x, callbacks=PrintInfo()) >>> >>> # Start dataset pipeline >>> iterator = dataset.create_tuple_iterator(num_epochs=2) >>> for i in range(2): ... for d in iterator: ... pass callback: start dataset pipeline callback: epoch begin, we are in epoch 1 callback: step begin, step 1 callback: step begin, step 2 callback: step end, step 1 callback: step end, step 2 callback: epoch end, we are in epoch 1 callback: epoch begin, we are in epoch 2 callback: step begin, step 1 callback: step begin, step 2 callback: step end, step 1 callback: step end, step 2 callback: epoch end, we are in epoch 2
- ds_epoch_begin(ds_run_context)[源代码]
用于定义在每个数据epoch开始前执行的回调方法。
- 参数:
ds_run_context (RunContext) - 数据处理管道运行信息。
- ds_epoch_end(ds_run_context)[源代码]
用于定义在每个数据epoch结束后执行的回调方法。
- 参数:
ds_run_context (RunContext) - 数据处理管道运行信息。