流程控制语句

下载Notebook下载样例代码查看源文件

目前主流的深度学习框架的执行模式有两种,分别为静态图模式GRAPH_MODE和动态图PYNATIVE_MODE模式。

PYNATIVE_MODE模式下,MindSpore完全支持Python原生语法的流程控制语句。GRAPH_MODE模式下,MindSpore在编译时做了性能优化,因此,在定义网络时使用流程控制语句时会有部分特殊约束,其他部分仍和Python原生语法保持一致。

运行模式从动态图切换到静态图时,请留意静态图语法支持。下面我们详细介绍在GRAPH_MODE模式下定义网络时流程控制语句的使用方式。

常量与变量条件

GRAPH_MODE模式下定义网络,MindSpore将流程控制语句中的条件表达式分为两类:即常量条件和变量条件。在图编译时可以确定结果为True或False的条件表达式为常量条件,在图编译时不能确定结果为True或False的条件表达式为变量条件。只有当条件表达式为变量条件时,MindSpore才会在网络中生成控制流算子

需要注意的是,当网络中存在控制流算子时,网络会被切分成多个执行子图,子图间进行流程跳转和数据传递会产生一定的性能损耗。

常量条件

判断方式:

  • 条件表达式中不存在Tensor类型,且也不存在元素为Tensor类型的List、Tuple、Dict。

  • 条件表达式中存在Tensor类型,或者元素为Tensor类型的List、Tuple、Dict,但是表达式结果不受Tensor的值影响。

举例:

  • for i in range(0,10)i为标量:潜在的条件表达式i < 10在图编译时可以确定结果,因此为常量条件;

  • self.flagself.flag为标量:此处self.flag为一个bool类型标量,其值在构建Cell对象时已确定;

  • x + 1 < 10x为标量:此处x + 1的值在构建Cell对象时是不确定的,但是在图编译时MindSpore会计算所有标量表达式的结果,因此该表达式的值也是在编译期确定的。

  • len(my_list) < 10my_list为元素是Tensor类型的List对象:该条件表达式包含Tensor,但是表达式结果不受Tensor的值影响,只与my_list中Tensor的数量有关;

变量条件

判断方式:

  • 条件表达式中存在Tensor类型或者元素为Tensor类型的List、Tuple、Dict,并且条件表达式的结果受Tensor的值影响。

举例:

  • x < yxy为算子输出。

  • x in listx为算子输出。

由于算子输出是图在各个step执行时才能确定,因此上面两个都属于变量条件。

if语句

GRAPH_MODE模式下定义网络时,使用if语句需要注意:在条件表达式为变量条件时,在不同分支的同一变量应被赋予相同的数据类型

变量条件的if语句

在下面代码中,在ifelse分支中,变量outif语句不同分支被赋予的Tensor的Shape分别是()和(2,)。网络最终返回的Tensor的shape由条件x < y决定,而在图编译时期无法确定x < y的结果,因此图编译时期无法确定out的Shape是()还是(2,),MindSpore最终因类型推导失败而抛出异常。

import numpy as np
import mindspore as ms
from mindspore import nn

class SingleIfNet(nn.Cell):

    def construct(self, x, y, z):
        # 构造条件表达式为变量条件的if语句
        if x < y:
            out = x
        else:
            out = z
        out = out + 1
        return out

forward_net = SingleIfNet()

x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
z = ms.Tensor(np.array([1, 2]), dtype=ms.int32)

output = forward_net(x, y, z)

执行上面的代码,报错信息如下:

ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:800 ProcessEvalResults] Cannot join the return values of different branches, perhaps you need to make them equal.
Shape Join Failed: shape1 = (), shape2 = (2).

常量条件的if语句

if语句中的条件表达式为常量条件时,其使用方式与Python原生语法保持一致,并无额外的约束。如下代码中的if语句条件表达式x < y + 1为常量条件(因为x和y都是标量常量类型),在图编译时期可确定变量out的类型为标量int类型,网络可正常编译和执行,输出正确结果1

[1]:
import numpy as np
import mindspore as ms
from mindspore import nn

class SingleIfNet(nn.Cell):

    def construct(self, z):
        x = 0
        y = 1

        # 构造条件表达式为常量条件的if语句
        if x < y + 1:
            out = x
        else:
            out = z
        out = out + 1

        return out

z = ms.Tensor(np.array([0, 1]), dtype=ms.int32)
forward_net = SingleIfNet()

output = forward_net(z)
print("output:", output)
output: 1

for语句

for语句会将循环体展开,因此使用for语句的网络的子图数量、算子数量取决于for语句的循环次数,算子数量过多或者子图过多会消耗更多的硬件资源。

下面的示例代码中,for语句中的循环体会被执行3次,输出结果为5

[2]:
import numpy as np
from mindspore import nn
import mindspore as ms

class IfInForNet(nn.Cell):

    def construct(self, x, y):
        out = 0

        # 构造条件表达式为常量条件的for语句
        for i in range(0, 3):
            # 构造条件表达式为变量条件的if语句
            if x + i < y:
                out = out + x
            else:
                out = out + y
            out = out + 1

        return out

forward_net = IfInForNet()

x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)

output = forward_net(x, y)
print("output:", output)
output: 5

由于for语句会展开循环体内容,所以上面的代码和下面的代码等价:

[3]:
import numpy as np
from mindspore import nn
import mindspore as ms

