Source code for mindspore_federated.trainer._fl_manager

# pylint: disable=missing-docstring
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FederatedLearningManager related class and functions."""

from copy import deepcopy
import numpy as np
import mindspore.ops as ops
from mindspore import nn
from mindspore.nn import Cell
from mindspore import load_param_into_net
from mindspore.communication.management import init, get_rank
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.train.callback import Callback
from mindspore_federated.common import _checkparam as validator
from mindspore_federated._mindspore_federated import Federated_, FLContext
from mindspore_federated import log as logger

from ..startup.ssl_config import init_ssl_config
from ..startup.yaml_config import load_yaml_config
from ..common import _fl_context, check_type

TRAIN_BEGIN_STEP_NUM = 1
TRAIN_END_STEP_NUM = 0


class _StartFLJob:
    """
    StartFLJob for Federated Learning Worker.
    """

    def __init__(self, data_size):
        self._data_size = data_size

    def construct(self):
        return Federated_.start_fl_job(self._data_size)


class _UpdateAndGetModel:
    """
    Update and Get Model for Federated Learning Worker.
    """

    def __init__(self, weights):
        super(_UpdateAndGetModel, self).__init__()
        self._weights = weights

    def construct(self):
        return Federated_.update_and_get_model(self._weights)


class _ExchangeKeys:
    """
    Exchange Keys for Stable PW Encrypt.
    """

    @staticmethod
    def construct():
        return Federated_.exchange_keys()


class _GetKeys:
    """
    Get Keys for Stable PW Encrypt.
    """

    @staticmethod
    def construct():
        return Federated_.get_keys()


class _PullWeight:
    """
    Pull Weight for Federated Learning Worker.
    """

    def __init__(self, pull_weight_params):
        self.pull_weight_params = pull_weight_params

    def construct(self):
        return Federated_.pull_weight(self.pull_weight_params)


class _PushWeight:
    """
    Push Weight for Federated Learning Worker.
    """

    def __init__(self, weights):
        self._weights = weights

    def construct(self):
        return Federated_.push_weight(self._weights)


class PushMetrics:
    """
    Push Metrics for Federated Learning Worker.
    """

    @staticmethod
    def construct(loss, accuracy):
        return Federated_.push_metrics(loss, accuracy)


class BroadcastNet(Cell):
    """
    Construct of weight input for Broadcast.
    """

    def __init__(self):
        super().__init__()
        self._broadcast = ops.Broadcast(0)

    def construct(self, input_x):
        return self._broadcast((input_x,))


def _get_fl_param_names(network, fl_param_names, requires_aggr=False):
    for sub_cell in network.cells():
        fl_param_names = _get_fl_param_names(sub_cell, fl_param_names, requires_aggr)
        if isinstance(sub_cell, nn.Optimizer):
            for k in sub_cell.parameters:
                if requires_aggr and not k.requires_aggr:
                    continue
                if k.name not in fl_param_names:
                    fl_param_names.append(k.name)
    return fl_param_names


def _get_lr(network):
    for sub_cell in network.cells():
        if isinstance(sub_cell, nn.Optimizer):
            return sub_cell.get_lr().asnumpy()
        lr = _get_lr(sub_cell)
        if lr is not None:
            return lr
    return None


[docs]class FederatedLearningManager(Callback): """ Manage Federated Learning during training. Args: yaml_config (str): The yaml file path. For more detail see `federated_server_yaml <https://gitee.com/mindspore/federated/blob/master/docs/api/api_python_en/horizontal/federated_server_yaml.md>`_. model (nn.Cell): A model for Federated Training. sync_frequency (int): Synchronization frequency of parameters in Federated Learning. Indicating the number of steps between two adjacent synchronization operations when `dataset_sink_mode` is set to False. If `sync_type` is set to "fixed", it serves as a fixed number of steps. If `sync_type` is set to "adaptive", it serves as the initial value of the adaptive synchronization frequency. Note that its function is changed in dataset sink mode. If `dataset_sink_mode` is set to True and `sink_size` is set to a non-positive value, the synchronization operation will execute once every `sync_frequency` epochs. If `dataset_sink_mode` is set to True and `sink_size` is set to a positive value, the synchronization operation will execute once every `sink_size` * `sync_frequency` steps. The `dataset_sink_mode` and the `sink_size` is set by users in `mindspore.train.Model` . http_server_address (str): The http server address used for communicating. Default: "". data_size (int): The data size to be reported to the worker. Default: 1. sync_type (str): The synchronization type of parameter in Federated Learning. Supports ["fixed", "adaptive"]. Default: "fixed". - fixed: The frequency of parameter synchronization is fixed. - adaptive: The frequency of parameter synchronization changes adaptively. run_distribute (bool): Whether to open distribute training. Default: False. ssl_config (Union(None, SSLConfig)): Config of ssl. Default: None. min_consistent_rate (float): Minimum consistency ratio threshold. The greater the value, the more difficult it is to improve the synchronization frequency. Value range: greater than or equal to 0.0. Default: 1.1. min_consistent_rate_at_round (int): The number of rounds of the minimum consistency ratio threshold. The greater the value, the more difficult it is to improve the synchronization frequency. Value range: greater than or equal to 0. Default: 0. ema_alpha (float): Gradient consistency smoothing coefficient. The smaller the value, the more the frequency will be judged according to the gradient bifurcation of the current round more. Otherwise it will be judged according to the historical gradient bifurcation more. Value range: (0.0, 1.0). Default: 0.5. observation_window_size (int): The number of rounds in the observation time window. The greater the value, the more difficult it is to reduce the synchronization frequency. Value range: greater than 0. Default: 5. frequency_increase_ratio (int): Frequency increase amplitude. The greater the value, the greater the frequency increase amplitude. Value range: greater than 0. Default: 2. unchanged_round (int): The number of rounds whose frequency does not change. The frequency is unchanged before unchanged_round rounds. Value range: greater than or equal to 0. Default: 0. Examples: >>> from mindspore_federated import FederatedLearningManager >>> from mindspore import nn, Model >>> from network.lenet import LeNet5, create_dataset_from_folder >>> network = LeNet5(62, 3) >>> federated_learning_manager = FederatedLearningManager( ... yaml_config="default_yaml_config.yaml", ... model=network, ... sync_frequency=100, ... http_server_address="127.0.0.1:10086", ... ) >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> net_opt = nn.Momentum(network.trainable_params(), 0.001, 0.9) >>> model = Model(network, net_loss, net_opt) >>> dataset = create_dataset_from_folder("0/train/", 32, 16, 1) >>> model.train(100, dataset, callbacks=[federated_learning_manager], dataset_sink_mode=False) """ def __init__(self, yaml_config, model, sync_frequency, http_server_address="", data_size=1, sync_type='fixed', run_distribute=False, ssl_config=None, **kwargs): super(FederatedLearningManager, self).__init__() check_type.check_str("yaml_config", yaml_config) init_ssl_config(ssl_config) load_yaml_config(yaml_config, _fl_context.ROLE_OF_SERVER) ctx = FLContext.get_instance() server_mode = ctx.server_mode() aggregation_type = ctx.aggregation_type() encrypt_type = ctx.encrypt_type() ctx.set_http_server_address(http_server_address) initial_model = {} for param in model.trainable_params(): param_data = np.reshape(param.asnumpy(), -1) initial_model[param.name] = param_data Federated_.init_federated_worker(initial_model) validator.check_isinstance('model', model, nn.Cell) validator.check_positive_int(sync_frequency) validator.check_string(sync_type, ["fixed", "adaptive"]) self._server_mode = server_mode self._model = model self._sync_frequency = sync_frequency self._next_begin_sync_iter = 1 self._next_end_sync_iter = self._sync_frequency self._data_size = data_size self._sync_type = sync_type self._run_distribute = run_distribute if self._run_distribute: init() self._broadcast = BroadcastNet() self._rank_id = get_rank() logger.info(f"Rank id is {self._rank_id}") self._global_step = 0 self._aggregation_type = aggregation_type self._global_prefix = "global_weights" if self._aggregation_type not in _fl_context.SUPPORT_AGG_TYPES and \ self._server_mode == _fl_context.SERVER_MODE_CLOUD: raise ValueError( "aggregation_type must be in {}, but got {}.".format(_fl_context.SUPPORT_AGG_TYPES, self._aggregation_type)) if self._aggregation_type in (_fl_context.FEDPROX, _fl_context.FEDNOVA): self._global_weights = ParameterTuple(self._model.trainable_params()).clone(prefix=self._global_prefix) for param in self._global_weights: param.requires_grad = False self._model.insert_param_to_cell(param.name, param, False) self._encrypt_type = encrypt_type if self._encrypt_type not in _fl_context.SUPPORT_ENC_TYPES_CLOUD and \ self._server_mode == _fl_context.SERVER_MODE_CLOUD: raise ValueError( "encrypt_mode must be in {}, but got {}.".format(_fl_context.SUPPORT_ENC_TYPES_CLOUD, self._encrypt_type)) if self._is_adaptive_sync(): self._as_set_init_state(kwargs) self._as_wrap_cell() logger.info(f"Step number needs to run per iteration {self._next_end_sync_iter}," f"server mode {self._server_mode}, aggregation type {self._aggregation_type}," f"encrypt type {self._encrypt_type}, http server address {http_server_address}") self._fl_param_names = list() self._fl_param_names = _get_fl_param_names(self._model, self._fl_param_names) if not self._fl_param_names: self._fl_param_names = [_.name for _ in self._model.trainable_params()] self._last_params = dict() self._local_control_params = dict() self._global_control_params = dict() self._scaffold_prefix = "control." if self._is_scaffold(): for param in self._model.trainable_params(): if param.name in self._fl_param_names: self._last_params[param.name] = deepcopy(param.asnumpy()) self._local_control_params[param.name] = np.zeros_like(param.asnumpy()) self._global_control_params[param.name] = np.zeros_like(param.asnumpy()) def __del__(self): Federated_.stop_federated_worker() def _is_adaptive_sync(self): """ Determine whether adaptive frequency synchronization is required. """ return self._sync_type == "adaptive" def _is_scaffold(self): """ Determine whether scaffold is required. """ return self._aggregation_type == _fl_context.SCAFFOLD def _is_fednova(self): """ Determine whether FedNova is required. """ return self._aggregation_type == _fl_context.FEDNOVA def _as_set_init_state(self, kwargs): """ Setting the initial state for adaptive synchronization. """ self._as_prefix = "as_abs_grad." self._min_consistent_rate = kwargs.get("min_consistent_rate", 1.1) validator.check_non_negative_float(self._min_consistent_rate) self._min_consistent_rate_at_round = kwargs.get("min_consistent_rate_at_round", 0) validator.check_non_negative_int(self._min_consistent_rate_at_round) self._ema_alpha = kwargs.get("ema_alpha", 0.5) validator.check_float_range(self._ema_alpha, 0.0, 1.0, validator.INC_NEITHER) self._observation_window_size = kwargs.get("observation_window_size", 5) validator.check_positive_int(self._observation_window_size) self._frequency_increase_ratio = kwargs.get("frequency_increase_ratio", 2) validator.check_positive_int(self._frequency_increase_ratio) self._unchanged_round = kwargs.get("unchanged_round", 0) validator.check_non_negative_int(self._unchanged_round) self._round_id = 0 self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params() if self._as_prefix not in _.name} self._model_size = 0 self._grads_ema = dict() self._abs_grads_ema = dict() for param in self._model.trainable_params(): if self._as_prefix not in param.name: self._model_size += np.product(param.shape) self._grads_ema[param.name] = np.zeros(param.shape) self._abs_grads_ema[param.name] = np.zeros(param.shape) self._model_size = float(self._model_size) def _as_wrap_cell(self): """ Wrap Cell for adaptive synchronization. """ param_list = list() for param in self._model.trainable_params(): new_param = param.clone() new_param.name = self._as_prefix + param.name param_list.append(new_param) for param in param_list: self._model.insert_param_to_cell(param.name, param, False) def _as_set_grads(self): """ Set the absolute value of the gradient for adaptive synchronization. """ abs_grads = dict() for param in self._model.trainable_params(): if self._as_prefix not in param.name: abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name]) for param in self._model.trainable_params(): if self._as_prefix in param.name: param.set_data(Parameter(abs_grads[param.name])) def _as_analyze_gradient(self): """ Analysis of relevant statistics based on gradient for adaptive synchronization. """ ctx = FLContext.get_instance() worker_num = int(ctx.start_fl_job_threshold() * ctx.update_model_ratio()) ema_alpha = self._ema_alpha consistent_rate_sum = 0.0 grads = dict() abs_grads = dict() for param in self._model.trainable_params(): if self._as_prefix in param.name: abs_grads[param.name.replace(self._as_prefix, '')] = param.asnumpy() * worker_num else: grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num for last_p in self._last_param: self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p] self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[last_p] divide_base = np.where(self._abs_grads_ema[last_p] == 0, np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p]) layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base consistent_rate_sum += np.sum(layer_consistent_rate) consistent_rate = float(consistent_rate_sum / self._model_size) if self._min_consistent_rate > consistent_rate: self._min_consistent_rate = consistent_rate self._min_consistent_rate_at_round = self._round_id elif self._round_id - self._min_consistent_rate_at_round > self._observation_window_size and \ self._sync_frequency > 1 and self._round_id > self._unchanged_round: self._sync_frequency = (self._sync_frequency + self._frequency_increase_ratio - 1) \ // self._frequency_increase_ratio self._min_consistent_rate = 1.1 self._min_consistent_rate_at_round = self._round_id self._observation_window_size *= self._frequency_increase_ratio for param in self._model.trainable_params(): if self._as_prefix not in param.name: self._grads_ema[param.name] = np.zeros(param.shape) self._abs_grads_ema[param.name] = np.zeros(param.shape) def _as_set_last_param(self): """ Set the value of last parameters for adaptive synchronization. """ self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params() if self._as_prefix not in _.name} def _start_pull_weight(self): """ Pull weight from server in hybrid training mode. """ logger.info("Try to pull weights. Local step number: {}".format(self._global_step)) # The worker has to train self._sync_frequency standalone iterations before it communicates with server. if self._global_step % self._sync_frequency != TRAIN_BEGIN_STEP_NUM: return pull_weight_params = list() pull_weight_params = _get_fl_param_names(self._model, pull_weight_params, True) if not pull_weight_params: pull_weight_params = [_.name for _ in self._model.trainable_params()] weight_infos = {} for param in self._model.trainable_params(): if param.name not in pull_weight_params: continue param_np = param.asnumpy() if param_np.dtype != np.float32: continue weight_infos[param.name] = (param_np.shape, param_np.dtype) pull_weight = _PullWeight(pull_weight_params) weights = pull_weight.construct() if not weights: raise ValueError("Weights from pulling weight is empty!") parameter_dict = {} for key, value in weights.items(): if key not in weight_infos: continue shape, dtype = weight_infos[key] param_data = np.reshape(value, shape).astype(dtype) parameter_dict[key] = Parameter(Tensor(param_data), name=key) load_param_into_net(self._model, parameter_dict) def _update_model_with_distribute(self, weights, weight_infos): """ Update model with distributed training mode. """ if self._rank_id == 0: update_and_get_model = _UpdateAndGetModel(weights) feature_map = update_and_get_model.construct() if not feature_map: raise ValueError("Feature map from getting model is empty!") parameter_dict = {} for key, weight_info in weight_infos.items(): if not feature_map[key]: continue value = feature_map[key] shape, dtype = weight_info[0], weight_info[1] param_data = np.reshape(value, shape).astype(dtype) tensor = Tensor(param_data) parameter_dict[key] = Parameter(tensor, name=key) self._broadcast(tensor) load_param_into_net(self._model, parameter_dict) else: parameter_dict = {} for key, weight_info in weight_infos.items(): value = weights[key] shape, dtype = weight_info[0], weight_info[1] param_data = np.reshape(value, shape).astype(dtype) received_tensor = self._broadcast(Tensor(param_data)) parameter_dict[key] = Parameter(received_tensor[0], name=key) load_param_into_net(self._model, parameter_dict) def _update_model(self, weights, weight_infos): """ Update and get model without distributed training mode. """ update_and_get_model = _UpdateAndGetModel(weights) feature_map = update_and_get_model.construct() if not feature_map: raise ValueError("Feature map from getting model is empty!") parameter_dict = {} parameter_dict_global = {} for key, value in feature_map.items(): if key not in weight_infos: continue shape, dtype = weight_infos[key] param_data = np.reshape(value, shape).astype(dtype) parameter_dict[key] = Parameter(Tensor(param_data), name=key) parameter_dict_global[self._global_prefix + "." + key] = \ Parameter(Tensor(param_data), name=self._global_prefix + "." + key) load_param_into_net(self._model, parameter_dict) if self._aggregation_type in (_fl_context.FEDPROX, _fl_context.FEDNOVA): load_param_into_net(self._model, parameter_dict_global) def _start_push_weight(self): """ Push weight to server in hybrid training mode. """ logger.info("Try to push weights. Local step number: {}".format(self._global_step)) if self._global_step % self._sync_frequency != TRAIN_END_STEP_NUM: return push_weight_params = list() push_weight_params = _get_fl_param_names(self._model, push_weight_params, True) if not push_weight_params: push_weight_params = [_.name for _ in self._model.trainable_params()] weights = dict() for param in self._model.trainable_params(): if param.name not in push_weight_params: continue weight = param.asnumpy().reshape(-1).tolist() weights[param.name] = weight push_weight = _PushWeight(weights) push_weight.construct() def _scaffold_set_global_control_params(self, flattened_control_params): for name in self._global_control_params: control_name = self._scaffold_prefix + name if control_name in flattened_control_params: global_control_param = deepcopy(flattened_control_params[control_name]) shape = self._global_control_params[name].shape self._global_control_params[name] = np.array(global_control_param, dtype=np.float32).reshape(shape) else: raise ValueError("'{}' is not in control parameters sent by server".format(control_name)) def _scaffold_update_params(self, lr): """ Using control parameters to update parameters every step. """ for param in self._model.trainable_params(): name = param.name if name in self._fl_param_names: if name in self._global_control_params: global_control_param = self._global_control_params[name] else: raise ValueError("'{}' is not in global_control_params".format(name)) if name in self._local_control_params: local_control_param = self._local_control_params[name] else: raise ValueError("'{}' is not in local_control_params".format(name)) control_params = lr * (global_control_param - local_control_param) param.set_data(Tensor(param.asnumpy() - control_params)) def _scaffold_get_control_params(self, lr): """ Get updated control parameters. """ control_params = dict() for param in self._model.trainable_params(): name = param.name if name in self._fl_param_names: if name in self._local_control_params: local_control_param = deepcopy(self._local_control_params[name]) else: raise ValueError("'{}' is not in local_control_params".format(name)) if name in self._global_control_params: global_control_param = deepcopy(self._global_control_params[name]) else: raise ValueError("'{}' is not in global_control_params".format(name)) temp1 = local_control_param - global_control_param if name in self._last_params: temp2 = (self._last_params[name] - param.asnumpy()) / (self._sync_frequency * lr) else: raise ValueError("'{}' is not in last_params".format(name)) control_params[name] = temp1 + temp2 return control_params def _scaffold_set_last_params_and_local_control_params(self, control_params): for param in self._model.trainable_params(): name = param.name if name in self._fl_param_names: self._last_params[name] = deepcopy(param.asnumpy()) if name in control_params: self._local_control_params[name] = control_params[name] else: raise ValueError("'{}' is not in control_params".format(name)) def _model_params_to_weights_dict(self, weights, weight_infos): """Exact trainable params from model, then fill into weights and weights_infos""" for param in self._model.trainable_params(): if self._global_prefix not in param.name: param_np = param.asnumpy() if param_np.dtype != np.float32: continue weight_infos[param.name] = (param_np.shape, param_np.dtype) weights[param.name] = param_np.reshape(-1).tolist() def _model_params_to_weights_diff_dict(self, weights, weight_infos): """Exact trainable params from model, then calculate diff value for FedNova""" local_params = list(filter(lambda x: self._global_prefix not in x.name and x.requires_grad, self._model.get_parameters())) global_params = list(filter(lambda x: self._global_prefix in x.name and not x.requires_grad, self._model.get_parameters())) for local_param, global_param in zip(local_params, global_params): param = local_param - global_param param_np = param.asnumpy() weight_infos[local_param.name] = (param_np.shape, param_np.dtype) weights[local_param.name] = param_np.reshape(-1).tolist() def on_train_step_begin(self, run_context): self._global_step += 1 is_cloud = self._server_mode == _fl_context.SERVER_MODE_CLOUD is_sync = self._global_step == self._next_begin_sync_iter is_dist = self._rank_id == 0 if self._run_distribute else not self._run_distribute if is_cloud and is_sync and is_dist: # In FedNova mode, the upload _data_size will be reset to the number of training steps if self._is_fednova(): cb_params = run_context.original_args() self._data_size = cb_params.batch_num * self._sync_frequency \ if cb_params.dataset_sink_mode else self._sync_frequency start_fl_job = _StartFLJob(self._data_size) flattened_control_params = start_fl_job.construct() if self._is_scaffold() and self._global_step != 1: self._scaffold_set_global_control_params(flattened_control_params) logger.debug("run_context is %r", run_context) def on_train_step_end(self, run_context): lr = 0.0 if self._is_scaffold(): cb_params = run_context.original_args() train_network = cb_params.train_network lr = _get_lr(train_network) if lr is None: raise ValueError("Can not find optimizer in train network!") self._scaffold_update_params(lr) if self._server_mode == _fl_context.SERVER_MODE_CLOUD: if self._global_step == self._next_end_sync_iter: if self._is_adaptive_sync(): self._as_set_grads() if self._encrypt_type == _fl_context.ENCRYPT_STABLE_PW: exchange_keys = _ExchangeKeys() exchange_keys.construct() get_keys = _GetKeys() get_keys.construct() control_params = dict() if self._is_scaffold(): control_params = self._scaffold_get_control_params(lr) weights = {} weight_infos = {} if self._is_fednova(): self._model_params_to_weights_diff_dict(weights, weight_infos) else: self._model_params_to_weights_dict(weights, weight_infos) if self._is_scaffold(): for name in control_params: delta_control_param = control_params[name] - self._local_control_params[name] weights[self._scaffold_prefix + name] = delta_control_param.reshape(-1).tolist() if self._run_distribute: self._update_model_with_distribute(weights, weight_infos) else: self._update_model(weights, weight_infos) if self._is_scaffold(): self._scaffold_set_last_params_and_local_control_params(control_params) logger.info("Load params from getting model into net, global step is {}.".format(self._global_step)) self._next_end_sync_iter = self._global_step + self._sync_frequency self._next_begin_sync_iter = self._global_step + 1 if self._is_adaptive_sync(): self._as_analyze_gradient() self._round_id += 1 self._as_set_last_param() cb_params = run_context.original_args() logger.info( "total epoch num:{}, batch num:{}, Current epoch num is: {}, Current step num is: {}".format( cb_params.epoch_num, cb_params.batch_num, cb_params.cur_epoch_num, cb_params.cur_step_num)) elif self._server_mode == _fl_context.SERVER_MODE_HYBRID: self._start_pull_weight() self._start_push_weight()