mindspore.rewrite
The ReWrite module in MindSpore provides users with the ability to modify the network's forward computation process based on custom rules, such as inserting, deleting, and replacing statements.
For a quick start of using ReWrite, please refer to Modifying Network With ReWrite .
MindSpore Rewrite package. This is an experimental python package that is subject to change or deletion.
- class mindspore.rewrite.Node(node: NodeImpl)[source]
A node is a data structure that expresses source code statements in a network.
Each node usually corresponds to a statement in expanded forward evaluation process.
Nodes can express a
Cell
call statement, aPrimitive
call statement, an arithmetic operation statement, a return statements, etc. of the forward calculation process.- Parameters
node (NodeImpl) – A handler of NodeImpl. It is recommended to call the specific methods in Node to create a Node, such as 'create_call_cell', rather than calling the Node's constructor directly. Don't care what NodeImpl is, just treat it as a handle.
- static create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None, name: str = '', is_sub_net: bool = False)[source]
Create a node. Only support create from a Cell now.
A node is corresponding to source code like:
targets = self.name(*args, **kwargs)
- Parameters
cell (Cell) – Cell-operator of this forward-layer.
targets (List[Union[ScopedValue, str]]) – Indicate output names. Used as targets of an assign statement in source code.
args (List[ScopedValue]) – Indicate input names. Used as args of a call expression of an assign statement in source code. Default:
None
, which indicates the cell has no args inputs.kwargs (Dict[str, ScopedValue]) – Type of key must be str and type of value must be ScopedValue. Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source code. Default:
None
, which indicates the cell has no kwargs inputs.name (str) – Indicate the name of node. Used as field name in source code. Default is None. Rewrite will generate name from cell when name is None. Rewrite will check and ensure the uniqueness of name while node being inserted. Default:
""
.is_sub_net (bool) – Indicate that is cell a network. If is_sub_net is true, Rewrite will try to parse the cell to a TreeNode, otherwise the cell is parsed to a CallCell node. Default:
False
.
- Returns
An instance of Node.
- Raises
Examples
>>> from mindspore.rewrite import SymbolTree, ScopedValue >>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> position = stree.after(node) >>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'], ... args=[ScopedValue.create_naming_value('x')], name='new_relu') >>> stree.insert(position, new_node) >>> print(type(new_node)) <class 'mindspore.rewrite.api.node.Node'>
- static create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None)[source]
Create a node that corresponds to a function call.
Note
The codes inside the function will not be parsed.
- Parameters
function (FunctionType) – The function to be called.
targets (List[Union[ScopedValue, str]]) – indicates output names. Used as targets of an assign statement in source code.
args (List[ScopedValue]) – Indicate input names. Used as args of a call expression of an assign statement in source code. Default:
None
, which indicates the function has no args inputs.kwargs (Dict[str, ScopedValue]) – Type of key must be str and type of value must be ScopedValue. Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source code. Default:
None
, which indicates the function has no kwargs inputs.
- Returns
An instance of Node.
- Raises
Examples
>>> from mindspore.rewrite import SymbolTree, ScopedValue >>> import mindspore.nn as nn >>> from mindspore import ops >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> position = stree.after(node) >>> new_node = node.create_call_function(function=ops.abs, targets=['x'], ... args=[ScopedValue.create_naming_value('x')]) >>> stree.insert(position, new_node) >>> print(new_node.get_node_type()) NodeType.CallFunction
- get_args()[source]
Get arguments of current node.
- Returns
A list of arguments of type
ScopedValue
.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> print(node.get_args()) [x]
- get_inputs()[source]
Gets a list of nodes whose output values are used as input values for the current node.
- Returns
A list of nodes.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv2") >>> inputs = node.get_inputs() >>> print([input.get_name() for input in inputs]) ['max_pool2d']
- get_instance_type()[source]
Gets the instance type called in the code corresponding to the current node.
When node_type of current node is CallCell, the code for that node calls an instance of type
Cell
.When node_type of current node is CallPrimitive, the code for that node calls an instance of type
Primitive
.When node_type of current node is Tree, the code for that node calls an instance of network type.
When node_type of current node is Python, Input, Output or CallMethod, the instance type is
NoneType
.
- Returns
The type of instance called in the statement corresponding to the current node.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> instance_type = node.get_instance_type() >>> print(instance_type) <class 'mindspore.nn.layer.conv.Conv2d'>
- get_kwargs()[source]
Get keyword arguments of current node.
- Returns
A dict of keyword arguments, where key is of type str, and value is of type
ScopedValue
.
Examples
>>> from mindspore.rewrite import SymbolTree >>> from mindspore import nn >>> >>> class ReLUNet(nn.Cell): ... def __init__(self): ... super().__init__() ... self.relu = nn.ReLU() ... ... def construct(self, input): ... output = self.relu(x=input) ... return output >>> >>> net = ReLUNet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("relu") >>> print(node.get_kwargs()) {'x': input}
- get_name()[source]
Get the name of current node.
When node has been inserted into SymbolTree, the name of node should be unique in SymbolTree.
- Returns
A string as name of node.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> name = node.get_name() >>> print(name) conv1
- get_node_type()[source]
Get the node_type of current node. See
mindspore.rewrite.NodeType
for details on node types.- Returns
A NodeType as node_type of node.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> node_type = node.get_node_type() >>> print(node_type) NodeType.CallCell
- get_sub_tree()[source]
Get the sub symbol tree stored in node with type of NodeType.Tree . See
mindspore.rewrite.NodeType
for details on node types.- Returns
SymbolTree stored in Tree node.
- Raises
TypeError – If current node is not type of NodeType.Tree .
AttributeError – If no symbol tree is stored in Tree node.
Examples: >>> import mindspore.nn as nn >>> from mindspore.rewrite import SymbolTree >>> >>> class SubNet(nn.Cell): … def __init__(self): … super().__init__() … self.relu = nn.ReLU() … … def construct(self, x): … x = self.relu(x) … return x … >>> class Net(nn.Cell): … def __init__(self): … super().__init__() … self.subnet = SubNet() … … def construct(self, x): … x = self.subnet(x) … return x >>> >>> net = Net() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("subnet") >>> print(type(node.get_sub_tree())) <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
- get_symbol_tree()[source]
Get the symbol tree which current node belongs to.
- Returns
SymbolTree, None if current node does not belong to any SymbolTree.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> print(type(node.get_symbol_tree())) <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
- get_targets()[source]
Gets a list of output values for the current node.
- Returns
A list of outputs of type
ScopedValue
.
- get_users()[source]
Get a list of nodes that use the output of the current node as input.
- Returns
A list of nodes.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> users = node.get_users() >>> print([user.get_name() for user in users]) ['relu']
- set_arg(index: int, arg: Union[ScopedValue, str])[source]
Set argument of current node.
- Parameters
index (int) – Indicate which input being modified.
arg (Union[ScopedValue, str]) – New argument to been set.
- Raises
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("relu_3") >>> node.set_arg(0, "fc1") >>> print(node.get_args()) [fc1]
- set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[source]
Set argument of current node by another Node.
- Parameters
- Raises
TypeError – If arg_idx is not a int number.
ValueError – If arg_idx is out of range.
TypeError – If src_node is not a Node instance.
TypeError – If out_idx is not a int number.
ValueError – If out_idx is out of range.
ValueError – If src_node has multi-outputs while out_idx is None or out_idx is not offered.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> src_node = stree.get_node("fc1") >>> dst_node = stree.get_node("relu_3") >>> dst_node.set_arg_by_node(0, src_node, 0) >>> print(dst_node.get_args()) [fc1_var]
- class mindspore.rewrite.NodeType[source]
NodeType represents type of Node.
Unknown: Not inited NodeType.
CallCell: CallCell node represents invoking cell-op in forward method.
CallPrimitive: CallPrimitive node represents invoking primitive-op in forward method.
CallFunction: CallFunction node represents invoking a function in forward method.
CallMethod: CallMethod node represents invoking of method in forward method which can not be mapped to cell-op or primitive-op in MindSpore.
Python: Python node holds unsupported-ast-node or unnecessary-to-parse-ast-node.
Input: Input node represents input of SymbolTree corresponding to arguments of forward method.
Output: Output node represents output of SymbolTree corresponding to return statement of forward method.
Tree: Tree node represents sub-network invoking in forward method.
CellContainer: CellContainer node represents invoking method
mindspore.nn.SequentialCell
in forward method.MathOps: MathOps node represents a mathematical operation, such as adding or comparing in forward method.
ControlFlow: ControlFlow node represents a control flow statement, such as if statement.
- class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[source]
ScopedValue represents a value with its full-scope.
ScopedValue is used to express: a left-value such as target of an assign statement, or a callable object such as func of a call statement, or a right-value such as args and kwargs of an assign statement.
- Parameters
arg_type (ValueType) – A ValueType represents type of current value.
scope (str, optional) – A string represents scope of current value. Take "self.var1" as an example, scope of this var1 is "self". Default:
""
.value – A handler represents value of current value. The type of value is corresponding to arg_type. Default:
None
.
- static create_name_values(names: Union[List[str], Tuple[str]], scopes: Union[List[str], Tuple[str]] = None)[source]
Create a list of naming ScopedValue.
- Parameters
- Returns
An list of instance of ScopedValue.
- Raises
TypeError – If names is not list or tuple and name in names is not str.
TypeError – If scopes is not list or tuple and scope in scopes is not str.
ValueError – If the length of names is not equal to the length of scopes when scopes are not None.
Examples
>>> from mindspore.rewrite import ScopedValue >>> variables = ScopedValue.create_name_values(names=["z", "z_1"], scopes=["self", "self"]) >>> print(variables) [self.z, self.z_1]
- classmethod create_naming_value(name: str, scope: str = '')[source]
Create a naming ScopedValue. A NamingValue represents a reference to another variable.
- Parameters
- Returns
An instance of ScopedValue.
- Raises
Examples
>>> from mindspore.rewrite import ScopedValue >>> variable = ScopedValue.create_naming_value("conv", "self") >>> print(variable) self.conv
- classmethod create_variable_value(value)[source]
Create ScopedValue from a variable.
ScopedValue's type is determined by type of value. ScopedValue's scope is empty.
- Parameters
value – The value to be converted to ScopedValue.
- Returns
An instance of ScopedValue.
Examples
>>> from mindspore.rewrite import ScopedValue >>> variable = ScopedValue.create_variable_value(2) >>> print(variable) 2
- class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[source]
SymbolTree stores information about a network, including statements of the network's forward computation process and the topological relationship between statement input and output.
The statements in the network are saved in the SymbolTree in the form of nodes, and by processing the nodes in the SymbolTree, you can delete the network code, insert and replace it, and get the modified network code and network instances.
- Parameters
handler (SymbolTreeImpl) – SymbolTree internal implementation instance. It is recommended to call the create method in SymbolTree to create a SymbolTree, rather than calling SymbolTree's constructor directly. Don't care what SymbolTreeImpl is, just treat it as a handle.
- after(node: Union[Node, str])[source]
Returns a location information after node. The return value of this interface is used as a parameter for the insert operation.
- Parameters
node (Union[Node, str]) – Indicate the position after which node. Can be a node or name of node.
- Returns
A Position to indicate where to insert node.
- Raises
TypeError – If node is not a Node or str.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.after(node)
- before(node: Union[Node, str])[source]
Returns a location information before node. The return value of this interface is used as a parameter for the insert operation.
- Parameters
node (Union[Node, str]) – Indicate the position before which node. Can be a node or name of node.
- Returns
A Position to indicate where to insert node.
- Raises
TypeError – if node is not a Node or str.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.before(node)
- classmethod create(network)[source]
Create a SymbolTree object by passing in the network instance network.
This interface parses the network instance, expands each source code statement of the forward computation process, and parses it into nodes, which is stored in the SymbolTree. The specific process is as follows:
Obtain the source code of the network instance.
Perform AST parsing on the network and obtain the AST nodes (abstract syntax trees) of each statement in the network.
Expand complex statements in the network forward evaluation process into multiple simple statements.
Create a SymbolTree object. Each SymbolTree corresponds to one network instance.
Use the rewrite node to store each statement of the network forward computation process. The node records the input, output, and other information of the statement.
Save the rewrite node to the SymbolTree, and update and maintain the topological connection between the nodes.
Return the SymbolTree object corresponding to the network instance.
If a user-defined network of type
mindspore.nn.Cell
is called in the forward computation process of the network, rewrite will generate a node of type NodeType.Tree for the corresponding statement. This type of node stores a new SymbolTree, which parses and maintains the node information of the user-defined network.If the following types of statements are called in the forward computation process of the network, rewrite will parse the internal statements in the statement and generate corresponding nodes:
Functions(Excludes Python built-in functions and third-party library functions)
Control flow statements, such as if statements
Note
Because the specific execution branch of control flows are still unknown during the rewrite operation of the network, no topology information will be established between the nodes inside the control flow and the nodes outside. Users cannot obtain nodes inside the control flow when they acquire nodes outside the control flow using interfaces like
mindspore.rewrite.Node.get_inputs()
andmindspore.rewrite.Node.get_users()
. Users also cannot obtain nodes outside the control flow, if they use these interfaces inside the control flow. Therefore, when users modify the network, they need to manually handle the node information inside and outside the control flow.The current rewrite module has the following syntax limitations:
Only networks of type
mindspore.nn.Cell
are supported as input to the rewrite module.Parsing one-line control flow syntax(e.g. one-line if-else, one-line for loop) is not currently supported.
Parsing decorator syntax is not currently supported.
Parsing local classes and embedded classes is not currently supported, that is, the definition of classes need to be placed on the outermost layer.
Parsing closure syntax is not currently supported, that is, the definition of out-of-class functions need to be placed at the outermost layer.
Parsing lambda expression syntax is not currently supported.
Parsing global variables is not currently supported, that is, global variables need to be converted to class variables or local variables before they can be used.
Parsing methods in the parent classes is not currently supported.
For statements that do not support parsing, rewrite will generate nodes of type NodeType.Python for corresponding statements to ensure that the network after rewrite can run normally. The Python node does not support modifying the input and output of statements, and there may be a problem between variable names and those generated by the rewrite. In this case, users need to adjust the variable names manually.
- Parameters
network (Cell) – network used to create SymbolTree.
- Returns
Symboltree, a SymbolTree created based on network.
- Raises
TypeError – If network is not a Cell instance.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> print(type(stree)) <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
- erase(node: Union[Node, str])[source]
Erase a node from rewrite.
- Parameters
node (Union[Node, str]) – A Node to be erased. Can be a node or name of node.
- Returns
An instance of Node being erased if node is in SymbolTree else None.
- Raises
TypeError – The type of node is not Node or str.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> stree.erase(node)
- get_code()[source]
Get source code corresponding to the network information in SymbolTree. If the network has already been modified, the source code of modified network is returned.
- Returns
A str represents source code of modified network.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> codes = stree.get_code() >>> print(codes)
- get_network()[source]
Get the network object generated based on SymbolTree. The source code is saved to a file in the 'rewritten_network' folder of the current directory.
Note
The modification of network by rewrite module is based on the modification of AST tree of original network instance, and the new network instance will obtain attribute information from original network instance, so the new network instance and the original network instance have data association, and the original network should no longer be used.
Due to the data association between the new network and the original network instance, manually creating a network instance using the source code file generated by rewrite is not currently supported.
- Returns
A network object generated from SymbolTree.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> new_net = stree.get_network()
- get_node(node_name: str)[source]
Get the node with the name node_name in the SymbolTree.
- Parameters
node_name (str) – The name of node.
- Returns
Node with name of node_name . Return
None
if there is no node named node_name in SymbolTree.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node('conv1') >>> print(node.get_name()) conv1
- insert(position, node: Node)[source]
Insert a node into SymbolTree at position.
position is obtained from before api or after api of SymbolTree.
- Parameters
position (Position) – Indicate where to insert node.
node (Node) – An instance of Node to be inserted.
- Returns
An instance of Node being inserted.
- Raises
ValueError – If position is not belong to current SymbolTree.
TypeError – If position is not a Position.
TypeError – If node is not a Node.
Examples
>>> from mindspore.rewrite import SymbolTree, ScopedValue >>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> position = stree.after(node) >>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'], ... args=[ScopedValue.create_naming_value('x')], name='new_relu') >>> stree.insert(position, new_node)
- nodes(all_nodes: bool = False)[source]
Get the generator of the node in the current SymbolTree, which is used to iterate through the nodes in SymbolTree.
- Parameters
all_nodes (bool) – Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree. Default:
False
.- Returns
A generator for nodes in SymbolTree.
- Raises
TypeError – If all_nodes is not bool.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> print([node.get_name() for node in stree.nodes()]) ['input_x', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1', 'unaryop_not', 'if_node', 'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return_1']
- print_node_tabulate(all_nodes: bool = False)[source]
Print the topology information of nodes in SymbolTree, including node type, node name, node code, and node input-output relationship.
The information is output to the screen using the print interface, including the following information:
node type (str): The type of node, refer to class:mindspore.rewrite.NodeType .
name (str): The name of node.
codes (str): The code statement in the SymbolTree corresponding to the node.
arg providers (Dict[int, Tuple[str, int]]): The format is {[idx, (n, k)]} , which means the idx th parameter of the node is provided by the k th output of node n .
target users (Dict[int, List[Tuple[str, int]]]): The format is '{[idx, [(n, k)]]}' , which means the idx th output of the node is used as the k th parameter of node n .
- Parameters
all_nodes (bool) – Print information of all nodes, including nodes in CallFunction node, CellContainer node and sub symbol tree. Default:
False
.- Raises
TypeError – If all_nodes is not bool.
Examples
>>> from mindspore.rewrite import SymbolTree >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> stree.print_node_tabulate()
- replace(old_node: Node, new_nodes: List[Node])[source]
Replace the old_node with nodes in the new_nodes list.
Nodes in new_nodes will be inserted into SymbolTree sequentially, and then old_node will be deleted.
Note
Replace only support one-to-one replacement or one-to-multi replacement.
Caller should maintain the topological relationship between each node in the new_nodes , as well as the topological relationship between nodes in the new_nodes and nodes in the original tree.
- Parameters
- Returns
An instance of Node represents root of node_tree been replaced in.
- Raises
Examples
>>> from mindspore.rewrite import SymbolTree, ScopedValue >>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'], ... args=[ScopedValue.create_naming_value('x')], name='new_relu') >>> stree.replace(node, [new_node])
- unique_name(name: str = 'output')[source]
Based on the given name , returns a new name that is unique within the symbol tree. This interface can be used when a variable name that does not conflict is required.
- Parameters
name (str, optional) – The prefix of the name. Defaults to
"output"
.- Returns
str, A new, unique name within a symbol tree in the format name_n, where n is a numeric subscript. If there is no name conflict when entered name, there is no numeric subscript.
- Raises
TypeError – The type of name is not str.