mindspore.rewrite

For a complete ReWrite example, refer to rewrite_example.py 。 The main functions of the sample code include: how to create a SymbolTree through the network, and how to insert, delete, and replace the nodes in the SymbolTree. It also includes the modification of the subnet and node replacement through pattern matching.

from typing import OrderedDict
import numpy as np

import mindspore
from mindspore import Tensor, export
from mindspore.rewrite import SymbolTree, ScopedValue, Node, NodeType, Replacement, PatternEngine, PatternNode, \
    TreeNodeHelper
import mindspore.nn as nn
import mindspore.ops as ops


class SubNet(nn.Cell):
    """Subnetwork definition"""
    def __init__(self):
        super().__init__()
        self.dense = nn.Dense(in_channels=32, out_channels=32, weight_init="ones")
        self.mean = ops.ReduceMean(keep_dims=False)
        self.conv1 = nn.Conv2d(1, 1, 1, stride=1)

    def construct(self, x):
        x = self.conv1(x)
        x = self.dense(x)
        return x


class Net(nn.Cell):
    """Network definition"""
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 1, 1, pad_mode='valid')
        self.conv2 = nn.Conv2d(1, 1, 1, pad_mode='valid')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.simnet = SubNet()

    def construct(self, x):
        """The forward computing process of networks."""
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.simnet(x)
        x = self.flatten(x)
        return x


def create_stree(network):
    """Create SymbolTree"""
    stree = SymbolTree.create(network)
    stree.dump()
    return stree


def insert_node(stree):
    """Insert a node into the network"""
    for node in stree.nodes():
        if node.get_name() == "conv2": # Insert a new node before the node named 'conv2'
            position = stree.before(node)
            new_conv = nn.Conv2d(1, 1, 1)
            new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
                                                  args=node.get_args())
            stree.insert(position, new_conv_node)
            break
    # Update the input of an existing node with a new node
    if new_conv_node is not None:
        for node in stree.nodes():
            if node.get_name() == "relu_1":
                node.set_arg_by_node(0, new_conv_node)
                break


def insert_node_to_subtree(stree):
    """Inserting a node into a subnetwork"""
    def _insert_conv(stree: SymbolTree):
        for node in stree.nodes():
            if node.get_instance_type() == nn.Conv2d:
                position = stree.after(node)
                new_conv = nn.Conv2d(1, 1, 1)
                new_conv_node = Node.create_call_cell(new_conv, targets=['x_1'], name='new_conv',
                                                      args=[ScopedValue.create_naming_value('x_1')])
                stree.insert(position, new_conv_node)
                break
    # Insert a new node in the subnet named 'simnet'
    for node in stree.nodes():
        if node.get_node_type() == NodeType.Tree and node.get_name() == "simnet":
            _insert_conv(TreeNodeHelper.get_sub_tree(node))
            break


def delete_node(stree):
    """Delete nodes of type nn.Flatten"""
    for node in stree.nodes():
        if node.get_instance_type() == nn.Flatten:
            for n in node.get_users():
                n.set_arg(0, "x_7")
            stree.erase_node(node)
            break


def replace_node(stree):
    """Replace nodes in the network"""
    new_conv = nn.Conv2d(1, 1, 1)
    new_conv_node = Node.create_call_cell(new_conv, [ScopedValue.create_naming_value("replace_conv")],
                                          args=[ScopedValue.create_naming_value('x')])
    for node in stree.nodes():
        if node.get_name() == "conv1":
            new_conv_node = stree.replace(node, [new_conv_node])


def pattern_replace(stree):
    """Replace nodes by pattern matching"""
    class ConvReplacement(Replacement):
        """Create the implementation of a new node class."""
        def build(self, pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict) -> [Node]:
            bn_node: Node = matched.get(pattern.name())

            conv = nn.Conv2d(1, 1, 1)
            conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs(),
                                              name="pattern_conv")
            return [conv_node]

    class BnReplace(PatternEngine):
        """Replace node of type nn.MaxPool2d in the network"""
        def __init__(self):
            super().__init__([nn.MaxPool2d], ConvReplacement())

    bn_replace = BnReplace()
    bn_replace.apply(stree)


