自定义Cell的反向传播函数
使用MindSpore构建神经网络时,需要继承 nn.Cell
类。构建网络的过程中,我们可能会遇到一些问题,例如:
Cell中存在一些不可求导的或者是尚未定义反向传播规则的操作或算子;
替换Cell的某些正向计算过程时,需要自定义相应的反向传播函数。
这时我们可以使用自定义Cell对象的反向传播函数的功能,形式为:
def bprop(self, ..., out, dout):
return ...
输入参数:与正向部分相同的输入参数再加上
out
和dout
,out
表示正向部分的计算结果,dout
表示回传到该nn.Cell
对象的梯度。返回值:关于正向部分每个输入的梯度,所以返回值的数量需要与正向部分输入的数量相同。
一个简单的完整示例如下:
[1]:
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops as ops
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = ops.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + 1
dy = y + 1
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(get_all=True)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32)
y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32)
out = GradNet(Net())(x, y)
print(out)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 1.50000000e+00, 1.60000002e+00, 1.39999998e+00],
[ 2.20000005e+00, 2.29999995e+00, 2.09999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.00999999e+00, 1.29999995e+00, 2.09999990e+00],
[ 1.10000002e+00, 1.20000005e+00, 2.29999995e+00],
[ 3.09999990e+00, 2.20000005e+00, 4.30000019e+00]]))
此示例通过定义Cell的 bprop
函数,对 MatMul
操作自定义了梯度计算过程,其中 dx
为对输入 x
的导数, dy
为对输入 y
的导数, out
为 MatMul
的计算结果, dout
为回传到 Net
的梯度。
应用样例
Cell中存在一些尚未定义反向传播规则的操作或算子。例如
ReLU6
算子尚未定义其二阶反向传播规则,这时我们可以通过自定义Cell的bprop
函数去自定义ReLU6
算子的二阶反向传播规则。代码如下:
[2]:
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
import mindspore.ops as ops
class ReluNet(nn.Cell):
def __init__(self):
super(ReluNet, self).__init__()
self.relu = ops.ReLU()
def construct(self, x):
return self.relu(x)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.relu6 = ops.ReLU6()
self.relu = ReluNet()
def construct(self, x):
return self.relu6(x)
def bprop(self, x, out, dout):
dx = self.relu(x)
return (dx, )
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
net = Net()
out = ops.grad(ops.grad(net))(x)
print(out)
[[1. 1. 1.]
[1. 1. 1.]]
此代码通过自定义 Net
的 bprop
函数,定义了一阶反向传播规则,而二阶反向传播规则通过 bprop
中的 self.relu
的反向传播规则得到。
替换Cell的某些正向计算过程时,需要自定义相应的反向传播函数。例如SNN网络有如下代码:
class relusigmoid(nn.Cell): def __init__(self): super().__init__() self.sigmoid = ops.Sigmoid() self.greater = ops.Greater() def construct(self, x): spike = self.greater(x, 0) return spike.astype(mindspore.float32) def bprop(self, x, out, dout): sgax = self.sigmoid(x * 5.0) grad_x = dout * (1 - sgax) * sgax * 5.0 return (grad_x,) class IFNode(nn.Cell): def __init__(self, v_threshold=1.0, fire=True, surrogate_function=relusigmoid()): super().__init__() self.v_threshold = v_threshold self.fire = fire self.surrogate_function = surrogate_function def construct(self, x, v): v = v + x if self.fire: spike = self.surrogate_function(v - self.v_threshold) * self.v_threshold v -= spike return spike, v return v, v
此代码自定义了一个新的激活函数relusigmoid,在子网
IFNode
里去替换原来的sigmoid激活函数,这时候就需要去自定义新的激活函数的反向传播规则。
约束与限制
当
bprop
函数的返回值数量为1时,也需要写成tuple的形式,即return (dx,)
。图模式下,
bprop
函数需要转换成图IR,所以需要遵循静态图语法,请参考静态图语法支持。只支持返回关于正向部分输入的梯度,不支持返回关于
Parameter
的梯度。不支持在
bprop
中使用Parameter
。