mindspore.rewrite

查看源文件

MindSpore的ReWrite模块为用户提供了基于自定义规则,对网络的前向计算过程进行修改的能力,如插入、删除和替换语句。

如何快速使用ReWrite,请参考 使用ReWrite修改网络

class mindspore.rewrite.Node(node: NodeImpl)[源代码]

节点是表达网络中源码语句的一种数据结构。

每一个节点通常对应一条前向计算过程展开后的语句。

节点可以表达前向计算过程的Cell调用语句、Primitive调用语句、算术运算语句、返回语句等。

参数:
  • node (NodeImpl) - Node 的内部实现实例。建议调用Node下的指定方法来创建Node,例如 create_call_cell ,而不直接 调用Node的构造函数。不需关心NodeImpl是什么,只需作为句柄看待。

static create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None, name: str = '', is_sub_net: bool = False)[源代码]

通过该接口可以根据 cell 对象创建一个Node实例。节点对应的源代码格式:

targets = self.name(*args, **kwargs)

参数:
  • cell (Cell) - 该节点对应的前向计算的Cell对象。

  • targets (List[Union[ScopedValue, str]]) - 表示输出名称。在源代码中作为节点的输出变量名。

  • args (List[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。默认值: None ,表示 cell 没有参数输入。

  • kwargs (Dict[str, ScopedValue]) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 kwargs。默认值: None ,表示 cell 没有 kwargs 输入。

  • name (str) - 表示节点的名称。用作源代码中的字段名称。当未提供名称时,ReWrite将根据 target 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。默认值: ""

  • is_sub_net (bool) - 表示 cell 是否是一个网络。如果 is_sub_netTrue ,Rewrite将尝试将 cell 解析为TreeNode,否则为CallCell节点。默认值: False

返回:

Node实例。

异常:
  • TypeError - 如果参数 cell 不是Cell类型。

  • TypeError - 如果参数 targets 不是list类型。

  • TypeError - 如果参数 targets 的成员不是str或者ScopedValue类型。

  • TypeError - 如果参数 args 不是ScopedValue类型。

  • TypeError - 如果参数 kwargkey 不是str类型或者 value 不是ScopedValue类型。

样例:

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.insert(position, new_node)
>>> print(type(new_node))
<class 'mindspore.rewrite.api.node.Node'>
static create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None)[源代码]

通过该接口可以根据一个函数调用创建一个Node实例。

说明

函数内部的代码不会被解析。

参数:
  • function (FunctionType) - 被调用的函数定义。

  • targets (List[Union[ScopedValue, str]]) - 表示输出名称。在源代码中作为节点的输出变量名。

  • args (List[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。默认值: None ,表示 function 没有参数输入。

  • kwargs (Dict[str, ScopedValue]) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 kwargs。默认值: None ,表示 function 没有 kwargs 输入。

返回:

Node实例。

异常:
  • TypeError - 如果参数 function 不是FunctionType类型。

  • TypeError - 如果参数 targets 不是list类型。

  • TypeError - 如果参数 targets 的成员不是str或者ScopedValue类型。

  • TypeError - 如果参数 args 不是ScopedValue类型。

  • TypeError - 如果参数 kwargkey 不是str类型或者 value 不是ScopedValue类型。

样例:

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> from mindspore import ops
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_function(function=ops.abs, targets=['x'],
...                                      args=[ScopedValue.create_naming_value('x')])
>>> stree.insert(position, new_node)
>>> print(new_node.get_node_type())
NodeType.CallFunction
get_args()[源代码]

获取当前节点的参数列表。

返回:

参数列表,参数类型为 ScopedValue

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> print(node.get_args())
[x]
get_inputs()[源代码]

获取一个节点列表,列表里的节点的输出作为当前节点的输入。

返回:

节点列表。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv2")
>>> inputs = node.get_inputs()
>>> print([input.get_name() for input in inputs])
['max_pool2d']
get_instance_type()[源代码]

获取当前节点对应的代码语句里调用的对象类型。

  • 如果当前节点的 node_typeCallCell,表示该节点的语句调用了一个 Cell 类型对象。

  • 如果当前节点的 node_typeCallPrimitive,表示该节点的语句调用了一个 Primitive 类型对象。

  • 如果当前节点的 node_typeTree,表示该节点的语句调用了一个网络类型的对象。

  • 如果当前节点的 node_typePythonInputOutputCallMethod,返回的对象类型是 NoneType

返回:

当前节点对应的代码语句里调用的对象类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> instance_type = node.get_instance_type()
>>> print(instance_type)
<class 'mindspore.nn.layer.conv.Conv2d'>
get_kwargs()[源代码]

获取当前节点的关键字参数列表。

返回:

一个包含关键字参数的字典,key的类型为str,value的类型为 ScopedValue

样例:

>>> from mindspore.rewrite import SymbolTree
>>> from mindspore import nn
>>>
>>> class ReLUNet(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, input):
...         output = self.relu(x=input)
...         return output
>>>
>>> net = ReLUNet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("relu")
>>> print(node.get_kwargs())
{'x': input}
get_name()[源代码]

获取当前节点的名称。当节点被插入到SymbolTree时,节点的名称在SymbolTree中应该是唯一的。

返回:

节点的名称,类型为str。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> name = node.get_name()
>>> print(name)
conv1
get_node_type()[源代码]

获取当前节点的类型。节点类型详见 mindspore.rewrite.NodeType

返回:

NodeType,当前节点的类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node_type = node.get_node_type()
>>> print(node_type)
NodeType.CallCell
get_sub_tree()[源代码]

获取类型为 NodeType.Tree 的节点里保存的符号树。节点类型详见 mindspore.rewrite.NodeType

返回:

保存在Tree类型节点里的符号树。

异常:
  • TypeError - 如果当前节点的类型不是 NodeType.Tree

  • AttributeError - 如果当前Tree类型节点里没有保存符号树。

样例:

>>> import mindspore.nn as nn
>>> from mindspore.rewrite import SymbolTree
>>>
>>> class SubNet(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, x):
...         x = self.relu(x)
...         return x
...
>>> class Net(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.subnet = SubNet()
...
...     def construct(self, x):
...         x = self.subnet(x)
...         return x
>>>
>>> net = Net()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("subnet")
>>> print(type(node.get_sub_tree()))
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
get_symbol_tree()[源代码]

获取当前节点所属的SymbolTree。

返回:

SymbolTree,如果当前节点不属于任何SymbolTree,则返回 None .

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> print(type(node.get_symbol_tree()))
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
get_targets()[源代码]

获取当前节点的输出值列表。

返回:

输出值列表,参数类型为 ScopedValue

get_users()[源代码]

获取一个节点列表,列表里的节点使用当前节点的输出作为输入。

返回:

节点列表。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> users = node.get_users()
>>> print([user.get_name() for user in users])
['relu']
set_arg(index: int, arg: Union[ScopedValue, str])[源代码]

设置当前节点的输入参数。

参数:
  • index (int) - 要设置的参数索引。

  • arg (Union[ScopedValue, str]) - 新参数的值。

异常:
  • TypeError - 如果参数 index 不是int类型。

  • TypeError - 如果参数 arg 不是str或者ScopedValue类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("relu_3")
>>> node.set_arg(0, "fc1")
>>> print(node.get_args())
[fc1]
set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[源代码]

将另一个节点设置为当前节点的输入。

参数:
  • arg_idx (int) - 要设置的参数索引。

  • src_node (Node) - 输入的节点。

  • out_idx (int,可选) - 指定输入节点的哪个输出作为当前节点输入,则取第一个输出。默认值: None

异常:
  • TypeError - 如果参数 arg_idx 不是int类型。

  • ValueError - 如果参数 arg_idx 超出了当前节点的参数数量。

  • TypeError - 如果参数 src_node 不是Node类型。

  • TypeError - 如果参数 out_idx 不是int类型。

  • ValueError - 如果参数 out_idx 超出了 src_node 的输出数量。

  • ValueError - 当 out_idx 为None或者没有给 out_idx 赋值时,参数 src_node 有多个输出。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> src_node = stree.get_node("fc1")
>>> dst_node = stree.get_node("relu_3")
>>> dst_node.set_arg_by_node(0, src_node, 0)
>>> print(dst_node.get_args())
[fc1_var]
class mindspore.rewrite.NodeType[源代码]

NodeType表示Node的类型。

  • Unknown:未初始化的节点类型。

  • CallCellCallCell 节点表示在前向计算中调用Cell对象。

  • CallPrimitiveCallPrimitive 节点代表在前向计算中调用Primitive对象。

  • CallFunctionCallFunction 节点代表在前向计算中调用了一个函数。

  • CallMethodCallMethod 不能对应到Cell或者Primitive的节点。

  • PythonPython 节点代表不支持的 ast 节点或无需解析的 ast 节点。

  • InputInput 节点代表SymbolTree的输入,对应方法的参数。

  • OutputOutput 节点代表SymbolTree的输出,对应方法的 return 语句。

  • TreeTree 节点代表前向计算中调用了别的网络。

  • CellContainer: CellContainer 节点代表在前向计算中调用 mindspore.nn.SequentialCell 函数。

  • MathOpsMathOps 节点代表在前向计算中的一个运算操作,如加法运算或比较运算。

  • ControlFlowControlFlow 节点代表一个控制流语句,如 if 语句。

class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[源代码]

ScopedValue表示具有完整范围的值。

ScopedValue用于表示:左值,如赋值语句的目标,或可调用对象,如调用语句的 func,或右值,如赋值语句的 argskwargs

参数:
  • arg_type (ValueType) - 当前值的类型。

  • scope (str,可选) - 字符串表示当前值的范围。以"self.var1"为例,这个var1的作用域是"self"。默认值: ""

  • value - 当前ScopedValue中保存的值。值的类型对应于 arg_type。默认值: None

static create_name_values(names: Union[List[str], Tuple[str]], scopes: Union[List[str], Tuple[str]] = None)[源代码]

创建ScopedValue的列表。

参数:
  • names (List[str] or Tuple[str]) - 引用变量的名称,类型为str的列表或元组。

  • scopes (List[str] or Tuple[str],可选) - 引用变量的范围,类型为str的列表或元组。默认值: None ,表示没有指定作用范围。

返回:

ScopedValue的实例列表。

异常:
  • TypeError - 如果 names 不是 listtuple 或者其中的元素不是str类型。

  • TypeError - 如果 scopes 不是 listtuple 或者其中的元素不是str类型。

  • ValueError - 如果 names 的长度不等于 scopes 的长度,而作用域不是None。

样例:

>>> from mindspore.rewrite import ScopedValue
>>> variables = ScopedValue.create_name_values(names=["z", "z_1"], scopes=["self", "self"])
>>> print(variables)
[self.z, self.z_1]
classmethod create_naming_value(name: str, scope: str = '')[源代码]

创建一个使用变量名称命名的ScopedValue。NamingValue表示对另一个变量的引用。

参数:
  • name (str) - 表示变量的字符串。

  • scope (str,可选) - 表示变量范围的字符串,默认值: "" ,表示没有指定作用范围。

返回:

ScopedValue的实例。

异常:
  • TypeError - 如果 name 不是str类型。

  • TypeError - 如果 scope 不是str类型。

样例:

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_naming_value("conv", "self")
>>> print(variable)
self.conv
classmethod create_variable_value(value)[源代码]

创建一个保存变量的ScopedValue。ScopedValue的类型由值的类型决定。ScopedValue的范围是空的。

参数:
  • value - 要转换为ScopedValue的值。

返回:

ScopedValue的实例。

样例:

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_variable_value(2)
>>> print(variable)
2
class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[源代码]

SymbolTree保存了一个网络的信息,包括网络前向计算过程的语句,和语句输入输出之间的拓扑关系。

网络里的语句以节点的形式保存在SymbolTree中,通过对SymbolTree里的节点进行处理,可以实现网络代码的删除、插入、替换等操作, 并得到修改后的网络代码及网络实例。

参数:
  • handler (SymbolTreeImpl) - SymbolTree内部实现实例。建议调用SymbolTree下的 create 方法来创建SymbolTree,而不直接 调用SymbolTree的构造函数。不需关心SymbolTreeImpl是什么,只需作为句柄看待。

after(node: Union[Node, str])[源代码]

返回一个位置信息,位置为 node 之后。该接口的返回值作为插入操作的参数使用。

参数:
  • node (Union[Node, str]) - 指定插入位置在哪个节点之后,可以是Node或者Node的名称。

返回:

Position,指定插入节点的位置。

异常:
  • TypeError - 参数不是Node或str类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.after(node)
before(node: Union[Node, str])[源代码]

返回一个位置信息,位置为 node 之前。该接口的返回值作为插入操作的参数使用。

参数:
  • node (Union[Node, str]) - 指定插入位置在哪个节点之前,可以是Node或者Node的名称。

返回:

Position,指定插入节点的位置。

异常:
  • TypeError - 参数不是Node或str类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.before(node)
classmethod create(network)[源代码]

通过传入网络实例 network ,创建一个SymbolTree对象。

该接口会解析传入的网络实例,将前向计算过程的每一条源码语句展开,并解析为节点,存储在SymbolTree中。具体流程如下:

  1. 获取网络实例对应的源码代码

  2. 对网络进行AST解析,获取网络里各个语句的AST节点(抽象语法树)

  3. 将网络前向计算过程里的复杂语句展开为多个简单语句

  4. 创建SymbolTree对象,每个SymbolTree对应一个网络实例

  5. 使用rewrite节点存储网络前向计算过程的每条语句,节点记录了语句的输入、输出等信息

  6. 将rewrite节点保存到SymbolTree里,同时更新和维护节点间的拓扑连接关系

  7. 返回网络实例对应的SymbolTree对象

如果网络的前向计算过程里调用了类型为 mindspore.nn.Cell 的用户自定义网络,rewrite会为对应语句生成类型 为 NodeType.Tree 的节点,这类节点内部保存了一个新的SymbolTree,这个SymbolTree解析并维护着自定义网络的节点信息。

如果网络的前向计算过程里调用了以下类型的语句,rewrite会将该语句所对应的内部语句进行解析,并生成对应节点:

说明

由于网络在rewrite操作期间,控制流的具体执行分支还处于未知状态,因此控制流内部的节点和外部的节点之间不会建立拓扑信息。 用户在控制流外部使用 mindspore.rewrite.Node.get_inputs()mindspore.rewrite.Node.get_users() 接口获取节点时, 无法获取控制流内部的节点。用户在控制流内部使用这些接口,也无法获取控制流外部的节点。 因此用户在进行网络修改时,需要手动处理好控制流内部和外部的节点信息。

当前rewrite模块存在以下语法限制:

  • 仅支持类型为 mindspore.nn.Cell 的网络作为rewrite模块的输入。

  • 暂不支持对单行控制流语法(如单行if-else、单行for循环等)进行解析。

  • 暂不支持对装饰器语法进行解析。

  • 暂不支持对局部类和内嵌类进行解析,即类的定义需要放在最外层。

  • 暂不支持对闭包语法进行解析,即类外函数的定义需要放在最外层。

  • 暂不支持对lambda表达式语法进行解析。

  • 暂不支持对全局变量进行解析,即需要将全局变量转换为类变量或局部变量后再使用。

  • 暂不支持对父类里的方法进行解析。

对于不支持解析的语句,rewrite会为对应语句生成类型为 NodeType.Python 的节点,以确保rewrite后的网络可以正常运行。 Python 节点不支持对语句的输入和输出进行修改,且可能出现变量名与rewrite生成的变量名的问题,此时用户需要手动对变量名进行调整。

参数:
  • network (Cell) - 待修改的网络实例。

返回:

SymbolTree,基于 network 创建的SymbolTree。

异常:
  • TypeError - 参数 network 不是Cell类型对象。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> print(type(stree))
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
erase(node: Union[Node, str])[源代码]

删除SymbolTree中的一个节点。

参数:
  • node (Union[Node, str]) - 被删除的节点。可以是Node或者Node的名称。

返回:

如果 node 属于当前的SymbolTree则返回被删除节点。否则返回None。

异常:
  • TypeError - 参数不是Node或str类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> stree.erase(node)
get_code()[源代码]

获取SymbolTree里的网络信息所对应的源码。如果网络已经被修改过,则返回的是修改后的源码。

返回:

str,SymbolTree对应的源码字符串。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> codes = stree.get_code()
>>> print(codes)
get_network()[源代码]

获取基于SymbolTree生成的网络对象。源码会保存到文件中,文件保存在当前目录的 rewritten_network 文件夹里。

说明

  • rewrite模块对网络的修改基于对原有网络实例的AST树的修改实现,且新的网络实例会从原有网络实例里获取属性信息, 因此,新网络实例和原有网络实例存在数据关联,原有网络不应该再被使用。

  • 由于新网络和原有网络实例存在数据关联,暂不支持使用rewrite生成的源码文件手动创建网络实例。

返回:

根据SymbolTree生成的网络对象。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> new_net = stree.get_network()
get_node(node_name: str)[源代码]

获取SymbolTree里名称为 node_name 的节点。

参数:
  • node_name (str) - 节点名称。

返回:

名称为 node_name 的节点。如果SymbolTree里没有名称为 node_name 的节点,则返回 None

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node('conv1')
>>> print(node.get_name())
conv1
insert(position, node: Node)[源代码]

在SymbolTree的 position 位置插入一个节点。 position 通过 beforeafter 来获得。

参数:
  • position (Position) - 插入位置。

  • node (Node) - 要插入的节点。

返回:

Node,被插入的节点。

异常:
  • ValueError - 如果 position 指定的不是该SymbolTree内的位置。

  • TypeError - 如果参数 position 不是Position类型。

  • TypeError - 如果参数 node 不是Node类型。

样例:

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.insert(position, new_node)
nodes(all_nodes: bool = False)[源代码]

返回当前SymbolTree里节点的生成器,该接口用于遍历SymbolTree里的节点。

参数:
  • all_nodes (bool) - 获取所有节点,包括在 CallFunction 节点、 CellContainer 节点和 子SymbolTree里面的节点。默认值: False

返回:

SymbolTree中节点的生成器。

异常:
  • TypeError - 如果参数 all_nodes 不是bool类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> print([node.get_name() for node in stree.nodes()])
['input_x', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1',
 'unaryop_not', 'if_node', 'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return_1']
print_node_tabulate(all_nodes: bool = False)[源代码]

打印SymbolTree里节点的拓扑信息,包括节点类型、节点名称、节点对应代码、节点的输入输出关系等。

信息通过print接口输出到屏幕上,包括以下信息:

  • node type (str):节点类型,具体类型参考 mindspore.rewrite.NodeType

  • name (str): 节点名称。

  • codes (str): 节点对应的SymbolTree里的代码语句。

  • arg providers (Dict[int, Tuple[str, int]]): 格式为 {[idx, (n, k)]} ,代表该节点的第 idx 个参数是节点 n 的第 k 个输出提供的。

  • target users (Dict[int, List[Tuple[str, int]]]): 格式为 {[idx, [(n, k)]]} ,代表该节点的第 idx 个输出被用作节点 n 的第 k 个参数。

参数:
  • all_nodes (bool) - 打印所有节点的信息,包括在 CallFunction 节点、 CellContainer 节点和 子SymbolTree里面的节点。默认值: False

异常:
  • TypeError - 如果参数 all_nodes 不是bool类型。

样例:

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> stree.print_node_tabulate()
replace(old_node: Node, new_nodes: List[Node])[源代码]

使用 new_nodes 列表里的节点来替代旧节点 old_node

该接口会将 new_nodes 里的节点按顺序插入到SymbolTree中,然后删除旧节点 old_node

说明

  • 仅支持一对一更换或一对多替换。如果需要多对多替换,请参考PatternEngine。

  • 调用者应维护好 new_nodes 里每个节点间的拓扑关系,以及 new_nodes 里的节点与原始树中节点的拓扑关系。

参数:
  • old_node (Node) - 被替换节点。

  • new_nodes (List[Node]) - 要替换进SymbolTree的节点列表。

返回:

替换到SymbolTree的节点列表的根节点。

异常:
  • TypeError - 如果参数 new_nodes 不是list,或者列表中的成员不是Node类型。

  • TypeError - 如果参数 old_node 不是Node类型。

样例:

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.replace(node, [new_node])
unique_name(name: str = 'output')[源代码]

基于给定 name ,返回一个SymbolTree内唯一的新的名称。当需要一个不冲突的变量名时,可以使用该接口。

参数:
  • name (str, 可选) - 名称前缀。默认值: "output"

返回:

str,一个SymbolTree内唯一的新的名称,名称格式为 name_n ,其中 n 为数字下标。如果输入 name 没有名称冲突,则没有数字下标。

异常:
  • TypeError - 如果参数 name 不是str类型。

class mindspore.rewrite.ValueType[源代码]

ValueType表示ScopedValue的类型。

  • NamingValue表示对另一个变量的引用。

  • CustomObjValue表示自定义类的实例,或类型超出ValueType的基本类型和容器类型范围的对象。