def get_net(stree):
    """Get the modified network"""
    return stree.get_network()


def get_code(stree):
    """Get the modified network code"""
    return stree.get_code()


def test_rewrite():
    """ReWrite test function."""
    net = Net()
    stree = create_stree(net)

    print(f"origin code: {stree.get_code()}")
    insert_node(stree)
    print(f"after inser node code: {stree.get_code()}")

    insert_node_to_subtree(stree)
    print(f"after inser node to subtree code: {stree.get_code()}")

    delete_node(stree)
    print(f"after remove node code: {stree.get_code()}")

    replace_node(stree)
    print(f"after replace node code: {stree.get_code()}")

    pattern_replace(stree)
    print(f"after pattern replace node code: {stree.get_code()}")

    inputs = Tensor(np.ones([1, 1, 32, 32]), mindspore.float32)
    new_net = get_net(stree)
    source_code = get_code(stree)
    print(source_code)
    out = new_net(inputs)
    print("out: ", out)
    export(new_net, inputs, file_name="new_net", file_format="MINDIR")


if __name__ == "__main__":
    test_rewrite()

MindSpore Rewrite package. This is an experimental python package that is subject to change or deletion.

class mindspore.rewrite.ArgType[source]

Argument types for sparsify.

  • CSR represents a CSRTensor.

  • COO represents a COOTensor.

  • NONSPARSE represents a non-sparse value.

class mindspore.rewrite.Node(node: NodeImpl)[source]

Node is a data structure represents a source code line in network.

For the most part, Node represents an operator invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method.

Parameters

node (NodeImpl) – A handler of NodeImpl. NodeImpl mentioned below is implementation of Node which is not an interface of Rewrite. Rewrite recommend invoking specific create method of Node to instantiate an instance of Node such as create_call_cell rather than invoking constructor of Node directly, so don’t care about what is NodeImpl and use its instance just as a handler.

static create_call_cell(cell: Cell, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, kwargs: {str: ScopedValue} = None, name: str = '', is_sub_net: bool = False)[source]

Create a node. Only support create from a Cell now.

A node is corresponding to source code like:

`targets` = self.`name`(*`args`, **`kwargs`)
Parameters
  • cell (Cell) – Cell-operator of this forward-layer.

  • targets (list[ScopedValue]) – Indicate output names. Used as targets of an assign statement in source code. Rewrite will check and ensure the uniqueness of each target while node being inserted.

  • args (list[ScopedValue]) – Indicate input names. Used as args of a call expression of an assign statement in source code. Rewrite will check and ensure the uniqueness of each arg while node being inserted. Default: None , which indicates the cell has no args inputs.

  • kwargs (dict) – Type of key must be str and type of value must be ScopedValue. Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source code. Rewrite will check and ensure the uniqueness of each kwarg while node being inserted. Default: None , which indicates the cell has no kwargs inputs.

  • name (str) – Indicate the name of node. Used as field name in source code. Default is None. Rewrite will generate name from targets when name is None. Rewrite will check and ensure the uniqueness of name while node being inserted. Default: "" .

  • is_sub_net (bool) – Indicate that is cell a network. If is_sub_net is true, Rewrite will try to parse the cell to a TreeNode, otherwise the cell is parsed to a CallCell node. Default: False .

Returns

An instance of Node.

Raises
  • TypeError – If cell is not a Cell.

  • TypeError – If targets is not list.

  • TypeError – If the type of targets is not in [ScopedValue, str].

  • TypeError – If arg in args is not a ScopedValue.

  • TypeError – If key of kwarg is not a str or value of kwarg in kwargs is not a ScopedValue.

get_inputs()[source]

Get input nodes of current node in topological order.

