创建网络
神经网络模型由多个数据操作层组成,mindspore.nn
提供了各种网络基础模块。本章以构建LeNet-5网络为例,先展示使用mindspore.nn
建立神经网络模型,再展示使用mindvision.classification.models
快速构建LeNet-5网络模型。
mindvision.classification.models
是基于mindspore.nn
开发的网络模型接口,提供了一些经典且常用的网络模型,方便用户使用。
LeNet-5模型
LeNet-5是Yann LeCun教授于1998年提出的一种典型的卷积神经网络,在MNIST数据集上达到99.4%准确率,是CNN领域的第一篇经典之作。其模型结构如下图所示:
按照LeNet的网络结构,LeNet除去输入层共有7层,其中有2个卷积层,2个子采样层,3个全连接层。
定义模型类
上图中用C代表卷积层,用S代表采样层,用F代表全连接层。
图片的输入size固定在\(32*32\),为了获得良好的卷积效果,要求数字在图片的中央,所以输入\(32*32\)其实为\(28*28\)图片填充后的结果。另外不像CNN网络三通道的输入图片,LeNet图片的输入仅是规范化后的二值图像。网络的输出为0~9十个数字的预测概率,可以理解为输入图像属于0~9数字的可能性大小。
MindSpore的Cell
类是构建所有网络的基类,也是网络的基本单元。构建神经网络时,需要继承Cell
类,并重写__init__
方法和construct
方法。
[11]:
import mindspore.nn as nn
class LeNet5(nn.Cell):
"""
LeNet-5网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 卷积层,输入的通道数为num_channel,输出的通道数为6,卷积核大小为5*5
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
# 卷积层,输入的通道数为6,输出的通道数为16,卷积核大小为5*5
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
# 全连接层,输入个数为16*5*5,输出个数为120
self.fc1 = nn.Dense(16 * 5 * 5, 120)
# 全连接层,输入个数为120,输出个数为84
self.fc2 = nn.Dense(120, 84)
# 全连接层,输入个数为84,分类的个数为num_class
self.fc3 = nn.Dense(84, num_class)
# ReLU激活函数
self.relu = nn.ReLU()
# 池化层
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
# 多维数组展平为一维数组
self.flatten = nn.Flatten()
def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
接下来建立上面定义的神经网络模型,并查看该网络模型的结构。
[12]:
model = LeNet5()
print(model)
LeNet5<
(conv1): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(conv2): Conv2d<input_channels=6, output_channels=16, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
(fc1): Dense<input_channels=400, output_channels=120, has_bias=True>
(fc2): Dense<input_channels=120, output_channels=84, has_bias=True>
(fc3): Dense<input_channels=84, output_channels=10, has_bias=True>
(relu): ReLU<>
(max_pool2d): MaxPool2d<kernel_size=2, stride=2, pad_mode=VALID>
(flatten): Flatten<>
>
模型层
本小节内容首先将会介绍LeNet-5网络中使用到Cell
类的关键成员函数,然后通过实例化网络介绍如何利用Cell
类访问模型参数,更多Cell
类内容参考mindspore.nn接口。
nn.Conv2d
加入nn.Conv2d
层,给网络中加入卷积函数,帮助神经网络提取特征。
[13]:
import numpy as np
from mindspore import Tensor
from mindspore import dtype as mstype
# 输入的通道数为1,输出的通道数为6,卷积核大小为5*5,使用normal算子初始化参数,不填充像素
conv2d = nn.Conv2d(1, 6, 5, has_bias=False, weight_init='normal', pad_mode='same')
input_x = Tensor(np.ones([1, 1, 32, 32]), mstype.float32)
print(conv2d(input_x).shape)
(1, 6, 32, 32)
nn.ReLU
加入nn.ReLU
层,给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。
[7]:
relu = nn.ReLU()
input_x = Tensor(np.array([-1, 2, -3, 2, -1]), mstype.float16)
output = relu(input_x)
print(output)
[0. 2. 0. 2. 0.]
nn.MaxPool2d
初始化nn.MaxPool2d
层,将6×28×28的张量降采样为6×7x7的张量。
[10]:
max_pool2d = nn.MaxPool2d(kernel_size=4, stride=4)
input_x = Tensor(np.ones([1, 6, 28, 28]), mstype.float32)
print(max_pool2d(input_x).shape)
(1, 6, 7, 7)
nn.Flatten
初始化nn.Flatten
层,将1×16×5×5的四维张量转换为400个连续元素的二维张量。
[11]:
flatten = nn.Flatten()
input_x = Tensor(np.ones([1, 16, 5, 5]), mstype.float32)
output = flatten(input_x)
print(output.shape)
(1, 400)
nn.Dense
初始化nn.Dense
层,对输入矩阵进行线性变换。
[8]:
dense = nn.Dense(400, 120, weight_init='normal')
input_x = Tensor(np.ones([1, 400]), mstype.float32)
output = dense(input_x)
print(output.shape)
(1, 120)
模型参数
网络内部的卷积层和全连接层等实例化后,即具有权重参数和偏置参数,这些参数会在训练过程中不断进行优化,在训练过程中可通过 get_parameters()
来查看网络各层的名字、形状、数据类型和是否反向计算等信息。
[10]:
for m in model.get_parameters():
print(f"layer:{m.name}, shape:{m.shape}, dtype:{m.dtype}, requeires_grad:{m.requires_grad}")
layer:backbone.conv1.weight, shape:(6, 1, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.conv2.weight, shape:(16, 6, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.fc1.weight, shape:(120, 400), dtype:Float32, requeires_grad:True
layer:backbone.fc1.bias, shape:(120,), dtype:Float32, requeires_grad:True
layer:backbone.fc2.weight, shape:(84, 120), dtype:Float32, requeires_grad:True
layer:backbone.fc2.bias, shape:(84,), dtype:Float32, requeires_grad:True
layer:backbone.fc3.weight, shape:(10, 84), dtype:Float32, requeires_grad:True
layer:backbone.fc3.bias, shape:(10,), dtype:Float32, requeires_grad:True
快速构建LeNet-5网络模型
上述介绍了使用mindspore.nn.cell
构建LeNet-5网络模型,在mindvision.classification.models
中已有构建好的网络模型接口,也可使用lenet
接口直接构建LeNet-5网络模型。
[12]:
from mindvision.classification.models import lenet
# num_classes表示分类的数量,pretrained表示是否使用预训练模型进行训练
model = lenet(num_classes=10, pretrained=False)
for m in model.get_parameters():
print(f"layer:{m.name}, shape:{m.shape}, dtype:{m.dtype}, requeires_grad:{m.requires_grad}")
layer:backbone.conv1.weight, shape:(6, 1, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.conv2.weight, shape:(16, 6, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.fc1.weight, shape:(120, 400), dtype:Float32, requeires_grad:True
layer:backbone.fc1.bias, shape:(120,), dtype:Float32, requeires_grad:True
layer:backbone.fc2.weight, shape:(84, 120), dtype:Float32, requeires_grad:True
layer:backbone.fc2.bias, shape:(84,), dtype:Float32, requeires_grad:True
layer:backbone.fc3.weight, shape:(10, 84), dtype:Float32, requeires_grad:True
layer:backbone.fc3.bias, shape:(10,), dtype:Float32, requeires_grad:True