mindspore.dataset.WaitedDSCallback

查看源文件
class mindspore.dataset.WaitedDSCallback(step_size=1)[源代码]

阻塞式数据处理回调类的抽象基类,用于与训练回调类 mindspore.train.Callback 的同步。

可用于在step或epoch开始前执行自定义的回调方法,例如在自动数据增强中根据上一个epoch的loss值来更新增强操作参数配置。

用户可通过 train_run_context 获取网络训练相关信息,如 networktrain_networkepoch_numbatch_numloss_fnoptimizerparallel_modedevice_numberlist_callbackcur_epoch_numcur_step_numdataset_sink_modenet_outputs 等,详见 mindspore.train.Callback

用户可通过 ds_run_context 获取数据处理管道相关信息,包括 cur_epoch_num (当前epoch数)、 cur_step_num_in_epoch (当前epoch的step数)、 cur_step_num (当前step数)。

说明

注意,第2个step或epoch开始时才会触发该调用。

参数:
  • step_size (int, 可选) - 每个step包含的数据行数。通常step_size与batch_size一致。默认值: 1

样例:

>>> import mindspore as ms
>>> import mindspore.dataset as ds
>>> import mindspore.nn as nn
>>> from mindspore.dataset import WaitedDSCallback
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
>>>
>>> # custom callback class for data synchronization in data pipeline
>>> class MyWaitedCallback(WaitedDSCallback):
...     def __init__(self, events, step_size=1):
...         super().__init__(step_size)
...         self.events = events
...
...     # callback method to be executed by data pipeline before the epoch starts
...     def sync_epoch_begin(self, train_run_context, ds_run_context):
...         event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
...         self.events.append(event)
...
...     # callback method to be executed by data pipeline before the step starts
...     def sync_step_begin(self, train_run_context, ds_run_context):
...         event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
...         self.events.append(event)
>>>
>>> # custom callback class for data synchronization in network training
>>> class MyMSCallback(ms.Callback):
...     def __init__(self, events):
...         self.events = events
...
...     # callback method to be executed by network training after the epoch ends
...     def epoch_end(self, run_context):
...         cb_params = run_context.original_args()
...         event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
...         self.events.append(event)
...
...     # callback method to be executed by network training after the step ends
...     def step_end(self, run_context):
...         cb_params = run_context.original_args()
...         event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
...         self.events.append(event)
>>>
>>> # custom network
>>> class Net(nn.Cell):
...     def construct(self, x, y):
...         return x
>>>
>>> # define a parameter that needs to be synchronized between data pipeline and network training
>>> events = []
>>>
>>> # define callback classes of data pipeline and netwok training
>>> my_cb1 = MyWaitedCallback(events, 1)
>>> my_cb2 = MyMSCallback(events)
>>> arr = [1, 2, 3, 4]
>>>
>>> # construct data pipeline
>>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
>>> # map the data callback object into the pipeline
>>> data = data.map(operations=(lambda x: x), callbacks=my_cb1)
>>>
>>> net = Net()
>>> model = ms.train.Model(net)
>>>
>>> # add the data and network callback objects to the model training callback list
>>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
sync_epoch_begin(train_run_context, ds_run_context)[源代码]

用于定义在数据epoch开始前,训练epoch结束后执行的回调方法。

参数:
  • train_run_context - 包含前一个epoch的反馈信息的网络训练运行信息。

  • ds_run_context - 数据处理管道运行信息。

sync_step_begin(train_run_context, ds_run_context)[源代码]

用于定义在数据step开始前,训练step结束后执行的回调方法。

参数:
  • train_run_context - 包含前一个step的反馈信息的网络训练运行信息。

  • ds_run_context - 数据处理管道运行信息。