mindspore.Tensor.register_hook

查看源文件
mindspore.Tensor.register_hook(hook)[源代码]

设置Tensor对象的反向hook函数。

说明

  • hook 必须有如下代码定义: grad 是反向传递给 Tensor 对象的梯度。 用户可以在 hook 中打印梯度数据或者返回新的输出梯度。

  • hook 返回新的梯度输出,不能不设置返回值: hook(grad) -> New grad_output

  • 静态图模式下需满足如下约束:

    • hook 同样需满足静态图模式下的语法约束。

    • 不支持在图内(即 Cell.construct 函数或被 @jit 修饰的函数)对 Parameter 注册 hook

    • 不支持在图内对 hook 进行删除。

    • 图内对 Tensor 注册 hook 将返回 Tensor 本身。

参数:
  • hook (function) - 捕获 Tensor 反向传播时的梯度,并输出或更改该梯度的 hook 函数。

返回:

返回与该 hook 函数对应的 handle 对象。可通过调用 handle.remove() 来删除添加的 hook 函数。

异常:
  • TypeError - 如果 hook 不是Python函数。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def hook_fn(grad):
...     return grad * 2
...
>>> def hook_test(x, y):
...     z = x * y
...     z.register_hook(hook_fn)
...     z = z * y
...     return z
...
>>> ms_grad = ms.grad(hook_test, grad_position=(0,1))
>>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32))
>>> print(output)
(Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))