静态图

https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.svg https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.svg 查看源文件

概述

在Graph模式下,Python代码并不是由Python解释器去执行,而是将代码编译成静态计算图,然后执行静态计算图。

在静态图模式下,MindSpore通过源码转换的方式,将Python的源码转换成中间表达IR(Intermediate Representation),并在此基础上对IR图进行优化,最终在硬件设备上执行优化后的图。MindSpore使用基于图表示的函数式IR,称为MindIR,详情可参考中间表示MindIR

MindSpore的静态图执行过程实际包含两步,对应静态图的Define和Run阶段,但在实际使用中,在实例化的Cell对象被调用时用户并不会分别感知到这两阶段,MindSpore将两阶段均封装在Cell的__call__方法中,因此实际调用过程为:

model(inputs) = model.compile(inputs) + model.construct(inputs),其中model为实例化Cell对象。

使用Graph模式需要设置ms.set_context(mode=ms.GRAPH_MODE),使用Cell类并且在construct函数中编写执行代码,此时construct函数的代码将会被编译成静态计算图。Cell定义详见Cell API文档

由于语法解析的限制,当前在编译构图时,支持的数据类型、语法以及相关操作并没有完全与Python语法保持一致,部分使用受限。借鉴传统JIT编译的思路,从图模式的角度考虑动静图的统一,扩展图模式的语法能力,使得静态图提供接近动态图的语法使用体验,从而实现动静统一。为了便于用户选择是否扩展静态图语法,提供了JIT语法支持级别选项jit_syntax_level,其值必须在[STRICT,LAX]范围内,选择STRICT则认为使用基础语法,不扩展静态图语法。默认值为LAX,更多请参考本文的扩展语法(LAX级别)章节。全部级别都支持所有后端。

  • STRICT: 仅支持基础语法,且执行性能最佳。可用于MindIR导入导出。

  • LAX: 支持更多复杂语法,最大程度地兼容Python所有语法。由于存在可能无法导出的语法,不能用于MindIR导入导出。

本文主要介绍,在编译静态图时,支持的数据类型、语法以及相关操作,这些规则仅适用于Graph模式。

基础语法(STRICT级别)

静态图内的常量与变量

在静态图中,常量与变量是理解静态图语法的一个重要概念,很多语法在常量输入和变量输入情况下支持的方法与程度是不同的。因此,在介绍静态图具体支持的语法之前,本小节先会对静态图中常量与变量的概念进行说明。

在静态图模式下,一段程序的运行会被分为编译期以及执行期。 在编译期,程序会被编译成一张中间表示图,并且程序不会真正的执行,而是通过抽象推导的方式对中间表示进行静态解析。这使得在编译期时,我们无法保证能获取到所有中间表示中节点的值。 常量和变量也就是通过能否能在编译器获取到其真实值来区分的。

  • 常量: 编译期内可以获取到值的量。

  • 变量: 编译期内无法获取到值的量。

常量产生场景

  • 作为图模式输入的标量,列表以及元组均为常量(在不使用mutable接口的情况下)。例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    a = 1
    b = [Tensor([1]), Tensor([2])]
    c = ["a", "b", "c"]
    
    class Net(nn.Cell):
       def construct(self, a, b, c):
          return a, b, c
    

    上述代码中,输入abc均为常量。

  • 图模式内生成的标量或者Tensor为常量。例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          a = 1
          b = "2"
          c = Tensor([1, 2, 3])
          return a, b, c
    

    上述代码中, abc均为常量。

  • 常量运算得到的结果为常量。例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          a = Tensor([1, 2, 3])
          b = Tensor([1, 1, 1])
          c = a + b
          return c
    

    上述代码中,ab均为图模式内产生的Tensor为常量,因此其计算得到的结果也是常量。但如果其中之一为变量时,其返回值也会为变量。

变量产生场景

  • 所有mutable接口的返回值均为变量(无论是在图外使用mutable还是在图内使用)。例如:

    from mindspore import mutable
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    a = mutable([Tensor([1]), Tensor([2])])
    
    class Net(nn.Cell):
       def construct(self, a):
          b = mutable(Tensor([3]))
          c = mutable((Tensor([1]), Tensor([2])))
          return a, b, c
    

    上述代码中,a是在图外调用mutable接口的,bc是在图内调用mutable接口生成的,abc均为变量。

  • 作为静态图的输入的Tensor都是变量。例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    a = Tensor([1])
    b = (Tensor([1]), Tensor([2]))
    
    class Net(nn.Cell):
       def construct(self, a, b):
          return a, b
    

    上述代码中,a是作为图模式输入的Tensor,因此其为变量。但b是作为图模式输入的元组,非Tensor类型,即使其内部的元素均为Tensor,b也是常量。

  • 通过变量计算得到的是变量。

    如果一个量是算子的输出,那么其多数情况下为变量。例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    a = Tensor([1])
    b = Tensor([2])
    
    class Net(nn.Cell):
       def construct(self, a, b):
          c = a + b
          return c
    

    在这种情况下,cab计算来的结果,且用来计算的输入ab均为变量,因此c也是变量。

数据类型

Python内置数据类型

当前支持的Python内置数据类型包括:NumberStringListTupleDictionary

Number

支持int(整型)、float(浮点型)、bool(布尔类型),不支持complex(复数)。

支持在网络里定义Number,即支持语法:y = 1y = 1.2y = True

当数据为常量时,编译时期可以获取到数值,在网络中可以支持强转Number的语法:y = int(x)y = float(x)y = bool(x)。 当数据为变量时,即需要在运行时期才可以获取到数值,也支持使用int(),float(),bool()等内置函数Python内置函数进行数据类型的转换。例如:

import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import Tensor

context.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self, x):
      out1 = int(11.1)
      out2 = int(Tensor([10]))
      out3 = int(x.asnumpy())
      return out1, out2, out3

