mindspore.rewrite
- class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[源代码]
A SymbolTree usually corresponding to forward method of a network.
- Parameters
network (Cell) – Network to be rewritten. Only support Cell-type-network now.
- Raises
RuntimeError – If network is not a Cell.
RuntimeError – If there is any unsupported ast node type while parsing or optimizing.
- after(node: Node)[源代码]
Get insert position after input node.
Position is used to indicate where to insert node, it indicates position in source code rather than position in topological order. We don’t need to care about what Position is, just treat it as a handler and use it as an arguments of insert api of SymbolTree.
- before(node: Node)[源代码]
Get insert position before input node.
Position is used to indicate where to insert node, it indicates position in source code rather than position in topological order. We don’t need to care about what Position is, just treat it as a handler and use it as an arguments of insert api of SymbolTree.
- erase_node(node: Node)[源代码]
Erase a node from rewrite. Can only erase a node not being depended on.
- get_code()[源代码]
Get source code of modified network.
- Returns
A str represents source code of modified network.
- get_inputs()[源代码]
Get input nodes of current SymbolTree.
- Returns
[Node], the node list of the current Symboltree.
- get_network()[源代码]
Get modified network. The source code of network is saved to a file, the default file name is network_define.py.
- Returns
A network object.
- insert(position, node: Node)[源代码]
Insert a node into SymbolTree at position.
position is obtained from before api or after api of SymbolTree.
- Parameters
position (Position) – Indicate where to insert node.
node (Node) – An instance of Node to be inserted.
- Returns
An instance of Node being inserted. node could be changed while calling this method for uniqueness and custom-object in args or kwargs.
- Raises
RuntimeError – If position is not belong to current SymbolTree.
TypeError – If position is not a Position.
TypeError – If node is not a Node.
- nodes()[源代码]
Get a generator for node of corresponding network.
- Returns
A generator for node of current SymbolTree.
- replace(old_node: Node, new_nodes: [Node])[源代码]
Replace old_node with a node_tree.
Note
Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi replacement, please refer to PatternEngine.
When applying one-to-multi replacement, Rewrite will insert all new_nodes into symbol_tree.
Caller should maintain arguments and targets of nodes intra sub-tree for specifying topological relation intra sub-tree.
Caller should maintain arguments of input nodes of sub-tree and for specifying topological relation of inputs of sub-tree.
Rewrite will maintain arguments of prepend node of sub-tree for specifying topological relation of outputs of sub-tree.
Rewrite will maintain all inputs of nodes after replace new_nodes into SymbolTree.
- Parameters
- Returns
An instance of Node represents root of node_tree been replaced in.
- Raises
RuntimeError – Old node is isolated.
TypeError – If old_node is not a Node.
TypeError – If new_nodes is not a list or node in new_nodes is not a Node.
- save_network_to_file()[源代码]
Save the modified network to a file. Default file name is network_define.py.
- class mindspore.rewrite.Node(node: NodeImpl)[源代码]
Node is a data structure represents a source code line in network.
For the most part, Node represents an operator invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method.
NodeImpl mentioned below is implementation of Node which is not an interface of Rewrite. Rewrite recommend invoking specific create method of Node to instantiate an instance of Node such as create_call_cell rather than invoking constructor of Node directly, so don’t care about what is NodeImpl and use its instance just as a handler.
- Parameters
node (NodeImpl) – A handler of NodeImpl.
- static create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue} = None, name: str = '', is_sub_net: bool = False)[源代码]
Create a node. Only support create from a Cell now.
A node is corresponding to source code like:
`targets` = self.`name`(*`args`, **`kwargs`)
- Parameters
cell (Cell) – Cell-operator of this forward-layer.
targets (list[ScopedValue]) – Indicate output names. Used as targets of an assign statement in source code. Rewrite will check and ensure the uniqueness of each target while node being inserted.
args (list[ScopedValue]) – Indicate input names. Used as args of a call expression of an assign statement in source code. Default is None indicate the cell has no args inputs. Rewrite will check and ensure the uniqueness of each arg while node being inserted.
kwargs (dict) – Type of key must be str and type of value must be ScopedValue. Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source code. Default is None indicate the cell has no kwargs inputs. Rewrite will check and ensure the uniqueness of each kwarg while node being inserted.
name (str) – Indicate the name of node. Used as field name in source code. Default is None. Rewrite will generate name from targets when name is None. Rewrite will check and ensure the uniqueness of name while node being inserted.
is_sub_net (bool) – Indicate that is cell a network. If is_sub_net is true, Rewrite will try to parse the cell to a TreeNode, else a CallCell Node. Default is a False.
- Returns
An instance of Node.
- Raises
- get_args()[源代码]
Get the arguments of current node.
When node_type of current node is CallCell, CallPrimitive or Tree, arguments are corresponding to args of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
When node_type of current node is Input, arguments represents default-value of argument of function.
When node_type of current node is Output, arguments represents return values.
When node_type of current node is Python, arguments are don’t-care.
- Returns
A list of instances of ScopedValue.
- get_attributes()[源代码]
Get all attributes of current node.
- Returns
A dict of str to instance of object as attributes.
- get_inputs()[源代码]
Get input nodes of current node in topological order.
- Returns
A list of instances of Node as input nodes.
- get_instance()[源代码]
Get the instance of current node.
When node_type of current node is CallCell, instance is an instance of Cell.
When node_type of current node is CallPrimitive, instance is an instance of primitive.
When node_type of current node is Tree, instance is an instance of network-cell.
When node_type of current node is Python, Input, Output or CallMethod, instance should be None.
- Returns
A object represents corresponding instance of current node.
- get_instance_type()[源代码]
Get the instance_type of current node.
When node_type of current node is CallCell, instance_type is type of cell-op.
When node_type of current node is CallPrimitive, instance_type is type of primitive-op.
When node_type of current node is Tree, instance_type is type of network-cell.
When node_type of current node is Python, Input, Output or CallMethod, instance_type should be NoneType.
- Returns
A type object represents corresponding instance type of current node.
- get_kwargs()[源代码]
Get the keyword arguments of current node.
When node_type of current node is CallCell, CallPrimitive or Tree, keyword arguments are corresponding to kwargs of ast.Call which represents arguments to invoke forward method of cell-op or primitive-op.
When node_type of current node is Python, Input or Output, keyword arguments are don’t-care.
- Returns
A dict of str to instance of ScopedValue.
- get_name()[源代码]
Get the name of current node.
When node has been inserted into SymbolTree, the name of node should be unique in SymbolTree.
- Returns
A string as name of node.
- get_next()[源代码]
Get next node of current node in source code order.
- Returns
An instance of Node as next node.
- get_prev()[源代码]
Get previous node of current node in source code order.
- Returns
An instance of Node as previous node.
- get_targets()[源代码]
Get targets of current node.
When node_type of current node is CallCell, CallPrimitive, CallMethod or Tree, targets are strings represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of ast.Assign.
When node_type of current node is Input, targets should have only one element which is a string represents parameter of function.
When node_type of current node is Python or Output, targets are don’t-care.
- Returns
A list of instances of ScopedValue as targets of node.
- get_users()[源代码]
Get output nodes of current node in topological order.
- Returns
A list of nodes represents users.
- set_arg(index: int, arg: Union[ScopedValue, str])[源代码]
Set argument of current node.
- Parameters
index (int) – Indicate which input being modified.
arg (Union[ScopedValue, str]) – New argument to been set.
- Raises
- set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[源代码]
Set argument of current node by another Node.
- Parameters
- Raises
RuntimeError – If src_node is not belong to current SymbolTree.
RuntimeError – If current node and src_node is not belong to same SymbolTree.
TypeError – If arg_idx is not a int number.
ValueError – If arg_idx is out of range.
TypeError – If src_node is not a Node instance.
TypeError – If out_idx is not a int number.
ValueError – If out_idx is out of range.
ValueError – If src_node has multi-outputs while out_idx is None or out_idx is not offered.
- class mindspore.rewrite.NodeType[源代码]
NodeType represents type of Node.
Unknown: Not inited NodeType.
CallCell: CallCell node represents invoking cell-op in forward method.
CallPrimitive: CallPrimitive node represents invoking primitive-op in forward method.
CallMethod: CallMethod node represents invoking of method in forward method which can not be mapped to cell-op or primitive-op in MindSpore.
Python: Python node holds unsupported-ast-node or unnecessary-to-parse-ast-node.
Input: Input node represents input of SymbolTree corresponding to arguments of forward method.
Output: Output node represents output of SymbolTree corresponding to return statement of forward method.
Tree: Tree node represents sub-network invoking in forward method.
- class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[源代码]
ScopedValue represents a value with its full-scope.
ScopedValue is used to express: a left-value such as target of an assign statement, or a callable object such as func of a call statement, or a right-value such as args and kwargs of an assign statement.
- Parameters
- static create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None)[源代码]
Create a list of naming ScopedValue.
- Parameters
names – (list[str] or tuple[str]): List or tuple of str represents names of referenced variables.
scopes – (list[str] or tuple[str]): List or tuple of str represents scopes of referenced variables.
- Returns
An list of instance of ScopedValue.
- Raises
RuntimeError – If the length of names is not equal to the length of scopes when scopes are not None.
TypeError – If names is not list or tuple and name in names is not str.
TypeError – If scopes is not list or tuple and scope in scopes is not str.
- class mindspore.rewrite.ValueType[源代码]
ValueType represents type of ScopedValue.
A NamingValue represents a reference to another variable.
A CustomObjValue represents an instance of custom class or an object whose type is out of range of base-type and container-type of ValueType.
- class mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None)[源代码]
PatternEngine is defined how to transform a SymbolTree by PattenNode.
- Parameters
pattern (Union[PatternNode, List]) – An instance of PatternNode or a cell-type-list to construct PatternNode as root of a pattern.
replacement (callable) – A callable define how to generate new_node.
- apply(stree: SymbolTree)[源代码]
Apply current pattern to a SymbolTree.
Note
Sub-tree node will be supported in the near feature.
- Parameters
stree (SymbolTree) – A SymbolTree to be transformed.
- Returns
A bool represents if stree been changed.
- Raises
TypeError – If stree is not a SymbolTree instance.
- class mindspore.rewrite.PatternNode(pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None)[源代码]
PatternNode is defined as a node while defining pattern.
- Parameters
pattern_node_name (str) – Name of current node.
match_type (Type) – A type represents what type would be matched of current node.
inputs (list[PatternNode]) – Input nodes of current node.
- add_input(node)[源代码]
Add an input for current PatternNode.
- Parameters
node (PatternNode) – Cell type as an input.
- Raises
TypeError – If node is not a PatternNode instance.
- set_inputs(inputs)[源代码]
Set inputs for current PatternNode.
- Parameters
inputs (list[PatternNode]) – Inputs to be set as inputs of current PatternNode.
- Raises
TypeError – If inputs is not a list or input in inputs is not PatternNode instance.
- class mindspore.rewrite.VarNode[源代码]
VarNode is a subclass of PatternNode whose match method is always return True.
- class mindspore.rewrite.Replacement[源代码]
Interface of replacement function.
- abstract build(pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict)[源代码]
Interface define for creating replacement nodes from matched result.
Note
Return value will be delivered into replace api of SymbolTree as argument, return value should follow restraint of parameter new_nodes of replace api if SymbolTree. See detail in docstring of replace api of SymbolTree.
- Parameters
pattern (PatternNode) – A PatternNode represents root node of current pattern.
is_chain_pattern (bool) – A bool indicated if pattern is a chain pattern or a tree pattern.
matched (OrderedDict) – An OrderedDict map from pattern_node name to node represents matched result.
- Returns
A list of instance of Node as replacement nodes.
- class mindspore.rewrite.TreeNodeHelper[源代码]
TreeNodeHelper is used to break circle reference while getting symbol_tree from a Tree type Node.
TreeNodeHelper provides a staticmethod get_sub_tree for getting symbol_tree from a Tree type Node.
- static get_sub_tree(node: Node)[源代码]
Getting symbol_tree from a Tree type Node.
- Parameters
node (Node) – A Node who may hold a sub-symbol_tree.
- Returns
An instance of SymbolTree represents sub-symbol_tree. Note that node’s symbol_tree maybe None, in this case, method will return None.
- Raises
RuntimeError – If node’s type is not NodeType.Tree.
TypeError – If node is not a Node instance.