构建网络
MindSpore的Cell
类是构建所有网络的基类,也是网络的基本单元。自定义网络时,需要继承Cell
类,本章主要介绍网络基本单元Cell
和自定义前向网络。
本章主要介绍前向网络模型的构建和网络模型的基本单元,因为不涉及到训练,因此没有反向传播和反向图。
网络基本单元 Cell
当用户需要自定义网络时,需要继承Cell
类,并重写__init__
方法和construct
方法。损失函数、优化器和模型层等本质上也属于网络结构,也需要继承Cell
类才能实现功能,同样用户也可以根据业务需求自定义这部分内容。
下面介绍Cell
的关键成员函数。
construct方法
Cell
类重写了__call__
方法,在Cell
类的实例被调用时,会执行construct
方法。网络结构在construct
方法里面定义。
如下样例中,构建了一个简单的卷积网络,卷积网络在__init__
中定义,在construct
方法传入输入数据x
执行卷积计算,并返回计算结果。
[1]:
from mindspore import nn
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(10, 20, 3, has_bias=True, weight_init='normal')
def construct(self, x):
out = self.conv(x)
return out
获取网络参数
nn.Cell
中返回参数的方法有parameters_dict
、get_parameters
和trainable_params
。
parameters_dict
:获取网络结构中所有参数,返回一个以key为参数名,value为参数值的OrderedDict。get_parameters
:获取网络结构中的所有参数,返回Cell中Parameter的迭代器。trainable_params
:获取Parameter中requires_grad
为True的属性,返回可训参数的列表。
如下示例分别使用上述方法获取网络参数并打印。
[2]:
net = Net()
# 获取网络结构中的所有参数
result = net.parameters_dict()
print("parameters_dict of result:\n", result)
# 获取网络结构中的所有参数
print("\nget_parameters of result:")
for m in net.get_parameters():
print(m)
# 获取可训练参数列表
result = net.trainable_params()
print("\ntrainable_params of result:\n", result)
parameters_dict of result:
OrderedDict([('conv.weight', Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True)), ('conv.bias', Parameter (name=conv.bias, shape=(20,), dtype=Float32, requires_grad=True))])
get_parameters of result:
Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True)
Parameter (name=conv.bias, shape=(20,), dtype=Float32, requires_grad=True)
trainable_params of result:
[Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True), Parameter (name=conv.bias, shape=(20,), dtype=Float32, requires_grad=True)]
相关属性
cells_and_names
cells_and_names
方法是一个迭代器,返回网络中每个Cell
的名字和它的内容本身。代码样例如下:
[3]:
net = Net()
for m in net.cells_and_names():
print(m)
('', Net<
(conv): Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=normal, bias_init=zeros, format=NCHW>
>)
('conv', Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=normal, bias_init=zeros, format=NCHW>)
set_grad
set_grad
用于指定网络是否需要计算梯度。在不传入参数调用时,默认设置requires_grad
为True,在执行前向网络时将会构建用于计算梯度的反向网络。TrainOneStepCell
和GradOperation
接口,无需使用set_grad
,其内部已实现。若用户需要自定义此类训练功能的接口,需要在其内部或者外部设置set_grad
。
[4]:
class CustomTrainOneStepCell(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
"""入参有三个:训练网络,优化器和反向传播缩放比例"""
super(CustomTrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network # 前向网络
self.network.set_grad() # 构建计算梯度的反向网络
self.optimizer = optimizer # 优化器
CustomTrainOneStepCell
代码详细内容可参见自定义训练与评估网络
set_train
set_train
接口指定模型是否为训练模式,在不传入参数调用时,默认设置的mode
属性为True。
在实现训练和推理结构不同的网络时可以通过training
属性区分训练和推理场景,当mode
设置为True时,为训练场景;当mode
设置为False时,为推理场景。
MindSpore中的nn.Dropout
算子,根据Cell
的mode
属性区分了两种执行逻辑,mode
为False时直接返回输入,mode
为True时执行算子。
[5]:
import numpy as np
import mindspore as ms
x = ms.Tensor(np.ones([2, 2, 3]), ms.float32)
net = nn.Dropout(keep_prob=0.7)
# 执行训练
net.set_train()
output = net(x)
print("training result:\n", output)
# 执行推理
net.set_train(mode=False)
output = net(x)
print("\ninfer result:\n", output)
training result:
[[[1.4285715 1.4285715 1.4285715]
[1.4285715 0. 0. ]]
[[1.4285715 1.4285715 1.4285715]
[1.4285715 1.4285715 1.4285715]]]
infer result:
[[[1. 1. 1.]
[1. 1. 1.]]
[[1. 1. 1.]
[1. 1. 1.]]]
to_float
to_float
接口递归地在配置了当前Cell
和所有子Cell
的强制转换类型,以使当前网络结构以使用特定的Float类型运行,通常在混合精度场景使用。
如下示例分别对nn.dense
层使用float32类型和float16类型进行运算,并打印输出结果的数据类型。
[6]:
import numpy as np
from mindspore import nn
import mindspore as ms
# float32进行计算
x = ms.Tensor(np.ones([2, 2, 3]), ms.float32)
net = nn.Dense(3, 2)
output = net(x)
print(output.dtype)
# float16进行计算
net1 = nn.Dense(3, 2)
net1.to_float(ms.float16)
output = net1(x)
print(output.dtype)
Float32
Float16
构建网络
构建网络时,可以继承nn.Cell
类,在__init__
构造函数中申明各个层的定义,在construct
中实现层之间的连接关系,完成神经网络的前向构造。
mindspore.ops
模块提供了基础算子的实现,如神经网络算子、数组算子和数学算子等。
mindspore.nn
模块实现了对基础算子的进一步封装,用户可以根据需要,灵活使用不同的算子。
同时,为了更好地构建和管理复杂的网络,mindspore.nn
提供了两种容器对网络中的子模块或模型层进行管理,分别为nn.CellList
和nn.SequentialCell
两种方式。
Ops算子构建网络
mindspore.ops模块提供了基础算子的实现,如神经网络算子、数组算子和数学算子等。
用户可使用mindspore.ops
中的算子来构建一个简单的算法 \(f(x)=x^2+w\),示例如下:
[7]:
import numpy as np
import mindspore as ms
from mindspore import nn, ops
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.mul = ops.Mul()
self.add = ops.Add()
self.weight = ms.Parameter(ms.Tensor(np.array([2, 2, 2]), ms.float32))
def construct(self, x):
return self.add(self.mul(x, x), self.weight)
net = Net()
input = ms.Tensor(np.array([1, 2, 3]), ms.float32)
output = net(input)
print(output)
[ 3. 6. 11.]
nn层构建网络
尽管mindspore.ops
模块提供的多样算子可以基本满足网络构建的诉求,但为了在复杂的深度网络中提供更方便易用的接口,mindspore.nn对mindspore.ops
算子进行了进一步的封装。
mindspore.nn
模块主要包括神经网络(neural network)中常用的卷积层(如nn.Conv2d
)、池化层(nn.MaxPool2d
)、非线性激活函数(如nn.ReLU
)、损失函数(如nn.LossBase
)、优化器(如nn.Momentum
)等,为用户的使用提供了便利。
下面示例代码中,使用mindspore.nn
模块构建一个Conv + Batch Normalization + ReLu模型网络。
[8]:
import numpy as np
from mindspore import nn
class ConvBNReLU(nn.Cell):
def __init__(self):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.bn = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
out = self.relu(x)
return out
net = ConvBNReLU()
print(net)
ConvBNReLU<
(conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
(relu): ReLU<>
>
容器构建网络
为了便于管理和组成更复杂的网络,mindspore.nn
提供了容器对网络中的子模型块或模型层进行管理,有nn.CellList
和nn.SequentialCell
两种方式。
CellList构建网络
使用nn.CellList
构造的Cell既可以是模型层,也可以是构建的网络子块。nn.CellList
支持append
、extend
方法和insert
方法三种方法。
在运行网络时,可以在construct方法里,使用for循环,运行输出结果。
append(cell)
:在列表末尾添加一个cell。extend(cells)
:将cells添加至列表末尾。insert(index, cell)
:在列表给定的索引之前插入给定的cell。
如下使用nn.CellList
构建并执行一个网络,依次包含一个之前定义的模型子块ConvBNReLU、一个Conv2d层、一个BatchNorm2d层和一个ReLU层:
[9]:
import numpy as np
import mindspore as ms
from mindspore import nn
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
layers = [ConvBNReLU()]
# 使用CellList对网络进行管理
self.build_block = nn.CellList(layers)
# 使用append方法添加Conv2d层和ReLU层
self.build_block.append(nn.Conv2d(64, 4, 4))
self.build_block.append(nn.ReLU())
# 使用insert方法在Conv2d层和ReLU层中间插入BatchNorm2d
self.build_block.insert(-1, nn.BatchNorm2d(4))
def construct(self, x):
# for循环执行网络
for layer in self.build_block:
x = layer(x)
return x
net = MyNet()
print(net)
MyNet<
(build_block): CellList<
(0): ConvBNReLU<
(conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
(relu): ReLU<>
>
(1): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(2): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.2.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.2.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.2.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.2.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
(3): ReLU<>
>
>
把数据输入到网络模型中:
[10]:
input = ms.Tensor(np.ones([1, 3, 64, 32]), ms.float32)
output = net(input)
print(output.shape)
(1, 4, 64, 32)
SequentialCell构建网络
使用nn.SequentialCell
构造Cell顺序容器,支持子模块以List或OrderedDict格式作为输入。
不同于nn.CellList
的是,nn.SequentialCell
类内部实现了construct
方法,可以直接输出结果。
如下示例使用nn.SequentialCell
构建一个网络,输入为List,网络结构依次包含一个之前定义的模型子块ConvBNReLU、一个Conv2d层、一个BatchNorm2d层和一个ReLU层:
[11]:
import numpy as np
import mindspore as ms
from mindspore import nn
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
layers = [ConvBNReLU()]
layers.extend([nn.Conv2d(64, 4, 4),
nn.BatchNorm2d(4),
nn.ReLU()])
self.build_block = nn.SequentialCell(layers) # 使用SequentialCell对网络进行管理
def construct(self, x):
return self.build_block(x)
net = MyNet()
print(net)
MyNet<
(build_block): SequentialCell<
(0): ConvBNReLU<
(conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
(relu): ReLU<>
>
(1): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(2): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.2.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.2.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.2.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.2.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
(3): ReLU<>
>
>
把数据输入到网络模型中:
[12]:
input = ms.Tensor(np.ones([1, 3, 64, 32]), ms.float32)
output = net(input)
print(output.shape)
(1, 4, 64, 32)
如下示例使用nn.SequentialCell
构建一个网络,输入为OrderedDict:
[13]:
import numpy as np
import mindspore as ms
from mindspore import nn
from collections import OrderedDict
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
layers = OrderedDict()
# 将cells加入字典
layers["ConvBNReLU"] = ConvBNReLU()
layers["conv"] = nn.Conv2d(64, 4, 4)
layers["norm"] = nn.BatchNorm2d(4)
layers["relu"] = nn.ReLU()
# 使用SequentialCell对网络进行管理
self.build_block = nn.SequentialCell(layers)
def construct(self, x):
return self.build_block(x)
net = MyNet()
print(net)
input = ms.Tensor(np.ones([1, 3, 64, 32]), ms.float32)
output = net(input)
print(output.shape)
MyNet<
(build_block): SequentialCell<
(ConvBNReLU): ConvBNReLU<
(conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.ConvBNReLU.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.ConvBNReLU.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.ConvBNReLU.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.ConvBNReLU.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
(relu): ReLU<>
>
(conv): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(norm): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.norm.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.norm.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.norm.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.norm.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
(relu): ReLU<>
>
>
(1, 4, 64, 32)
nn与ops关系
mindspore.nn
模块是Python实现的模型组件,对低阶API的封装,主要包括神经网络模型相关的各种模型层、损失函数、优化器等。
同时mindspore.nn
也提供了部分与mindspore.ops
算子同名的接口,主要作用是对mindspore.ops
算子进行进一步封装,为用户提供更友好的API。用户也可使用mindspore.ops
算子根据实际场景实现自定义的网络。
如下示例使用mindspore.ops.Conv2D
算子实现卷积计算功能,即nn.Conv2d
算子功能。
[14]:
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms
from mindspore.common.initializer import initializer
class Net(nn.Cell):
def __init__(self, in_channels=10, out_channels=20, kernel_size=3):
super(Net, self).__init__()
self.conv2d = ops.Conv2D(out_channels, kernel_size)
self.bias_add = ops.BiasAdd()
self.weight = ms.Parameter(
initializer('normal', [out_channels, in_channels, kernel_size, kernel_size]),
name='conv.weight')
self.bias = ms.Parameter(initializer('normal', [out_channels]), name='conv.bias')
def construct(self, x):
"""输入数据x"""
output = self.conv2d(x, self.weight)
output = self.bias_add(output, self.bias)
return output