mindspore.Tensor.register_hook

View Source On Gitee
Tensor.register_hook(hook)[source]

Registers a backward hook for tensor.

Note

  • The hook must be defined as the following code. grad is the gradient passed to the tensor, which may be modified by returning a new output gradient.

  • The hook should have the following signature: hook(grad) -> New output gradient, but can not return None or not set return value.

  • The following constraints must be met under graph mode:

    • The hook must satisfy the syntax constraints of the graph mode.

    • Registering hook for Parameter is not supported in the graph (i.e., function Cell.construct or function decorated by @jit).

    • It is not supported to delete hook inside graph.

    • Register hook in the graph will return then Tensor it self.

Parameters

hook (function) – Python function. Tensor backward hook function.

Returns

A handle corresponding to the hook . The handle can be used to remove the added hook by calling handle.remove() .

Raises

TypeError – If the hook is not a function of python.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def hook_fn(grad):
...     return grad * 2
...
>>> def hook_test(x, y):
...     z = x * y
...     z.register_hook(hook_fn)
...     z = z * y
...     return z
...
>>> ms_grad = ms.grad(hook_test, grad_position=(0,1))
>>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32))
>>> print(output)
(Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))