mindspore.ops.HookBackward

class mindspore.ops.HookBackward(hook_fn, cell_id='')[source]

This operation is used as a tag to hook gradient in intermediate variables. Note that this function is only supported in Pynative Mode.

Note

The hook function must be defined like hook_fn(grad) -> Tensor or None, where grad is the gradient passed to the primitive and gradient may be modified and passed to next primitive. The difference between a hook function and callback of InsertGradientOf is that a hook function is executed in the python environment while callback will be parsed and added to the graph.

Parameters

hook_fn (Function) – Python function. hook function.

Inputs:
  • inputs (Tensor) - The variable to hook.

Raises
  • TypeError – If inputs are not a Tensor.

  • TypeError – If hook_fn is not a function of python.

Supported Platforms:

Ascend GPU CPU

Examples

>>> def hook_fn(grad_out):
...     print(grad_out)
...
>>> grad_all = GradOperation(get_all=True)
>>> hook = ops.HookBackward(hook_fn)
>>> def hook_test(x, y):
...     z = x * y
...     z = hook(z)
...     z = z * y
...     return z
...
>>> def backward(x, y):
...     return grad_all(hook_test)(x, y)
...
>>> output = backward(1, 2)
>>> print(output)