Returns

A list of instances of Node as input nodes.

get_instance_type()[source]

Get the instance_type of current node.

  • When node_type of current node is CallCell, instance_type is type of cell-op.

  • When node_type of current node is CallPrimitive, instance_type is type of primitive-op.

  • When node_type of current node is Tree, instance_type is type of network-cell.

  • When node_type of current node is Python, Input, Output or CallMethod, instance_type should be NoneType.

Returns

A type object represents corresponding instance type of current node.

get_name()[source]

Get the name of current node.

When node has been inserted into SymbolTree, the name of node should be unique in SymbolTree.

Returns

A string as name of node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> name = node.get_name()
get_node_type()[source]

Get the node_type of current node.

Returns

A NodeType as node_type of node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node_type = node.get_node_type()
get_users()[source]

Get output nodes of current node in topological order.

Returns

A list of nodes represents users.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> users = node.get_users()
set_arg(index: int, arg: Union[ScopedValue, str])[source]

Set argument of current node.

Parameters
  • index (int) – Indicate which input being modified.

  • arg (Union[ScopedValue, str]) – New argument to been set.

Raises
  • TypeError – If index is not a int number.

  • TypeError – If the type of arg is not in [ScopedValue, str].

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> node.set_arg(0, "x")
set_arg_by_node(arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None)[source]

Set argument of current node by another Node.

Parameters
  • arg_idx (int) – Indicate which input being modified.

  • src_node (Node) – A Node as new input. Can be a node or name of node.

  • out_idx (int, optional) – Indicate which output of src_node as new input of current node. Default is None which means use first output of src_node as new input.

Raises
  • RuntimeError – If src_node is not belong to current SymbolTree.

  • TypeError – If arg_idx is not a int number.

  • ValueError – If arg_idx is out of range.

  • TypeError – If src_node is not a Node instance.

  • TypeError – If out_idx is not a int number.

  • ValueError – If out_idx is out of range.

  • ValueError – If src_node has multi-outputs while out_idx is None or out_idx is not offered.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> src_node = stree.get_node("conv1")
>>> dst_node = stree.get_node("conv2")
>>> dst_node.set_arg_by_node(0, src_node)
class mindspore.rewrite.NodeType[source]

NodeType represents type of Node.

  • Unknown: Not inited NodeType.

  • CallCell: CallCell node represents invoking cell-op in forward method.

  • CallPrimitive: CallPrimitive node represents invoking primitive-op in forward method.

  • CallMethod: CallMethod node represents invoking of method in forward method which can not be mapped to cell-op or primitive-op in MindSpore.

  • Python: Python node holds unsupported-ast-node or unnecessary-to-parse-ast-node.

  • Input: Input node represents input of SymbolTree corresponding to arguments of forward method.

  • Output: Output node represents output of SymbolTree corresponding to return statement of forward method.

  • Tree: Tree node represents sub-network invoking in forward method.

  • MathOps: MathOps node represents a mathematical operation, such as adding or comparing in forward method.

class mindspore.rewrite.PatternEngine(pattern: Union[PatternNode, List], replacement: Replacement = None)[source]

PatternEngine is defined how to transform a SymbolTree by PattenNode.

Parameters
  • pattern (Union[PatternNode, List]) – An instance of PatternNode or a cell-type-list to construct PatternNode as root of a pattern.

  • replacement (callable) – A callable define how to generate new_node. Default: None .

apply(stree: SymbolTree)[source]

Apply current pattern to a SymbolTree.

Note

Sub-tree node will be supported in the near feature.

Parameters

stree (SymbolTree) – A SymbolTree to be transformed.

Returns

A bool represents if stree been changed.

Raises

TypeError – If stree is not a SymbolTree instance.

pattern()[source]

Getter of pattern.

Returns

A instance of PatternNode, used to indicate the type that the current pattern needs to match.

class mindspore.rewrite.PatternNode(pattern_node_name: str, match_type: Type = Type[None], inputs: ['PatternNode'] = None)[source]

