mindspore.rewrite
MindSpore的ReWrite功能用于修改网络前向的计算过程,对网络进行插入、删除和替换语句。该功能目前处于开发调试阶段,可能会更改或删除。
ReWrite完整示例请参考 rewrite_example.py 。 该样例代码的主要功能包括:怎么通过网络创建SymbolTree,并且对SymbolTree中的节点进行插入删除替换等操作,其中还包含了对子网络的修改和通过模式匹配进行节点替换。
from typing import OrderedDict
import numpy as np
import mindspore
from mindspore import Tensor, export
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, Replacement, PatternEngine, PatternNode, \
TreeNodeHelper
import mindspore.nn as nn
import mindspore.ops as ops
class SubNet(nn.Cell):
"""子网络定义"""
def __init__(self):
super().__init__()
self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
self.mean = ops.ReduceMean(keep_dims=False)
self.conv1 = nn.Conv2d(1, 1, 1, stride=1)
def construct(self, x):
x = self.conv1(x)
x = self.dense(x)
return x
class Net(nn.Cell):
"""网络定义"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1, pad_mode='valid')
self.conv2 = nn.Conv2d(1, 1, 1, pad_mode='valid')
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
self.simnet = SubNet()
def construct(self, x):
"""网络的前向计算过程"""
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.simnet(x)
x = self.flatten(x)
return x
def create_stree(network):
"""创建SymbolTree"""
stree = SymbolTree.create(network)
stree.dump()
return stree
def insert_node(stree):
"""在网络中插入节点"""
for node in stree.nodes():
if node.get_name() == "conv2": # 在名称为'conv2'的节点前面插入新的节点
position = stree.before(node)
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=node.get_args())
stree.insert(position, new_conv_node)
break
# 使用新节点更新已有节点的输入
if new_conv_node is not None:
for node in stree.nodes():
if node.get_name() == "relu_1":
node.set_arg_by_node(0, new_conv_node)
break
def insert_node_to_subtree(stree):
"""在子网络中插入节点"""
def _insert_conv(stree: SymbolTree):
for node in stree.nodes():
if node.get_instance_type() == nn.Conv2d:
position = stree.after(node)
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
args=[ScopedValue.create_naming_value('x_1')])
stree.insert(position, new_conv_node)
break
# 在名称为'simnet'的子网络中插入新节点
for node in stree.nodes():
if node.get_node_type() == NodeType.Tree and node.get_name() == "simnet":
_insert_conv(TreeNodeHelper.get_sub_tree(node))
break
def delete_node(stree):
"""删除类型为nn.Flatten的节点"""
for node in stree.nodes():
if node.get_instance_type() == nn.Flatten:
for n in node.get_users():
n.set_arg(0, "x_7")
stree.erase_node(node)
break
def replace_node(stree):
"""替换网络中的节点"""
new_conv = nn.Conv2d(1, 1, 1)
new_conv_node = Node.create_call_cell(new_conv, [ScopedValue.create_naming_value("replace_conv")],
args=[ScopedValue.create_naming_value('x')])
for node in stree.nodes():
if node.get_name() == "conv1":
new_conv_node = stree.replace(node, [new_conv_node])
def pattern_replace(stree):
"""通过模式匹配的方式替换节点"""
class ConvReplacement(Replacement):
"""创建新节点类的实现"""
def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
bn_node: Node = matched.get(pattern.name())
conv = nn.Conv2d(1, 1, 1)
conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs(),
name="pattern_conv")
return [conv_node]
class BnReplace(PatternEngine):
"""替换网络中nn.MaxPool2d类型的节点"""
def __init__(self):
super().__init__([nn.MaxPool2d], ConvReplacement())
bn_replace = BnReplace()
bn_replace.apply(stree)
def get_net(stree):
"""获取修改后的网络"""
return stree.get_network()
def get_code(stree):
"""获取修改后的网络代码"""
return stree.get_code()
def test_rewrite():
"""ReWrite测试函数"""
net = Net()
stree = create_stree(net)
print(f"origin code: {stree.get_code()}")
insert_node(stree)
print(f"after inser node code: {stree.get_code()}")
insert_node_to_subtree(stree)
print(f"after inser node to subtree code: {stree.get_code()}")
delete_node(stree)
print(f"after remove node code: {stree.get_code()}")
replace_node(stree)
print(f"after replace node code: {stree.get_code()}")
pattern_replace(stree)
print(f"after pattern replace node code: {stree.get_code()}")
inputs = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32) # pylint: disable=not-callable
new_net = get_net(stree)
source_code = get_code(stree)
print(source_code)
out = new_net(inputs)
print("out: ", out)
export(new_net, inputs, file_name="new_net", file_format="MINDIR")
if __name__ == "__main__":
test_rewrite()
- class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[源代码]
SymbolTree通常对应于网络的前向计算过程。
- 参数:
handler (SymbolTreeImpl) - SymbolTree内部实现实例。
- after(node: Node)[源代码]
获取插入位置,位置为 node 之后。 返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 Position是什么,只需将其视为处理程序并将其用作SymbolTree的插入接口的参数。
- 参数:
node (Node) - 指定插入位置在哪个节点之后,可以是Node或者Node的名称。
- 返回:
Position,指定插入节点的位置。
- 异常:
TypeError - 参数不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.after(node)
- before(node: Node)[源代码]
与after的区别是,该接口返回的位置为 node 之前。 返回值用于指示插入节点的位置,它指示在源代码中的位置,而不是在拓扑顺序中的位置。不需要关心 Position 是什么,只需将其视为处理程序并将其用作 SymbolTree 的插入接口的参数。
- 参数:
node (Node) - 指定插入位置在哪个节点之前,可以是Node或者Node的名称。
- 返回:
Position,指定插入节点的位置。
- 异常:
TypeError - 参数不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.before(node)
- create(network)[源代码]
根据传入的 network 创建SymbolTree对象。
- 参数:
network (Cell) - 重写的网络。
- 返回:
SymbolTree,基于 network 创建的符号树。
- 异常:
TypeError - 参数 network 不是Cell类型对象。
- erase_node(node: Node)[源代码]
删除SymbolTree中的一个节点。被删除的节点必须不被其他节点依赖。
- 参数:
node (Node) - 被删除的节点。可以是Node或者Node的名称。
- 返回:
如果 node 属于当前的SymbolTree则返回被删除节点。否则返回None。
- 异常:
TypeError - 如果参数不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> input_node = node.get_inputs()[0] >>> output_nodes = node.get_users() >>> for n in output_nodes: ... n.set_arg(0, "x") >>> stree.erase_node(node)
- get_code()[源代码]
获取SymbolTree所对应的源代码。
- 返回:
str,SymbolTree对应的源码字符串。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> codes = stree.get_code() >>> print(codes)
- get_network()[源代码]
获取SymbolTree所对应的生成的网络对象。源码会保存到文件中,默认的文件名为 network_define.py。
- 返回:
根据SymbolTree生成的网络对象。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> new_net = stree.get_network()
- insert(position, node: Node)[源代码]
在SymbolTree的 position 位置插入一个节点。 position 可以通过 before 或 after 来获得。
- 参数:
position (Position) - 插入位置。
node (Node) - 要插入的节点。
- 返回:
Node,被插入的节点, 当调用此方法时会对参数进行唯一性处理, node 会被修改。
- 异常:
RuntimeError - 如果 position 指定的不是该SymbolTree内的位置。
TypeError - 如果参数 position 不是Position类型。
TypeError - 如果参数 node 不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from mindspore.ops import abs >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> position = stree.after(node) >>> new_node = stree.create_call_function(abs, ["x"], node) >>> stree.insert(position, new_node)
- nodes()[源代码]
获取当前SymbolTree的节点,用于遍历。
- 返回:
当前SymbolTree中节点的生成器。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... print(node.get_name())
- replace(old_node: Node, new_nodes: [Node])[源代码]
使用新节点列表来替代旧节点。
说明
仅支持一对一更换或一对多替换。如果需要多对多替换,请参考PatternEngine。
当一对多替换时,Rewrite会将 new_nodes 中所有节点插入到 symbol_tree 中。
调用者应指定子树内节点的参数和输出来确定子树内的拓扑关系。
调用者应指定子树输入节点的参数来确定子树与原始树中节点的拓扑关系。
ReWrite将维护子树的前置节点的参数,用于指定子树输出的拓扑关系。
将 new_nodes 替换到SymbolTree后,ReWrite将维护节点的所有输入。
- 参数:
old_node (Node) - 被替换节点。
new_nodes (list[Node]) - 要替换进SymbolTree的节点列表。
- 返回:
替换到SymbolTree的节点列表的根节点。
- 异常:
RuntimeError - 如果 old_node 仍然被其他节点依赖。
TypeError - 如果参数 new_nodes 不是list,或者列表中的成员不是Node类型。
TypeError - 如果参数 old_node 不是Node类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> from mindspore.ops import abs >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> input_node = stree.get_node("input_x") >>> node = stree.get_node("conv1") >>> new_node = stree.create_call_function(abs, ["x"], input_node) >>> stree.replace(node, [new_node])
- class mindspore.rewrite.Node(node: NodeImpl)[源代码]
节点是表达网络中源代码的一种数据结构。
在大多数情况下,Node表示前向计算的的运算,它可以是Cell的实例、Primitive的实例或可调用的方法。
- 参数:
node (NodeImpl) - NodeImpl 的handle。NodeImpl是Node的实现,不是Rewrite的接口。Rewrite建议调用Node的特定 create 方法来实例化Node的实例,例如 create_call_cell,而不直接调用Node的构造函数,不需关心NodeImpl是什么,只需作为handle看待。
- static create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue} = None, name: str = '', is_sub_net: bool = False)[源代码]
通过该接口可以根据 cell 对象创建一个Node实例。节点对应的源代码格式:
targets = self.name(*args, **kwargs)
。- 参数:
cell (Cell) - 该节点对应的前向计算的Cell对象。
targets (list[ScopedValue]) - 表示输出名称。在源代码中作为节点的输出。Rewrite将在插入节点时检查并确保每个目标的唯一性。
args (list[ScopedValue]) - 该节点的参数名称。用作源代码中代码语句的参数。Rewrite将在插入节点时检查并确保每个 arg 的唯一性。默认值:
None
,表示 cell 没有参数输入。kwargs (dict) - 键的类型必须是str,值的类型必须是ScopedValue。用来说明带有关键字的形参的输入参数名称。输入名称在源代码中作为语句表达式中的 kwargs。Rewrite将在插入节点时检查并确保每个 kwarg 的唯一性。默认值:
None
,表示 cell 没有 kwargs 输入。name (str) - 表示节点的名称。用作源代码中的字段名称。当未提供名称时,ReWrite将根据 target 生成一个默认名称。Rewrite将在插入节点时检查并确保名称的唯一性。默认值:
""
。is_sub_net (bool) - 表示 cell 是否是一个网络。如果 is_sub_net 为
True
,Rewrite将尝试将 cell 解析为TreeNode,否则为CallCell节点。默认值:False
。
- 返回:
Node实例。
- 异常:
TypeError - 如果参数 cell 不是Cell类型。
TypeError - 如果参数 targets 不是list类型。
TypeError - 如果参数 targets 的成员不是str或者ScopedValue类型。
TypeError - 如果参数 args 不是ScopedValue类型。
TypeError - 如果参数 kwarg 的 key 不是str类型或者 value 不是ScopedValue类型。
- get_instance_type()[源代码]
获取当前节点对应的 operation 实例类型。
如果当前节点的 node_type 是 CallCell,该节点是Cell对象。
如果当前节点的 node_type 是 CallPrimitive,该节点的是Primitive对象。
如果当前节点的 node_type 是 Tree,该节点的类型是网络。
如果当前节点的 node_type 是 Python、 Input、 Output、 CallMethod,该节点的类型为NoneType。
- 返回:
当前节点的 operation 类型。
- get_name()[源代码]
获取当前节点的名称。当节点被插入到SymbolTree时,节点的名称在SymbolTree中应该是唯一的。
- 返回:
节点的名称,类型为str。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> name = node.get_name()
- get_node_type()[源代码]
获取当前节点节点的类型。
- 返回:
NodeType,当前节点的类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> node_type = node.get_node_type()
- get_users()[源代码]
按拓扑顺序获取当前节点的输出节点。
- 返回:
输出节点的列表。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> users = node.get_users()
- set_arg(index: int, arg: Union[ScopedValue, str])[源代码]
设置当前节点的输入参数。
- 参数:
index (int) - 要设置的参数索引。
arg (Union[ScopedValue, str]) - 新参数的值。
- 异常:
TypeError - 如果参数 index 不是int类型。
TypeError - 如果参数 arg 不是str或者ScopedValue类型。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> node.set_arg(0, "x")
- set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[源代码]
将另一个节点设置为当前节点的输入。
- 参数:
arg_idx (int) - 要设置的参数索引。
src_node (Node) - 输入的节点。
out_idx (int,optional) - 指定输入节点的哪个输出作为当前节点输入,则取第一个输出。默认值:None。
- 异常:
RuntimeError - 如果 src_node 不属于当前的SymbolTree。
TypeError - 如果参数 arg_idx 不是int类型。
ValueError - 如果参数 arg_idx 超出了当前节点的参数数量。
TypeError - 如果参数 src_node 不是Node类型。
TypeError - 如果参数 out_idx 不是int类型。
ValueError - 如果参数 out_idx 超出了 src_node 的输出数量。
ValueError - 当 out_idx 为None或者没有给 out_idx 赋值时,参数 src_node 有多个输出。
样例:
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> src_node = stree.get_node("conv1") >>> dst_node = stree.get_node("conv2") >>> dst_node.set_arg_by_node(0, src_node)
- class mindspore.rewrite.NodeType[源代码]
NodeType表示Node的类型。
Unknown:未初始化的节点类型。
CallCell: CallCell 节点表示在前向计算中调用Cell对象。
CallPrimitive: CallPrimitive 节点代表在前向计算中调用Primitive对象。
CallMethod: CallMethod 不能对应到Cell或者Primitive的节点。
Python: Python 节点包含不支持的 ast 的节点类型或不必要的解析 ast 节点。
Input:输入节点代表SymbolTree的输入,对应方法的参数。
Output: 输出节点代表SymbolTree的输出,对应方法的 return 语句。
Tree: 树节点代表转发方法中的子网调用。
- class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[源代码]
ScopedValue表示具有完整范围的值。
ScopedValue用于表示:左值,如赋值语句的目标,或可调用对象,如调用语句的 func,或右值,如赋值语句的 args 和 kwargs。
- 参数:
arg_type (ValueType) - 当前值的类型。
scope (str) - 字符串表示当前值的范围。以”self.var1”为例,这个var1的作用域是”self”。默认值:
""
。value - 当前ScopedValue中保存的值。值的类型对应于 arg_type。默认值:
None
。
- static create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None)[源代码]
创建ScopedValue的列表。
- 参数:
names (list[str] or tuple[str]) - 引用变量的名称,类型为str的列表或元组。
scopes (list[str] or tuple[str]) - 引用变量的范围,类型为str的列表或元组。默认值:
None
,表示没有指定作用范围。
- 返回:
ScopedValue的实例列表。
- 异常:
TypeError - 如果 names 不是 list 或 tuple 或者其中的元素不是str类型。
TypeError - 如果 scopes 不是 list 或 tuple 或者其中的元素不是str类型。
RuntimeError - 如果 names 的长度不等于 scopes 的长度,而作用域不是None。
样例:
>>> from mindspore.rewrite import ScopedValue >>> variables = ScopedValue.create_name_values(["z", "z_1"], name="subnet")
- create_naming_value(name: str, scope: str = '')[源代码]
创建一个使用变量名称命名的ScopedValue。NamingValue表示对另一个变量的引用。
- 参数:
name (str) – 表示变量的字符串。
scope (str) – 表示变量范围的字符串,默认值:
""
,表示没有指定作用范围。
- 返回:
ScopedValue的实例。
- 异常:
TypeError - 如果 name 不是str类型。
TypeError - 如果 scope 不是str类型。
样例:
>>> from mindspore.rewrite import ScopedValue >>> variable = ScopedValue.create_naming_value("conv", "self")
- class mindspore.rewrite.ValueType[源代码]
ValueType表示ScopedValue的类型。
NamingValue表示对另一个变量的引用。
CustomObjValue表示自定义类的实例,或类型超出ValueType的基本类型和容器类型范围的对象。
- class mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None)[源代码]
PatternEngine通过PattenNode修改SymbolTree。
- 参数:
pattern (Union[PatternNode, List]) - PatternNode的实例或用于构造 Pattent 的Cell类型列表。
replacement (callable) - 生成新节点的接口实现。默认值:
None
。
- apply(stree: SymbolTree)[源代码]
在 stree 上面执行当前的匹配模式。
说明
当前还不支持子树节点。
- 参数:
stree (SymbolTree) - 要修改的SymbolTree。
- 返回:
bool,表示是否对 stree 进行了修改。
- 异常:
TypeError - 如果参数 stree 不是SymbolTree类型。
- class mindspore.rewrite.PatternNode(pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None)[源代码]
PatternNode在定义 pattern 时被定义为一个节点。
- 参数:
pattern_node_name (str) - 节点名称。
match_type (Type) - 当前节点的匹配类型。默认值:
Type[None]
。inputs (list[PatternNode]) - 当前节点的输入节点。默认值:
None
。
- add_input(node)[源代码]
为当前节点添加输入。
- 参数:
node (PatternNode) - 新增的输入节点。
- 异常:
TypeError - 如果参数 node 不是PattenNode类型。
- static create_pattern_from_list(type_list: [])[源代码]
使用类型的列表来创建Pattern。
- 参数:
type_list (list[type]) - 类型列表。
- 返回:
根据列表生成的模式的根节点。
- 异常:
TypeError - 如果 type_list 不是list类型。
- static create_pattern_from_node(node: Node)[源代码]
根据节点及其输入创建Pattern。
- 参数:
node (Node) - 要修改的节点。
- 返回:
根据 node 创建的PattentNode。
- 异常:
TypeError - 如果 node 不是Node类型。
- static from_node(node: Node)[源代码]
根据 node 创建PatternNode。
- 参数:
node (Node) - 要修改的节点。
- 返回:
根据 node 创建的PattentNode。
- 异常:
TypeError - 如果 node 不是Node类型。
- match(node: Node)[源代码]
检查当前PatternNode是否可以与node匹配。
- 参数:
node (Node) - 要匹配的节点。
- 异常:
TypeError - 如果参数 node 不是PattenNode类型。
- class mindspore.rewrite.Replacement[源代码]
替换的接口定义。
样例:
>>> from mindspore.rewrite import Replacement, Node >>> from mindspore.nn import nn >>> class BnReplacement(Replacement): ... def build(self, pattern, is_chain_pattern: bool, matched): ... bn_node: Node = matched.get(pattern.name()) ... conv = nn.Conv2d(16, 16, 3) ... conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs()) ... return [conv_node]
- abstract build(pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict)[源代码]
用于从匹配结果创建替换节点的接口定义。
说明
返回值将作为SymbolTree的替换函数的参数,返回值应遵循替换函数参数的 new_nodes 的约束。请参阅SymbolTree的 replace 的文档字符串中的详细信息。
- 参数:
pattern (PatternNode) - 当前模式的根节点。
is_chain_pattern (bool) - 标记,标记模式是链模式或树模式。
matched (OrderedDict) - 匹配结果,从名称映射到节点的字典。
- 返回:
作为替换节点的节点实例列表。
- class mindspore.rewrite.TreeNodeHelper[源代码]
TreeNodeHelper用于在从Tree类型节点获取 symbol_tree 时打破循环引用。
TreeNodeHelper提供了静态方法 get_sub_tree 用于从Tree类型节点获取 symbol_tree。
- mindspore.rewrite.sparsify(f, arg_types, sparse_rules=None)[源代码]
模型自动稀疏化接口,将稠密模型转换为稀疏模型。通过 arg_types 指定的参数类型,将稀疏参数在模型中传导,并调用相应的稀疏函数。
- 参数:
f (Cell) - 被稀疏化的网络。
arg_types (Tuple[ArgType] | Dict[int, ArgType]) - f 接受的参数类型(稀疏CSR/COO、非稀疏等)。如果是tuple,长度需要和 f 的参数数量相等;如果是dict,每个键值对应一个参数的索引,字典里没有表示的参数默认为非稀疏。
sparse_rules (Dict[str, SparseFunc], 可选) - 自定义稀疏规则。默认值:
None
。
- class mindspore.rewrite.SparseFunc(fn: Union[str, Callable], inputs: Optional[Any] = None, outputs: Optional[Any] = None)[源代码]
在稀疏化中表示一个稀疏函数。
说明
如果 fn 是一个包含类型注解的函数,且同时提供了 inputs,则类型注解中的输入类型将被忽略。outputs 同理。
- 参数:
fn (Union[str, Callable]) - 稀疏函数,如果是字符串,表示一个mindspore.ops.function接口;或者是任意函数对象。
inputs (Any, 可选) - 函数的输入类型。如果是
None
,则使用函数本身的类型注解。默认值:None
。outputs (Any, 可选) - 函数的输出类型。如果是
None
,则使用函数本身的类型注解。默认值:None
。