mindspore.rewrite

The ReWrite module in MindSpore provides users with the ability to modify the network’s forward computation process based on custom rules, such as inserting, deleting, and replacing statements.

For a quick start of using ReWrite, please refer to Modifying Network With ReWrite .

MindSpore Rewrite package. This is an experimental python package that is subject to change or deletion.

class mindspore.rewrite.Node(node: NodeImpl)[source]

A node is a data structure that expresses source code statements in a network.

Each node usually corresponds to a statement in expanded forward evaluation process.

Nodes can express a Cell call statement, a Primitive call statement, an arithmetic operation statement, a return statements, etc. of the forward calculation process.

Parameters

node (NodeImpl) – A handler of NodeImpl. It is recommended to call the specific methods in Node to create a Node, such as ‘create_call_cell’, rather than calling the Node’s constructor directly. Don’t care what NodeImpl is, just treat it as a handle.

static create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue} = None, name: str = '', is_sub_net: bool = False)[source]

Create a node. Only support create from a Cell now.

A node is corresponding to source code like:

`targets` = self.`name`(*`args`, **`kwargs`)
Parameters
  • cell (Cell) – Cell-operator of this forward-layer.

  • targets (list[ScopedValue]) – Indicate output names. Used as targets of an assign statement in source code.

  • args (list[ScopedValue]) – Indicate input names. Used as args of a call expression of an assign statement in source code. Default: None , which indicates the cell has no args inputs.

  • kwargs (dict) – Type of key must be str and type of value must be ScopedValue. Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source code. Default: None , which indicates the cell has no kwargs inputs.

  • name (str) – Indicate the name of node. Used as field name in source code. Default is None. Rewrite will generate name from targets when name is None. Rewrite will check and ensure the uniqueness of name while node being inserted. Default: "" .

  • is_sub_net (bool) – Indicate that is cell a network. If is_sub_net is true, Rewrite will try to parse the cell to a TreeNode, otherwise the cell is parsed to a CallCell node. Default: False .

Returns

An instance of Node.

Raises
  • TypeError – If cell is not a Cell.

  • TypeError – If targets is not list.

  • TypeError – If the type of targets is not in [ScopedValue, str].

  • TypeError – If arg in args is not a ScopedValue.

  • TypeError – If key of kwarg is not a str or value of kwarg in kwargs is not a ScopedValue.

Examples

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.insert(position, new_node)
>>> print(type(new_node))
<class 'mindspore.rewrite.api.node.Node'>
get_args()[source]

Get arguments of current node.

Returns

A list of arguments of type ScopedValue .

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> print(node.get_args())
[x]
get_inputs()[source]

Gets a list of nodes whose output values are used as input values for the current node.

Returns

A list of nodes.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv2")
>>> inputs = node.get_inputs()
>>> print([input.get_name() for input in inputs])
['max_pool2d']
get_instance_type()[source]

Gets the instance type called in the code corresponding to the current node.

  • When node_type of current node is CallCell, the code for that node calls an instance of type Cell .

  • When node_type of current node is CallPrimitive, the code for that node calls an instance of type Primitive .

  • When node_type of current node is Tree, the code for that node calls an instance of network type.

  • When node_type of current node is Python, Input, Output or CallMethod, the instance type is NoneType .

Returns

The type of instance called in the statement corresponding to the current node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> instance_type = node.get_instance_type()
>>> print(instance_type)
<class 'mindspore.nn.layer.conv.Conv2d'>
get_name()[source]

Get the name of current node.

When node has been inserted into SymbolTree, the name of node should be unique in SymbolTree.

Returns

A string as name of node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> name = node.get_name()
>>> print(name)
conv1
get_node_type()[source]

Get the node_type of current node. See mindspore.rewrite.NodeType for details on node types.

Returns

A NodeType as node_type of node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node_type = node.get_node_type()
>>> print(node_type)
NodeType.CallCell
get_targets()[source]

Gets a list of output values for the current node.

Returns

A list of outputs of type ScopedValue .

get_users()[source]

Get a list of nodes that use the output of the current node as input.

Returns

A list of nodes.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> users = node.get_users()
>>> print([user.get_name() for user in users])
['relu']
set_arg(index: int, arg: Union[ScopedValue, str])[source]

Set argument of current node.

Parameters
  • index (int) – Indicate which input being modified.

  • arg (Union[ScopedValue, str]) – New argument to been set.

