构建网络

下载Notebook下载样例代码查看源文件

MindSpore的Cell类是构建所有网络的基类,也是网络的基本单元。自定义网络时,需要继承Cell类,本章主要介绍网络基本单元Cell和自定义前向网络。

本章主要介绍前向网络模型的构建和网络模型的基本单元,因为不涉及到训练,因此没有反向传播和反向图。

learningrate.png

网络基本单元 Cell

当用户需要自定义网络时,需要继承Cell类,并重写__init__方法和construct方法。损失函数、优化器和模型层等本质上也属于网络结构,也需要继承Cell类才能实现功能,同样用户也可以根据业务需求自定义这部分内容。

下面介绍Cell的关键成员函数。

construct方法

Cell类重写了__call__方法,在Cell类的实例被调用时,会执行construct方法。网络结构在construct方法里面定义。

如下样例中,构建了一个简单的卷积网络,卷积网络在__init__中定义,在construct方法传入输入数据x执行卷积计算,并返回计算结果。

[1]:
from mindspore import nn


class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(10, 20, 3, has_bias=True, weight_init='normal')

    def construct(self, x):
        out = self.conv(x)
        return out

获取网络参数

nn.Cell中返回参数的方法有parameters_dictget_parameterstrainable_params

  • parameters_dict:获取网络结构中所有参数,返回一个以key为参数名,value为参数值的OrderedDict。

  • get_parameters:获取网络结构中的所有参数,返回Cell中Parameter的迭代器。

  • trainable_params:获取Parameter中requires_grad为True的属性,返回可训参数的列表。

如下示例分别使用上述方法获取网络参数并打印。

[2]:
net = Net()

# 获取网络结构中的所有参数
result = net.parameters_dict()
print("parameters_dict of result:\n", result)

# 获取网络结构中的所有参数
print("\nget_parameters of result:")
for m in net.get_parameters():
    print(m)

# 获取可训练参数列表
result = net.trainable_params()
print("\ntrainable_params of result:\n", result)
parameters_dict of result:
 OrderedDict([('conv.weight', Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True)), ('conv.bias', Parameter (name=conv.bias, shape=(20,), dtype=Float32, requires_grad=True))])

get_parameters of result:
Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True)
Parameter (name=conv.bias, shape=(20,), dtype=Float32, requires_grad=True)

trainable_params of result:
 [Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True), Parameter (name=conv.bias, shape=(20,), dtype=Float32, requires_grad=True)]

相关属性

  1. cells_and_names

cells_and_names方法是一个迭代器,返回网络中每个Cell的名字和它的内容本身。代码样例如下:

[3]:
net = Net()
for m in net.cells_and_names():
    print(m)
('', Net<
  (conv): Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=normal, bias_init=zeros, format=NCHW>
  >)
('conv', Conv2d<input_channels=10, output_channels=20, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=True, weight_init=normal, bias_init=zeros, format=NCHW>)
  1. set_grad

set_grad用于指定网络是否需要计算梯度。在不传入参数调用时,默认设置requires_grad为True,在执行前向网络时将会构建用于计算梯度的反向网络。TrainOneStepCellGradOperation接口,无需使用set_grad,其内部已实现。若用户需要自定义此类训练功能的接口,需要在其内部或者外部设置set_grad

[4]:
class CustomTrainOneStepCell(nn.Cell):
    def __init__(self, network, optimizer, sens=1.0):
        """入参有三个:训练网络,优化器和反向传播缩放比例"""
        super(CustomTrainOneStepCell, self).__init__(auto_prefix=False)
        self.network = network      # 前向网络
        self.network.set_grad()     # 构建计算梯度的反向网络
        self.optimizer = optimizer  # 优化器

CustomTrainOneStepCell代码详细内容可参见自定义训练与评估网络

  1. set_train

set_train接口指定模型是否为训练模式,在不传入参数调用时,默认设置的mode属性为True。

在实现训练和推理结构不同的网络时可以通过training属性区分训练和推理场景,当mode设置为True时,为训练场景;当mode设置为False时,为推理场景。

MindSpore中的nn.Dropout算子,根据Cellmode属性区分了两种执行逻辑,mode为False时直接返回输入,mode为True时执行算子。

[5]:
import numpy as np
import mindspore as ms

x = ms.Tensor(np.ones([2, 2, 3]), ms.float32)
net = nn.Dropout(keep_prob=0.7)

# 执行训练
net.set_train()
output = net(x)
print("training result:\n", output)

# 执行推理
net.set_train(mode=False)
output = net(x)
print("\ninfer result:\n", output)
training result:
 [[[1.4285715 1.4285715 1.4285715]
  [1.4285715 0.        0.       ]]

 [[1.4285715 1.4285715 1.4285715]
  [1.4285715 1.4285715 1.4285715]]]

infer result:
 [[[1. 1. 1.]
  [1. 1. 1.]]

 [[1. 1. 1.]
  [1. 1. 1.]]]
  1. to_float

