Source code for mindspore.dataset.callback.ds_callback

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Python callback class
"""
import threading
from mindspore._c_dataengine import PyDSCallback
from mindspore.train.callback import Callback
import mindspore.dataset as ds
from .validators import check_callback


[docs]class DSCallback: """ Abstract base class used to build a dataset callback class. Args: step_size (int, optional): The number of steps between the step_begin and step_end are called (Default=1). Examples: >>> class PrintInfo(DSCallback): >>> def ds_epoch_end(self, ds_run_context): >>> print(cb_params.cur_epoch_num) >>> print(cb_params.cur_step_num) >>> >>> data = data.map(operations=op, callbacks=PrintInfo()) """ @check_callback def __init__(self, step_size=1): self.step_size = step_size
[docs] def ds_begin(self, ds_run_context): """ Called before the data pipeline is started. Args: ds_run_context (RunContext): Include some information of the pipeline. """
[docs] def ds_epoch_begin(self, ds_run_context): """ Called before a new epoch is started. Args: ds_run_context (RunContext): Include some information of the pipeline. """
[docs] def ds_epoch_end(self, ds_run_context): """ Called after an epoch is finished. Args: ds_run_context (RunContext): Include some information of the pipeline. """
[docs] def ds_step_begin(self, ds_run_context): """ Called before n steps are started. Args: ds_run_context (RunContext): Include some information of the pipeline. """
[docs] def ds_step_end(self, ds_run_context): """ Called after n steps are finished. Args: ds_run_context (RunContext): Include some information of the pipeline. """
[docs] def create_runtime_obj(self): """ Creates a runtime (C++) object from the callback methods defined by the user. Returns: _c_dataengine.PyDSCallback """ c_cb = PyDSCallback(self.step_size) at_least_one = False if self.__class__.ds_begin != DSCallback.ds_begin: c_cb.set_begin(self.ds_begin) at_least_one = True if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin: c_cb.set_epoch_begin(self.ds_epoch_begin) at_least_one = True if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end: c_cb.set_epoch_end(self.ds_epoch_end) at_least_one = True if self.__class__.ds_step_begin != DSCallback.ds_step_begin: c_cb.set_step_begin(self.ds_step_begin) at_least_one = True if self.__class__.ds_step_end != DSCallback.ds_step_end: c_cb.set_step_end(self.ds_step_end) at_least_one = True if not at_least_one: raise AttributeError("Provided Callback class did not override any of the 6 callback methods.") return c_cb
[docs]class WaitedDSCallback(Callback, DSCallback): """ Abstract base class used to build a dataset callback class that are synchronized with the training callback. This class can be used to execute a user defined logic right after the previous step or epoch. For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters. Args: step_size (int, optional): The number of rows in each step. Usually the step size will be equal to the batch size (Default=1). Examples: >>> my_cb = MyWaitedCallback(32) >>> data = data.map(operations=AugOp(), callbacks=my_cb) >>> data = data.batch(32) >>> # define the model >>> model.train(epochs, data, callbacks=[my_cb]) """ def __init__(self, step_size=1): super().__init__() self.step_size = step_size self.step_event = threading.Event() self.step_run_context = None self.epoch_event = threading.Event() self.epoch_run_context = None self.training_ended = False
[docs] def sync_epoch_begin(self, train_run_context, ds_run_context): """ Called before a new dataset epoch is started and after the previous training epoch is ended. Args: train_run_context: Include some information of the model with feedback from the previous epoch. ds_run_context: Include some information of the dataset pipeline. """
[docs] def sync_step_begin(self, train_run_context, ds_run_context): """ Called before a new dataset step is started and after the previous training step is ended. Args: train_run_context: Include some information of the model with feedback from the previous step. ds_run_context: Include some information of the dataset pipeline. """
[docs] def epoch_end(self, run_context): """ Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin. Args: run_context: Include some information of the model. """ self.epoch_run_context = run_context self.epoch_event.set()
[docs] def ds_epoch_begin(self, ds_run_context): """ Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback. Args: ds_run_context: Include some information of the pipeline. """ if ds_run_context.cur_epoch_num > 1: if not self.training_ended: success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout()) self.epoch_event.clear() if not success: raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s).") # by the time this thread wakes up, self.epoch_run_context is already available self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
[docs] def step_end(self, run_context): """ Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin. Args: run_context: Include some information of the model. """ self.step_run_context = run_context self.step_event.set()
[docs] def ds_step_begin(self, ds_run_context): """ Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback. Args: ds_run_context: Include some information of the pipeline. """ if ds_run_context.cur_step_num > self.step_size: if not self.training_ended: success = self.step_event.wait(timeout=ds.config.get_callback_timeout()) self.step_event.clear() if not success: raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s).") # by the time this thread wakes up, self.epoch_run_context is already available self.sync_step_begin(self.step_run_context, ds_run_context)
[docs] def create_runtime_obj(self): """ Creates a runtime (C++) object from the callback methods defined by the user. This method is internal. Returns: _c_dataengine.PyDSCallback """ c_cb = PyDSCallback(self.step_size) at_least_one = False if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin: c_cb.set_step_begin(self.ds_step_begin) at_least_one = True if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin: c_cb.set_epoch_begin(self.ds_epoch_begin) at_least_one = True if not at_least_one: raise AttributeError("Provided Callback class did not override any of the 2 callback methods.") return c_cb
def end(self, run_context): self.epoch_end(run_context) self.step_end(run_context) self.training_ended = True