Source code for mindspore_serving.worker.distributed.agent_startup

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Serving, distributed worker agent startup"""

import os
import time
import sys
import traceback
import signal
from multiprocessing import Process, Pipe
import threading
import psutil

from mindspore_serving._mindspore_serving import ExitSignalHandle_
from mindspore_serving._mindspore_serving import WorkerAgent_, AgentStartUpConfig_
from mindspore_serving._mindspore_serving import DistributedServableConfig_, OneRankConfig_

from mindspore_serving import log as logger
from mindspore_serving.common import check_type
from mindspore_serving.worker.distributed import worker_agent


def _get_local_ip(rank_list, port):
    """Get the local ip from the rank table config"""
    import socket
    ip_list = set()
    for item in rank_list:
        ip_list.add(item.ip)
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        for ip in ip_list:
            try:
                s.bind((ip, port))
                logger.info(f"Get local machine ip success, ip {ip}")
                return ip
            # pylint: disable=bare-except
            except:
                pass
    raise RuntimeError(f"Get local machine ip failed, rank table ips: {ip_list}, bind port {port}")


def _check_local_ip(agent_ip, port):
    """Check the local ip"""
    import socket
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        for i in range(8):
            try:
                s.bind((agent_ip, port + i))
                logger.info(f"Check local machine ip success, ip {agent_ip}")
                return True
            # pylint: disable=bare-except
            except:
                pass
    return False


def _update_model_files_path(model_files, group_config_files):
    """Check and return model files or group config files"""
    script_dir = os.path.dirname(os.path.realpath(sys.argv[0]))
    logger.info(f"input model files: {model_files}")
    logger.info(f"input group config files: {group_config_files}")
    model_files_temp = []
    for item in model_files:
        file_name = os.path.realpath(os.path.join(script_dir, item))
        if not os.access(file_name, os.R_OK):
            raise RuntimeError(f"Cannot access model file '{file_name}'")
        model_files_temp.append(file_name)

    if group_config_files is not None:
        group_files_temp = []
        for item in group_config_files:
            file_name = os.path.realpath(os.path.join(script_dir, item))
            if not os.access(file_name, os.R_OK):
                raise RuntimeError(f"Cannot access group config file '{file_name}'")
            group_files_temp.append(file_name)
    else:
        group_files_temp = None
    logger.info(f"absolute model files: {model_files_temp}")
    logger.info(f"absolute group config files: {group_files_temp}")
    return model_files_temp, group_files_temp


def _make_json_table_file(distributed_config):
    """Make rank table json file"""
    rank_size = len(distributed_config.rank_list)
    runtime_dir = os.path.abspath(".")
    time_stamp = str(time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())))
    rank_table_dir = os.path.join(runtime_dir, "temp_rank_table")
    try:
        os.mkdir(rank_table_dir)
    except FileExistsError:
        pass
    rank_table_file_name = os.path.join(rank_table_dir, f"hccl_rank_table_{time_stamp}_{rank_size}p.json")
    with open(rank_table_file_name, "w") as fp:
        fp.write(distributed_config.rank_table_content)
    return rank_table_file_name


signal_success = "Success"
signal_exit = "Exit"


def _recv_parent(parent_process, index, recv_pipe, handle_stop_signal=True):
    """Receive message from Start up process.
    Return False on Ctrl+C(and worker Stop message) Exit Signal, heartbeat failed, and signal_exit.
    Return True on receiving signal_success."""
    try:
        while True:
            while not recv_pipe.poll(0.1):
                if handle_stop_signal and ExitSignalHandle_.has_stopped():
                    logger.warning(f"Child {index}: Exit on Ctrl+C or stop message from worker")
                    return False
                if not parent_process.is_running():  # 3s
                    logger.warning(f"Child {index}: Exit on failure of exit of parent process")
                    return False
            parent_signal = recv_pipe.recv()
            break
        if parent_signal == signal_success:
            logger.info(f"Child {index}: Receive success")
            return True
        if parent_signal == signal_exit:
            logger.warning(f"Child {index}: Exit on receiving exit message")
    # pylint: disable=broad-except
    except Exception as e:
        logger.warning(f"Child {index}: Exit on exception: {e}")
    return False


def _agent_process(send_pipe, recv_pipe, index, start_config):
    """Agent process"""
    parent_process = psutil.Process(os.getppid())
    try:
        # listening success or failed message from parent process
        worker_agent.start_worker_agent(start_config=start_config)
        send_pipe.send((index, signal_success))
        success_msg = _recv_parent(parent_process, index, recv_pipe)
        if not success_msg:
            worker_agent.stop()
        send_pipe.close()
        recv_pipe.close()
        while not ExitSignalHandle_.has_stopped():
            if not parent_process.is_running():
                logger.warning(f"Child {index}, detect parent pid={parent_process.pid} has exited, child begin to exit")
                worker_agent.stop()
                return
            time.sleep(0.1)
    # pylint: disable=broad-except
    except Exception as e:
        traceback.print_exc()
        logger.error(f"Child {index}: Catch exception and notify exit of others")
        exception = RuntimeError(f"Child {index} exception happen: {e}")
        send_pipe.send((index, exception))
        _recv_parent(parent_process, index, recv_pipe, False)
        logger.error(f"Child {index}: end send message to parent")


def _send_pipe_msg(send_pipe, msg):
    """Send pipe message"""
    try:
        send_pipe.send(msg)
    # pylint: disable=broad-except
    except Exception as e:
        logger.warning(f"Send pipe message exception happen: {e}")


def _send_exit_signal_to_children(subprocess_list):
    """Send exit signal to all child processes, and terminate all child processes when they are still alive
    in some seconds later"""

    def wait_exit(wait_seconds, msg):
        for i in range(wait_seconds):
            all_exit = True
            for process in subprocess_list:
                if process.is_alive():
                    logger.warning(f"There are still child processes that have not exited and {msg} in "
                                   f"{wait_seconds - i} seconds.")
                    time.sleep(1)
                    all_exit = False
                    break
            if all_exit:
                logger.info(f"All Child process exited")
                return True
        return False

    if wait_exit(3, "SIGINT will be sent"):
        return
    # Send signal SIGINT
    for index, process in enumerate(subprocess_list):
        if process.is_alive():
            logger.warning(f"Send signal SIGINT to {index}")
            try:
                child_process = psutil.Process(process.pid)
                children_of_child = child_process.children(recursive=True)
                for item in children_of_child:
                    os.kill(item.pid, signal.SIGINT)
            # pylint: disable=broad-except
            except Exception as e:
                logger.warning(f"Get exception when send signal SIGINT to children of child {index}, exception: {e}")
            os.kill(process.pid, signal.SIGINT)

    if wait_exit(10, "will be forcibly killed"):
        return

    for index, process in enumerate(subprocess_list):
        if process.is_alive():
            logger.warning(f"Kill Child process {index}")
            try:
                child_process = psutil.Process(process.pid)
                children_of_child = child_process.children(recursive=True)
                for item in children_of_child:
                    os.kill(item.pid, signal.SIGKILL)
            # pylint: disable=broad-except
            except Exception as e:
                logger.warning(f"Get exception when send signal SIGKILL to children of child {index}, exception: {e}")
            os.kill(process.pid, signal.SIGKILL)


def _send_exit_msg_to_children(send_pipe_list, subprocess_list):
    """Send exit msg to all child processes, and terminate all child processes when they are still alive
    in some seconds later"""
    index = 0
    for send_pipe, process in zip(send_pipe_list, subprocess_list):
        if process.is_alive():
            logger.warning(f"Send exit message to Child {index}")
            _send_pipe_msg(send_pipe, signal_exit)
            logger.warning(f"End send exit message to Child {index}")
        else:
            logger.warning(f"Child {index} is not alive")
        index += 1
    _send_exit_signal_to_children(subprocess_list)


def _listening_agents_when_startup(p_recv_pipe, send_pipe_list, subprocess_list):
    """Listening child process"""
    count = len(send_pipe_list)
    for _ in range(count):
        while True:
            if p_recv_pipe.poll(0.1):
                break
            if ExitSignalHandle_.has_stopped():
                logger.warning("Fail to start agents because of Ctrl+C")
                _send_exit_msg_to_children(send_pipe_list, subprocess_list)
                raise RuntimeError("Fail to start agents because of Ctrl+C")
            for send_pipe, process in zip(send_pipe_list, subprocess_list):
                if process.is_alive():
                    continue
                logger.warning("Fail to start agents because of death of one agent")
                _send_exit_msg_to_children(send_pipe_list, subprocess_list)
                raise RuntimeError("Fail to start agents because of death of one agent")

        index, msg = p_recv_pipe.recv()
        logger.info(f"Receive msg from Child {index}: {msg}")
        if isinstance(msg, Exception):
            logger.warning("Fail to start agents because of exception raise by one agent")
            _send_exit_msg_to_children(send_pipe_list, subprocess_list)
            raise msg

    for send_pipe in send_pipe_list:
        _send_pipe_msg(send_pipe, signal_success)


def _listening_agents_after_startup(subprocess_list, worker_ip, worker_port, agent_ip):
    """Listening agent status after success start up of agents"""

    def wait_child_exit():
        while not ExitSignalHandle_.has_stopped():
            for index, process in enumerate(subprocess_list):
                if not process.is_alive():
                    logger.warning(f"Child {index}, pid={process.pid} has exited")
                    return
            time.sleep(0.1)

    def listening_thread_fun():
        wait_child_exit()
        WorkerAgent_.startup_notify_exit(worker_ip, worker_port, agent_ip)
        _send_exit_signal_to_children(subprocess_list)

    thread = threading.Thread(target=listening_thread_fun)
    thread.start()


def _startup_agents(common_meta, worker_ip, worker_port,
                    agent_ip, agent_start_port, device_id_list, rank_id_list,
                    model_files, group_config_files, rank_table_file):
    """Start up all agents in one machine"""
    servable_name = common_meta.servable_name
    send_pipe_list = []
    subprocess_list = []
    c_send_pipe, p_recv_pipe = Pipe()
    group_file = ""
    agents_count = len(device_id_list)
    for index in range(agents_count):
        device_id, rank_id, model_file = device_id_list[index], rank_id_list[index], model_files[index]
        if group_config_files is not None:
            group_file = group_config_files[index]

        p_send_pipe, c_recv_pipe = Pipe()
        send_pipe_list.append(p_send_pipe)

        agent_port = agent_start_port + index

        start_config = AgentStartUpConfig_()
        start_config.rank_id = rank_id
        start_config.device_id = device_id
        start_config.model_file_name = model_file
        start_config.group_file_name = group_file
        start_config.rank_table_json_file_name = rank_table_file
        start_config.agent_ip = agent_ip
        start_config.agent_port = agent_port
        start_config.worker_ip = worker_ip
        start_config.worker_port = worker_port
        start_config.common_meta = common_meta

        process = Process(target=_agent_process,
                          args=(c_send_pipe, c_recv_pipe, index, start_config),
                          name=f"{servable_name}_worker_agent_rank{rank_id}_device{device_id}")
        process.start()
        subprocess_list.append(process)

    msg = f"worker_ip: {worker_ip}, worker_port: {worker_port}, agent_ip: {agent_ip}, " \
          f"agent_start_port: {agent_start_port}, device ids: {device_id_list}, rank ids: {rank_id_list}, " \
          f"rank table file: {rank_table_file}, model files: {model_files}, group config files: {group_config_files}"

    try:
        _listening_agents_when_startup(p_recv_pipe, send_pipe_list, subprocess_list)
    # pylint: disable=broad-except
    except Exception as e:
        WorkerAgent_.notify_failed(worker_ip, worker_port)
        logger.error(f"Failed to start agents, {msg}")
        print(f"Failed to start agents, {msg}")
        raise e

    logger.info(f"Success to start agents, {msg}")
    print(f"Success to start agents, {msg}")
    _listening_agents_after_startup(subprocess_list, worker_ip, worker_port, agent_ip)


class DistributedServableConfig:
    """Python DistributedServableConfig"""

    def __init__(self):
        self.rank_table_content = ""
        self.rank_list = None
        self.common_meta = None
        self.distributed_meta = None

    def set(self, config):
        """Set from C++ DistributedServableConfig_ obj"""
        self.rank_table_content = config.rank_table_content
        self.rank_list = []
        for item in config.rank_list:
            new_item = {"device_id": item.device_id, "ip": item.ip}
            self.rank_list.append(new_item)
        self.common_meta = {"servable_name": config.common_meta.servable_name,
                            "with_batch_dim": config.common_meta.with_batch_dim,
                            "without_batch_dim_inputs": config.common_meta.without_batch_dim_inputs,
                            "inputs_count": config.common_meta.inputs_count,
                            "outputs_count": config.common_meta.outputs_count}

        self.distributed_meta = {"rank_size": config.distributed_meta.rank_size,
                                 "stage_size": config.distributed_meta.stage_size}

    def get(self):
        """Get as C++ DistributedServableConfig_ obj"""
        config = DistributedServableConfig_()
        config.rank_table_content = self.rank_table_content
        rank_list = []
        for item in self.rank_list:
            new_item = OneRankConfig_()
            new_item.device_id = item["device_id"]
            new_item.ip = item["ip"]
            rank_list.append(new_item)
        config.rank_list = rank_list
        config.common_meta.servable_name = self.common_meta["servable_name"]
        config.common_meta.with_batch_dim = self.common_meta["with_batch_dim"]
        config.common_meta.without_batch_dim_inputs = self.common_meta["without_batch_dim_inputs"]
        config.common_meta.inputs_count = self.common_meta["inputs_count"]
        config.common_meta.outputs_count = self.common_meta["outputs_count"]

        config.distributed_meta.rank_size = self.distributed_meta["rank_size"]
        config.distributed_meta.stage_size = self.distributed_meta["stage_size"]
        return config


def _get_worker_distributed_config(worker_ip, worker_port):
    """Get worker distributed config from worker through sub process"""
    c_send_pipe, p_recv_pipe = Pipe()

    def process_fun(c_send_pipe):
        try:
            distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port)
            config = DistributedServableConfig()
            config.set(distributed_config)
            c_send_pipe.send(config)
        # pylint: disable=broad-except
        except Exception as e:
            c_send_pipe.send(e)

    process = Process(target=process_fun, args=(c_send_pipe,),
                      name=f"worker_agent_get_agents_config_from_worker")
    process.start()
    process.join()
    assert not process.is_alive()
    if p_recv_pipe.poll(0.1):
        config = p_recv_pipe.recv()
        if isinstance(config, Exception):
            raise config
        distributed_config = config.get()
        return distributed_config
    raise RuntimeError(f"Failed to get agents config from worker")


[docs]def startup_worker_agents(worker_ip, worker_port, model_files, group_config_files=None, agent_start_port=7000, agent_ip=None, rank_start=None): r""" Start up all needed worker agenton current machine. Serving has two running modes. One is running in a single process, providing the Serving service of a single model. The other includes a master and multiple workers. This interface is for the second scenario. The master is responsible for providing the Serving access interface for clients, while the worker is responsible for providing the inference service of the specific model. The communications between the master and workers through gPRC are defined as (master_ip, master_port) and (worker_ip, worker_port). Args: worker_ip (str): The worker ip the agents linked to. worker_port (int): The worker port the agents linked to. model_files (list or tuple of str): All model files need in current machine, absolute path or path relative to this startup python script. group_config_files (None, list or tuple of str): All group config files need in current machine, absolute path or path relative to this startup python script, default None, which means there are no configuration files. agent_start_port (int): The starting agent port of the agents link to worker. agent_ip (str or None): The local agent ip, if it's None, the agent ip will be obtained from rank table file. Default None. Parameter agent_ip and parameter rank_start must have values at the same time, or both None at the same time. rank_start (int or None): The starting rank id of this machine, if it's None, the rank ip will be obtained from rank table file. Default None. Parameter agent_ip and parameter rank_start must have values at the same time, or both None at the same time. Examples: >>> import os >>> from mindspore_serving.worker import distributed >>> model_files = [] >>> for i in range(8): >>> model_files.append(f"models/device{i}/matmul.mindir") >>> distributed.startup_worker_agents(worker_ip="127.0.0.1", worker_port=6200, model_files=model_files) """ check_type.check_str("worker_ip", worker_ip) check_type.check_ip_port("worker_port", worker_port) check_type.check_int("agent_start_port", agent_start_port, 1, 65535 - 7) model_files = check_type.check_and_as_str_tuple_list("model_files", model_files) if group_config_files is not None: group_config_files = check_type.check_and_as_str_tuple_list("group_config_files", group_config_files) ExitSignalHandle_.start() distributed_config = _get_worker_distributed_config(worker_ip, worker_port) # distributed_config = WorkerAgent_.get_agents_config_from_worker(worker_ip, worker_port) # get machine ip rank_list = distributed_config.rank_list local_device_id_list = [] local_rank_id_list = [] if agent_ip is None: if rank_start is not None: raise RuntimeError("Parameter 'agent_ip' and parameter 'rank_start' must have values at the same time, " "or both None at the same time.") local_ip = _get_local_ip(rank_list, agent_start_port) # get all device_id and rank_id for rank_id, item in enumerate(rank_list): if item.ip == local_ip: local_device_id_list.append(item.device_id) local_rank_id_list.append(rank_id) else: if rank_start is None: raise RuntimeError("Parameter 'agent_ip' and parameter 'rank_start' must have values at the same time, " "or both None at the same time.") check_type.check_str("agent_ip", agent_ip) check_type.check_int("rank_start", rank_start, 0) if rank_start >= len(rank_list): raise RuntimeError(f"Parameter 'rank_start' cannot equal or larger than rank size {len(rank_list)}.") if not _check_local_ip(agent_ip, agent_start_port): raise RuntimeError(f"Check ip 'agent_ip' valid failed, agent_ip: {agent_ip}") local_ip = agent_ip rank_table_ip = rank_list[rank_start].ip for rank_id, item in enumerate(rank_list): if item.ip == rank_table_ip: local_device_id_list.append(item.device_id) local_rank_id_list.append(rank_id) # handle model files and group config files if len(local_device_id_list) != len(model_files): raise RuntimeError(f"Card count {local_device_id_list} described rank table does not equal to model files size " f"{len(model_files)}, model files: {model_files}") if group_config_files is not None and len(model_files) != len(group_config_files): raise RuntimeError(f"Model files count {len(model_files)} does not equal to group config files " f"count {len(group_config_files)} when group_config_files is not None, " f"model files: {model_files}, group config files: {group_config_files}") model_files, group_config_files = _update_model_files_path(model_files, group_config_files) # make json table file and export env rank_table_file = _make_json_table_file(distributed_config) _startup_agents(distributed_config.common_meta, worker_ip, worker_port, local_ip, agent_start_port, local_device_id_list, local_rank_id_list, model_files, group_config_files, rank_table_file)