Network Migration Debugging Example
The following uses the classic network ResNet50 as an example to describe the network migration method in detail based on the code.
Model Analysis and Preparation
Assume that the MindSpore operating environment has been configured according to Environment Preparation and Information Acquisition. Assume that ResNet-50 has not been implemented in the models repository.
First, analyze the algorithm and network structure.
The Residual Neural Network (ResNet) was proposed by Kaiming He et al. from Microsoft Research Institute. They used residual units to successfully train a 152-layer neural network, and thus became the winner of ILSVRC 2015. A conventional convolutional network or fully-connected network has more or less information losses, and further causes gradient disappearance or explosion. As a result, deep network training fails. The ResNet can solve these problems to some extent. By passing the input information to the output, the information integrity is protected. The network only needs to learn the differences between the input and output, simplifying the learning objective and difficulty. Its structure can accelerate training of a neural network and greatly improve the accuracy of the network model.
Paper: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.”Deep Residual Learning for Image Recognition”
The sample code of PyTorch ResNet-50 CIFAR-10 contains the PyTorch ResNet implementation, CIFAR-10 data processing, network training, and inference processes.
Checklist
When reading the paper and referring to the implementation, analyze and fill in the following checklist:
Trick |
Record |
---|---|
Data augmentation |
RandomCrop, RandomHorizontalFlip, Resize, Normalize |
Learning rate attenuation policy |
Fixed learning rate = 0.001 |
Optimization parameters |
Adam optimizer, weight_decay = 1e-5 |
Training parameters |
batch_size = 32, epochs = 90 |
Network structure optimization |
Bottleneck |
Training process optimization |
None |
Reproducing Reference Implementation
Download the PyTorch code and CIFAR-10 dataset to train the network.
Train Epoch: 89 [0/1563 (0%)] Loss: 0.010917
Train Epoch: 89 [100/1563 (6%)] Loss: 0.013386
Train Epoch: 89 [200/1563 (13%)] Loss: 0.078772
Train Epoch: 89 [300/1563 (19%)] Loss: 0.031228
Train Epoch: 89 [400/1563 (26%)] Loss: 0.073462
Train Epoch: 89 [500/1563 (32%)] Loss: 0.098645
Train Epoch: 89 [600/1563 (38%)] Loss: 0.112967
Train Epoch: 89 [700/1563 (45%)] Loss: 0.137923
Train Epoch: 89 [800/1563 (51%)] Loss: 0.143274
Train Epoch: 89 [900/1563 (58%)] Loss: 0.088426
Train Epoch: 89 [1000/1563 (64%)] Loss: 0.071185
Train Epoch: 89 [1100/1563 (70%)] Loss: 0.094342
Train Epoch: 89 [1200/1563 (77%)] Loss: 0.126669
Train Epoch: 89 [1300/1563 (83%)] Loss: 0.245604
Train Epoch: 89 [1400/1563 (90%)] Loss: 0.050761
Train Epoch: 89 [1500/1563 (96%)] Loss: 0.080932
Test set: Average loss: -9.7052, Accuracy: 91%
Finished Training
You can download training logs and saved parameter files from resnet_pytorch_res.
Analyzing API/Feature Missing
API analysis
PyTorch API
MindSpore API
Different or Not
nn.Conv2D
nn.Conv2d
Yes. Difference
nn.BatchNorm2D
nn.BatchNom2d
Yes. Difference
nn.ReLU
nn.ReLU
No
nn.MaxPool2D
nn.MaxPool2d
Yes. Difference
nn.AdaptiveAvgPool2D
nn.AdaptiveAvgPool2D
No
nn.Linear
nn.Dense
Yes. Difference
torch.flatten
nn.Flatten
No
By using MindSpore Dev Toolkit tool or checking PyTorch API Mapping, we find that four APIs are different.
Function analysis
PyTorch Function
MindSpore Function
nn.init.kaiming_normal_
initializer(init='HeNormal')
nn.init.constant_
initializer(init='Constant')
nn.Sequential
nn.SequentialCell
nn.Module
nn.Cell
nn.distibuted
set_auto_parallel_context
torch.optim.SGD
nn.optim.SGD
ornn.optim.Momentum
(The interface design of MindSpore is different from that of PyTorch. Therefore, only the comparison of key functions is listed here.)
After API and function analysis, we find that there are no missing APIs and functions on MindSpore compared with PyTorch.
MindSpore Model Implementation
Datasets
The CIFAR-10 dataset is as follows:
└─dataset_path
├─cifar-10-batches-bin # train dataset
├─ data_batch_1.bin
├─ data_batch_2.bin
├─ data_batch_3.bin
├─ data_batch_4.bin
├─ data_batch_5.bin
└─cifar-10-verify-bin # evaluate dataset
├─ test_batch.bin
This operation is implemented on PyTorch/MindSpore as follows:
PyTorch Dataset Processing | MindSpore Dataset Processing |
|
|
Network Model Implementation
By referring to PyTorch ResNet, we have implemented MindSpore ResNet. The comparison tool shows that the implementation is different in the following aspects:
PyTorch | MindSpore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Loss Function
PyTorch | MindSpore |
|
|
Learning Rate and Optimizer
PyTorch | MindSpore |
|
|
Model Validation
The trained PyTorch parameters are obtained in Reproducing Reference Implementation. How do I convert the parameter file into a checkpoint file that can be used by MindSpore?
The following steps are required:
Print the names and shapes of all parameters in the PyTorch parameter file and the names and shapes of all parameters in the MindSpore cell to which parameters need to be loaded.
Compare the parameter name and shape to construct the parameter mapping.
Create a parameter list based on the parameter mapping (PyTorch parameters -> numpy -> MindSpore parameters) and save the parameter list as a checkpoint.
Unit test: Load PyTorch parameters and MindSpore parameters, construct random input, and compare the output.
Printing Parameters
PyTorch | MindSpore |
|
|
Parameter Mapping and Checkpoint Saving
Except the BatchNorm parameter, the names and shapes of other parameters are correct. In this case, you can write a simple Python script for parameter mapping.
import mindspore as ms
def param_convert(ms_params, pt_params, ckpt_path):
# Parameter name mapping dictionary
bn_ms2pt = {"gamma": "weight",
"beta": "bias",
"moving_mean": "running_mean",
"moving_variance": "running_var"}
new_params_list = []
for ms_param in ms_params.keys():
# In the parameter list, only the parameters that contain bn and downsample.1 are the parameters of the BatchNorm operator.
if "bn" in ms_param or "downsample.1" in ms_param:
ms_param_item = ms_param.split(".")
pt_param_item = ms_param_item[:-1] + [bn_ms2pt[ms_param_item[-1]]]
pt_param = ".".join(pt_param_item)
# If the corresponding parameter is found and the shape is the same, add the parameter to the parameter list.
if pt_param in pt_params and pt_params[pt_param].shape == ms_params[ms_param].shape:
ms_value = pt_params[pt_param]
new_params_list.append({"name": ms_param, "data": ms.Tensor(ms_value)})
else:
print(ms_param, "not match in pt_params")
# Other parameters
else:
# If the corresponding parameter is found and the shape is the same, add the parameter to the parameter list.
if ms_param in pt_params and pt_params[ms_param].shape == ms_params[ms_param].shape:
ms_value = pt_params[ms_param]
new_params_list.append({"name": ms_param, "data": ms.Tensor(ms_value)})
else:
print(ms_param, "not match in pt_params")
# Save as MindSpore checkpoint.
ms.save_checkpoint(new_params_list, ckpt_path)
ckpt_path = "resnet50.ckpt"
param_convert(ms_params, pt_params, ckpt_path)
After the execution is complete, you can find the generated checkpoint file in ckpt_path
.
If the parameter mapping is complex and it is difficult to find the mapping based on the parameter name, you can write a parameter mapping dictionary, for example:
param = {
'bn1.bias': 'bn1.beta',
'bn1.weight': 'bn1.gamma',
'IN.weight': 'IN.gamma',
'IN.bias': 'IN.beta',
'BN.bias': 'BN.beta',
'in.weight': 'in.gamma',
'bn.weight': 'bn.gamma',
'bn.bias': 'bn.beta',
'bn2.weight': 'bn2.gamma',
'bn2.bias': 'bn2.beta',
'bn3.bias': 'bn3.beta',
'bn3.weight': 'bn3.gamma',
'BN.running_mean': 'BN.moving_mean',
'BN.running_var': 'BN.moving_variance',
'bn.running_mean': 'bn.moving_mean',
'bn.running_var': 'bn.moving_variance',
'bn1.running_mean': 'bn1.moving_mean',
'bn1.running_var': 'bn1.moving_variance',
'bn2.running_mean': 'bn2.moving_mean',
'bn2.running_var': 'bn2.moving_variance',
'bn3.running_mean': 'bn3.moving_mean',
'bn3.running_var': 'bn3.moving_variance',
'downsample.1.running_mean': 'downsample.1.moving_mean',
'downsample.1.running_var': 'downsample.1.moving_variance',
'downsample.0.weight': 'downsample.1.weight',
'downsample.1.bias': 'downsample.1.beta',
'downsample.1.weight': 'downsample.1.gamma'
}
Then, you can obtain the parameter file based on the param_convert
process.
Unit Test
After obtaining the corresponding parameter file, you need to perform a unit test on the entire model to ensure model consistency.
import numpy as np
import torch
import mindspore as ms
from resnet_ms.src.resnet import resnet50 as ms_resnet50
from resnet_pytorch.resnet import resnet50 as pt_resnet50
def check_res(pth_path, ckpt_path):
inp = np.random.uniform(-1, 1, (4, 3, 224, 224)).astype(np.float32)
# When performing a unit test, you need to add a training or inference label to the cell.
ms_resnet = ms_resnet50(num_classes=10).set_train(False)
pt_resnet = pt_resnet50(num_classes=10).eval()
pt_resnet.load_state_dict(torch.load(pth_path, map_location='cpu'))
ms.load_checkpoint(ckpt_path, ms_resnet)
print("========= pt_resnet conv1.weight ==========")
print(pt_resnet.conv1.weight.detach().numpy().reshape((-1,))[:10])
print("========= ms_resnet conv1.weight ==========")
print(ms_resnet.conv1.weight.data.asnumpy().reshape((-1,))[:10])
pt_res = pt_resnet(torch.from_numpy(inp))
ms_res = ms_resnet(ms.Tensor(inp))
print("========= pt_resnet res ==========")
print(pt_res)
print("========= ms_resnet res ==========")
print(ms_res)
print("diff", np.max(np.abs(pt_res.detach().numpy() - ms_res.asnumpy())))
pth_path = "resnet.pth"
ckpt_path = "resnet50.ckpt"
check_res(pth_path, ckpt_path)
During the unit test, you need to add training or inference labels to cells. PyTorch training uses .train()
and inference uses .eval()
, MindSpore training uses .set_train()
and inference uses .set_train(False)
.
Result:
========= pt_resnet conv1.weight ==========
[ 1.091892e-40 -1.819391e-39 3.509566e-40 -8.281730e-40 1.207908e-39
-3.576954e-41 -1.000796e-39 1.115791e-39 -1.077758e-39 -6.031427e-40]
========= ms_resnet conv1.weight ==========
[ 1.091892e-40 -1.819391e-39 3.509566e-40 -8.281730e-40 1.207908e-39
-3.576954e-41 -1.000796e-39 1.115791e-39 -1.077758e-39 -6.031427e-40]
========= pt_resnet res ==========
tensor([[-15.1945, -5.6529, 6.5738, 9.7807, -2.4615, 3.0365, -4.7216,
-11.1005, 2.7121, -9.3612],
[-14.2412, -5.9004, 5.6366, 9.7030, -1.6322, 2.6926, -3.7307,
-10.7582, 1.4195, -7.9930],
[-13.4795, -5.6582, 5.6432, 8.9152, -1.5169, 2.6958, -3.4469,
-10.5300, 1.3318, -8.1476],
[-13.6448, -5.4239, 5.8254, 9.3094, -2.1969, 2.7042, -4.1194,
-10.4388, 1.9331, -8.1746]], grad_fn=<AddmmBackward0>)
========= ms_resnet res ==========
[[-15.194535 -5.652934 6.5737996 9.780719 -2.4615316 3.0365033
-4.7215843 -11.100524 2.7121294 -9.361177 ]
[-14.24116 -5.9004383 5.6366115 9.702984 -1.6322318 2.69261
-3.7307222 -10.758192 1.4194587 -7.992969 ]
[-13.47945 -5.658216 5.6432185 8.915173 -1.5169426 2.6957715
-3.446888 -10.529953 1.3317728 -8.147601 ]
[-13.644804 -5.423854 5.825424 9.309403 -2.1969485 2.7042081
-4.119426 -10.438771 1.9330862 -8.174606 ]]
diff 2.861023e-06
The final result is similar and basically meets the expectation. When the results are very different, you can fix the randomness of PyTorch and MindSpore after completing the parameter mapping, and then use the tool: TroubleShooter API level network results automatic comparison for comparing the network forward and reverse results to improve the localization efficiency.
Inference Process
PyTorch | MindSpore |
|
|
Execute: |
|
Result: |
Result: |
The inference accuracy is the same.
When inference results are inconsistent, here the tool TroubleShooter compares MindSpore and PyTorch network outputs for consistency compares the inference results of PyTorch and MindSpore networks to locate where the network outputs start to be inconsistent, to improve the migration efficiency.
Training Process
For details about the PyTorch training process, see PyToch ResNet-50 CIFAR-10 Sample Code. The log file and trained path are stored in resnet_pytorch_res.
The corresponding MindSpore code is as follows:
import numpy as np
import mindspore as ms
from mindspore.train import Model
from mindspore import nn, Profiler
from src.dataset import create_dataset
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.utils import init_env
from src.resnet import resnet50
def train_epoch(epoch, model, loss_fn, optimizer, data_loader):
model.set_train()
# Define forward function
def forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# Define function of one-step training
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
optimizer(grads)
return loss
dataset_size = data_loader.get_dataset_size()
for batch_idx, (data, target) in enumerate(data_loader):
loss = float(train_step(data, target).asnumpy())
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx, dataset_size,
100. * batch_idx / dataset_size, loss))
def test_epoch(model, data_loader, loss_func):
model.set_train(False)
test_loss = 0
correct = 0
for data, target in data_loader:
output = model(data)
test_loss += float(loss_func(output, target).asnumpy())
pred = np.argmax(output.asnumpy(), axis=1)
correct += (pred == target.asnumpy()).sum()
dataset_size = data_loader.get_dataset_size()
test_loss /= dataset_size
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
test_loss, 100. * correct / dataset_size))
@moxing_wrapper()
def train_net():
init_env(config)
if config.enable_profiling:
profiler = Profiler()
train_dataset = create_dataset(config.dataset_name, config.data_path, True, batch_size=config.batch_size,
image_size=(int(config.image_height), int(config.image_width)),
rank_size=40, rank_id=config.rank_id)
eval_dataset = create_dataset(config.dataset_name, config.data_path, False, batch_size=1,
image_size=(int(config.image_height), int(config.image_width)))
config.steps_per_epoch = train_dataset.get_dataset_size()
resnet = resnet50(num_classes=config.class_num)
optimizer = nn.Adam(resnet.trainable_params(), config.lr, weight_decay=config.weight_decay)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
for epoch in range(config.epoch_size):
train_epoch(epoch, train_net, loss_fn, optimizer, train_dataset)
test_epoch(resnet, eval_dataset, loss_fn)
print('Finished Training')
save_path = './resnet.ckpt'
ms.save_checkpoint(resnet, save_path)
if __name__ == '__main__':
train_net()
Performance Optimization
During the preceding training, it is found that the training is slow and performance optimization is required. Before performing specific optimization items, run the profiler tool to obtain the performance data. The profiler tool can obtain only the training encapsulated by the model. Therefore, you need to reconstruct the training process first.
device_num = config.device_num
if config.use_profilor:
profiler = Profiler()
# Note that the profiling data should not be too large. Otherwise, the processing will be slow. In this example, if use_profilor is set to True, the original dataset is divided into 40 copies.
device_num = 40
train_dataset = create_dataset(config.dataset_name, config.data_path, True, batch_size=config.batch_size,
image_size=(int(config.image_height), int(config.image_width)),
rank_size=device_num, rank_id=config.rank_id)
.....
loss_scale = ms.amp.FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(resnet, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale)
if config.use_profilor:
# Note that the profiling data should not be too large. Otherwise, the processing will be slow.
model.train(3, train_dataset, callbacks=[LossMonitor(), TimeMonitor()], dataset_sink_mode=True)
profiler.analyse()
else:
model.train(config.epoch_size, train_dataset, eval_dataset, callbacks=[LossMonitor(), TimeMonitor()],
dataset_sink_mode=False)
Set use_profilor=True
. The data
directory is generated in the running directory. Rename the directory profiler_v1
and run the mindinsight start
command in the same directory.
The following figure shows the MindSpore Insight profiler page. (This analysis is performed in the Ascend environment, which is similar to that in the GPU. The CPU does not support profiler.) There are three parts on the page.
The first part is step trace, which is the most basic part for profiler. The data of a single device includes the step interval and forward and backward propagation. The forward and backward time is the actual running time of the model on the device, and the step interval time includes data processing, data printing, and time when parameters are saved on the CPU during the training process. It can be seen that the step trace time and forward and backward execution time are almost even, and non-device operations such as data processing account for a large part.
The second part is the forward and backward network execution time, where you can view details.
The upper part shows the proportion of each AI Core operator to the total time, and the lower part shows the details of each operator.
You can click an operator to obtain the execution time, scope, shape, and type of the operator.
In addition to the AI Core operators, there may be AI CPU and HOST CPU operators on the network. These operators take more time than the AI Core operators. You can click the tabs to view the time.
In addition to viewing the operator performance, you can also view the raw data for analysis.
Go to the profiler_v1/profiler/
directory and click the aicore_intermediate_0_type.csv
file to view the statistics of each operator. There are 30 AI Core operators in total. The total execution time is 37.526 ms.
In addition, aicore_intermediate_0_detail.csv
contains detailed data of each operator, which is similar to the operator details displayed in MindSpore Insight. ascend_timeline_display_0.json
is a timeline data file. For details, see timeline.
The third part is the performance data during data processing. You can view the data queue status in this part.
And a queue status of each data processing operation:
Now, let’s analyze the process and solve the problem.
From the step trace, the step interval and forward and backward execution time are almost even. MindSpore provides an on-device execution method to concurrently process data and execute the network on the device. You only need to set dataset_sink_mode=True
in model.train
. Note that this configuration is True
by default. When this configuration is enabled, one epoch returns the result of only one network. You are advised to change the value to False
during debugging.
When dataset_sink_mode=True
is set, the result of setting the profiler is as follows:
The execution time is reduced by half.
Let’s go on with the analysis and optimization. According to the execution time of forward and backward operators, Cast
and BatchNorm
account for almost 50%. Why are there so many Cast
? According to Constructing MindSpore Network, Conv, Sort, and TopK in the Ascend environment can only be float16. Therefore, the Cast
operator is added before and after Conv calculation. The most direct method is to change the network calculation to float16. Only Cast
is added before the network input and loss computation. The consumption of the Cast
operator can be ignored. This involves the mixed precision policy of MindSpore.
MindSpore has three methods to use mixed precision:
Use
Cast
to convert the network inputcast
intofloat16
and the loss inputcast
intofloat32
.Use the
to_float
method ofCell
. For details, see Network Construction.Use the
amp_level
interface of theModel
to perform mixed precision. For details, see Automatic Mixed-Precision.
Use the third method to set amp_level
in Model
to O3
and check the profiler result.
Each step takes only 23 ms.
Finally, let’s look at data processing.
After the sink mode is added, there are two queues in total. The host queue is a queue in the memory. The dataset object continuously places the input data required by the network in the host queue. The other is a data queue on the device. The data in the host queue is cached to the data queue, and the network directly obtains the model input from the data queue.
The host queue is empty in many places, indicating that the dataset is quickly taken away by the data queue when data is continuously generated. The data queue is almost full. Therefore, data can keep up with network training, and data processing is not the bottleneck of network training.
If most of the data queues are empty, you need to optimize the data performance. For example:
In the queue of each data processing operation, the last operator and the batch
operator are empty for a long time. In this case, you can increase the degree of parallelism of the batch
operator. For details, see Data Processing Performance Tuning.
The code required for ResNet migration can be obtained from code.
You can click the following video to learn.