Raises
  • TypeError – If index is not a int number.

  • TypeError – If the type of arg is not in [ScopedValue, str].

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("relu_3")
>>> node.set_arg(0, "fc1")
>>> print(node.get_args())
[fc1]
set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[source]

Set argument of current node by another Node.

Parameters
  • arg_idx (int) – Indicate which input being modified.

  • src_node (Node) – A Node as new input. Can be a node or name of node.

  • out_idx (int, optional) – Indicate which output of src_node as new input of current node. Default: None , which means use first output of src_node as new input.

Raises
  • RuntimeError – If src_node is not belong to current SymbolTree.

  • TypeError – If arg_idx is not a int number.

  • ValueError – If arg_idx is out of range.

  • TypeError – If src_node is not a Node instance.

  • TypeError – If out_idx is not a int number.

  • ValueError – If out_idx is out of range.

  • ValueError – If src_node has multi-outputs while out_idx is None or out_idx is not offered.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> src_node = stree.get_node("fc1")
>>> dst_node = stree.get_node("relu_3")
>>> dst_node.set_arg_by_node(0, src_node, 0)
>>> print(dst_node.get_args())
[fc1]
class mindspore.rewrite.NodeType[source]

NodeType represents type of Node.

  • Unknown: Not inited NodeType.

  • CallCell: CallCell node represents invoking cell-op in forward method.

  • CallPrimitive: CallPrimitive node represents invoking primitive-op in forward method.

  • CallFunction: CallFunction node represents invoking mindspore function in forward method.

  • CallMethod: CallMethod node represents invoking of method in forward method which can not be mapped to cell-op or primitive-op in MindSpore.

  • Python: Python node holds unsupported-ast-node or unnecessary-to-parse-ast-node.

  • Input: Input node represents input of SymbolTree corresponding to arguments of forward method.

  • Output: Output node represents output of SymbolTree corresponding to return statement of forward method.

  • Tree: Tree node represents sub-network invoking in forward method.

  • MathOps: MathOps node represents a mathematical operation, such as adding or comparing in forward method.

class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[source]

ScopedValue represents a value with its full-scope.

ScopedValue is used to express: a left-value such as target of an assign statement, or a callable object such as func of a call statement, or a right-value such as args and kwargs of an assign statement.

Parameters
  • arg_type (ValueType) – A ValueType represents type of current value.

  • scope (str) – A string represents scope of current value. Take “self.var1” as an example, scope of this var1 is “self”. Default: "" .

  • value – A handler represents value of current value. The type of value is corresponding to arg_type. Default: None .

static create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None)[source]

Create a list of naming ScopedValue.

Parameters
  • names (list[str] or tuple[str]) – List or tuple of str represents names of referenced variables.

  • scopes (list[str] or tuple[str]) – List or tuple of str represents scopes of referenced variables. Default: None .

Returns

An list of instance of ScopedValue.

Raises
  • RuntimeError – If the length of names is not equal to the length of scopes when scopes are not None.

  • TypeError – If names is not list or tuple and name in names is not str.

  • TypeError – If scopes is not list or tuple and scope in scopes is not str.

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variables = ScopedValue.create_name_values(names=["z", "z_1"], scopes=["self", "self"])
>>> print(variables)
[self.z, self.z_1]
classmethod create_naming_value(name: str, scope: str = '')[source]

Create a naming ScopedValue. A NamingValue represents a reference to another variable.

Parameters
  • name – (str): A string represents the identifier of another variable.

  • scope – (str): A string represents the scope of another variable. Default: "" .

Returns

An instance of ScopedValue.

Raises

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_naming_value("conv", "self")
>>> print(variable)
self.conv
classmethod create_variable_value(value)[source]

Create ScopedValue from a variable.

ScopedValue’s type is determined by type of value. ScopedValue’s scope is empty.

Parameters

value – The value to be converted to ScopedValue.

Returns

An instance of ScopedValue.

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_variable_value(2)
>>> print(variable)
2
class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[source]

SymbolTree stores information about a network, including statements of the network’s forward computation process and the topological relationship between statement input and output.

The statements in the network are saved in the SymbolTree in the form of nodes, and by processing the nodes in the SymbolTree, you can delete the network code, insert and replace it, and get the modified network code and network instances.

Parameters

handler (SymbolTreeImpl) – SymbolTree internal implementation instance. It is recommended to call the create method in SymbolTree to create a SymbolTree, rather than calling SymbolTree’s constructor directly. Don’t care what SymbolTreeImpl is, just treat it as a handle.

after(node: Union[Node, str])[source]

Returns a location information after node. The return value of this interface is used as a parameter for the insert operation.