net = Net()
res = net(Tensor(2))
print("res[0]:", res[0])
print("res[1]:", res[1])
print("res[2]:", res[2])

运行结果如下:

res[0]: 11
res[1]: 10
res[2]: 2

支持返回Number类型。例如:

import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import Tensor

context.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self, x, y):
      return x + y

net = Net()
res = net(ms.mutable(1), ms.mutable(2))
print(res)

运行结果如下:

3
String

支持在网络里构造String,即支持使用引号('")来创建字符串,如x = 'abcd'y = "efgh"。可以通过str()的方式进行将常量转换成字符串。支持对字符串连接,截取,以及使用成员运算符(innot in)判断字符串是否包含指定的字符。支持格式化字符串的输出,将一个值插入到一个有字符串格式符%s的字符串中。支持在常量场景下使用格式化字符串函数str.format()

例如:

import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import Tensor

context.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self):
      var1 = 'Hello!'
      var2 = "MindSpore"
      var3 = str(123)
      var4 = "{} is {}".format("string", var3)
      return var1[0], var2[4:9], var1 + var2, var2 * 2, "H" in var1, "My name is %s!" % var2, var4

net = Net()
res = net()
print("res:", res)

运行结果如下:

res: ('H', 'Spore', 'Hello!MindSpore', 'MindSporeMindSpore', True, 'My name is MindSpore!', 'string is 123')
List

JIT_SYNTAX_LEVEL设置为LAX的情况下,静态图模式可以支持部分List对象的inplace操作,具体介绍详见支持列表就地修改操作章节。

List的基础使用场景如下:

  • 图模式支持图内创建List

    支持在图模式内创建List对象,且List内对象的元素可以包含任意图模式支持的类型,也支持多层嵌套。例如:

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          a = [1, 2, 3, 4]
          b = ["1", "2", "a"]
          c = [ms.Tensor([1]), ms.Tensor([2])]
          d = [a, b, c, (4, 5)]
          return d
    

    上述示例代码中,所有的List对象都可以被正常的创建。

  • 图模式支持返回List

    在MindSpore2.0版本之前,当图模式返回List 对象时,List会被转换为Tuple。MindSpore2.0版本已经可以支持返回List对象。例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          a = [1, 2, 3, 4]
          return a
    
    net = Net()
    output = net()  # output: [1, 2, 3, 4]
    

    与图模式内创建List 相同,图模式返回List对象可以包括任意图模式支持的类型,也支持多层嵌套。

  • 图模式支持从全局变量中获取List对象。

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    global_list = [1, 2, 3, 4]
    
    class Net(nn.Cell):
       def construct(self):
          global_list.reverse()
          return global_list
    
    net = Net()
    output = net()  # output: [4, 3, 2, 1]
    

    需要注意的是,在基础场景下图模式返回的列表与全局变量的列表不是同一个对象,当JIT_SYNTAX_LEVEL设置为LAX时,返回的对象与全局对象为统一对象。

  • 图模式支持以List作为输入。

    图模式支持List作为静态图的输入,作为输入的List对象的元素必须为图模式支持的输入类型,也支持多层嵌套。

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    list_input = [1, 2, 3, 4]
    
    class Net(nn.Cell):
       def construct(self, x):
          return x
    
    net = Net()
    output = net(list_input)  # output: [1, 2, 3, 4]
    

    需要注意的是,List作为静态图输入时,无论其内部的元素是什么类型,一律被视为常量。

  • 图模式支持List的内置方法。

    List 内置方法的详细介绍如下:

    • List索引取值

      基础语法:element = list_object[index]

      基础语义:将List对象中位于第index位的元素提取出来(index从0开始)。支持多层索引取值。

      索引值index支持类型包括intTensorslice。其中,int以及Tensor类型的输入可以支持常量以及变量,slice内部数据必须为编译时能够确定的常量。

      示例如下:

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      class Net(nn.Cell):
         def construct(self):
            x = [[1, 2], 3, 4]
            a = x[0]
            b = x[0][ms.Tensor([1])]
            c = x[1:3:1]
            return a, b, c
      
      net = Net()
      a, b, c = net()
      print('a:{}'.format(a))
      print('b:{}'.format(b))
      print('c:{}'.format(c))
      

      运行结果如下:

      a:[1, 2]
      b:2
      c:[3, 4]
      
    • List索引赋值

      基础语法:list_object[index] = target_element

      基础语义:将List对象中位于第index位的元素赋值为 target_elementindex从0开始)。支持多层索引赋值。

      索引值index支持类型包括intTensorslice。其中,int 以及Tensor类型的输入可以支持常量以及变量,slice内部数据必须为编译时能够确定的常量。

      索引赋值对象target_element支持所有图模式支持的数据类型。

      目前,List索引赋值不支持inplace操作, 索引赋值后将会生成一个新的对象。该操作后续将会支持inplace操作。

      示例如下:

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      
      class Net(nn.Cell):
         def construct(self):
            x = [[0, 1], 2, 3, 4]
            x[1] = 10
            x[2] = "ok"
            x[3] = (1, 2, 3)
            x[0][1] = 88
            return x
      
      net = Net()
      output = net()
      print('output:{}'.format(output))
      

      运行结果如下:

      output:[[0, 88], 10, 'ok', (1, 2, 3)]
      
    • List.append

      基础语法:list_object.append(target_element)

      基础语义:向List对象list_object的最后追加元素target_element

      目前,List.append不支持inplace操作, 追加元素后将会生成一个新的对象。该操作后续将会支持inplace操作。

      示例如下:

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      class Net(nn.Cell):
         def construct(self):
            x = [1, 2, 3]
            x.append(4)
            return x
      
      net = Net()
      x = net()
      print('x:{}'.format(x))
      

      运行结果如下:

      x:[1, 2, 3, 4]
      
    • List.clear

      基础语法:list_object.clear()

      基础语义:清空List对象 list_object中包含的元素。

      目前,List.clear不支持inplace, 清空元素后将会生成一个新的对象。该操作后续将会支持inplace。

      示例如下:

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      class Net(nn.Cell):
         def construct(self):
            x = [1, 3, 4]
            x.clear()
            return x
      
      net = Net()
      x = net()
      print('x:{}'.format(x))
      

      运行结果如下:

      x:[]
      
    • List.extend

      基础语法:list_object.extend(target)

      基础语义:向List对象list_object的最后依次插入target内的所有元素。

      target支持的类型为TupleList以及Tensor。其中,如果target类型为Tensor的情况下,会先将该Tensor转换为List,再进行插入操作。

      示例如下:

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      class Net(nn.Cell):
         def construct(self):
            x1 = [1, 2, 3]
            x1.extend((4, "a"))
            x2 = [1, 2, 3]
            x2.extend(ms.Tensor([4, 5]))
            return x1, x2
      
      net = Net()
      output1, output2 = net()
      print('output1:{}'.format(output1))
      print('output2:{}'.format(output2))
      

      运行结果如下:

      output1:[1, 2, 3, 4, 'a']
      output2:[1, 2, 3, Tensor(shape=[], dtype=Int64, value= 4), Tensor(shape=[], dtype=Int64, value= 5)]
      
    • List.pop

      基础语法:pop_element = list_object.pop(index=-1)

      基础语义:将List对象list_object 的第index个元素从list_object中删除,并返回该元素。

      index 要求必须为常量int, 当list_object的长度为list_obj_size时,index的取值范围为:[-list_obj_size,list_obj_size-1]index为负数,代表从后往前的位数。当没有输入index时,默认值为-1,即删除最后一个元素。

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      class Net(nn.Cell):
         def construct(self):
            x = [1, 2, 3]
            b = x.pop()
            return b, x
      
      net = Net()
      pop_element, res_list = net()
      print('pop_element:{}'.format(pop_element))
      print('res_list:{}'.format(res_list))
      

      运行结果如下:

      pop_element:3
      res_list:[1, 2]
      
    • List.reverse

      基础语法:list_object.reverse()

      基础语义:将List对象list_object的元素顺序倒转。

      示例如下:

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      class Net(nn.Cell):
         def construct(self):
            x = [1, 2, 3]
            x.reverse()
            return x
      
      net = Net()
      output = net()
      print('output:{}'.format(output))
      

      运行结果如下:

      output:[3, 2, 1]
      
    • List.insert

      基础语法:list_object.insert(index, target_obj)

      基础语义:将target_obj插入到list_object的第index位。

      index要求必须为常量int。如果list_object的长度为list_obj_size。当index < -list_obj_size时,插入到List的第一位。当index >= list_obj_size时,插入到List的最后。index为负数代表从后往前的位数。

      示例如下:

      import mindspore as ms
      from mindspore import nn
      from mindspore import context
      from mindspore import Tensor
      
      context.set_context(mode=ms.GRAPH_MODE)
      
      class Net(nn.Cell):
         def construct(self):
            x = [1, 2, 3]
            x.insert(3, 4)
            return x
      
      net = Net()
      output = net()
      print('output:{}'.format(output))
      

      运行结果如下:

      output:[1, 2, 3, 4]
      
