Source code for mindvision.engine.callback.monitor

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" Init for base architecture engine monitor register. """

import time
import os
import stat
from typing import Optional, Union, Iterable
import numpy as np

import mindspore as ms
from mindspore import save_checkpoint
from mindspore.train.callback import Callback

from mindvision.check_param import Rel, Validator as validator

__all__ = ["LossMonitor", "ValAccMonitor"]


[docs]class LossMonitor(Callback): """ Loss Monitor for classification. Args: lr_init (Union[float, Iterable], optional): The learning rate schedule. Default: None. per_print_times (int): Every how many steps to print the log information. Default: 1. Examples: >>> from mindvision.engine.callback import LossMonitor >>> lr = [0.01, 0.008, 0.006, 0.005, 0.002] >>> monitor = LossMonitor(lr_init=lr, per_print_times=100) """ def __init__(self, lr_init: Optional[Union[float, Iterable]] = None, per_print_times: int = 1): super(LossMonitor, self).__init__() self.lr_init = lr_init self.per_print_times = per_print_times self.last_print_time = 0 # pylint: disable=unused-argument
[docs] def epoch_begin(self, run_context): """ Record time at the beginning of epoch. Args: run_context (RunContext): Context of the process running. """ self.losses = [] self.epoch_time = time.time()
[docs] def epoch_end(self, run_context): """ Print training info at the end of epoch. Args: run_context (RunContext): Context of the process running. """ callback_params = run_context.original_args() epoch_mseconds = (time.time() - self.epoch_time) * 1000 per_step_mseconds = epoch_mseconds / callback_params.batch_num print(f"Epoch time: {epoch_mseconds:5.3f} ms, " f"per step time: {per_step_mseconds:5.3f} ms, " f"avg loss: {np.mean(self.losses):5.3f}", flush=True)
# pylint: disable=unused-argument
[docs] def step_begin(self, run_context): """ Record time at the beginning of step. Args: run_context (RunContext): Context of the process running. """ self.step_time = time.time()
[docs] def step_end(self, run_context): """ Print training info at the end of step. Args: run_context (RunContext): Context of the process running. """ callback_params = run_context.original_args() step_mseconds = (time.time() - self.step_time) * 1000 loss = callback_params.net_outputs if isinstance(loss, (tuple, list)): if isinstance(loss[0], ms.Tensor) and isinstance(loss[0].asnumpy(), np.ndarry): loss = loss[0] if isinstance(loss, ms.Tensor) and isinstance(loss.asnumpy(), np.ndarray): loss = np.mean(loss.asnumpy()) self.losses.append(loss) cur_step_in_epoch = (callback_params.cur_step_num - 1) % callback_params.batch_num + 1 # Boundary check. if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)): raise ValueError(f"Invalid loss, terminate training.") def print_info(): lr_output = self.lr_init[callback_params.cur_step_num - 1] if isinstance(self.lr_init, list) else self.lr_init print(f"Epoch:[{(callback_params.cur_epoch_num - 1):3d}/{callback_params.epoch_num:3d}], " f"step:[{cur_step_in_epoch:5d}/{callback_params.batch_num:5d}], " f"loss:[{loss:5.3f}/{np.mean(self.losses):5.3f}], " f"time:{step_mseconds:5.3f} ms, " f"lr:{lr_output:5.5f}", flush=True) if (callback_params.cur_step_num - self.last_print_time) >= self.per_print_times: self.last_print_time = callback_params.cur_step_num print_info()
[docs]class ValAccMonitor(Callback): """ Monitors the train loss and the validation accuracy, after each epoch saves the best checkpoint file with highest validation accuracy. Args: model (ms.Model): The model to monitor. dataset_val (ms.dataset): The dataset that the model needs. num_epochs (int): The number of epochs. interval (int): Every how many epochs to validate and print information. Default: 1. eval_start_epoch (int): From which time to validate. Default: 1. save_best_ckpt (bool): Whether to save the checkpoint file which performs best. Default: True. ckpt_directory (str): The path to save checkpoint files. Default: './'. best_ckpt_name (str): The file name of the checkpoint file which performs best. Default: 'best.ckpt'. metric_name (str): The name of metric for model evaluation. Default: 'Accuracy'. dataset_sink_mode (bool): Whether to use the dataset sinking mode. Default: True. Raises: ValueError: If `interval` is not more than 1. Examples: >>> import mindspore as ms >>> import mindspore.nn as nn >>> import mindspore.dataset as ds >>> from mindvision.classification.models import lenet >>> from mindvision.classification.dataset import Mnist >>> >>> net = lenet() >>> opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.001, momentum=0.9) >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean') >>> model = ms.Model(net, loss,opt,metrics={"Accuracy":nn.Accuracy()}) >>> dataset_val = Mnist("./mnist", split="test", batch_size=32, resize=32, download=True) >>> dataset_val = dataset_val.run() >>> monitor = ValAccMonitor(model, dataset_val, num_epochs=10) """ def __init__(self, model: ms.Model, dataset_val: ms.dataset, num_epochs: int, interval: int = 1, eval_start_epoch: int = 1, save_best_ckpt: bool = True, ckpt_directory: str = "./", best_ckpt_name: str = "best.ckpt", metric_name: str = "Accuracy", dataset_sink_mode: bool = True): super(ValAccMonitor, self).__init__() self.model = model self.dataset_val = dataset_val self.num_epochs = num_epochs self.eval_start_epoch = eval_start_epoch self.save_best_ckpt = save_best_ckpt self.metric_name = metric_name self.interval = validator.check_int(interval, 1, Rel.GE, "interval") self.best_res = 0 self.dataset_sink_mode = dataset_sink_mode if not os.path.isdir(ckpt_directory): os.makedirs(ckpt_directory) self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)
[docs] def apply_eval(self): """Model evaluation, return validation accuracy.""" return self.model.eval(self.dataset_val, dataset_sink_mode=self.dataset_sink_mode)[self.metric_name]
[docs] def epoch_end(self, run_context): """ After epoch, print train loss and val accuracy, save the best ckpt file with highest validation accuracy. Args: run_context (RunContext): Context of the process running. """ callback_params = run_context.original_args() cur_epoch = callback_params.cur_epoch_num if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: # Validation result res = self.apply_eval() print("-" * 20) print(f"Epoch: [{cur_epoch: 3d} / {self.num_epochs: 3d}], " f"Train Loss: [{callback_params.net_outputs.asnumpy() :5.3f}], " f"{self.metric_name}: {res: 5.3f}") def remove_ckpt_file(file_name): os.chmod(file_name, stat.S_IWRITE) os.remove(file_name) # Save the best ckpt file if res >= self.best_res: self.best_res = res if self.save_best_ckpt: if os.path.exists(self.best_ckpt_path): remove_ckpt_file(self.best_ckpt_path) save_checkpoint(callback_params.train_network, self.best_ckpt_path)
# pylint: disable=unused-argument
[docs] def end(self, run_context): """ Print the best validation accuracy after network training. Args: run_context (RunContext): Context of the process running. """ print("=" * 80) print(f"End of validation the best {self.metric_name} is: {self.best_res: 5.3f}, " f"save the best ckpt file in {self.best_ckpt_path}", flush=True)