Source code for mindspore.train.callback

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Callback related classes and functions."""

import os
import stat
import shutil
import time
import numpy as np

import mindspore.context as context
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
from mindspore.train._utils import _make_directory
from mindspore import log as logger
from mindspore._checkparam import check_int_non_negative
from mindspore.common.tensor import Tensor
from .summary.summary_record import _cache_summary_tensor_data


__all__ = ["Callback", "LossMonitor", "TimeMonitor", "ModelCheckpoint", "SummaryStep", "CheckpointConfig", "RunContext"]


_cur_dir = os.getcwd()
_cur_net = None
_save_dir = _cur_dir


class _CheckpointManager:
    """Manage checkpoint files according to train_config of checkpoint."""
    def __init__(self):
        self._ckpoint_filelist = []

    @property
    def ckpoint_filelist(self):
        """Get all the related checkpoint files managed here."""
        return self._ckpoint_filelist

    @property
    def ckpoint_num(self):
        """Get the number of the related checkpoint files managed here."""
        return len(self._ckpoint_filelist)

    def update_ckpoint_filelist(self, directory, prefix):
        """Update the checkpoint file list."""
        self._ckpoint_filelist = []
        files = os.listdir(directory)
        for filename in files:
            if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix):
                mid_name = filename[len(prefix):-5]
                flag = True
                for char in mid_name:
                    if char.isalpha():
                        flag = False
                if flag:
                    self._ckpoint_filelist.append(directory + '/' + filename)

    def remove_ckpoint_file(self, file_name):
        """Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
        try:
            os.chmod(file_name, stat.S_IWRITE)
            os.remove(file_name)
            self._ckpoint_filelist.remove(file_name)
        except OSError:
            logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
        except ValueError:
            logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)

    def remove_oldest_ckpoint_file(self):
        """Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
        ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
        self.remove_ckpoint_file(ckpoint_files[0])

    def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
        """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
        movs = []
        oldest_file = ''
        oldest_time = cur_time
        for ck_file in self._ckpoint_filelist:
            modify_time = os.path.getmtime(ck_file)
            if cur_time - modify_time < 60 * minutes:
                movs.append(ck_file)

                if modify_time < oldest_time:
                    oldest_time = modify_time
                    oldest_file = ck_file

        for mv_file in movs:
            if mv_file == oldest_file:
                continue
            self.remove_ckpoint_file(mv_file)


def _check_file_name_prefix(file_name_prefix):
    """
    Check file name valid or not.

    File name can't include '/'. This file name naming convention only apply to Linux.
    """
    if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
        return False
    return True


def _chg_ckpt_file_name_if_same_exist(directory, prefix):
    """Check if there is a file with the same name."""
    files = os.listdir(directory)
    suffix_num = 0
    pre_len = len(prefix)
    for filename in files:
        name_ext = os.path.splitext(filename)
        if name_ext[-1] != ".ckpt":
            continue
        # find same prefix file
        if filename.find(prefix) == 0 and not filename[pre_len].isalpha():
            # add the max suffix + 1
            index = filename[pre_len:].find("-")
            if index == 0:
                suffix_num = max(suffix_num, 1)
            elif index != -1:
                num = filename[pre_len+1:pre_len+index]
                if num.isdigit():
                    suffix_num = max(suffix_num, int(num)+1)

    if suffix_num != 0:
        prefix = prefix + "_" + str(suffix_num)

    return prefix


