mindspore.Tensor.register_hook
- Tensor.register_hook(hook)[source]
Registers a backward hook for tensor.
Note
The hook must be defined as the following code. grad is the gradient passed to the tensor, which may be modified by returning a new output gradient.
The hook should have the following signature: hook(grad) -> New output gradient, but can not return None or not set return value.
The following constraints must be met under graph mode:
The hook must satisfy the syntax constraints of the graph mode.
It is not supported to delete hook inside graph.
It is not supported to register hook after the Tensor is used before.
It is not supported to register multiple hooks for a Tensor inside graph.
Register hook in the graph will return then Tensor it self.
- Parameters
hook (function) – Python function. Tensor backward hook function.
- Returns
A handle corresponding to the hook . The handle can be used to remove the added hook by calling handle.remove() .
- Raises
TypeError – If the hook is not a function of python.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore as ms >>> from mindspore import Tensor >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def hook_fn(grad): ... return grad * 2 ... >>> def hook_test(x, y): ... z = x * y ... z.register_hook(hook_fn) ... z = z * y ... return z ... >>> ms_grad = ms.grad(hook_test, grad_position=(0,1)) >>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32)) >>> print(output) (Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))