Source code for mindspore.rewrite.api.symbol_tree

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Rewrite module api: SymbolTree."""
from typing import Optional
from types import FunctionType
import mindspore as ms

from mindspore.nn import Cell
from ..._checkparam import Validator
from .node import Node
from ..symbol_tree_builder import SymbolTreeBuilder
from ..symbol_tree import Position, SymbolTree as SymbolTreeImpl

ParamTypes = (int, str, float, bool, Node)
MsDtypes = (ms.float16, ms.float32, ms.float64)


[docs]class SymbolTree: """ A `SymbolTree` usually corresponding to forward method of a network. Args: handler (SymbolTreeImpl): SymbolTree internal implementation instance. """ def __init__(self, handler: SymbolTreeImpl): Validator.check_value_type("handler", handler, [SymbolTreeImpl], "SymbolTree") self._symbol_tree: SymbolTreeImpl = handler
[docs] @classmethod def create(cls, network): """ Create a new `SymbolTree` of the input `network`. Args: network (Cell): `network` used to create `SymbolTree`. Returns: Symboltree, a `Symboltree` created based on `network`. Raises: TypeError: If `network` is not a `Cell` instance. """ Validator.check_value_type("network", network, [Cell], "SymbolTree") return cls(SymbolTreeBuilder(network).build())
@staticmethod def _check_args_type(args): for arg in args: if arg not in MsDtypes and not isinstance(arg, ParamTypes): raise TypeError(f"For call-function Node, got unsupported arg: {arg}, type: {type(arg)}") @staticmethod def _check_kwargs_type(kwargs): for k, v in kwargs.items(): if not isinstance(k, str): raise TypeError(f"For call-function Node, key in kwarg must be a str, but got: {type(v)}",) if v not in MsDtypes and not isinstance(v, ParamTypes): raise TypeError(f"For call-function Node, got unsupported kwarg value: {v}, type: {type(v)}")
[docs] def create_call_function(self, func, targets, *args, **kwargs): r""" Create a Node object and generate the execution code to insert into the source code. The source code calls the 'func' function with 'args' and' kwargs' as parameters. Args: func (FunctionType): The function to be called. targets (list[str]): indicates the output name. As the output of the node in the source code. args (Union[MsDtypes, ParamTypes]): parameter name of the node. Used as a parameter to a code statement in source code. The default value is None, which means there is no parameter input in the cell. kwargs (dict{str,Union[MsDtypes, ParamTypes]}): The key type must be str, and the value must be value or type must be ParamTypes. The input parameter name used to describe the formal parameter with a keyword. Enter the name in the source code as the 'kwargs' in the statement expression.The default value is None, which means there is no 'kwargs' input. Returns: An instance of `Node`. Raises: TypeError: If `func` is not FunctionType. TypeError: If `targets` is not `list`. TypeError: If the type of `targets` is not str. TypeError: If arg in `args` is not ParamType. TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not ParamType. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> new_node = stree.create_call_function(F.abs, ["x"], node) """ Validator.check_value_type("func", func, [FunctionType], "SymbolTree node") Validator.check_element_type_of_iterable("targets", targets, [str], "SymbolTree node") args_ = list(args) SymbolTree._check_args_type(args_) for i, arg in enumerate(args_): if isinstance(arg, Node): args_[i] = arg.get_handler() SymbolTree._check_kwargs_type(kwargs) for key, value in kwargs.items(): if isinstance(value, Node): kwargs[key] = value.get_handler() return Node(self._symbol_tree.create_call_function(func, targets, args_, kwargs))
[docs] def get_handler(self) -> SymbolTreeImpl: """ Get handler of `SymbolTree` implementation. Returns: An instance of `SymbolTree`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> handler = stree.get_handler() """ return self._symbol_tree
[docs] def nodes(self): """ Get a generator for node of corresponding network. Returns: A generator for node of current `SymbolTree`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... node.set_attribute("channel", 3) """ for node in self._symbol_tree.nodes(): yield Node(node)
[docs] def get_node(self, node_name: str) -> Optional[Node]: """ Get node by `node_name`. Args: node_name (str): A string represents name of node. Returns: An instance of node if find else None. Raises: TypeError: If `node_name` is not `str`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") """ Validator.check_value_type("node_name", node_name, [str], "SymbolTree") node_impl = self._symbol_tree.get_node(node_name) if node_impl is None: return None return Node(node_impl)
def get_inputs(self) -> [Node]: return [Node(node_impl) for node_impl in self._symbol_tree.get_inputs()]
[docs] def before(self, node: Node): """ Get insert position before input `node`. `Position` is used to indicate where to insert node, it indicates position in source code rather than position in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as an arguments of `insert` api of `SymbolTree`. Args: node (Node): Indicate the position before which node. Can be a node or name of node. Returns: A `Position` to indicate where to insert node. Raises: TypeError: if `node` is not a `Node`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.before(node) """ Validator.check_value_type("node", node, [Node], "SymbolTree") return self._symbol_tree.before(node.get_handler())
[docs] def after(self, node: Node): """ Get insert position after input `node`. `Position` is used to indicate where to insert node, it indicates position in source code rather than position in topological order. We don't need to care about what `Position` is, just treat it as a handler and use it as an arguments of `insert` api of `SymbolTree`. Args: node (Node): Indicate the position after which node. Can be a node or name of node. Returns: A `Position` to indicate where to insert node. Raises: TypeError: If `node` is not a `Node`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> for node in stree.nodes(): ... if node.get_name() == "conv1": ... position = stree.after(node) """ Validator.check_value_type("node", node, [Node], "SymbolTree") return self._symbol_tree.after(node.get_handler())
[docs] def insert(self, position, node: Node) -> Node: """ Insert a `node` into `SymbolTree` at `position`. `position` is obtained from `before` api or `after` api of `SymbolTree`. Args: position (Position): Indicate where to insert `node`. node (Node): An instance of Node to be inserted. Returns: An instance of Node being inserted. `node` could be changed while calling this method for uniqueness and custom-object in args or kwargs. Raises: RuntimeError: If `position` is not belong to current `SymbolTree`. TypeError: If `position` is not a `Position`. TypeError: If `node` is not a `Node`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> position = stree.after(node) >>> new_node = stree.create_call_function(F.abs, ["x"], node) >>> stree.insert(position, new_node) """ Validator.check_value_type("position", position, [Position], "SymbolTree") Validator.check_value_type("node", node, [Node], "SymbolTree") return Node(self._symbol_tree.insert_node(position, node.get_handler()))
[docs] def erase_node(self, node: Node) -> Optional[Node]: """ Erase a `node` from rewrite. Can only erase a node not being depended on. Args: node (Node): A `Node` to be erased. Can be a node or name of node. Returns: An instance of `Node` being erased if node is in `SymbolTree` else None. Raises: TypeError: If `node` is not a `Node`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> input_node = node.get_inputs()[0] >>> output_nodes = node.get_users() >>> for n in output_nodes: ... n.set_arg(0, "x") >>> stree.erase_node(node) """ Validator.check_value_type("node", node, [Node], "SymbolTree") return Node(self._symbol_tree.erase_node(node.get_handler()))
[docs] def replace(self, old_node: Node, new_nodes: [Node]) -> Node: """ Replace `old_node` with a node_tree. Note: 1. Replace support one-to-one replacement or one-to-multi replacement. If you need multi-to-multi replacement, please refer to `PatternEngine`. 2. When applying one-to-multi replacement, Rewrite will insert all `new_nodes` into symbol_tree. 3. Caller should maintain arguments and targets of nodes intra sub-tree for specifying topological relation intra sub-tree. 4. Caller should maintain arguments of input nodes of sub-tree and for specifying topological relation of inputs of sub-tree. 5. Rewrite will maintain arguments of prepend node of sub-tree for specifying topological relation of outputs of sub-tree. 6. Rewrite will maintain all inputs of nodes after replace `new_nodes` into `SymbolTree`. Args: old_node (Node): Node to be replaced. new_nodes (list[Node]): Nodes of the node_tree to replace in. Returns: An instance of Node represents root of node_tree been replaced in. Raises: RuntimeError: Old node is not isolated. TypeError: If `old_node` is not a `Node`. TypeError: If `new_nodes` is not a `list` or node in `new_nodes` is not a `Node`. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> node = stree.get_node("conv1") >>> new_node = stree.create_call_function(F.abs, ["x"], node) >>> stree.replace(node, [new_node]) """ Validator.check_value_type("old_node", old_node, [Node], "SymbolTree") Validator.check_element_type_of_iterable("new_nodes", new_nodes, [Node], "SymbolTree") nodes_impl = [node.get_handler() for node in new_nodes] return Node(self._symbol_tree.replace(old_node.get_handler(), nodes_impl))
def set_output(self, index: int, return_value: str) -> Node: Validator.check_value_type("index", index, [int], "SymbolTree") Validator.check_value_type("return_value", return_value, [str], "SymbolTree") return Node(self._symbol_tree.set_output(return_value, index))
[docs] def dump(self): """ Print the ir map information corresponding to the network in 'SymbolTree' to the screen. """ self._symbol_tree.dump()
[docs] def print_node_tabulate(self): """ Print node information of graph. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> stree.print_node_tabulate() """ self._symbol_tree.print_node_tabulate()
[docs] def get_code(self) -> str: """ Get source code of modified network. Returns: A str represents source code of modified network. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> stree.get_code() """ return self._symbol_tree.get_code()
[docs] def get_network(self) -> Cell: """ Get modified network. The source code of network is saved to a file, the default file name is `network_define.py`. Returns: A network object. Examples: >>> from mindspore.rewrite import SymbolTree >>> from lenet import Lenet >>> net = Lenet() >>> stree = SymbolTree.create(net) >>> stree.get_network() """ return self._symbol_tree.get_network()
def set_saved_file_name(self, file_name: str): Validator.check_value_type("file_name", file_name, [str], "Saving network") self._symbol_tree.set_saved_file_name(file_name) def get_saved_file_name(self): return self._symbol_tree.get_saved_file_name() def save_network_to_file(self): self._symbol_tree.save_network_to_file()