{ "cells": [ { "cell_type": "markdown", "id": "ac2df353", "metadata": {}, "source": [ "# 静态图语法支持\n", "\n", "[](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0/zh_cn/note/mindspore_static_graph_syntax_support.ipynb) [](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3.0/zh_cn/note/mindspore_static_graph_syntax_support.py) [](https://gitee.com/mindspore/docs/blob/r2.3.0/docs/mindspore/source_zh_cn/note/static_graph_syntax_support.ipynb)\n", "\n", "## 概述\n", "\n", "在Graph模式下,Python代码并不是由Python解释器去执行,而是将代码编译成静态计算图,然后执行静态计算图。\n", "\n", "在静态图模式下,MindSpore通过源码转换的方式,将Python的源码转换成中间表达IR(Intermediate Representation),并在此基础上对IR图进行优化,最终在硬件设备上执行优化后的图。MindSpore使用基于图表示的函数式IR,称为MindIR,详情可参考[中间表示MindIR](https://www.mindspore.cn/docs/zh-CN/r2.3.0/design/all_scenarios.html#中间表示mindir)。\n", "\n", "MindSpore的静态图执行过程实际包含两步,对应静态图的Define和Run阶段,但在实际使用中,在实例化的Cell对象被调用时用户并不会分别感知到这两阶段,MindSpore将两阶段均封装在Cell的`__call__`方法中,因此实际调用过程为:\n", "\n", "`model(inputs) = model.compile(inputs) + model.construct(inputs)`,其中`model`为实例化Cell对象。\n", "\n", "使用Graph模式有两种方式:一是调用`@jit`装饰器修饰函数或者类的成员方法,所修饰的函数或方法将会被编译成静态计算图。`jit`使用规则详见[jit API文档](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.jit.html#mindspore.jit)。二是设置`ms.set_context(mode=ms.GRAPH_MODE)`,使用`Cell`类并且在`construct`函数中编写执行代码,此时`construct`函数的代码将会被编译成静态计算图。`Cell`定义详见[Cell API文档](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/nn/mindspore.nn.Cell.html)。\n", "\n", "由于语法解析的限制,当前在编译构图时,支持的数据类型、语法以及相关操作并没有完全与Python语法保持一致,部分使用受限。借鉴传统JIT编译的思路,从图模式的角度考虑动静图的统一,扩展图模式的语法能力,使得静态图提供接近动态图的语法使用体验,从而实现动静统一。为了便于用户选择是否扩展静态图语法,提供了JIT语法支持级别选项`jit_syntax_level`,其值必须在[STRICT,LAX]范围内,选择`STRICT`则认为使用基础语法,不扩展静态图语法。默认值为`LAX`,更多请参考本文的[扩展语法(LAX级别)](#扩展语法lax级别)章节。全部级别都支持所有后端。\n", "\n", "- STRICT: 仅支持基础语法,且执行性能最佳。可用于MindIR导入导出。\n", "- LAX: 支持更多复杂语法,最大程度地兼容Python所有语法。由于存在可能无法导出的语法,不能用于MindIR导入导出。\n", "\n", "本文主要介绍,在编译静态图时,支持的数据类型、语法以及相关操作,这些规则仅适用于Graph模式。\n", "\n", "## 基础语法(STRICT级别)\n", "\n", "### 静态图内的常量与变量\n", "\n", "在静态图中,常量与变量是理解静态图语法的一个重要概念,很多语法在常量输入和变量输入情况下支持的方法与程度是不同的。因此,在介绍静态图具体支持的语法之前,本小节先会对静态图中常量与变量的概念进行说明。\n", "\n", "在静态图模式下,一段程序的运行会被分为编译期以及执行期。 在编译期,程序会被编译成一张中间表示图,并且程序不会真正的执行,而是通过抽象推导的方式对中间表示进行静态解析。这使得在编译期时,我们无法保证能获取到所有中间表示中节点的值。 常量和变量也就是通过能否能在编译器获取到其真实值来区分的。\n", "\n", "- 常量: 编译期内可以获取到值的量。\n", "- 变量: 编译期内无法获取到值的量。\n", "\n", "#### 常量产生场景\n", "\n", "- 作为图模式输入的标量,列表以及元组均为常量(在不使用mutable接口的情况下)。例如:" ] }, { "cell_type": "code", "execution_count": 102, "id": "e6964906", "metadata": {}, "outputs": [], "source": [ "from mindspore import Tensor, jit\n", "\n", "a = 1\n", "b = [Tensor([1]), Tensor([2])]\n", "c = [\"a\", \"b\", \"c\"]\n", "\n", "@jit\n", "def foo(a, b, c):\n", " return a, b, c" ] }, { "cell_type": "markdown", "id": "b511f34e", "metadata": {}, "source": [ "上述代码中,输入`a`,`b`,`c`均为常量。\n", "\n", "- 图模式内生成的标量或者Tensor为常量。例如:" ] }, { "cell_type": "code", "execution_count": 103, "id": "36a141ef", "metadata": {}, "outputs": [], "source": [ "from mindspore import jit, Tensor\n", "\n", "@jit\n", "def foo():\n", " a = 1\n", " b = \"2\"\n", " c = Tensor([1, 2, 3])\n", " return a, b, c" ] }, { "cell_type": "markdown", "id": "538635fe", "metadata": {}, "source": [ "上述代码中, `a`,`b`,`c`均为常量。\n", "\n", "- 常量运算得到的结果为常量。例如:" ] }, { "cell_type": "code", "execution_count": 104, "id": "b7e995be", "metadata": {}, "outputs": [], "source": [ "from mindspore import jit, Tensor\n", "\n", "@jit\n", "def foo():\n", " a = Tensor([1, 2, 3])\n", " b = Tensor([1, 1, 1])\n", " c = a + b\n", " return c" ] }, { "cell_type": "markdown", "id": "0e306590", "metadata": {}, "source": [ "上述代码中,`a`、`b`均为图模式内产生的Tensor为常量,因此其计算得到的结果也是常量。但如果其中之一为变量时,其返回值也会为变量。\n", "\n", "#### 变量产生场景\n", "\n", "- 所有mutable接口的返回值均为变量(无论是在图外使用mutable还是在图内使用)。例如:" ] }, { "cell_type": "code", "execution_count": 105, "id": "43e94877", "metadata": {}, "outputs": [], "source": [ "from mindspore import Tensor, jit\n", "from mindspore import mutable\n", "\n", "a = mutable([Tensor([1]), Tensor([2])])\n", "\n", "@jit\n", "def foo(a):\n", " b = mutable(Tensor([3]))\n", " c = mutable((Tensor([1]), Tensor([2])))\n", " return a, b, c" ] }, { "cell_type": "markdown", "id": "08aeaed5", "metadata": {}, "source": [ "上述代码中,`a`是在图外调用mutable接口的,`b`和`c`是在图内调用mutable接口生成的,`a`、`b`、`c`均为变量。\n", "\n", "- 作为静态图的输入的Tensor都是变量。例如:" ] }, { "cell_type": "code", "execution_count": 106, "id": "d4bac49b", "metadata": {}, "outputs": [], "source": [ "from mindspore import Tensor, jit\n", "\n", "a = Tensor([1])\n", "b = (Tensor([1]), Tensor([2]))\n", "\n", "@jit\n", "def foo(a, b):\n", " return a, b" ] }, { "cell_type": "markdown", "id": "9cf26d3e", "metadata": {}, "source": [ "上述代码中,`a`是作为图模式输入的Tensor,因此其为变量。但`b`是作为图模式输入的元组,非Tensor类型,即使其内部的元素均为Tensor,`b`也是常量。\n", "\n", "- 通过变量计算得到的是变量。\n", "\n", " 如果一个量是算子的输出,那么其多数情况下为常量。例如:" ] }, { "cell_type": "code", "execution_count": 107, "id": "b9765b45", "metadata": {}, "outputs": [], "source": [ "from mindspore import Tensor, jit, ops\n", "\n", "a = Tensor([1])\n", "b = Tensor([2])\n", "\n", "@jit\n", "def foo(a, b):\n", " c = a + b\n", " return c" ] }, { "cell_type": "markdown", "id": "081696e5", "metadata": {}, "source": [ "在这种情况下,`c`是`a`和`b`计算来的结果,且用来计算的输入`a`、`b`均为变量,因此`c`也是变量。\n", "\n", "### 数据类型\n", "\n", "#### Python内置数据类型\n", "\n", "当前支持的`Python`内置数据类型包括:`Number`、`String`、`List`、`Tuple`和`Dictionary`。\n", "\n", "##### Number\n", "\n", "支持`int`(整型)、`float`(浮点型)、`bool`(布尔类型),不支持`complex`(复数)。\n", "\n", "支持在网络里定义`Number`,即支持语法:`y = 1`、`y = 1.2`、`y = True`。\n", "\n", "当数据为常量时,编译时期可以获取到数值,在网络中可以支持强转`Number`的语法:`y = int(x)`、`y = float(x)`、`y = bool(x)`。\n", "当数据为变量时,即需要在运行时期才可以获取到数值,也支持使用int(),float(),bool()等内置函数[Python内置函数](#python内置函数)进行数据类型的转换。例如:" ] }, { "cell_type": "code", "execution_count": 108, "id": "539145a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "res[0]: 11\n", "res[1]: 10\n", "res[2]: 2\n" ] } ], "source": [ "from mindspore import Tensor, jit\n", "\n", "@jit\n", "def foo(x):\n", " out1 = int(11.1)\n", " out2 = int(Tensor([10]))\n", " out3 = int(x.asnumpy())\n", " return out1, out2, out3\n", "\n", "res = foo(Tensor(2))\n", "print(\"res[0]:\", res[0])\n", "print(\"res[1]:\", res[1])\n", "print(\"res[2]:\", res[2])" ] }, { "cell_type": "markdown", "id": "38ca4b5a", "metadata": {}, "source": [ "支持返回Number类型。例如:" ] }, { "cell_type": "code", "execution_count": 109, "id": "30b19828", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "3\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit\n", "def test_return_scalar(x, y):\n", " return x + y\n", "\n", "res = test_return_scalar(ms.mutable(1), ms.mutable(2))\n", "print(res)" ] }, { "cell_type": "markdown", "id": "3cf060f3", "metadata": {}, "source": [ "##### String\n", "\n", "支持在网络里构造`String`,即支持使用引号(`'`或`\"`)来创建字符串,如`x = 'abcd'`或`y = \"efgh\"`。可以通过`str()`的方式进行将常量转换成字符串。支持对字符串连接,截取,以及使用成员运算符(`in`或`not in`)判断字符串是否包含指定的字符。支持格式化字符串的输出,将一个值插入到一个有字符串格式符`%s`的字符串中。支持在常量场景下使用格式化字符串函数`str.format()`。\n", "\n", "例如:" ] }, { "cell_type": "code", "execution_count": 110, "id": "19bf6d12", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "res: ('H', 'Spore', 'Hello!MindSpore', 'MindSporeMindSpore', True, 'My name is MindSpore!', 'string is 123')\n" ] } ], "source": [ "from mindspore import jit\n", "\n", "@jit\n", "def foo():\n", " var1 = 'Hello!'\n", " var2 = \"MindSpore\"\n", " var3 = str(123)\n", " var4 = \"{} is {}\".format(\"string\", var3)\n", " return var1[0], var2[4:9], var1 + var2, var2 * 2, \"H\" in var1, \"My name is %s!\" % var2, var4\n", "\n", "res = foo()\n", "print(\"res:\", res)" ] }, { "cell_type": "markdown", "id": "72d7a9be", "metadata": {}, "source": [ "##### List\n", "\n", "在`JIT_SYNTAX_LEVEL`设置为`LAX`的情况下,静态图模式可以支持部分`List`对象的inplace操作,具体介绍详见[支持列表就地修改操作](#支持列表就地修改操作)章节。\n", "\n", "`List`的基础使用场景如下:\n", "\n", "- 图模式支持图内创建`List`。\n", "\n", " 支持在图模式内创建`List`对象,且`List`内对象的元素可以包含任意图模式支持的类型,也支持多层嵌套。例如:" ] }, { "cell_type": "code", "execution_count": 111, "id": "2da43e45", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import mindspore as ms\n", "\n", "@ms.jit\n", "def generate_list():\n", " a = [1, 2, 3, 4]\n", " b = [\"1\", \"2\", \"a\"]\n", " c = [ms.Tensor([1]), ms.Tensor([2])]\n", " d = [a, b, c, (4, 5)]\n", " return d" ] }, { "cell_type": "markdown", "id": "b0ef63b1", "metadata": {}, "source": [ "上述示例代码中,所有的`List`对象都可以被正常的创建。\n", "\n", "- 图模式支持返回`List`。\n", "\n", " 在MindSpore2.0版本之前,当图模式返回`List` 对象时,`List`会被转换为`Tuple`。MindSpore2.0版本已经可以支持返回`List`对象。例如:" ] }, { "cell_type": "code", "execution_count": 112, "id": "ea7c7945", "metadata": {}, "outputs": [ ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit\n", "def list_func():\n", " a = [1, 2, 3, 4]\n", " return a\n", "\n", "output = list_func() # output: [1, 2, 3, 4]" ] }, { "cell_type": "markdown", "id": "a28f9d17", "metadata": {}, "source": [ "与图模式内创建`List` 相同,图模式返回`List`对象可以包括任意图模式支持的类型,也支持多层嵌套。\n", "\n", "- 图模式支持从全局变量中获取`List`对象。" ] }, { "cell_type": "code", "execution_count": 113, "id": "889d0ce8", "metadata": {}, "outputs": [ ], "source": [ "import mindspore as ms\n", "\n", "global_list = [1, 2, 3, 4]\n", "\n", "@ms.jit\n", "def list_func():\n", " global_list.reverse()\n", " return global_list\n", "\n", "output = list_func() # output: [4, 3, 2, 1]" ] }, { "cell_type": "markdown", "id": "5f31207d", "metadata": {}, "source": [ "需要注意的是,在基础场景下图模式返回的列表与全局变量的列表不是同一个对象,当`JIT_SYNTAX_LEVEL`设置为`LAX`时,返回的对象与全局对象为统一对象。\n", "\n", "- 图模式支持以`List`作为输入。\n", "\n", " 图模式支持`List`作为静态图的输入,作为输入的`List`对象的元素必须为图模式支持的输入类型,也支持多层嵌套。" ] }, { "cell_type": "code", "execution_count": 114, "id": "0ca96da0", "metadata": {}, "outputs": [ ], "source": [ "import mindspore as ms\n", "\n", "list_input = [1, 2, 3, 4]\n", "\n", "@ms.jit\n", "def list_func(x):\n", " return x\n", "\n", "output = list_func(list_input) # output: [1, 2, 3, 4]" ] }, { "cell_type": "markdown", "id": "18fb25a1", "metadata": {}, "source": [ "需要注意的是,`List`作为静态图输入时,无论其内部的元素是什么类型,一律被视为常量。\n", "\n", "- 图模式支持List的内置方法。\n", "\n", " `List` 内置方法的详细介绍如下:\n", "\n", "- List索引取值\n", "\n", " 基础语法:```element = list_object[index]```。\n", "\n", " 基础语义:将`List`对象中位于第`index`位的元素提取出来(`index`从0开始)。支持多层索引取值。\n", "\n", " 索引值`index`支持类型包括`int`,`Tensor`和`slice`。其中,`int`以及`Tensor`类型的输入可以支持常量以及变量,`slice`内部数据必须为编译时能够确定的常量。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 115, "id": "cef90ce6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a:[1, 2]\n", "b:2\n", "c:[3, 4]\n" ] } ], "source": [ "import mindspore as ms\n", "@ms.jit()\n", "def list_getitem_func():\n", " x = [[1, 2], 3, 4]\n", " a = x[0]\n", " b = x[0][ms.Tensor([1])]\n", " c = x[1:3:1]\n", " return a, b, c\n", "\n", "a, b, c = list_getitem_func()\n", "print('a:{}'.format(a))\n", "print('b:{}'.format(b))\n", "print('c:{}'.format(c))" ] }, { "cell_type": "markdown", "id": "1635578f", "metadata": {}, "source": [ "- List索引赋值\n", "\n", " 基础语法:```list_object[index] = target_element```。\n", "\n", " 基础语义:将`List`对象中位于第`index`位的元素赋值为 `target_element`(`index`从0开始)。支持多层索引赋值。\n", "\n", " 索引值`index`支持类型包括`int`,`Tensor`和`slice`。其中,`int` 以及`Tensor`类型的输入可以支持常量以及变量,`slice`内部数据必须为编译时能够确定的常量。\n", "\n", " 索引赋值对象`target_element`支持所有图模式支持的数据类型。\n", "\n", " 目前,`List`索引赋值不支持inplace操作, 索引赋值后将会生成一个新的对象。该操作后续将会支持inplace操作。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 116, "id": "bb06d465", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output:[[0, 88], 10, 'ok', (1, 2, 3)]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_setitem_func():\n", " x = [[0, 1], 2, 3, 4]\n", " x[1] = 10\n", " x[2] = \"ok\"\n", " x[3] = (1, 2, 3)\n", " x[0][1] = 88\n", " return x\n", "\n", "output = test_setitem_func()\n", "print('output:{}'.format(output))" ] }, { "cell_type": "markdown", "id": "138d7848", "metadata": {}, "source": [ "- List.append\n", "\n", " 基础语法:```list_object.append(target_element)```。\n", "\n", " 基础语义:向`List`对象`list_object`的最后追加元素`target_element`。\n", "\n", " 目前,`List.append`不支持inplace操作, 索引赋值后将会生成一个新的对象。该操作后续将会支持inplace操作。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 117, "id": "232b8c0c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x:[1, 2, 3, 4]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_list():\n", " x = [1, 2, 3]\n", " x.append(4)\n", " return x\n", "\n", "x = test_list()\n", "print('x:{}'.format(x))" ] }, { "cell_type": "markdown", "id": "7fb155fb", "metadata": {}, "source": [ "- List.clear\n", "\n", " 基础语法:```list_object.clear()```。\n", "\n", " 基础语义:清空`List`对象 `list_object`中包含的元素。\n", "\n", " 目前,`List.clear`不支持inplace, 索引赋值后将会生成一个新的对象。该操作后续将会支持inplace。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 118, "id": "b8437a5e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x:[]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_list_clear():\n", " x = [1, 3, 4]\n", " x.clear()\n", " return x\n", "\n", "x = test_list_clear()\n", "print('x:{}'.format(x))" ] }, { "cell_type": "markdown", "id": "f6789309", "metadata": {}, "source": [ "- List.extend\n", "\n", " 基础语法:```list_object.extend(target)```。\n", "\n", " 基础语义:向`List`对象`list_object`的最后依次插入`target`内的所有元素。\n", "\n", " `target`支持的类型为`Tuple`,`List`以及`Tensor`。其中,如果`target`类型为`Tensor`的情况下,会先将该`Tensor`转换为`List`,再进行插入操作。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 119, "id": "15065fda", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output1:[1, 2, 3, 4, 'a']\n", "output2:[1, 2, 3, Tensor(shape=[], dtype=Int64, value= 4), Tensor(shape=[], dtype=Int64, value= 5)]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_list_extend():\n", " x1 = [1, 2, 3]\n", " x1.extend((4, \"a\"))\n", " x2 = [1, 2, 3]\n", " x2.extend(ms.Tensor([4, 5]))\n", " return x1, x2\n", "\n", "output1, output2 = test_list_extend()\n", "print('output1:{}'.format(output1))\n", "print('output2:{}'.format(output2))" ] }, { "cell_type": "markdown", "id": "d275131e", "metadata": {}, "source": [ "- List.pop\n", "\n", " 基础语法:```pop_element = list_object.pop(index=-1)```。\n", "\n", " 基础语义:将`List`对象`list_object` 的第`index`个元素从`list_object`中删除,并返回该元素。\n", "\n", " `index` 要求必须为常量`int`, 当`list_object`的长度为`list_obj_size`时,`index`的取值范围为:`[-list_obj_size,list_obj_size-1]`。`index`为负数,代表从后往前的位数。当没有输入`index`时,默认值为-1,即删除最后一个元素。" ] }, { "cell_type": "code", "execution_count": 120, "id": "6c401a86", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pop_element:3\n", "res_list:[1, 2]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_list_pop():\n", " x = [1, 2, 3]\n", " b = x.pop()\n", " return b, x\n", "\n", "pop_element, res_list = test_list_pop()\n", "print('pop_element:{}'.format(pop_element))\n", "print('res_list:{}'.format(res_list))" ] }, { "cell_type": "markdown", "id": "5231ac8e", "metadata": {}, "source": [ "- List.reverse\n", "\n", " 基础语法:```list_object.reverse()```。\n", "\n", " 基础语义:将`List`对象`list_object`的元素顺序倒转。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 121, "id": "35790450", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output:[3, 2, 1]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_list_reverse():\n", " x = [1, 2, 3]\n", " x.reverse()\n", " return x\n", "\n", "output = test_list_reverse()\n", "print('output:{}'.format(output))" ] }, { "cell_type": "markdown", "id": "c8116f8e", "metadata": {}, "source": [ "- List.insert\n", "\n", " 基础语法:```list_object.insert(index, target_obj)```。\n", "\n", " 基础语义:将`target_obj`插入到`list_object`的第`index`位。\n", "\n", " `index`要求必须为常量`int`。如果`list_object`的长度为`list_obj_size`。当`index < -list_obj_size`时,插入到`List`的第一位。当`index >= list_obj_size`时,插入到`List`的最后。`index`为负数代表从后往前的位数。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 122, "id": "b6257418", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "output:[1, 2, 3, 4]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_list_insert():\n", " x = [1, 2, 3]\n", " x.insert(3, 4)\n", " return x\n", "\n", "output = test_list_insert()\n", "print('output:{}'.format(output))" ] }, { "cell_type": "markdown", "id": "bba335fa", "metadata": {}, "source": [ "##### Tuple\n", "\n", "支持在网络里构造元组`Tuple`,使用小括号包含元素,即支持语法`y = (1, 2, 3)`。元组`Tuple`的元素不能修改,但支持索引访问元组`Tuple`中的元素,支持对元组进行连接组合。\n", "\n", "- 支持索引取值。\n", "\n", " 支持使用方括号加下标索引的形式来访问元组`Tuple`中的元素,索引值支持`int`、`slice`、`Tensor`,也支持多层索引取值,即支持语法`data = tuple_x[index0][index1]...`。\n", "\n", " 索引值为`Tensor`有如下限制:\n", "\n", "- `Tuple`里存放的都是`Cell`,每个`Cell`要在`Tuple`定义之前完成定义,每个`Cell`的入参个数、入参类型和入参`shape`要求一致,每个`Cell`的输出个数、输出类型和输出`shape`也要求一致。\n", "\n", "- 索引`Tensor`是一个`dtype`为`int32`的标量`Tensor`,取值范围在`[-tuple_len, tuple_len)`,`Ascend`后端不支持负数索引。\n", "\n", "- 支持`CPU`、`GPU`和`Ascend`后端。" ] }, { "cell_type": "markdown", "id": "b887d8e1", "metadata": {}, "source": [ " `int`、`slice`索引示例如下:" ] }, { "cell_type": "code", "execution_count": 123, "id": "cf9f7dce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y:3\n", "z:[1 2 3]\n", "m:((2, 3, 4), 3, 4)\n", "n:(2, 3, 4)\n" ] } ], "source": [ "import numpy as np\n", "import mindspore as ms\n", "\n", "t = ms.Tensor(np.array([1, 2, 3]))\n", "\n", "@ms.jit()\n", "def test_index():\n", " x = (1, (2, 3, 4), 3, 4, t)\n", " y = x[1][1]\n", " z = x[4]\n", " m = x[1:4]\n", " n = x[-4]\n", " return y, z, m, n\n", "\n", "y, z, m, n = test_index()\n", "print('y:{}'.format(y))\n", "print('z:{}'.format(z))\n", "print('m:{}'.format(m))\n", "print('n:{}'.format(n))" ] }, { "cell_type": "markdown", "id": "2d1c3a30", "metadata": {}, "source": [ " `Tensor`索引示例如下:" ] }, { "cell_type": "code", "execution_count": 124, "id": "e006a4ed", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ret:[0.]\n" ] } ], "source": [ "import mindspore as ms\n", "from mindspore import nn, set_context\n", "\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.relu = nn.ReLU()\n", " self.softmax = nn.Softmax()\n", " self.layers = (self.relu, self.softmax)\n", "\n", " def construct(self, x, index):\n", " ret = self.layers[index](x)\n", " return ret\n", "\n", "x = ms.Tensor([-1.0], ms.float32)\n", "\n", "net = Net()\n", "ret = net(x, 0)\n", "print('ret:{}'.format(ret))" ] }, { "cell_type": "markdown", "id": "e669cb4f", "metadata": {}, "source": [ "- 支持连接组合。\n", "\n", " 与字符串`String`类似,元组支持使用`+`和`*`进行组合,得到一个新的元组`Tuple`,例如:" ] }, { "cell_type": "code", "execution_count": 125, "id": "ef1d6027", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out1:(1, 2, 3, 4, 5, 6)\n", "out2:(1, 2, 3, 1, 2, 3)\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_index():\n", " x = (1, 2, 3)\n", " y = (4, 5, 6)\n", " return x + y, x * 2\n", "\n", "out1, out2 = test_index()\n", "print('out1:{}'.format(out1))\n", "print('out2:{}'.format(out2))" ] }, { "cell_type": "markdown", "id": "9d414438", "metadata": {}, "source": [ "##### Dictionary\n", "\n", "支持在网络里构造字典`Dictionary`,每个键值`key:value`用冒号`:`分割,每个键值对之间用逗号`,`分割,整个字典使用大括号`{}`包含键值对,即支持语法`y = {\"a\": 1, \"b\": 2}`。\n", "\n", "键`key`是唯一的,如果字典中存在多个相同的`key`,则重复的`key`以最后一个作为最终结果;而值`value`可以不是唯一的。键`key`需要保证是不可变的。当前键`key`支持`String`、`Number`、常量`Tensor`以及只包含这些类型对象的`Tuple`;值`value`支持`Number`、`Tuple`、`Tensor`、`List`、`Dictionary`和`None`。\n", "\n", "- 支持接口。\n", "\n", " `keys`:取出`dict`里所有的`key`值,组成`Tuple`返回。\n", "\n", " `values`:取出`dict`里所有的`value`值,组成`Tuple`返回。\n", "\n", " `items`:取出`dict`里每一对`key`和`value`组成的`Tuple`,最终组成`List`返回。\n", "\n", " `get`:`dict.get(key[, value])`返回指定`key`对应的`value`值,如果指定`key`不存在,返回默认值`None`或者设置的默认值`value`。\n", "\n", " `clear`:删除`dict`里所有的元素。\n", "\n", " `has_key`:`dict.has_key(key)`判断`dict`里是否存在指定`key`。\n", "\n", " `update`:`dict1.update(dict2)`把`dict2`中的元素更新到`dict1`中。\n", "\n", " `fromkeys`:`dict.fromkeys(seq([, value]))`用于创建新的`Dictionary`,以序列`seq`中的元素做`Dictionary`的`key`,`value`为所有`key`对应的初始值。\n", "\n", " 示例如下,其中返回值中的`x`和`new_dict`是一个`Dictionary`,在图模式JIT语法支持级别选项为LAX下扩展支持,更多Dictionary的高阶使用请参考本文的[支持Dictionary的高阶用法](#支持dictionary的高阶用法)章节。" ] }, { "cell_type": "code", "execution_count": 126, "id": "123eaeb2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_keys:('a', 'b', 'c')\n", "x_values:(Tensor(shape=[3], dtype=Int64, value= [1, 2, 3]), Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]), Tensor(shape=[3], dtype=Int64, value= [7, 8, 9]))\n", "x_items:[('a', Tensor(shape=[3], dtype=Int64, value= [1, 2, 3])), ('b', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6])), ('c', Tensor(shape=[3], dtype=Int64, value= [7, 8, 9]))]\n", "value_a:[1 2 3]\n", "check_key:True\n", "new_x:{'a': Tensor(shape=[3], dtype=Int64, value= [0, 0, 0]), 'b': Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]), 'c': Tensor(shape=[3], dtype=Int64, value= [7, 8, 9])}\n", "new_dict:{'a': 123, 'b': 123, 'c': 123, 'd': 123}\n" ] } ], "source": [ "import mindspore as ms\n", "import numpy as np\n", "\n", "x = {\"a\": ms.Tensor(np.array([1, 2, 3])), \"b\": ms.Tensor(np.array([4, 5, 6])), \"c\": ms.Tensor(np.array([7, 8, 9]))}\n", "\n", "@ms.jit()\n", "def test_dict():\n", " x_keys = x.keys()\n", " x_values = x.values()\n", " x_items = x.items()\n", " value_a = x.get(\"a\")\n", " check_key = x.has_key(\"a\")\n", " y = {\"a\": ms.Tensor(np.array([0, 0, 0]))}\n", " x.update(y)\n", " new_dict = x.fromkeys(\"abcd\", 123)\n", " return x_keys, x_values, x_items, value_a, check_key, x, new_dict\n", "\n", "x_keys, x_values, x_items, value_a, check_key, new_x, new_dict = test_dict()\n", "print('x_keys:{}'.format(x_keys))\n", "print('x_values:{}'.format(x_values))\n", "print('x_items:{}'.format(x_items))\n", "print('value_a:{}'.format(value_a))\n", "print('check_key:{}'.format(check_key))\n", "print('new_x:{}'.format(new_x))\n", "print('new_dict:{}'.format(new_dict))" ] }, { "cell_type": "markdown", "id": "2674c0bf", "metadata": {}, "source": [ "#### MindSpore自定义数据类型\n", "\n", "当前MindSpore自定义数据类型包括:`Tensor`、`Primitive`、`Cell`和`Parameter`。\n", "\n", "##### Tensor\n", "\n", "Tensor的属性与接口详见[Tensor API文档](https://mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.Tensor.html#mindspore-tensor)。\n", "\n", "支持在静态图模式下创建和使用Tensor。创建方式有使用[tensor函数接口](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.tensor.html#mindspore.tensor)和使用`Tensor`类接口。推荐使用tensor函数接口,用户可以使用指定所需要的dtype类型。代码用例如下。" ] }, { "cell_type": "code", "execution_count": 127, "id": "2737331d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "import mindspore as ms\n", "import mindspore.nn as nn\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", "\n", " @ms.jit\n", " def construct(self, x):\n", " return ms.tensor(x.asnumpy(), dtype=ms.float32)\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "net = Net()\n", "x = ms.Tensor(1, dtype=ms.int32)\n", "print(net(x))" ] }, { "cell_type": "markdown", "id": "3eb7ed38", "metadata": {}, "source": [ "##### Primitive\n", "\n", "当前支持在construct里构造`Primitive`及其子类的实例。\n", "\n", "但在调用时,参数只能通过位置参数方式传入,不支持通过键值对方式传入。\n", "\n", "示例如下:" ] }, { "cell_type": "markdown", "id": "1d0c8c2a", "metadata": {}, "source": [ "```python\n", "import mindspore as ms\n", "from mindspore import nn, ops, Tensor, set_context\n", "import numpy as np\n", "\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def construct(self, x):\n", " reduce_sum = ops.ReduceSum(True) #支持在construct里构造`Primitive`及其子类的实例\n", " ret = reduce_sum(x, axis=2)\n", " return ret\n", "\n", "x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))\n", "net = Net()\n", "ret = net(x)\n", "print('ret.shape:{}'.format(ret.shape))\n", "```\n", "\n", "上面所定义的网络里,reduce_sum(x, axis=2)的参数不支持通过键值对方式传入,只能通过位置参数方式传入,即reduce_sum(x, 2)。\n", "\n", "结果报错如下:\n", "\n", "```text\n", "TypeError: For Primitive[ReduceSum], only positional arguments as inputs are supported, but got AbstractKeywordArg(key: axis, value: AbstractScalar(type: Int64, Value: 2, Shape: NoShape))\n", "```" ] }, { "cell_type": "markdown", "id": "88491f2a", "metadata": {}, "source": [ "当前不支持在网络调用`Primitive`及其子类相关属性和接口。\n", "\n", "当前已定义的`Primitive`详见[Primitive API文档](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/ops/mindspore.ops.Primitive.html#mindspore.ops.Primitive)。\n", "\n", "##### Cell\n", "\n", "当前支持在网络里构造`Cell`及其子类的实例,即支持语法`cell = Cell(args...)`。\n", "\n", "但在调用时,参数只能通过位置参数方式传入,不支持通过键值对方式传入,即不支持在语法`cell = Cell(arg_name=value)`。\n", "\n", "当前不支持在网络调用`Cell`及其子类相关属性和接口,除非是在`Cell`自己的`construct`中通过`self`调用。\n", "\n", "`Cell`定义详见[Cell API文档](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/nn/mindspore.nn.Cell.html)。\n", "\n", "##### Parameter\n", "\n", "`Parameter`是变量张量,代表在训练网络时,需要被更新的参数。\n", "\n", "`Parameter`的定义和使用详见[Parameter API文档](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.Parameter.html#mindspore.Parameter)。\n", "\n", "### 运算符\n", "\n", "算术运算符和赋值运算符支持`Number`和`Tensor`运算,也支持不同`dtype`的`Tensor`运算。详见[运算符](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/static_graph_syntax/operators.html)。\n", "\n", "### 原型\n", "\n", "原型代表编程语言中最紧密绑定的操作。\n", "\n", "#### 属性引用与修改\n", "\n", "属性引用是后面带有一个句点加一个名称的原型。\n", "\n", "在MindSpore的Cell 实例中使用属性引用作为左值需满足如下要求:\n", "\n", "- 被修改的属性属于本`cell`对象,即必须为`self.xxx`。\n", "- 该属性在Cell的`__init__`函数中完成初始化且其为Parameter类型。\n", "\n", "在JIT语法支持级别选项为`LAX`时,可以支持更多情况的属性修改,具体详见[支持属性设置与修改](#支持属性设置与修改)。\n", "\n", "示例如下:" ] }, { "cell_type": "code", "execution_count": 128, "id": "a72c6830", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ret:1\n" ] } ], "source": [ "import mindspore as ms\n", "from mindspore import nn, set_context\n", "\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.weight = ms.Parameter(ms.Tensor(3, ms.float32), name=\"w\")\n", " self.m = 2\n", "\n", " def construct(self, x, y):\n", " self.weight = x # 满足条件可以修改\n", " # self.m = 3 # self.m 非Parameter类型禁止修改\n", " # y.weight = x # y不是self,禁止修改\n", " return x\n", "\n", "net = Net()\n", "ret = net(1, 2)\n", "print('ret:{}'.format(ret))" ] }, { "cell_type": "markdown", "id": "376ee059", "metadata": {}, "source": [ "#### 索引取值\n", "\n", "对序列`Tuple`、`List`、`Dictionary`、`Tensor`的索引取值操作(Python称为抽取)。\n", "\n", "`Tuple`的索引取值请参考本文的[Tuple](#tuple)章节。\n", "\n", "`List`的索引取值请参考本文的[List](#list)章节。\n", "\n", "`Dictionary`的索引取值请参考本文的[Dictionary](#dictionary)章节。\n", "\n", "`Tensor`的索引取详见[Tensor 索引取值文档](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/index_support.html#索引取值)。\n", "\n", "#### 调用\n", "\n", "所谓调用就是附带可能为空的一系列参数来执行一个可调用对象(例如:`Cell`、`Primitive`)。\n", "\n", "示例如下:" ] }, { "cell_type": "code", "execution_count": 129, "id": "3f56afba", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ret:[[3. 3. 3. 3.]]\n" ] } ], "source": [ "import mindspore as ms\n", "from mindspore import nn, ops, set_context\n", "import numpy as np\n", "\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.matmul = ops.MatMul()\n", "\n", " def construct(self, x, y):\n", " out = self.matmul(x, y) # Primitive调用\n", " return out\n", "\n", "x = ms.Tensor(np.ones(shape=[1, 3]), ms.float32)\n", "y = ms.Tensor(np.ones(shape=[3, 4]), ms.float32)\n", "net = Net()\n", "ret = net(x, y)\n", "print('ret:{}'.format(ret))" ] }, { "cell_type": "markdown", "id": "3049a46e", "metadata": {}, "source": [ "### 语句\n", "\n", "当前静态图模式支持部分Python语句,包括raise语句、assert语句、pass语句、return语句、break语句、continue语句、if语句、for语句、while语句、with语句、列表生成式、生成器表达式、函数定义语句等,详见[Python语句](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/static_graph_syntax/statements.html)。\n", "\n", "### Python内置函数\n", "\n", "当前静态图模式支持部分Python内置函数,其使用方法与对应的Python内置函数类似,详见[Python内置函数](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/static_graph_syntax/python_builtin_functions.html)。\n", "\n", "### 网络定义\n", "\n", "#### 网络入参\n", "\n", "在对整网入参求梯度的时候,会忽略非`Tensor`的入参,只计算`Tensor`入参的梯度。\n", "\n", "示例如下。整网入参`(x, y, z)`中,`x`和`z`是`Tensor`,`y`是非`Tensor`。因此,`grad_net`在对整网入参`(x, y, z)`求梯度的时候,会自动忽略`y`的梯度,只计算`x`和`z`的梯度,返回`(grad_x, grad_z)`。" ] }, { "cell_type": "code", "execution_count": 130, "id": "054e6db9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ret:(Tensor(shape=[1], dtype=Int64, value= [1]), Tensor(shape=[1], dtype=Int64, value= [1]))\n" ] } ], "source": [ "import numpy as np\n", "import mindspore as ms\n", "from mindspore import nn\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", "\n", " def construct(self, x, y, z):\n", " return x + y + z\n", "\n", "class GradNet(nn.Cell):\n", " def __init__(self, net):\n", " super(GradNet, self).__init__()\n", " self.forward_net = net\n", "\n", " def construct(self, x, y, z):\n", " return ms.grad(self.forward_net, grad_position=(0, 1, 2))(x, y, z)\n", "\n", "input_x = ms.Tensor([1])\n", "input_y = 2\n", "input_z = ms.Tensor([3])\n", "\n", "net = Net()\n", "grad_net = GradNet(net)\n", "ret = grad_net(input_x, input_y, input_z)\n", "print('ret:{}'.format(ret))" ] }, { "cell_type": "markdown", "id": "9b823a13", "metadata": {}, "source": [ "## 基础语法的语法约束\n", "\n", "图模式下的执行图是从源码转换而来,并不是所有的Python语法都能支持。下面介绍在基础语法下存在的一些语法约束。更多网络编译问题可见[网络编译](https://www.mindspore.cn/docs/zh-CN/r2.3.0/faq/network_compilation.html)。\n", "\n", "1. 当`construct`函数里,使用未定义的类成员时,将抛出`AttributeError`异常。示例如下:\n", "\n", " ```python\n", " import mindspore as ms\n", " from mindspore import nn, set_context\n", "\n", " set_context(mode=ms.GRAPH_MODE)\n", "\n", " class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", "\n", " def construct(self, x):\n", " return x + self.y\n", "\n", " net = Net()\n", " net(1)\n", " ```\n", "\n", " 结果报错如下:\n", "\n", " AttributeError: External object has no attribute y" ] }, { "cell_type": "markdown", "id": "6b65d6ae", "metadata": {}, "source": [ "2. `nn.Cell`不支持`classmethod`修饰的类方法。示例如下:\n", "\n", " ```python\n", " import mindspore as ms\n", "\n", " ms.set_context(mode=ms.GRAPH_MODE)\n", "\n", " class Net(ms.nn.Cell):\n", " @classmethod\n", " def func(cls, x, y):\n", " return x + y\n", "\n", " def construct(self, x, y):\n", " return self.func(x, y)\n", "\n", " net = Net()\n", " out = net(ms.Tensor(1), ms.Tensor(2))\n", " print(out)\n", " ```\n", "\n", " 结果报错如下:\n", "\n", " TypeError: The parameters number of the function is 3, but the number of provided arguments is 2." ] }, { "cell_type": "markdown", "id": "b16fc212", "metadata": {}, "source": [ "3. 在图模式下,有些Python语法难以转换成图模式下的[中间表示MindIR](https://www.mindspore.cn/docs/zh-CN/r2.3.0/design/all_scenarios.html#中间表示mindir)。对标Python的关键字,存在部分关键字在图模式下是不支持的:AsyncFunctionDef、Delete、AnnAssign、AsyncFor、AsyncWith、Match、Try、Import、ImportFrom、Nonlocal、NamedExpr、Set、SetComp、Await、Yield、YieldFrom、Starred。如果在图模式下使用相关的语法,将会有相应的报错信息提醒用户。\n", "\n", " 如果使用Try语句,示例如下:\n", "\n", " ```python\n", " import mindspore as ms\n", "\n", " @ms.jit\n", " def test_try_except(x, y):\n", " global_out = 1\n", " try:\n", " global_out = x / y\n", " except ZeroDivisionError:\n", " print(\"division by zero, y is zero.\")\n", " return global_out\n", "\n", " test_try_except_out = test_try_except(1, 0)\n", " print(\"out:\", test_try_except_out)\n", " ```\n", "\n", " 结果报错如下:\n", "\n", " RuntimeError: Unsupported statement 'Try'." ] }, { "cell_type": "markdown", "id": "88db8a87", "metadata": {}, "source": [ "4. 对标Python内置数据类型,除去当前图模式下支持的[Python内置数据类型](#python内置数据类型),复数`complex`和集合`set`类型是不支持的。列表`list`和字典`dictionary`的一些高阶用法在基础语法场景下是不支持的,需要在JIT语法支持级别选项`jit_syntax_level`为`LAX`时才支持,更多请参考本文的[扩展语法(LAX级别)](#扩展语法lax级别)章节。\n", "\n", "5. 对标Python的内置函数,在基础语法场景下,除去当前图模式下支持的[Python内置函数](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/static_graph_syntax/python_builtin_functions.html),仍存在部分内置函数在图模式下是不支持的,例如:basestring、bin、bytearray、callable、chr、cmp、compile、 delattr、dir、divmod、eval、execfile、file、frozenset、hash、hex、id、input、issubclass、iter、locals、long、memoryview、next、object、oct、open、ord、property、raw_input、reduce、reload、repr、reverse、set、slice、sorted、unichr、unicode、vars、xrange、\\_\\_import\\_\\_。\n", "\n", "6. Python提供了很多第三方库,通常需要通过import语句调用。在图模式下JIT语法支持级别为STRICT时,不能直接使用第三方库。如果需要在图模式下使用第三方库的数据类型或者调用第三方库的方法,需要在JIT语法支持级别选项`jit_syntax_level`为`LAX`时才支持,更多请参考本文的[扩展语法(LAX级别)](#扩展语法lax级别)中的[调用第三方库](#调用第三方库)章节。\n", "\n", "## 扩展语法(LAX级别)\n", "\n", "下面主要介绍当前扩展支持的静态图语法。\n", "\n", "### 调用第三方库\n", "\n", "- 第三方库\n", "\n", " 1. Python内置模块和Python标准库。例如`os`、`sys`、`math`、`time`等模块。\n", "\n", " 2. 第三方代码库。路径在Python安装目录的`site-packages`目录下,需要先安装后导入,例如`NumPy`、`SciPy`等。需要注意的是,`mindyolo`、`mindflow`等MindSpore套件不被视作第三方库,具体列表可以参考[parser](https://gitee.com/mindspore/mindspore/blob/v2.3.0/mindspore/python/mindspore/_extends/parse/parser.py)文件的 `_modules_from_mindspore` 列表。\n", "\n", " 3. 通过环境变量`MS_JIT_IGNORE_MODULES`指定的模块。与之相对的有环境变量`MS_JIT_MODULES`,具体使用方法请参考[环境变量](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/env_var_list.html)。\n", "\n", "- 支持第三方库的数据类型,允许调用和返回第三方库的对象。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 131, "id": "6c620a7b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[5 7 9]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore as ms\n", "\n", "@ms.jit\n", "def func():\n", " a = np.array([1, 2, 3])\n", " b = np.array([4, 5, 6])\n", " out = a + b\n", " return out\n", "\n", "print(func())" ] }, { "cell_type": "markdown", "id": "462efb7d", "metadata": {}, "source": [ "- 支持调用第三方库的方法。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 132, "id": "f753b9bb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(2, 2)\n" ] } ], "source": [ "from scipy import linalg\n", "import mindspore as ms\n", "\n", "@ms.jit\n", "def func():\n", " x = [[1, 2], [3, 4]]\n", " return linalg.qr(x)\n", "\n", "out = func()\n", "print(out[0].shape)" ] }, { "cell_type": "markdown", "id": "c8534673", "metadata": {}, "source": [ "- 支持使用NumPy第三方库数据类型创建Tensor对象。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 133, "id": "2da4da0a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2 3 4]\n" ] } ], "source": [ "import numpy as np\n", "import mindspore as ms\n", "\n", "@ms.jit\n", "def func():\n", " x = np.array([1, 2, 3])\n", " out = ms.Tensor(x) + 1\n", " return out\n", "\n", "print(func())" ] }, { "cell_type": "markdown", "id": "427f2db9", "metadata": {}, "source": [ "- 暂不支持对第三方库数据类型的下标索引赋值。\n", "\n", " 示例如下:\n", "\n", " ```python\n", " import numpy as np\n", " import mindspore as ms\n", "\n", " @ms.jit\n", " def func():\n", " x = np.array([1, 2, 3])\n", " x[0] += 1\n", " return ms.Tensor(x)\n", "\n", " res = func()\n", " print(\"res: \", res)\n", " ```\n", "\n", " 报错信息如下:\n", "\n", " RuntimeError: For operation 'setitem', current input arguments types are <External, Number, Number>. The 1-th argument type 'External' is not supported now." ] }, { "cell_type": "markdown", "id": "86f8c818", "metadata": {}, "source": [ "### 支持自定义类的使用\n", "\n", "支持在图模式下使用用户自定义的类,可以对类进行实例化,使用对象的属性及方法。\n", "\n", "例如下面的例子,其中`GetattrClass`是用户自定义的类,没有使用`@jit_class`修饰,也没有继承`nn.Cell`。" ] }, { "cell_type": "code", "execution_count": 134, "id": "09b2b6ad", "metadata": {}, "outputs": [ ], "source": [ "import mindspore as ms\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "\n", "class GetattrClass():\n", " def __init__(self):\n", " self.attr1 = 99\n", " self.attr2 = 1\n", "\n", " def method1(self, x):\n", " return x + self.attr2\n", "\n", "class GetattrClassNet(ms.nn.Cell):\n", " def __init__(self):\n", " super(GetattrClassNet, self).__init__()\n", " self.cls = GetattrClass()\n", "\n", " def construct(self):\n", " return self.cls.method1(self.cls.attr1)\n", "\n", "net = GetattrClassNet()\n", "out = net()\n", "assert out == 100" ] }, { "cell_type": "markdown", "id": "8ce2737c", "metadata": {}, "source": [ "### 基础运算符支持更多数据类型\n", "\n", "在静态图语法重载了以下运算符: ['+', '-', '*', '/', '//', '%', '**', '<<', '>>', '&', '|', '^', 'not', '==', '!=', '<', '>', '<=', '>=', 'in', 'not in', 'y=x[0]']。图模式重载的运算符详见[运算符](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/static_graph_syntax/operators.html)。列表中的运算符在输入图模式中不支持的输入类型时将使用扩展静态图语法支持,并使输出结果与动态图模式下的输出结果一致。\n", "\n", "代码用例如下。" ] }, { "cell_type": "code", "execution_count": 135, "id": "7e5bfe55", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[5 7]\n" ] } ], "source": [ "import mindspore as ms\n", "import mindspore.nn as nn\n", "from mindspore import Tensor\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "\n", "class InnerClass(nn.Cell):\n", " def construct(self, x, y):\n", " return x.asnumpy() + y.asnumpy()\n", "\n", "net = InnerClass()\n", "ret = net(Tensor([4, 5]), Tensor([1, 2]))\n", "print(ret)" ] }, { "cell_type": "markdown", "id": "d0a195f3", "metadata": {}, "source": [ "上述例子中,`.asnumpy()`输出的数据类型: `numpy.ndarray`为运算符`+`在图模式中不支持的输入类型。因此`x.asnumpy() + y.asnumpy()`将使用扩展语法支持。\n", "\n", "在另一个用例中:" ] }, { "cell_type": "code", "execution_count": 136, "id": "191b044c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "True\n" ] } ], "source": [ "class InnerClass(nn.Cell):\n", " def construct(self):\n", " return (None, 1) in ((None, 1), 1, 2, 3)\n", "\n", "net = InnerClass()\n", "print(net())" ] }, { "cell_type": "markdown", "id": "608b823f", "metadata": {}, "source": [ "`tuple` in `tuple`在原本的图模式中是不支持的运算,现已使用扩展静态图语法支持。\n", "\n", "### 基础类型\n", "\n", "扩展对Python原生数据类型`List`、`Dictionary`、`None`的支持。\n", "\n", "#### 支持列表就地修改操作\n", "\n", "列表`List`以及元组`Tuple`是Python中最基本的序列内置类型,`List`与`Tuple`最核心的区别是`List`是可以改变的对象,而`Tuple`是不可以更改的。这意味着`Tuple`一旦被创建,就不可以在对象地址不变的情况下更改。而`List`则可以通过一系列inplace操作,在不改变对象地址的情况下,对对象进行修改。例如:" ] }, { "cell_type": "code", "execution_count": 137, "id": "c1a0a4c6", "metadata": {}, "outputs": [], "source": [ "a = [1, 2, 3, 4]\n", "a_id = id(a)\n", "a.append(5)\n", "a_after_id = id(a)\n", "assert a_id == a_after_id" ] }, { "cell_type": "markdown", "id": "c87bbd4b", "metadata": {}, "source": [ "上述示例代码中,通过`append`这个inplace语法更改`List`对象的时候,其对象的地址并没有被修改。而`Tuple`是不支持这种inplace操作的。在`JIT_SYNTAX_LEVEL`设置为`LAX`的情况下,静态图模式可以支持部分`List`对象的inplace操作。\n", "\n", "具体使用场景如下:\n", "\n", "- 支持从全局变量中获取原`List`对象。\n", "\n", " 在下面示例中,静态图获取到`List`对象,并在原有对象上进行了图模式支持的inplace操作`list.reverse()`, 并将原有对象返回。可以看到图模式返回的对象与原有的全局变量对象id相同,即两者为同一对象。若`JIT_SYNTAX_LEVEL`设置为`STRICT`选项,则返回的`List`对象与全局对象为两个不同的对象。" ] }, { "cell_type": "code", "execution_count": 138, "id": "7a3a949a", "metadata": {}, "outputs": [ ], "source": [ "import mindspore as ms\n", "\n", "global_list = [1, 2, 3, 4]\n", "\n", "@ms.jit\n", "def list_func():\n", " global_list.reverse()\n", " return global_list\n", "\n", "output = list_func() # output: [4, 3, 2, 1]\n", "assert id(global_list) == id(output)" ] }, { "cell_type": "markdown", "id": "b2babb6b", "metadata": {}, "source": [ "- 不支持对输入`List`对象进行inplace操作。\n", "\n", " `List`作为静态图输入时,会对该`List`对象进行一次复制,并使用该复制对象进行后续的计算,因此无法对原输入对象进行inplace操作。例如:\n", "\n", " ```python\n", " import mindspore as ms\n", "\n", " list_input = [1, 2, 3, 4]\n", "\n", " @ms.jit\n", " def list_func(x):\n", " x.reverse()\n", " return x\n", "\n", " output = list_func(list_input) # output: [4, 3, 2, 1] list_input: [1, 2, 3, 4]\n", " assert id(output) != id(list_input)\n", " ```\n", "\n", " 如上述用例所示,`List`对象作为图模式输入时无法在原有对象上进行inplace操作。图模式返回的对象与输入的对象id不同,为不同对象。\n", "\n", "- 支持部分`List`内置函数的就地修改操作。\n", "\n", " 在`JIT_SYNTAX_LEVEL`设置为`LAX`的情况下,图模式部分`List`内置函数支持inplace。在 `JIT_SYNTAX_LEVEL`为 `STRICT` 的情况下,所有方法均不支持inplace操作。\n", "\n", " 目前,图模式支持的`List`就地修改内置方法有`extend`、`pop`、`reverse`以及`insert`。内置方法`append`、`clear`以及索引赋值暂不支持就地修改,后续版本将会支持。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 139, "id": "a069afee", "metadata": {}, "outputs": [ ], "source": [ "import mindspore as ms\n", "\n", "list_input = [1, 2, 3, 4]\n", "\n", "@ms.jit\n", "def list_func():\n", " list_input.reverse()\n", " return list_input\n", "\n", "output = list_func() # output: [4, 3, 2, 1] list_input: [4, 3, 2, 1]\n", "assert id(output) == id(list_input)" ] }, { "cell_type": "markdown", "id": "6c698dab", "metadata": {}, "source": [ "#### 支持Dictionary的高阶用法\n", "\n", "- 支持顶图返回Dictionary。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 140, "id": "8c364502", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out: {'y': 'a'}\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit()\n", "def test_dict():\n", " x = {'a': 'a', 'b': 'b'}\n", " y = x.get('a')\n", " z = dict(y=y)\n", " return z\n", "\n", "out = test_dict()\n", "print(\"out:\", out)" ] }, { "cell_type": "markdown", "id": "a699b0a9", "metadata": {}, "source": [ "- 支持Dictionary索引取值和赋值。\n", "\n", " 示例如下:" ] }, { "cell_type": "code", "execution_count": 141, "id": "acd9c9fe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out1:{'a': (2, 3, 4), 'b': Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]), 'c': Tensor(shape=[3], dtype=Int64, value= [7, 8, 9])}\n", "out2:[4 5 6]\n" ] } ], "source": [ "import mindspore as ms\n", "import numpy as np\n", "\n", "x = {\"a\": ms.Tensor(np.array([1, 2, 3])), \"b\": ms.Tensor(np.array([4, 5, 6])), \"c\": ms.Tensor(np.array([7, 8, 9]))}\n", "\n", "@ms.jit()\n", "def test_dict():\n", " y = x[\"b\"]\n", " x[\"a\"] = (2, 3, 4)\n", " return x, y\n", "\n", "out1, out2 = test_dict()\n", "print('out1:{}'.format(out1))\n", "print('out2:{}'.format(out2))" ] }, { "cell_type": "markdown", "id": "1da1d787", "metadata": {}, "source": [ "#### 支持使用None\n", "\n", "`None`是Python中的一个特殊值,表示空,可以赋值给任何变量。对于没有返回值语句的函数认为返回`None`。同时也支持`None`作为顶图或者子图的入参或者返回值。支持`None`作为切片的下标,作为`List`、`Tuple`、`Dictionary`的输入。\n", "\n", "示例如下:" ] }, { "cell_type": "code", "execution_count": 142, "id": "464eb3da", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 'a', None)\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit\n", "def test_return_none():\n", " return 1, \"a\", None\n", "\n", "res = test_return_none()\n", "print(res)" ] }, { "cell_type": "markdown", "id": "ee35cd73", "metadata": {}, "source": [ "对于没有返回值的函数,默认返回`None`对象。" ] }, { "cell_type": "code", "execution_count": 143, "id": "e884ed7d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x:\n", "3\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit\n", "def foo():\n", " x = 3\n", " print(\"x:\", x)\n", "\n", "res = foo()\n", "assert res is None" ] }, { "cell_type": "markdown", "id": "518306a3", "metadata": {}, "source": [ "如下面例子,`None`作为顶图的默认入参。" ] }, { "cell_type": "code", "execution_count": 144, "id": "83c20151", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y is None\n", "x:\n", "[1, 2]\n" ] } ], "source": [ "import mindspore as ms\n", "\n", "@ms.jit\n", "def foo(x, y=None):\n", " if y is not None:\n", " print(\"y:\", y)\n", " else:\n", " print(\"y is None\")\n", " print(\"x:\", x)\n", " return y\n", "\n", "x = [1, 2]\n", "res = foo(x)\n", "assert res is None" ] }, { "cell_type": "markdown", "id": "d8903185", "metadata": {}, "source": [ "### 内置函数支持更多数据类型\n", "\n", "扩展内置函数的支持范围。Python内置函数完善支持更多输入类型,例如第三方库数据类型。\n", "\n", "例如下面的例子,`x.asnumpy()`和`np.ndarray`均是扩展支持的类型。更多内置函数的支持情况可见[Python内置函数](https://www.mindspore.cn/docs/zh-CN/r2.3.0/note/static_graph_syntax/python_builtin_functions.html)章节。" ] }, { "cell_type": "code", "execution_count": 145, "id": "539339ca", "metadata": {}, "outputs": [ ], "source": [ "import numpy as np\n", "import mindspore as ms\n", "import mindspore.nn as nn\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def construct(self, x):\n", " return isinstance(x.asnumpy(), np.ndarray)\n", "\n", "x = ms.Tensor(np.array([-1, 2, 4]))\n", "net = Net()\n", "out = net(x)\n", "assert out" ] }, { "cell_type": "markdown", "id": "94fb0c44", "metadata": {}, "source": [ "### 支持控制流\n", "\n", "为了提高Python标准语法支持度,实现动静统一,扩展支持更多数据类型在控制流语句的使用。控制流语句是指`if`、`for`、`while`等流程控制语句。理论上,通过扩展支持的语法,在控制流场景中也支持。代码用例如下:" ] }, { "cell_type": "code", "execution_count": 146, "id": "22cbbf77", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "res: 2\n" ] } ], "source": [ "import numpy as np\n", "import mindspore as ms\n", "\n", "@ms.jit\n", "def func():\n", " x = np.array(1)\n", " if x <= 1:\n", " x += 1\n", " return ms.Tensor(x)\n", "\n", "res = func()\n", "print(\"res: \", res)" ] }, { "cell_type": "markdown", "id": "cd526a08", "metadata": {}, "source": [ "### 支持属性设置与修改\n", "\n", "具体使用场景如下:\n", "\n", "- 对自定义类对象以及第三方类型的属性进行设置与修改。\n", "\n", " 图模式下支持对自定义类对象的属性进行设置与修改,例如:" ] }, { "cell_type": "code", "execution_count": 147, "id": "e1d4b665", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "obj.x is: 100\n" ] } ], "source": [ "from mindspore import jit\n", "\n", "class AssignClass():\n", " def __init__(self):\n", " self.x = 1\n", "\n", "obj = AssignClass()\n", "\n", "@jit\n", "def foo():\n", " obj.x = 100\n", "\n", "foo()\n", "print(f\"obj.x is: {obj.x}\")" ] }, { "cell_type": "markdown", "id": "eb441e21", "metadata": {}, "source": [ " 图模式下支持对第三方库对象的属性进行设置与修改,例如:\n", "\n" ] }, { "cell_type": "code", "execution_count": 148, "id": "70e5ef89", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape is (2, 2)\n" ] } ], "source": [ "from mindspore import jit\n", "import numpy as np\n", "\n", "@jit\n", "def foo():\n", " a = np.array([1, 2, 3, 4])\n", " a.shape = (2, 2)\n", " return a.shape\n", "\n", "shape = foo()\n", "print(f\"shape is {shape}\")\n", "\n" ] }, { "cell_type": "markdown", "id": "aa48654a", "metadata": {}, "source": [ "- 对Cell的self对象进行修改,例如:\n", "\n" ] }, { "cell_type": "code", "execution_count": 148, "id": "70e5ef89", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "net.m is 3\n" ] } ], "source": [ "import mindspore as ms\n", "from mindspore import nn, set_context\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.m = 2\n", "\n", " def construct(self):\n", " self.m = 3\n", " return 0\n", "\n", "net = Net()\n", "net()\n", "print(f\"net.m is {net.m}\")\n", "\n" ] }, { "cell_type": "markdown", "id": "88157a32", "metadata": {}, "source": [ " 注意,self对象支持属性修改和设置。若`__init__`内没有定义某个属性,对齐PYNATIVE模式,图模式也允许设置此属性。例如:\n", "\n" ] }, { "cell_type": "code", "execution_count": 149, "id": "b8a58e03", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "net.m2 is 3\n" ] } ], "source": [ "import mindspore as ms\n", "from mindspore import nn, set_context\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super().__init__()\n", " self.m = 2\n", "\n", " def construct(self):\n", " self.m2 = 3\n", " return 0\n", "\n", "net = Net()\n", "net()\n", "print(f\"net.m2 is {net.m2}\")" ] }, { "cell_type": "markdown", "id": "4f23b346", "metadata": {}, "source": [ "- 对静态图内的Cell对象以及jit_class对象进行设置与修改。\n", "\n", " 支持对图模式jit_class对象进行属性修改,例如:\n", "\n" ] }, { "cell_type": "code", "execution_count": 150, "id": "c80302dd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "net.inner.x is 100\n" ] } ], "source": [ "import mindspore as ms\n", "from mindspore import nn, set_context, jit_class\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "@jit_class\n", "class InnerClass():\n", " def __init__(self):\n", " self.x = 10\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.inner = InnerClass()\n", "\n", " def construct(self):\n", " self.inner.x = 100\n", " return 0\n", "\n", "net = Net()\n", "net()\n", "print(f\"net.inner.x is {net.inner.x}\")\n" ] }, { "cell_type": "markdown", "id": "2df38d32", "metadata": {}, "source": [ "### 支持求导\n", "\n", "扩展支持的静态图语法,同样支持其在求导中使用,例如:" ] }, { "cell_type": "code", "execution_count": 151, "id": "ed624aff", "metadata": {}, "outputs": [], "source": [ "import mindspore as ms\n", "from mindspore import ops, set_context\n", "set_context(mode=ms.GRAPH_MODE)\n", "\n", "@ms.jit\n", "def dict_net(a):\n", " x = {'a': a, 'b': 2}\n", " return a, (x, (1, 2))\n", "\n", "out = ops.grad(dict_net)(ms.Tensor([1]))\n", "assert out == 2" ] }, { "cell_type": "markdown", "id": "62e773d5", "metadata": {}, "source": [ "### Annotation Type\n", "\n", "对于运行时的扩展支持的语法,会产生一些无法被类型推导出的节点,比如动态创建Tensor等。这种类型称为`Any`类型。因为该类型无法在编译时推导出正确的类型,所以这种`Any`将会以一种默认最大精度`float64`进行运算,防止其精度丢失。为了能更好的优化相关性能,需要减少`Any`类型数据的产生。当用户可以明确知道当前通过扩展支持的语句会产生具体类型的时候,我们推荐使用`Annotation @jit.typing:`的方式进行指定对应Python语句类型,从而确定解释节点的类型避免`Any`类型的生成。\n", "\n", "例如,[Tensor](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.Tensor.html#mindspore.Tensor)类和[tensor](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/mindspore/mindspore.tensor.html#mindspore.tensor)接口的区别就在于在`tensor`接口内部运用了Annotation Type机制。当`tensor`函数的`dtype`确定时,函数内部会利用`Annotation`指定输出类型从而避免`Any`类型的产生。`Annotation Type`的使用只需要在对应Python语句上面或者后面加上注释 `# @jit.typing: () -> tensor_type[float32]` 即可,其中 `->` 后面的 `tensor_type[float32]` 指示了被注释的语句输出类型。\n", "\n", "代码用例如下。" ] }, { "cell_type": "code", "execution_count": 152, "id": "faed046a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "y1 value is 2.0, dtype is Float32\n", "y2 value is 2.0, dtype is Float32\n", "y3 value is 2.0, dtype is Float64\n", "y4 value is 2.0, dtype is Float32\n" ] } ], "source": [ "import mindspore as ms\n", "import mindspore.nn as nn\n", "from mindspore import ops, Tensor\n", "\n", "class Net(nn.Cell):\n", " def __init__(self):\n", " super(Net, self).__init__()\n", " self.abs = ops.Abs()\n", "\n", " @ms.jit\n", " def construct(self, x, y):\n", " y1 = ms.tensor(x.asnumpy() + y.asnumpy(), dtype=ms.float32)\n", " y2 = ms.Tensor(x.asnumpy() + y.asnumpy(), dtype=ms.float32) # @jit.typing: () -> tensor_type[float32]\n", " y3 = Tensor(x.asnumpy() + y.asnumpy())\n", " y4 = Tensor(x.asnumpy() + y.asnumpy(), dtype=ms.float32)\n", " return self.abs(y1), self.abs(y2), self.abs(y3), self.abs(y4)\n", "\n", "ms.set_context(mode=ms.GRAPH_MODE)\n", "net = Net()\n", "x = ms.Tensor(-1, dtype=ms.int32)\n", "y = ms.Tensor(-1, dtype=ms.float32)\n", "y1, y2, y3, y4 = net(x, y)\n", "\n", "print(f\"y1 value is {y1}, dtype is {y1.dtype}\")\n", "print(f\"y2 value is {y2}, dtype is {y2.dtype}\")\n", "print(f\"y3 value is {y3}, dtype is {y3.dtype}\")\n", "print(f\"y4 value is {y4}, dtype is {y4.dtype}\")" ] }, { "cell_type": "markdown", "id": "7c94fb01", "metadata": {}, "source": [ "上述例子,可以看到创建了`Tensor`的相关区别。对于`y3`、`y4`,因为`Tensor`类没有增加`Annotation`指示,`y3`、`y4`没有办法推出正确的类型,导致只能按照最高精度`float64`进行运算。\n", "对于`y2`,由于创建`Tensor`时,通过`Annotation`指定了对应类型,使得其类型可以按照指定类型进行运算。\n", "对于`y1`,由于使用了`tensor`函数接口创建`Tensor`,传入的`dtype`参数作为`Annotation`的指定类型,所以也避免了`Any`类型的产生。\n", "\n", "## 扩展语法的语法约束\n", "\n", "在使用静态图扩展支持语法时,请注意以下几点:\n", "\n", "1. 对标动态图的支持能力,即:须在动态图语法范围内,包括但不限于数据类型等。\n", "\n", "2. 在扩展静态图语法时,支持了更多的语法,但执行性能可能会受影响,不是最佳。\n", "\n", "3. 在扩展静态图语法时,支持了更多的语法,由于使用Python的能力,不能使用MindIR导入导出的能力。\n", "\n", "4. 暂不支持跨Python文件重复定义同名的全局变量,且这些全局变量在网络中会被用到。" ] } ], "metadata": { "kernelspec": { "display_name": "MindSpore", "language": "python", "name": "mindspore" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }