Network Compilation
Q: What can I do if an error “Create python object `<class ‘mindspore.common.tensor.Tensor’>` failed, only support create Cell or Primitive object.” is reported?
A: Currently in graph mode, the construct
function (or the function decorated by the @ms_function
decorator) only supports the construction of Cell
and Primitive object
. The construction of Tensor
is not supported, that is, the syntax x = Tensor(args...)
is not supported.
If it is a constant tensor, please define it in the function __init__
. If not, you can use the @constexpr
decorator to modify the function and generate the Tensor
in the function.
Please see the usage of @constexpr
in https://www.mindspore.cn/docs/api/en/r1.6/api_python/ops/mindspore.ops.constexpr.html.
The constant Tensor
used on the network can be used as a network attribute and defined in init
, that is, self.x = Tensor(args...)
. Then the constant can be used in the construct
function (or the function decorated by the @ms_function
decorator).
In the following example, Tensor
of shape = (3, 4), dtype = int64
is generated by @constexpr
.
@constexpr
def generate_tensor():
return Tensor(np.ones((3, 4).astype(np.int64)))
Q: What can I do if an error “‘self.xx’ should be defined in the class ‘init’ function.” is reported?
A: If you want to assign for a class member such as self.xx
in the function construct
, self.xx
must have been defined to a Parameter type firstly while the other types are not supported. But the local variable xx
is not under the regulation.
Q: What can I do if an error “This comparator ‘AnyValue’ is not supported. For statement ‘is’, only support compare with ‘None’, ‘False’ or ‘True’” is reported?
A: For the syntax is
or is not
, currently MindSpore
only supports comparisons with True
, False
and None
. Other types, such as strings, are not supported.
Q: What can I do if an error “MindSpore does not support comparison with operators more than one now, ops size =2” is reported?
A: For comparison statements, MindSpore
supports at most one operator. Please modify your code. For example, you can use 1 < x and x < 3
to take the place of 1 < x < 3
.
Q: What can I do if an error “TypeError: The function construct need 1 positional argument and 0 default argument, but provided 2” is reported?
A: When you call the instance of a network, the function construct
will be executed. And the program will check the number of parameters required by the function construct
and the number of parameters actually given. If they are not equal, the above exception will be thrown.
Please check your code to make sure they are equal.
Q: What can I do if an error “Type Join Failed” or “Shape Join Failed” is reported?
A: In the inference stage of front-end compilation, the abstract types of nodes, including type
and shape
, will be inferred. Common abstract types include AbstractScalar
, AbstractTensor
, AbstractFunction
, AbstractTuple
, AbstractList
, etc. In some scenarios, such as multi-branch scenarios, the abstract types of the return values of different branches will be joined to infer the abstract type of the returned result. If these abstract types do not match, or type
/shape
are inconsistent, the above exception will be thrown.
When an error similar to “Type Join Failed: dtype1 = Float32, dtype2 = Float16” appears, it means that the data types are inconsistent, resulting in an exception when joining abstract. According to the provided data types and code line, the error can be quickly located. In addition, the specific abstract information and node information are provided in the error message. You can view the MindIR information through the analyze_fail.dat
file to locate and solve the problem. For specific introduction of MindIR, please refer to MindSpore IR (MindIR). The code sample is as follows:
import numpy as np
import mindspore as ms
import mindspore.ops as ops
from mindspore import nn, Tensor, context
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = ops.ReLU()
self.cast = ops.Cast()
def construct(self, x, a, b):
if a > b: # The type of the two branches are inconsistent.
return self.relu(x) # shape: (2, 3, 4, 5), dtype:Float32
else:
return self.cast(self.relu(x), ms.float16) # shape: (2, 3, 4, 5), dtype:Float16
input_x = Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = Tensor(2, ms.float32)
input_b = Tensor(6, ms.float32)
net = Net()
out_me = net(input_x, input_a, input_b)
The result is as follows:
TypeError: The return values of different branches do not match. Type Join Failed: dtype1 = Float32, dtype2 = Float16. The abstract type of the return value of the current branch is AbstractTensor(shape: (2, 3, 4, 5), element: AbstractScalar(Type: Float16, Value: AnyValue, Shape: NoShape), value_ptr: 0x32ed00e0, value: AnyValue), and that of the previous branch is AbstractTensor(shape: (2, 3, 4, 5), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x32ed00e0, value: AnyValue). Please check the node construct.4:[CNode]5{[0]: [CNode]6}, true branch: ✓construct.2, false branch: ✗construct.3. trace:
In file test_join.py(14)/ if a > b:/
The function call stack (See file 'analyze_fail.dat' for more details):
# 0 In file test_join.py(14)
if a > b:
When an error similar to “Shape Join Failed: shape1 = (2, 3, 4, 5), shape2 = ()” appears, it means that the shapes are inconsistent, resulting in an exception when joining abstract. The code sample is as follows:
import numpy as np
import mindspore as ms
import mindspore.ops as ops
from mindspore import nn, Tensor, context
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.relu = ops.ReLU()
self.reducesum = ops.ReduceSum()
def construct(self, x, a, b):
if a > b: # The shape of the two branches are inconsistent.
return self.relu(x) # shape: (2, 3, 4, 5), dtype:Float32
else:
return self.reducesum(x) # shape:(), dype: Float32
input_x = Tensor(np.random.rand(2, 3, 4, 5).astype(np.float32))
input_a = Tensor(2, ms.float32)
input_b = Tensor(6, ms.float32)
net = Net()
out = net(input_x, input_a, input_b)
The result is as follows:
ValueError: The return values of different branches do not match. Shape Join Failed: shape1 = (2, 3, 4, 5), shape2 = (). The abstract type of the return value of the current branch is AbstractTensor(shape: (), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x239b5120, value: AnyValue), and that of the previous branch is AbstractTensor(shape: (2, 3, 4, 5), element: AbstractScalar(Type: Float32, Value: AnyValue, Shape: NoShape), value_ptr: 0x239b5120, value: AnyValue). Please check the node construct.4:[CNode]5{[0]: [CNode]6}, true branch: ✓construct.2, false branch: ✗construct.3. trace:
In file test_join1.py(14)/ if a > b:/
The function call stack (See file 'analyze_fail.dat' for more details):
# 0 In file test_join1.py(14)
if a > b:
When an error similar to “Type Join Failed: abstract type AbstractTensor can not join with AbstractTuple” appears, it means that the two abstract types are mismatched, resulting in an exception when joining abstract. The code sample is as follows:
import mindspore.ops as ops
from mindspore import Tensor, ms_function
x = Tensor([1.0])
y = Tensor([2.0])
grad = ops.GradOperation(get_by_list=False, sens_param=True)
sens = 1.0
def test_net(a, b):
return a, b
@ms_function()
def join_fail():
sens_i = ops.Fill()(ops.DType()(x), ops.Shape()(x), sens) # sens_i is a Scalar Tensor with shape: (1), dtype:Float64, value:1.0
# sens_i = (sens_i, sens_i)
a = grad(test_net)(x, y, sens_i) # For test_net output with type tuple(Tensor, Tensor), sens_i wih same type are needed to calculate the gradient, but sens_i is a Tensor;Setting sens_i = (sens_i, sens_i) before grad can fix the problem.
return a
join_fail()
The result is as follows:
TypeError: mindspore/core/abstract/abstract_value.cc:48 AbstractTypeJoinLogging] Type Join Failed: abstract type AbstractTensor cannot not join with AbstractTuple. For more details, please refer to the FAQ at https://www.mindspore.cn. this: AbstractTensor(shape: (1), element: AbstractScalar(Type: Float64, Value: AnyValue, Shape: NoShape), value_ptr: 0x55f643f283d0, value: Tensor(shape=[1], dtype=Float64, value= [ 1.00000000e+00])), other: AbstractTuple(element[0]: AbstractTensor(shape: (1), element: AbstractScalar(Type: Float64, Value: AnyValue, Shape: NoShape), value_ptr: 0x55f64473a500, value: Tensor(shape=[1], dtype=Float64, value= [ 1.00000000e+00])), element[1]: AbstractTensor(shape: (1), element: AbstractScalar(Type: Float64, Value: AnyValue, Shape: NoShape), value_ptr: 0x55f6447042c0, value: Tensor(shape=[1], dtype=Float64, value= [ 2.00000000e+00]))). Please check the node test_net.2:test_net{[0]: test_net, [1]: test_net}. trace:
In file test_shape_join_failed.py(9)/def test_net(a, b):/
In file test_shape_join_failed.py(15)/ a = grad(test_net)(x, y, sens_i)/
The function call stack (See file 'analyze_fail.dat' for more details):
# 0 In file test_shape_join_failed.py(15)
a = grad(test_net)(x, y, sens_i)
^
# 1 In file test_shape_join_failed.py(9)
def test_net(a, b):
^
Q: What can I do if an error “The params of function ‘bprop’ of Primitive or Cell requires the forward inputs as well as the ‘out’ and ‘dout” is reported?
A: The inputs of user-defined back propagation function bprop
should contain all the inputs of the forward pass, out
and dout
. The example is as follow:
class BpropUserDefinedNet(nn.Cell):
def __init__(self):
super(BpropUserDefinedNet, self).__init__()
self.zeros_like = P.ZerosLike()
def construct(self, x, y):
return x + y
def bprop(self, x, y, out, dout):
return self.zeros_like(out), self.zeros_like(out)
Q: What can I do if an error “There isn’t any branch that can be evaluated“ is reported? When an error similar to “There isn’t any branch that can be evaluated” appears. it means that there may be infinite recursion or loop in the code, which causes each branch of the if condition to be unable to deduce the correct type and dimension information.
The example is as follow:
from mindspore import Tensor, ms_function
from mindspore import dtype as mstype
import mindspore.context as context
ZERO = Tensor([0], mstype.int32)
ONE = Tensor([1], mstype.int32)
@ms_function
def f(x):
y = ZERO
if x < 0:
y = f(x - 3)
elif x < 3:
y = x * f(x - 1)
elif x < 5:
y = x * f(x - 2)
else:
y = f(x - 4)
z = y + 1
return z
def test_endless():
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([5], mstype.int32)
f(x)
the f(x)’s each branch of the if condition cannot deduce the correct type and dimension information
Q: What can I do if an error “Exceed function call depth limit 1000” is reported?
This indicates that there is an infinite recursive loop in the code, or the code is too complex, that caused the stack depth exceed.
At this time, you can set context.set_context(max_call_depth = value) to change the maximum depth of the stack, and consider simplifying the code logic or checking whether there is infinite recursion or loop in the code.
Otherwise, set max_call_depth can change the recursive depth of MindSpore, it may also cause exceed the maximum depth of the system stack and cause segment fault. At this time, you may also need to set the system stack depth.
Q: Why report an error that ‘could not get source code’ and ‘Mindspore can not compile temporary source code in terminal. Please write source code to a python file and run the file.’?
A: When compiling a network, MindSpore use inspect.getsourcelines(self.fn)
to get the code file. If the network is the temporary code which edited in terminal, MindSpore will report an error as the title. It can be solved if writing the network to a python file.
Q: Why report an error that ‘Corresponding forward node candidate:’ and ‘Corresponding code candidate:’?
A: “Corresponding forward node candidate:” is the code in the associated forward network, indicating that the backpropagation operator corresponds to the forward code. “Corresponding code candidate:” means that the operator is fused by these code, and the separator “-” is used to distinguish different code.
For example:
The operator FusionOp_BNTrainingUpdate_ReLUV2 reported an error and printed the following code:
Corresponding code candidate: - In file /home/workspace/mindspore/build/package/mindspore/nn/layer/normalization.py(212)/ return self.bn_train(x,/ In file /home/workspace/mindspore/tests/st/tbe_networks/resnet.py(265)/ x = self.bn1(x)/ In file /home/workspace/mindspore/build/package/mindspore/nn/wrap/cell_wrapper.py(109)/ out = self._backbone(data)/ In file /home/workspace/mindspore/build/package/mindspore/nn/wrap/cell_wrapper.py(356)/ loss = self.network(*inputs)/ In file /home/workspace/mindspore/build/package/mindspore/train/dataset_helper.py(98)/ return self.network(*outputs)/ - In file /home/workspace/mindspore/tests/st/tbe_networks/resnet.py(266)/ x = self.relu(x)/ In file /home/workspace/mindspore/build/package/mindspore/nn/wrap/cell_wrapper.py(109)/ out = self._backbone(data)/ In file /home/workspace/mindspore/build/package/mindspore/nn/wrap/cell_wrapper.py(356)/ loss = self.network(*inputs)/ In file /home/workspace/mindspore/build/package/mindspore/train/dataset_helper.py(98)/ return self.network(*outputs)/
The code call stack of the first separator points to ‘x = self.bn1(x)’ on line 265 in the network script file, and the code call stack of the second separator points to ‘x = self.bn1(x)’ in line 266 of the network script file. It can be seen that the operator FusionOp_BNTrainingUpdate_ReLUV2 is a fusion of these two lines of code.
The operator Conv2DBackpropFilter reported an error and printed the following code:
In file /home/workspace/mindspore/build/package/mindspore/ops/_grad/grad_nn_ops.py(65)/ dw = filter_grad(dout, x, w_shape)/ Corresponding forward node candidate: - In file /home/workspace/mindspore/build/package/mindspore/nn/layer/conv.py(266)/ output = self.conv2d(x, self.weight)/ In file /home/workspace/mindspore/tests/st/tbe_networks/resnet.py(149)/ out = self.conv1(x)/ In file /home/workspace/mindspore/tests/st/tbe_networks/resnet.py(195)/ x = self.a(x)/ In file /home/workspace/mindspore/tests/st/tbe_networks/resnet.py(270)/ x = self.layer2(x)/ In file /home/workspace/mindspore/build/package/mindspore/nn/wrap/cell_wrapper.py(109)/ out = self._backbone(data)/ In file /home/workspace/mindspore/build/package/mindspore/nn/wrap/cell_wrapper.py(356)/ loss = self.network(*inputs)/ In file /home/workspace/mindspore/build/package/mindspore/train/dataset_helper.py(98)/ return self.network(*outputs)/
The first line is the corresponding source code of the operator. The operator is a bprop operator realized by MindSpore. The second line indicates that the operator has an associated forward node, and points to ‘out = self.conv1(x)’ on line 149 of the network script file. In summary, the operator Conv2DBackpropFilter is a bprop operator, and the corresponding forward node is a convolution operator.
Q: What is “JIT Fallback”? What can I do if an error “Should not use Python object in runtime” is reported?
A: JIT Fallback is to realize the unification of static graph mode and dynamic graph mode from the perspective of static graph. With JIT Fallback feature, the static graph mode can support as many syntaxes in the dynamic graph mode as possible, so that the static graph mode can provide a syntax experience close to that of the dynamic graph mode. The environment variable switch of JIT Fallback is DEV_ENV_ENABLE_FALLBACK
, and JIT Fallback is enabled by default.
When the errors “Should not use Python object in runtime” and “We suppose all nodes generated by JIT Fallback would not return to outside of graph” appear, it means that there is an incorrect syntax in the code. When using the JIT Fallback feature to process unsupported syntax expressions, corresponding nodes will be generated, which need to be inferred and executed at compile time. Otherwise, these nodes will throw an error when passed to the runtime. The current JIT Fallback conditionally supports some constant scenes in Graph mode, and it also needs to conform to MindSpore’s programming syntax. Please refer to Static Graph Syntax Support.
For example, when calling the third-party library NumPy, JIT Fallback supports the syntax of np.add(x, y)
and Tensor(np.add(x, y))
, but MindSpore does not support returning the NumPy type. Therefore, the program will report an error. The code sample is as follows:
import numpy as np
import mindspore.nn as nn
from mindspore import context
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def construct(self, x, y):
out = np.add(x, y)
return out
net = Net()
out = net(1, 1)
The result is as follows:
RuntimeError: mindspore/ccsrc/pipeline/jit/validator.cc:139 ValidateValueNode] Should not use Python object in runtime, node: ValueNode<InterpretedObject> InterpretedObject: '2'
We suppose all nodes generated by JIT Fallback not return to outside of graph.
# In file test.py(9)
out = np.add(x, y)
^
When there is an error related to JIT Fallback, please review the code syntax and modify it according to Static Graph Syntax Support and the provided code line. If you need to turn off JIT Fallback, you can use export DEV_ENV_ENABLE_FALLBACK=0
.
Q: What can I do if an error “Operator[AddN] input(kNumberTypeBool,kNumberTypeBool) output(kNumberTypeBool) is not support. This error means the current input type is not supported, please refer to the MindSpore doc for supported types.”
A: Currently, Tensor with bool data type has weak support by MindSpore, only a few primitives support Tensor (bool). If Tensor(bool) used in forward graph correctly, but get total derivative in the backward graph will using primitive AddN
that not support Tensor(bool), which will raise exception.
The example is as follow:
from mindspore import context, ops, ms_function, Tensor, dtype
context.set_context(save_graphs=True, save_graphs_path='graph_path')
@ms_function
def test_logic(x, y):
z = x and y
return z and x
x = Tensor(True, dtype.bool_)
y = Tensor(True, dtype.bool_)
grad = ops.GradOperation(get_all=True)
grad_net = grad(test_logic)
out = grad_net(x, y)
The forward processing of the above code can be expressed as: r = f(z, x), z = z(x, y)
, the corresponding full derivative formula is: dr/dx = df/dz * dz/dx + df/dx
, functionf(z,x)
and z(x,y)
are primitive and
; Primitive and
in the forward graph supports Tensor (bool), but primitive AddN
in the backward graph not supports Tensor(bool). And the error cannot be mapped to a specific forward code line.
The result is as follows:
Traceback (most recent call last):
File "grad_fail.py", line 14, in <module>
out = grad_net(x, y)
File "/usr/local/python3.7/lib/python3.7/site-packages/mindspore/common/api.py", line 307, in staging_specialize
out = _MindsporeFunctionExecutor(func, ms_create_time, input_signature, process_obj)(*args)
File "/usr/local/python3.7/lib/python3.7/site-packages/mindspore/common/api.py", line 79, in wrapper
results = fn(*arg, **kwargs)
File "/usr/local/python3.7/lib/python3.7/site-packages/mindspore/common/api.py", line 221, in __call__
phase = self.compile(args_list, arg_names, parse_method)
File "/usr/local/python3.7/lib/python3.7/site-packages/mindspore/common/api.py", line 195, in compile
self.enable_tuple_broaden)
TypeError: mindspore/ccsrc/runtime/device/cpu/kernel_select_cpu.cc:235 KernelNotSupportException] Operator[AddN] input(kNumberTypeBool,kNumberTypeBool) output(kNumberTypeBool) is not support. This error means the current input type is not supported, please refer to the MindSpore doc for supported types.
Trace:
In file /usr/local/python3.7/lib/python3.7/site-packages/mindspore/ops/composite/multitype_ops/add_impl.py(287)/ return F.addn((x, y))/
If you encounter problems like this one, please remove the use of tensor (bool). In this example, replace tensor (bool) with bool can solve the problem.