PatternNode is defined as a node while defining pattern.

Parameters
  • pattern_node_name (str) – Name of current node.

  • match_type (Type) – A type represents what type would be matched of current node. Default: Type[None] .

  • inputs (list[PatternNode]) – Input nodes of current node. Default: None .

add_input(node)[source]

Add an input for current PatternNode.

Parameters

node (PatternNode) – Cell type as an input.

Raises

TypeError – If node is not a PatternNode instance.

static create_pattern_from_list(type_list: [])[source]

Create a Pattern from a cell type list.

Parameters

type_list (list[type]) – Input cell type list.

Returns

A PatternNode as root of pattern created from cell type list.

Raises

TypeError – If type_list is not a list.

static create_pattern_from_node(node: Node)[source]

Create a Pattern from node with its inputs.

Parameters

node (Node) – Input rewrite node.

Returns

A PatternNode as root of pattern created from rewrite node.

Raises

TypeError – If node is not a Node instance.

static from_node(node: Node)[source]

Create a PatternNode from node.

Parameters

node (Node) – Input rewrite node.

Returns

A PatternNode created from node.

Raises

TypeError – If node is not a Node instance.

get_inputs()[source]

Getter of inputs.

Returns

A list of PatternNode, the inputs of current node.

match(node: Node)[source]

Check if current PatternNode can match with node.

Parameters

node (Node) – A rewrite node to be match.

Raises

TypeError – If node is not a Node instance.

name()[source]

Getter of PatternNode name.

set_inputs(inputs)[source]

Set inputs for current PatternNode.

Parameters

inputs (list[PatternNode]) – Inputs to be set as inputs of current PatternNode.

Raises

TypeError – If inputs is not a list or input in inputs is not PatternNode instance.

type()[source]

Getter of PatternNode type.

class mindspore.rewrite.Replacement[source]

Interface of replacement function.

Examples

>>> from mindspore.rewrite import Replacement, Node
>>> from mindspore.nn import nn
>>> class BnReplacement(Replacement):
...     def build(self, pattern, is_chain_pattern: bool, matched):
...         bn_node: Node = matched.get(pattern.name())
...         conv = nn.Conv2d(16, 16, 3)
...         conv_node = Node.create_call_cell(conv, ['x1'], bn_node.get_args(), bn_node.get_kwargs())
...         return [conv_node]
abstract build(pattern: PatternNode, is_chain_pattern: bool, matched: OrderedDict)[source]

Interface define for creating replacement nodes from matched result.

Note

Return value will be delivered into replace api of SymbolTree as argument, return value should follow restraint of parameter new_nodes of replace api if SymbolTree. See detail in docstring of replace api of SymbolTree.

Parameters
  • pattern (PatternNode) – A PatternNode represents root node of current pattern.

  • is_chain_pattern (bool) – A bool indicated if pattern is a chain pattern or a tree pattern.

  • matched (OrderedDict) – An OrderedDict map from pattern_node name to node represents matched result.

Returns

A list of instance of Node as replacement nodes.

class mindspore.rewrite.ScopedValue(arg_type: ValueType, scope: str = '', value=None)[source]

ScopedValue represents a value with its full-scope.

ScopedValue is used to express: a left-value such as target of an assign statement, or a callable object such as func of a call statement, or a right-value such as args and kwargs of an assign statement.

Parameters
  • arg_type (ValueType) – A ValueType represents type of current value.

  • scope (str) – A string represents scope of current value. Take “self.var1” as an example, scope of this var1 is “self”. Default: "" .

  • value – A handler represents value of current value. The type of value is corresponding to arg_type. Default: None .

static create_name_values(names: Union[list, tuple], scopes: Union[list, tuple] = None)[source]

Create a list of naming ScopedValue.

Parameters
  • names (list[str] or tuple[str]) – List or tuple of str represents names of referenced variables.

  • scopes (list[str] or tuple[str]) – List or tuple of str represents scopes of referenced variables. Default: None .

