mindspore.ops.HookBackward
- class mindspore.ops.HookBackward(hook_fn, cell_id='')[源代码]
This operation is used as a tag to hook gradient in intermediate variables. Note that this function is only supported in pynative mode.
Note
The hook function must be defined like hook_fn(grad) -> new gradient or None, where the ‘grad’ is the gradient passed to the primitive. The ‘grad’ may be modified by returning a new gradient and passed to next primitive. The difference between a hook function and callback of InsertGradientOf is that the hook function is executed in the python environment while callback will be parsed and added to the graph.
- Parameters
hook_fn (Function) – Python function. hook function.
cell_id (str) – Used to identify whether the function registered by the hook is actually registered on the specified cell object. For example, ‘nn.Conv2d’ is a cell object. The default value of cell_id is empty string(“”), in this case, the system will automatically register a value of cell_id. The value of cell_id currently does not support custom values.
- Inputs:
input (Tensor) - The variable to hook.
- Outputs:
output (Tensor) - Returns input directly. HookBackward does not affect the forward result.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> from mindspore import ops >>> from mindspore import Tensor >>> from mindspore import context >>> from mindspore.ops import GradOperation >>> context.set_context(mode=context.PYNATIVE_MODE) >>> def hook_fn(grad): ... print(grad) ... >>> hook = ops.HookBackward(hook_fn) >>> def hook_test(x, y): ... z = x * y ... z = hook(z) ... z = z * y ... return z ... >>> grad_all = GradOperation(get_all=True) >>> def backward(x, y): ... return grad_all(hook_test)(x, y) ... >>> output = backward(Tensor(1, mindspore.float32), Tensor(2, mindspore.float32)) (Tensor(shape=[], dtype=Float32, value= 2),) >>> print(output) (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))