[docs]class CheckpointConfig: """ The config for model checkpoint. Args: save_checkpoint_steps (int): Steps to save checkpoint. Default: 1. save_checkpoint_seconds (int): Seconds to save checkpoint. Default: 0. Can't be used with save_checkpoint_steps at the same time. keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. Can't be used with keep_checkpoint_max at the same time. Raises: ValueError: If the input_param is None or 0. Examples: >>> config = CheckpointConfig() >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config) >>> model.train(10, dataset, callbacks=ckpoint_cb) """ def __init__(self, save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: raise ValueError("The input_param can't be all None or 0") if save_checkpoint_steps: save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) if save_checkpoint_seconds: save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) if keep_checkpoint_max: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) if keep_checkpoint_per_n_minutes: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: self._save_checkpoint_seconds = None self._keep_checkpoint_max = keep_checkpoint_max self._keep_checkpoint_per_n_minutes = keep_checkpoint_per_n_minutes if self._keep_checkpoint_max and self._keep_checkpoint_max > 0: self._keep_checkpoint_per_n_minutes = None else: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 @property def save_checkpoint_steps(self): """Get the value of _save_checkpoint_steps.""" return self._save_checkpoint_steps @property def save_checkpoint_seconds(self): """Get the value of _save_checkpoint_seconds.""" return self._save_checkpoint_seconds @property def keep_checkpoint_max(self): """Get the value of _keep_checkpoint_max.""" return self._keep_checkpoint_max @property def keep_checkpoint_per_n_minutes(self): """Get the value of _keep_checkpoint_per_n_minutes.""" return self._keep_checkpoint_per_n_minutes
[docs] def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, 'save_checkpoint_seconds': self._save_checkpoint_seconds, 'keep_checkpoint_max': self._keep_checkpoint_max, 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes} return checkpoint_policy
def _set_cur_net(net): """ Set current net for which we are using to save checkpoint. Args: net (Cell): train network """ global _cur_net _cur_net = net def _checkpoint_cb_for_save_op(parameter_list): """ The checkpoint callback function for MindSpore. Will be executed by checkpoint save op. Args: parameter_list (list): Format is like [{"name",name},{"data",value}] and value type is Tensor. Returns: bool, true: means save checkpoint success. """ if _cur_net is None: logger.warning("_cur_net is None. parameters are not updated.") return False logger.info("update parameters in the net.") _fill_param_into_net(_cur_net, parameter_list) _set_cur_net(None) return True def _summary_cb_for_save_op(summary_list): """ The summary callback function for MindSpore. Will be executed by summary op. Args: summary_list (list): Format is like [{"name": tag_name, "data": tensor},...] and value is Scalar/Tensor. Returns: bool, true: means save summary success. """ ret = _cache_summary_tensor_data(summary_list) return ret def _build_callbacks(callbacks): """ Contain a list of callback. Args: callbacks (list): Callback functions list, Support None, a single Callback object, or a list. Returns: List, a list of callback functions. """ if callbacks: if isinstance(callbacks, tuple): raise TypeError("Callbacks cannot be a tuple. Please check it.") if not isinstance(callbacks, list): callbacks = [callbacks] else: callbacks = [] excute_callbacks = [] for cb in callbacks: if cb is None or not isinstance(cb, Callback): raise TypeError("Callback must inheriting base class Callback. Some callback is Wrong. Please check it.") excute_callbacks.append(cb) return _ListCallback(excute_callbacks) class _ListCallback: """ Sequential execution of callback functions. Execute Callback functions at certain points. Args: callbacks (list): Callback functions list. """ def __init__(self, callbacks): super(_ListCallback, self).__init__() self._callbacks = callbacks def begin(self, run_context): """Called once before network training.""" for cb in self._callbacks: cb.begin(run_context) def epoch_begin(self, run_context): """Called before each epoch begin.""" for cb in self._callbacks: cb.epoch_begin(run_context) def epoch_end(self, run_context): """Called after each epoch finished.""" for cb in self._callbacks: cb.epoch_end(run_context) def step_begin(self, run_context): """Called before each epoch begin.""" for cb in self._callbacks: cb.step_begin(run_context) def step_end(self, run_context): """Called after each step finished.""" for cb in self._callbacks: cb.step_end(run_context) def end(self, run_context): """Called once after network training.""" for cb in self._callbacks: cb.end(run_context)
[docs]class Callback: """ Abstract base class used to build a callback function. Callback function will execution some operating to the current step or epoch. Examples: >>> class Print_info(Callback): >>> def step_end(self, run_context): >>> cb_params = run_context.original_args() >>> print(cb_params.cur_epoch_num) >>> print(cb_params.cur_step_num) >>> >>> print_cb = Print_info() >>> model.train(epoch, dataset, callback=print_cb) """ def __init__(self): pass
[docs] def begin(self, run_context): """ Called once before the network executing. Args: run_context (RunContext): Include some information of the model. """
[docs] def epoch_begin(self, run_context): """ Called before each epoch beginning. Args: run_context (RunContext): Include some information of the model. """
[docs] def epoch_end(self, run_context): """ Called after each epoch finished. Args: run_context (RunContext): Include some information of the model. """
[docs] def step_begin(self, run_context): """ Called before each epoch beginning. Args: run_context (RunContext): Include some information of the model. """
[docs] def step_end(self, run_context): """ Called after each step finished. Args: run_context (RunContext): Include some information of the model. """
[docs] def end(self, run_context): """ Called once after network training. Args: run_context (RunContext): Include some information of the model. """
[docs]class SummaryStep(Callback): """ The summary callback class. Args: summary (Object): Summary recode object. flush_step (int): Number of interval steps to execute. Default: 10. """ def __init__(self, summary, flush_step=10): super(SummaryStep, self).__init__() if not isinstance(flush_step, int) or isinstance(flush_step, bool) or flush_step <= 0: raise ValueError("`flush_step` should be int and greater than 0") self._summary = summary self._flush_step = flush_step
[docs] def step_end(self, run_context): """ Save summary. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() if cb_params.cur_step_num % self._flush_step == 0: self._summary.record(cb_params.cur_step_num, cb_params.train_network)
@property def summary_file_name(self): return self._summary.full_file_name
class _InternalCallbackParam(dict): """Internal callback object's parameters.""" def __getattr__(self, key): return self[key] def __setattr__(self, key, value): self[key] = value
[docs]class RunContext: """ Provides information about the model. Run call being made. Provides information about original request to model function. callback objects can stop the loop by calling request_stop() of run_context. Args: original_args (dict): Holding the related information of model etc. """ def __init__(self, original_args): if not isinstance(original_args, dict): raise TypeError("The arg of RunContext should be dict type.") self._original_args = original_args self._stop_requested = False
[docs] def original_args(self): """ Get the _original_args object. Returns: _InternalCallbackParam, a object holding the original arguments of model. """ return self._original_args
[docs] def request_stop(self): """ Sets stop requested during training. Callbacks can use this function to request stop of iterations. model.train() checks whether this is called or not. """ self._stop_requested = True
[docs] def get_stop_requested(self): """ Returns whether a stop is requested or not. Returns: bool, if true, model.train() stops iterations. """ return self._stop_requested
[docs]class ModelCheckpoint(Callback): """ The checkpoint callback class. It is called to combine with train process and save the model and network parameters after traning. Args: prefix (str): Checkpoint files names prefix. Default: "CKP". directory (str): Lolder path into which checkpoint files will be saved. Default: None. config (CheckpointConfig): Checkpoint strategy config. Default: None. Raises: ValueError: If the prefix is invalid. TypeError: If the config is not CheckpointConfig type. """ def __init__(self, prefix='CKP', directory=None, config=None): super(ModelCheckpoint, self).__init__() self._latest_ckpt_file_name = "" self._init_time = time.time() self._last_time = time.time() self._last_time_for_keep = time.time() self._last_triggered_step = 0 if _check_file_name_prefix(prefix): self._prefix = prefix else: raise ValueError("Prefix {} for checkpoint file name invalid, " "please check and correct it and then continue.".format(prefix)) if directory: self._directory = _make_directory(directory) else: self._directory = _cur_dir if config is None: self._config = CheckpointConfig() else: if not isinstance(config, CheckpointConfig): raise TypeError("config should be CheckpointConfig type.") self._config = config # get existing checkpoint files self._manager = _CheckpointManager() self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._graph_saved = False
[docs] def step_end(self, run_context): """ Save the checkpoint at the end of step. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() # save graph (only once) if not self._graph_saved: graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') _save_graph(cb_params.train_network, graph_file_name) self._graph_saved = True self._save_ckpt(cb_params)
[docs] def end(self, run_context): """ Save the last checkpoint after training finished. Args: run_context (RunContext): Context of the train running. """ cb_params = run_context.original_args() _to_save_last_ckpt = True self._save_ckpt(cb_params, _to_save_last_ckpt) from mindspore.parallel._cell_wrapper import destroy_allgather_cell destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save): """Check whether save checkpoint files or not.""" if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ or force_to_save is True: return True elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: self._cur_time = time.time() if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save is True: self._last_time = self._cur_time return True else: if cb_params.cur_step_num == cb_params.step_num: return True return False def _save_ckpt(self, cb_params, force_to_save=False): """Save checkpoint files.""" if cb_params.cur_step_num == self._last_triggered_step: return save_ckpt = self._check_save_ckpt(cb_params, force_to_save) step_num_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 if save_ckpt: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + ".ckpt" # update checkpoint file list. self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: self._manager.remove_oldest_ckpoint_file() elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: self._cur_time_for_keep = time.time() if (self._cur_time_for_keep - self._last_time_for_keep) \ < self._config.keep_checkpoint_per_n_minutes * 60: self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._cur_time_for_keep) # generate the new checkpoint file and rename it. global _save_dir _save_dir = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt' gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num if context.get_context("enable_ge"): _set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() _exec_save_checkpoint(cb_params.train_network, gen_file) if os.path.exists(gen_file): shutil.move(gen_file, cur_file) self._latest_ckpt_file_name = cur_file @property def latest_ckpt_file_name(self): """Return the latest checkpoint path and file name.""" return self._latest_ckpt_file_name
[docs]class LossMonitor(Callback): """ Monitor the loss in training. If the loss is NAN or INF, it will terminate training. Note: If per_print_times is 0 do not print loss. Args: per_print_times (int): Print loss every times. Default: 1. Raises: ValueError: If print_step is not int or less than zero. """ def __init__(self, per_print_times=1): super(LossMonitor, self).__init__() if not isinstance(per_print_times, int) or per_print_times < 0: raise ValueError("print_step must be int and >= 0.") self._per_print_times = per_print_times def step_end(self, run_context): cb_params = run_context.original_args() loss = cb_params.net_outputs if isinstance(loss, (tuple, list)): if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray): loss = loss[0] if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray): loss = np.mean(loss.asnumpy()) cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): raise ValueError("epoch: {} step: {}. Invalid loss, terminating training." .format(cb_params.cur_epoch_num, cur_step_in_epoch)) if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
class TimeMonitor(Callback): def __init__(self, data_size): super(TimeMonitor, self).__init__() self.data_size = data_size def epoch_begin(self, run_context): self.epoch_time = time.time() def epoch_end(self, run_context): epoch_mseconds = (time.time() - self.epoch_time) * 1000 per_step_mseconds = epoch_mseconds / self.data_size print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) def step_begin(self, run_context): self.step_time = time.time() def step_end(self, run_context): step_mseconds = (time.time() - self.step_time) * 1000 print('step time', step_mseconds, flush=True)