mindspore.Tensor.register_hook

查看源文件
mindspore.Tensor.register_hook(hook_fn)[源代码]

设置Tensor对象的反向hook函数。

说明

  • register_hook(hook_fn) 在图模式下,或者在PyNative模式下使用 jit 装饰器功能时不起作用。

  • hook_fn必须有如下代码定义:grad 是反向传递给Tensor对象的梯度。 用户可以在hook_fn中打印梯度数据或者返回新的输出梯度。

  • hook_fn返回新的梯度输出,不能不设置返回值:hook_fn(grad) -> New grad_output。

参数:
  • hook_fn (function) - 捕获Tensor反向传播时的梯度,并输出或更改该梯度的 hook_fn 函数。

返回:

返回与该hook_fn函数对应的handle对象。可通过调用handle.remove()来删除添加的hook_fn函数。

异常:
  • TypeError - 如果 hook_fn 不是Python函数。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def hook_fn(grad):
...     return grad * 2
...
>>> def hook_test(x, y):
...     z = x * y
...     z.register_hook(hook_fn)
...     z = z * y
...     return z
...
>>> ms_grad = ms.grad(hook_test, grad_position=(0,1))
>>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32))
>>> print(output)
(Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))