Custom Debugging Information
Overview
This section describes how to use the customized capabilities provided by MindSpore, such as callback
, metrics
, Print
operators and log printing, to help you quickly debug the training network.
Introduction to Callback
Callback
is a callback function, and callback is not a function but a class. You can use callback function to observe the internal status and related information of the network during training or perform specific actions in a specific period.
For example, you can monitor the loss, save model parameters, dynamically adjust parameters, and terminate training tasks in advance.
Callback Capabilities of MindSpore
MindSpore provides the Callback
capabilities to allow users to insert customized operations in a specific phase of training or inference, including:
Callback
classes such asModelCheckpoint
,LossMonitor
, andSummaryCollector
provided by the MindSpore framework.User-customized
Callback
supported by MindSpore .
Usage: Transfer the Callback
object in the model.train
method. It can be a Callback
list, for example:
import mindspore as ms
ckpt_cb = ms.ModelCheckpoint()
loss_cb = ms.LossMonitor()
summary_cb = ms.SummaryCollector(summary_dir='./summary_dir')
model.train(epoch, dataset, callbacks=[ckpt_cb, loss_cb, summary_cb])
ModelCheckpoint
can save model parameters for retraining or inference.
LossMonitor
can output loss in the log for easily viewing, and it also monitors the changes in the loss value during training, terminates the training when the loss value is Nan
or Inf
.
SummaryCollector
can save the training information to files for subsequent visualizations.
During the training process, the Callback
list will execute the Callback
function in the defined order. Therefore, in the definition process, the dependency between Callback
needs to be considered.
Custom Callback
You can customize Callback
based on the callback
base class as required.
The Callback
base class is defined as follows:
class Callback():
"""Callback base class"""
def begin(self, run_context):
"""Called once before the network executing."""
pass
def epoch_begin(self, run_context):
"""Called before each epoch beginning."""
pass
def epoch_end(self, run_context):
"""Called after each epoch finished."""
pass
def step_begin(self, run_context):
"""Called before each step beginning."""
pass
def step_end(self, run_context):
"""Called after each step finished."""
pass
def end(self, run_context):
"""Called once after network training."""
pass
The Callback
can record important information during training and transfer the information to the Callback
object through a dictionary variable RunContext.original_args()
,
You can obtain related attributes from each custom Callback
and perform customized operations. You can also customize other variables and transfer them to the RunContext.original_args()
object.
The main attributes of RunContext.original_args()
are as follows:
loss_fn
: Loss functionoptimizer
: Optimizertrain_dataset
: Training datasetepoch_num
: Number of training epochsbatch_num
: Number of batches in an epochtrain_network
: Training networkcur_epoch_num
: Number of current epochscur_step_num
: Number of current stepsparallel_mode
: Parallel modelist_callback
: All callback functionsnet_outputs
: Network output results…
You can inherit the Callback
base class to customize a callback
object.
Here are two examples to further explain the usage of custom Callback
.
custom
Callback
sample code:https://gitee.com/mindspore/docs/blob/r1.8/docs/sample_code/debugging_info/custom_callback.py
Terminate training within the specified time.
import mindspore as ms class StopAtTime(ms.Callback): def __init__(self, run_time): super(StopAtTime, self).__init__() self.run_time = run_time*60 def begin(self, run_context): cb_params = run_context.original_args() cb_params.init_time = time.time() def step_end(self, run_context): cb_params = run_context.original_args() epoch_num = cb_params.cur_epoch_num step_num = cb_params.cur_step_num loss = cb_params.net_outputs cur_time = time.time() if (cur_time - cb_params.init_time) > self.run_time: print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss) run_context.request_stop()
The implementation principle is: You can use the
run_context.original_args
method to obtain thecb_params
dictionary, which contains the main attribute information described above. In addition, you can modify and add values in the dictionary. In the preceding example, aninit_time
object is defined inbegin
and transferred to thecb_params
dictionary. A decision is made at eachstep_end
. When the training time is longer than the configured time threshold, a training termination signal will be sent to therun_context
to terminate the training in advance and the current values ofepoch
,step
, andloss
will be printed.Save the checkpoint file with the highest accuracy during training.
import mindspore as ms class SaveCallback(ms.Callback): def __init__(self, eval_model, ds_eval): super(SaveCallback, self).__init__() self.model = eval_model self.ds_eval = ds_eval self.acc = 0 def step_end(self, run_context): cb_params = run_context.original_args() result = self.model.eval(self.ds_eval) if result['accuracy'] > self.acc: self.acc = result['accuracy'] file_name = str(self.acc) + ".ckpt" save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name) print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
The specific implementation principle is: define a
Callback
object, and initialize the object to receive themodel
object and theds_eval
(verification dataset). Verify the accuracy of the model in thestep_end
phase. When the accuracy is the current highest, automatically trigger the save checkpoint method to save the current parameters.
MindSpore Metrics Introduction
After the training is complete, you can use metrics to evaluate the training result.
MindSpore provides multiple metrics, such as accuracy
, loss
, precision
, recall
, and F1
.
You can define a metrics dictionary object that contains multiple metrics and transfer them to the model
object and use the model.eval
function to verify the training result.
import mindspore as ms
import mindspore.nn as nn
metrics = {
'accuracy': nn.Accuracy(),
'loss': nn.Loss(),
'precision': nn.Precision(),
'recall': nn.Recall(),
'f1_score': nn.F1()
}
model = ms.Model(network=net, loss_fn=net_loss, optimizer=net_opt, metrics=metrics)
result = model.eval(ds_eval)
The model.eval
method returns a dictionary that contains the metrics and results transferred to the metrics.
The Callback
function can also be used in the eval process, and the user can call the related API or customize the Callback
method to achieve the desired function.
You can also define your own metrics
class by inheriting the Metric
base class and rewriting the clear
, update
, and eval
methods.
The Accuracy
operator is used as an example to describe the internal implementation principle.
The Accuracy
inherits the EvaluationBase
base class and rewrites the preceding three methods.
The
clear
method initializes related calculation parameters in the class.The
update
method accepts the predicted value and tag value and updates the internal variables ofAccuracy
.The
eval
method calculates related indicators and returns the calculation result.
By invoking the eval
method of Accuracy
, you will obtain the calculation result.
You can understand how Accuracy
runs by using the following code:
import mindspore as ms
from mindspore.nn import Accuracy
import numpy as np
x = ms.Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = ms.Tensor(np.array([1, 0, 1]))
metric = Accuracy()
metric.clear()
metric.update(x, y)
accuracy = metric.eval()
print('Accuracy is ', accuracy)
The output is as follows:
Accuracy is 0.6667
MindSpore Print Operator Introduction
MindSpore-developed Print
operator is used to print the tensors or character strings input by users. Multiple strings, multiple tensors, and a combination of tensors and strings are supported, which are separated by comma (,). The Print
operator is supported in Ascend and GPU environment.
The method of using the MindSpore Print
operator is the same as that of other operators. You need to declare the operator in the __init__
in the network and call it inconstruct
, and the specific usage examples and output results are as follows:
import numpy as np
import mindspore as ms
import mindspore.ops as ops
import mindspore.nn as nn
ms.set_context(mode=ms.GRAPH_MODE)
class PrintDemo(nn.Cell):
def __init__(self):
super(PrintDemo, self).__init__()
self.print = ops.Print()
def construct(self, x, y):
self.print('print Tensor x and Tensor y:', x, y)
return x
x = ms.Tensor(np.ones([2, 1]).astype(np.int32))
y = ms.Tensor(np.ones([2, 2]).astype(np.int32))
net = PrintDemo()
output = net(x, y)
The output is as follows:
print Tensor x and Tensor y:
Tensor(shape=[2, 1], dtype=Int32, value=
[[1]
[1]])
Tensor(shape=[2, 2], dtype=Int32, value=
[[1 1]
[1 1]])
Running Data Recorder
Running Data Recorder(RDR) is the feature MindSpore provides to record data while training program is running. If a running exception occurs in MindSpore, the pre-recorded data in MindSpore is automatically exported to assist in locating the cause of the running exception. Different exceptions will export different data, for instance, the occurrence of Run task error
exception, the computational graph, execution sequence of the graph, memory allocation and other information will be exported to assist in locating the cause of the exception.
Not all run exceptions export data, and only partial exception exports are currently supported.
Only supports the data collection of CPU/Ascend/GPU in the training scenario with the graph mode.
Usage
Set RDR by Configuration File
Create the configuration file
mindspore_config.json
.{ "rdr": { "enable": true, "mode": 1, "path": "/path/to/rdr/dir" } }
enable: Controls whether the RDR is enabled.
mode: Controls RDR data exporting mode. When mode is set to 1, RDR exports data only in the exceptional scenario. When mode is set to 2, RDR exports data in exceptional or normal scenario.
path: Set the path to which RDR stores data. Only absolute path is supported.
Configure RDR via
context
.import mindspore as ms ms.set_context(env_config_path="./mindspore_config.json")
Set RDR by Environment Variables
Set export MS_RDR_ENABLE=1
to enable RDR, and set export MS_RDR_MODE=1
or export MS_RDR_MODE=2
to control exporting mode for RDR data, and set the root directory by export MS_RDR_PATH=/path/to/root/dir
for recording data. The final directory for recording data is /path/to/root/dir/rank_{RANK_ID}/rdr/
. RANK_ID
is the unique ID for multi-cards training, the single card scenario defaults to RANK_ID=0
.
The configuration file set by the user takes precedence over the environment variables.
Exception Handling
If MindSpore is used for training on Ascend 910, there is an exception Run task error
in training.
When we go to the directory for recording data, we can see several files appear in this directory, each file represents a kind of data. For example, hwopt_d_before_graph_0.ir
is a computational graph file. You can use a text tool to open this file to view the calculational graph and analyze whether the calculational graph meets your expectations.
Diagnosis Handling
When RDR is enabled and environment variable export MS_RDR_MODE=2
is set, it is diagnostic mode. After the graph compilation is complete, we can also see the saved file which is the same as those that are exception handled in the export directory of the RDR file.
Memory Reuse
The memory reuse is to let different Tensors share the same part of the memory to reduce memory overhead and support a larger network. After shutting down, each Tensor has its own independent memory space, and tensors have no shared memory.
The MindSpore memory multiplexing function is turned on by default, and the function can be manually controlled to turn off and on in the following ways.
Usage
Construct configuration file
mindspore_config.json
.{ "sys": { "mem_reuse": true } }
mem_reuse: controls whether the memory multiplexing function is turned on. When it is set to true, the control memory multiplexing function is turned on, and when false, the memory multiplexing function is turned off.
Configure the memory multiplexing function through
context
.import mindspore as ms ms.set_context(env_config_path="./mindspore_config.json")