网络迁移调试实例
本章将以经典网络 ResNet50 为例,结合代码来详细介绍网络迁移方法。
模型分析与准备
假设已经按照环境准备章节配置好了MindSpore的运行环境。且假设resnet50在models仓还没有实现。
首先需要分析算法及网络结构。
残差神经网络(ResNet)由微软研究院何凯明等人提出,通过ResNet单元,成功训练152层神经网络,赢得了ILSVRC2015冠军。传统的卷积网络或全连接网络或多或少存在信息丢失的问题,还会造成梯度消失或爆炸,导致深度网络训练失败,ResNet则在一定程度上解决了这个问题。通过将输入信息传递给输出,确保信息完整性。整个网络只需要学习输入和输出的差异部分,简化了学习目标和难度。ResNet的结构大幅提高了神经网络训练的速度,并且大大提高了模型的准确率。
论文:Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun."Deep Residual Learning for Image Recognition"
我们找到了一份PyTorch ResNet50 Cifar10的示例代码,里面包含了PyTorch ResNet的实现,Cifar10数据处理,网络训练及推理流程。
checklist
在阅读论文和参考实现过程中,我们分析填写以下checklist:
trick |
记录 |
---|---|
数据增强 |
RandomCrop,RandomHorizontalFlip,Resize,Normalize |
学习率衰减策略 |
固定学习率 0.001 |
优化器参数 |
Adam优化器,weight_decay=1e-5 |
训练参数 |
batch_size=32,epochs=90 |
网络结构优化点 |
Bottleneck |
训练流程优化点 |
无 |
复现参考实现
下载PyTorch的代码,CIFAR-10的数据集,对网络进行训练:
Train Epoch: 89 [0/1563 (0%)] Loss: 0.010917
Train Epoch: 89 [100/1563 (6%)] Loss: 0.013386
Train Epoch: 89 [200/1563 (13%)] Loss: 0.078772
Train Epoch: 89 [300/1563 (19%)] Loss: 0.031228
Train Epoch: 89 [400/1563 (26%)] Loss: 0.073462
Train Epoch: 89 [500/1563 (32%)] Loss: 0.098645
Train Epoch: 89 [600/1563 (38%)] Loss: 0.112967
Train Epoch: 89 [700/1563 (45%)] Loss: 0.137923
Train Epoch: 89 [800/1563 (51%)] Loss: 0.143274
Train Epoch: 89 [900/1563 (58%)] Loss: 0.088426
Train Epoch: 89 [1000/1563 (64%)] Loss: 0.071185
Train Epoch: 89 [1100/1563 (70%)] Loss: 0.094342
Train Epoch: 89 [1200/1563 (77%)] Loss: 0.126669
Train Epoch: 89 [1300/1563 (83%)] Loss: 0.245604
Train Epoch: 89 [1400/1563 (90%)] Loss: 0.050761
Train Epoch: 89 [1500/1563 (96%)] Loss: 0.080932
Test set: Average loss: -9.7052, Accuracy: 91%
Finished Training
可以从resnet_pytorch_res下载到训练时日志和保存的参数文件。
分析API/特性缺失
API分析
PyTorch 使用API
MindSpore 对应API
是否有差异
nn.Conv2D
nn.Conv2d
有,差异对比
nn.BatchNorm2D
nn.BatchNom2d
有,差异对比
nn.ReLU
nn.ReLU
无
nn.MaxPool2D
nn.MaxPool2d
有,差异对比
nn.AdaptiveAvgPool2D
nn.AdaptiveAvgPool2D
无
nn.Linear
nn.Dense
有,差异对比
torch.flatten
nn.Flatten
无
可通过借助MindSpore Dev ToolkitAPI扫描工具,或查看PyTorch API映射来获取API差异。
功能分析
Pytorch 使用功能
MindSpore 对应功能
nn.init.kaiming_normal_
initializer(init='HeNormal')
nn.init.constant_
initializer(init='Constant')
nn.Sequential
nn.SequentialCell
nn.Module
nn.Cell
nn.distibuted
set_auto_parallel_context
torch.optim.SGD
nn.optim.SGD
ornn.optim.Momentum
(由于MindSpore 和 PyTorch 在接口设计上不完全一致,这里仅列出关键功能的比对)
经过API和功能分析,我们发现,相比 PyTorch,MindSpore 上没有缺失的API和功能。
MindSpore模型实现
数据集
以CIFAR-10数据集为例,其目录组织参考:
└─dataset_path
├─cifar-10-batches-bin # train dataset
├─ data_batch_1.bin
├─ data_batch_2.bin
├─ data_batch_3.bin
├─ data_batch_4.bin
├─ data_batch_5.bin
└─cifar-10-verify-bin # evaluate dataset
├─ test_batch.bin
PyTorch和MindSpore的数据集处理代码如下:
PyTorch 数据集处理 | MindSpore 数据集处理 |
|
|
网络模型实现
参考PyTorch resnet,我们实现了一版MindSpore resnet,通过比较工具发现,实现只有几个地方有差别:
PyTorch | MindSpore |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Loss函数
PyTorch | MindSpore |
|
|
学习率与优化器
PyTorch | MindSpore |
|
|
模型验证
在复现参考实现章节我们获取到了训练好的PyTorch的参数,我们怎样将参数文件转换成MindSpore能够使用的checkpoint文件呢?
基本需要以下几个流程:
打印PyTorch的参数文件里所有参数的参数名和shape,打印需要加载参数的MindSpore Cell里所有参数的参数名和shape;
比较参数名和shape,构造参数映射关系;
按照参数映射将PyTorch的参数 -> numpy -> MindSpore的Parameter,构成Parameter List后保存成checkpoint;
单元测试:PyTorch加载参数,MindSpore加载参数,构造随机输入,对比输出。
打印参数
PyTorch | MindSpore |
|
|
参数映射及checkpoint保存
发现除了BatchNorm的参数外,其他参数的名字和shape是完全能够对的上的,这时可以写一个简单的python脚本来做参数映射:
import mindspore as ms
def param_convert(ms_params, pt_params, ckpt_path):
# 参数名映射字典
bn_ms2pt = {"gamma": "weight",
"beta": "bias",
"moving_mean": "running_mean",
"moving_variance": "running_var"}
new_params_list = []
for ms_param in ms_params.keys():
# 在参数列表中,只有包含bn和downsample.1的参数是BatchNorm算子的参数
if "bn" in ms_param or "downsample.1" in ms_param:
ms_param_item = ms_param.split(".")
pt_param_item = ms_param_item[:-1] + [bn_ms2pt[ms_param_item[-1]]]
pt_param = ".".join(pt_param_item)
# 如找到参数对应且shape一致,加入到参数列表
if pt_param in pt_params and pt_params[pt_param].shape == ms_params[ms_param].shape:
ms_value = pt_params[pt_param]
new_params_list.append({"name": ms_param, "data": ms.Tensor(ms_value)})
else:
print(ms_param, "not match in pt_params")
# 其他参数
else:
# 如找到参数对应且shape一致,加入到参数列表
if ms_param in pt_params and pt_params[ms_param].shape == ms_params[ms_param].shape:
ms_value = pt_params[ms_param]
new_params_list.append({"name": ms_param, "data": ms.Tensor(ms_value)})
else:
print(ms_param, "not match in pt_params")
# 保存成MindSpore的checkpoint
ms.save_checkpoint(new_params_list, ckpt_path)
ckpt_path = "resnet50.ckpt"
param_convert(ms_params, pt_params, ckpt_path)
执行完成可以在ckpt_path
找到生成的checkpoint文件。
当参数映射关系非常复杂,通过参数名很难找到映射关系时,可以写一个参数映射字典,如:
param = {
'bn1.bias': 'bn1.beta',
'bn1.weight': 'bn1.gamma',
'IN.weight': 'IN.gamma',
'IN.bias': 'IN.beta',
'BN.bias': 'BN.beta',
'in.weight': 'in.gamma',
'bn.weight': 'bn.gamma',
'bn.bias': 'bn.beta',
'bn2.weight': 'bn2.gamma',
'bn2.bias': 'bn2.beta',
'bn3.bias': 'bn3.beta',
'bn3.weight': 'bn3.gamma',
'BN.running_mean': 'BN.moving_mean',
'BN.running_var': 'BN.moving_variance',
'bn.running_mean': 'bn.moving_mean',
'bn.running_var': 'bn.moving_variance',
'bn1.running_mean': 'bn1.moving_mean',
'bn1.running_var': 'bn1.moving_variance',
'bn2.running_mean': 'bn2.moving_mean',
'bn2.running_var': 'bn2.moving_variance',
'bn3.running_mean': 'bn3.moving_mean',
'bn3.running_var': 'bn3.moving_variance',
'downsample.1.running_mean': 'downsample.1.moving_mean',
'downsample.1.running_var': 'downsample.1.moving_variance',
'downsample.0.weight': 'downsample.1.weight',
'downsample.1.bias': 'downsample.1.beta',
'downsample.1.weight': 'downsample.1.gamma'
}
再结合param_convert
的相关流程就可以获取到参数文件了。
单元测试
获得对应的参数文件后,我们需要对整个模型做一次单元测试,保证模型的一致性:
import numpy as np
import torch
import mindspore as ms
from resnet_ms.src.resnet import resnet50 as ms_resnet50
from resnet_pytorch.resnet import resnet50 as pt_resnet50
def check_res(pth_path, ckpt_path):
inp = np.random.uniform(-1, 1, (4, 3, 224, 224)).astype(np.float32)
# 注意做单元测试时,需要给Cell打训练或推理的标签
ms_resnet = ms_resnet50(num_classes=10).set_train(False)
pt_resnet = pt_resnet50(num_classes=10).eval()
pt_resnet.load_state_dict(torch.load(pth_path, map_location='cpu'))
ms.load_checkpoint(ckpt_path, ms_resnet)
print("========= pt_resnet conv1.weight ==========")
print(pt_resnet.conv1.weight.detach().numpy().reshape((-1,))[:10])
print("========= ms_resnet conv1.weight ==========")
print(ms_resnet.conv1.weight.data.asnumpy().reshape((-1,))[:10])
pt_res = pt_resnet(torch.from_numpy(inp))
ms_res = ms_resnet(ms.Tensor(inp))
print("========= pt_resnet res ==========")
print(pt_res)
print("========= ms_resnet res ==========")
print(ms_res)
print("diff", np.max(np.abs(pt_res.detach().numpy() - ms_res.asnumpy())))
pth_path = "resnet.pth"
ckpt_path = "resnet50.ckpt"
check_res(pth_path, ckpt_path)
注意做单元测试时,需要给Cell打训练或推理的标签,PyTorch 训练 .train()
,推理.eval()
,MindSpore训练.set_train()
,推理.set_train(False)
。
打印结果为:
========= pt_resnet conv1.weight ==========
[ 1.091892e-40 -1.819391e-39 3.509566e-40 -8.281730e-40 1.207908e-39
-3.576954e-41 -1.000796e-39 1.115791e-39 -1.077758e-39 -6.031427e-40]
========= ms_resnet conv1.weight ==========
[ 1.091892e-40 -1.819391e-39 3.509566e-40 -8.281730e-40 1.207908e-39
-3.576954e-41 -1.000796e-39 1.115791e-39 -1.077758e-39 -6.031427e-40]
========= pt_resnet res ==========
tensor([[-15.1945, -5.6529, 6.5738, 9.7807, -2.4615, 3.0365, -4.7216,
-11.1005, 2.7121, -9.3612],
[-14.2412, -5.9004, 5.6366, 9.7030, -1.6322, 2.6926, -3.7307,
-10.7582, 1.4195, -7.9930],
[-13.4795, -5.6582, 5.6432, 8.9152, -1.5169, 2.6958, -3.4469,
-10.5300, 1.3318, -8.1476],
[-13.6448, -5.4239, 5.8254, 9.3094, -2.1969, 2.7042, -4.1194,
-10.4388, 1.9331, -8.1746]], grad_fn=<AddmmBackward0>)
========= ms_resnet res ==========
[[-15.194535 -5.652934 6.5737996 9.780719 -2.4615316 3.0365033
-4.7215843 -11.100524 2.7121294 -9.361177 ]
[-14.24116 -5.9004383 5.6366115 9.702984 -1.6322318 2.69261
-3.7307222 -10.758192 1.4194587 -7.992969 ]
[-13.47945 -5.658216 5.6432185 8.915173 -1.5169426 2.6957715
-3.446888 -10.529953 1.3317728 -8.147601 ]
[-13.644804 -5.423854 5.825424 9.309403 -2.1969485 2.7042081
-4.119426 -10.438771 1.9330862 -8.174606 ]]
diff 2.861023e-06
可以看到最后的结果差不大,基本符合预期。 当结果差很大时,可在完成参数映射后,固定PyTorch和MindSpore的随机性,再使用工具:TroubleShooter API级别网络结果自动比较进行网络正向和反向的结果对比,提升定位效率。
推理流程
PyTorch | MindSpore |
|
|
执行 |
|
得到推理精度结果: |
得到推理精度结果: |
推理精度一致。
当推理结果不一致时,这里可借助工具TroubleShooter比较MindSpore和PyTorch网络输出是否一致比较PyTorch和MindSpore网络的推理结果,定位网络输出哪里开始不一致,提升迁移效率。
训练流程
PyTorch的训练流程参考pytoch resnet50 CIFAR-10的示例代码,日志文件和训练好的pth保存在resnet_pytorch_res。
对应的MindSpore代码:
import numpy as np
import mindspore as ms
from mindspore.train import Model
from mindspore import nn
from mindspore import Profiler
from src.dataset import create_dataset
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.config import config
from src.utils import init_env
from src.resnet import resnet50
def train_epoch(epoch, model, loss_fn, optimizer, data_loader):
model.set_train()
# Define forward function
def forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# Define function of one-step training
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
optimizer(grads)
return loss
dataset_size = data_loader.get_dataset_size()
for batch_idx, (data, target) in enumerate(data_loader):
loss = float(train_step(data, target).asnumpy())
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx, dataset_size,
100. * batch_idx / dataset_size, loss))
def test_epoch(model, data_loader, loss_func):
model.set_train(False)
test_loss = 0
correct = 0
for data, target in data_loader:
output = model(data)
test_loss += float(loss_func(output, target).asnumpy())
pred = np.argmax(output.asnumpy(), axis=1)
correct += (pred == target.asnumpy()).sum()
dataset_size = data_loader.get_dataset_size()
test_loss /= dataset_size
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
test_loss, 100. * correct / dataset_size))
@moxing_wrapper()
def train_net():
init_env(config)
if config.enable_profiling:
profiler = Profiler()
train_dataset = create_dataset(config.dataset_name, config.data_path, True, batch_size=config.batch_size,
image_size=(int(config.image_height), int(config.image_width)),
rank_size=40, rank_id=config.rank_id)
eval_dataset = create_dataset(config.dataset_name, config.data_path, False, batch_size=1,
image_size=(int(config.image_height), int(config.image_width)))
config.steps_per_epoch = train_dataset.get_dataset_size()
resnet = resnet50(num_classes=config.class_num)
optimizer = nn.Adam(resnet.trainable_params(), config.lr, weight_decay=config.weight_decay)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
for epoch in range(config.epoch_size):
train_epoch(epoch, train_net, loss_fn, optimizer, train_dataset)
test_epoch(resnet, eval_dataset, loss_fn)
print('Finished Training')
save_path = './resnet.ckpt'
ms.save_checkpoint(resnet, save_path)
if __name__ == '__main__':
train_net()