Returns

An list of instance of ScopedValue.

Raises
  • RuntimeError – If the length of names is not equal to the length of scopes when scopes are not None.

  • TypeError – If names is not list or tuple and name in names is not str.

  • TypeError – If scopes is not list or tuple and scope in scopes is not str.

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variables = ScopedValue.create_name_values(["z", "z_1"], name="subnet")
classmethod create_naming_value(name: str, scope: str = '')[source]

Create a naming ScopedValue. A NamingValue represents a reference to another variable.

Parameters
  • name – (str): A string represents the identifier of another variable.

  • scope – (str): A string represents the scope of another variable. Default: "" .

Returns

An instance of ScopedValue.

Raises

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_naming_value("conv", "self")
classmethod create_variable_value(value)[source]

Create ScopedValue from a variable.

ScopedValue’s type is determined by type of value. ScopedValue’s scope is empty.

Parameters

value – The value to be converted to ScopedValue.

Returns

An instance of ScopedValue.

Examples

>>> from mindspore.rewrite import ScopedValue
>>> variable = ScopedValue.create_variable_value(2)
class mindspore.rewrite.SymbolTree(handler: SymbolTreeImpl)[source]

A SymbolTree usually corresponding to forward method of a network.

Parameters

handler (SymbolTreeImpl) – SymbolTree internal implementation instance.

after(node: Node)[source]

Get insert position after input node.

Position is used to indicate where to insert node, it indicates position in source code rather than position in topological order. We don’t need to care about what Position is, just treat it as a handler and use it as an arguments of insert api of SymbolTree.

Parameters

node (Node) – Indicate the position after which node. Can be a node or name of node.

Returns

A Position to indicate where to insert node.

Raises

TypeError – If node is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.after(node)
before(node: Node)[source]

Get insert position before input node.

Position is used to indicate where to insert node, it indicates position in source code rather than position in topological order. We don’t need to care about what Position is, just treat it as a handler and use it as an arguments of insert api of SymbolTree.

Parameters

node (Node) – Indicate the position before which node. Can be a node or name of node.

Returns

A Position to indicate where to insert node.

Raises

TypeError – if node is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     if node.get_name() == "conv1":
...         position = stree.before(node)
classmethod create(network)[source]

Create a new SymbolTree of the input network.

Parameters

network (Cell) – network used to create SymbolTree.

Returns

Symboltree, a Symboltree created based on network.

Raises

TypeError – If network is not a Cell instance.

dump()[source]

Print the ir map information corresponding to the network in ‘SymbolTree’ to the screen.

erase_node(node: Node)[source]

Erase a node from rewrite. Can only erase a node not being depended on.

Parameters

node (Node) – A Node to be erased. Can be a node or name of node.

Returns

An instance of Node being erased if node is in SymbolTree else None.

Raises

TypeError – If node is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> input_node = node.get_inputs()[0]
>>> output_nodes = node.get_users()
>>> for n in output_nodes:
...     n.set_arg(0, "x")
>>> stree.erase_node(node)
get_code()[source]

Get source code of modified network.

Returns

A str represents source code of modified network.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> codes = stree.get_code()
>>> print(codes)
get_network()[source]

Get modified network. The source code of network is saved to a file, the default file name is network_define.py.

Returns

A network object.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> new_net = stree.get_network()
insert(position, node: Node)[source]

Insert a node into SymbolTree at position.

position is obtained from before api or after api of SymbolTree.

Parameters
  • position (Position) – Indicate where to insert node.

  • node (Node) – An instance of Node to be inserted.

Returns

An instance of Node being inserted. node could be changed while calling this method for uniqueness and custom-object in args or kwargs.

Raises

Examples

>>> from mindspore.rewrite import SymbolTree
>>> from mindspore.ops import abs
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> node = stree.get_node("conv1")
>>> position = stree.after(node)
>>> new_node = stree.create_call_function(abs, ["x"], node)
>>> stree.insert(position, new_node)
nodes()[source]