to_float接口递归地在配置了当前Cell和所有子Cell的强制转换类型,以使当前网络结构以使用特定的Float类型运行,通常在混合精度场景使用。

如下示例分别对nn.dense层使用float32类型和float16类型进行运算,并打印输出结果的数据类型。

[6]:
import numpy as np
from mindspore import nn
import mindspore as ms

# float32进行计算
x = ms.Tensor(np.ones([2, 2, 3]), ms.float32)
net = nn.Dense(3, 2)
output = net(x)
print(output.dtype)

# float16进行计算
net1 = nn.Dense(3, 2)
net1.to_float(ms.float16)
output = net1(x)
print(output.dtype)
Float32
Float16

构建网络

构建网络时,可以继承nn.Cell类,在__init__构造函数中申明各个层的定义,在construct中实现层之间的连接关系,完成神经网络的前向构造。

mindspore.ops模块提供了基础算子的实现,如神经网络算子、数组算子和数学算子等。

mindspore.nn模块实现了对基础算子的进一步封装,用户可以根据需要,灵活使用不同的算子。

同时,为了更好地构建和管理复杂的网络,mindspore.nn提供了两种容器对网络中的子模块或模型层进行管理,分别为nn.CellListnn.SequentialCell两种方式。

Ops算子构建网络

mindspore.ops模块提供了基础算子的实现,如神经网络算子、数组算子和数学算子等。

用户可使用mindspore.ops中的算子来构建一个简单的算法 \(f(x)=x^2+w\),示例如下:

[7]:
import numpy as np
import mindspore as ms
from mindspore import nn, ops

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.mul = ops.Mul()
        self.add = ops.Add()
        self.weight = ms.Parameter(ms.Tensor(np.array([2, 2, 2]), ms.float32))

    def construct(self, x):
        return self.add(self.mul(x, x), self.weight)

net = Net()
input = ms.Tensor(np.array([1, 2, 3]), ms.float32)
output = net(input)

print(output)
[ 3.  6. 11.]

nn层构建网络

尽管mindspore.ops模块提供的多样算子可以基本满足网络构建的诉求,但为了在复杂的深度网络中提供更方便易用的接口,mindspore.nnmindspore.ops算子进行了进一步的封装。

mindspore.nn模块主要包括神经网络(neural network)中常用的卷积层(如nn.Conv2d)、池化层(nn.MaxPool2d)、非线性激活函数(如nn.ReLU)、损失函数(如nn.LossBase)、优化器(如nn.Momentum)等,为用户的使用提供了便利。

下面示例代码中,使用mindspore.nn模块构建一个Conv + Batch Normalization + ReLu模型网络。

[8]:
import numpy as np
from mindspore import nn

class ConvBNReLU(nn.Cell):
    def __init__(self):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv(x)
        x = self.bn(x)
        out = self.relu(x)
        return out

