Customizing Reverse Propagation Function of Cell
When MindSpore is used to build a neural network, the nn.Cell
class needs to be inherited. We might have the following problems when we construct networks:
There are operations or operators in Cell that are not derivable or for which reverse propagation rules are not yet defined.
When replacing certain forward calculation procedures of Cell, you need to customize the corresponding reverse propagation function.
Then we can use the function of customizing the backward propagation function of the Cell object. The format is as follows:
def bprop(self, ..., out, dout):
return ...
Input parameters: Input parameters in the forward propagation plus
out
anddout
.out
indicates the computation result of the forward propagation, anddout
indicates the gradient returned to thenn.Cell
object.Return values: Gradient of each input in the forward propagation. The number of return values must be the same as the number of inputs in the forward propagation.
A complete simple example is as follows:
import mindspore.nn as nn
import mindspore as ms
import mindspore.ops as ops
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.matmul = ops.MatMul()
def construct(self, x, y):
out = self.matmul(x, y)
return out
def bprop(self, x, y, out, dout):
dx = x + 1
dy = y + 1
return dx, dy
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(get_all=True)
def construct(self, x, y):
gradient_function = self.grad_op(self.net)
return gradient_function(x, y)
x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32)
y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32)
out = GradNet(Net())(x, y)
print(out)
(Tensor(shape=[2, 3], dtype=Float32, value=
[[ 1.50000000e+00, 1.60000002e+00, 1.39999998e+00],
[ 2.20000005e+00, 2.29999995e+00, 2.09999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
[[ 1.00999999e+00, 1.29999995e+00, 2.09999990e+00],
[ 1.10000002e+00, 1.20000005e+00, 2.29999995e+00],
[ 3.09999990e+00, 2.20000005e+00, 4.30000019e+00]]))
This example customizes the gradient calculation process for the MatMul
operation by defining bprop
function of Cell, where dx
is the derivative of the input x
, dy
is the derivative of the input y
, out
is the result of the MatMul
calculation, and dout
is the gradient passed back to Net
.
Application example
There are some operators which is non-differentiable or has not been defined the back propagation function in the Cell. For example, the operator
ReLU6
has not been defined its second-order back propagation rule, which can be defined by customizing thebprop
function of Cell. The code is as follow:import mindspore.nn as nn from mindspore import Tensor from mindspore import dtype as mstype import mindspore.ops as ops class ReluNet(nn.Cell): def __init__(self): super(ReluNet, self).__init__() self.relu = ops.ReLU() def construct(self, x): return self.relu(x) class Net(nn.Cell): def __init__(self): super(Net, self).__init__() self.relu6 = ops.ReLU6() self.relu = ReluNet() def construct(self, x): return self.relu6(x) def bprop(self, x, out, dout): dx = self.relu(x) return (dx, ) x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) net = Net() out = ops.grad(ops.grad(net))(x) print(out)
[[1. 1. 1.] [1. 1. 1.]]
The above code defines the first-order back propagation rule by customizing the
bprop
function ofNet
and gets the second-order back propagation rule by the back propagation rule ofself.relu
in thebprop
.We need the customized back propagation function when we want to replace some forward calculate process of the Cell. For example, there is following code in the network SNN:
class relusigmoid(nn.Cell): def __init__(self): super().__init__() self.sigmoid = ops.Sigmoid() self.greater = ops.Greater() def construct(self, x): spike = self.greater(x, 0) return spike.astype(mindspore.float32) def bprop(self, x, out, dout): sgax = self.sigmoid(x * 5.0) grad_x = dout * (1 - sgax) * sgax * 5.0 return (grad_x,) class IFNode(nn.Cell): def __init__(self, v_threshold=1.0, fire=True, surrogate_function=relusigmoid()): super().__init__() self.v_threshold = v_threshold self.fire = fire self.surrogate_function = surrogate_function def construct(self, x, v): v = v + x if self.fire: spike = self.surrogate_function(v - self.v_threshold) * self.v_threshold v -= spike return spike, v return v, v
The above code replaces the origin sigmoid activation function in the sub-network
IFNode
with a customized activation function relusigmoid, and then we should customize the new back propagation function for the new activation function.
Constraints
If the number of return values of the
bprop
function is 1, the return value must be written in the tuple format, that is,return (dx,)
.In graph mode, the
bprop
function needs to be converted into a graph IR. Therefore, the static graph syntax must be complied with. For details, see Static Graph Syntax Support.Only support returning the gradient of the forward propagation input, not the gradient of the
Parameter
.The use of
Parameter
is not supported inbprop
.