mindspore.rewrite

The ReWrite module in MindSpore provides users with the ability to modify the network’s forward computation process based on custom rules, such as inserting, deleting, and replacing statements.

For a quick start of using ReWrite, please refer to Modifying Network With ReWrite .

MindSpore Rewrite package. This is an experimental python package that is subject to change or deletion.

class mindspore.rewrite.Node(node: NodeImpl)[source]

A node is a data structure that expresses source code statements in a network.

Each node usually corresponds to a statement in expanded forward evaluation process.

Nodes can express a Cell call statement, a Primitive call statement, an arithmetic operation statement, a return statements, etc. of the forward calculation process.

Parameters

node (NodeImpl) – A handler of NodeImpl. It is recommended to call the specific methods in Node to create a Node, such as ‘create_call_cell’, rather than calling the Node’s constructor directly. Don’t care what NodeImpl is, just treat it as a handle.

static create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None, name: str = '', is_sub_net: bool = False)[source]

Create a node. Only support create from a Cell now.

A node is corresponding to source code like:

targets = self.name(*args, **kwargs)

Parameters
  • cell (Cell) – Cell-operator of this forward-layer.

  • targets (List[Union[ScopedValue, str]]) – Indicate output names. Used as targets of an assign statement in source code.

  • args (List[ScopedValue]) – Indicate input names. Used as args of a call expression of an assign statement in source code. Default: None , which indicates the cell has no args inputs.

  • kwargs (Dict[str, ScopedValue]) – Type of key must be str and type of value must be ScopedValue. Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source code. Default: None , which indicates the cell has no kwargs inputs.

  • name (str) – Indicate the name of node. Used as field name in source code. Default is None. Rewrite will generate name from cell when name is None. Rewrite will check and ensure the uniqueness of name while node being inserted. Default: "" .

  • is_sub_net (bool) – Indicate that is cell a network. If is_sub_net is true, Rewrite will try to parse the cell to a TreeNode, otherwise the cell is parsed to a CallCell node. Default: False .

Returns

An instance of Node.

Raises
  • TypeError – If cell is not a Cell.

  • TypeError – If targets is not list.

  • TypeError – If the type of targets is not in [ScopedValue, str].

  • TypeError – If arg in args is not a ScopedValue.

  • TypeError – If key of kwarg is not a str or value of kwarg in kwargs is not a ScopedValue.

Examples

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.insert(position, new_node)
>>> print(type(new_node))
<class 'mindspore.rewrite.api.node.Node'>
static create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None)[source]

Create a node that corresponds to a function call.

Note

The codes inside the function will not be parsed.

Parameters
  • function (FunctionType) – The function to be called.

  • targets (List[Union[ScopedValue, str]]) – indicates output names. Used as targets of an assign statement in source code.

  • args (List[ScopedValue]) – Indicate input names. Used as args of a call expression of an assign statement in source code. Default: None , which indicates the function has no args inputs.

  • kwargs (Dict[str, ScopedValue]) – Type of key must be str and type of value must be ScopedValue. Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source code. Default: None , which indicates the function has no kwargs inputs.

Returns

An instance of Node.

Raises
  • TypeError – If function is not a FunctionType.

  • TypeError – If targets is not list.

  • TypeError – If the type of targets is not in [ScopedValue, str].

  • TypeError – If arg in args is not a ScopedValue.

  • TypeError – If key of kwarg is not a str or value of kwarg in kwargs is not a ScopedValue.

Examples

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_function(function=ops.abs, targets=['x'],
...                                      args=[ScopedValue.create_naming_value('x')])
>>> stree.insert(position, new_node)
>>> print(new_node.get_node_type())
NodeType.CallFunction
get_args()[source]

Get arguments of current node.

Returns

A list of arguments of type ScopedValue .

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> print(node.get_args())
[x]
get_inputs()[source]

Gets a list of nodes whose output values are used as input values for the current node.

Returns

A list of nodes.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv2")
>>> inputs = node.get_inputs()
>>> print([input.get_name() for input in inputs])
['max_pool2d']
get_instance_type()[source]

