流程控制语句
目前主流的深度学习框架的执行模式有两种,分别为静态图模式GRAPH_MODE
和动态图PYNATIVE_MODE
模式。
在PYNATIVE_MODE
模式下,MindSpore完全支持Python原生语法的流程控制语句。GRAPH_MODE
模式下,MindSpore在编译时做了性能优化,因此,在定义网络时使用流程控制语句时会有部分特殊约束,其他部分仍和Python原生语法保持一致。
运行模式从动态图切换到静态图时,请留意静态图语法支持。下面我们详细介绍在GRAPH_MODE
模式下定义网络时流程控制语句的使用方式。
常量与变量条件
在GRAPH_MODE
模式下定义网络,MindSpore将流程控制语句中的条件表达式分为两类:即常量条件和变量条件。在图编译时可以确定结果为True或False的条件表达式为常量条件,在图编译时不能确定结果为True或False的条件表达式为变量条件。只有当条件表达式为变量条件时,MindSpore才会在网络中生成控制流算子。
需要注意的是,当网络中存在控制流算子时,网络会被切分成多个执行子图,子图间进行流程跳转和数据传递会产生一定的性能损耗。
常量条件
判断方式:
条件表达式中不存在Tensor类型,且也不存在元素为Tensor类型的List、Tuple、Dict。
条件表达式中存在Tensor类型,或者元素为Tensor类型的List、Tuple、Dict,但是表达式结果不受Tensor的值影响。
举例:
for i in range(0,10)
,i
为标量:潜在的条件表达式i < 10
在图编译时可以确定结果,因此为常量条件;self.flag
,self.flag
为标量:此处self.flag
为一个bool类型标量,其值在构建Cell对象时已确定;x + 1 < 10
,x
为标量:此处x + 1
的值在构建Cell对象时是不确定的,但是在图编译时MindSpore会计算所有标量表达式的结果,因此该表达式的值也是在编译期确定的。len(my_list) < 10
,my_list
为元素是Tensor类型的List对象:该条件表达式包含Tensor,但是表达式结果不受Tensor的值影响,只与my_list
中Tensor的数量有关;
变量条件
判断方式:
条件表达式中存在Tensor类型或者元素为Tensor类型的List、Tuple、Dict,并且条件表达式的结果受Tensor的值影响。
举例:
x < y
,x
和y
为算子输出。x in list
,x
为算子输出。
由于算子输出是图在各个step执行时才能确定,因此上面两个都属于变量条件。
if语句
在GRAPH_MODE
模式下定义网络时,使用if
语句需要注意:在条件表达式为变量条件时,在不同分支的同一变量应被赋予相同的数据类型,例如Tensor类型变量要求shape和type一致。Tensor变量的shape一致性约束详见ShapeJoin规则。
变量条件的if语句
在下面代码中,在if
和else
分支中,变量out
在if
语句不同分支被赋予的Tensor的Shape分别是()和(2,)。网络最终返回的Tensor的shape由条件x < y
决定,而在图编译时期无法确定x < y
的结果,因此图编译时期无法确定out
的Shape是()还是(2,),MindSpore最终因类型推导失败而抛出异常。
import numpy as np
import mindspore as ms
from mindspore import nn
class SingleIfNet(nn.Cell):
def construct(self, x, y, z):
# 构造条件表达式为变量条件的if语句
if x < y:
out = x
else:
out = z
out = out + 1
return out
forward_net = SingleIfNet()
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
z = ms.Tensor(np.array([1, 2]), dtype=ms.int32)
output = forward_net(x, y, z)
执行上面的代码,报错信息如下:
ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:800 ProcessEvalResults] Cannot join the return values of different branches, perhaps you need to make them equal.
Shape Join Failed: shape1 = (), shape2 = (2).
常量条件的if语句
当if
语句中的条件表达式为常量条件时,其使用方式与Python原生语法保持一致,并无额外的约束。如下代码中的if
语句条件表达式x < y + 1
为常量条件(因为x和y都是标量常量类型),在图编译时期可确定变量out
的类型为标量int
类型,网络可正常编译和执行,输出正确结果1
。
[1]:
import numpy as np
import mindspore as ms
from mindspore import nn
class SingleIfNet(nn.Cell):
def construct(self, z):
x = 0
y = 1
# 构造条件表达式为常量条件的if语句
if x < y + 1:
out = x
else:
out = z
out = out + 1
return out
z = ms.Tensor(np.array([0, 1]), dtype=ms.int32)
forward_net = SingleIfNet()
output = forward_net(z)
print("output:", output)
output: 1
for语句
for
语句会将循环体展开,因此使用for
语句的网络的子图数量、算子数量取决于for
语句的循环次数,算子数量过多或者子图过多会消耗更多的硬件资源。
下面的示例代码中,for
语句中的循环体会被执行3次,输出结果为5
。
[2]:
import numpy as np
from mindspore import nn
import mindspore as ms
class IfInForNet(nn.Cell):
def construct(self, x, y):
out = 0
# 构造条件表达式为常量条件的for语句
for i in range(0, 3):
# 构造条件表达式为变量条件的if语句
if x + i < y:
out = out + x
else:
out = out + y
out = out + 1
return out
forward_net = IfInForNet()
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y)
print("output:", output)
output: 5
由于for
语句会展开循环体内容,所以上面的代码和下面的代码等价:
[3]:
import numpy as np
from mindspore import nn
import mindspore as ms
class IfInForNet(nn.Cell):
def construct(self, x, y):
out = 0
# 循环: 0
if x + 0 < y:
out = out + x
else:
out = out + y
out = out + 1
# 循环: 1
if x + 1 < y:
out = out + x
else:
out = out + y
out = out + 1
# 循环: 2
if x + 2 < y:
out = out + x
else:
out = out + y
out = out + 1
return out
forward_net = IfInForNet()
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y)
print("output:", output)
output: 5
从上面两段示例代码我们可以看出,在部分场景下,使用for
语句会导致出现子图过多的问题时。为了节约硬件资源开销,提升网络编译性能,可尝试将for
语句等价转换为条件表达式是变量条件的while
语句。
while语句
while
语句相比for
语句更为灵活。当while
的条件为常量时,while
对循环体的处理和for
类似,会展开循环体内容。
当while
的条件表达式是变量条件时,while
语句则不会展开循环体内容,而是在执行图中产生控制流算子,因此可以避免for
循环带来的子图过多的问题。
常量条件的while语句
下面的示例代码中,for
语句中的循环体会被执行3次,输出结果为5
,和上面介绍for
语句中的示例代码本质上是一样的。
[4]:
import numpy as np
from mindspore import nn
import mindspore as ms
class IfInWhileNet(nn.Cell):
def construct(self, x, y):
i = 0
out = x
# 构造条件表达式为常量条件的while语句
while i < 3:
# 构造条件表达式为变量条件的if语句
if x + i < y:
out = out + x
else:
out = out + y
out = out + 1
i = i + 1
return out
forward_net = IfInWhileNet()
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y)
print("output:", output)
output: 5
变量条件的while语句
约束一:当while语句中的条件表达式是变量条件时,while循环体内部不能出现标量、List、Tuple等非Tensor类型的计算操作。
为了避免产生过多的控制流算子,我们可以尝试使用条件表达式为变量条件的while
语句重写上面的代码:
[6]:
import numpy as np
from mindspore import nn
import mindspore as ms
class IfInWhileNet(nn.Cell):
def construct(self, x, y, i):
out = x
# 构造条件表达式为变量条件的while语句
while i < 3:
# 构造条件表达式为变量条件的if语句
if x + i < y:
out = out + x
else:
out = out + y
out = out + 1
i = i + 1
return out
forward_net = IfInWhileNet()
i = ms.Tensor(np.array(0), dtype=ms.int32)
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y, i)
print("output:", output)
output: 5
需要注意的是,在上面的代码中,while
语句的条件表达式为变量条件,while
循环体不会被展开,while
循环体内的表达式都是在各个step运行时计算,同时也产生了如下约束:
当
while
语句中的条件表达式是变量条件时,while
循环体内部不能出现标量、List、Tuple等非Tensor类型的计算操作。
因为这些类型的计算操作是在图编译时期完成的,这与while
循环体在执行期进行计算的机制是矛盾的。下面我们通过示例代码说明:
import numpy as np
from mindspore import nn
import mindspore as ms
class IfInWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.nums = [1, 2, 3]
def construct(self, x, y, i):
j = 0
out = x
# 构造条件表达式为变量条件的while语句
while i < 3:
if x + i < y:
out = out + x
else:
out = out + y
out = out + self.nums[j]
i = i + 1
# 在条件表达式为变量条件的while语句循环体内构造标量计算
j = j + 1
return out
forward_net = IfInWhileNet()
i = ms.Tensor(np.array(0), dtype=ms.int32)
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y, i)
上面的代码中,条件表达式i < 3
为变量条件的while
循环体内部存在标量计算j = j + 1
,因此会导致图编译出错。代码在执行时报错信息如下:
IndexError: mindspore/core/abstract/prim_structures.cc:127 InferTupleOrListGetItem] list_getitem evaluator index should be in range[-3, 3), but got 3.
约束二:当while语句中的条件表达式是变量条件时,循环体内部不能更改算子的输入shape,并且循环体内与循环体外相同名称变量数据类型应该一致,例如Tensor类型变量要求shape和type一致。Tensor变量的shape一致性约束详见ShapeJoin规则。
MindSpore要求网络的同一个算子的输入shape在图编译时是确定的,而在while
的循环体内部改变算子输入shape的操作是在图执行时生效,两者是矛盾的。
下面我们通过示例代码来说明:
import numpy as np
from mindspore import nn
import mindspore as ms
from mindspore import ops
class IfInWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.expand_dims = ops.ExpandDims()
def construct(self, x, y, i):
out = x
# 构造条件表达式为变量条件的while语句
while i < 3:
if x + i < y:
out = out + x
else:
out = out + y
out = out + 1
# 更改算子的输入shape
out = self.expand_dims(out, -1)
i = i + 1
return out
forward_net = IfInWhileNet()
i = ms.Tensor(np.array(0), dtype=ms.int32)
x = ms.Tensor(np.array(0), dtype=ms.int32)
y = ms.Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y, i)
上面的代码中,条件表达式i < 3
为变量条件的while
循环体内部的ExpandDims
算子会改变表达式out = out + 1
在下一轮循环的输入shape,因此会导致图编译出错。代码在执行时报错信息如下:
ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:800 ProcessEvalResults] Cannot join the return values of different branches, perhaps you need to make them equal.
Shape Join Failed: shape1 = (1), shape2 = (1, 1).
ShapeJoin规则
unknow_shape
表示动态shape场景下,对应维度的长度是动态的,unknown_rank
表示动态rank场景下shape的维度是动态的,shape1
与shape2
分别表示进行Join的两个分支的shape。当满足下面任一规则时,Shape Join会成功,否则会出现Shape Join Failed
异常。
规则1:
shape1和shape2维度均固定且两者维度相等,且shape1[i]等于shape2[i]。
规则2:
shape1和shape2维度均固定且两者维度相等,且shape1[i]或shape2[i]至少有一个是unknown_shape。
规则3:
shape1和shape2维度至少有一个是动态的,即shape1或者shape2是动态rank。
规则4:
shape1和shape2维度固定且两者维度不相等,较小的维度是m, 较大的维度是n。
在0~m-1维范围内满足:
shape1[i]或shape2[i]相等。
或shape1[i]与shape2[i]均是unknown_shape。
在m~n-1维范围内满足:维度较大者的shape[i]是unknown_shape。
下述列表是Shape Join的规则示例。
shape1 |
shape2 |
Join结果 |
---|---|---|
(3, 4) |
(3, 4) |
(3, 4) |
(3, 5) |
(3, 4) |
Join Fail |
(3, 4) |
(3, 4, 1) |
Join Fail |
(3, unknown_shape) |
(3, 4) |
(3, unknown_shape) |
unknown_rank |
(3, 4) |
unknown_rank |
(3, unknown_shape) |
(3, unknown_shape, unknown_shape) |
unknown_rank |
(3, unknown_shape) |
(4, unknown_shape, unknown_shape) |
Join Fail |
(3, unknown_shape) |
(3, 4, unknown_shape) |
Join Fail |