Tuple

支持在网络里构造元组Tuple,使用小括号包含元素,即支持语法y = (1, 2, 3)。元组Tuple的元素不能修改,但支持索引访问元组Tuple中的元素,支持对元组进行连接组合。

  • 支持索引取值。

    支持使用方括号加下标索引的形式来访问元组Tuple中的元素,索引值支持intsliceTensor,也支持多层索引取值,即支持语法data = tuple_x[index0][index1]...

    索引值为Tensor有如下限制:

    • Tuple里存放的都是Cell,每个Cell要在Tuple定义之前完成定义,每个Cell的入参个数、入参类型和入参shape要求一致,每个Cell的输出个数、输出类型和输出shape也要求一致。

    • 索引Tensor是一个dtypeint32的标量Tensor,取值范围在[-tuple_len, tuple_len)

    • 支持CPUGPUAscend后端。

    intslice索引示例如下:

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    t = ms.Tensor(np.array([1, 2, 3]))
    
    class Net(nn.Cell):
       def construct(self):
          x = (1, (2, 3, 4), 3, 4, t)
          y = x[1][1]
          z = x[4]
          m = x[1:4]
          n = x[-4]
          return y, z, m, n
    
    net = Net()
    y, z, m, n = net()
    print('y:{}'.format(y))
    print('z:{}'.format(z))
    print('m:{}'.format(m))
    print('n:{}'.format(n))
    

    运行结果如下:

    y:3
    z:[1 2 3]
    m:((2, 3, 4), 3, 4)
    n:(2, 3, 4)
    

    Tensor索引示例如下:

    import mindspore as ms
    from mindspore import nn, set_context
    
    set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def __init__(self):
          super(Net, self).__init__()
          self.relu = nn.ReLU()
          self.softmax = nn.Softmax()
          self.layers = (self.relu, self.softmax)
    
       def construct(self, x, index):
          ret = self.layers[index](x)
          return ret
    
    x = ms.Tensor([-1.0], ms.float32)
    
    net = Net()
    ret = net(x, 0)
    print('ret:{}'.format(ret))
    

    运行结果如下:

    ret:[0.]
    
  • 支持连接组合。

    与字符串String类似,元组支持使用+*进行组合,得到一个新的元组Tuple,例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          x = (1, 2, 3)
          y = (4, 5, 6)
          return x + y, x * 2
    
    net = Net()
    out1, out2 = net()
    print('out1:{}'.format(out1))
    print('out2:{}'.format(out2))
    

    运行结果如下:

    out1:(1, 2, 3, 4, 5, 6)
    out2:(1, 2, 3, 1, 2, 3)
    
Dictionary

