自定义神经网络层
通常情况下,MindSpore提供的神经网络层接口和function函数接口能够满足模型构造需求,但由于AI领域不断推陈出新,因此有可能遇到新网络结构没有内置模块的情况。此时我们可以根据需要,通过MindSpore提供的function接口、Primitive算子自定义神经网络层,并可以使用Cell.bprop
方法自定义反向。下面分别详述三种自定义方法。
使用function接口构造神经网络层
MindSpore提供大量基础的function接口,可以使用其构造复杂的Tensor操作,封装为神经网络层。下面以Threshold
为例,其公式如下:
可以看到Threshold
判断Tensor的值是否大于threshold
值,保留判断结果为True
的值,替换判断结果为False
的值。因此,对应实现如下:
[43]:
import mindspore
import numpy as np
from mindspore import nn, ops, Tensor, Parameter
class Threshold(nn.Cell):
def __init__(self, threshold, value):
super().__init__()
self.threshold = threshold
self.value = value
def construct(self, inputs):
cond = ops.gt(inputs, self.threshold)
value = ops.fill(inputs.dtype, inputs.shape, self.value)
return ops.select(cond, inputs, value)
这里分别使用了ops.gt
、ops.fill
、ops.select
来实现判断和替换。下面执行自定义的Threshold
层:
[45]:
m = Threshold(0.1, 20)
inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)
m(inputs)
[45]:
Tensor(shape=[3], dtype=Float32, value= [ 2.00000000e+01, 2.00000003e-01, 3.00000012e-01])
可以看到inputs[0] = threshold
, 因此被替换为20
。
自定义Cell反向
在特殊场景下,我们不但需要自定义神经网络层的正向逻辑,也需要手动控制其反向的计算,此时我们可以通过Cell.bprop
接口对其反向进行定义。在全新的神经网络结构设计、反向传播速度优化等场景下会用到该功能。下面我们以Dropout2d
为例,介绍如何自定义Cell反向:
[55]:
class Dropout2d(nn.Cell):
def __init__(self, keep_prob):
super().__init__()
self.keep_prob = keep_prob
self.dropout2d = ops.Dropout2D(keep_prob)
def construct(self, x):
return self.dropout2d(x)
def bprop(self, x, out, dout):
_, mask = out
dy, _ = dout
if self.keep_prob != 0:
dy = dy * (1 / self.keep_prob)
dy = mask.astype(mindspore.float32) * dy
return (dy.astype(x.dtype), )
dropout_2d = Dropout2d(0.8)
dropout_2d.bprop_debug = True
bprop
方法分别有三个入参:
x: 正向输入,当正向输入为多个时,需同样数量的入参。
out: 正向输出。
dout: 反向传播时,当前Cell执行之前的反向结果。
一般我们需要根据正向输出和前层反向结果配合,根据反向求导公式计算反向结果,并将其返回。Dropout2d
的反向计算需要根据正向输出的mask
矩阵对前层反向结果进行mask,然后根据keep_prob
进行缩放。最终可得到正确的计算结果。
自定义Cell反向时,在PyNative模式下支持拓展写法,可以对Cell内部的权重求导,具体列子如下:
[ ]:
class NetWithParam(nn.Cell):
def __init__(self):
super(NetWithParam, self).__init__()
self.w = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name='weight')
self.internal_params = [self.w]
def construct(self, x):
output = self.w * x
return output
def bprop(self, *args):
return (self.w * args[-1],), {self.w: args[0] * args[-1]}
bprop
方法支持*args入参,args数组中最后一位args[-1]
为返回给该cell的梯度。通过self.internal_params
设置求导的权重,同时在bprop
函数的返回值为一个元组和一个字典,返回输入对应梯度的元组,以及以key为权重,value为权重对应梯度的字典。