建立神经网络
Ascend
GPU
CPU
入门
模型开发
神经网络模型由多个数据操作层组成,mindspore.nn
提供了各种网络基础模块。
在以下内容中,我们将以构建LeNet网络为例,展示MindSpore是如何建立神经网络模型的。
首先导入本文档需要的模块和接口,如下所示:
[1]:
import numpy as np
import mindspore
import mindspore.nn as nn
from mindspore import Tensor
定义模型类
MindSpore的Cell
类是构建所有网络的基类,也是网络的基本单元。当用户需要神经网络时,需要继承Cell
类,并重写__init__
方法和construct
方法。
[2]:
class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
模型层
本小节内容首先将会介绍LeNet网络中使用到Cell
类的关键成员函数,然后通过实例化网络介绍如何利用Cell
类访问模型参数。
nn.Conv2d
加入nn.Conv2d
层,给网络中加入卷积函数,帮助神经网络提取特征。
[3]:
conv2d = nn.Conv2d(1, 6, 5, has_bias=False, weight_init='normal', pad_mode='valid')
input_x = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
print(conv2d(input_x).shape)
(1, 6, 28, 28)
nn.ReLU
加入nn.ReLU
层,给网络中加入非线性的激活函数,帮助神经网络学习各种复杂的特征。
[4]:
relu = nn.ReLU()
input_x = Tensor(np.array([-1, 2, -3, 2, -1]), mindspore.float16)
output = relu(input_x)
print(output)
[0. 2. 0. 2. 0.]
nn.MaxPool2d
初始化nn.MaxPool2d
层,将6×28×28的数组降采样为6×14×14的数组。
[5]:
max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
input_x = Tensor(np.ones([1, 6, 28, 28]), mindspore.float32)
print(max_pool2d(input_x).shape)
(1, 6, 14, 14)
nn.Flatten
初始化nn.Flatten
层,将16×5×5的数组转换为400个连续数组。
[6]:
flatten = nn.Flatten()
input_x = Tensor(np.ones([1, 16, 5, 5]), mindspore.float32)
output = flatten(input_x)
print(output.shape)
(1, 400)
nn.Dense
初始化nn.Dense
层,对输入矩阵进行线性变换。
[7]:
dense = nn.Dense(400, 120, weight_init='normal')
input_x = Tensor(np.ones([1, 400]), mindspore.float32)
output = dense(input_x)
print(output.shape)
(1, 120)
模型参数
网络内部的卷积层和全连接层等实例化后,即具有权重和偏置,这些权重和偏置参数会在之后训练中进行优化。nn.Cell
中使用parameters_and_names()
方法访问所有参数。
在示例中,我们遍历每个参数,并打印网络各层名字和属性。
[8]:
model = LeNet5()
for m in model.parameters_and_names():
print(m)
('conv1.weight', Parameter (name=conv1.weight, shape=(6, 1, 5, 5), dtype=Float32, requires_grad=True))
('conv2.weight', Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True))
('fc1.weight', Parameter (name=fc1.weight, shape=(120, 400), dtype=Float32, requires_grad=True))
('fc1.bias', Parameter (name=fc1.bias, shape=(120,), dtype=Float32, requires_grad=True))
('fc2.weight', Parameter (name=fc2.weight, shape=(84, 120), dtype=Float32, requires_grad=True))
('fc2.bias', Parameter (name=fc2.bias, shape=(84,), dtype=Float32, requires_grad=True))
('fc3.weight', Parameter (name=fc3.weight, shape=(10, 84), dtype=Float32, requires_grad=True))
('fc3.bias', Parameter (name=fc3.bias, shape=(10,), dtype=Float32, requires_grad=True))