mindspore.ops.HookBackward
- class mindspore.ops.HookBackward(hook_fn, cell_id='')[源代码]
用来导出中间变量中的梯度。请注意,此函数仅在PyNative模式下支持。
说明
钩子函数必须定义为 hook_fn(grad) -> new gradient or None ,其中’grad’是传递给Primitive的梯度。可以通过返回新的梯度并传递到下一个Primitive来修改’grad’。钩子函数和InsertGradientOf的回调的区别在于,钩子函数是在python环境中执行的,而回调将被解析并添加到图中。
- 参数:
hook_fn (Function) - Python函数。钩子函数。
cell_id (str,可选) - 用于标识钩子注册的函数是否实际注册在指定的cell对象上。例如,
mindspore.nn.Conv2d
是一个cell对象。默认值:””,此情况下系统将自动注册 cell_id 的值。 此参数目前不支持自定义。
- 输入:
input (Tensor) - 需要导出的变量的梯度。
- 输出:
output (Tensor) - 直接返回 input 。 HookBackward 不影响前向结果。
- 异常:
TypeError - 如果 input 不是Tensor。
TypeError - 如果 hook_fn 不是Python的函数。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore as ms >>> from mindspore import ops >>> from mindspore import Tensor >>> from mindspore.ops import GradOperation >>> ms.set_context(mode=ms.PYNATIVE_MODE) >>> def hook_fn(grad): ... print(grad) ... >>> hook = ops.HookBackward(hook_fn) >>> def hook_test(x, y): ... z = x * y ... z = hook(z) ... z = z * y ... return z ... >>> grad_all = GradOperation(get_all=True) >>> def backward(x, y): ... return grad_all(hook_test)(x, y) ... >>> output = backward(Tensor(1, ms.float32), Tensor(2, ms.float32)) (Tensor(shape=[], dtype=Float32, value= 2),) >>> print(output) (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))