{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ms_function动静结合\n", "\n", "`Ascend` `GPU` `CPU` `模型运行`\n", "\n", "[![在线运行](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tYXN0ZXIvcHJvZ3JhbW1pbmdfZ3VpZGUvemhfY24vbWluZHNwb3JlX21zX2Z1bmN0aW9uLmlweW5i&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c) [![下载Notebook](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_notebook.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_ms_function.ipynb) [![下载样例代码](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_download_code.png)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r1.6/programming_guide/zh_cn/mindspore_ms_function.py) [![查看源文件](https://gitee.com/mindspore/docs/raw/r1.6/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.6/docs/mindspore/programming_guide/source_zh_cn/ms_function.ipynb)\n", "\n", "## 概述\n", "\n", "ms_function的作用是在PyNative模式下提升执行性能。在MindSpore框架中,PyNative模式(即动态图模式)下,用户可以使用完整的Python语法,更加简单方便地使用MindSpore进行网络调优。与此同时,PyNative模式也会导致一部分性能的损失。\n", "\n", "ms_function支持在PyNative模式下,让被ms_function修饰的程序以静态图的方式来运行。ms_function会将修饰的程序通过静态编译的方式来生成可执行图,整体下发执行,从而提升该修饰部分的执行性能。\n", "\n", "本文档主要介绍ms_function的使用方法和工作原理,以便您可以更有效地使用ms_function功能。\n", "\n", "## 修饰独立函数\n", "\n", "使用ms_function装饰器时,可以对独立定义的函数进行修饰。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.406170Z", "start_time": "2022-01-04T11:36:29.874130Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[2. 2. 2.]\n", " [2. 2. 2.]\n", " [2. 2. 2.]]\n" ] } ], "source": [ "# pylint: disable=W0235,W0612\n", "import numpy as np\n", "import mindspore.ops as ops\n", "from mindspore import context, Tensor, ms_function\n", "\n", "@ms_function\n", "def add_func(x, y):\n", " return ops.add(x, y)\n", "\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "x = Tensor(np.ones([3, 3], dtype=np.float32))\n", "y = Tensor(np.ones([3, 3], dtype=np.float32))\n", "out = add_func(x, y)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 修饰Cell的成员函数\n", "\n", "使用ms_function装饰器时,可以对Cell的成员函数进行修饰。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.498145Z", "start_time": "2022-01-04T11:36:31.408221Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Tensor(shape=[3, 3], dtype=Float32, value=\n", "[[1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n", " [1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n", " [1.00000000e+000, 1.00000000e+000, 1.00000000e+000]]), Tensor(shape=[3, 3], dtype=Float32, value=\n", "[[1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n", " [1.00000000e+000, 1.00000000e+000, 1.00000000e+000],\n", " [1.00000000e+000, 1.00000000e+000, 1.00000000e+000]]))\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "import mindspore.ops as ops\n", "from mindspore import context, Tensor, ms_function\n", "\n", "class Add(nn.Cell):\n", " def __init__(self):\n", " super(Add, self).__init__()\n", "\n", " @ms_function\n", " def construct(self, x, y):\n", " out = x * y\n", " return out\n", "\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "x = Tensor(np.ones([3, 3], dtype=np.float32))\n", "y = Tensor(np.ones([3, 3], dtype=np.float32))\n", "grad_ops = ops.GradOperation(get_all=True)\n", "net = Add()\n", "grad_out = grad_ops(net)(x, y)\n", "print(grad_out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 实现原理\n", "\n", "本小节将介绍ms_function的实现原理,当你深入了解了ms_function的工作原理时,你将会更高效地使用ms_function。\n", "\n", "以一个简单的动静结合的用例来说明,如下:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.517439Z", "start_time": "2022-01-04T11:36:31.499697Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[1024. 1024. 1024.]\n", " [1024. 1024. 1024.]\n", " [1024. 1024. 1024.]]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "from mindspore import context, Tensor, ms_function\n", "\n", "class Add(nn.Cell):\n", " def __init__(self):\n", " super(Add, self).__init__()\n", "\n", " def construct(self, x):\n", " x = x + x\n", " x = x + x\n", " return x\n", "\n", "\n", "class Mul(nn.Cell):\n", " def __init__(self):\n", " super(Mul, self).__init__()\n", "\n", " @ms_function\n", " def construct(self, x):\n", " x = x * x\n", " x = x * x\n", " return x\n", "\n", "\n", "class Test(nn.Cell):\n", " def __init__(self):\n", " super(Test, self).__init__()\n", " self.add = Add()\n", " self.mul = Mul()\n", "\n", " def construct(self, x):\n", " x = self.add(x)\n", " x = self.mul(x)\n", " x = self.add(x)\n", " return x\n", "\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "x = Tensor(np.ones([3, 3], dtype=np.float32))\n", "net = Test()\n", "out = net(x)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "该用例按照执行序,编译的方式如下图所示:\n", "\n", "![image](./images/ms_function.png)\n", "\n", "被ms_function修饰的函数将会按照静态图的方式进行编译和执行。如果网络涉及到反向求导,被ms_function修饰的部分也将以整图的形式来生成反向图,并与前后单个算子的反向图连成整体的反向图,下发执行。\n", "其中,缓存的策略与静态图的缓存策略一致,相同的函数对象在输入Shape和Type信息一致时,编译的图结构将会被缓存。\n", "\n", "## 使用须知\n", "\n", "在使用ms_function来修饰函数,加速执行效率时,请注意以下几点:\n", "\n", "1. ms_function修饰的函数须在静态图编译支持的语法范围内,包括但不限于数据类型等。\n", "\n", "2. ms_function修饰的函数所支持的控制流语法,与静态图保持一致。其中,仅对固定循环次数或者分支条件的控制流结构具有加速效果。\n", "\n", "3. 在PyNative模式下使用ms_function功能时,非ms_function修饰的部分支持断点调试;被ms_function修饰的部分由于是以静态图的方式编译,不支持断点调试。\n", "\n", "4. 由于ms_function修饰的函数将按照静态图的方式编译执行,因此ms_function不支持修饰的函数中含有Hook算子,以及不支持修饰自定义Bprop函数等。\n", "\n", "5. ms_function修饰的函数会受到静态图函数副作用的影响。\n", "\n", "函数副作用指:当调用函数时,除了函数返回值之外,还对主调用函数产生的附加影响。例如修改全局变量(函数外的变量),修改函数的参数等。\n", "\n", "场景1:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.534298Z", "start_time": "2022-01-04T11:36:31.519461Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "5\n" ] } ], "source": [ "import numpy as np\n", "from mindspore import context, Tensor, ms_function\n", "\n", "value = 5\n", "\n", "@ms_function\n", "def func(x, y):\n", " out = x + y\n", " value = 1\n", " return out\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "x = Tensor(np.ones([3, 3], dtype=np.float32))\n", "y = Tensor(np.ones([3, 3], dtype=np.float32))\n", "func(x, y)\n", "print(value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "该场景下,`value`是全局变量且在`func`函数中被修改。此时,如果用ms_function修饰`func`函数,全局变量`value`的值将不会被修改。原因是静态图编译时,会优化掉与返回值无关的语句。\n", "\n", "场景2:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2022-01-04T11:36:31.552710Z", "start_time": "2022-01-04T11:36:31.535881Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[6. 6. 6.]\n", " [6. 6. 6.]\n", " [6. 6. 6.]]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore.nn as nn\n", "from mindspore import context, Tensor, ms_function\n", "\n", "class Func(nn.Cell):\n", " def __init__(self):\n", " super(Func, self).__init__()\n", " self.value = 5\n", "\n", " @ms_function\n", " def construct(self, x):\n", " out = self.value + x\n", " return out\n", "\n", "context.set_context(mode=context.PYNATIVE_MODE)\n", "x = Tensor(np.ones([3, 3], dtype=np.float32))\n", "func = Func()\n", "out = func(x)\n", "func.value = 1\n", "out = func(x)\n", "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "该场景下,`value`是`Func`对象的参数,此时如果用ms_function修饰`Func`对象的`construct`成员函数。执行`Func`时将会以静态图的方式编译执行。由于静态图会缓存编译结果,第二次调用`Func`时,对`value`的修改将不会生效。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 4 }