net = ConvBNReLU()
print(net)
ConvBNReLU<
  (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
  (bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
  (relu): ReLU<>
  >

容器构建网络

为了便于管理和组成更复杂的网络,mindspore.nn提供了容器对网络中的子模型块或模型层进行管理,有nn.CellListnn.SequentialCell两种方式。

  1. CellList构建网络

使用nn.CellList构造的Cell既可以是模型层,也可以是构建的网络子块。nn.CellList支持appendextend方法和insert方法三种方法。

在运行网络时,可以在construct方法里,使用for循环,运行输出结果。

  • append(cell):在列表末尾添加一个cell。

  • extend(cells):将cells添加至列表末尾。

  • insert(index, cell):在列表给定的索引之前插入给定的cell。

如下使用nn.CellList构建并执行一个网络,依次包含一个之前定义的模型子块ConvBNReLU、一个Conv2d层、一个BatchNorm2d层和一个ReLU层:

[9]:
import numpy as np
import mindspore as ms
from mindspore import nn

class MyNet(nn.Cell):

    def __init__(self):
        super(MyNet, self).__init__()
        layers = [ConvBNReLU()]
        # 使用CellList对网络进行管理
        self.build_block = nn.CellList(layers)

        # 使用append方法添加Conv2d层和ReLU层
        self.build_block.append(nn.Conv2d(64, 4, 4))
        self.build_block.append(nn.ReLU())

        # 使用insert方法在Conv2d层和ReLU层中间插入BatchNorm2d
        self.build_block.insert(-1, nn.BatchNorm2d(4))

    def construct(self, x):
        # for循环执行网络
        for layer in self.build_block:
            x = layer(x)
        return x

net = MyNet()
print(net)
MyNet<
  (build_block): CellList<
    (0): ConvBNReLU<
      (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
      (bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
      (relu): ReLU<>
      >
    (1): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (2): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.2.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.2.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.2.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.2.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
    (3): ReLU<>
    >
  >

把数据输入到网络模型中:

[10]:
input = ms.Tensor(np.ones([1, 3, 64, 32]), ms.float32)
output = net(input)
print(output.shape)
(1, 4, 64, 32)
  1. SequentialCell构建网络

使用nn.SequentialCell构造Cell顺序容器,支持子模块以List或OrderedDict格式作为输入。

不同于nn.CellList的是,nn.SequentialCell类内部实现了construct方法,可以直接输出结果。

如下示例使用nn.SequentialCell构建一个网络,输入为List,网络结构依次包含一个之前定义的模型子块ConvBNReLU、一个Conv2d层、一个BatchNorm2d层和一个ReLU层:

[11]:
import numpy as np
import mindspore as ms
from mindspore import nn

class MyNet(nn.Cell):

    def __init__(self):
        super(MyNet, self).__init__()

        layers = [ConvBNReLU()]
        layers.extend([nn.Conv2d(64, 4, 4),
                       nn.BatchNorm2d(4),
                       nn.ReLU()])
        self.build_block = nn.SequentialCell(layers)  # 使用SequentialCell对网络进行管理

    def construct(self, x):
        return self.build_block(x)

net = MyNet()
print(net)
MyNet<
  (build_block): SequentialCell<
    (0): ConvBNReLU<
      (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
      (bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.0.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.0.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.0.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.0.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
      (relu): ReLU<>
      >
    (1): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (2): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.2.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.2.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.2.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.2.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
    (3): ReLU<>
    >
  >

把数据输入到网络模型中:

[12]:
input = ms.Tensor(np.ones([1, 3, 64, 32]), ms.float32)
output = net(input)
print(output.shape)
(1, 4, 64, 32)

如下示例使用nn.SequentialCell构建一个网络,输入为OrderedDict:

[13]:
import numpy as np
import mindspore as ms
from mindspore import nn
from collections import OrderedDict

class MyNet(nn.Cell):

    def __init__(self):
        super(MyNet, self).__init__()
        layers = OrderedDict()

        # 将cells加入字典
        layers["ConvBNReLU"] = ConvBNReLU()
        layers["conv"] = nn.Conv2d(64, 4, 4)
        layers["norm"] = nn.BatchNorm2d(4)
        layers["relu"] = nn.ReLU()

        # 使用SequentialCell对网络进行管理
        self.build_block = nn.SequentialCell(layers)

    def construct(self, x):
        return self.build_block(x)

net = MyNet()
print(net)

input = ms.Tensor(np.ones([1, 3, 64, 32]), ms.float32)
output = net(input)
print(output.shape)
MyNet<
  (build_block): SequentialCell<
    (ConvBNReLU): ConvBNReLU<
      (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
      (bn): BatchNorm2d<num_features=64, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.ConvBNReLU.bn.gamma, shape=(64,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.ConvBNReLU.bn.beta, shape=(64,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.ConvBNReLU.bn.moving_mean, shape=(64,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.ConvBNReLU.bn.moving_variance, shape=(64,), dtype=Float32, requires_grad=False)>
      (relu): ReLU<>
      >
    (conv): Conv2d<input_channels=64, output_channels=4, kernel_size=(4, 4), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>
    (norm): BatchNorm2d<num_features=4, eps=1e-05, momentum=0.09999999999999998, gamma=Parameter (name=build_block.norm.gamma, shape=(4,), dtype=Float32, requires_grad=True), beta=Parameter (name=build_block.norm.beta, shape=(4,), dtype=Float32, requires_grad=True), moving_mean=Parameter (name=build_block.norm.moving_mean, shape=(4,), dtype=Float32, requires_grad=False), moving_variance=Parameter (name=build_block.norm.moving_variance, shape=(4,), dtype=Float32, requires_grad=False)>
    (relu): ReLU<>
    >
  >
(1, 4, 64, 32)

nn与ops关系

mindspore.nn模块是Python实现的模型组件,对低阶API的封装,主要包括神经网络模型相关的各种模型层、损失函数、优化器等。

同时mindspore.nn也提供了部分与mindspore.ops算子同名的接口,主要作用是对mindspore.ops算子进行进一步封装,为用户提供更友好的API。用户也可使用mindspore.ops算子根据实际场景实现自定义的网络。

如下示例使用mindspore.ops.Conv2D算子实现卷积计算功能,即nn.Conv2d算子功能。

[14]:
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore as ms
from mindspore.common.initializer import initializer


class Net(nn.Cell):
    def __init__(self, in_channels=10, out_channels=20, kernel_size=3):
        super(Net, self).__init__()
        self.conv2d = ops.Conv2D(out_channels, kernel_size)
        self.bias_add = ops.BiasAdd()
        self.weight = ms.Parameter(
            initializer('normal', [out_channels, in_channels, kernel_size, kernel_size]),
            name='conv.weight')
        self.bias = ms.Parameter(initializer('normal', [out_channels]), name='conv.bias')

    def construct(self, x):
        """输入数据x"""
        output = self.conv2d(x, self.weight)
        output = self.bias_add(output, self.bias)
        return output