mindspore.Tensor.register_hook
- Tensor.register_hook(hook_fn)[source]
Registers a backward hook for tensor.
Note
The register_backward_hook(hook_fn) does not work in graph mode or functions decorated with 'jit'.
The 'hook_fn' must be defined as the following code. grad is the gradient passed to the tensor, which may be modified by returning a new output gradient.
The 'hook_fn' should have the following signature: hook_fn(grad) -> New output gradient, but can not return None or not set return value.
- Parameters
hook_fn (function) – Python function. Tensor backward hook function.
- Returns
A handle corresponding to the hook_fn . The handle can be used to remove the added hook_fn by calling handle.remove() .
- Raises
TypeError – If the hook_fn is not a function of python.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore as ms >>> from mindspore import Tensor >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def hook_fn(grad): ... return grad * 2 ... >>> def hook_test(x, y): ... z = x * y ... z.register_hook(hook_fn) ... z = z * y ... return z ... >>> ms_grad = ms.grad(hook_test, grad_position=(0,1)) >>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32)) >>> print(output) (Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))