mindspore.ops.HookBackward

查看源文件
class mindspore.ops.HookBackward(hook_fn, cell_id='')[源代码]

用来导出中间变量中的梯度。请注意,此函数仅在PyNative模式下支持。

说明

钩子函数必须定义为 hook_fn(grad) -> new gradient or None ,其中’grad’是传递给Primitive的梯度。可以通过返回新的梯度并传递到下一个Primitive来修改’grad’。钩子函数和InsertGradientOf的回调的区别在于,钩子函数是在python环境中执行的,而回调将被解析并添加到图中。

参数:
  • hook_fn (Function) - Python函数。钩子函数。

  • cell_id (str,可选) - 用于标识钩子注册的函数是否实际注册在指定的cell对象上。例如,mindspore.nn.Conv2d 是一个cell对象。默认值: "" ,此情况下系统将自动注册 cell_id 的值。 此参数目前不支持自定义。

输入:
  • input (Tensor) - 需要导出的变量的梯度。

输出:
  • output (Tensor) - 直接返回 inputHookBackward 不影响前向结果。

异常:
  • TypeError - 如果 input 不是Tensor。

  • TypeError - 如果 hook_fn 不是Python的函数。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore as ms
>>> from mindspore import ops
>>> from mindspore import Tensor
>>> from mindspore.ops import GradOperation
>>> ms.set_context(mode=ms.PYNATIVE_MODE)
>>> def hook_fn(grad):
...     print(grad)
...
>>> hook = ops.HookBackward(hook_fn)
>>> def hook_test(x, y):
...     z = x * y
...     z = hook(z)
...     z = z * y
...     return z
...
>>> grad_all = GradOperation(get_all=True)
>>> def backward(x, y):
...     return grad_all(hook_test)(x, y)
...
>>> output = backward(Tensor(1, ms.float32), Tensor(2, ms.float32))
(Tensor(shape=[], dtype=Float32, value= 2),)
>>> print(output)
(Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))