Using the Process Control Statement
Ascend
GPU
CPU
Model Development
Overview
The MindSpore process control statement is similar to the native Python syntax, especially in PYNATIVE_MODE
mode. However, there are some special constraints in GRAPH_MODE
mode. The following process control statements are executed in GRAPH_MODE
mode.
When a process control statement is used, MindSpore determines whether to generate a control flow operator on a network based on whether the condition is a variable. The control flow operator is generated on the network only when the condition is a variable. If a condition expression result needs to be determined during graph build, the condition is a constant. Otherwise, the condition is a variable. It should be specially noted that, when a control flow operator exists in a network, the network is divided into multiple execution subgraphs, and process jumping and data transmission between the subgraphs cause performance loss to some extent.
In the scenario where the condition is a variable:
The condition expression contains tensors or a list, tuple, or dict of the tensor type, and the condition expression result is affected by the tensor value.
Common variable conditions are as follows:
(x < y).all()
, wherex
ory
is the operator output. In this case, whether the condition is true depends on the operator outputx
andy
, and the operator output can be determined only when each step is executed.x in list
, wherex
is the operator output.
In the scenario where the condition is a constant:
The condition expression does not contain tensors or a list, tuple, or dict of the tensor type.
The condition expression contains tensors or a list, tuple, or dict of the tensor type, but the condition expression result is not affected by the tensor value.
Common constant conditions are as follows:
self.flag
, which is a scalar of the Boolean type. The value ofself.flag
is determined when the cell object is created. Therefore,self.flag
is a constant condition.x + 1 < 10
, wherex
is a scalar. Although the value ofx + 1
is uncertain when a cell object is created, MindSpore computes the results of all scalar expressions during graph build. Therefore, the expression value is determined during build and this is a constant condition.len(my_list) < 10
, wheremy_list
is a list object of the tensor type. Although the condition expression contains tensors, the expression result is not affected by the tensor value and is related only to the number of tensors inmy_list
. Therefore, this is a constant condition.for i in range (0,10)
, wherei
is a scalar, and the potential condition expressioni < 10
is a constant condition.
Using the if Statement
When using the if
statement, ensure that the same variable name in different branches is assigned the same data type if the condition is a variable. In addition, the number of subgraphs of the execution graph generated by the network is in direct proportion to the number of if
. Too many if
statements generate high performance overheads of the control flow operators and those of the subgraph data transmission.
Using an if Statement with a Variable Condition
In example 1, out
is set to [0] in the true branch and to [0, 1] in the false branch. x < y
is a variable. Therefore, the data type of out
cannot be determined in the out = out + 1
statement, causing a graph build exception.
Example 1:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
class SingleIfNet(nn.Cell):
def construct(self, x, y, z):
if x < y:
out = x
else:
out = z
out = out + 1
return out
forward_net = SingleIfNet()
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
z = Tensor(np.array([0, 1]), dtype=ms.int32)
output = forward_net(x, y, z)
The error information in example 1 is as follows:
ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:734 ProcessEvalResults] The return values of different branches do not match. Shape Join Failed: shape1 = (2), shape2 = ()..
Using an if Statement with a Constant Condition
In example 2, out
is assigned to scalar 0 in the true branch and is assigned to [0, 1] in the false branch. x
and y
are scalars, and x < y + 1
is a constant. It can be determined that the true branch is used in the build phase; therefore, only the content of the true branch exists on the network and there is no control flow operator. The input out
data type of out = out + 1
is fixed. Therefore, the test case can be executed properly.
Example 2:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
class SingleIfNet(nn.Cell):
def construct(self, z):
x = 0
y = 1
if x < y + 1:
out = x
else:
out = z
out = out + 1
return out
forward_net = SingleIfNet()
z = Tensor(np.array([0, 1]), dtype=ms.int32)
output = forward_net(z)
Using the for Statement
The for
statement expands the loop body. In example 3, for
is cycled for three times, which is the same as the structure of the execution graph generated in example 4. Therefore, the number of subgraphs and operators of the network using the for
statement depends on the number of for
iterations. If there are too many operators or subgraphs, hardware resources are limited. If there are too many subgraphs due to the for
statement, you can refer to the while
writing mode and try to convert the for
statement to the while
statement whose condition is variable.
Example 3:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
class IfInForNet(nn.Cell):
def construct(self, x, y):
out = 0
for i in range(0,3):
if x + i < y :
out = out + x
else:
out = out + y
out = out + 1
return out
forward_net = IfInForNet()
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y)
Example 4:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
class IfInForNet(nn.Cell):
def construct(self, x, y):
out = 0
#######cycle 0
if x + 0 < y :
out = out + x
else:
out = out + y
out = out + 1
#######cycle 1
if x + 1 < y :
out = out + x
else:
out = out + y
out = out + 1
#######cycle 2
if x + 2 < y :
out = out + x
else:
out = out + y
out = out + 1
return out
forward_net = IfInForNet()
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y)
Using the while Statement
The while
statement is more flexible than the for
statement. When the condition of while
is a constant, while
processes and expands the loop body in a similar way as for
. When the condition of while
is a variable, while
does not expand the loop body. In this case, a control flow operator is generated when the graph is executed.
Using a while Statement with a Constant Condition
As shown in example 5, the condition i < 3
is a constant, and the content of the while
loop body is copied for three times. Therefore, the generated execution diagram is the same as that in example 4. When the while
statement condition is a constant, the number of operators and subgraphs is proportional to the number of while
loops. If there are too many operators or subgraphs, hardware resources are limited.
Example 5:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
class IfInWhileNet(nn.Cell):
def construct(self, x, y):
i = 0
out = x
while i < 3:
if x + i < y :
out = out + x
else:
out = out + y
out = out + 1
i = i + 1
return out
forward_net = IfInWhileNet()
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y)
Using a while Statement with a Variable Condition
As shown in example 6, the while
condition is changed to a variable, and while
is not expanded. The final network output result is the same as that in example 5, but the structure of the execution graph is different. In example 6, there are fewer operators and more subgraphs in an execution graph that is not expanded. A shorter build time and a smaller device memory are used, but extra performance overheads caused by execution of a control flow operator and data transfer between subgraphs are generated.
Example 6:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
class IfInWhileNet(nn.Cell):
def construct(self, x, y, i):
out = x
while i < 3:
if x + i < y :
out = out + x
else:
out = out + y
out = out + 1
i = i + 1
return out
forward_net = IfInWhileNet()
i = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y, i)
When the condition of while
is a variable, the while
loop body cannot be expanded. The expressions in the while
loop body are calculated during the running of each step. Therefore, computation types other than tensor, such as scalar, list, and tuple operations cannot exist in the loop body. These types of computation need to be completed during graph build, which conflicts with the computation mechanism of while
during execution. As shown in example 7, the condition i < 3
is a variable condition, but the j = j + 1
scalar computation operation exists in the loop body. As a result, an error occurs during graph build.
Example 7:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
class IfInWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.nums = [1, 2, 3]
def construct(self, x, y, i):
j = 0
out = x
while i < 3:
if x + i < y :
out = out + x
else:
out = out + y
out = out + self.nums[j]
i = i + 1
j = j + 1
return out
forward_net = IfInWhileNet()
i = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y, i)
The error information in example 7 is as follows:
IndexError: mindspore/core/abstract/prim_structures.cc:178 InferTupleOrListGetItem] list_getitem evaluator index should be in range[-3, 3), but got 3.
When the while
condition is a variable, the input shape of the operator cannot be changed in the loop body. MindSpore requires that the input shape of the same operator on the network be determined during graph build. However, changing the input shape of the operator in the while
loop body takes effect during graph execution. As shown in example 8, the condition i < 3
is a variable condition, and while
is not expanded. The ExpandDims
operator in the loop body changes the input shape of the expression out = out + 1
in the next loop. As a result, an error occurs during graph build.
Example 8:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common import dtype as ms
from mindspore import ops
class IfInWhileNet(nn.Cell):
def __init__(self):
super().__init__()
self.expand_dims = ops.ExpandDims()
def construct(self, x, y, i):
out = x
while i < 3:
if x + i < y :
out = out + x
else:
out = out + y
out = out + 1
out = self.expand_dims(out, -1)
i = i + 1
return out
forward_net = IfInWhileNet()
i = Tensor(np.array(0), dtype=ms.int32)
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
output = forward_net(x, y, i)
The error information in example 8 is as follows:
ValueError: mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc:734 ProcessEvalResults] The return values of different branches do not match. Shape Join Failed: shape1 = (1, 1), shape2 = (1)..
Constraints
In addition to the constraints in the conditional variable scenario, the current process statement has constraints in other specific scenarios.
Side Effect
When a process control statement with a variable condition is used, the network model generated after graph build contains the control flow operator. In this scenario, the forward graph is executed twice. In this case, if the forward graph contains side effect operators such as Assign
in a training scenario, the computation result of the backward graph is inconsistent with the expected result.
As shown in example 9, the expected gradient of x
is 2, but the actual gradient is 3. The reason is that the forward graph is executed twice so that tmp = self.var + 1
and self.assign(self.var, tmp)
are executed twice, separately. out = (self.var + 1) * x
is actually out = (2 + 1) * x
, so the gradient result is incorrect.
Example 9:
import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore import dtype as ms
from mindspore import ops
from mindspore.ops import composite
from mindspore import Parameter
class ForwardNet(nn.Cell):
def __init__(self):
super().__init__()
self.var = Parameter(Tensor(np.array(0), ms.int32))
self.assign = ops.Assign()
def construct(self, x, y):
if x < y:
tmp = self.var + 1
self.assign(self.var, tmp)
out = (self.var + 1) * x
out = out + 1
return out
class BackwardNet(nn.Cell):
def __init__(self, net):
super(BackwardNet, self).__init__(auto_prefix=False)
self.forward_net = net
self.grad = composite.GradOperation()
def construct(self, *inputs):
grads = self.grad(self.forward_net)(*inputs)
return grads
forward_net = ForwardNet()
backward_net = BackwardNet(forward_net)
x = Tensor(np.array(0), dtype=ms.int32)
y = Tensor(np.array(1), dtype=ms.int32)
output = backward_net(x, y)
print("output:", output)
The execution result is as follows:
output: 3
The following table lists the side effect operators that are not supported in the control flow training scenario.
Side Effect List |
---|
Assign |
AssignAdd |
AssignSub |
ScalarSummary |
ImageSummary |
TensorSummary |
HistogramSummary |
ScatterAdd |
ScatterDiv |
ScatterMax |
ScatterMin |
ScatterMul |
ScatterNdAdd |
ScatterNdSub |
ScatterNdUpadte |
ScatterNonAliasingAdd |
ScatterSub |
ScatterUpdate |
Dead Cycle
If the value of cond
in expression while cond:
is always a scalar True
, no matter whether there is a break
or return
in while body, an unexpected exception may be raised.