{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "69a92ef2",
   "metadata": {},
   "source": [
    "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.4.1/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.4.1/tutorials/zh_cn/beginner/mindspore_accelerate_with_static_graph.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.4.1/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.4.1/tutorials/zh_cn/beginner/mindspore_accelerate_with_static_graph.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.4.1/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.4.1/tutorials/source_zh_cn/beginner/accelerate_with_static_graph.ipynb)\n",
    "\n",
    "[基本介绍](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/introduction.html) || [快速入门](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/quick_start.html) || [张量 Tensor](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/tensor.html) || [数据加载与处理](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/dataset.html) || [网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/model.html) || [函数式自动微分](https://gitee.com/mindspore/docs/blob/r2.4.1/tutorials/source_zh_cn/beginner/autograd.ipynb) || [模型训练](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/train.html) || [保存与加载](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/save_load.html) || **使用静态图加速** || [自动混合精度](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/mixed_precision.html) ||\n",
    "\n",
    "# 使用静态图加速\n",
    "\n",
    "## 背景介绍\n",
    "\n",
    "AI编译框架分为两种运行模式,分别是动态图模式以及静态图模式。MindSpore默认情况下是以动态图模式运行,但也支持手工切换为静态图模式。两种运行模式的详细介绍如下:\n",
    "\n",
    "### 动态图模式\n",
    "\n",
    "动态图的特点是计算图的构建和计算同时发生(Define by run),其符合Python的解释执行方式,在计算图中定义一个Tensor时,其值就已经被计算且确定,因此在调试模型时较为方便,能够实时得到中间结果的值,但由于所有节点都需要被保存,导致难以对整个计算图进行优化。\n",
    "\n",
    "在MindSpore中,动态图模式又被称为PyNative模式。由于动态图的解释执行特性,在脚本开发和网络流程调试过程中,推荐使用动态图模式进行调试。\n",
    "如需要手动控制框架采用PyNative模式,可以通过以下代码进行网络构建:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6dc8d4c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " ...\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]\n",
      " [-0.00134926 -0.13563682 -0.02863023 -0.05452826  0.03290743 -0.12423715\n",
      "  -0.0582641  -0.10854103 -0.08558805  0.06099342]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import nn, Tensor\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)  # 使用set_context进行动态图模式的配置\n",
    "\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits\n",
    "\n",
    "model = Network()\n",
    "input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))\n",
    "output = model(input)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a9a195c",
   "metadata": {},
   "source": [
    "### 静态图模式\n",
    "\n",
    "相较于动态图而言,静态图的特点是将计算图的构建和实际计算分开(Define and run)。有关静态图模式的运行原理,可以参考[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.4.1/model_train/program_form/static_graph.html#概述)。\n",
    "\n",
    "在MindSpore中,静态图模式又被称为Graph模式,在Graph模式下,基于图优化、计算图整图下沉等技术,编译器可以针对图进行全局的优化,获得较好的性能,因此比较适合网络固定且需要高性能的场景。\n",
    "\n",
    "如需要手动控制框架采用静态图模式,可以通过以下代码进行网络构建:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2476e059",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]\n",
      " [ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]\n",
      " [ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]\n",
      " [ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]\n",
      " ...\n",
      " [ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]\n",
      " [ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]\n",
      " [ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]\n",
      " [ 0.05363735  0.05117104 -0.03343301  0.06347139  0.07546629  0.03263091\n",
      "   0.02790363  0.06269836  0.01838502  0.04387159]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import nn, Tensor\n",
    "ms.set_context(mode=ms.GRAPH_MODE)  # 使用set_context进行运行静态图模式的配置\n",
    "\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits\n",
    "\n",
    "model = Network()\n",
    "input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))\n",
    "output = model(input)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89825897",
   "metadata": {},
   "source": [
    "## 静态图模式的使用场景\n",
    "\n",
    "MindSpore编译器重点面向Tensor数据的计算以及其微分处理。因此使用MindSpore API以及基于Tensor对象的操作更适合使用静态图编译优化。其他操作虽然可以部分入图编译,但实际优化作用有限。另外,静态图模式先编译后执行的模式导致其存在编译耗时。因此,如果函数无需反复执行,那么使用静态图加速也可能没有价值。\n",
    "\n",
    "有关使用静态图来进行网络编译的示例,请参考[网络构建](https://www.mindspore.cn/tutorials/zh-CN/r2.4.1/beginner/model.html)。\n",
    "\n",
    "## 静态图模式开启方式\n",
    "\n",
    "通常情况下,由于动态图的灵活性,我们会选择使用PyNative模式来进行自由的神经网络构建,以实现模型的创新和优化。但是当需要进行性能加速时,我们需要对神经网络部分或整体进行加速。MindSpore提供了两种切换为图模式的方式,分别是基于装饰器的开启方式以及基于全局context的开启方式。\n",
    "\n",
    "### 基于装饰器的开启方式\n",
    "\n",
    "MindSpore提供了jit装饰器,可以通过修饰Python函数或者Python类的成员函数使其被编译成计算图,通过图优化等技术提高运行速度。此时我们可以简单的对想要进行性能优化的模块进行图编译加速,而模型其他部分,仍旧使用解释执行方式,不丢失动态图的灵活性。无论全局context是设置成静态图模式还是动态图模式,被jit修饰的部分始终会以静态图模式进行运行。\n",
    "\n",
    "在需要对Tensor的某些运算进行编译加速时,可以在其定义的函数上使用jit修饰器,在调用该函数时,该模块自动被编译为静态图。需要注意的是,jit装饰器只能用来修饰函数,无法对类进行修饰。jit的使用示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d97a2f69",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]\n",
      " [-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]\n",
      " [-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]\n",
      " [-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]\n",
      " ...\n",
      " [-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]\n",
      " [-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]\n",
      " [-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]\n",
      " [-0.12126954  0.06986676 -0.2230821  -0.07087803 -0.01003947  0.01063392\n",
      "   0.10143848 -0.0200909  -0.09724037  0.0114444 ]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import nn, Tensor\n",
    "\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits\n",
    "\n",
    "input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))\n",
    "\n",
    "@ms.jit  # 使用ms.jit装饰器,使被装饰的函数以静态图模式运行\n",
    "def run(x):\n",
    "    model = Network()\n",
    "    return model(x)\n",
    "\n",
    "output = run(input)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c5adf4d",
   "metadata": {},
   "source": [
    "除使用修饰器外,也可使用函数变换方式调用jit方法,示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7337a0bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]\n",
      " [ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]\n",
      " [ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]\n",
      " [ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]\n",
      " ...\n",
      " [ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]\n",
      " [ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]\n",
      " [ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]\n",
      " [ 0.11027216 -0.09628229  0.0457969   0.05396656 -0.06958974  0.0428197\n",
      "  -0.1572069  -0.14151613 -0.04531277  0.07521383]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import nn, Tensor\n",
    "\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits\n",
    "\n",
    "input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))\n",
    "\n",
    "def run(x):\n",
    "    model = Network()\n",
    "    return model(x)\n",
    "\n",
    "run_with_jit = ms.jit(run)  # 通过调用jit将函数转换为以静态图方式执行\n",
    "output = run_with_jit(input)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df543f53",
   "metadata": {},
   "source": [
    "当我们需要对神经网络的某部分进行加速时,可以直接在construct方法上使用jit修饰器,在调用实例化对象时,该模块自动被编译为静态图。示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e51b32f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]\n",
      " [ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]\n",
      " [ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]\n",
      " [ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]\n",
      " ...\n",
      " [ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]\n",
      " [ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]\n",
      " [ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]\n",
      " [ 0.10522258  0.06597593 -0.09440921 -0.04883489  0.07194916  0.1343117\n",
      "  -0.06813788  0.01986085  0.0216996  -0.05345828]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import nn, Tensor\n",
    "\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    @ms.jit  # 使用ms.jit装饰器,使被装饰的函数以静态图模式运行\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits\n",
    "\n",
    "input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))\n",
    "model = Network()\n",
    "output = model(input)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8aef1df",
   "metadata": {},
   "source": [
    "### 基于context的开启方式\n",
    "\n",
    "context模式是一种全局的设置模式。代码示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bb1d7b61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]\n",
      " [ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]\n",
      " [ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]\n",
      " [ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]\n",
      " ...\n",
      " [ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]\n",
      " [ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]\n",
      " [ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]\n",
      " [ 0.08501796 -0.04404321 -0.05165704  0.00357929  0.00051521  0.00946456\n",
      "   0.02748473 -0.19415936 -0.00278988  0.04024826]]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import nn, Tensor\n",
    "ms.set_context(mode=ms.GRAPH_MODE)  # 使用set_context进行运行静态图模式的配置\n",
    "\n",
    "class Network(nn.Cell):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.dense_relu_sequential = nn.SequentialCell(\n",
    "            nn.Dense(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dense(512, 10)\n",
    "        )\n",
    "\n",
    "    def construct(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.dense_relu_sequential(x)\n",
    "        return logits\n",
    "\n",
    "model = Network()\n",
    "input = Tensor(np.ones([64, 1, 28, 28]).astype(np.float32))\n",
    "output = model(input)\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b863f0f7",
   "metadata": {},
   "source": [
    "## 静态图的语法约束\n",
    "\n",
    "在Graph模式下,Python代码并不是由Python解释器去执行,而是将代码编译成静态计算图,然后执行静态计算图。因此,编译器无法支持全量的Python语法。MindSpore的静态图编译器维护了Python常用语法子集,以支持神经网络的构建及训练。详情可参考[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.4.1/model_train/program_form/static_graph.html)。\n",
    "\n",
    "## 静态图高级编程技巧\n",
    "\n",
    "使用静态图高级编程技巧可以有效地提高编译效率以及执行效率,并可以使程序运行的更加稳定。详情可参考[静态图高级编程技巧](https://www.mindspore.cn/docs/zh-CN/r2.4.1/model_train/program_form/static_graph_syntax/static_graph_expert_programming.html)。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}