mindspore.Tensor.register_hook
- mindspore.Tensor.register_hook(hook_fn)[源代码]
设置Tensor对象的反向hook函数。
说明
register_hook(hook_fn) 在图模式下,或者在PyNative模式下使用 jit 装饰器功能时不起作用。
hook_fn必须有如下代码定义:grad 是反向传递给Tensor对象的梯度。 用户可以在hook_fn中打印梯度数据或者返回新的输出梯度。
hook_fn返回新的梯度输出,不能不设置返回值:hook_fn(grad) -> New grad_output。
- 参数:
hook_fn (function) - 捕获Tensor反向传播时的梯度,并输出或更改该梯度的 hook_fn 函数。
- 返回:
返回与该hook_fn函数对应的handle对象。可通过调用handle.remove()来删除添加的hook_fn函数。
- 异常:
TypeError - 如果 hook_fn 不是Python函数。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore as ms >>> from mindspore import Tensor >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def hook_fn(grad): ... return grad * 2 ... >>> def hook_test(x, y): ... z = x * y ... z.register_hook(hook_fn) ... z = z * y ... return z ... >>> ms_grad = ms.grad(hook_test, grad_position=(0,1)) >>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32)) >>> print(output) (Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))