自定义Cell的反向
用户可以自定义nn.Cell对象的反向传播(计算)函数,从而控制nn.Cell对象梯度计算的过程,定位梯度问题。
自定义bprop函数的使用方法是:在定义的nn.Cell对象里面增加一个用户自定义的bprop函数。训练的过程中会使用用户自定义的bprop函数来生成反向图。
示例代码:
[5]:
ms.set_context(mode=ms.PYNATIVE_MODE)
class Net(nn.Cell):
def construct(self, x, y):
z = x * y
z = z * y
return z
def bprop(self, x, y, out, dout):
x_dout = x + y
y_dout = x * y
return x_dout, y_dout
grad_all = ops.GradOperation(get_all=True)
output = grad_all(Net())(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))
print(output)
(Tensor(shape=[], dtype=Float32, value= 3), Tensor(shape=[], dtype=Float32, value= 2))