mindspore.dataset.WaitedDSCallback
- class mindspore.dataset.WaitedDSCallback(step_size=1)[源代码]
阻塞式数据处理回调类的抽象基类,用于与训练回调类 mindspore.train.callback 的同步。
可用于在step或epoch开始前执行自定义的回调方法,例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。
用户可通过 train_run_context 获取网络训练相关信息,如 network 、 train_network 、 epoch_num 、 batch_num 、 loss_fn 、 optimizer 、 parallel_mode 、 device_number 、 list_callback 、 cur_epoch_num 、 cur_step_num 、 dataset_sink_mode 、 net_outputs 等,详见 mindspore.train.callback 。
用户可通过 ds_run_context 获取数据处理管道相关信息,包括 cur_epoch_num (当前epoch数)、 cur_step_num_in_epoch (当前epoch的step数)、 cur_step_num (当前step数)。
Note
注意,第2个step或epoch开始时才会触发该调用。
参数:
step_size (int, optional) - 每个step包含的数据行数。通常step_size与batch_size一致,默认值:1。
样例:
>>> import mindspore.nn as nn >>> from mindspore.dataset import WaitedDSCallback >>> from mindspore import context >>> from mindspore.train import Model >>> from mindspore.train.callback import Callback >>> >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") >>> >>> # custom callback class for data synchronization in data pipeline >>> class MyWaitedCallback(WaitedDSCallback): ... def __init__(self, events, step_size=1): ... super().__init__(step_size) ... self.events = events ... ... # callback method to be executed by data pipeline before the epoch starts ... def sync_epoch_begin(self, train_run_context, ds_run_context): ... event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" ... self.events.append(event) ... ... # callback method to be executed by data pipeline before the step starts ... def sync_step_begin(self, train_run_context, ds_run_context): ... event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" ... self.events.append(event) >>> >>> # custom callback class for data synchronization in network training >>> class MyMSCallback(Callback): ... def __init__(self, events): ... self.events = events ... ... # callback method to be executed by network training after the epoch ends ... def epoch_end(self, run_context): ... cb_params = run_context.original_args() ... event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" ... self.events.append(event) ... ... # callback method to be executed by network training after the step ends ... def step_end(self, run_context): ... cb_params = run_context.original_args() ... event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" ... self.events.append(event) >>> >>> # custom network >>> class Net(nn.Cell): ... def construct(self, x, y): ... return x >>> >>> # define a parameter that needs to be synchronized between data pipeline and network training >>> events = [] >>> >>> # define callback classes of data pipeline and netwok training >>> my_cb1 = MyWaitedCallback(events, 1) >>> my_cb2 = MyMSCallback(events) >>> arr = [1, 2, 3, 4] >>> >>> # construct data pipeline >>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) >>> # map the data callback object into the pipeline >>> data = data.map(operations=(lambda x: x), callbacks=my_cb1) >>> >>> net = Net() >>> model = Model(net) >>> >>> # add the data and network callback objects to the model training callback list >>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])