mindspore.ops.HookBackward
- class mindspore.ops.HookBackward(hook_fn, cell_id='')[source]
This operation is used as a tag to hook gradient in intermediate variables. Note that this function is only supported in Pynative Mode.
Note
The hook function must be defined like hook_fn(grad) -> Tensor or None, where grad is the gradient passed to the primitive and gradient may be modified and passed to next primitive. The difference between a hook function and callback of InsertGradientOf is that a hook function is executed in the python environment while callback will be parsed and added to the graph.
- Parameters
hook_fn (Function) – Python function. hook function.
- Inputs:
inputs (Tensor) - The variable to hook.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> def hook_fn(grad_out): ... print(grad_out) ... >>> grad_all = GradOperation(get_all=True) >>> hook = ops.HookBackward(hook_fn) >>> def hook_test(x, y): ... z = x * y ... z = hook(z) ... z = z * y ... return z ... >>> def backward(x, y): ... return grad_all(hook_test)(x, y) ... >>> output = backward(1, 2) >>> print(output)