mindspore_federated.startup.vertical_federated_local 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Interface for start up vertical federated communicator"""
from mindspore_federated._mindspore_federated import VerticalFederated_
from mindspore_federated.common import tensor_utils, data_join_utils
from mindspore_federated.data_join.context import _WorkerRegister

from .server_config import ServerConfig, init_server_config
from .ssl_config import SSLConfig, init_vertical_ssl_config, init_vertical_enable_ssl


[文档]class VerticalFederatedCommunicator: """ Define the vertical federated communicator. Args: http_server_config (ServerConfig): Configuration of local http server. remote_server_config (ServerConfig): Configuration of remote http server. enable_ssl (bool, optional): whether to enable ssl communication. Default: ``False``. ssl_config (SSLConfig, optional): Configuration of ssl encryption. Default: ``None``. compress_configs (dict, optional): Configuration of communication compression. Default: ``None``. Examples: >>> from mindspore_federated import VerticalFederatedCommunicator, ServerConfig >>> http_server_config = ServerConfig(server_name='server', server_address="127.0.0.1:1086") >>> remote_server_config = ServerConfig(server_name='client', server_address="127.0.0.1:1087") >>> vertical_communicator = VerticalFederatedCommunicator(http_server_config=http_server_config, ... remote_server_config=remote_server_config) >>> vertical_communicator.launch() >>> vertical_communicator.data_join_wait_for_start() """ def __init__(self, http_server_config, remote_server_config, enable_ssl=False, ssl_config=None, compress_configs=None): if http_server_config is not None and not isinstance(http_server_config, ServerConfig): raise RuntimeError( f"Parameter 'http_server_config' should be instance of ServerConfig," f"but got {type(http_server_config)}") if remote_server_config is not None and not isinstance(remote_server_config, ServerConfig) \ and not isinstance(remote_server_config, list): raise RuntimeError( f"Parameter 'remote_server_address' should be instance of ServerConfig or list," f"but got {type(remote_server_config)}") if ssl_config is not None and not isinstance(ssl_config, SSLConfig): raise RuntimeError( f"Parameter 'ssl_config' should be None or instance of SSLConfig, but got {type(ssl_config)}") if compress_configs is not None and not isinstance(compress_configs, dict): raise RuntimeError( f"Parameter 'compress_configs' should be None or instance of dict, but got {type(compress_configs)}") self._http_server_config = http_server_config self._remote_server_config = remote_server_config self._enable_ssl = enable_ssl self._ssl_config = ssl_config init_vertical_enable_ssl(self._enable_ssl) init_vertical_ssl_config(self._ssl_config) init_server_config(self._http_server_config, self._remote_server_config) self._compress_configs = compress_configs if compress_configs is not None else {}
[文档] def launch(self): """ Start vertical federated learning communicator. """ VerticalFederated_.start_vertical_communicator()
[文档] def send_tensors(self, target_server_name, tensor_dict): """ Send distributed training sensor data. Args: target_server_name (str): Specifies the name of the remote server. tensor_dict (OrderedDict): The dict of Tensors to be sent. Examples: >>> from mindspore import Tensor >>> backbone_out = OrderedDict() >>> backbone_out["hidden_states"] = Tensor(np.random.random(size=(2, 2048, 1280)).astype(np.float16)) >>> vertical_communicator.send_tensors("leader", backbone_out) """ tensor_list_item_py = tensor_utils.tensor_dict_to_tensor_list_pybind_obj( ts_dict=tensor_dict, name="", compress_configs=self._compress_configs) return VerticalFederated_.send_tensor_list(target_server_name, tensor_list_item_py)
def send_register(self, target_server_name: str, worker_register: _WorkerRegister): worker_register_item_py = data_join_utils.worker_register_to_pybind_obj(worker_register) worker_config_item_py = VerticalFederated_.send_worker_register(target_server_name, worker_register_item_py) return data_join_utils.pybind_obj_to_worker_config(worker_config_item_py)
[文档] def receive(self, target_server_name: str): """ Get the sensor data sent by the remote server. Args: target_server_name (str): Specifies the name of the remote server. """ tensor_list_item_py = VerticalFederated_.receive(target_server_name) _, tensor_dict = tensor_utils.tensor_list_pybind_obj_to_tensor_dict(tensor_list_item_py) return tensor_dict
[文档] def data_join_wait_for_start(self): """ Block and wait for the registration information of the client worker. """ return VerticalFederated_.data_join_wait_for_start()
[文档] def http_server_config(self): """ Returns the local server configuration. """ return self._http_server_config
[文档] def remote_server_config(self): """ Returns the remote server configuration. """ return self._remote_server_config