{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# 动静态图结合\n",
    "\n",
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0rc2/zh_cn/design/mindspore_dynamic_graph_and_static_graph.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0rc2/zh_cn/design/mindspore_dynamic_graph_and_static_graph.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/mindspore/source_zh_cn/design/dynamic_graph_and_static_graph.ipynb)\n",
    "\n",
    "## 静态图和动态图的概念\n",
    "\n",
    "目前主流的深度学习框架的执行模式有两种,分别为静态图模式和动态图模式。\n",
    "\n",
    "静态图模式下,程序在编译执行时先生成神经网络的图结构,然后再执行图中涉及的计算操作。因此,在静态图模式下,编译器利用图优化等技术对执行图进行更大程度的优化,从而获得更好的执行性能,有助于规模部署和跨平台运行。\n",
    "\n",
    "动态图模式下,程序按照代码的编写顺序执行,在执行正向过程中根据反向传播的原理,动态生成反向执行图。这种模式下,编译器将神经网络中的各个算子逐一下发执行,方便用户编写和调试神经网络模型。\n",
    "\n",
    "## MindSpore静态图\n",
    "\n",
    "在MindSpore中,静态图模式又被称为Graph模式,可以通过`set_context(mode=GRAPH_MODE)`来设置成静态图模式。静态图模式比较适合网络固定且需要高性能的场景。在静态图模式下,基于图优化、计算图整图下沉等技术,编译器可以针对图进行全局的优化,因此在静态图下能获得较好的性能,但是执行图是从源码转换而来,因此在静态图下不是所有的Python语法都能支持。\n",
    "\n",
    "### Graph模式执行原理\n",
    "\n",
    "在Graph模式下,MindSpore通过源码转换的方式,将Python的源码转换成IR,再在此基础上进行相关的图优化,最终在硬件设备上执行优化后的图。MindSpore使用的是一种基于图表示的函数式IR,即MindIR,采用了接近于ANF函数式的语义。Graph模式是基于MindIR进行编译优化的,使用Graph模式时,需要使用`nn.Cell`类并且在`construct`函数中编写执行代码,或者调用`@jit`装饰器。\n",
    "\n",
    "Graph模式的代码用例如下所示:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "import mindspore as ms\n",
    "\n",
    "ms.set_context(mode=ms.GRAPH_MODE, device_target=\"CPU\")\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.mul = ops.Mul()\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        return self.mul(x, y)\n",
    "\n",
    "x = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))\n",
    "y = ms.Tensor(np.array([4.0, 5.0, 6.0]).astype(np.float32))\n",
    "\n",
    "net = Net()\n",
    "print(net(x, y))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[ 4. 10. 18.]\n"
     ]
    }
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T10:51:23.282871Z",
     "start_time": "2022-01-04T10:51:21.743620Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Graph模式自动微分原理\n",
    "\n",
    "在MindSpore中,Graph模式下的自动微分原理可以参考[自动微分](https://www.mindspore.cn/tutorials/zh-CN/r2.3.0rc2/beginner/autograd.html)。\n",
    "\n",
    "## MindSpore动态图\n",
    "\n",
    "在MindSpore中,动态图模式又被称为PyNative模式,可以通过`set_context(mode=PYNATIVE_MODE)`来设置成动态图模式。在脚本开发和网络流程调试中,推荐使用动态图模式进行调试,支持执行单算子、普通函数和网络、以及单独求梯度的操作。\n",
    "\n",
    "### PyNative模式执行原理\n",
    "\n",
    "在PyNative模式下,用户可以使用完整的Python API,此外针对使用MindSpore提供的API时,框架会根据用户选择的硬件平台(Ascend,GPU,CPU),将算子API的操作在对应的硬件平台上执行,并返回相应的结果。框架整体的执行过程如下:\n",
    "\n",
    "![process](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/docs/mindspore/source_zh_cn/design/images/framework.png)\n",
    "\n",
    "通过前端的Python API,调用到框架层,最终到相应的硬件设备上进行计算。例如:完成一个加法。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "import mindspore as ms\n",
    "import mindspore.ops as ops\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"CPU\")\n",
    "x = ms.Tensor(np.ones([1, 3, 3, 4]).astype(np.float32))\n",
    "y = ms.Tensor(np.ones([1, 3, 3, 4]).astype(np.float32))\n",
    "output = ops.add(x, y)\n",
    "print(output.asnumpy())"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[[[[2. 2. 2. 2.]\n",
      "   [2. 2. 2. 2.]\n",
      "   [2. 2. 2. 2.]]\n",
      "\n",
      "  [[2. 2. 2. 2.]\n",
      "   [2. 2. 2. 2.]\n",
      "   [2. 2. 2. 2.]]\n",
      "\n",
      "  [[2. 2. 2. 2.]\n",
      "   [2. 2. 2. 2.]\n",
      "   [2. 2. 2. 2.]]]]\n"
     ]
    }
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T10:51:23.292278Z",
     "start_time": "2022-01-04T10:51:23.284465Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "此例中,当调用到Python接口ops.add(x, y)时,会将Python的接口调用通过Pybind11调用到框架的C++层,转换成C++的调用,接着框架会根据用户设置的device_target选择对应的硬件设备,在该硬件设备上执行add这个操作。\n",
    "\n",
    "从上述原理可以看到,在PyNative模式下,Python脚本代码会根据Python的语法进行执行,而执行过程中涉及到MindSpore的API,会根据用户设置在不同的硬件上进行执行,从而进行加速。因此,在PyNative模式下,用户可以随意使用Python的语法以及调试方法。例如可以使用常见的PyCharm、VS Code等IDE进行代码的调试。\n",
    "\n",
    "### PyNative模式自动微分原理\n",
    "\n",
    "在前面的介绍中,我们可以看出,在PyNative下执行正向过程完全是按照Python的语法进行执行。在PyNative下是基于Tensor进行实现反向传播的,我们在执行正向过程中,将所有应用于Tensor的操作记录下来,并针对每个操作求取其反向,并将所有反向过程串联起来形成整体反向传播图(简称反向图)。最终,将反向图在设备上进行执行计算出梯度。\n",
    "\n",
    "反向构图过程示例,如下代码,对矩阵x乘上固定参数z,然后与y进行矩阵乘法,最终对x进行求导。"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "import numpy as np\n",
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "import mindspore as ms\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"CPU\")\n",
    "\n",
    "class Net(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.matmul = ops.MatMul()\n",
    "        self.z = ms.Parameter(ms.Tensor(np.array([2.0], np.float32)), name='z')\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        x = x * self.z\n",
    "        out = self.matmul(x, y)\n",
    "        return out\n",
    "\n",
    "class GradNetWrtX(nn.Cell):\n",
    "    def __init__(self, net):\n",
    "        super(GradNetWrtX, self).__init__()\n",
    "        self.net = net\n",
    "\n",
    "    def construct(self, x, y):\n",
    "        gradient_function = ms.grad(self.net)\n",
    "        return gradient_function(x, y)\n",
    "\n",
    "x = ms.Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=ms.float32)\n",
    "y = ms.Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=ms.float32)\n",
    "output = GradNetWrtX(Net())(x, y)\n",
    "print(output)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[[9.02      5.4       7.2000003]\n",
      " [9.02      5.4       7.2000003]]\n"
     ]
    }
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T10:51:23.439867Z",
     "start_time": "2022-01-04T10:51:23.293334Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "![forward](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/docs/mindspore/source_zh_cn/design/images/forward.png) ![backward](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/docs/mindspore/source_zh_cn/design/images/backward.png)\n",
    "\n",
    "根据上述PyNative下构图原理,我们可以看到,在正向传播过程中,我们记录了Mul的计算过程,根据Mul对应的反向bprop的定义,得到了反向的MulGrad算子,根据Mul算子的bprop定义,如下:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "from mindspore.ops._grad.grad_base import bprop_getters\n",
    "\n",
    "@bprop_getters.register(ops.Mul)\n",
    "def get_bprop_mul(self):\n",
    "    \"\"\"Grad definition for `Mul` operation.\"\"\"\n",
    "    mul_func = P.Mul()\n",
    "\n",
    "    def bprop(x, y, out, dout):\n",
    "        bc_dx = mul_func(y, dout)\n",
    "        bc_dy = mul_func(x, dout)\n",
    "        return binop_grad_common(x, y, bc_dx, bc_dy)\n",
    "\n",
    "    return bprop"
   ],
   "outputs": [],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T10:51:23.445921Z",
     "start_time": "2022-01-04T10:51:23.441876Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "可以看到对Mul的输入求反向,需要两个输入和输出的反向传播梯度值,此时根据实际的输入值,可以将z连接到MulGrad。以此类推,对下一个算子Matmul,相应的得到MatmulGrad信息,再根据bprop的输入输出,将上下文梯度传播连接起来。\n",
    "\n",
    "同理对于输入y求导,可以使用同样的过程进行推导。\n",
    "\n",
    "### PyNative模式下的控制流\n",
    "\n",
    "在PyNative模式下,脚本按照Python的语法执行,因此在MindSpore中,针对控制流语法并没有做特殊处理,直接按照Python的语法直接展开执行,进而对展开的执行算子进行自动微分的操作。例如,对于for循环,在PyNative下会根据具体的循环次数,不断的执行for循环中的语句,并对其算子进行自动微分的操作。\n",
    "\n",
    "## 动静统一\n",
    "\n",
    "### 概述\n",
    "\n",
    "当前在业界支持动态图和静态图两种模式,动态图通过解释执行,具有动态语法亲和性,表达灵活;静态图使用jit编译优化执行,偏静态语法,在语法上有较多限制。动态图和静态图的编译流程不一致,语法约束不一致。MindSpore针对动态图和静态图模式,首先统一API表达,在两种模式下使用相同的API;其次统一动态图和静态图的底层微分机制。\n",
    "\n",
    "![dynamic](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3.0rc2/docs/mindspore/source_zh_cn/design/images/dynamic.png)\n",
    "\n",
    "### 动态图和静态图互相转换\n",
    "\n",
    "在MindSpore中,我们可以通过控制模式输入参数来切换执行使用动态图还是静态图。例如:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "ms.set_context(mode=ms.PYNATIVE_MODE)"
   ],
   "outputs": [],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T10:51:23.461198Z",
     "start_time": "2022-01-04T10:51:23.447508Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "由于在静态图下,对于Python语法有所限制,因此从动态图切换成静态图时,需要符合静态图的语法限制,才能正确使用静态图来进行执行。更多静态图的语法限制可以参考[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)。\n",
    "\n",
    "### 动静结合\n",
    "\n",
    "MindSpore支持在动态图下使用静态编译的方式来进行混合执行,通过使用jit修饰需要用静态图来执行的函数对象,即可实现动态图和静态图的混合执行,更多jit的使用可参考[jit文档](https://www.mindspore.cn/tutorials/zh-CN/r2.3.0rc2/beginner/accelerate_with_static_graph.html#基于装饰器的开启方式)。\n",
    "\n",
    "例如:"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "\n",
    "class AddMulMul(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(AddMulMul, self).__init__()\n",
    "        self.param = ms.Parameter(ms.Tensor(0.5, ms.float32))\n",
    "\n",
    "    @ms.jit\n",
    "    def construct(self, x):\n",
    "        x = x + x\n",
    "        x = x * self.param\n",
    "        x = x * x\n",
    "        return x\n",
    "\n",
    "class CellCallSingleCell(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super(CellCallSingleCell, self).__init__()\n",
    "        self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init=\"ones\", pad_mode=\"valid\")\n",
    "        self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init=\"ones\")\n",
    "        self.relu = nn.ReLU()\n",
    "        self.add_mul_mul = AddMulMul()\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.add_mul_mul(x)\n",
    "        x = self.relu(x)\n",
    "        return x\n",
    "\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"CPU\")\n",
    "inputs = ms.Tensor(np.ones([1, 1, 2, 2]).astype(np.float32))\n",
    "net = CellCallSingleCell()\n",
    "out = net(inputs)\n",
    "print(out)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[[[[15.99984]]\n",
      "\n",
      "  [[15.99984]]]]\n"
     ]
    }
   ],
   "metadata": {
    "ExecuteTime": {
     "end_time": "2022-01-04T10:51:23.514919Z",
     "start_time": "2022-01-04T10:51:23.462207Z"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 静态图语法增强技术\n",
    "\n",
    "在MindSpore静态图模式下,用户编写程序时需要遵循MindSpore[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html),语法使用存在约束限制。而在动态图模式下,Python脚本代码会根据Python语法进行执行,用户可以使用任意Python语法。可以看出,静态图和动态图的语法约束限制是不同的。\n",
    "\n",
    "JIT Fallback是从静态图的角度出发考虑静态图和动态图的统一。在编译过程中发现不支持的语法时,将该语法Fallback到Python解释器进行解释执行。通过JIT Fallback特性,静态图可以支持尽量多的动态图语法,使得静态图提供接近动态图的语法使用体验,从而实现动静统一。\n",
    "\n",
    "在图模式场景下,MindSpore框架在图编译过程中遇到不支持语法或不支持符号时会报错,多数是在类型推导阶段。图编译阶段首先对用户编写的Python源码进行解析,然后执行后续的静态分析、类型推导、优化等步骤。因此,JIT Fallback特性需要预先检测不支持语法。常见的不支持语法主要包括:调用第三方库的方法、调用类名来创建对象、调用未支持的Python内置函数等。对不支持的语法Fallback到Python解释器进行解释执行。由于图模式采用[中间表示MindIR](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/design/all_scenarios.html#中间表示mindir),需要将解释执行的语句转换到中间表示,记录下解释器需要的信息。\n",
    "\n",
    "下面主要介绍使用JIT Fallback扩展支持的静态图语法。JIT语法支持级别选项jit_syntax_level的默认值为`LAX`,即使用JIT Fallback的能力扩展静态图语法。\n",
    "\n",
    "#### 调用第三方库\n",
    "\n",
    "完善支持NumPy、SciPy等第三方库。静态图模式支持np.ndarray等众多第三方库数据类型及其运算操作,支持获取调用第三方库的属性和方法,并支持通过Tensor的asnumpy()等方法与NumPy等三方库进行交互处理。也就是说,用户可以在静态图模式下调用MindSpore自身接口和算子,或者直接调用三方库的接口,也可以把它们融合在一起使用。\n",
    "\n",
    "- 支持第三方库(如NumPy、SciPy等)的数据类型,允许调用和返回第三方库的对象。\n",
    "- 支持调用第三方库的方法。\n",
    "- 支持使用NumPy第三方库数据类型创建Tensor对象。\n",
    "- 暂不支持对第三方库数据类型的下标索引赋值。\n",
    "\n",
    "更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[调用第三方库](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#调用第三方库)。\n",
    "\n",
    "#### 支持自定义类的使用\n",
    "\n",
    "用户自定义的类,没有使用`@jit_class`修饰,也没有继承`nn.Cell`。通过JIT Fallback技术方案,静态图模式下允许创建和引用自定义类的实例,可以支持直接获取和调用自定义类实例的属性和方法,并且允许修改属性(Inplace操作)。\n",
    "\n",
    "更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[支持自定义类的使用](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#支持自定义类的使用)。\n",
    "\n",
    "#### 基础运算符支持更多数据类型\n",
    "\n",
    "在静态图语法重载了以下运算符: ['+', '-', '*', '/', '//', '%', '**', '<<', '>>', '&', '|', '^', 'not', '==', '!=', '<', '>', '<=', '>=', 'in', 'not in', 'y=x[0]']。图模式重载的运算符详见[运算符](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax/operators.html)。列表中的运算符在输入图模式中不支持的输入类型时将使用扩展静态图语法支持,并使输出结果与动态图模式下的输出结果一致。\n",
    "\n",
    "更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[基础运算符支持更多数据类型](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#基础运算符支持更多数据类型)。\n",
    "\n",
    "#### 基础类型\n",
    "\n",
    "扩展对Python原生数据类型`List`、`Dictionary`、`None`的支持。更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[基础类型](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#基础类型)。\n",
    "\n",
    "##### 支持列表就地修改操作\n",
    "\n",
    "- 支持从全局变量中获取原`List`对象。\n",
    "- 不支持对输入`List`对象进行inplace操作。\n",
    "- 支持部分`List`内置函数的就地修改操作。\n",
    "\n",
    "##### 支持Dictionary的高阶用法\n",
    "\n",
    "- 支持顶图返回Dictionary。\n",
    "- 支持Dictionary索引取值和赋值。\n",
    "\n",
    "##### 支持使用None\n",
    "\n",
    "`None`是Python中的一个特殊值,表示空,可以赋值给任何变量。对于没有返回值语句的函数认为返回`None`。同时也支持`None`作为顶图或者子图的入参或者返回值。支持`None`作为切片的下标,作为`List`、`Tuple`、`Dictionary`的输入。\n",
    "\n",
    "#### 内置函数支持更多数据类型\n",
    "\n",
    "扩展内置函数的支持范围。Python内置函数完善支持更多输入类型,例如第三方库数据类型。更多内置函数的支持情况可见[Python内置函数](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax/python_builtin_functions.html)章节。\n",
    "\n",
    "#### 支持控制流\n",
    "\n",
    "为了提高Python标准语法支持度,实现动静统一,扩展支持更多数据类型在控制流语句的使用。控制流语句是指`if`、`for`、`while`等流程控制语句。理论上,通过扩展支持的语法,在控制流场景中也支持。更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[支持控制流](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#支持控制流)。\n",
    "\n",
    "#### 支持属性设置与修改\n",
    "\n",
    "支持更多类型的Inplace操作。之前版本只支持通过Inplace算子对Parameter类型进行值修改,在MindSpore2.1版本静态图模式下,支持了自定义类,Cell子类,jit_class类的属性修改。除了支持更改类self属性和全局变量的属性以外,也新增支持对List类型的extend()、reverse()、insert()、pop()等Inplace操作。更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[支持属性设置与修改](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#支持属性设置与修改)。\n",
    "\n",
    "- 对自定义类对象以及第三方类型的属性进行设置与修改。\n",
    "- 对Cell的self对象进行修改。\n",
    "- 对静态图内的Cell对象以及jit_class对象进行设置与修改。\n",
    "\n",
    "#### 支持求导\n",
    "\n",
    "使用JIT Fallback支持的静态图语法,同样支持其在求导中使用。更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[支持求导](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#支持求导)。\n",
    "\n",
    "#### Annotation Type\n",
    "\n",
    "对于运行时的JIT Fallback支持的语法,会产生一些无法被类型推导出的节点,这种类型称为`Any`类型。因为该类型无法在编译时推导出正确的类型,所以这种`Any`将会以一种默认最大精度`float64`进行运算,防止其精度丢失。为了能更好的优化相关性能,需要减少`Any`类型数据的产生。当用户可以明确知道当前通过扩展支持的语句会产生具体类型的时候,我们推荐使用`Annotation @jit.typing:`的方式进行指定对应Python语句类型,从而确定解释节点的类型避免`Any`类型的生成。更多使用可见[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html)中的[Annotation Type](https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/note/static_graph_syntax_support.html#annotation-type)。\n",
    "\n",
    "#### 使用须知\n",
    "\n",
    "在使用静态图JIT Fallback扩展支持语法时,请注意以下几点:\n",
    "\n",
    "1. 对标动态图的支持能力,即:须在动态图语法范围内,包括但不限于数据类型等。\n",
    "\n",
    "2. 在扩展静态图语法时,支持了更多的语法,但执行性能可能会受影响,不是最佳。\n",
    "\n",
    "3. 在扩展静态图语法时,支持了更多的语法,由于使用Python的能力,不能使用MindIR导入导出的能力。\n",
    "\n",
    "4. 暂不支持跨Python文件重复定义同名的全局变量,且这些全局变量在网络中会被用到。"
   ],
   "metadata": {},
   "attachments": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 动态图转静态图技术\n",
    "\n",
    "MindSpore提供了一种在代码无需修改的情况下,直接将用户的动态图代码转换成静态图的功能——PIJit。该功能同时兼顾性能和易用性,去除动静模式切换的代价,真正做到动静统一。它基于Python字节码的分析,对Python的执行流进行图捕获,可以以静态图方式运行的子图便以静态图方式运行,Python语法不支持的子图便以动态图方式运行,同时通过修改调整字节码的方式链接静态图,达到动静混合执行。在满足易用性的前提下,尽可能地提高性能。\n",
    "\n",
    "#### PIJit包含以下功能\n",
    "\n",
    "- 1. 图捕获:对字节码预处理,动态跟踪解释执行,识别MindSpore可入图操作,提供裂图功能保证函数(字节码)功能的正确性。\n",
    "- 2. 字节码支持:当前支持Python3.7和Python3.9的版本的字节码,并计划支持Python3.8与Python3.10。\n",
    "- 3. 图优化:对图中生成的字节码进行优化,包括分支裁剪、字节码筛选、函数字节码内联等功能。\n",
    "- 4. 异常捕获机制:支持with、try-except语法。\n",
    "- 5. 支持循环处理:通过模拟字节码的操作栈实现图捕获、裂图等特性。\n",
    "- 6. UD分析:通过变量的user-def链分析的方法,解决部分参数类型不能作为静态图的返回值问题(函数、Bool、None),同时减少无用的入参,提高图的执行效率,减少数据的拷贝。\n",
    "- 7. 副作用分析处理:弥补静态图的副作用处理上的劣势,根据不同场景,收集记录产生副作用的变量及字节码,在保证程序语义的基础上,在静态图外补充副作用的处理。\n",
    "- 8. 守护门禁:门禁(Guard)记录了子图/优化进入的输入需要满足的条件,检查输入是否适合对应的子图优化。\n",
    "- 9. Cache:图管理(Cache)则缓存了子图/优化和门禁(Guard)对应关系。\n",
    "- **未来计划支持Symbolic Shape**\n",
    "\n",
    "#### 使用方式\n",
    "\n",
    "def jit(fn=None, input_signature=None, hash_args=None, jit_config=None, mode=\"PIJit\"):\n",
    "\n",
    "原Jit功能使用mode=\"PSJit\",新特性PIJit使用mode=\"PIJit\",jit_config传递一个参数字典,可以提供一些优化和调试选项。如:print_after_all可以打印入图的字节码和裂图信息,loop_unrolling可以提供循环展开功能。\n",
    "\n",
    "#### 使用限制\n",
    "\n",
    "- 不支持在静态图模式下,运行带装饰@jit(mode=\\\"PIJit\\\")的函数,此时该装饰@jit(mode=\\\"PIJit\\\")视为无效。\n",
    "- 不支持在@jit(mode=\\\"PSJit\\\")装饰的函数内部调用带装饰@jit(mode=\\\"PIJit\\\")的函数,该装饰 @jit(mode=\\\"PIJit\\\")视为无效。\n",
    "\n"
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}