{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Cell\n",
    "\n",
    "`Ascend` `GPU` `CPU` `入门`\n",
    "\n",
    "[![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvcHJvZ3JhbW1pbmdfZ3VpZGUvemhfY24vbWluZHNwb3JlX2NlbGwuaXB5bmI=&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_cell.ipynb) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/r1.5/programming_guide/zh_cn/mindspore_cell.py) [![](https://gitee.com/mindspore/docs/raw/r1.5/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.5/docs/mindspore/programming_guide/source_zh_cn/cell.ipynb)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 概述\n",
    "\n",
    "MindSpore的`Cell`类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,需要继承`Cell`类,并重写`__init__`方法和`construct`方法。\n",
    "\n",
    "损失函数、优化器和模型层等本质上也属于网络结构,也需要继承`Cell`类才能实现功能,同样用户也可以根据业务需求自定义这部分内容。\n",
    "\n",
    "本节内容介绍`Cell`类的关键成员函数,“网络构建”中将介绍基于`Cell`实现的MindSpore内置损失函数、优化器和模型层及使用方法,以及通过实例介绍如何利用`Cell`类构建自定义网络。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 关键成员函数\n",
    "\n",
    "### construct方法\n",
    "\n",
    "`Cell`类重写了`__call__`方法,在`Cell`类的实例被调用时,会执行`construct`方法。网络结构在`construct`方法里面定义。\n",
    "\n",
    "下面的样例中,我们构建了一个简单的网络实现卷积计算功能。构成网络的算子在`__init__`中定义,在`construct`方法里面使用,用例的网络结构为`Conv2d` -> `BiasAdd`。\n",
    "\n",
    "在`construct`方法中,`x`为输入数据,`output`是经过网络结构计算后得到的计算结果。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "from mindspore import Parameter\n",
    "from mindspore.common.initializer import initializer\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self, in_channels=10, out_channels=20, kernel_size=3):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv2d = ops.Conv2D(out_channels, kernel_size)\n",
    "        self.bias_add = ops.BiasAdd()\n",
    "        self.weight = Parameter(\n",
    "            initializer('normal', [out_channels, in_channels, kernel_size, kernel_size]),\n",
    "            name='conv.weight')\n",
    "        self.bias = Parameter(initializer('normal', [out_channels]), name='conv.bias')\n",
    "\n",
    "    def construct(self, x):\n",
    "        output = self.conv2d(x, self.weight)\n",
    "        output = self.bias_add(output, self.bias)\n",
    "        return output"
   ],
   "outputs": [],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-02-08T01:01:31.855049Z",
     "start_time": "2021-02-08T01:01:31.084345Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### parameters_dict\n",
    "\n",
    "`parameters_dict`方法识别出网络结构中所有的参数,返回一个以key为参数名,value为参数值的`OrderedDict`。\n",
    "\n",
    "`Cell`类中返回参数的方法还有许多,例如`get_parameters`、`trainable_params`等,具体使用方法可以参见[API文档](https://www.mindspore.cn/docs/api/zh-CN/r1.5/api_python/nn/mindspore.nn.Cell.html)。\n",
    "\n",
    "代码样例如下:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "net = Net()\n",
    "result = net.parameters_dict()\n",
    "print(result.keys())\n",
    "print(result['conv.weight'])"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "odict_keys(['conv.weight'])\n",
      "Parameter (name=conv.weight, shape=(20, 10, 3, 3), dtype=Float32, requires_grad=True)\n"
     ]
    }
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-02-08T01:01:31.867924Z",
     "start_time": "2021-02-08T01:01:31.856066Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "样例中的`Net`采用上文构造网络的用例,打印了网络中所有参数的名字和`weight`参数的结果。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### cells_and_names\n",
    "\n",
    "`cells_and_names`方法是一个迭代器,返回网络中每个`Cell`的名字和它的内容本身。\n",
    "\n",
    "用例简单实现了获取与打印每个`Cell`名字的功能,其中根据网络结构可知,存在1个`Cell`为`nn.Conv2d`。\n",
    "\n",
    "其中`nn.Conv2d`是`MindSpore`以Cell为基类封装好的一个卷积层,其具体内容将在“模型层”中进行介绍。\n",
    "\n",
    "代码样例如下:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "class Net1(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net1, self).__init__()\n",
    "        self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')\n",
    "\n",
    "    def construct(self, x):\n",
    "        out = self.conv(x)\n",
    "        return out\n",
    "\n",
    "net = Net1()\n",
    "names = []\n",
    "for m in net.cells_and_names():\n",
    "    print(m)\n",
    "    names.append(m[0]) if m[0] else None\n",
    "print('-------names-------')\n",
    "print(names)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "('', Net1<\n",
      "  (conv): Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>\n",
      "  >)\n",
      "('conv', Conv2d<input_channels=3, output_channels=64, kernel_size=(3, 3), stride=(1, 1), pad_mode=same, padding=0, dilation=(1, 1), group=1, has_bias=False, weight_init=normal, bias_init=zeros, format=NCHW>)\n",
      "-------names-------\n",
      "['conv']\n"
     ]
    }
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-02-08T01:01:31.893191Z",
     "start_time": "2021-02-08T01:01:31.870508Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### set_grad\n",
    "\n",
    "`set_grad`用于指定网络是否需要计算反向。在不传入参数调用时,默认设置`requires_grad`为True,在执行前向网络时将会构建用于计算梯度的反向网络。\n",
    "\n",
    "以`TrainOneStepCell`为例,其接口功能是使网络进行单步训练,需要计算网络反向,因此初始化方法里需要使用`set_grad`。\n",
    "\n",
    "`TrainOneStepCell`部分代码如下:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "```python\n",
    "class TrainOneStepCell(Cell):\n",
    "    def __init__(self, network, optimizer, sens=1.0):\n",
    "        super(TrainOneStepCell, self).__init__(auto_prefix=False)\n",
    "        self.network = network\n",
    "        self.network.set_grad()\n",
    "        ......\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "如果用户使用`TrainOneStepCell`、`GradOperation`等类似接口,无需使用`set_grad`,内部已封装实现。\n",
    "\n",
    "若用户需要自定义此类训练功能的接口,需要在其内部调用,或者在外部设置`network.set_grad`。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### set_train\n",
    "\n",
    "`set_train`接口递归地配置了当前`Cell`和所有子`Cell`的training属性,在不传入参数调用时,默认设置的training属性为True。\n",
    "\n",
    "在实现训练和推理结构不同的网络时可以通过`training`属性区分训练和推理场景,在网络运行时结合`set_train`可以切换网络的执行逻辑。\n",
    "\n",
    "例如`nn.Dropout`部分代码如下:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "```py\n",
    "class Dropout(Cell):\n",
    "    def __init__(self, keep_prob=0.5, dtype=mstype.float32):\n",
    "        \"\"\"Initialize Dropout.\"\"\"\n",
    "        super(Dropout, self).__init__()\n",
    "        self.dropout = ops.Dropout(keep_prob, seed0, seed1)\n",
    "        ......\n",
    "\n",
    "    def construct(self, x):\n",
    "        if not self.training:\n",
    "            return x\n",
    "\n",
    "        if self.keep_prob == 1:\n",
    "            return x\n",
    "\n",
    "        out, _ = self.dropout(x)\n",
    "        return out\n",
    "```"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "`nn.Dropout`中根据`Cell`的training属性区分了两种执行逻辑,training为False时直接返回输入,training为True时执行`Dropout`算子。因此在定义网络时需要根据训练和推理场景设置网络的执行模式,以`nn.Dropout`为例:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import mindspore.nn as nn\n",
    "net = nn.Dropout()\n",
    "# 执行训练\n",
    "net.set_train()\n",
    "# 执行推理\n",
    "net.set_train(False)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### to_float\n",
    "\n",
    "`to_float`接口递归地在配置了当前`Cell`和所有子`Cell`的强制转换类型,以使当前网络结构以使用特定的float类型运行。通常在混合精度场景使用。\n",
    "\n",
    "`to_float`与混合精度详见<https://www.mindspore.cn/docs/programming_guide/zh-CN/r1.5/enable_mixed_precision.html>。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## nn模块与ops模块的关系\n",
    "\n",
    "MindSpore的nn模块是Python实现的模型组件,是对低阶API的封装,主要包括各种模型层、损失函数、优化器等。\n",
    "\n",
    "同时nn也提供了部分与`Primitive`算子同名的接口,主要作用是对`Primitive`算子进行进一步封装,为用户提供更友好的API。\n",
    "\n",
    "重新分析上文介绍`construct`方法的用例,此用例是MindSpore的`nn.Conv2d`源码简化内容,内部会调用`ops.Conv2D`。`nn.Conv2d`卷积API增加输入参数校验功能并判断是否`bias`等,是一个高级封装的模型层。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "import mindspore.nn as nn\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv = nn.Conv2d(10, 20, 3, has_bias=True, weight_init='normal')\n",
    "\n",
    "    def construct(self, x):\n",
    "        out = self.conv(x)\n",
    "        return out\n"
   ],
   "outputs": [],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-02-08T01:01:31.916550Z",
     "start_time": "2021-02-08T01:01:31.894206Z"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}