{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hook编程\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/custom_program/mindspore_hook_program.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/custom_program/mindspore_hook_program.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/custom_program/hook_program.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "调试深度学习网络是每一个深度学习领域的从业者需要面对且投入精力较大的工作。由于深度学习网络隐藏了中间层算子的输入、输出数据以及反向梯度,只提供网络输入数据(特征量、权重)的梯度,导致无法准确地感知中间层算子的数据变化,从而降低了调试效率。为了方便用户准确、快速地对深度学习网络进行调试,MindSpore在动态图模式下设计了Hook功能,**使用Hook功能可以捕获中间层算子的输入、输出数据以及反向梯度**。\n",
    "\n",
    "目前,动态图模式下提供了五种形式的Hook功能,分别是:HookBackward算子和在Cell对象上进行注册的register_forward_pre_hook、register_forward_hook、register_backward_pre_hook、register_backward_hook功能。\n",
    "\n",
    "## HookBackward算子\n",
    "\n",
    "HookBackward将Hook功能以算子的形式实现。用户初始化一个HookBackward算子,将其安插到深度学习网络中需要捕获梯度的位置。在网络正向执行时,HookBackward算子将输入数据不做任何修改后原样输出;在网络反向传播梯度时,在HookBackward上注册的Hook函数将会捕获反向传播至此的梯度。用户可以在Hook函数中自定义对梯度的操作,比如打印梯度,或者返回新的梯度。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T03:32:04.585336Z",
     "start_time": "2024-08-15T03:32:04.578481Z"
    }
   },
   "source": [
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def hook_fn(grad_out):\n",
    "    \"\"\"打印梯度\"\"\"\n",
    "    print(\"hook_fn print grad_out:\", grad_out)\n",
    "\n",
    "hook = ops.HookBackward(hook_fn)\n",
    "def hook_test(x, y):\n",
    "    z = x * y\n",
    "    z = hook(z)\n",
    "    z = z * y\n",
    "    return z\n",
    "\n",
    "def net(x, y):\n",
    "    return ms.grad(hook_test, grad_position=(0, 1))(x, y)\n",
    "\n",
    "output = net(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))\n",
    "print(\"output:\", output)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hook_fn print grad_out: (Tensor(shape=[], dtype=Float32, value= 2),)\n",
      "output: (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))\n"
     ]
    }
   ],
   "execution_count": 15
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "更多HookBackward算子的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/master/api_python/ops/mindspore.ops.HookBackward.html)。\n",
    "\n",
    "## Cell对象的register_forward_pre_hook功能\n",
    "\n",
    "用户可以对Cell对象使用`register_forward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入该Cell对象的数据。该功能在静态图模式下和在使用`@jit`修饰的函数内不起作用。`register_forward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_pre_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_pre_hook_fn(cell, inputs):\n",
    "    print(\"forward inputs: \", inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的cell是Cell对象,inputs是正向传入到Cell对象的数据。因此,用户可以使用register_forward_pre_hook函数来捕获网络中某一个Cell对象的正向输入数据。用户可以在Hook函数中自定义对输入数据的操作,比如查看、打印数据,或者返回新的输入数据给当前的Cell对象。如果在Hook函数中对Cell对象的原始输入数据进行计算操作后,再作为新的输入数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T02:33:03.197024Z",
     "start_time": "2024-08-15T02:33:03.181122Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "[2.]\n",
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_pre_hook_fn(cell, inputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    input_x = inputs[0]\n",
    "    return input_x\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "\n",
    "x = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "y = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "output = net(x, y)\n",
    "print(output)\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)\n",
    "net.handle.remove()\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用户如果在Hook函数中直接返回新创建的数据,而不是返回由原始输入数据经过计算后得到的数据,那么梯度的反向传播将会在该Cell对象上截止。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T02:33:11.216359Z",
     "start_time": "2024-08-15T02:33:11.205063Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_pre_hook_fn(cell, inputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    return ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "\n",
    "x = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "y = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook` 函数和 `handle` 对象的 `remove()` 函数。在动态图模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook` 函数,那么Cell对象每次运行都将注册一个新的Hook函数。\n",
    "\n",
    "更多关于Cell对象的 `register_forward_pre_hook` 功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_forward_pre_hook)。\n",
    "\n",
    "## Cell对象的register_forward_hook功能\n",
    "\n",
    "用户可以在Cell对象上使用`register_forward_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入Cell对象的数据和Cell对象的输出数据。该功能在静态图模式下和在使用`@jit`修饰的函数内不起作用。`register_forward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_hook_fn(cell, inputs, outputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    print(\"forward outputs: \", outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的`cell`是Cell对象,`inputs`是正向传入到Cell对象的数据,`outputs`是Cell对象的正向输出数据。因此,用户可以使用`register_forward_hook`函数来捕获网络中某一个Cell对象的正向输入数据和输出数据。用户可以在Hook函数中自定义对输入、输出数据的操作,比如查看、打印数据,或者返回新的输出数据。如果在Hook函数中对Cell对象的原始输出数据进行计算操作后,再作为新的输出数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T02:33:22.691493Z",
     "start_time": "2024-08-15T02:33:22.680164Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "forward outputs:  [2.]\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_hook_fn(cell, inputs, outputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    print(\"forward outputs: \", outputs)\n",
    "    outputs = outputs + outputs\n",
    "    return outputs\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_hook(forward_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "\n",
    "x = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "y = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)\n",
    "net.handle.remove()\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用户如果在Hook函数中直接返回新创建的数据,而不是将原始的输出数据经过计算后,将得到的新输出数据返回,那么梯度的反向传播将会在该Cell对象上截止。该现象可以参考`register_forward_pre_hook`函数的用例说明。\n",
    "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的`construct`函数中调用`register_forward_hook`函数和`handle`对象的`remove()`函数。在动态图模式下,如果在Cell对象的`construct`函数中调用`register_forward_hook`函数,那么Cell对象每次运行都将注册一个新的Hook函数。\n",
    "\n",
    "更多关于Cell对象的`register_forward_hook`功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_forward_hook)。\n",
    "\n",
    "## Cell对象的register_backward_pre_hook功能\n",
    "\n",
    "用户可以在Cell对象上使用`register_backward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用`@jit`修饰的函数内不起作用。`register_backward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_pre_hook`函数,都会返回一个不同的`handle`对象。\n",
    "\n",
    "与HookBackward算子所使用的自定义Hook函数有所不同,`register_backward_pre_hook`使用的Hook函数的入参中,包含了表示Cell对象信息`cell`以及反向传入到Cell对象的梯度。\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def backward_pre_hook_function(grad_output):\n",
    "    print(grad_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的`cell`是Cell对象信息,`grad_output`是网络反向传播时,传入到Cell对象的梯度。因此,用户可以使用`register_backward_pre_hook`函数来捕获网络中某一个Cell对象的反向梯度输入值。用户可以在Hook函数中自定义对梯度的操作,比如查看、打印梯度,或者返回新的输入梯度。如果需要在Hook函数中返回新的输入梯度时,返回值必须是`tuple`的形式。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T02:33:48.925726Z",
     "start_time": "2024-08-15T02:33:48.885275Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=\n",
      "[[[[ 1.00000000e+00]],\n",
      "  [[ 1.00000000e+00]]]]),)\n",
      "[[[[1.99999 1.99999]\n",
      "   [1.99999 1.99999]]]]\n",
      "-------------\n",
      " [[[[1.99999 1.99999]\n",
      "   [1.99999 1.99999]]]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def backward_pre_hook_function(cell, grad_output):\n",
    "    print(grad_output)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init=\"ones\", pad_mode=\"valid\")\n",
    "        self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init=\"ones\")\n",
    "        self.handle = self.bn.register_backward_pre_hook(backward_pre_hook_function)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net)\n",
    "output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n",
    "print(output)\n",
    "net.handle.remove()\n",
    "output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n",
    "print(\"-------------\\n\", output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_pre_hook` 函数和 `handle` 对象的 `remove()` 函数。在PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_backward_pre_hook` 函数,那么Cell对象每次运行都将注册一个新的Hook函数。\n",
    "\n",
    "更多关于Cell对象的`register_backward_pre_hook`功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_backward_pre_hook)。\n",
    "\n",
    "## Cell对象的register_backward_hook功能\n",
    "\n",
    "用户可以在Cell对象上使用`register_backward_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用`@jit`修饰的函数内不起作用。`register_backward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_hook`函数,都会返回一个不同的`handle`对象。\n",
    "\n",
    "与HookBackward算子所使用的自定义Hook函数有所不同,`register_backward_hook`使用的Hook函数的入参中,包含了表示Cell对象信息`cell`、反向传入到Cell对象的梯度、以及Cell对象的反向输出的梯度。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T01:31:51.395381Z",
     "start_time": "2024-08-15T01:31:51.393199Z"
    }
   },
   "outputs": [],
   "source": [
    "def backward_hook_function(cell, grad_input, grad_output):\n",
    "    print(grad_input)\n",
    "    print(grad_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的`cell`是Cell对象信息,`grad_input`是Cell对象反向输出的梯度,`grad_output`是网络反向传播时,传入到Cell对象的梯度。因此,用户可以使用`register_backward_hook`函数来捕获网络中某一个Cell对象的反向传入和反向输出梯度。用户可以在Hook函数中自定义对梯度的操作,比如查看、打印梯度,或者返回新的输出梯度。如果需要在Hook函数中返回新的输出梯度时,返回值必须是`tuple`的形式。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T02:34:40.550644Z",
     "start_time": "2024-08-15T02:34:40.530050Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=\n",
      "[[[[ 9.99994993e-01]],\n",
      "  [[ 9.99994993e-01]]]]),)\n",
      "(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=\n",
      "[[[[ 1.00000000e+00]],\n",
      "  [[ 1.00000000e+00]]]]),)\n",
      "[[[[1.99999 1.99999]\n",
      "   [1.99999 1.99999]]]]\n",
      "-------------\n",
      " [[[[1.99999 1.99999]\n",
      "   [1.99999 1.99999]]]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def backward_hook_function(cell, grad_input, grad_output):\n",
    "    print(grad_input)\n",
    "    print(grad_output)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init=\"ones\", pad_mode=\"valid\")\n",
    "        self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init=\"ones\")\n",
    "        self.handle = self.bn.register_backward_hook(backward_hook_function)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net)\n",
    "output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n",
    "print(output)\n",
    "net.handle.remove()\n",
    "output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n",
    "print(\"-------------\\n\", output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook` 函数和 `handle` 对象的 `remove()` 函数。在PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_backward_hook` 函数,那么Cell对象每次运行都将注册一个新的Hook函数。\n",
    "\n",
    "更多关于Cell对象的 `register_backward_hook` 功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_backward_hook)。\n",
    "\n",
    "## Cell对象使用多个hook功能\n",
    "\n",
    "当 `register_backward_pre_hook` 函数、 `register_backward_hook` 函数、`register_forward_pre_hook` 函数、 `register_forward_hook` 函数同时作用于同一Cell对象时,如果 `register_forward_pre_hook` 和 `register_forward_hook` 函数中有添加其他算子进行数据处理,这些新增算子会在Cell对象执行前或者执行后参与数据的正向计算,但是这些新增算子的反向梯度不在 `register_backward_pre_hook` 函数和 `register_backward_hook` 函数的捕获范围内。 `register_backward_pre_hook` 中注册的Hook函数仅捕获原始Cell对象的输入梯度。`register_backward_hook` 中注册的Hook函数仅捕获原始Cell对象的输入、输出梯度。\n",
    "\n",
    "示例代码如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-15T03:04:03.727934Z",
     "start_time": "2024-08-15T03:04:03.712374Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "forward outputs:  [2.]\n",
      "grad input:  (Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]),)\n",
      "grad input:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "grad output:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_pre_hook_fn(cell, inputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    input_x = inputs[0]\n",
    "    return input_x\n",
    "\n",
    "def forward_hook_fn(cell, inputs, outputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    print(\"forward outputs: \", outputs)\n",
    "    outputs = outputs + outputs\n",
    "    return outputs\n",
    "\n",
    "def backward_pre_hook_fn(cell, grad_output):\n",
    "    print(\"grad input: \", grad_output)\n",
    "\n",
    "def backward_hook_fn(cell, grad_input, grad_output):\n",
    "    print(\"grad input: \", grad_output)\n",
    "    print(\"grad output: \", grad_input)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n",
    "        self.handle2 = self.relu.register_forward_hook(forward_hook_fn)\n",
    "        self.handle3 = self.relu.register_backward_pre_hook(backward_pre_hook_fn)\n",
    "        self.handle4 = self.relu.register_backward_hook(backward_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "gradient = grad_net(ms.Tensor(np.ones([1]).astype(np.float32)), ms.Tensor(np.ones([1]).astype(np.float32)))\n",
    "print(gradient)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的 `grad_output` 是梯度反向传播时传入`self.relu`的梯度,而不是传入 `forward_hook_fn` 函数中,新增的 `Add` 算子的梯度。这里的 `grad_input` 是梯度反向传播时 `self.relu` 反向输出的梯度,而不是 `forward_pre_hook_fn` 函数中新增 `Add` 算子的反向输出梯度。 `register_forward_pre_hook` 函数和 `register_forward_hook` 函数是在Cell对象执行前后起作用,不会影响Cell对象上反向Hook函数的梯度捕获范围。\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}