迁移脚本
概述
本文档主要介绍,怎样将网络脚本从TensorFlow或PyTorch框架迁移至MindSpore。
TensorFlow脚本迁移MindSpore
通过读TensorBoard图,进行脚本迁移。
以TensorFlow实现的PoseNet为例,演示如何利用TensorBoard读图,编写MindSpore代码,将TensorFlow模型迁移到MindSpore上。
此处提到的PoseNet代码为基于Python2的代码,需要对Python3做一些语法更改才能在Python3上运行,具体修改内容不予赘述。
改写代码,利用
tf.summary
接口,保存TensorBoard需要的log,并启动TensorBoard。打开的TensorBoard如图所示,图例仅供参考,可能因log生成方式的差异,TensorBoard展示的图也有所差异。
找到3个输入的Placeholder,通过看图并阅读代码得知,第二、第三个输入都只在计算loss时使用。
至此,我们可以初步划分出,构造网络模型三步:
第一步,在网络的三个输入中,第一个输入将在backbone中计算出六个输出;
第二步,上一步结果与第二、第三个输入在loss子网中计算loss;
第三步,利用
TrainOneStepCell
自动微分构造反向网络;利用TensorFlow工程中提供的Adam优化器及属性,写出对应的MindSpore优化器来更新参数,网络脚本骨干可写作:
[1]:
from mindspore import nn
from mindspore.nn import TrainOneStepCell
from mindspore.nn import Adam
from mindspore.common.initializer import Normal
# combine backbone and loss
class PoseNetLossCell(nn.Cell):
def __init__(self, backbone, loss):
super(PoseNetLossCell, self).__init__()
self.pose_net = backbone
self.loss = loss
def construct(self, input_1, input_2, input_3):
p1_x, p1_q, p2_x, p2_q, p3_x, p3_q = self.poss_net(input_1)
loss = self.loss(p1_x, p1_q, p2_x, p2_q, p3_x, p3_q, input_2, input_3)
return loss
# define backbone
class PoseNet(nn.Cell):
def __init__(self):
super(PoseNet, self).__init__()
self.fc = nn.Dense(1, 6, Normal(0.02), Normal(0.02))
def construct(self, input_1):
"""do something with input_1, output num 6"""
p1_x, p1_q, p2_x, p2_q, p3_x, p3_q = self.fc(input_1)
return p1_x, p1_q, p2_x, p2_q, p3_x, p3_q
# define loss
class PoseNetLoss(nn.Cell):
def __init__(self):
super(PoseNetLoss, self).__init__()
def construct(self, p1_x, p1_q, p2_x, p2_q, p3_x, p3_q, poses_x, poses_q):
"""do something to calc loss"""
return loss
# define network
backbone = PoseNet()
loss = PoseNetLoss()
net_with_loss = PoseNetLossCell(backbone, loss)
opt = Adam(net_with_loss.trainable_params(), learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08, use_locking=False)
net_with_grad = TrainOneStepCell(net_with_loss, opt)
接下来,我们来具体实现backbone中的计算逻辑。
第一个输入首先经过了一个名为conv1的子图,通过看图可得,其中计算逻辑为:
输入->Conv2D->BiasAdd->ReLU,虽然图上看起来,BiasAdd后的算子名虽然为conv1,但其实际执行的是ReLU。
这样一来,第一个子图conv1,可以定义如下,具体参数,与原工程中的参数对齐:
class Conv1(nn.Cell):
def __init__(self):
super(Conv1, self).__init__()
self.conv = Conv2d()
self.relu = ReLU()
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
return x
通过观察TensorBoard图和代码,我们不难发现,原TensorFlow工程中定义的conv这一类型的子网,可以复写为MindSpore的子网,减少重复代码。
TensorFlow工程conv子网定义:
def conv(self, input, k_h, k_w, c_o, s_h, s_w, name, relu=True, padding=DEFAULT_PADDING, group=1, biased=True):
# Verify that the padding is acceptable
self.validate_padding(padding)
# Get the number of channels in the input
c_i = input.get_shape()[-1]
# Verify that the grouping parameter is valid
assert c_i % group == 0
assert c_o % group == 0
# Convolution for a given input and kernel
convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
with tf.variable_scope(name) as scope:
kernel = self.make_var('weights', shape=[k_h, k_w, c_i / group, c_o])
if group == 1:
# This is the common-case. Convolve the input without any further complications.
output = convolve(input, kernel)
else:
# Split the input into groups and then convolve each of them independently
input_groups = tf.split(3, group, input)
kernel_groups = tf.split(3, group, kernel)
output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)]
# Concatenate the groups
output = tf.concat(3, output_groups)
# Add the biases
if biased:
biases = self.make_var('biases', [c_o])
output = tf.nn.bias_add(output, biases)
if relu:
# ReLU non-linearity
output = tf.nn.relu(output, name=scope.name)
return output
则对应MindSpore子网定义如下:
[2]:
from mindspore import nn
from mindspore.nn import Conv2d, ReLU
class ConvReLU(nn.Cell):
def __init__(self, channel_in, kernel_size, channel_out, strides):
super(ConvReLU, self).__init__()
self.conv = Conv2d(channel_in, channel_out, kernel_size, strides, has_bias=True)
self.relu = ReLU()
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
return x
那么,对照着TensorBoard中的数据流向与算子属性,backbone计算逻辑可编写如下:
from mindspore.nn import MaxPool2d
import mindspore.ops as ops
class LRN(nn.Cell):
def __init__(self, radius, alpha, beta, bias=1.0):
super(LRN, self).__init__()
self.lrn = ops.LRN(radius, bias, alpha, beta)
def construct(self, x):
return self.lrn(x)
class PoseNet(nn.Cell):
def __init__(self):
super(PoseNet, self).__init__()
self.conv1 = ConvReLU(3, 7, 64, 2)
self.pool1 = MaxPool2d(3, 2, pad_mode="SAME")
self.norm1 = LRN(2, 2e-05, 0.75)
self.reduction2 = ConvReLU(64, 1, 64, 1)
self.conv2 = ConvReLU(64, 3, 192, 1)
self.norm2 = LRN(2, 2e-05, 0.75)
self.pool2 = MaxPool2d(3, 2, pad_mode="SAME")
self.icp1_reduction1 = ConvReLU(192, 1, 96, 1)
self.icp1_out1 = ConvReLU(96, 3, 128, 1)
self.icp1_reduction2 = ConvReLU(192, 1, 16, 1)
self.icp1_out2 = ConvReLU(16, 5, 32, 1)
self.icp1_pool = MaxPool2d(3, 1, pad_mode="SAME")
self.icp1_out3 = ConvReLU(192, 5, 32, 1)
self.icp1_out0 = ConvReLU(192, 1, 64, 1)
self.concat = ops.Concat(axis=1)
self.icp2_reduction1 = ConvReLU(256, 1, 128, 1)
self.icp2_out1 = ConvReLU(128, 3, 192, 1)
self.icp2_reduction2 = ConvReLU(256, 1, 32, 1)
self.icp2_out2 = ConvReLU(32, 5, 96, 1)
self.icp2_pool = MaxPool2d(3, 1, pad_mode="SAME")
self.icp2_out3 = ConvReLU(256, 1, 64, 1)
self.icp2_out0 = ConvReLU(256, 1, 128, 1)
self.icp3_in = MaxPool2d(3, 2, pad_mode="SAME")
self.icp3_reduction1 = ConvReLU(480, 1, 96, 1)
self.icp3_out1 = ConvReLU(96, 3, 208, 1)
self.icp3_reduction2 = ConvReLU(480, 1, 16, 1)
self.icp3_out2 = ConvReLU(16, 5, 48, 1)
self.icp3_pool = MaxPool2d(3, 1, pad_mode="SAME")
self.icp3_out3 = ConvReLU(480, 1, 64, 1)
self.icp3_out0 = ConvReLU(480, 1, 192, 1)
"""etc"""
"""..."""
def construct(self, input_1):
"""do something with input_1, output num 6"""
x = self.conv1(input_1)
x = self.pool1(x)
x = self.norm1(x)
x = self.reduction2(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.pool2(x)
pool2 = x
x = self.icp1_reduction1(x)
x = self.icp1_out1(x)
icp1_out1 = x
icp1_reduction2 = self.icp1_reduction2(pool2)
icp1_out2 = self.icp1_out2(icp1_reduction2)
icp1_pool = self.icp1_pool(pool2)
icp1_out3 = self.icp1_out3(icp1_pool)
icp1_out0 = self.icp1_out0(pool2)
icp2_in = self.concat((icp1_out0, icp1_out1, icp1_out2, icp1_out3))
"""etc"""
"""..."""
return p1_x, p1_q, p2_x, p2_q, p3_x, p3_q
相应的,loss计算逻辑可编写如下:
[3]:
from mindspore import ops
class PoseNetLoss(nn.Cell):
def __init__(self):
super(PoseNetLoss, self).__init__()
self.sub = ops.Sub()
self.square = ops.Square()
self.reduce_sum = ops.ReduceSum()
self.sqrt = ops.Sqrt()
def construct(self, p1_x, p1_q, p2_x, p2_q, p3_x, p3_q, poses_x, poses_q):
"""do something to calc loss"""
l1_x = self.sqrt(self.reduce_sum(self.square(self.sub(p1_x, poses_x)))) * 0.3
l1_q = self.sqrt(self.reduce_sum(self.square(self.sub(p1_q, poses_q)))) * 150
l2_x = self.sqrt(self.reduce_sum(self.square(self.sub(p2_x, poses_x)))) * 0.3
l2_q = self.sqrt(self.reduce_sum(self.square(self.sub(p2_q, poses_q)))) * 150
l3_x = self.sqrt(self.reduce_sum(self.square(self.sub(p3_x, poses_x)))) * 1
l3_q = self.sqrt(self.reduce_sum(self.square(self.sub(p3_q, poses_q)))) * 500
return l1_x + l1_q + l2_x + l2_q + l3_x + l3_q
最终,你的训练脚本应该类似如下所示:
import mindspore as ms
from mindspore import dataset as ds
import numpy as np
if __name__ == "__main__":
epoch_size = 5
backbone = PoseNet()
loss = PoseNetLoss()
net_with_loss = PoseNetLossCell(backbone, loss)
opt = Adam(net_with_loss.trainable_params(), learning_rate=0.001, beta1=0.9, beta2=0.999, eps=1e-08, use_locking=False)
net_with_grad = TrainOneStepCell(net_with_loss, opt)
"""dataset define"""
model = ms.Model(net_with_grad)
model.train(epoch_size, dataset)
这样,就基本完成了模型脚本从TensorFlow到MindSpore的迁移,接下来就是利用丰富的MindSpore工具和计算策略,对精度进行调优,在此不予详述。
PyTorch脚本迁移MindSpore
通过读PyTorch脚本,直接进行迁移。
PyTorch子网模块通常继承
torch.nn.Module
,MindSpore通常继承mindspore.nn.Cell
;PyTorch子网模块正向计算逻辑需要重写forward方法,MindSpore子网模块正向计算逻辑需要重写construct方法。以常见的Bottleneck类在MindSpore下的迁移为例。
PyTorch工程代码
# defined in PyTorch
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, mode='NORM', k=1, dilation=1):
super(Bottleneck, self).__init__()
self.mode = mode
self.relu = nn.ReLU(inplace=True)
self.k = k
btnk_ch = planes // 4
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, btnk_ch, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(btnk_ch)
self.conv2 = nn.Conv2d(btnk_ch, btnk_ch, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
self.bn3 = nn.BatchNorm2d(btnk_ch)
self.conv3 = nn.Conv2d(btnk_ch, planes, kernel_size=1, bias=False)
if mode == 'UP':
self.shortcut = None
elif inplanes != planes or stride > 1:
self.shortcut = nn.Sequential(
nn.BatchNorm2d(inplanes),
self.relu,
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
)
else:
self.shortcut = None
def _pre_act_forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.mode == 'UP':
residual = self.squeeze_idt(x)
elif self.shortcut is not None:
residual = self.shortcut(residual)
out += residual
return out
def squeeze_idt(self, idt):
n, c, h, w = idt.size()
return idt.view(n, c // self.k, self.k, h, w).sum(2)
def forward(self, x):
out = self._pre_act_forward(x)
return out
根据PyTorch和MindSpore对卷积参数定义的区别,可以翻译成如下定义:
from mindspore import nn
import mindspore.ops as ops
# defined in MindSpore
class Bottleneck(nn.Cell):
def __init__(self, inplanes, planes, stride=1, k=1, dilation=1):
super(Bottleneck, self).__init__()
self.mode = mode
self.relu = nn.ReLU()
self.k = k
btnk_ch = planes // 4
self.bn1 = nn.BatchNorm2d(num_features=inplanes, momentum=0.9)
self.conv1 = nn.Conv2d(in_channels=inplanes, out_channels=btnk_ch, kernel_size=1, pad_mode='pad', has_bias=False)
self.bn2 = nn.BatchNorm2d(num_features=btnk_ch, momentum=0.9)
self.conv2 = nn.Conv2d(in_channels=btnk_ch, out_channels=btnk_ch, kernel_size=3, stride=stride, pad_mode='pad', padding=dilation, dilation=dilation, has_bias=False)
self.bn3 = nn.BatchNorm2d(num_features=btnk_ch, momentum=0.9)
self.conv3 = nn.Conv2d(in_channels=btnk_ch, out_channels=planes, kernel_size=1, pad_mode='pad', has_bias=False)
self.shape = ops.Shape()
self.reshape = ops.Reshape()
self.reduce_sum = ops.ReduceSum()
if mode == 'UP':
self.shortcut = None
elif inplanes != planes or stride > 1:
self.shortcut = nn.SequentialCell([
nn.BatchNorm2d(num_features=inplanes, momentum=0.9),
nn.ReLU(),
nn.Conv2d(in_channels=inplanes, out_channels=planes, kernel_size=1, stride=stride, pad_mode='pad', has_bias=False)])
else:
self.shortcut = None
def _pre_act_forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.shortcut is not None:
residual = self.shortcut(residual)
out += residual
return out
def construct(self, x):
out = self._pre_act_forward(x)
return out
PyTorch的反向传播通常使用
loss.backward()
实现,参数更新通过optimizer.step()
实现,在MindSpore中,这些不需要用户显式调用执行,可以交给TrainOneStepCell
类进行反向传播和梯度更新。最后,训练脚本结构应如下所示:
# define dataset
dataset = ...
# define backbone and loss
backbone = Net()
loss = NetLoss()
# combine backbone and loss
net_with_loss = WithLossCell(backbone, loss)
# define optimizer
opt = ...
# combine forward and backward
net_with_grad = TrainOneStepCell(net_with_loss, opt)
# define model and train
model = ms.Model(net_with_grad)
model.train(epoch_size, dataset)
PyTorch和mindspore在一些基础API的定义上比较相似,比如mindspore.nn.SequentialCell和torch.nn.Sequential,另外,一些算子API可能不尽相同,此处列举一些常见的API对照,更多信息可以参考MindSpore官网的MindSpore与PyTorch对照表。
PyTorch |
MindSpore |
---|---|
tensor.view() |
mindspore.ops.operations.Reshape()(tensor) |
tensor.size() |
mindspore.ops.operations.Shape()(tensor) |
tensor.sum(axis) |
mindspore.ops.operations.ReduceSum()(tensor, axis) |
torch.nn.Upsample[mode: nearest] |
mindspore.ops.operations.ResizeNearestNeighbor |
torch.nn.Upsample[mode: bilinear] |
mindspore.ops.operations.ResizeBilinear |
torch.nn.Linear |
mindspore.nn.Dense |
torch.nn.PixelShuffle |
mindspore.ops.operations.DepthToSpace |
值得注意的是,尽管torch.nn.MaxPool2d
和mindspore.nn.MaxPool2d
在接口定义上较为相似,但在Ascend上的训练过程中,MindSpore实际调用了MaxPoolWithArgMax
算子,该算子与TensorFlow的同名算子功能相同,在迁移过程中MaxPool层后的输出MindSpore与PyTorch不一致是正常现象,理论上不影响最终训练结果。