创建网络

在线运行下载Notebook下载样例代码查看源文件

神经网络模型由多个数据操作层组成,mindspore.nn提供了各种网络基础模块。本章以构建LeNet-5网络为例,先展示使用mindspore.nn建立神经网络模型,再展示使用mindvision.classification.models快速构建LeNet-5网络模型。

mindvision.classification.models是基于mindspore.nn开发的网络模型接口,提供了一些经典且常用的网络模型,方便用户使用。

LeNet-5模型

LeNet-5是Yann LeCun教授于1998年提出的一种典型的卷积神经网络,在MNIST数据集上达到99.4%准确率,是CNN领域的第一篇经典之作。其模型结构如下图所示:

LeNet-5

按照LeNet的网络结构,LeNet除去输入层共有7层,其中有2个卷积层,2个子采样层,3个全连接层。

定义模型类

上图中用C代表卷积层,用S代表采样层,用F代表全连接层。

图片的输入size固定在\(32*32\),为了获得良好的卷积效果,要求数字在图片的中央,所以输入\(32*32\)其实为\(28*28\)图片填充后的结果。另外不像CNN网络三通道的输入图片,LeNet图片的输入仅是规范化后的二值图像。网络的输出为0~9十个数字的预测概率,可以理解为输入图像属于0~9数字的可能性大小。

MindSpore的Cell类是构建所有网络的基类,也是网络的基本单元。构建神经网络时,需要继承Cell类,并重写__init__方法和construct方法。

[11]:
import mindspore.nn as nn

class LeNet5(nn.Cell):
    """
    LeNet-5网络结构
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        # 卷积层,输入的通道数为num_channel,输出的通道数为6,卷积核大小为5*5
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        # 卷积层,输入的通道数为6,输出的通道数为16,卷积核大小为5*5
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        # 全连接层,输入个数为16*5*5,输出个数为120
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        # 全连接层,输入个数为120,输出个数为84
        self.fc2 = nn.Dense(120, 84)
        # 全连接层,输入个数为84,分类的个数为num_class
        self.fc3 = nn.Dense(84, num_class)
        # ReLU激活函数
        self.relu = nn.ReLU()
        # 池化层
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        # 多维数组展平为一维数组
        self.flatten = nn.Flatten()

    def construct(self, x):
        # 使用定义好的运算构建前向网络
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

接下来建立上面定义的神经网络模型,并查看该网络模型的结构。

[12]:
model = LeNet5()

print(model)
LeNet5<
  (conv1): Conv2d<input_channels=1, output_channels=6, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
  (conv2): Conv2d<input_channels=6, output_channels=16, kernel_size=(5, 5), stride=(1, 1), pad_mode=valid, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
  (fc1): Dense<input_channels=400, output_channels=120, has_bias=True>
  (fc2): Dense<input_channels=120, output_channels=84, has_bias=True>
  (fc3): Dense<input_channels=84, output_channels=10, has_bias=True>
  (relu): ReLU<>
  (max_pool2d): MaxPool2d<kernel_size=2, stride=2, pad_mode=VALID>
  (flatten): Flatten<>
  >

模型层

本小节内容首先将会介绍LeNet-5网络中使用到Cell类的关键成员函数,然后通过实例化网络介绍如何利用Cell类访问模型参数,更多Cell类内容参考mindspore.nn接口

nn.Conv2d

加入nn.Conv2d层,给网络中加入卷积函数,帮助神经网络提取特征。

[13]:
import numpy as np

from mindspore import Tensor
from mindspore import dtype as mstype

# 输入的通道数为1,输出的通道数为6,卷积核大小为5*5,使用normal算子初始化参数,不填充像素
conv2d = nn.Conv2d(1, 6, 5, has_bias=False, weight_init='normal', pad_mode='same')
input_x = Tensor(np.ones([1, 1, 32, 32]), mstype.float32)

print(conv2d(input_x).shape)
(1, 6, 32, 32)

nn.ReLU

加入nn.ReLU层,给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。

[7]:
relu = nn.ReLU()

input_x = Tensor(np.array([-1, 2, -3, 2, -1]), mstype.float16)

output = relu(input_x)
print(output)
[0. 2. 0. 2. 0.]

nn.MaxPool2d

初始化nn.MaxPool2d层,将6×28×28的张量降采样为6×7x7的张量。

[10]:
max_pool2d = nn.MaxPool2d(kernel_size=4, stride=4)
input_x = Tensor(np.ones([1, 6, 28, 28]), mstype.float32)

print(max_pool2d(input_x).shape)
(1, 6, 7, 7)

nn.Flatten

初始化nn.Flatten层,将1×16×5×5的四维张量转换为400个连续元素的二维张量。

[11]:
flatten = nn.Flatten()
input_x = Tensor(np.ones([1, 16, 5, 5]), mstype.float32)
output = flatten(input_x)

print(output.shape)
(1, 400)

nn.Dense

初始化nn.Dense层,对输入矩阵进行线性变换。

[8]:
dense = nn.Dense(400, 120, weight_init='normal')
input_x = Tensor(np.ones([1, 400]), mstype.float32)
output = dense(input_x)

print(output.shape)
(1, 120)

模型参数

网络内部的卷积层和全连接层等实例化后,即具有权重参数和偏置参数,这些参数会在训练过程中不断进行优化,在训练过程中可通过 get_parameters() 来查看网络各层的名字、形状、数据类型和是否反向计算等信息。

[10]:
for m in model.get_parameters():
    print(f"layer:{m.name}, shape:{m.shape}, dtype:{m.dtype}, requeires_grad:{m.requires_grad}")
layer:backbone.conv1.weight, shape:(6, 1, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.conv2.weight, shape:(16, 6, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.fc1.weight, shape:(120, 400), dtype:Float32, requeires_grad:True
layer:backbone.fc1.bias, shape:(120,), dtype:Float32, requeires_grad:True
layer:backbone.fc2.weight, shape:(84, 120), dtype:Float32, requeires_grad:True
layer:backbone.fc2.bias, shape:(84,), dtype:Float32, requeires_grad:True
layer:backbone.fc3.weight, shape:(10, 84), dtype:Float32, requeires_grad:True
layer:backbone.fc3.bias, shape:(10,), dtype:Float32, requeires_grad:True

快速构建LeNet-5网络模型

上述介绍了使用mindspore.nn.cell构建LeNet-5网络模型,在mindvision.classification.models中已有构建好的网络模型接口,也可使用lenet接口直接构建LeNet-5网络模型。

[12]:
from mindvision.classification.models import lenet

# num_classes表示分类的数量,pretrained表示是否使用预训练模型进行训练
model = lenet(num_classes=10, pretrained=False)

for m in model.get_parameters():
    print(f"layer:{m.name}, shape:{m.shape}, dtype:{m.dtype}, requeires_grad:{m.requires_grad}")
layer:backbone.conv1.weight, shape:(6, 1, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.conv2.weight, shape:(16, 6, 5, 5), dtype:Float32, requeires_grad:True
layer:backbone.fc1.weight, shape:(120, 400), dtype:Float32, requeires_grad:True
layer:backbone.fc1.bias, shape:(120,), dtype:Float32, requeires_grad:True
layer:backbone.fc2.weight, shape:(84, 120), dtype:Float32, requeires_grad:True
layer:backbone.fc2.bias, shape:(84,), dtype:Float32, requeires_grad:True
layer:backbone.fc3.weight, shape:(10, 84), dtype:Float32, requeires_grad:True
layer:backbone.fc3.bias, shape:(10,), dtype:Float32, requeires_grad:True