支持在网络里构造字典Dictionary,每个键值key:value用冒号:分割,每个键值对之间用逗号,分割,整个字典使用大括号{}包含键值对,即支持语法y = {"a": 1, "b": 2}

key是唯一的,如果字典中存在多个相同的key,则重复的key以最后一个作为最终结果;而值value可以不是唯一的。键key需要保证是不可变的。当前键key支持StringNumber、常量Tensor以及只包含这些类型对象的Tuple;值value支持NumberTupleTensorListDictionaryNone

  • 支持接口。

    keys:取出dict里所有的key值,组成Tuple返回。

    values:取出dict里所有的value值,组成Tuple返回。

    items:取出dict里每一对keyvalue组成的Tuple,最终组成List返回。

    getdict.get(key[, value])返回指定key对应的value值,如果指定key不存在,返回默认值None或者设置的默认值value

    clear:删除dict里所有的元素。

    has_keydict.has_key(key)判断dict里是否存在指定key

    updatedict1.update(dict2)dict2中的元素更新到dict1中。

    fromkeysdict.fromkeys(seq([, value]))用于创建新的Dictionary,以序列seq中的元素做Dictionarykeyvalue为所有key对应的初始值。

    示例如下,其中返回值中的xnew_dict是一个Dictionary,在图模式JIT语法支持级别选项为LAX下扩展支持,更多Dictionary的高阶使用请参考本文的支持Dictionary的高阶用法章节。

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    x = {"a": ms.Tensor(np.array([1, 2, 3])), "b": ms.Tensor(np.array([4, 5, 6])), "c": ms.Tensor(np.array([7, 8, 9]))}
    
    class Net(nn.Cell):
       def construct(self):
          x_keys = x.keys()
          x_values = x.values()
          x_items = x.items()
          value_a = x.get("a")
          check_key = x.has_key("a")
          y = {"a": ms.Tensor(np.array([0, 0, 0]))}
          x.update(y)
          new_dict = x.fromkeys("abcd", 123)
          return x_keys, x_values, x_items, value_a, check_key, x, new_dict
    
    net = Net()
    x_keys, x_values, x_items, value_a, check_key, new_x, new_dict = net()
    print('x_keys:{}'.format(x_keys))
    print('x_values:{}'.format(x_values))
    print('x_items:{}'.format(x_items))
    print('value_a:{}'.format(value_a))
    print('check_key:{}'.format(check_key))
    print('new_x:{}'.format(new_x))
    print('new_dict:{}'.format(new_dict))
    

    运行结果如下:

    x_keys:('a', 'b', 'c')
    x_values:(Tensor(shape=[3], dtype=Int64, value= [1, 2, 3]), Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]), Tensor(shape=[3], dtype=Int64, value= [7, 8, 9]))
    x_items:[('a', Tensor(shape=[3], dtype=Int64, value= [1, 2, 3])), ('b', Tensor(shape=[3], dtype=Int64, value= [4, 5, 6])), ('c', Tensor(shape=[3], dtype=Int64, value= [7, 8, 9]))]
    value_a:[1 2 3]
    check_key:True
    new_x:{'a': Tensor(shape=[3], dtype=Int64, value= [0, 0, 0]), 'b': Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]), 'c': Tensor(shape=[3], dtype=Int64, value= [7, 8, 9])}
    new_dict:{'a': 123, 'b': 123, 'c': 123, 'd': 123}
    

MindSpore自定义数据类型

当前MindSpore自定义数据类型包括:TensorPrimitiveCellParameter

Tensor

Tensor的属性与接口详见Tensor API文档

支持在静态图模式下创建和使用Tensor。创建方式有使用tensor函数接口和使用Tensor类接口。推荐使用tensor函数接口,用户可以使用指定所需要的dtype类型。代码用例如下。

import mindspore as ms
import mindspore.nn as nn

class Net(nn.Cell):
   def __init__(self):
      super(Net, self).__init__()

   def construct(self, x):
      return ms.tensor(x.asnumpy(), dtype=ms.float32)

ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
x = ms.Tensor(1, dtype=ms.int32)
print(net(x))

运行结果如下:

1.0
Primitive

当前支持在construct里构造Primitive及其子类的实例。

示例如下:

import mindspore as ms
from mindspore import nn, ops, Tensor, set_context
import numpy as np

set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def __init__(self):
      super(Net, self).__init__()

   def construct(self, x):
      reduce_sum = ops.ReduceSum(True) #支持在construct里构造`Primitive`及其子类的实例
      ret = reduce_sum(x, axis=2)
      return ret

x = Tensor(np.random.randn(3, 4, 5, 6).astype(np.float32))
net = Net()
ret = net(x)
print('ret.shape:{}'.format(ret.shape))

运行结果如下:

ret.shape:(3, 4, 1, 6)

当前不支持在网络调用Primitive及其子类相关属性和接口。

当前已定义的Primitive详见Primitive API文档

Cell

当前支持在网络里构造Cell及其子类的实例,即支持语法cell = Cell(args...)

但在调用时,参数只能通过位置参数方式传入,不支持通过键值对方式传入,即不支持在语法cell = Cell(arg_name=value)

当前不支持在网络调用Cell及其子类相关属性和接口,除非是在Cell自己的construct中通过self调用。

Cell定义详见Cell API文档

Parameter

Parameter是变量张量,代表在训练网络时,需要被更新的参数。

Parameter的定义和使用详见Parameter API文档

运算符

算术运算符和赋值运算符支持NumberTensor运算,也支持不同dtypeTensor运算。详见运算符

原型

原型代表编程语言中最紧密绑定的操作。

属性引用与修改

属性引用是后面带有一个句点加一个名称的原型。

在MindSpore的Cell 实例中使用属性引用作为左值需满足如下要求:

  • 被修改的属性属于本cell对象,即必须为self.xxx

  • 该属性在Cell的__init__函数中完成初始化且其为Parameter类型。

