{ "cells": [ { "cell_type": "markdown", "id": "055c9867-e640-43fd-b3c4-658730291ded", "metadata": {}, "source": [ "[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9vYnMuZHVhbHN0YWNrLmNuLW5vcnRoLTQubXlodWF3ZWljbG91ZC5jb20vbWluZHNwb3JlLXdlYnNpdGUvbm90ZWJvb2svcjIuMC4wLWFscGhhL3R1dG9yaWFscy96aF9jbi9hZHZhbmNlZC9taW5kc3BvcmVfY29tcHV0ZV9ncmFwaC5pcHluYg==&imageid=77ef960a-bd26-4de4-9695-5b85a786fb90) [![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/tutorials/zh_cn/advanced/mindspore_compute_graph.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.0.0-alpha/tutorials/zh_cn/advanced/mindspore_compute_graph.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r2.0.0-alpha/tutorials/source_zh_cn/advanced/compute_graph.ipynb)\n", "\n", "# 计算图\n", "\n", "![comp-graph.png](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.0.0-alpha/tutorials/source_zh_cn/advanced/images/comp-graph.png)\n", "\n", "计算图是使用有向图表示数学表达式的一种形式。如图所示,神经网络结构可以被视为由Tensor数据和Tensor运算作为节点构成的计算图,因此使用深度学习框架构造神经网络并训练,即是构造计算图,执行计算图的过程。当前在业界框架对计算图的支持分为动态图和静态图两种模式,动态图通过解释执行,具有动态语法亲和性,表达灵活;静态图使用JIT(just in time)编译优化执行,使用静态语法,在语法上有较多限制。\n", "\n", "MindSpore同时支持两种计算图模式,采用统一的API表达,在两种模式下使用相同的API,并采用统一的自动微分机制,实现动态图与静态图的统一。下面我们分别介绍MindSpore的两种计算图模式。" ] }, { "cell_type": "markdown", "id": "2954f3ac-cc4b-424f-8f85-56edc822597d", "metadata": {}, "source": [ "## 动态图" ] }, { "cell_type": "markdown", "id": "0e8af791-b6ea-4624-95f6-f8a8eea2ad71", "metadata": {}, "source": [ "动态图的特点是计算图的构建和计算同时发生(Define by run),其符合Python的解释执行方式,在计算图中定义一个Tensor时,其值就已经被计算且确定,因此在调试模型时较为方便,能够实时得到中间结果的值,但由于所有节点都需要被保存,导致难以对整个计算图进行优化。\n", "\n", "在MindSpore中,动态图模式又被称为PyNative模式。由于动态图的解释执行特性,在脚本开发和网络流程调试过程中,推荐使用动态图模式进行调试。\n", "\n", "> MindSpore默认计算图模式为PyNative模式。\n", "\n", "如需要手动控制框架采用PyNative模式,可以通过以下代码进行配置:" ] }, { "cell_type": "code", "execution_count": 1, "id": "6744e02b-a720-4945-99b1-b521f01e4832", "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "\n", "ms.set_context(mode=ms.PYNATIVE_MODE)" ] }, { "cell_type": "markdown", "id": "4fd50adf-c2b6-4355-8f1d-7457c8b4de41", "metadata": {}, "source": [ "在PyNative模式下,所有计算节点对应的底层算子均采用单Kernel执行的方式,因此可以任意进行计算结果的打印和调试,如:" ] }, { "cell_type": "code", "execution_count": 2, "id": "544dc330-054d-4804-99fc-f96eb03e90a0", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from mindspore import nn\n", "from mindspore import ops\n", "from mindspore import Tensor, Parameter\n", "\n", "class Network(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.w = Parameter(Tensor(np.random.randn(5, 3), ms.float32), name='w') # weight\n", " self.b = Parameter(Tensor(np.random.randn(3,), ms.float32), name='b') # bias\n", "\n", " def construct(self, x):\n", " out = ops.matmul(x, self.w)\n", " print('matmul: ', out)\n", " out = out + self.b\n", " print('add bias: ', out)\n", " return out\n", "\n", "model = Network()" ] }, { "cell_type": "markdown", "id": "dd902edb-93d9-4fa7-a22e-1d4abbd33166", "metadata": {}, "source": [ "我们简单定义一个shape为(5,)的Tensor作为输入,观察输出情况。可以看到在`construct`方法中插入的`print`语句将中间结果进行实时的打印输出。" ] }, { "cell_type": "code", "execution_count": 3, "id": "dfdbd030-92a0-42a0-8a62-489442fcd729", "metadata": {}, "outputs": [], "source": [ "x = ops.ones(5, ms.float32) # input tensor" ] }, { "cell_type": "code", "execution_count": 4, "id": "deb97de3-d806-438e-8be0-097f1d08f316", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "matmul: [-1.8809001 2.0400267 0.32370526]\n", "add bias: [-1.6770952 1.5087128 0.15726662]\n", "out: [-1.6770952 1.5087128 0.15726662]\n" ] } ], "source": [ "out = model(x)\n", "print('out: ', out)" ] }, { "cell_type": "markdown", "id": "f8eb6d22-7f19-4f91-a60c-45104f362bb5", "metadata": {}, "source": [ "## 静态图" ] }, { "cell_type": "markdown", "id": "4cd72031-836d-475e-a74d-5264b29393d1", "metadata": {}, "source": [ "相较于动态图而言,静态图的特点是将计算图的构建和实际计算分开(Define and run)。在构建阶段,根据完整的计算流程对原始的计算图进行优化和调整,编译得到更省内存和计算量更少的计算图。由于编译之后图的结构不再改变,所以称之为 “静态图” 。在计算阶段,根据输入数据执行编译好的计算图得到计算结果。相较于动态图,静态图对全局的信息掌握更丰富,可做的优化也会更多,但是其中间过程对于用户来说是黑盒,无法像动态图一样实时拿到中间计算结果。\n", "\n", "在MindSpore中,静态图模式又被称为Graph模式,在Graph模式下,基于图优化、计算图整图下沉等技术,编译器可以针对图进行全局的优化,获得较好的性能,因此比较适合网络固定且需要高性能的场景。\n", "\n", "如需要手动控制框架采用Graph模式,可以通过以下代码进行配置:" ] }, { "cell_type": "code", "execution_count": 5, "id": "3f4a630c-05f7-4686-9e65-dc7701f6c1bf", "metadata": {}, "outputs": [], "source": [ "ms.set_context(mode=ms.GRAPH_MODE)" ] }, { "cell_type": "markdown", "id": "183ce4a4-e779-4959-98bc-ddde6a97a28d", "metadata": {}, "source": [ "### 基于源码转换的图编译\n", "\n", "在静态图模式下,MindSpore通过源码转换的方式,将Python的源码转换成中间表达IR(Intermediate Representation),并在此基础上对IR图进行优化,最终在硬件设备上执行优化后的图。MindSpore使用基于图表示的函数式IR,称为MindIR,详情可参考[中间表示MindIR](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/design/mindir.html)。\n", "\n", "MindSpore的静态图执行过程实际包含两步,对应静态图的Define和Run阶段,但在实际使用中,在实例化的Cell对象被调用时并不会感知,MindSpore将两阶段均封装在Cell的`__call__`方法中,因此实际调用过程为:\n", "\n", "`model(inputs) = model.compile(inputs) + model.construct(inputs)`,其中`model`为实例化Cell对象。\n", "\n", "下面我们显式调用`compile`方法进行示例:" ] }, { "cell_type": "code", "execution_count": 6, "id": "ca7a48b6-882b-4053-adf3-0709d54fe318", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "...\n", "matmul:\n", "Tensor(shape=[3], dtype=Float32, value=[-4.01971531e+00 -5.79053342e-01 3.41115999e+00])\n", "add bias:\n", "Tensor(shape=[3], dtype=Float32, value=[-3.94732714e+00 -1.46257186e+00 4.50144434e+00])\n", "out: [-3.9473271 -1.4625719 4.5014443]\n" ] } ], "source": [ "model = Network()\n", "\n", "model.compile(x)\n", "out = model(x)\n", "print('out: ', out)" ] }, { "cell_type": "markdown", "id": "ba7a09fc-af3d-42f2-946b-835187999864", "metadata": {}, "source": [ "### 静态图语法\n", "\n", "在Graph模式下,Python代码并不是由Python解释器去执行,而是将代码编译成静态计算图,然后执行静态计算图。因此,编译器无法支持全量的Python语法。MindSpore的静态图编译器维护了Python常用语法子集,以支持神经网络的构建及训练。详情可参考[静态图语法支持](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/note/static_graph_syntax_support.html)。" ] }, { "cell_type": "markdown", "id": "fdff81ad-7cba-438f-ac2c-27a4719fe5fa", "metadata": {}, "source": [ "### 静态图控制流\n", "\n", "在PyNative模式下,MindSpore完全支持Python原生语法的流程控制语句。Graph模式下,MindSpore在编译时做了性能优化,因此,在定义网络时使用流程控制语句时会有部分特殊约束,其他部分仍和Python原生语法保持一致。详情可参考[流程控制语句](https://mindspore.cn/tutorials/experts/zh-CN/r2.0.0-alpha/network/control_flow.html)。" ] }, { "cell_type": "markdown", "id": "59bea676-a87b-4166-8277-ec201874ab89", "metadata": {}, "source": [ "## 即时编译" ] }, { "cell_type": "markdown", "id": "432f1bd1-4b9f-4024-a37f-9c8c74625db9", "metadata": {}, "source": [ "通常情况下,由于动态图的灵活性,我们会选择使用PyNative模式来进行自由的神经网络构建,以实现模型的创新和优化。但是当需要进行性能加速时,我们需要对神经网络部分或整体进行加速。此时,直接切换为Graph模式是一种简单选择,但是由于静态图对语法和控制流等限制,使得我们无法从动态图无感知地切换至静态图。\n", "\n", "为此,MindSpore提供了`jit`装饰器,可以通过修饰Python函数或者Python类的成员函数使其被编译成计算图,通过图优化等技术提高运行速度。此时我们可以简单的对想要进行性能优化的模块进行图编译加速,而模型其他部分,仍旧使用解释执行方式,不丢失动态图的灵活性。" ] }, { "cell_type": "markdown", "id": "b62ea5ca-8ee8-4805-89ca-33f95a735af6", "metadata": {}, "source": [ "### Cell模块编译\n", "\n", "当我们需要对神经网络的某部分进行加速时,可以直接在`construct`方法上使用`jit`修饰器,在调用实例化对象时,该模块自动被编译为静态图。示例如下:" ] }, { "cell_type": "code", "execution_count": 7, "id": "4faf3c9d-c74e-497f-a478-bfa0b6dd6e34", "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "from mindspore import nn\n", "\n", "class Network(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.fc = nn.Dense(10, 1)\n", "\n", " @ms.jit\n", " def construct(self, x):\n", " return self.fc(x)" ] }, { "cell_type": "markdown", "id": "dcd7590d-dddb-4a24-ab10-fbcc908de49d", "metadata": {}, "source": [ "### 函数编译\n", "\n", "与Cell模块编译类似,在需要对Tensor的某些运算进行编译加速时,可以在其定义的函数上使用`jit`修饰器,在调用该函数时,该模块自动被编译为静态图。示例如下:\n", "\n", "> 基于MindSpore的函数式自动微分特性,推荐使用函数编译方式对Tensor运算进行即时编译加速。" ] }, { "cell_type": "code", "execution_count": 8, "id": "64f717f0-30d8-4d99-9489-1dc619b65720", "metadata": {}, "outputs": [], "source": [ "@ms.jit\n", "def mul(x, y):\n", " return x * y" ] }, { "cell_type": "markdown", "id": "cb665231-5419-4925-93a5-19819e8f5d02", "metadata": {}, "source": [ "### 整图编译\n", "\n", "MindSpore支持将神经网络训练的正向计算、反向传播、梯度优化更新等步骤合为一个计算图进行编译优化,此方法称为整图编译。此时,仅需将神经网络训练逻辑构造为函数,并在函数上使用`jit`修饰器,即可达到整图编译的效果。\n", "\n", "下面使用简单的全连接网络进行举例:" ] }, { "cell_type": "code", "execution_count": 9, "id": "33fb98bc-3a5d-458e-88b6-a11fcd68b8aa", "metadata": {}, "outputs": [], "source": [ "network = nn.Dense(10, 1)\n", "loss_fn = nn.BCELoss()\n", "optimizer = nn.Adam(network.trainable_params(), 0.01)\n", "\n", "def forward_fn(data, label):\n", " logits = network(data)\n", " loss = loss_fn(logits, label)\n", " return loss\n", "\n", "grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)\n", "\n", "@ms.jit\n", "def train_step(data, label):\n", " loss, grads = grad_fn(data, label)\n", " optimizer(grads)\n", " return loss" ] }, { "cell_type": "markdown", "id": "e89bfe8c-7ffc-40fb-9901-6541d808ba4e", "metadata": {}, "source": [ "如上述代码所示,将神经网络正向执行与损失函数封装为`forward_fn`后,执行函数变换获得梯度计算函数。而后将梯度计算函数、优化器调用封装为`train_step`函数,并使用`jit`进行修饰,调用`train_step`函数时,会进行静态图编译,获得整图并执行。\n", "\n", "除使用修饰器外,也可使用函数变换方式调用`jit`方法,示例如下:" ] }, { "cell_type": "code", "execution_count": 10, "id": "f905fd70-ce8a-4f3e-86ee-f001e6880d31", "metadata": {}, "outputs": [], "source": [ "train_step = ms.jit(train_step)" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 5 }