Source code for mindspore_gl.parser.vcg

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Translation."""
import ast
import inspect
from types import MethodType
from textwrap import dedent
from ast_decompiler import decompile
from .infer_expr_type_pass import InferExprTypePass
from .check_syntax_pass import CheckSyntaxPass
from .ast_rewriter import AstRewriter
from .code_comparator import CodeComparator
from .utils import src_to_function

SCREEN_WIDTH = 200
DISPLAY = True


def set_display_config(screen_width, display):
    """
    Set screen width and display configure used for translate function.

    Args:
        screen_width (int): Determines the screen width on which the code is displayed.
        display (bool): Show code comparison or Not.
    """
    global SCREEN_WIDTH, DISPLAY
    SCREEN_WIDTH = screen_width
    DISPLAY = display


[docs]def translate(obj, method_name: str): """ Translate the vertex central code into MindSpore understandable code. After translation, a new function will generate in /.mindspore_gl. The origin method will be replaced with this function. Args: obj: (Object): The object. method_name (str): The name of the method to be translated. """ global SCREEN_WIDTH, DISPLAY fn = getattr(obj, method_name) src = inspect.getsource(fn) src = dedent(src) py_ast = ast.parse(src) syntax_checker = CheckSyntaxPass(fn.__globals__) ret = syntax_checker.analyze(py_ast) type_inferer = InferExprTypePass(ret, src) ret = type_inferer.analyze(py_ast) if DISPLAY: comparator = CodeComparator(SCREEN_WIDTH) comparator.record_origin_lineno(py_ast) rewriter = AstRewriter(ret) new_ast = rewriter.visit(py_ast) if DISPLAY: comparator.mapping_by_origin_lineno(new_ast) comparator.show_diff() new_src = decompile(new_ast) new_fn = src_to_function(new_src, method_name, fn.__globals__) new_fn.__module__ = fn.__module__ setattr(obj, method_name, MethodType(new_fn, obj))