在JIT语法支持级别选项为LAX时,可以支持更多情况的属性修改,具体详见支持属性设置与修改

示例如下:

import mindspore as ms
from mindspore import nn, set_context

set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def __init__(self):
      super().__init__()
      self.weight = ms.Parameter(ms.Tensor(3, ms.float32), name="w")
      self.m = 2

   def construct(self, x, y):
      self.weight = x  # 满足条件可以修改
      # self.m = 3     # self.m 非Parameter类型禁止修改
      # y.weight = x   # y不是self,禁止修改
      return x

net = Net()
ret = net(1, 2)
print('ret:{}'.format(ret))

运行结果如下:

ret:1

索引取值

对序列TupleListDictionaryTensor的索引取值操作(Python称为抽取)。

Tuple的索引取值请参考本文的Tuple章节。

List的索引取值请参考本文的List章节。

Dictionary的索引取值请参考本文的Dictionary章节。

Tensor的索引取详见Tensor 索引取值文档

调用

所谓调用就是附带可能为空的一系列参数来执行一个可调用对象(例如:CellPrimitive)。

示例如下:

import mindspore as ms
from mindspore import nn, ops, set_context
import numpy as np

set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def __init__(self):
      super().__init__()
      self.matmul = ops.MatMul()

   def construct(self, x, y):
      out = self.matmul(x, y)  # Primitive调用
      return out

x = ms.Tensor(np.ones(shape=[1, 3]), ms.float32)
y = ms.Tensor(np.ones(shape=[3, 4]), ms.float32)
net = Net()
ret = net(x, y)
print('ret:{}'.format(ret))

运行结果如下:

ret:[[3. 3. 3. 3.]]

语句

当前静态图模式支持部分Python语句,包括raise语句、assert语句、pass语句、return语句、break语句、continue语句、if语句、for语句、while语句、with语句、列表生成式、生成器表达式、函数定义语句等,详见Python语句

Python内置函数

当前静态图模式支持部分Python内置函数,其使用方法与对应的Python内置函数类似,详见Python内置函数

网络定义

网络入参

在对整网入参求梯度的时候,会忽略非Tensor的入参,只计算Tensor入参的梯度。

示例如下。整网入参(x, y, z)中,xzTensory是非Tensor。因此,grad_net在对整网入参(x, y, z)求梯度的时候,会自动忽略y的梯度,只计算xz的梯度,返回(grad_x, grad_z)

import mindspore as ms
from mindspore import nn

ms.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def __init__(self):
      super(Net, self).__init__()

   def construct(self, x, y, z):
      return x + y + z

class GradNet(nn.Cell):
   def __init__(self, net):
      super(GradNet, self).__init__()
      self.forward_net = net

   def construct(self, x, y, z):
      return ms.grad(self.forward_net, grad_position=(0, 1, 2))(x, y, z)

input_x = ms.Tensor([1])
input_y = 2
input_z = ms.Tensor([3])

net = Net()
grad_net = GradNet(net)
ret = grad_net(input_x, input_y, input_z)
print('ret:{}'.format(ret))

运行结果如下:

ret:(Tensor(shape=[1], dtype=Int64, value= [1]), Tensor(shape=[1], dtype=Int64, value= [1]))

基础语法的语法约束

图模式下的执行图是从源码转换而来,并不是所有的Python语法都能支持。下面介绍在基础语法下存在的一些语法约束。更多网络编译问题可见网络编译

  1. construct函数里,使用未定义的类成员时,将抛出AttributeError异常。示例如下:

    import mindspore as ms
    from mindspore import nn, set_context
    
    set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def __init__(self):
          super(Net, self).__init__()
    
       def construct(self, x):
          return x + self.y
    
    net = Net()
    net(1)
    

    结果报错如下:

    AttributeError: External object has no attribute y
    
  2. nn.Cell不支持classmethod修饰的类方法。示例如下:

    import mindspore as ms
    
    ms.set_context(mode=ms.GRAPH_MODE)
    
    class Net(ms.nn.Cell):
       @classmethod
       def func(cls, x, y):
          return x + y
    
       def construct(self, x, y):
          return self.func(x, y)
    
    net = Net()
    out = net(ms.Tensor(1), ms.Tensor(2))
    print(out)
    

    结果报错如下:

    TypeError: The parameters number of the function is 3, but the number of provided arguments is 2.
    
  3. 在图模式下,有些Python语法难以转换成图模式下的中间表示MindIR。对标Python的关键字,存在部分关键字在图模式下是不支持的:AsyncFunctionDef、Delete、AnnAssign、AsyncFor、AsyncWith、Match、Try、Import、ImportFrom、Nonlocal、NamedExpr、Set、SetComp、Await、Yield、YieldFrom。如果在图模式下使用相关的语法,将会有相应的报错信息提醒用户。

    如果使用Try语句,示例如下:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self, x, y):
          global_out = 1
          try:
             global_out = x / y
          except ZeroDivisionError:
             print("division by zero, y is zero.")
          return global_out
    
    net = Net()
    test_try_except_out = net(1, 0)
    print("out:", test_try_except_out)
    

    结果报错如下:

    RuntimeError: Unsupported statement 'Try'.
    
  4. 对标Python内置数据类型,除去当前图模式下支持的Python内置数据类型,复数complex和集合set类型是不支持的。列表list和字典dictionary的一些高阶用法在基础语法场景下是不支持的,需要在JIT语法支持级别选项jit_syntax_levelLAX时才支持,更多请参考本文的扩展语法(LAX级别)章节。

  5. 对标Python的内置函数,在基础语法场景下,除去当前图模式下支持的Python内置函数,仍存在部分内置函数在图模式下是不支持的,例如:basestring、bin、bytearray、callable、chr、cmp、compile、 delattr、dir、divmod、eval、execfile、file、frozenset、hash、hex、id、input、issubclass、iter、locals、long、memoryview、next、object、oct、open、ord、property、raw_input、reduce、reload、repr、reverse、set、slice、sorted、unichr、unicode、vars、xrange、__import__。

  6. Python提供了很多第三方库,通常需要通过import语句调用。在图模式下JIT语法支持级别为STRICT时,不能直接使用第三方库。如果需要在图模式下使用第三方库的数据类型或者调用第三方库的方法,需要在JIT语法支持级别选项jit_syntax_levelLAX时才支持,更多请参考本文的扩展语法(LAX级别)中的调用第三方库章节。

  7. 在图模式下,不感知在图外对类的属性的修改,即图外对类的属性修改不会生效。例如:

    import mindspore as ms
    from mindspore import nn, ops, Tensor, context
    
    class Net(nn.Cell):
    def __init__(self):
       super().__init__()
       self.len = 1
    
    def construct(self, inputs):
       x = inputs + self.len
       return x
    
    context.set_context(mode=ms.GRAPH_MODE)
    inputs = 2
    net = Net()
    print("out1:", net(inputs))
    net.len = 2
    print("out2:", net(inputs))
    

    输出的结果将不会发生变化:

    out1: 3
    out2: 3
    