class IfInForNet(nn.Cell):
    def construct(self, x, y):
        out = 0

        # 循环: 0
        if x + 0 < y:
            out = out + x
        else:
            out = out + y
        out = out + 1
        # 循环: 1
        if x + 1 < y:
            out = out + x
        else:
            out = out + y
        out = out + 1
        # 循环: 2
        if x + 2 < y:
            out = out + x
        else:
            out = out + y
        out = out + 1

        return out

forward_net = IfInForNet()

x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)

output = forward_net(x, y)
print("output:", output)
output: 5

从上面两段示例代码我们可以看出,在部分场景下,使用for语句会导致出现子图过多的问题时。为了节约硬件资源开销,提升网络编译性能,可尝试将for语句等价转换为条件表达式是变量条件的while语句。

while语句

while语句相比for语句更为灵活。当while的条件为常量时,while对循环体的处理和for类似,会展开循环体内容。

while的条件表达式是变量条件时,while语句则不会展开循环体内容,而是在执行图中产生控制流算子,因此可以避免for循环带来的子图过多的问题。

常量条件的while语句

下面的示例代码中,for语句中的循环体会被执行3次,输出结果为5,和上面介绍for语句中的示例代码本质上是一样的。

[4]:
import numpy as np
from mindspore import nn
import mindspore as ms

class IfInWhileNet(nn.Cell):

    def construct(self, x, y):
        i = 0
        out = x
        # 构造条件表达式为常量条件的while语句
        while i < 3:
            # 构造条件表达式为变量条件的if语句
            if x + i < y:
                out = out + x
            else:
                out = out + y
            out = out + 1
            i = i + 1
        return out

forward_net = IfInWhileNet()
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)

output = forward_net(x, y)
print("output:", output)
output: 5

变量条件的while语句

  1. 约束一:当while语句中的条件表达式是变量条件时,while循环体内部不能出现标量、List、Tuple等非Tensor类型的计算操作

为了避免产生过多的控制流算子,我们可以尝试使用条件表达式为变量条件的while语句重写上面的代码:

[6]:
import numpy as np
from mindspore import nn
import mindspore as ms

class IfInWhileNet(nn.Cell):

    def construct(self, x, y, i):
        out = x
        # 构造条件表达式为变量条件的while语句
        while i < 3:
            # 构造条件表达式为变量条件的if语句
            if x + i < y:
                out = out + x
            else:
                out = out + y
            out = out + 1
            i = i + 1
        return out

forward_net = IfInWhileNet()
i = ms.Tensor(np.array(0), dtype=ms.int32)
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)

output = forward_net(x, y, i)
print("output:", output)
output: 5

需要注意的是,在上面的代码中,while语句的条件表达式为变量条件,while循环体不会被展开,while循环体内的表达式都是在各个step运行时计算,同时也产生了如下约束:

while语句中的条件表达式是变量条件时,while循环体内部不能出现标量、List、Tuple等非Tensor类型的计算操作。

因为这些类型的计算操作是在图编译时期完成的,这与while循环体在执行期进行计算的机制是矛盾的。下面我们通过示例代码说明:

import numpy as np
from mindspore import nn
import mindspore as ms
class IfInWhileNet(nn.Cell):

    def __init__(self):
        super().__init__()
        self.nums = [1, 2, 3]

    def construct(self, x, y, i):
        j = 0
        out = x

        # 构造条件表达式为变量条件的while语句
        while i < 3:
            if x + i < y:
                out = out + x
            else:
                out = out + y
            out = out + self.nums[j]
            i = i + 1
            # 在条件表达式为变量条件的while语句循环体内构造标量计算
            j = j + 1

        return out

forward_net = IfInWhileNet()
i = ms.Tensor(np.array(0), dtype=ms.int32)
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)

output = forward_net(x, y, i)

上面的代码中,条件表达式i < 3为变量条件的while循环体内部存在标量计算j = j + 1,因此会导致图编译出错。代码在执行时报错信息如下:

IndexError: mindspore/core/abstract/prim_structures.cc:127 InferTupleOrListGetItem] list_getitem evaluator index should be in range[-3, 3), but got 3.
  1. 约束二:当while语句中的条件表达式是变量条件时,循环体内部不能更改算子的输入shape。

MindSpore要求网络的同一个算子的输入shape在图编译时是确定的,而在while的循环体内部改变算子输入shape的操作是在图执行时生效,两者是矛盾的。

下面我们通过示例代码来说明:

import numpy as np
from mindspore import nn
import mindspore as ms
from mindspore import ops

class IfInWhileNet(nn.Cell):

    def __init__(self):
        super().__init__()
        self.expand_dims = ops.ExpandDims()

    def construct(self, x, y, i):
        out = x
        # 构造条件表达式为变量条件的while语句
        while i < 3:
            if x + i < y:
                out = out + x
            else:
                out = out + y
            out = out + 1
            # 更改算子的输入shape
            out = self.expand_dims(out, -1)
            i = i + 1
        return out

forward_net = IfInWhileNet()
i = ms.Tensor(np.array(0), dtype=ms.int32)
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)

output = forward_net(x, y, i)

上面的代码中,条件表达式i < 3为变量条件的while循环体内部的ExpandDims算子会改变表达式out = out + 1在下一轮循环的输入shape,因此会导致图编译出错。代码在执行时报错信息如下:

ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:800 ProcessEvalResults] Cannot join the return values of different branches, perhaps you need to make them equal.
Shape Join Failed: shape1 = (1), shape2 = (1, 1).