Gets the instance type called in the code corresponding to the current node.

  • When node_type of current node is CallCell, the code for that node calls an instance of type Cell .

  • When node_type of current node is CallPrimitive, the code for that node calls an instance of type Primitive .

  • When node_type of current node is Tree, the code for that node calls an instance of network type.

  • When node_type of current node is Python, Input, Output or CallMethod, the instance type is NoneType .

Returns

The type of instance called in the statement corresponding to the current node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> instance_type = node.get_instance_type()
>>> print(instance_type)
<class 'mindspore.nn.layer.conv.Conv2d'>
get_kwargs()[source]

Get keyword arguments of current node.

Returns

A dict of keyword arguments, where key is of type str, and value is of type ScopedValue .

Examples

>>> from mindspore.rewrite import SymbolTree
>>> from mindspore import nn
>>>
>>> class ReLUNet(nn.Cell):
...     def __init__(self):
...         super().__init__()
...         self.relu = nn.ReLU()
...
...     def construct(self, input):
...         output = self.relu(x=input)
...         return output
>>>
>>> net = ReLUNet()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("relu")
>>> print(node.get_kwargs())
{'x': input}
get_name()[source]

Get the name of current node.

When node has been inserted into SymbolTree, the name of node should be unique in SymbolTree.

Returns

A string as name of node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> name = node.get_name()
>>> print(name)
conv1
get_node_type()[source]

Get the node_type of current node. See mindspore.rewrite.NodeType for details on node types.

Returns

A NodeType as node_type of node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node_type = node.get_node_type()
>>> print(node_type)
NodeType.CallCell
get_sub_tree()[source]

Get the sub symbol tree stored in node with type of NodeType.Tree . See mindspore.rewrite.NodeType for details on node types.

Returns

SymbolTree stored in Tree node.

Raises
  • TypeError – If current node is not type of NodeType.Tree .

  • AttributeError – If no symbol tree is stored in Tree node.

Examples: >>> import mindspore.nn as nn >>> from mindspore.rewrite import SymbolTree >>> >>> class SubNet(nn.Cell): … def __init__(self): … super().__init__() … self.relu = nn.ReLU() … … def construct(self, x): … x = self.relu(x) … return x … … class Net(nn.Cell): … def __init__(self): … super().__init__() … self.subnet = SubNet() … … def construct(self, x): … x = self.subnet(x) … return x >>> >>> net = Net() >>> stree = SymbolTree.create(net) >>> node = stree.get_node(“subnet”) >>> print(type(node.get_sub_tree())) <class ‘mindspore.rewrite.api.symbol_tree.SymbolTree’>

get_symbol_tree()[source]

Get the symbol tree which current node belongs to.

Returns

SymbolTree, None if current node does not belong to any SymbolTree.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> print(type(node.get_symbol_tree()))
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
get_targets()[source]

Gets a list of output values for the current node.

Returns

A list of outputs of type ScopedValue .

get_users()[source]

Get a list of nodes that use the output of the current node as input.

Returns

A list of nodes.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> users = node.get_users()
>>> print([user.get_name() for user in users])
['relu']
set_arg(index: int, arg: Union[ScopedValue, str])[source]

Set argument of current node.

Parameters
  • index (int) – Indicate which input being modified.

  • arg (Union[ScopedValue, str]) – New argument to been set.

Raises
  • TypeError – If index is not a int number.

  • TypeError – If the type of arg is not in [ScopedValue, str].

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("relu_3")
>>> node.set_arg(0, "fc1")
>>> print(node.get_args())
[fc1]
set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[source]

Set argument of current node by another Node.

Parameters
  • arg_idx (int) – Indicate which input being modified.

  • src_node (Node) – A Node as new input. Can be a node or name of node.

  • out_idx (int, optional) – Indicate which output of src_node as new input of current node. Default: None , which means use first output of src_node as new input.

Raises
  • TypeError – If arg_idx is not a int number.

  • ValueError – If arg_idx is out of range.

  • TypeError – If src_node is not a Node instance.

  • TypeError – If out_idx is not a int number.

  • ValueError – If out_idx is out of range.

  • ValueError – If src_node has multi-outputs while out_idx is None or out_idx is not offered.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> src_node = stree.get_node("fc1")
