{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自定义参数初始化\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.4.10/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.4.10/zh_cn/model_train/custom_program/mindspore_initializer.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.4.10/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.4.10/zh_cn/model_train/custom_program/mindspore_initializer.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.4.10/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.4.10/docs/mindspore/source_zh_cn/model_train/custom_program/initializer.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用内置参数初始化\n",
    "\n",
    "MindSpore提供了多种网络参数初始化的方式,并在部分算子中封装了参数初始化的功能。本节以Conv2d为例,分别介绍如何使用Initializer子类,字符串进行参数初始化。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Initializer初始化\n",
    "\n",
    "`Initializer`是MindSpore内置的参数初始化基类,所有内置参数初始化方法均继承该类。`mindspore.nn`中提供的神经网络层封装均提供`weight_init`、`bias_init`等入参,可以直接使用实例化的Initializer进行参数初始化。样例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-12-29T03:42:22.717822Z",
     "start_time": "2021-12-29T03:42:20.636585Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "import mindspore as ms\n",
    "from mindspore.common.initializer import Normal, initializer\n",
    "\n",
    "input_data = ms.Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))\n",
    "# 卷积层,输入通道为3,输出通道为64,卷积核大小为3*3,权重参数使用正态分布生成的随机数\n",
    "net = nn.Conv2d(3, 64, 3, weight_init=Normal(0.2))\n",
    "# 网络输出\n",
    "output = net(input_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 字符串初始化\n",
    "\n",
    "除使用实例化的Initializer外,MindSpore也提供了参数初始化简易方法,即使用参数初始化方法名称的字符串。此方法使用Initializer的默认参数进行初始化。样例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "import mindspore as ms\n",
    "\n",
    "net = nn.Conv2d(3, 64, 3, weight_init='normal')\n",
    "output = net(input_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义参数初始化\n",
    "\n",
    "通常情况下,MindSpore提供的默认参数初始化可以满足常用神经网络层的初始化需求,在遇到需要自定义的参数初始化方法时,可以继承Initializer自定义参数初始化方法。下面以`XavierNormal`为例介绍自定义参数初始化方法:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-12-29T03:42:22.729232Z",
     "start_time": "2021-12-29T03:42:22.723517Z"
    }
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "from mindspore.common.initializer import Initializer\n",
    "\n",
    "\n",
    "def _calculate_fan_in_and_fan_out(arr):\n",
    "    # 计算fan_in和fan_out。fan_in是 `arr` 中输入单元的数量,fan_out是 `arr` 中输出单元的数量。\n",
    "    shape = arr.shape\n",
    "    dimensions = len(shape)\n",
    "    if dimensions < 2:\n",
    "        raise ValueError(\"'fan_in' and 'fan_out' can not be computed for arr with fewer than\"\n",
    "                         \" 2 dimensions, but got dimensions {}.\".format(dimensions))\n",
    "    if dimensions == 2:  # Linear\n",
    "        fan_in = shape[1]\n",
    "        fan_out = shape[0]\n",
    "    else:\n",
    "        num_input_fmaps = shape[1]\n",
    "        num_output_fmaps = shape[0]\n",
    "        receptive_field_size = 1\n",
    "        for i in range(2, dimensions):\n",
    "            receptive_field_size *= shape[i]\n",
    "        fan_in = num_input_fmaps * receptive_field_size\n",
    "        fan_out = num_output_fmaps * receptive_field_size\n",
    "    return fan_in, fan_out\n",
    "\n",
    "\n",
    "class XavierNormal(Initializer):\n",
    "    def __init__(self, gain=1):\n",
    "        super().__init__()\n",
    "        # 配置初始化所需要的参数\n",
    "        self.gain = gain\n",
    "\n",
    "    def _initialize(self, arr): # arr为需要初始化的Tensor\n",
    "        fan_in, fan_out = _calculate_fan_in_and_fan_out(arr) # 计算fan_in, fan_out值\n",
    "\n",
    "        std = self.gain * math.sqrt(2.0 / float(fan_in + fan_out)) # 根据公式计算std值\n",
    "        data = np.random.normal(0, std, arr.shape) # 使用numpy构造初始化好的ndarray\n",
    "\n",
    "        arr[:] = data[:] # 将初始化好的ndarray赋值到arr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成自定义初始化方法后,我们可以像内置初始化方法一样进行调用:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Conv2d(3, 64, 3, weight_init=XavierNormal())\n",
    "# 网络输出\n",
    "output = net(input_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cell遍历初始化\n",
    "\n",
    "除了使用`weight_init`, `bias_init`等`mindspore.nn`接口提供的入参外,我们也习惯于先构造完整神经网络,然后对`weight`、`bias`等参数进行统一管理。此时需要先构造网络并实例化,然后对Cell进行遍历,并对参数进行赋值。下面是一个简单的样例:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, param in net.parameters_and_names():\n",
    "    if 'weight' in name:\n",
    "        param.set_data(initializer(Normal(), param.shape, param.dtype))\n",
    "    if 'bias' in name:\n",
    "        param.set_data(initializer('zeros', param.shape, param.dtype))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}