Custom Debugging Information
Linux
Ascend
GPU
CPU
Model Optimization
Intermediate
Expert
Overview
This section describes how to use the customized capabilities provided by MindSpore, such as callback
, metrics
,Print
operator and log printing, to help you quickly debug the training network.
Introduction to Callback
Callback here is not a function but a class. You can use callback to observe the internal status and related information of the network during training or perform specific actions in a specific period. For example, you can monitor the loss, save model parameters, dynamically adjust parameters, and terminate training tasks in advance.
Callback Capabilities of MindSpore
MindSpore provides the callback capabilities to allow users to insert customized operations in a specific phase of training or inference, including:
Callback classes such as
ModelCheckpoint
,LossMonitor
, andSummaryCollector
provided by the MindSpore frameworkCustom callback classes
Usage: Transfer the callback object in the model.train
method. The callback object can be a list, for example:
ckpt_cb = ModelCheckpoint()
loss_cb = LossMonitor()
summary_cb = SummaryCollector(summary_dir='./summary_dir')
model.train(epoch, dataset, callbacks=[ckpt_cb, loss_cb, summary_cb])
ModelCheckpoint
can save model parameters for retraining or inference.
LossMonitor
can output loss information in logs for users to view. In addition, LossMonitor
monitors the loss value change during training. When the loss value is Nan
or Inf
, the training terminates.
SummaryCollector
can save the training information to files for later use.
During the training process, the callback list will execute the callback function in the defined order. Therefore, in the definition process, the dependency between callbacks needs to be considered.
Custom Callback
You can customize callback based on the callback
base class as required.
The callback base class is defined as follows:
class Callback():
"""Callback base class"""
def begin(self, run_context):
"""Called once before the network executing."""
pass
def epoch_begin(self, run_context):
"""Called before each epoch beginning."""
pass
def epoch_end(self, run_context):
"""Called after each epoch finished."""
pass
def step_begin(self, run_context):
"""Called before step epoch beginning."""
pass
def step_end(self, run_context):
"""Called after each step finished."""
pass
def end(self, run_context):
"""Called once after network training."""
pass
The callback can record important information during training and transfer the information to the callback object through a dictionary variable cb_params
,
You can obtain related attributes from each custom callback and perform customized operations. You can also customize other variables and transfer them to the cb_params
object.
The main attributes of cb_params
are as follows:
loss_fn: Loss function
optimizer: Optimizer
train_dataset: Training dataset
cur_epoch_num: Number of current epochs
cur_step_num: Number of current steps
batch_num: Number of batches in an epoch
…
You can inherit the callback base class to customize a callback object.
Here are two examples to further understand the usage of custom Callback.
Terminate training within the specified time.
class StopAtTime(Callback): def __init__(self, run_time): super(StopAtTime, self).__init__() self.run_time = run_time*60 def begin(self, run_context): cb_params = run_context.original_args() cb_params.init_time = time.time() def step_end(self, run_context): cb_params = run_context.original_args() epoch_num = cb_params.cur_epoch_num step_num = cb_params.cur_step_num loss = cb_params.net_outputs cur_time = time.time() if (cur_time - cb_params.init_time) > self.run_time: print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss) run_context.request_stop() stop_cb = StopAtTime(run_time=10) model.train(100, dataset, callbacks=stop_cb)
The output is as follows:
epoch: 20 step: 32 loss: 2.298344373703003
The implementation logic is: You can use the
run_context.original_args
method to obtain thecb_params
dictionary, which contains the main attribute information described above. In addition, you can modify and add values in the dictionary. In the preceding example, aninit_time
object is defined inbegin
and transferred to thecb_params
dictionary. A decision is made at eachstep_end
. When the training time is greater than the configured time threshold, a training termination signal will be sent to therun_context
to terminate the training in advance and the current values of epoch, step, and loss will be printed.Save the checkpoint file with the highest accuracy during training.
from mindspore.train.serialization import save_checkpoint class SaveCallback(Callback): def __init__(self, model, eval_dataset): super(SaveCallback, self).__init__() self.model = model self.eval_dataset = eval_dataset self.acc = 0.5 def step_end(self, run_context): cb_params = run_context.original_args() epoch_num = cb_params.cur_epoch_num result = self.model.eval(self.eval_dataset) if result['accuracy'] > self.acc: self.acc = result['accuracy'] file_name = str(self.acc) + ".ckpt" save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name) print("Save the maximum accuracy checkpoint,the accuracy is", self.acc) network = Lenet() loss = nn.SoftmaxCrossEntryWithLogits(sparse=True, reduction='mean') oprimizer = nn.Momentum(network.trainable_params(), 0.01, 0.9) model = Model(network, loss_fn=loss, optimizer=optimizer, metrics={"accuracy"}) model.train(epoch_size, train_dataset=ds_train, callbacks=SaveCallback(model, ds_eval))
The specific implementation logic is: define a callback object, and initialize the object to receive the model object and the ds_eval (verification dataset). Verify the accuracy of the model in the step_end phase. When the accuracy is the current highest, manually trigger the save checkpoint method to save the current parameters.
MindSpore Metrics
After the training is complete, you can use metrics to evaluate the training result.
MindSpore provides multiple metrics, such as accuracy
, loss
, tolerance
, recall
, and F1
.
You can define a metrics dictionary object that contains multiple metrics and transfer them to the model.eval
interface to verify the training precision.
metrics = {
'accuracy': nn.Accuracy(),
'loss': nn.Loss(),
'precision': nn.Precision(),
'recall': nn.Recall(),
'f1_score': nn.F1()
}
net = ResNet()
loss = CrossEntropyLoss()
opt = Momentum()
model = Model(net, loss_fn=loss, optimizer=opt, metrics=metrics, callbacks=TimeMonitor())
ds_eval = create_dataset()
output = model.eval(ds_eval)
The model.eval
method returns a dictionary that contains the metrics and results transferred to the metrics.
The callback function can also be used in the eval process, and the user can call the related API or customize the callback method to achieve the desired function.
You can also define your own metrics class by inheriting the Metric
base class and rewriting the clear
, update
, and eval
methods.
The Accuracy
operator is used as an example to describe the internal implementation principle.
The Accuracy
inherits the EvaluationBase
base class and rewrites the preceding three methods.
The
clear
method initializes related calculation parameters in the class.The
update
method accepts the predicted value and tag value and updates the internal variables of Accuracy.The
eval
method calculates related indicators and returns the calculation result.
By invoking the eval
method of Accuracy
, you will obtain the calculation result.
You can understand how Accuracy
runs by using the following code:
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([1, 0, 1]))
metric = Accuracy()
metric.clear()
metric.update(x, y)
accuracy = metric.eval()
print('Accuracy is ', accuracy)
The output is as follows:
Accuracy is 0.6667
MindSpore Print Operator
MindSpore-developed Print
operator is used to print the tensors or character strings input by users. Multiple strings, multiple tensors, and a combination of tensors and strings are supported, which are separated by comma (,).
The use method of MindSpore Print
operator is the same that of other operators. You need to assert MindSpore Print
operator in __init__
and invoke using construct
. The following is an example.
import numpy as np
from mindspore import Tensor
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE)
class PrintDemo(nn.Cell):
def __init__(self):
super(PrintDemo, self).__init__()
self.print = ops.Print()
def construct(self, x, y):
self.print('print Tensor x and Tensor y:', x, y)
return x
x = Tensor(np.ones([2, 1]).astype(np.int32))
y = Tensor(np.ones([2, 2]).astype(np.int32))
net = PrintDemo()
output = net(x, y)
The output is as follows:
print Tensor x and Tensor y:
Tensor shape:[[const vector][2, 1]]Int32
val:[[1]
[1]]
Tensor shape:[[const vector][2, 2]]Int32
val:[[1 1]
[1 1]]
Data Dump Introduction
The input and output of the operator can be saved for debugging through the data dump when the training result deviates from the expectation. Data dump includes Synchronous Dump and Asynchronous Dump.
Synchronous Dump
Create dump json file:
data_dump.json
.The name and location of the JSON file can be customized.
{ "common_dump_settings": { "dump_mode": 0, "path": "/tmp/net/", "net_name": "ResNet50", "iteration": 0, "input_output": 0, "kernels": ["Default/Conv-op12"], "support_device": [0,1,2,3,4,5,6,7] }, "e2e_dump_settings": { "enable": false, "trans_flag": false } }
dump_mode
:0:dump all kernels in graph, 1: dump kernels in kernels list.path
:The absolute path where dump saves data.net_name
:net name eg:ResNet50.iteration
:Specify the iterations to dump. All kernels in graph will be dumped.input_output
:0:dump input and output of kernel, 1:dump input of kernel, 2:dump output of kernel.kernels
:full name of kernel. Enablecontext.set_context(save_graphs=True)
and get full name of kernel fromir
file. You can get it fromhwopt_d_end_graph_{graph_id}.ir
whendevice_target
isAscend
and you can get it fromhwopt_pm_7_getitem_tuple.ir
whendevice_target
isGPU
.support_device
:support devices, default setting is[0,1,2,3,4,5,6,7]
. You can specify specific device ids to dump specific device data.enable
:enable synchronous dump.trans_flag
:enable trans flag. Transform the device data format into NCHW.
Specify the location of the JSON file.
export MINDSPORE_DUMP_CONFIG={Absolute path of data_dump.json}
Set the environment variables before executing the training script. Settings will not take effect during training.
Dump environment variables need to be configured before calling
mindspore.communication.management.init
.
Execute the training script to dump data.
You can set
context.set_context(reserve_class_name_in_scope=False)
in your training script to avoid dump failure because of file name is too long.Parse the Dump file
Call
numpy.fromfile
to parse dump data file.
Asynchronous Dump
Create dump json file:
data_dump.json
.The name and location of the JSON file can be customized.
{ "common_dump_settings": { "dump_mode": 0, "path": "/absolute_path", "net_name": "ResNet50", "iteration": 0, "input_output": 0, "kernels": ["Default/Conv-op12"], "support_device": [0,1,2,3,4,5,6,7] }, "async_dump_settings": { "enable": false, "op_debug_mode": 0 } }
dump_mode
:0:dump all kernels in graph, 1: dump kernels in kernels list.path
:Absolute path where dump data saves.net_name
:net name eg:ResNet50.iteration
:Specify the iterations to dump. Iteration should be set to 0 when dataset_sink_mode is False and data of every iteration will be dumped.input_output
:0:dump input and output of kernel, 1:dump input of kernel, 2:dump output of kernel. This parameter does not take effect on the GPU and only the output of operator will be dumped.kernels
:Full name of kernel. Enablecontext.set_context(save_graphs=True)
and get full name of kernel fromhwopt_d_end_graph_{graph_id}.ir
.kernels
only support TBE operator, AiCPU operator and communication operator. Data of communication operation input operator will be dumped ifkernels
is set to the name of communication operator.support_device
:support devices, default setting is[0,1,2,3,4,5,6,7]
. You can specify specific device ids to dump specific device data.enable
:enable Asynchronous Dump.op_debug_mode
:please set to 0.
Specify the json configuration file of Dump.
export MINDSPORE_DUMP_CONFIG={Absolute path of data_dump.json}
Set the environment variables before executing the training script. Setting environment variables during training will not take effect.
Dump environment variables need to be configured before calling
mindspore.communication.management.init
.
Execute the training script to dump data.
You can set
context.set_context(reserve_class_name_in_scope=False)
in your training script to avoid dump failure because of file name is too long.Parse the Dump file
Change directory to /var/log/npu/ide_daemon/dump/ after training, execute the following commands to parse Dump data file:
python /usr/local/Ascend/toolkit/tools/operator_cmp/compare/dump_data_conversion.pyc -type offline -target numpy -i ./{Dump file path}} -o ./{output file path}