mindspore.Tensor.register_hook
- mindspore.Tensor.register_hook(hook)[源代码]
设置Tensor对象的反向hook函数。
说明
hook 必须有如下代码定义: grad 是反向传递给 Tensor 对象的梯度。 用户可以在 hook 中打印梯度数据或者返回新的输出梯度。
hook 返回新的梯度输出,不能不设置返回值: hook(grad) -> New grad_output。
静态图模式下需满足如下约束:
hook 同样需满足静态图模式下的语法约束。
不支持在图内(即 Cell.construct 函数或被 @jit 修饰的函数)对 Parameter 注册 hook。
不支持在图内对 hook 进行删除。
图内对 Tensor 注册 hook 将返回 Tensor 本身。
- 参数:
hook (function) - 捕获 Tensor 反向传播时的梯度,并输出或更改该梯度的 hook 函数。
- 返回:
返回与该 hook 函数对应的 handle 对象。可通过调用 handle.remove() 来删除添加的 hook 函数。
- 异常:
TypeError - 如果 hook 不是Python函数。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore as ms >>> from mindspore import Tensor >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def hook_fn(grad): ... return grad * 2 ... >>> def hook_test(x, y): ... z = x * y ... z.register_hook(hook_fn) ... z = z * y ... return z ... >>> ms_grad = ms.grad(hook_test, grad_position=(0,1)) >>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32)) >>> print(output) (Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))