扩展语法(LAX级别)

下面主要介绍当前扩展支持的静态图语法。

调用第三方库

  • 第三方库

    1. Python内置模块和Python标准库。例如ossysmathtime等模块。

    2. 第三方代码库。路径在Python安装目录的site-packages目录下,需要先安装后导入,例如NumPySciPy等。需要注意的是,mindyolomindflow等MindSpore套件不被视作第三方库,具体列表可以参考parser文件的 _modules_from_mindspore 列表。

    3. 通过环境变量MS_JIT_IGNORE_MODULES指定的模块。与之相对的有环境变量MS_JIT_MODULES,具体使用方法请参考环境变量

  • 支持第三方库的数据类型,允许调用和返回第三方库的对象。

    示例如下:

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          a = np.array([1, 2, 3])
          b = np.array([4, 5, 6])
          out = a + b
          return out
    
    net = Net()
    print(net())
    

    运行结果如下:

    [5 7 9]
    
  • 支持调用第三方库的方法。

    示例如下:

    from scipy import linalg
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          x = [[1, 2], [3, 4]]
          return linalg.qr(x)
    
    net = Net()
    out = net()
    print(out[0].shape)
    

    运行结果如下:

    (2, 2)
    
  • 支持使用NumPy第三方库数据类型创建Tensor对象。

    示例如下:

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          x = np.array([1, 2, 3])
          out = ms.Tensor(x) + 1
          return out
    
    net = Net()
    print(net())
    

    运行结果如下:

    [2 3 4]
    
  • 支持对第三方库数据类型的下标索引赋值。

    示例如下:

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          x = np.array([1, 2, 3])
          x[0] += 1
          return ms.Tensor(x)
    
    net = Net()
    res = net()
    print("res: ", res)
    

    运行结果如下:

    res: [2 2 3]
    

支持自定义类的使用

支持在图模式下使用用户自定义的类,可以对类进行实例化,使用对象的属性及方法。

例如下面的例子,其中GetattrClass是用户自定义的类,没有使用@jit_class修饰,也没有继承nn.Cell

import mindspore as ms

ms.set_context(mode=ms.GRAPH_MODE)

class GetattrClass():
   def __init__(self):
      self.attr1 = 99
      self.attr2 = 1

   def method1(self, x):
      return x + self.attr2

class GetattrClassNet(ms.nn.Cell):
   def __init__(self):
      super(GetattrClassNet, self).__init__()
      self.cls = GetattrClass()

   def construct(self):
      return self.cls.method1(self.cls.attr1)

net = GetattrClassNet()
out = net()
assert out == 100

基础运算符支持更多数据类型

在静态图语法重载了以下运算符: ['+', '-', '*','/','//','%','**','<<','>>','&','|','^', 'not', '==', '!=', '<', '>', '<=', '>=', 'in', 'not in', 'y=x[0]']。图模式重载的运算符详见运算符。列表中的运算符在输入图模式中不支持的输入类型时将使用扩展静态图语法支持,并使输出结果与动态图模式下的输出结果一致。

代码用例如下。

import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
ms.set_context(mode=ms.GRAPH_MODE)

class InnerClass(nn.Cell):
   def construct(self, x, y):
      return x.asnumpy() + y.asnumpy()

net = InnerClass()
ret = net(Tensor([4, 5]), Tensor([1, 2]))
print(ret)

运行结果如下:

[5 7]

上述例子中,.asnumpy()输出的数据类型: numpy.ndarray为运算符+在图模式中不支持的输入类型。因此x.asnumpy() + y.asnumpy()将使用扩展语法支持。

在另一个用例中:

import mindspore as ms
import mindspore.nn as nn
ms.set_context(mode=ms.GRAPH_MODE)

class InnerClass(nn.Cell):
   def construct(self):
      return (None, 1) in ((None, 1), 1, 2, 3)

net = InnerClass()
print(net())

运行结果如下:

True

tuple in tuple在原本的图模式中是不支持的运算,现已使用扩展静态图语法支持。

基础类型

扩展对Python原生数据类型ListDictionaryNone的支持。

支持列表就地修改操作

列表List以及元组Tuple是Python中最基本的序列内置类型,ListTuple最核心的区别是List是可以改变的对象,而Tuple是不可以更改的。这意味着Tuple一旦被创建,就不可以在对象地址不变的情况下更改。而List则可以通过一系列inplace操作,在不改变对象地址的情况下,对对象进行修改。例如:

a = [1, 2, 3, 4]
a_id = id(a)
a.append(5)
a_after_id = id(a)
assert a_id == a_after_id

上述示例代码中,通过append这个inplace语法更改List对象的时候,其对象的地址并没有被修改。而Tuple是不支持这种inplace操作的。在JIT_SYNTAX_LEVEL设置为LAX的情况下,静态图模式可以支持部分List对象的inplace操作。

