mindspore.Tensor.register_hook
- Tensor.register_hook(hook)[source]
Registers a backward hook for tensor.
Note
The hook must be defined as the following code. grad is the gradient passed to the tensor, which may be modified by returning a new output gradient.
The hook should have the following signature: hook(grad) -> New output gradient, but can not return None or not set return value.
The following constraints must be met under graph mode:
The hook must satisfy the syntax constraints of the graph mode.
Registering hook for Parameter is not supported in the graph (i.e., function Cell.construct or function decorated by @jit).
It is not supported to delete hook inside graph.
Register hook in the graph will return then Tensor it self.
- Parameters
hook (function) – Python function. Tensor backward hook function.
- Returns
A handle corresponding to the hook . The handle can be used to remove the added hook by calling handle.remove() .
- Raises
TypeError – If the hook is not a function of python.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore as ms >>> from mindspore import Tensor >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def hook_fn(grad): ... return grad * 2 ... >>> def hook_test(x, y): ... z = x * y ... z.register_hook(hook_fn) ... z = z * y ... return z ... >>> ms_grad = ms.grad(hook_test, grad_position=(0,1)) >>> output = ms_grad(Tensor(1, ms.float32), Tensor(2, ms.float32)) >>> print(output) (Tensor(shape=[], dtype=Float32, value=8), Tensor(shape=[], dtype=Float32, value=6))