{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_modelarts.svg)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9yMi4yL3R1dG9yaWFscy96aF9jbi9hZHZhbmNlZC9tb2R1bGVzL21pbmRzcG9yZV9sYXllci5pcHluYg==&imageid=4c43b3ad-9df7-4b83-a096-c775dc4ba243) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/zh_cn/advanced/modules/mindspore_layer.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.2/tutorials/zh_cn/advanced/modules/mindspore_layer.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.2/tutorials/source_zh_cn/advanced/modules/layer.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Cell与参数\n",
    "\n",
    "Cell作为神经网络构造的基础单元,与神经网络层(Layer)的概念相对应,对Tensor计算操作的抽象封装,能够更准确清晰地对神经网络结构进行表示。除了基础的Tensor计算流程定义外,神经网络层还包含了参数管理、状态管理等功能。而参数(Parameter)是神经网络训练的核心,通常作为神经网络层的内部成员变量。本节我们将系统介绍参数、神经网络层以及其相关使用方法。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter\n",
    "\n",
    "参数(Parameter)是一类特殊的Tensor,是指在模型训练过程中可以对其值进行更新的变量。MindSpore提供`mindspore.Parameter`类进行Parameter的构造。为了对不同用途的Parameter进行区分,下面对两种不同类别的Parameter进行定义:\n",
    "\n",
    "- 可训练参数。在模型训练过程中根据反向传播算法求得梯度后进行更新的Tensor,此时需要将`required_grad`设置为`True`。\n",
    "- 不可训练参数。不参与反向传播,但需要更新值的Tensor(如BatchNorm中的`mean`和`var`变量),此时需要将`requires_grad`设置为`False`。\n",
    "\n",
    "> Parameter默认设置`required_grad=True`。\n",
    "\n",
    "下面我们构造一个简单的全连接层:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "from mindspore import ops\n",
    "from mindspore import Tensor, Parameter\n",
    "\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight\n",
    "        self.b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias\n",
    "\n",
    "    def construct(self, x):\n",
    "        z = ops.matmul(x, self.w) + self.b\n",
    "        return z\n",
    "\n",
    "net = Network()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在`Cell`的`__init__`方法中,我们定义了`w`和`b`两个Parameter,并配置`name`进行命名空间管理。在`construct`方法中使用`self.attr`直接调用参与Tensor运算。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 获取Parameter\n",
    "\n",
    "在使用Cell+Parameter构造神经网络层后,我们可以使用多种方法来获取Cell管理的Parameter。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 获取单个参数\n",
    "\n",
    "单独获取某个特定参数,直接调用Python类的成员变量即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-12-29T03:42:22.729232Z",
     "start_time": "2021-12-29T03:42:22.723517Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-1.2192779  -0.36789745  0.0946381 ]\n"
     ]
    }
   ],
   "source": [
    "print(net.b.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 获取可训练参数\n",
    "\n",
    "可使用`Cell.trainable_params`方法获取可训练参数,通常在配置优化器时需调用此接口。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Parameter (name=w, shape=(5, 3), dtype=Float32, requires_grad=True), Parameter (name=b, shape=(3,), dtype=Float32, requires_grad=True)]\n"
     ]
    }
   ],
   "source": [
    "print(net.trainable_params())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 获取所有参数\n",
    "\n",
    "使用`Cell.get_parameters()`方法可获取所有参数,此时会返回一个Python迭代器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'generator'>\n"
     ]
    }
   ],
   "source": [
    "print(type(net.get_parameters()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "或者可以调用`Cell.parameters_and_names`返回参数名称及参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "w:\n",
      "[[ 4.15680408e-02 -1.20311625e-01  5.02573885e-02]\n",
      " [ 1.22175144e-04 -1.34980649e-01  1.17642188e+00]\n",
      " [ 7.57667869e-02 -1.74758151e-01 -5.19092619e-01]\n",
      " [-1.67846107e+00  3.27240258e-01 -2.06452996e-01]\n",
      " [ 5.72323874e-02 -8.27963874e-02  5.94243526e-01]]\n",
      "b:\n",
      "[-1.2192779  -0.36789745  0.0946381 ]\n"
     ]
    }
   ],
   "source": [
    "for name, param in net.parameters_and_names():\n",
    "    print(f\"{name}:\\n{param.asnumpy()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 修改Parameter\n",
    "\n",
    "#### 直接修改参数值\n",
    "\n",
    "Parameter是一种特殊的Tensor,因此可以使用Tensor索引修改的方式对其值进行修改。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 1.         -0.36789745  0.0946381 ]\n"
     ]
    }
   ],
   "source": [
    "net.b[0] = 1.\n",
    "print(net.b.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 覆盖修改参数值\n",
    "\n",
    "可调用`Parameter.set_data`方法,使用相同Shape的Tensor对Parameter进行覆盖。该方法常用于使用Initializer进行[Cell遍历初始化](https://www.mindspore.cn/tutorials/zh-CN/r2.2/advanced/modules/initializer.html#cell%E9%81%8D%E5%8E%86%E5%88%9D%E5%A7%8B%E5%8C%96)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3. 4. 5.]\n"
     ]
    }
   ],
   "source": [
    "net.b.set_data(Tensor([3, 4, 5]))\n",
    "print(net.b.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 运行时修改参数值\n",
    "\n",
    "参数的主要作用为模型训练时对其值进行更新,在反向传播获得梯度后,或不可训练参数需要进行更新,都涉及到运行时参数修改。由于MindSpore的[使用静态图加速](https://www.mindspore.cn/tutorials/zh-CN/r2.2/beginner/accelerate_with_static_graph.html)编译设计,此时需要使用`mindspore.ops.assign`接口对参数进行赋值。该方法常用于[自定义优化器](https://www.mindspore.cn/tutorials/zh-CN/r2.2/advanced/modules/optimizer.html#%E8%87%AA%E5%AE%9A%E4%B9%89%E4%BC%98%E5%8C%96%E5%99%A8)场景。下面是一个简单的运行时修改参数值样例:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7. 8. 9.]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "\n",
    "@ms.jit\n",
    "def modify_parameter():\n",
    "    b_hat = ms.Tensor([7, 8, 9])\n",
    "    ops.assign(net.b, b_hat)\n",
    "    return True\n",
    "\n",
    "modify_parameter()\n",
    "print(net.b.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameter Tuple\n",
    "\n",
    "变量元组ParameterTuple,用于保存多个Parameter,继承于元组tuple,提供克隆功能。\n",
    "\n",
    "如下示例提供ParameterTuple创建方法:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Parameter (name=x, shape=(2, 3), dtype=Int64, requires_grad=True), Parameter (name=y, shape=(1, 2, 3), dtype=Float32, requires_grad=True), Parameter (name=z, shape=(), dtype=Float32, requires_grad=True))\n",
      "(Parameter (name=params_copy.x, shape=(2, 3), dtype=Int64, requires_grad=True), Parameter (name=params_copy.y, shape=(1, 2, 3), dtype=Float32, requires_grad=True), Parameter (name=params_copy.z, shape=(), dtype=Float32, requires_grad=True))\n"
     ]
    }
   ],
   "source": [
    "from mindspore.common.initializer import initializer\n",
    "from mindspore import ParameterTuple\n",
    "# 创建\n",
    "x = Parameter(default_input=ms.Tensor(np.arange(2 * 3).reshape((2, 3))), name=\"x\")\n",
    "y = Parameter(default_input=initializer('ones', [1, 2, 3], ms.float32), name='y')\n",
    "z = Parameter(default_input=2.0, name='z')\n",
    "params = ParameterTuple((x, y, z))\n",
    "\n",
    "# 从params克隆并修改名称为\"params_copy\"\n",
    "params_copy = params.clone(\"params_copy\")\n",
    "\n",
    "print(params)\n",
    "print(params_copy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Cell训练状态转换\n",
    "\n",
    "神经网络中的部分Tensor操作在训练和推理时的表现并不相同,如`nn.Dropout`在训练时进行随机丢弃,但在推理时则不丢弃,`nn.BatchNorm`在训练时需要更新`mean`和`var`两个变量,在推理时则固定其值不变。因此我们可以通过`Cell.set_train`接口来设置神经网络的状态。\n",
    "\n",
    "`set_train(True)`时,神经网络状态为`train`, `set_train`接口默认值为`True`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train\n"
     ]
    }
   ],
   "source": [
    "net.set_train()\n",
    "print(net.phase)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`set_train(False)`时,神经网络状态为`predict`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict\n"
     ]
    }
   ],
   "source": [
    "net.set_train(False)\n",
    "print(net.phase)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义神经网络层\n",
    "\n",
    "通常情况下,MindSpore提供的神经网络层接口和function函数接口能够满足模型构造需求,但由于AI领域不断推陈出新,因此有可能遇到新网络结构没有内置模块的情况。此时我们可以根据需要,通过MindSpore提供的function接口、Primitive算子自定义神经网络层,并可以使用`Cell.bprop`方法自定义反向。下面分别详述三种自定义方法。\n",
    "\n",
    "### 使用function接口构造神经网络层\n",
    "\n",
    "MindSpore提供大量基础的function接口,可以使用其构造复杂的Tensor操作,封装为神经网络层。下面以`Threshold`为例,其公式如下:\n",
    "\n",
    "$$\n",
    "y =\\begin{cases}\n",
    "   x, &\\text{ if } x > \\text{threshold} \\\\\n",
    "   \\text{value}, &\\text{ otherwise }\n",
    "   \\end{cases}\n",
    "$$\n",
    "\n",
    "可以看到`Threshold`判断Tensor的值是否大于`threshold`值,保留判断结果为`True`的值,替换判断结果为`False`的值。因此,对应实现如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Threshold(nn.Cell):\n",
    "    def __init__(self, threshold, value):\n",
    "        super().__init__()\n",
    "        self.threshold = threshold\n",
    "        self.value = value\n",
    "\n",
    "    def construct(self, inputs):\n",
    "        cond = ops.gt(inputs, self.threshold)\n",
    "        value = ops.fill(inputs.dtype, inputs.shape, self.value)\n",
    "        return ops.select(cond, inputs, value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里分别使用了`ops.gt`、`ops.fill`、`ops.select`来实现判断和替换。下面执行自定义的`Threshold`层:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[3], dtype=Float32, value= [ 2.00000000e+01,  2.00000003e-01,  3.00000012e-01])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = Threshold(0.1, 20)\n",
    "inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)\n",
    "m(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到`inputs[0] = threshold`, 因此被替换为`20`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 自定义Cell反向\n",
    "\n",
    "在特殊场景下,我们不但需要自定义神经网络层的正向逻辑,也需要手动控制其反向的计算,此时我们可以通过`Cell.bprop`接口对其反向进行定义。在全新的神经网络结构设计、反向传播速度优化等场景下会用到该功能。下面我们以`Dropout2d`为例,介绍如何自定义Cell反向:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dropout2d(nn.Cell):\n",
    "    def __init__(self, keep_prob):\n",
    "        super().__init__()\n",
    "        self.keep_prob = keep_prob\n",
    "        self.dropout2d = ops.Dropout2D(keep_prob)\n",
    "\n",
    "    def construct(self, x):\n",
    "        return self.dropout2d(x)\n",
    "\n",
    "    def bprop(self, x, out, dout):\n",
    "        _, mask = out\n",
    "        dy, _ = dout\n",
    "        if self.keep_prob != 0:\n",
    "            dy = dy * (1 / self.keep_prob)\n",
    "        dy = mask.astype(mindspore.float32) * dy\n",
    "        return (dy.astype(x.dtype), )\n",
    "\n",
    "dropout_2d = Dropout2d(0.8)\n",
    "dropout_2d.bprop_debug = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`bprop`方法分别有三个入参:\n",
    "\n",
    "- *x*: 正向输入,当正向输入为多个时,需同样数量的入参。\n",
    "- *out*: 正向输出。\n",
    "- *dout*: 反向传播时,当前Cell执行之前的反向结果。\n",
    "\n",
    "一般我们需要根据正向输出和前层反向结果配合,根据反向求导公式计算反向结果,并将其返回。`Dropout2d`的反向计算需要根据正向输出的`mask`矩阵对前层反向结果进行mask,然后根据`keep_prob`进行缩放。最终可得到正确的计算结果。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hook功能\n",
    "\n",
    "调试深度学习网络是每一个深度学习领域的从业者需要面对且投入精力较大的工作。由于深度学习网络隐藏了中间层算子的输入、输出数据以及反向梯度,只提供网络输入数据(特征量、权重)的梯度,导致无法准确地感知中间层算子的数据变化,从而降低了调试效率。为了方便用户准确、快速地对深度学习网络进行调试,MindSpore在动态图模式下设计了Hook功能,**使用Hook功能可以捕获中间层算子的输入、输出数据以及反向梯度**。\n",
    "\n",
    "目前,动态图模式下提供了四种形式的Hook功能,分别是:HookBackward算子和在Cell对象上进行注册的register_forward_pre_hook、register_forward_hook、register_backward_hook功能。\n",
    "\n",
    "### HookBackward算子\n",
    "\n",
    "HookBackward将Hook功能以算子的形式实现。用户初始化一个HookBackward算子,将其安插到深度学习网络中需要捕获梯度的位置。在网络正向执行时,HookBackward算子将输入数据不做任何修改后原样输出;在网络反向传播梯度时,在HookBackward上注册的Hook函数将会捕获反向传播至此的梯度。用户可以在Hook函数中自定义对梯度的操作,比如打印梯度,或者返回新的梯度。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-02T09:13:38.660885Z",
     "start_time": "2022-03-02T09:13:38.645597Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hook_fn print grad_out: (Tensor(shape=[], dtype=Float32, value= 2),)\n",
      "output: (Tensor(shape=[], dtype=Float32, value= 4), Tensor(shape=[], dtype=Float32, value= 4))\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def hook_fn(grad_out):\n",
    "    \"\"\"打印梯度\"\"\"\n",
    "    print(\"hook_fn print grad_out:\", grad_out)\n",
    "\n",
    "hook = ops.HookBackward(hook_fn)\n",
    "def hook_test(x, y):\n",
    "    z = x * y\n",
    "    z = hook(z)\n",
    "    z = z * y\n",
    "    return z\n",
    "\n",
    "def net(x, y):\n",
    "    return ms.grad(hook_test, grad_position=(0, 1))(x, y)\n",
    "\n",
    "output = net(ms.Tensor(1, ms.float32), ms.Tensor(2, ms.float32))\n",
    "print(\"output:\", output)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "更多HookBackward算子的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r2.2/api_python/ops/mindspore.ops.HookBackward.html)。\n",
    "\n",
    "### Cell对象的register_forward_pre_hook功能\n",
    "\n",
    "用户可以在Cell对象上使用`register_forward_pre_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入该Cell对象的数据。该功能在静态图模式下和在使用`@jit`修饰的函数内不起作用。`register_forward_pre_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_pre_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_pre_hook_fn(cell_id, inputs):\n",
    "    print(\"forward inputs: \", inputs)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的cell_id是Cell对象的名称以及ID信息,inputs是正向传入到Cell对象的数据。因此,用户可以使用register_forward_pre_hook函数来捕获网络中某一个Cell对象的正向输入数据。用户可以在Hook函数中自定义对输入数据的操作,比如查看、打印数据,或者返回新的输入数据给当前的Cell对象。如果在Hook函数中对Cell对象的原始输入数据进行计算操作后,再作为新的输入数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "[2.]\n",
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_pre_hook_fn(cell_id, inputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    input_x = inputs[0]\n",
    "    return input_x\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "\n",
    "x = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "y = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "output = net(x, y)\n",
    "print(output)\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)\n",
    "net.handle.remove()\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用户如果在Hook函数中直接返回新创建的数据,而不是返回由原始输入数据经过计算后得到的数据,那么梯度的反向传播将会在该Cell对象上截止。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 0.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_pre_hook_fn(cell_id, inputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    return ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "\n",
    "x = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "y = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook` 函数和 `handle` 对象的 `remove()` 函数。在动态图模式下,如果在Cell对象的 `construct` 函数中调用 `register_forward_pre_hook` 函数,那么Cell对象每次运行都将新注册一个Hook函数。\n",
    "\n",
    "更多关于Cell对象的 `register_forward_pre_hook` 功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r2.2/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_forward_pre_hook)。\n",
    "\n",
    "### Cell对象的register_forward_hook功能\n",
    "\n",
    "用户可以在Cell对象上使用`register_forward_hook`函数来注册一个自定义的Hook函数,用来捕获正向传入Cell对象的数据和Cell对象的输出数据。该功能在静态图模式下和在使用`@jit`修饰的函数内不起作用。`register_forward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_forward_hook`函数,都会返回一个不同的`handle`对象。Hook函数应该按照以下的方式进行定义。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_hook_fn(cell_id, inputs, outputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    print(\"forward outputs: \", outputs)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的`cell_id`是Cell对象的名称以及ID信息,`inputs`是正向传入到Cell对象的数据,`outputs`是Cell对象的正向输出数据。因此,用户可以使用`register_forward_hook`函数来捕获网络中某一个Cell对象的正向输入数据和输出数据。用户可以在Hook函数中自定义对输入、输出数据的操作,比如查看、打印数据,或者返回新的输出数据。如果在Hook函数中对Cell对象的原始输出数据进行计算操作后,再作为新的输出数据返回,这些新增的计算操作将会同时作用于梯度的反向传播。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "forward outputs:  [2.]\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 1.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_hook_fn(cell_id, inputs, outputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    print(\"forward outputs: \", outputs)\n",
    "    outputs = outputs + outputs\n",
    "    return outputs\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_hook(forward_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "\n",
    "x = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "y = ms.Tensor(np.ones([1]).astype(np.float32))\n",
    "\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)\n",
    "net.handle.remove()\n",
    "gradient = grad_net(x, y)\n",
    "print(gradient)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "用户如果在Hook函数中直接返回新创建的数据,而不是将原始的输出数据经过计算后,将得到的新输出数据返回,那么梯度的反向传播将会在该Cell对象上截止。该现象可以参考`register_forward_pre_hook`函数的用例说明。\n",
    "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的`construct`函数中调用`register_forward_hook`函数和`handle`对象的`remove()`函数。在动态图模式下,如果在Cell对象的`construct`函数中调用`register_forward_hook`函数,那么Cell对象每次运行都将新注册一个Hook函数。\n",
    "\n",
    "更多关于Cell对象的`register_forward_hook`功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r2.2/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_forward_hook)。\n",
    "\n",
    "### Cell对象的register_backward_hook功能\n",
    "\n",
    "用户可以在Cell对象上使用`register_backward_hook`函数来注册一个自定义的Hook函数,用来捕获网络反向传播时与Cell对象相关联的梯度。该功能在图模式下或者在使用`@jit`修饰的函数内不起作用。`register_backward_hook`函数接收Hook函数作为入参,并返回一个与Hook函数一一对应的`handle`对象。用户可以通过调用`handle`对象的`remove()`函数来删除与之对应的Hook函数。每一次调用`register_backward_hook`函数,都会返回一个不同的`handle`对象。\n",
    "\n",
    "与HookBackward算子所使用的自定义Hook函数有所不同,`register_backward_hook`使用的Hook函数的入参中,包含了表示Cell对象名称与id信息的`cell_id`、反向传入到Cell对象的梯度、以及Cell对象的反向输出的梯度。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def backward_hook_function(cell_id, grad_input, grad_output):\n",
    "    print(grad_input)\n",
    "    print(grad_output)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的`cell_id`是Cell对象的名称以及ID信息,`grad_input`是网络反向传播时,传入到Cell对象的梯度,它对应于正向过程中下一个算子的反向输出梯度;`grad_output`是Cell对象反向输出的梯度。因此,用户可以使用`register_backward_hook`函数来捕获网络中某一个Cell对象的反向传入和反向输出梯度。用户可以在Hook函数中自定义对梯度的操作,比如查看、打印梯度,或者返回新的输出梯度。如果需要在Hook函数中返回新的输出梯度时,返回值必须是`tuple`的形式。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-03-02T09:14:26.523389Z",
     "start_time": "2022-03-02T09:14:26.506784Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=\n",
      "[[[[ 1.00000000e+00]],\n",
      "  [[ 1.00000000e+00]]]]),)\n",
      "(Tensor(shape=[1, 2, 1, 1], dtype=Float32, value=\n",
      "[[[[ 9.99994993e-01]],\n",
      "  [[ 9.99994993e-01]]]]),)\n",
      "[[[[1.99999 1.99999]\n",
      "   [1.99999 1.99999]]]]\n",
      "-------------\n",
      " [[[[1.99999 1.99999]\n",
      "   [1.99999 1.99999]]]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def backward_hook_function(cell_id, grad_input, grad_output):\n",
    "    print(grad_input)\n",
    "    print(grad_output)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init=\"ones\", pad_mode=\"valid\")\n",
    "        self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init=\"ones\")\n",
    "        self.handle = self.bn.register_backward_hook(backward_hook_function)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net)\n",
    "output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n",
    "print(output)\n",
    "net.handle.remove()\n",
    "output = grad_net(ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32)))\n",
    "print(\"-------------\\n\", output)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当 `register_backward_hook` 函数和 `register_forward_pre_hook` 函数、 `register_forward_hook` 函数同时作用于同一Cell对象时,如果 `register_forward_pre_hook` 和 `register_forward_hook` 函数中有添加其他算子进行数据处理,这些新增算子会在Cell对象执行前或者执行后参与数据的正向计算,但是这些新增算子的反向梯度不在 `register_backward_hook` 函数的捕获范围内。 `register_backward_hook` 中注册的Hook函数仅捕获原始Cell对象的输入、输出梯度。\n",
    "\n",
    "示例代码:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "forward inputs:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "forward outputs:  [2.]\n",
      "grad input:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "grad output:  (Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)\n",
      "(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]))\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)\n",
    "\n",
    "def forward_pre_hook_fn(cell_id, inputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    input_x = inputs[0]\n",
    "    return input_x\n",
    "\n",
    "def forward_hook_fn(cell_id, inputs, outputs):\n",
    "    print(\"forward inputs: \", inputs)\n",
    "    print(\"forward outputs: \", outputs)\n",
    "    outputs = outputs + outputs\n",
    "    return outputs\n",
    "\n",
    "def backward_hook_fn(cell_id, grad_input, grad_output):\n",
    "    print(\"grad input: \", grad_input)\n",
    "    print(\"grad output: \", grad_output)\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "        self.handle = self.relu.register_forward_pre_hook(forward_pre_hook_fn)\n",
    "        self.handle2 = self.relu.register_forward_hook(forward_hook_fn)\n",
    "        self.handle3 = self.relu.register_backward_hook(backward_hook_fn)\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x + y\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "net = Net()\n",
    "grad_net = ms.grad(net, grad_position=(0, 1))\n",
    "gradient = grad_net(ms.Tensor(np.ones([1]).astype(np.float32)), ms.Tensor(np.ones([1]).astype(np.float32)))\n",
    "print(gradient)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里的 `grad_input` 是梯度反向传播时传入`self.relu`的梯度,而不是传入 `forward_hook_fn` 函数中,新增的 `Add` 算子的梯度。这里的 `grad_output` 是梯度反向传播时 `self.relu` 反向输出的梯度,而不是 `forward_pre_hook_fn` 函数中新增 `Add` 算子的反向输出梯度。 `register_forward_pre_hook` 函数和 `register_forward_hook` 函数是在Cell对象执行前后起作用,不会影响Cell对象上反向Hook函数的梯度捕获范围。\n",
    "\n",
    "为了避免脚本在切换到图模式时运行失败,不建议在Cell对象的 `construct` 函数中调用 `register_backward_hook` 函数和 `handle` 对象的 `remove()` 函数。在PyNative模式下,如果在Cell对象的 `construct` 函数中调用 `register_backward_hook` 函数,那么Cell对象每次运行都将新注册一个Hook函数。\n",
    "\n",
    "更多关于Cell对象的 `register_backward_hook` 功能的说明可以参考[API文档](https://mindspore.cn/docs/zh-CN/r2.2/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.register_backward_hook)。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}