具体使用场景如下:

  • 支持从全局变量中获取原List对象。

    在下面示例中,静态图获取到List对象,并在原有对象上进行了图模式支持的inplace操作list.reverse(), 并将原有对象返回。可以看到图模式返回的对象与原有的全局变量对象id相同,即两者为同一对象。若JIT_SYNTAX_LEVEL设置为STRICT选项,则返回的List对象与全局对象为两个不同的对象。

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    global_list = [1, 2, 3, 4]
    
    class Net(nn.Cell):
       def construct(self):
          global_list.reverse()
          return global_list
    
    net = Net()
    output = net()  # output: [4, 3, 2, 1]
    assert id(global_list) == id(output)
    
  • 不支持对输入List对象进行inplace操作。

    List作为静态图输入时,会对该List对象进行一次复制,并使用该复制对象进行后续的计算,因此无法对原输入对象进行inplace操作。例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    list_input = [1, 2, 3, 4]
    
    class Net(nn.Cell):
       def construct(self, x):
          x.reverse()
          return x
    
    net = Net()
    output = net(list_input)  # output: [4, 3, 2, 1]  list_input: [1, 2, 3, 4]
    assert id(output) != id(list_input)
    

    如上述用例所示,List对象作为图模式输入时无法在原有对象上进行inplace操作。图模式返回的对象与输入的对象id不同,为不同对象。

  • 支持部分List内置函数的就地修改操作。

    JIT_SYNTAX_LEVEL设置为LAX的情况下,图模式部分List内置函数支持inplace。在 JIT_SYNTAX_LEVELSTRICT 的情况下,所有方法均不支持inplace操作。

    目前,图模式支持的List就地修改内置方法有extendpopreverse以及insert。内置方法appendclear以及索引赋值暂不支持就地修改,后续版本将会支持。

    示例如下:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    list_input = [1, 2, 3, 4]
    
    class Net(nn.Cell):
       def construct(self):
          list_input.reverse()
          return list_input
    
    net = Net()
    output = net()  # output: [4, 3, 2, 1]  list_input: [4, 3, 2, 1]
    assert id(output) == id(list_input)
    

支持Dictionary的高阶用法

  • 支持顶图返回Dictionary。

    示例如下:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          x = {'a': 'a', 'b': 'b'}
          y = x.get('a')
          z = dict(y=y)
          return z
    
    net = Net()
    out = net()
    print("out:", out)
    

    运行结果如下:

    out: {'y': 'a'}
    
  • 支持Dictionary索引取值和赋值。

    示例如下:

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    x = {"a": ms.Tensor(np.array([1, 2, 3])), "b": ms.Tensor(np.array([4, 5, 6])), "c": ms.Tensor(np.array([7, 8, 9]))}
    
    class Net(nn.Cell):
       def construct(self):
          y = x["b"]
          x["a"] = (2, 3, 4)
          return x, y
    
    net = Net()
    out1, out2 = net()
    print('out1:{}'.format(out1))
    print('out2:{}'.format(out2))
    

    运行结果如下:

    out1:{'a': (2, 3, 4), 'b': Tensor(shape=[3], dtype=Int64, value= [4, 5, 6]), 'c': Tensor(shape=[3], dtype=Int64, value= [7, 8, 9])}
    out2:[4 5 6]
    

支持使用None

None是Python中的一个特殊值,表示空,可以赋值给任何变量。对于没有返回值语句的函数认为返回None。同时也支持None作为顶图或者子图的入参或者返回值。支持None作为切片的下标,作为ListTupleDictionary的输入。

示例如下:

import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import Tensor

context.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self):
      return 1, "a", None

net = Net()
res = net()
print(res)

运行结果如下:

(1, 'a', None)

对于没有返回值的函数,默认返回None对象。

import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import Tensor

context.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self):
      x = 3
      print("x:", x)

net = Net()
res = net()
assert res is None

运行结果如下:

x:
3

如下面例子,None作为顶图的默认入参。

import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import Tensor

context.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self, x, y=None):
      if y is not None:
         print("y:", y)
      else:
         print("y is None")
      print("x:", x)
      return y

x = [1, 2]
net = Net()
res = net(x)
assert res is None

运行结果如下:

y is None
x:
[1, 2]

内置函数支持更多数据类型

扩展内置函数的支持范围。Python内置函数完善支持更多输入类型,例如第三方库数据类型。

例如下面的例子,x.asnumpy()np.ndarray均是扩展支持的类型。更多内置函数的支持情况可见Python内置函数章节。

import numpy as np
import mindspore as ms
import mindspore.nn as nn

ms.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self, x):
      return isinstance(x.asnumpy(), np.ndarray)

x = ms.Tensor(np.array([-1, 2, 4]))
net = Net()
out = net(x)
assert out

支持控制流

为了提高Python标准语法支持度,实现动静统一,扩展支持更多数据类型在控制流语句的使用。控制流语句是指ifforwhile等流程控制语句。理论上,通过扩展支持的语法,在控制流场景中也支持。代码用例如下:

import numpy as np
import mindspore as ms
from mindspore import nn
from mindspore import context
from mindspore import Tensor

context.set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self):
      x = np.array(1)
      if x <= 1:
         x += 1
      return ms.Tensor(x)

net = Net()
res = net()
print("res: ", res)

运行结果如下:

res:  2

支持属性设置与修改