Parameters

node (Union[Node, str]) – Indicate the position after which node. Can be a node or name of node.

Returns

A Position to indicate where to insert node.

Raises

TypeError – If node is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.after(node)
before(node: Union[Node, str])[source]

Returns a location information before node. The return value of this interface is used as a parameter for the insert operation.

Parameters

node (Union[Node, str]) – Indicate the position before which node. Can be a node or name of node.

Returns

A Position to indicate where to insert node.

Raises

TypeError – if node is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.before(node)
classmethod create(network)[source]

Create a SymbolTree object by passing in the network instance network.

This interface parses the network instance, expands each source code statement of the forward computation process, and parses it into nodes, which is stored in the SymbolTree.

Parameters

network (Cell) – network used to create SymbolTree.

Returns

Symboltree, a SymbolTree created based on network.

Raises

TypeError – If network is not a Cell instance.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> print(type(stree))
<class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
erase(node: Union[Node, str])[source]

Erase a node from rewrite.

Parameters

node (Union[Node, str]) – A Node to be erased. Can be a node or name of node.

Returns

An instance of Node being erased if node is in SymbolTree else None.

Raises

TypeError – The type of node is not Node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> stree.erase(node)
get_code()[source]

Get source code corresponding to the network information in SymbolTree. If the network has already been modified, the source code of modified network is returned.

Returns

A str represents source code of modified network.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> codes = stree.get_code()
>>> print(codes)
get_network()[source]

Get the network object generated based on SymbolTree. The source code is saved to a file in the ‘rewritten_network’ folder of the current directory.

Returns

A network object generated from SymbolTree.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> new_net = stree.get_network()
get_node(node_name: str)[source]

Get the node with the name node_name in the SymbolTree.

Parameters

node_name (str) – The name of node.

Returns

Node with name of node_name . Return None if there is no node named node_name in SymbolTree.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node('conv1')
>>> print(node.get_name())
conv1
insert(position, node: Node)[source]

Insert a node into SymbolTree at position.

position is obtained from before api or after api of SymbolTree.

Parameters
  • position (Position) – Indicate where to insert node.

  • node (Node) – An instance of Node to be inserted.

Returns

An instance of Node being inserted.

Raises

Examples

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.insert(position, new_node)
nodes()[source]

Get the generator of the node in the current SymbolTree, which is used to iterate through the nodes in SymbolTree.

Returns

A generator for node of current SymbolTree.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> print([node.get_name() for node in stree.nodes()])
['input_x', 'Expr', 'conv1', 'relu', 'max_pool2d', 'conv2', 'relu_1', 'max_pool2d_1',
 'flatten', 'fc1', 'relu_2', 'fc2', 'relu_3', 'fc3', 'return']
print_node_tabulate()[source]

Print the topology information of nodes in SymbolTree, including node type, node name, node code, and node input-output relationship. The information is output to the screen using the print interface.

Warning

This is an experimental API that is subject to change or deletion.

replace(old_node: Node, new_nodes: [Node])[source]

Replace the old_node with nodes in the new_nodes list.

Nodes in new_nodes will be inserted into SymbolTree sequentially, and then old_node will be deleted.

Note

  • Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi replacement, please refer to PatternEngine.

  • Caller should maintain the topological relationship between each node in the new_nodes , as well as the topological relationship between nodes in the new_nodes and nodes in the original tree.

Parameters
  • old_node (Node) – Node to be replaced.

  • new_nodes (list[Node]) – Nodes of the node_tree to replace in.

Returns

An instance of Node represents root of node_tree been replaced in.

Raises
  • RuntimeError – Old node is not isolated.

  • TypeError – If old_node is not a Node.

  • TypeError – If new_nodes is not a list or node in new_nodes is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree, ScopedValue
>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
>>> stree.replace(node, [new_node])
unique_name(name: str = 'output')[source]

Based on the given name , returns a new name that is unique within the symbol tree. This interface can be used when a variable name that does not conflict is required.

Parameters

name (str, optional) – The prefix of the name. Defaults to "output" .

Returns

str, A new, unique name within a symbol tree in the format name_n, where n is a numeric subscript. If there is no name conflict when entered name, there is no numeric subscript.

Raises

TypeError – The type of name is not str.

class mindspore.rewrite.ValueType[source]

ValueType represents type of ScopedValue.

  • A NamingValue represents a reference to another variable.

  • A CustomObjValue represents an instance of custom class or an object whose type is out of range of base-type and container-type of ValueType.