>>> dst_node = stree.get_node("relu_3")
>>> dst_node.set_arg_by_node(0, src_node, 0)
>>> print(dst_node.get_args())
[fc1_var]
class mindspore.rewrite.NodeType[source]

NodeType represents type of Node.

  • Unknown: Not inited NodeType.

  • CallCell: CallCell node represents invoking cell-op in forward method.

  • CallPrimitive: CallPrimitive node represents invoking primitive-op in forward method.

  • CallFunction: CallFunction node represents invoking a function in forward method.

  • CallMethod: CallMethod node represents invoking of method in forward method which can not be mapped to cell-op or primitive-op in MindSpore.

  • Python: Python node holds unsupported-ast-node or unnecessary-to-parse-ast-node.

  • Input: Input node represents input of SymbolTree corresponding to arguments of forward method.

  • Output: Output node represents output of SymbolTree corresponding to return statement of forward method.

  • Tree: Tree node represents sub-network invoking in forward method.

  • CellContainer: CellContainer node represents invoking method mindspore.nn.SequentialCell in forward method.

  • MathOps: MathOps node represents a mathematical operation, such as adding or comparing in forward method.

  • ControlFlow: ControlFlow node represents a control flow statement, such as if statement.

class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[source]

ScopedValue represents a value with its full-scope.

ScopedValue is used to express: a left-value such as target of an assign statement, or a callable object such as func of a call statement, or a right-value such as args and kwargs of an assign statement.

Parameters
  • arg_type (ValueType) – A ValueType represents type of current value.

  • scope (str, optional) – A string represents scope of current value. Take “self.var1” as an example, scope of this var1 is “self”. Default: "" .

  • value – A handler represents value of current value. The type of value is corresponding to arg_type. Default: None .

static create_name_values(names: Union[List[str], Tuple[str]], scopes: Union[List[str], Tuple[str]] = None)[source]

Create a list of naming ScopedValue.

Parameters
  • names (List[str] or Tuple[str]) – List or tuple of str represents names of referenced variables.

  • scopes (List[str] or Tuple[str], optional) – List or tuple of str represents scopes of referenced variables. Default: None .

Returns

An list of instance of ScopedValue.

Raises
  • TypeError – If names is not list or tuple and name in names is not str.

  • TypeError – If scopes is not list or tuple and scope in scopes is not str.

  • ValueError – If the length of names is not equal to the length of scopes when scopes are not None.

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variables = ScopedValue.create_name_values(names=["z", "z_1"], scopes=["self", "self"])
>>> print(variables)
[self.z, self.z_1]
classmethod create_naming_value(name: str, scope: str = '')[source]

Create a naming ScopedValue. A NamingValue represents a reference to another variable.

Parameters
  • name (str) – A string represents the identifier of another variable.

  • scope (str, optional) – A string represents the scope of another variable. Default: "" .

Returns

An instance of ScopedValue.

Raises

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_naming_value("conv", "self")
>>> print(variable)
self.conv
classmethod create_variable_value(value)[source]

Create ScopedValue from a variable.

ScopedValue’s type is determined by type of value. ScopedValue’s scope is empty.

Parameters

value – The value to be converted to ScopedValue.

Returns

An instance of ScopedValue.

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_variable_value(2)
>>> print(variable)
2
class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[source]

SymbolTree stores information about a network, including statements of the network’s forward computation process and the topological relationship between statement input and output.

The statements in the network are saved in the SymbolTree in the form of nodes, and by processing the nodes in the SymbolTree, you can delete the network code, insert and replace it, and get the modified network code and network instances.

Parameters

handler (SymbolTreeImpl) – SymbolTree internal implementation instance. It is recommended to call the create method in SymbolTree to create a SymbolTree, rather than calling SymbolTree’s constructor directly. Don’t care what SymbolTreeImpl is, just treat it as a handle.

after(node: Union[Node, str])[source]

Returns a location information after node. The return value of this interface is used as a parameter for the insert operation.

Parameters

node (Union[Node, str]) – Indicate the position after which node. Can be a node or name of node.

