{
    "cells": [
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "# 自定义神经网络层\n",
       "\n",
       "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/zh_cn/model_train/custom_program/mindspore_network_custom.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.5.0/zh_cn/model_train/custom_program/mindspore_network_custom.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.5.0/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindspore/source_zh_cn/model_train/custom_program/network_custom.ipynb)"
      ]
     },
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "通常情况下,MindSpore提供的神经网络层接口和function函数接口能够满足模型构造需求,但由于AI领域不断推陈出新,因此有可能遇到新网络结构没有内置模块的情况。此时我们可以根据需要,通过MindSpore提供的function接口、Primitive算子自定义神经网络层,并可以使用`Cell.bprop`方法自定义反向。下面分别详述三种自定义方法。\n",
       "\n",
       "## 使用function接口构造神经网络层\n",
       "\n",
       "MindSpore提供大量基础的function接口,可以使用其构造复杂的Tensor操作,封装为神经网络层。下面以`Threshold`为例,其公式如下:\n",
       "\n",
       "$$\n",
       "y =\\begin{cases}\n",
       "   x, &\\text{ if } x > \\text{threshold} \\\\\n",
       "   \\text{value}, &\\text{ otherwise }\n",
       "   \\end{cases}\n",
       "$$\n",
       "\n",
       "可以看到`Threshold`判断Tensor的值是否大于`threshold`值,保留判断结果为`True`的值,替换判断结果为`False`的值。因此,对应实现如下:"
      ]
     },
     {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {},
      "outputs": [],
      "source": [
       "import mindspore\n",
       "import numpy as np\n",
       "from mindspore import nn, ops, Tensor, Parameter\n",
       "class Threshold(nn.Cell):\n",
       "    def __init__(self, threshold, value):\n",
       "        super().__init__()\n",
       "        self.threshold = threshold\n",
       "        self.value = value\n",
       "\n",
       "    def construct(self, inputs):\n",
       "        cond = ops.gt(inputs, self.threshold)\n",
       "        value = ops.fill(inputs.dtype, inputs.shape, self.value)\n",
       "        return ops.select(cond, inputs, value)"
      ]
     },
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "这里分别使用了`ops.gt`、`ops.fill`、`ops.select`来实现判断和替换。下面执行自定义的`Threshold`层:"
      ]
     },
     {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {},
      "outputs": [
       {
        "data": {
         "text/plain": [
          "Tensor(shape=[3], dtype=Float32, value= [ 2.00000000e+01,  2.00000003e-01,  3.00000012e-01])"
         ]
        },
        "execution_count": 45,
        "metadata": {},
        "output_type": "execute_result"
       }
      ],
      "source": [
       "m = Threshold(0.1, 20)\n",
       "inputs = mindspore.Tensor([0.1, 0.2, 0.3], mindspore.float32)\n",
       "m(inputs)"
      ]
     },
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "可以看到`inputs[0] = threshold`, 因此被替换为`20`。"
      ]
     },
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "## 自定义Cell反向\n",
       "\n",
       "在特殊场景下,我们不但需要自定义神经网络层的正向逻辑,也需要手动控制其反向的计算,此时我们可以通过`Cell.bprop`接口对其反向进行定义。在全新的神经网络结构设计、反向传播速度优化等场景下会用到该功能。下面我们以`Dropout2d`为例,介绍如何自定义Cell反向:"
      ]
     },
     {
      "cell_type": "code",
      "execution_count": 55,
      "metadata": {},
      "outputs": [],
      "source": [
       "class Dropout2d(nn.Cell):\n",
       "    def __init__(self, keep_prob):\n",
       "        super().__init__()\n",
       "        self.keep_prob = keep_prob\n",
       "        self.dropout2d = ops.Dropout2D(keep_prob)\n",
       "\n",
       "    def construct(self, x):\n",
       "        return self.dropout2d(x)\n",
       "\n",
       "    def bprop(self, x, out, dout):\n",
       "        _, mask = out\n",
       "        dy, _ = dout\n",
       "        if self.keep_prob != 0:\n",
       "            dy = dy * (1 / self.keep_prob)\n",
       "        dy = mask.astype(mindspore.float32) * dy\n",
       "        return (dy.astype(x.dtype), )\n",
       "\n",
       "dropout_2d = Dropout2d(0.8)\n",
       "dropout_2d.bprop_debug = True"
      ]
     },
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "`bprop`方法分别有三个入参:\n",
       "\n",
       "- *x*: 正向输入,当正向输入为多个时,需同样数量的入参。\n",
       "- *out*: 正向输出。\n",
       "- *dout*: 反向传播时,当前Cell执行之前的反向结果。\n",
       "\n",
       "一般我们需要根据正向输出和前层反向结果配合,根据反向求导公式计算反向结果,并将其返回。`Dropout2d`的反向计算需要根据正向输出的`mask`矩阵对前层反向结果进行mask,然后根据`keep_prob`进行缩放。最终可得到正确的计算结果。"
      ]
     },
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "自定义Cell反向时,在PyNative模式下支持拓展写法,可以对Cell内部的权重求导,具体列子如下:"
      ]
     },
     {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
       "class NetWithParam(nn.Cell):\n",
       "    def __init__(self):\n",
       "        super(NetWithParam, self).__init__()\n",
       "        self.w = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name='weight')\n",
       "        self.internal_params = [self.w]\n",
       "\n",
       "    def construct(self, x):\n",
       "        output = self.w * x\n",
       "        return output\n",
       "\n",
       "    def bprop(self, *args):\n",
       "        return (self.w * args[-1],), {self.w: args[0] * args[-1]}"
      ]
     },
     {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
       "`bprop`方法支持*args入参,args数组中最后一位`args[-1]`为返回给该cell的梯度。通过`self.internal_params`设置求导的权重,同时在`bprop`函数的返回值为一个元组和一个字典,返回输入对应梯度的元组,以及以key为权重,value为权重对应梯度的字典。"
      ]
     }
    ],
    "metadata": {
     "kernelspec": {
      "display_name": "MindSpore",
      "language": "python",
      "name": "mindspore"
     },
     "language_info": {
      "codemirror_mode": {
       "name": "ipython",
       "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.5"
     }
    },
    "nbformat": 4,
    "nbformat_minor": 4
   }