Get a generator for node of corresponding network.

Returns

A generator for node of current SymbolTree.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> for node in stree.nodes():
...     print(node.get_name())
replace(old_node: Node, new_nodes: [Node])[source]

Replace old_node with a node_tree.

Note

  1. Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi replacement, please refer to PatternEngine.

  2. When applying one-to-multi replacement, Rewrite will insert all new_nodes into symbol_tree.

  3. Caller should maintain arguments and targets of nodes intra sub-tree for specifying topological relation intra sub-tree.

  4. Caller should maintain arguments of input nodes of sub-tree and for specifying topological relation of inputs of sub-tree.

  5. Rewrite will maintain arguments of prepend node of sub-tree for specifying topological relation of outputs of sub-tree.

  6. Rewrite will maintain all inputs of nodes after replace new_nodes into SymbolTree.

Parameters
  • old_node (Node) – Node to be replaced.

  • new_nodes (list[Node]) – Nodes of the node_tree to replace in.

Returns

An instance of Node represents root of node_tree been replaced in.

Raises
  • RuntimeError – Old node is not isolated.

  • TypeError – If old_node is not a Node.

  • TypeError – If new_nodes is not a list or node in new_nodes is not a Node.

Examples

>>> from mindspore.rewrite import SymbolTree
>>> from mindspore.ops import abs
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> stree = SymbolTree.create(net)
>>> input_node = stree.get_node("input_x")
>>> node = stree.get_node("conv1")
>>> new_node = stree.create_call_function(abs, ["x"], input_node)
>>> stree.replace(node, [new_node])
class mindspore.rewrite.TreeNodeHelper[source]

TreeNodeHelper is used to break circle reference while getting symbol_tree from a Tree type Node.

TreeNodeHelper provides a staticmethod get_sub_tree for getting symbol_tree from a Tree type Node.

static get_sub_tree(node: Node)[source]

Getting symbol_tree from a Tree type Node.

Parameters

node (Node) – A Node which may hold a sub-symbol_tree.

Returns

An instance of SymbolTree represents sub-symbol_tree. Note that node’s symbol_tree maybe None, in this case, method will return None.

Raises
class mindspore.rewrite.ValueType[source]

ValueType represents type of ScopedValue.

  • A NamingValue represents a reference to another variable.

  • A CustomObjValue represents an instance of custom class or an object whose type is out of range of base-type and container-type of ValueType.

class mindspore.rewrite.VarNode[source]

VarNode is a subclass of PatternNode whose match method is always return True.

mindspore.rewrite.sparsify(f, arg_types, sparse_rules=None)[source]

Sparsify a Cell object by inferring the appropriate sparse function calls to replace the original function calls by propagating sparse properties provided in arg_types.

Parameters
  • f (Cell) – Cell object to be sparsified.

  • arg_types (Tuple[ArgType] | Dict[int, ArgType]) – The type of argument (sparse csr, sparse coo, non-sparse etc.) expected by f. If arg_type is a tuple, its length should be the same as the number of arguments for f; if arg_type is a dictionary, each key represents an index into the arguments, and arguments not referenced by the dictionary are considered to be non-sparse.

  • sparse_rules (Dict[str, SparseFunc], Optional) – Additional sparse rules. Default: None .

class mindspore.rewrite.SparseFunc[source]

Represents a sparse function in sparsify.

Note

If fn is a function with type hints, inputs and/or outputs, when provided, override function type hints.

Parameters
  • fn (Union[str, Callable]) – a sparse function. If fn is a string, the function represents a mindspore functional op; or fn can be any function object.

  • inputs (Any, Optional) – input types for the function. If inputs is None, use the input types in function type hints. Default: None .

  • outputs (Any, Optional) – output types for the function. If outputs is None, use the output types in function type hints. Default: None .