Returns

A Position to indicate where to insert node.

Raises

TypeError – If node is not a Node or str.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.after(node)
before(node: Union[Node, str])[source]

Returns a location information before node. The return value of this interface is used as a parameter for the insert operation.

Parameters

node (Union[Node, str]) – Indicate the position before which node. Can be a node or name of node.

Returns

A Position to indicate where to insert node.

Raises

TypeError – if node is not a Node or str.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.before(node)
classmethod create(network)[source]

Create a SymbolTree object by passing in the network instance network.

This interface parses the network instance, expands each source code statement of the forward computation process, and parses it into nodes, which is stored in the SymbolTree. The specific process is as follows:

  1. Obtain the source code of the network instance.

  2. Perform AST parsing on the network and obtain the AST nodes (abstract syntax trees) of each statement in the network.

  3. Expand complex statements in the network forward evaluation process into multiple simple statements.

  4. Create a SymbolTree object. Each SymbolTree corresponds to one network instance.

  5. Use the rewrite node to store each statement of the network forward computation process. The node records the input, output, and other information of the statement.

  6. Save the rewrite node to the SymbolTree, and update and maintain the topological connection between the nodes.

  7. Return the SymbolTree object corresponding to the network instance.

If a user-defined network of type mindspore.nn.Cell is called in the forward computation process of the network, rewrite will generate a node of type NodeType.Tree for the corresponding statement. This type of node stores a new SymbolTree, which parses and maintains the node information of the user-defined network.

If the following types of statements are called in the forward computation process of the network, rewrite will parse the internal statements in the statement and generate corresponding nodes:

  • mindspore.nn.SequentialCell

  • Functions(Excludes Python built-in functions and third-party library functions)

  • Control flow statements, such as if statements

Note

Because the specific execution branch of control flows are still unknown during the rewrite operation of the network, no topology information will be established between the nodes inside the control flow and the nodes outside. Users cannot obtain nodes inside the control flow when they acquire nodes outside the control flow using interfaces like mindspore.rewrite.Node.get_inputs() and mindspore.rewrite.Node.get_users() . Users also cannot obtain nodes outside the control flow, if they use these interfaces inside the control flow. Therefore, when users modify the network, they need to manually handle the node information inside and outside the control flow.

The current rewrite module has the following syntax limitations:

  • Only networks of type mindspore.nn.Cell are supported as input to the rewrite module.

  • Parsing one-line control flow syntax(e.g. one-line if-else, one-line for loop) is not currently supported.

  • Parsing decorator syntax is not currently supported.

  • Parsing local classes and embedded classes is not currently supported, that is, the definition of classes need to be placed on the outermost layer.

  • Parsing closure syntax is not currently supported, that is, the definition of out-of-class functions need to be placed at the outermost layer.

  • Parsing lambda expression syntax is not currently supported.

  • Parsing global variables is not currently supported, that is, global variables need to be converted to class variables or local variables before they can be used.

  • Parsing methods in the parent classes is not currently supported.

For statements that do not support parsing, rewrite will generate nodes of type NodeType.Python for corresponding statements to ensure that the network after rewrite can run normally. The Python node does not support modifying the input and output of statements, and there may be a problem between variable names and those generated by the rewrite. In this case, users need to adjust the variable names manually.

Parameters

network (Cell) – network used to create SymbolTree.

Returns

Symboltree, a SymbolTree created based on network.

Raises

TypeError – If network is not a Cell instance.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> print(type(stree))
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
erase(node: Union[Node, str])[source]

Erase a node from rewrite.

Parameters

node (Union[Node, str]) – A Node to be erased. Can be a node or name of node.

Returns

An instance of Node being erased if node is in SymbolTree else None.

Raises

TypeError – The type of node is not Node or str.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> stree.erase(node)
get_code()[source]

Get source code corresponding to the network information in SymbolTree. If the network has already been modified, the source code of modified network is returned.

Returns

A str represents source code of modified network.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> codes = stree.get_code()
>>> print(codes)
get_network()[source]

Get the network object generated based on SymbolTree. The source code is saved to a file in the ‘rewritten_network’ folder of the current directory.