具体使用场景如下:

  • 对自定义类对象以及第三方类型的属性进行设置与修改。

    图模式下支持对自定义类对象的属性进行设置与修改,例如:

    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class AssignClass():
       def __init__(self):
          self.x = 1
    
    obj = AssignClass()
    
    class Net(nn.Cell):
       def construct(self):
          obj.x = 100
    
    net = Net()
    net()
    print(f"obj.x is: {obj.x}")
    

    运行结果如下:

    obj.x is: 100
    

    图模式下支持对第三方库对象的属性进行设置与修改,例如:

    import numpy as np
    import mindspore as ms
    from mindspore import nn
    from mindspore import context
    from mindspore import Tensor
    
    context.set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def construct(self):
          a = np.array([1, 2, 3, 4])
          a.shape = (2, 2)
          return a.shape
    
    net = Net()
    shape = net()
    print(f"shape is {shape}")
    

    运行结果如下:

    shape is (2, 2)
    
  • 对Cell的self对象进行修改,例如:

    import mindspore as ms
    from mindspore import nn, set_context
    set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def __init__(self):
          super().__init__()
          self.m = 2
    
       def construct(self):
          self.m = 3
          return 0
    
    net = Net()
    net()
    print(f"net.m is {net.m}")
    

    运行结果如下:

    net.m is 3
    

    注意,self对象支持属性修改和设置。若__init__内没有定义某个属性,对齐PYNATIVE模式,图模式也允许设置此属性。例如:

    import mindspore as ms
    from mindspore import nn, set_context
    set_context(mode=ms.GRAPH_MODE)
    
    class Net(nn.Cell):
       def __init__(self):
          super().__init__()
          self.m = 2
    
       def construct(self):
          self.m2 = 3
          return 0
    
    net = Net()
    net()
    print(f"net.m2 is {net.m2}")
    

    运行结果如下:

    net.m2 is 3
    
  • 对静态图内的Cell对象以及jit_class对象进行设置与修改。

    支持对图模式jit_class对象进行属性修改,例如:

    import mindspore as ms
    from mindspore import nn, set_context, jit_class
    set_context(mode=ms.GRAPH_MODE)
    
    @jit_class
    class InnerClass():
       def __init__(self):
          self.x = 10
    
    class Net(nn.Cell):
       def __init__(self):
          super(Net, self).__init__()
          self.inner = InnerClass()
    
       def construct(self):
          self.inner.x = 100
          return 0
    
    net = Net()
    net()
    print(f"net.inner.x is {net.inner.x}")
    

    运行结果如下:

    net.inner.x is 100
    

支持求导

扩展支持的静态图语法,同样支持其在求导中使用,例如:

import mindspore as ms
from mindspore import ops, set_context, nn
set_context(mode=ms.GRAPH_MODE)

class Net(nn.Cell):
   def construct(self, a):
      x = {'a': a, 'b': 2}
      return a, (x, (1, 2))

net = Net()
out = ops.grad(net)(ms.Tensor([1]))
assert out == 2

Annotation Type

对于运行时的扩展支持的语法,会产生一些无法被类型推导出的节点,比如动态创建Tensor等。这种类型称为Any类型。因为该类型无法在编译时推导出正确的类型,所以这种Any将会以一种默认最大精度float64进行运算,防止其精度丢失。为了能更好的优化相关性能,需要减少Any类型数据的产生。当用户可以明确知道当前通过扩展支持的语句会产生具体类型的时候,我们推荐使用Annotation @jit.typing:的方式进行指定对应Python语句类型,从而确定解释节点的类型避免Any类型的生成。

例如,Tensor类和tensor接口的区别就在于在tensor接口内部运用了Annotation Type机制。当tensor函数的dtype确定时,函数内部会利用Annotation指定输出类型从而避免Any类型的产生。Annotation Type的使用只需要在对应Python语句上面或者后面加上注释 # @jit.typing: () -> tensor_type[float32] 即可,其中 -> 后面的 tensor_type[float32] 指示了被注释的语句输出类型。

代码用例如下。

import mindspore as ms
import mindspore.nn as nn
from mindspore import ops, Tensor

class Net(nn.Cell):
   def __init__(self):
      super(Net, self).__init__()
      self.abs = ops.Abs()

   def construct(self, x, y):
      z = x.asnumpy() + y.asnumpy()
      y1 = ms.tensor(z, dtype=ms.float32)
      y2 = ms.Tensor(z, dtype=ms.float32) # @jit.typing: () -> tensor_type[float32]
      y3 = Tensor(z)
      y4 = Tensor(z, dtype=ms.float32)
      return self.abs(y1), self.abs(y2), self.abs(y3), self.abs(y4)

ms.set_context(mode=ms.GRAPH_MODE)
net = Net()
x = ms.Tensor(-1, dtype=ms.int32)
y = ms.Tensor(-1, dtype=ms.float32)
y1, y2, y3, y4 = net(x, y)

print(f"y1 value is {y1}, dtype is {y1.dtype}")
print(f"y2 value is {y2}, dtype is {y2.dtype}")
print(f"y3 value is {y3}, dtype is {y3.dtype}")
print(f"y4 value is {y4}, dtype is {y4.dtype}")

运行结果如下:

y1 value is 2.0, dtype is Float32
y2 value is 2.0, dtype is Float32
y3 value is 2.0, dtype is Float64
y4 value is 2.0, dtype is Float32

上述例子,可以看到创建了Tensor的相关区别。对于y3y4,因为Tensor类没有增加Annotation指示,y3y4没有办法推出正确的类型,导致只能按照最高精度float64进行运算。 对于y2,由于创建Tensor时,通过Annotation指定了对应类型,使得其类型可以按照指定类型进行运算。 对于y1,由于使用了tensor函数接口创建Tensor,传入的dtype参数作为Annotation的指定类型,所以也避免了Any类型的产生。

扩展语法的语法约束

在使用静态图扩展支持语法时,请注意以下几点:

  1. 对标动态图的支持能力,即:须在动态图语法范围内,包括但不限于数据类型等。

  2. 在扩展静态图语法时,支持了更多的语法,但执行性能可能会受影响,不是最佳。

  3. 在扩展静态图语法时,支持了更多的语法,由于使用Python的能力,不能使用MindIR导入导出的能力。