mindspore.ops.HookBackward

class mindspore.ops.HookBackward(hook_fn, cell_id='')[源代码]

This operation is used as a tag to hook gradient in intermediate variables. Note that this function is only supported in pynative mode.

Note

The hook function must be defined like hook_fn(grad) -> new gradient or None, where the ‘grad’ is the gradient passed to the primitive. The ‘grad’ may be modified by returning a new gradient and passed to next primitive. The difference between a hook function and callback of InsertGradientOf is that the hook function is executed in the python environment while callback will be parsed and added to the graph.

Parameters
  • hook_fn (Function) – Python function. hook function.

  • cell_id (str) – Used to identify whether the function registered by the hook is actually registered on the specified cell object. For example, ‘nn.Conv2d’ is a cell object. The default value of cell_id is empty string(“”), in this case, the system will automatically register a value of cell_id. The value of cell_id currently does not support custom values.

Inputs:
  • input (Tensor) - The variable to hook.

Outputs:
  • output (Tensor) - Returns input directly. HookBackward does not affect the forward result.

Raises
  • TypeError – If input is not a tensor.

  • TypeError – If hook_fn is not a function of python.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore as ms
>>> from mindspore import ops
>>> from mindspore import Tensor
>>> from mindspore.ops import GradOperation
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def hook_fn(grad):
...     print(grad)
...
>>> hook = ops.HookBackward(hook_fn)
>>> def hook_test(x, y):
...     z = x * y
...     z = hook(z)
...     z = z * y
...     return z
...
>>> grad_all = GradOperation(get_all=True)
>>> def backward(x, y):
...     return grad_all(hook_test)(x, y)
...
>>> output = backward(Tensor(1, ms.float32), Tensor(2, ms.float32))
(Tensor(shape=[], dtype=Float32, value= 2),)
>>> print(output)
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))