Note

  • The modification of network by rewrite module is based on the modification of AST tree of original network instance, and the new network instance will obtain attribute information from original network instance, so the new network instance and the original network instance have data association, and the original network should no longer be used.

  • Due to the data association between the new network and the original network instance, manually creating a network instance using the source code file generated by rewrite is not currently supported.

Returns

A network object generated from SymbolTree.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> new_net = stree.get_network()
get_node(node_name: str)[source]

Get the node with the name node_name in the SymbolTree.

Parameters

node_name (str) – The name of node.

Returns

Node with name of node_name . Return None if there is no node named node_name in SymbolTree.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node('conv1')
>>> print(node.get_name())
conv1
insert(position, node: Node)[source]

Insert a node into SymbolTree at position.

position is obtained from before api or after api of SymbolTree.

Parameters
  • position (Position) – Indicate where to insert node.

  • node (Node) – An instance of Node to be inserted.

Returns

An instance of Node being inserted.

Raises
  • ValueError – If position is not belong to current SymbolTree.

  • TypeError – If position is not a Position.

  • TypeError – If node is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.insert(position, new_node)
nodes(all_nodes: bool = False)[source]

Get the generator of the node in the current SymbolTree, which is used to iterate through the nodes in SymbolTree.

Parameters

all_nodes (bool) – Get all nodes including nodes in CallFunction node, CellContainer node and sub symbol tree. Default: False .

Returns

A generator for nodes in SymbolTree.

Raises

TypeError – If all_nodes is not bool.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> print([node.get_name() for node in stree.nodes()])
['input_x', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1',
 'unaryop_not', 'if_node', 'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return_1']
print_node_tabulate(all_nodes: bool = False)[source]

Print the topology information of nodes in SymbolTree, including node type, node name, node code, and node input-output relationship.

The information is output to the screen using the print interface, including the following information:

  • node type (str): The type of node, refer to class:mindspore.rewrite.NodeType .

  • name (str): The name of node.

  • codes (str): The code statement in the SymbolTree corresponding to the node.

  • arg providers (Dict[int, Tuple[str, int]]): The format is {[idx, (n, k)]} , which means the idx th parameter of the node is provided by the k th output of node n .

  • target users (Dict[int, List[Tuple[str, int]]]): The format is ‘{[idx, [(n, k)]]}’ , which means the idx th output of the node is used as the k th parameter of node n .

Parameters

all_nodes (bool) – Print information of all nodes, including nodes in CallFunction node, CellContainer node and sub symbol tree. Default: False .

Raises

TypeError – If all_nodes is not bool.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> stree.print_node_tabulate()
replace(old_node: Node, new_nodes: List[Node])[source]

Replace the old_node with nodes in the new_nodes list.

Nodes in new_nodes will be inserted into SymbolTree sequentially, and then old_node will be deleted.

Note

  • Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi replacement, please refer to PatternEngine.

  • Caller should maintain the topological relationship between each node in the new_nodes , as well as the topological relationship between nodes in the new_nodes and nodes in the original tree.

Parameters
  • old_node (Node) – Node to be replaced.

  • new_nodes (List[Node]) – Nodes of the node_tree to replace in.

Returns

An instance of Node represents root of node_tree been replaced in.

Raises
  • TypeError – If old_node is not a Node.

  • TypeError – If new_nodes is not a list or node in new_nodes is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.replace(node, [new_node])
unique_name(name: str = 'output')[source]

Based on the given name , returns a new name that is unique within the symbol tree. This interface can be used when a variable name that does not conflict is required.

Parameters

name (str, optional) – The prefix of the name. Defaults to "output" .

Returns

str, A new, unique name within a symbol tree in the format name_n, where n is a numeric subscript. If there is no name conflict when entered name, there is no numeric subscript.

Raises

TypeError – The type of name is not str.

class mindspore.rewrite.ValueType[source]

ValueType represents type of ScopedValue.

  • A NamingValue represents a reference to another variable.

  • A CustomObjValue represents an instance of custom class or an object whose type is out of range of base-type and container-type of ValueType.