mindspore.rewrite
MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进行插入、删除和替换语句。该功能目前处于开发调试阶段,可能会更改或删除。
- class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[源代码]
SymbolTree通常对应于网络的前向计算过程。
- 参数:
handler (SymbolTreeImpl) - SymbolTree内部实现实例。
- after(node: Node)[源代码]
获取插入位置,位置为 node 之后。 返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 Position是什么,只需将其视为处理程序并将其用作SymbolTree的插入接口的参数。
- 参数:
node (Node) - 指定插入位置在哪个节点之后,可以是Node或者Node的名称。
- 返回:
Position,指定插入节点的位置。
- 异常:
TypeError - 参数不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.after(node)
- before(node: Node)[源代码]
与after的区别是,该接口返回的位置为 node 之前。 返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 Position 是什么,只需将其视为处理程序并将其用作 SymbolTree 的插入接口的参数。
- 参数:
node (Node) - 指定插入位置在哪个节点之前,可以是Node或者Node的名称。
- 返回:
Position,指定插入节点的位置。
- 异常:
TypeError - 参数不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.before(node)
- create(network)[源代码]
根据传入的 network 创建SymbolTree对象。
- 参数:
network (Cell) - 重写的网络。
- 返回:
SymbolTree,基于 network 创建的符号树。
- 异常:
TypeError - 参数 network 不是Cell类型对象。
- create_call_function(func, targets, args, kwargs)[源代码]
创建一个Node对象,并生成执行代码插入源码中。源码中以 args 和 kwargs 为参数调用 func 函数。
- 参数:
func (FunctionType) - 要被调用的函数。
targets (list[str]) - 表示输出名称。在源代码中作为节点的输出。
args (Union[MsDtypes, ParamTypes]) - 该节点的参数名称。用作源代码中代码语句的参数。默认为None表示 cell 没有参数输入。
kwargs (dict{str,Union[MsDtypes, ParamTypes]}) - 键的类型必须是str,值必须是MsDtypes或类型必须是ParamTypes。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 kwargs。默认为None,表示没有 kwargs 输入。
- 返回:
一个Node实例。
- 异常:
TypeError - 如果参数 func 不是FunctionType类型。
TypeError - 如果参数 targets 不是list类型。
TypeError - 如果参数 targets 的成员不是str类型。
TypeError - 如果参数 args 不是ParamType类型。
TypeError - 如果参数 kwarg 的 key 不是str类型或者 value 不是ParamType类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> new_node = stree.create_call_function(F.abs, ["x"], node)
- erase_node(node: Node)[源代码]
删除SymbolTree中的一个节点。被删除的节点必须不被其他节点依赖。
- 参数:
node (Node) - 被删除的节点。可以是Node或者Node的名称。
- 返回:
如果 node 属于当前的SymbolTree则返回被删除节点。否则返回None。
- 异常:
TypeError - 如果参数不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> input_node = node.get_inputs()[0] >>> output_nodes = node.get_users() >>> for n in output_nodes: ... n.set_arg(0, "x") >>> stree.erase_node(node)
- get_code()[源代码]
获取SymbolTree所对应的源代码。
- 返回:
str,SymbolTree对应的源码字符串。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> stree.get_code()
- get_handler()[源代码]
获取SymbolTree对应实现的handle。
- 返回:
SymbolTree对象。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> handler = stree.get_handler()
- get_network()[源代码]
获取SymbolTree所对应的生成的网络对象。源码会保存到文件中,默认的文件名为 network_define.py。
- 返回:
根据SymbolTree生成的网络对象。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> stree.get_network()
- get_node(node_name: str)[源代码]
获取节点名为 node_name 的节点。
- 参数:
node_name (str) - 节点的名称。
- 返回:
如果找到则返回结果,否则返回 None。
- 异常:
TypeError - 如果 node_name 不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1")
- insert(position, node: Node)[源代码]
在SymbolTree的 position 位置插入一个节点。 position 可以通过 before 或 after 来获得。
- 参数:
position (Position) - 插入位置。
node (Node) - 要插入的节点。
- 返回:
Node,被插入的节点, 当调用此方法时会对参数进行唯一性处理, node 会被修改。
- 异常:
RuntimeError - 如果 position 指定的不是该SymbolTree内的位置。
TypeError - 如果参数 position 不是Position类型。
TypeError - 如果参数 node 不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> position = stree.after(node) >>> new_node = stree.create_call_function(F.abs, ["x"], node) >>> stree.insert(position, new_node)
- nodes()[源代码]
获取当前SymbolTree的节点,用于遍历。
- 返回:
当前SymbolTree中节点的生成器。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... node.set_attribute("channel", 3)
- print_node_tabulate()[源代码]
打印当前SymbolTree的节点信息表格。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> stree.print_node_tabulate()
- replace(old_node: Node, new_nodes: [Node])[源代码]
使用新节点列表来替代旧节点。
说明
仅支持一对一更换或一对多替换。如果需要多对多替换,请参考PatternEngine。
当一对多替换时,Rewrite会将 new_nodes 中所有节点插入到 symbol_tree 中。
调用者应指定子树内节点的参数和输出来确定子树内的拓扑关系。
调用者应指定子树输入节点的参数来确定子树与原始树中节点的拓扑关系。
ReWrite将维护子树的前置节点的参数,用于指定子树输出的拓扑关系。
将 new_nodes 替换到SymbolTree后,ReWrite将维护节点的所有输入。
- 参数:
old_node (Node) - 被替换节点。
new_nodes (list[Node]) - 要替换进SymbolTree的节点列表。
- 返回:
替换到SymbolTree的节点列表的根节点。
- 异常:
RuntimeError - 如果 old_node 仍然被其他节点依赖。
TypeError - 如果参数 new_nodes 不是list,或者列表中的成员不是Node类型。
TypeError - 如果参数 old_node 不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> new_node = stree.create_call_function(F.abs, ["x"], node) >>> stree.replace(node, [new_node])
- class mindspore.rewrite.Node(node: NodeImpl)[源代码]
节点是表达网络中源代码的一种数据结构。
在大多数情况下,Node表示前向计算的的运算,它可以是Cell的实例、Primitive的实例或可调用的方法。
- 参数:
node (NodeImpl) - NodeImpl 的handle。NodeImpl是Node的实现,不是Rewrite的接口。Rewrite建议调用Node的特定 create 方法来实例化Node的实例,例如 create_call_cell,而不直接调用Node的构造函数,不需关心NodeImpl是什么,只需作为handle看待。
- static create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue} = None, name: str = '', is_sub_net: bool = False)[源代码]
通过该接口可以根据 cell 对象创建一个Node实例。节点对应的源代码格式:
targets = self.name(*args, **kwargs)
。- 参数:
cell (Cell) - 该节点对应的前向计算的Cell对象。
targets (list[ScopedValue]) - 表示输出名称。在源代码中作为节点的输出。Rewrite将在插入节点时检查并确保每个目标的唯一性。
args (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。表示 cell 没有参数输入。Rewrite将在插入节点时检查并确保每个 arg 的唯一性。默认值:None。
kwargs (dict) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 kwargs。表示 cell 没有 kwargs 输入。Rewrite将在插入节点时检查并确保每个 kwarg 的唯一性。默认值:None。
name (str) - 表示节点的名称。用作源代码中的字段名称。当名称为无时,ReWrite将根据 target 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。默认值:None。
is_sub_net (bool) - 表示 cell 是否是一个网络。如果 is_sub_net 为真,Rewrite将尝试将 cell 解析为TreeNode,否则为CallCell节点。默认值:False。
- 返回:
Node实例。
- 异常:
TypeError - 如果参数 cell 不是Cell类型。
TypeError - 如果参数 targets 不是list类型。
TypeError - 如果参数 targets 的成员不是str或者ScopedValue类型。
TypeError - 如果参数 args 不是ScopedValue类型。
TypeError - 如果参数 kwarg 的 key 不是str类型或者 value 不是ScopedValue类型。
- get_args()[源代码]
获取当前节点的参数。
当前节点的 node_type 为 CallCell、 CallPrimitive 或 Tree 时,返回值对应于 ast.Call 的 args,表示调用 cell-op 或 primitive-op 的 forward 方法的参数。
当前节点的 node_type 为 Input 时,返回值为函数参数的默认值。
当前节点的 node_type 为 Output 时,返回值为网络的返回值。
当前节点的 node_type 为 Python 时,没有实际含义,可以忽略。
- 返回:
ScopedValue 实例的列表。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> args = node.get_args()
- get_attribute(key: str)[源代码]
获取当前节点属性 key 的值。
- 参数:
key (str) - 属性的名称。
- 返回:
属性值,可能是任意类型。
- 异常:
TypeError - 如果参数 key 不是str类型。
- get_instance()[源代码]
获取当前节点对应的 operation 实例。
如果当前节点的 node_type 是 CallCell,该节点的实例是一个Cell的对象。
如果当前节点的 node_type 是 CallPrimitive,该节点的实例是一个Primitive的对象。
如果当前节点的 node_type 是 Tree,该节点的实例是一个网络的对象。
如果当前节点的 node_type 是 Python、 Input、 Output、 CallMethod,该节点的实例为None。
- 返回:
当前节点的 operation 实例。
- get_instance_type()[源代码]
获取当前节点对应的 operation 实例类型。
如果当前节点的 node_type 是 CallCell,该节点是Cell对象。
如果当前节点的 node_type 是 CallPrimitive,该节点的是Primitive对象。
如果当前节点的 node_type 是 Tree,该节点的类型是网络。
如果当前节点的 node_type 是 Python、 Input、 Output、 CallMethod,该节点的类型为NoneType。
- 返回:
当前节点的 operation 类型。
- get_kwargs()[源代码]
获取当前节点带 key 值的参数。
当前节点的 node_type 为 CallCell、 CallPrimitive 或 Tree 时,关键字参数对应于 ast.Call 的 kwargs,表示调用 cell-op 或 Primitive-op 方法的参数。
当前节点的 node_type 为 Python、 Input 或 Output 时,不关心关键字参数。
- 返回:
key 为str, value 为ScopedValue的字典。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> kwargs = node.get_kwargs()
- get_name()[源代码]
获取当前节点的名称。当节点被插入到SymbolTree时,节点的名称在SymbolTree中应该是唯一的。
- 返回:
节点的名称,类型为str。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> name = node.get_name()
- get_node_type()[源代码]
获取当前节点节点的类型。
- 返回:
NodeType,当前节点的类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> node_type = node.get_node_type()
- get_targets()[源代码]
获取当前节点的输出名称。
当前节点的 node_type 为 CallCell、 CallPrimitive、 CallMethod 或 Tree 时, target 为字符串,表示单元操作或原始操作或函数调用的调用结果,它们对应于 ast.Assign 的 targets。
当前节点的 node_type 为 Input 时, targets 应该只有一个元素,字符串代表函数的参数。
当前节点的 node_type 为 Python 或 Output 时, target 不需要关心。
- 返回:
节点输出的ScopedValue列表。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> targets = node.get_targets()
- get_users()[源代码]
按拓扑顺序获取当前节点的输出节点。
- 返回:
输出节点的列表。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> users = node.get_users()
- set_arg(index: int, arg: Union[ScopedValue, str])[源代码]
设置当前节点的输入参数。
- 参数:
index (int) - 要设置的参数索引。
arg (Union[ScopedValue, str]) - 新参数的值。
- 异常:
TypeError - 如果参数 index 不是int类型。
TypeError - 如果参数 arg 不是str或者ScopedValue类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> node.set_arg(0, "x")
- set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[源代码]
将另一个节点设置为当前节点的输入。
- 参数:
arg_idx (int) - 要设置的参数索引。
src_node (Node) - 输入的节点。
out_idx (int,optional) - 指定输入节点的哪个输出作为当前节点输入,则取第一个输出。默认值:None。
- 异常:
RuntimeError - 如果 src_node 不属于当前的SymbolTree。
RuntimeError - 如果当前节点和 src_node 不属于同一个SymbolTree。
TypeError - 如果参数 arg_idx 不是int类型。
ValueError - 如果参数 arg_idx 超出了当前节点的参数数量。
TypeError - 如果参数 src_node 不是Node类型。
TypeError - 如果参数 out_idx 不是int类型。
ValueError - 如果参数 out_idx 超出了 src_node 的输出数量。
ValueError - 当 out_idx 为None或者没有给 out_idx 赋值时,参数 src_node 有多个输出。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> src_node = stree.get_node("conv1") >>> dst_node = stree.get_node("conv2") >>> dst_node.set_arg_by_node(0, src_node)
- set_attribute(key: str, value)[源代码]
设置当前节点的属性。
- 参数:
key (str) - 属性的名称。
value (object) - 属性值。
- 异常:
TypeError - 如果参数 key 不是str类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> node.set_attribute("channel", 3)
- class mindspore.rewrite.NodeType[源代码]
NodeType表示Node的类型。
Unknown:未初始化的节点类型。
CallCell: CallCell 节点表示在前向计算中调用Cell对象。
CallPrimitive: CallPrimitive 节点代表在前向计算中调用Primitive对象。
CallMethod: CallMethod 不能对应到Cell或者Primitive的节点。
Python: Python 节点包含不支持的 ast 的节点类型或不必要的解析 ast 节点。
Input:输入节点代表SymbolTree的输入,对应方法的参数。
Output: 输出节点代表SymbolTree的输出,对应方法的 return 语句。
Tree: 树节点代表转发方法中的子网调用。
- class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[源代码]
ScopedValue表示具有完整范围的值。
ScopedValue用于表示:左值,如赋值语句的目标,或可调用对象,如调用语句的 func,或右值,如赋值语句的 args 和 kwargs。
- 参数:
arg_type (ValueType) - 当前值的类型。
scope (str) - 字符串表示当前值的范围。以”self.var1”为例,这个var1的作用域是”self”。默认值: “”。
value - 当前ScopedValue中保存的值。值的类型对应于 arg_type。默认值:None。
- static create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None)[源代码]
创建ScopedValue的列表。
- 参数:
names (list[str] or tuple[str]) – 引用变量的名称,类型为str的列表或元组。
scopes (list[str] or tuple[str]) – 引用变量的范围,类型为str的列表或元组。表示没有指定作用范围。默认值:None。
- 返回:
ScopedValue的实例列表。
- 异常:
TypeError - 如果 names 不是 list 或 tuple 或者其中的元素不是str类型。
TypeError - 如果 scopes 不是 list 或 tuple 或者其中的元素不是str类型。
RuntimeError - 如果 names 的长度不等于 scopes 的长度,而作用域不是None。
样例:
>>> from mindspore.rewrite import ScopedValue >>> variables = ScopedValue.create_name_values(["z", "z_1"]), name="subnet")
- create_naming_value(name: str, scope: str = '')[源代码]
创建一个使用变量名称命名的ScopedValue。NamingValue表示对另一个变量的引用。
- 参数:
name (str) – 表示变量的字符串。
scope (str) – 表示变量范围的字符串,表示没有指定作用范围。默认值:空字符串。
- 返回:
ScopedValue的实例。
- 异常:
TypeError - 如果 name 不是str类型。
TypeError - 如果 scope 不是str类型。
样例:
>>> from mindspore.rewrite import ScopedValue >>> variable = ScopedValue.create_naming_value("conv", "self")
- class mindspore.rewrite.ValueType[源代码]
ValueType表示ScopedValue的类型。
NamingValue表示对另一个变量的引用。
CustomObjValue表示自定义类的实例,或类型超出ValueType的基本类型和容器类型范围的对象。
- class mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None)[源代码]
PatternEngine通过PattenNode修改SymbolTree。
- 参数:
pattern (Union[PatternNode, List]) - PatternNode的实例或用于构造 Pattent 的Cell类型列表。
replacement (callable) - 生成新节点的接口实现。
- apply(stree: SymbolTree)[源代码]
在 stree 上面执行当前的匹配模式。
说明
当前还不支持子树节点。
- 参数:
stree (SymbolTree) - 要修改的SymbolTree。
- 返回:
bool,表示是否对 stree 进行了修改。
- 异常:
TypeError - 如果参数 stree 不是SymbolTree类型。
- class mindspore.rewrite.PatternNode(pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None)[源代码]
PatternNode在定义 pattern 时被定义为一个节点。
- 参数:
pattern_node_name (str) - 节点名称。
match_type (Type) - 当前节点的匹配类型。默认值:None。
inputs (list[PatternNode]) - 当前节点的输入节点。默认值:None。
- add_input(node)[源代码]
为当前节点添加输入。
- 参数:
node (PatternNode) - 新增的输入节点。
- 异常:
TypeError - 如果参数 node 不是PattenNode类型。
- static create_pattern_from_list(type_list: [])[源代码]
使用类型的列表来创建Pattern。
- 参数:
type_list (list[type]) - 类型列表。
- 返回:
根据列表生成的模式的根节点。
- 异常:
TypeError - 如果 type_list 不是list类型。
- static create_pattern_from_node(node: Node)[源代码]
根据节点及其输入创建Pattern。
- 参数:
node (Node) - 要修改的节点。
- 返回:
根据 node 创建的PattentNode。
- 异常:
TypeError - 如果 node 不是Node类型。
- static from_node(node: Node)[源代码]
根据 node 创建PatternNode。
- 参数:
node (Node) - 要修改的节点。
- 返回:
根据 node 创建的PattentNode。
- 异常:
TypeError - 如果 node 不是Node类型。
- match(node: Node)[源代码]
检查当前PatternNode是否可以与node匹配。
- 参数:
node (Node) - 要匹配的节点。
- 异常:
TypeError - 如果参数 node 不是PattenNode类型。
- class mindspore.rewrite.Replacement[源代码]
替换的接口定义。
样例:
>>> from mindspore.rewrite import Replacement, Node >>> from mindspore.nn import nn >>> class BnReplacement(Replacement): ... def build(self, pattern, is_chain_pattern: bool, matched): ... bn_node: Node = matched.get(pattern.name()) ... conv = nn.Conv2d(16, 16, 3) ... conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs()) ... return [conv_node]
- abstract build(pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict)[源代码]
用于从匹配结果创建替换节点的接口定义。
说明
返回值将作为SymbolTree的替换函数的参数,返回值应遵循替换函数参数的 new_nodes 的约束。请参阅SymbolTree的 replace 的文档字符串中的详细信息。
- 参数:
pattern (PatternNode) - 当前模式的根节点。
is_chain_pattern (bool) - 标记,标记模式是链模式或树模式。
matched (OrderedDict) - 匹配结果,从名称映射到节点的字典。
- 返回:
作为替换节点的节点实例列表。