Source code for mindspore_serving.server.register.pipeline

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pipelineing registration interface"""

from mindspore_serving._mindspore_serving import PipelineStorage_, RequestSpec_
from mindspore_serving import log as logger
from mindspore_serving.server.common import check_type
from .utils import get_servable_dir, get_func_name


class PipelineStorage:
    """Register and get pipeline info, pipeline info include: func, name, input and output count"""

    def __init__(self):
        self.pipeline = {}
        self.is_register = False

    def clear(self):
        self.pipeline = {}
        self.is_register = False

    def has_registered(self):
        return self.is_register

    def register(self, fun, pipeline_name, inputs_count, outputs_count):
        self.pipeline[pipeline_name] = {"fun": fun, "inputs_count": inputs_count, "outputs_count": outputs_count}
        self.is_register = True

    def get(self, pipeline_name):
        pipeline = self.pipeline.get(pipeline_name, None)
        if pipeline is None:
            raise RuntimeError(f"Pipeline '{pipeline_name}' not found")
        return pipeline


pipeline_storage = PipelineStorage()


def register_pipeline_args(func, inputs_count, outputs_count):
    """register pipeline"""
    servable_name = get_servable_dir()
    func_name = get_func_name(func)
    name = servable_name + "." + func_name

    logger.info(f"Register pipeline {name} {inputs_count} {outputs_count}")
    pipeline_storage.register(func, name, inputs_count, outputs_count)


[docs]class PipelineServable: """Create Pipeline Servable for Servable calls. .. warning:: This is a beta interface and may be changed in the future. Args: servable_name (str): The name of servable. method (str): The name of method supplied by servable. version_number (int, optional): The number of version supplied by servable. Default: 0. Raises: RuntimeError: The type or value of the parameters is invalid, or other errors happened. Examples: >>> from mindspore_serving.server import distributed >>> from mindspore_serving.server import register >>> >>> distributed.declare_servable(rank_size=8, stage_size=1, with_batch_dim=False) >>> @register.register_method(output_names=["y"]) >>> def fun(x): ... y = register.call_servable(x) ... return y >>> servable = register.PipelineServable(servable_name="service", method="fun") >>> @register.register_pipeline(output_names=["y"]) >>> def predict(x): ... y = servable.run(x) ... return y """ def __init__(self, servable_name, method, version_number=0): check_type.check_str('servable_name', servable_name) check_type.check_str('method', method) check_type.check_int('version_number', version_number, 0) self.spec = RequestSpec_() self.storage = PipelineStorage_.get_instance() self.spec.servable_name = servable_name self.spec.method_name = method self.spec.version_number = version_number
[docs] def run(self, *args): """ Servable calls function in Pipeline register function. Args: args (numpy.ndarray): One or more input numpy arrays. Returns: numpy.ndarray, A numpy array object is returned if there is only one output; otherwise, a numpy array tuple is returned. Raises: RuntimeError: The type or value of the parameters is invalid, or other errors happened. """ arg_list = [] if len(args) != 1 or not isinstance(args[0], list): arg_list.append(args) result = self.storage.run(self.spec, arg_list) if len(result[0]) == 1: return result[0][0] return result[0] for arg in args[0]: arg_list.append((arg,)) return self.storage.run(self.spec, arg_list)