Cell与参数
Cell作为神经网络构造的基础单元,与神经网络层(Layer)的概念相对应,对Tensor计算操作的抽象封装,能够更准确清晰地对神经网络结构进行表示。除了基础的Tensor计算流程定义外,神经网络层还包含了参数管理、状态管理等功能。而参数(Parameter)是神经网络训练的核心,通常作为神经网络层的内部成员变量。本节我们将系统介绍参数、神经网络层以及其相关使用方法。
Parameter
参数(Parameter)是一类特殊的Tensor,是指在模型训练过程中可以对其值进行更新的变量。MindSpore提供mindspore.Parameter
类进行Parameter的构造。为了对不同用途的Parameter进行区分,下面对两种不同类别的Parameter进行定义:
可训练参数。在模型训练过程中根据反向传播算法求得梯度后进行更新的Tensor,此时需要将
required_grad
设置为True
。不可训练参数。不参与反向传播,但需要更新值的Tensor(如BatchNorm中的
mean
和var
变量),此时需要将requires_grad
设置为False
。
Parameter默认设置
required_grad=True
。
下面我们构造一个简单的全连接层:
[1]:
import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter
class Network(nn.Cell):
def __init__(self):
super().__init__()
self.w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
self.b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias
def construct(self, x):
z = ops.matmul(x, self.w) + self.b
return z
net = Network()
在Cell
的__init__
方法中,我们定义了w
和b
两个Parameter,并配置name
进行命名空间管理。在construct
方法中使用self.attr
直接调用参与Tensor运算。
获取Parameter
在使用Cell+Parameter构造神经网络层后,我们可以使用多种方法来获取Cell管理的Parameter。
获取单个参数
单独获取某个特定参数,直接调用Python类的成员变量即可。
[2]:
print(net.b.asnumpy())
[-1.2192779 -0.36789745 0.0946381 ]
获取可训练参数
可使用Cell.trainable_params
方法获取可训练参数,通常在配置优化器时需调用此接口。
[3]:
print(net.trainable_params())
[Parameter (name=w, shape=(5, 3), dtype=Float32, requires_grad=True), Parameter (name=b, shape=(3,), dtype=Float32, requires_grad=True)]
获取所有参数
使用Cell.get_parameters()
方法可获取所有参数,此时会返回一个Python迭代器。
[4]:
print(type(net.get_parameters()))
<class 'generator'>
或者可以调用Cell.parameters_and_names
返回参数名称及参数。
[5]:
for name, param in net.parameters_and_names():
print(f"{name}:\n{param.asnumpy()}")
w:
[[ 4.15680408e-02 -1.20311625e-01 5.02573885e-02]
[ 1.22175144e-04 -1.34980649e-01 1.17642188e+00]
[ 7.57667869e-02 -1.74758151e-01 -5.19092619e-01]
[-1.67846107e+00 3.27240258e-01 -2.06452996e-01]
[ 5.72323874e-02 -8.27963874e-02 5.94243526e-01]]
b:
[-1.2192779 -0.36789745 0.0946381 ]
修改Parameter
直接修改参数值
Parameter是一种特殊的Tensor,因此可以使用Tensor索引修改的方式对其值进行修改。
[6]:
net.b[0] = 1.
print(net.b.asnumpy())
[ 1. -0.36789745 0.0946381 ]
覆盖修改参数值
可调用Parameter.set_data
方法,使用相同Shape的Tensor对Parameter进行覆盖。该方法常用于使用Initializer进行Cell遍历初始化。
[7]:
net.b.set_data(Tensor([3, 4, 5]))
print(net.b.asnumpy())
[3. 4. 5.]
运行时修改参数值
参数的主要作用为模型训练时对其值进行更新,在反向传播获得梯度后,或不可训练参数需要进行更新,都涉及到运行时参数修改。由于MindSpore的使用静态图加速编译设计,此时需要使用mindspore.ops.assign
接口对参数进行赋值。该方法常用于自定义优化器场景。下面是一个简单的运行时修改参数值样例:
[8]:
import mindspore as ms
@ms.jit
def modify_parameter():
b_hat = ms.Tensor([7, 8, 9])
ops.assign(net.b, b_hat)
return True
modify_parameter()
print(net.b.asnumpy())
[7. 8. 9.]
Parameter Tuple
变量元组ParameterTuple,用于保存多个Parameter,继承于元组tuple,提供克隆功能。
如下示例提供ParameterTuple创建方法:
[10]:
from mindspore.common.initializer import initializer
from mindspore import ParameterTuple
# 创建
x = Parameter(default_input=ms.Tensor(np.arange(2 * 3).reshape((2, 3))), name="x")
y = Parameter(default_input=initializer('ones', [1, 2, 3], ms.float32), name='y')
z = Parameter(default_input=2.0, name='z')
params = ParameterTuple((x, y, z))
# 从params克隆并修改名称为"params_copy"
params_copy = params.clone("params_copy")
print(params)
print(params_copy)
(Parameter (name=x, shape=(2, 3), dtype=Int64, requires_grad=True), Parameter (name=y, shape=(1, 2, 3), dtype=Float32, requires_grad=True), Parameter (name=z, shape=(), dtype=Float32, requires_grad=True))
(Parameter (name=params_copy.x, shape=(2, 3), dtype=Int64, requires_grad=True), Parameter (name=params_copy.y, shape=(1, 2, 3), dtype=Float32, requires_grad=True), Parameter (name=params_copy.z, shape=(), dtype=Float32, requires_grad=True))
Cell训练状态转换
神经网络中的部分Tensor操作在训练和推理时的表现并不相同,如nn.Dropout
在训练时进行随机丢弃,但在推理时则不丢弃,nn.BatchNorm
在训练时需要更新mean
和var
两个变量,在推理时则固定其值不变。因此我们可以通过Cell.set_train
接口来设置神经网络的状态。
set_train(True)
时,神经网络状态为train
, set_train
接口默认值为True
:
[11]:
net.set_train()
print(net.phase)
train
set_train(False)
时,神经网络状态为predict
:
[12]:
net.set_train(False)
print(net.phase)
predict
自定义神经网络层
通常情况下,MindSpore提供的神经网络层接口和function函数接口能够满足模型构造需求,但由于AI领域不断推陈出新,因此有可能遇到新网络结构没有内置模块的情况。此时我们可以根据需要,通过MindSpore提供的function接口、Primitive算子自定义神经网络层,并可以使用Cell.bprop
方法自定义反向。下面分别详述三种自定义方法。
使用function接口构造神经网络层
MindSpore提供大量基础的function接口,可以使用其构造复杂的Tensor操作,封装为神经网络层。下面以Threshold
为例,其公式如下:
可以看到Threshold
判断Tensor的值是否大于threshold
值,保留判断结果为True
的值,替换判断结果为False
的值。因此,对应实现如下:
[43]:
class Threshold(nn.Cell):
def __init__(self, threshold, value):
super().__init__()
self.threshold = threshold
self.value = value
def construct(self, inputs):
cond = ops.gt(inputs, self.threshold)
value = ops.fill(inputs.dtype, inputs.shape, self.value)
return ops.select(cond, inputs, value)
这里分别使用了ops.gt
、ops.fill
、ops.select
来实现判断和替换。下面执行自定义的Threshold
层:
[45]:
m = Threshold(0.1, 20)
inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)
m(inputs)
[45]:
Tensor(shape=[3], dtype=Float32, value= [ 2.00000000e+01, 2.00000003e-01, 3.00000012e-01])
可以看到inputs[0] = threshold
, 因此被替换为20
。
自定义Cell反向
在特殊场景下,我们不但需要自定义神经网络层的正向逻辑,也需要手动控制其反向的计算,此时我们可以通过Cell.bprop
接口对其反向进行定义。在全新的神经网络结构设计、反向传播速度优化等场景下会用到该功能。下面我们以Dropout2d
为例,介绍如何自定义Cell反向:
[55]:
class Dropout2d(nn.Cell):
def __init__(self, keep_prob):
super().__init__()
self.keep_prob = keep_prob
self.dropout2d = ops.Dropout2D(keep_prob)
def construct(self, x):
return self.dropout2d(x)
def bprop(self, x, out, dout):
_, mask = out
dy, _ = dout
if self.keep_prob != 0:
dy = dy * (1 / self.keep_prob)
dy = mask.astype(mindspore.float32) * dy
return (dy.astype(x.dtype), )
dropout_2d = Dropout2d(0.8)
dropout_2d.bprop_debug = True
bprop
方法分别有三个入参:
x: 正向输入,当正向输入为多个时,需同样数量的入参。
out: 正向输出。
dout: 反向传播时,当前Cell执行之前的反向结果。
一般我们需要根据正向输出和前层反向结果配合,根据反向求导公式计算反向结果,并将其返回。Dropout2d
的反向计算需要根据正向输出的mask
矩阵对前层反向结果进行mask,然后根据keep_prob
进行缩放。最终可得到正确的计算结果。
自定义Cell反向时,在PyNative模式下支持拓展写法,可以对Cell内部的权重求导,具体列子如下:
[ ]:
class NetWithParam(nn.Cell):
def __init__(self):
super(NetWithParam, self).__init__()
self.w = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name='weight')
self.internal_params = [self.w]
def construct(self, x):
output = self.w * x
return output
def bprop(self, *args):
return (self.w * args[-1],), {self.w: args[0] * args[-1]}
bprop
方法支持*args入参,args数组中最后一位args[-1]
为返回给该cell的梯度。通过self.internal_params
设置求导的权重,同时在bprop
函数的返回值为一个元组和一个字典,返回输入对应梯度的元组,以及以key为权重,value为权重对应梯度的字典。
Hook功能
调试深度学习网络是每一个深度学习领域的从业者需要面对且投入精力较大的工作。由于深度学习网络隐藏了中间层算子的输入、输出数据以及反向梯度,只提供网络输入数据(特征量、权重)的梯度,导致无法准确地感知中间层算子的数据变化,从而降低了调试效率。为了方便用户准确、快速地对深度学习网络进行调试,MindSpore在动态图模式下设计了Hook功能,使用Hook功能可以捕获中间层算子的输入、输出数据以及反向梯度。
目前,动态图模式下提供了四种形式的Hook功能,分别是:HookBackward算子和在Cell对象上进行注册的register_forward_pre_hook、register_forward_hook、register_backward_hook功能。
HookBackward算子
HookBackward将Hook功能以算子的形式实现。用户初始化一个HookBackward算子,将其安插到深度学习网络中需要捕获梯度的位置。在网络正向执行时,HookBackward算子将输入数据不做任何修改后原样输出;在网络反向传播梯度时,在HookBackward上注册的Hook函数将会捕获反向传播至此的梯度。用户可以在Hook函数中自定义对梯度的操作,比如打印梯度,或者返回新的梯度。
示例代码:
[5]:
import mindspore as ms
from mindspore import ops
ms.set_context(mode=ms.PYNATIVE_MODE)
def hook_fn(grad_out):
"""打印梯度"""
print("hook_fn print grad_out:", grad_out)
hook = ops.HookBackward(hook_fn)
def hook_test(x, y):
z = x * y
z = hook(z)
z = z * y
return z
def net(x, y):
return ms.grad(hook_test, grad_position=(0, 1))(x, y)
output = net(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))
print("output:", output)
hook_fn print grad_out: (Tensor(shape=[], dtype=Float32, value= 2),)
output: (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))
更多HookBackward算子的说明可以参考API文档。
Cell对象的register_forward_pre_hook功能
用户可以在Cell对象上使用register_forward_pre_hook
函数来注册一个自定义的Hook函数,用来捕获正向传入该Cell对象的数据。该功能在静态图模式下和在使用@jit
修饰的函数内不起作用。register_forward_pre_hook
函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的handle
对象。用户可以通过调用handle
对象的remove()
函数来删除与之对应的Hook函数。每一次调用register_forward_pre_hook
函数,都会返回一个不同的handle
对象。Hook函数应该按照以下的方式进行定义。
[6]:
def forward_pre_hook_fn(cell, inputs):
print("forward inputs: ", inputs)
这里的cell是Cell对象,inputs是正向传入到Cell对象的数据。因此,用户可以使用register_forward_pre_hook函数来捕获网络中某一个Cell对象的正向输入数据。用户可以在Hook函数中自定义对输入数据的操作,比如查看、打印数据,或者返回新的输入数据给当前的Cell对象。如果在Hook函数中对Cell对象的原始输入数据进行计算操作后,再作为新的输入数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。
示例代码:
[7]:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
ms.set_context(mode=ms.PYNATIVE_MODE)
def forward_pre_hook_fn(cell, inputs):
print("forward inputs: ", inputs)
input_x = inputs[0]
return input_x
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)
def construct(self, x, y):
x = x + y
x = self.relu(x)
return x
net = Net()
grad_net = ms.grad(net, grad_position=(0, 1))
x = ms.Tensor(np.ones([1]).astype(np.float32))
y = ms.Tensor(np.ones([1]).astype(np.float32))
output = net(x, y)
print(output)
gradient = grad_net(x, y)
print(gradient)
net.handle.remove()
gradient = grad_net(x, y)
print(gradient)
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
[2.]
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))
(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))
用户如果在Hook函数中直接返回新创建的数据,而不是返回由原始输入数据经过计算后得到的数据,那么梯度的反向传播将会在该Cell对象上截止。
示例代码:
[8]:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
ms.set_context(mode=ms.PYNATIVE_MODE)
def forward_pre_hook_fn(cell, inputs):
print("forward inputs: ", inputs)
return ms.Tensor(np.ones([1]).astype(np.float32))
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)
def construct(self, x, y):
x = x + y
x = self.relu(x)
return x
net = Net()
grad_net = ms.grad(net, grad_position=(0, 1))
x = ms.Tensor(np.ones([1]).astype(np.float32))
y = ms.Tensor(np.ones([1]).astype(np.float32))
gradient = grad_net(x, y)
print(gradient)
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
(Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]))
为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct
函数中调用 register_forward_pre_hook
函数和 handle
对象的 remove()
函数。在动态图模式下,如果在Cell对象的 construct
函数中调用 register_forward_pre_hook
函数,那么Cell对象每次运行都将新注册一个Hook函数。
更多关于Cell对象的 register_forward_pre_hook
功能的说明可以参考API文档。
Cell对象的register_forward_hook功能
用户可以在Cell对象上使用register_forward_hook
函数来注册一个自定义的Hook函数,用来捕获正向传入Cell对象的数据和Cell对象的输出数据。该功能在静态图模式下和在使用@jit
修饰的函数内不起作用。register_forward_hook
函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的handle
对象。用户可以通过调用handle
对象的remove()
函数来删除与之对应的Hook函数。每一次调用register_forward_hook
函数,都会返回一个不同的handle
对象。Hook函数应该按照以下的方式进行定义。
示例代码:
[9]:
def forward_hook_fn(cell, inputs, outputs):
print("forward inputs: ", inputs)
print("forward outputs: ", outputs)
这里的cell
是Cell对象,inputs
是正向传入到Cell对象的数据,outputs
是Cell对象的正向输出数据。因此,用户可以使用register_forward_hook
函数来捕获网络中某一个Cell对象的正向输入数据和输出数据。用户可以在Hook函数中自定义对输入、输出数据的操作,比如查看、打印数据,或者返回新的输出数据。如果在Hook函数中对Cell对象的原始输出数据进行计算操作后,再作为新的输出数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。
示例代码:
[10]:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
ms.set_context(mode=ms.PYNATIVE_MODE)
def forward_hook_fn(cell, inputs, outputs):
print("forward inputs: ", inputs)
print("forward outputs: ", outputs)
outputs = outputs + outputs
return outputs
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.handle = self.relu.register_forward_hook(forward_hook_fn)
def construct(self, x, y):
x = x + y
x = self.relu(x)
return x
net = Net()
grad_net = ms.grad(net, grad_position=(0, 1))
x = ms.Tensor(np.ones([1]).astype(np.float32))
y = ms.Tensor(np.ones([1]).astype(np.float32))
gradient = grad_net(x, y)
print(gradient)
net.handle.remove()
gradient = grad_net(x, y)
print(gradient)
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
forward outputs: [2.]
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))
(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))
用户如果在Hook函数中直接返回新创建的数据,而不是将原始的输出数据经过计算后,将得到的新输出数据返回,那么梯度的反向传播将会在该Cell对象上截止。该现象可以参考register_forward_pre_hook
函数的用例说明。
为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的construct
函数中调用register_forward_hook
函数和handle
对象的remove()
函数。在动态图模式下,如果在Cell对象的construct
函数中调用register_forward_hook
函数,那么Cell对象每次运行都将新注册一个Hook函数。
更多关于Cell对象的register_forward_hook
功能的说明可以参考API文档。
Cell对象的register_backward_hook功能
用户可以在Cell对象上使用register_backward_hook
函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用@jit
修饰的函数内不起作用。register_backward_hook
函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的handle
对象。用户可以通过调用handle
对象的remove()
函数来删除与之对应的Hook函数。每一次调用register_backward_hook
函数,都会返回一个不同的handle
对象。
与HookBackward算子所使用的自定义Hook函数有所不同,register_backward_hook
使用的Hook函数的入参中,包含了表示Cell对象名称与id信息的cell_id
、反向传入到Cell对象的梯度、以及Cell对象的反向输出的梯度。
示例代码:
[11]:
def backward_hook_function(cell_id, grad_input, grad_output):
print(grad_input)
print(grad_output)
这里的cell_id
是Cell对象的名称以及ID信息,grad_input
是网络反向传播时,传入到Cell对象的梯度,它对应于正向过程中下一个算子的反向输出梯度;grad_output
是Cell对象反向输出的梯度。因此,用户可以使用register_backward_hook
函数来捕获网络中某一个Cell对象的反向传入和反向输出梯度。用户可以在Hook函数中自定义对梯度的操作,比如查看、打印梯度,或者返回新的输出梯度。如果需要在Hook函数中返回新的输出梯度时,返回值必须是tuple
的形式。
示例代码:
[12]:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
ms.set_context(mode=ms.PYNATIVE_MODE)
def backward_hook_function(cell_id, grad_input, grad_output):
print(grad_input)
print(grad_output)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones")
self.handle = self.bn.register_backward_hook(backward_hook_function)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
net = Net()
grad_net = ms.grad(net)
output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))
print(output)
net.handle.remove()
output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))
print("-------------\n", output)
(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=
[[[[ 1.00000000e+00]],
[[ 1.00000000e+00]]]]),)
(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=
[[[[ 9.99994993e-01]],
[[ 9.99994993e-01]]]]),)
[[[[1.99999 1.99999]
[1.99999 1.99999]]]]
-------------
[[[[1.99999 1.99999]
[1.99999 1.99999]]]]
当 register_backward_hook
函数和 register_forward_pre_hook
函数、 register_forward_hook
函数同时作用于同一Cell对象时,如果 register_forward_pre_hook
和 register_forward_hook
函数中有添加其他算子进行数据处理,这些新增算子会在Cell对象执行前或者执行后参与数据的正向计算,但是这些新增算子的反向梯度不在 register_backward_hook
函数的捕获范围内。 register_backward_hook
中注册的Hook函数仅捕获原始Cell对象的输入、输出梯度。
示例代码:
[13]:
import numpy as np
import mindspore as ms
import mindspore.nn as nn
ms.set_context(mode=ms.PYNATIVE_MODE)
def forward_pre_hook_fn(cell, inputs):
print("forward inputs: ", inputs)
input_x = inputs[0]
return input_x
def forward_hook_fn(cell, inputs, outputs):
print("forward inputs: ", inputs)
print("forward outputs: ", outputs)
outputs = outputs + outputs
return outputs
def backward_hook_fn(cell_id, grad_input, grad_output):
print("grad input: ", grad_input)
print("grad output: ", grad_output)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.relu = nn.ReLU()
self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)
self.handle2 = self.relu.register_forward_hook(forward_hook_fn)
self.handle3 = self.relu.register_backward_hook(backward_hook_fn)
def construct(self, x, y):
x = x + y
x = self.relu(x)
return x
net = Net()
grad_net = ms.grad(net, grad_position=(0, 1))
gradient = grad_net(ms.Tensor(np.ones([1]).astype(np.float32)), ms.Tensor(np.ones([1]).astype(np.float32)))
print(gradient)
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
forward inputs: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
forward outputs: [2.]
grad input: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
grad output: (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))
这里的 grad_input
是梯度反向传播时传入self.relu
的梯度,而不是传入 forward_hook_fn
函数中,新增的 Add
算子的梯度。这里的 grad_output
是梯度反向传播时 self.relu
反向输出的梯度,而不是 forward_pre_hook_fn
函数中新增 Add
算子的反向输出梯度。 register_forward_pre_hook
函数和 register_forward_hook
函数是在Cell对象执行前后起作用,不会影响Cell对象上反向Hook函数的梯度捕获范围。
为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 construct
函数中调用 register_backward_hook
函数和 handle
对象的 remove()
函数。在PyNative模式下,如果在Cell对象的 construct
函数中调用 register_backward_hook
函数,那么Cell对象每次运行都将新注册一个Hook函数。
更多关于Cell对象的 register_backward_hook
功能的说明可以参考API文档。