Source code for mindspore.train.callback._fl_manager

# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FederatedLearningManager related class and functions."""

from copy import deepcopy
import numpy as np
from mindspore import context, nn
from mindspore.common import Parameter, ParameterTuple
from mindspore.train.callback import Callback
from mindspore.ops import operations as P
from mindspore._checkparam import Validator, Rel


class _StartFLJob(nn.Cell):
    """
    StartFLJob for Federated Learning Worker.
    """

    def __init__(self, data_size):
        super(_StartFLJob, self).__init__()
        self.start_fl_job = P.StartFLJob(data_size)

    def construct(self):
        succ = self.start_fl_job()
        return succ


class _UpdateAndGetModel(nn.Cell):
    """
    Update and Get Model for Federated Learning Worker.
    """

    def __init__(self, weights, encrypt_type=""):
        super(_UpdateAndGetModel, self).__init__()
        self.update_model = P.UpdateModel(encrypt_type)
        self.get_model = P.GetModel()
        self.weights = weights

    def construct(self):
        self.update_model(self.weights)
        succ = self.get_model(self.weights)
        return succ


class _ExchangeKeys(nn.Cell):
    """
    Exchange Keys for Stable PW Encrypt.
    """

    def __init__(self):
        super(_ExchangeKeys, self).__init__()
        self.exchange_keys = P.ExchangeKeys()

    def construct(self):
        return self.exchange_keys()


class _GetKeys(nn.Cell):
    """
    Get Keys for Stable PW Encrypt.
    """

    def __init__(self):
        super(_GetKeys, self).__init__()
        self.get_keys = P.GetKeys()

    def construct(self):
        return self.get_keys()


[docs]class FederatedLearningManager(Callback): """ Manage Federated Learning during training. Args: model (nn.Cell): A training model. sync_frequency (int): Synchronization frequency of parameters in Federated Learning. Note that in dataset sink mode, the unit of the frequency is the number of epochs. Otherwise, the unit of the frequency is the number of steps. sync_type (str): Parameter synchronization type in Federated Learning. Supports ["fixed", "adaptive"]. Default: "fixed". - fixed: The frequency of parameter synchronization is fixed. - adaptive: The frequency of parameter synchronization changes adaptively. Note: This is an experimental prototype that is subject to change. """ def __init__(self, model, sync_frequency, sync_type='fixed', **kwargs): super(FederatedLearningManager, self).__init__() server_mode = context.get_fl_context("server_mode") if server_mode not in ("FEDERATED_LEARNING", "HYBRID_TRAINING"): raise ValueError("server_mode must in (\"FEDERATED_LEARNING\", \"HYBRID_TRAINING\")") Validator.check_isinstance('model', model, nn.Cell) Validator.check_positive_int(sync_frequency) Validator.check_string(sync_type, ["fixed", "adaptive"]) self._model = model self._sync_frequency = sync_frequency self._next_sync_iter_id = self._sync_frequency self._sync_type = sync_type self._global_step = 0 self._data_size = 0 self._encrypt_type = kwargs.get("encrypt_type", "NOT_ENCRYPT") if self._encrypt_type != "NOT_ENCRYPT" and self._encrypt_type != "STABLE_PW_ENCRYPT": raise ValueError( "encrypt_mode must be 'NOT_ENCRYPT' or 'STABLE_PW_ENCRYPT', but got {}.".format(self._encrypt_type)) if self._is_adaptive_sync(): self._as_set_init_state(kwargs) self._as_wrap_cell() def _is_adaptive_sync(self): """ Determine whether adaptive frequency synchronization is required. """ return self._sync_type == "adaptive" def _as_set_init_state(self, kwargs): """ Setting the initial state for adaptive synchronization. """ self._as_prefix = "as_abs_grad." self._min_consistent_rate = kwargs.get("min_consistent_rate", 1.1) Validator.check_non_negative_float(self._min_consistent_rate) self._min_consistent_rate_at_round = kwargs.get("min_consistent_rate_at_round", 0) Validator.check_non_negative_int(self._min_consistent_rate_at_round) self._ema_alpha = kwargs.get("ema_alpha", 0.5) Validator.check_float_range(self._ema_alpha, 0.0, 1.0, Rel.INC_NEITHER) self._observation_window_size = kwargs.get("observation_window_size", 5) Validator.check_positive_int(self._observation_window_size) self._frequency_increase_ratio = kwargs.get("frequency_increase_ratio", 2) Validator.check_positive_int(self._frequency_increase_ratio) self._unchanged_round = kwargs.get("unchanged_round", 0) Validator.check_non_negative_int(self._unchanged_round) self._round_id = 0 self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params() if self._as_prefix not in _.name} self._model_size = 0 self._grads_ema = dict() self._abs_grads_ema = dict() for param in self._model.trainable_params(): if self._as_prefix not in param.name: self._model_size += np.product(param.shape) self._grads_ema[param.name] = np.zeros(param.shape) self._abs_grads_ema[param.name] = np.zeros(param.shape) self._model_size = float(self._model_size) def _as_wrap_cell(self): """ Wrap Cell for adaptive synchronization. """ param_list = list() for param in self._model.trainable_params(): new_param = param.clone() new_param.name = self._as_prefix + param.name param_list.append(new_param) for param in param_list: self._model.insert_param_to_cell(param.name, param, False) def _as_set_grads(self): """ Set the absolute value of the gradient for adaptive synchronization. """ abs_grads = dict() for param in self._model.trainable_params(): if self._as_prefix not in param.name: abs_grads[self._as_prefix + param.name] = np.abs(param.asnumpy() - self._last_param[param.name]) for param in self._model.trainable_params(): if self._as_prefix in param.name: param.set_data(Parameter(abs_grads[param.name])) def _as_analyze_gradient(self): """ Analysis of relevant statistics based on gradient for adaptive synchronization. """ worker_num = context.get_fl_context("worker_num") ema_alpha = self._ema_alpha consistent_rate_sum = 0.0 grads = dict() abs_grads = dict() for param in self._model.trainable_params(): if self._as_prefix in param.name: abs_grads[param.name.replace(self._as_prefix, '')] = param.asnumpy() * worker_num else: grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num for last_p in self._last_param: self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p] self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[last_p] divide_base = np.where(self._abs_grads_ema[last_p] == 0, np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p]) layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base consistent_rate_sum += np.sum(layer_consistent_rate) consistent_rate = float(consistent_rate_sum / self._model_size) if self._min_consistent_rate > consistent_rate: self._min_consistent_rate = consistent_rate self._min_consistent_rate_at_round = self._round_id elif self._round_id - self._min_consistent_rate_at_round > self._observation_window_size and \ self._sync_frequency > 1 and self._round_id > self._unchanged_round: self._sync_frequency = (self._sync_frequency + self._frequency_increase_ratio - 1) \ // self._frequency_increase_ratio self._min_consistent_rate = 1.1 self._min_consistent_rate_at_round = self._round_id self._observation_window_size *= self._frequency_increase_ratio for param in self._model.trainable_params(): if self._as_prefix not in param.name: self._grads_ema[param.name] = np.zeros(param.shape) self._abs_grads_ema[param.name] = np.zeros(param.shape) def _as_set_last_param(self): """ Set the value of last parameters for adaptive synchronization. """ self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params() if self._as_prefix not in _.name}
[docs] def step_end(self, run_context): """ Synchronization parameters at the end of step. If sync_type is "adaptive", the synchronous frequency is adaptively adjusted here. Args: run_context (RunContext): Include some information of the model. """ self._global_step += 1 cb_params = run_context.original_args() self._data_size += cb_params.batch_num if context.get_fl_context("ms_role") == "MS_WORKER": if self._global_step == self._next_sync_iter_id: start_fl_job = _StartFLJob(self._data_size) start_fl_job() self._data_size = 0 if self._is_adaptive_sync(): self._as_set_grads() if self._encrypt_type == "STABLE_PW_ENCRYPT": exchange_keys = _ExchangeKeys() exchange_keys() get_keys = _GetKeys() get_keys() update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()), self._encrypt_type) else: update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params())) update_and_get_model() self._next_sync_iter_id = self._global_step + self._sync_frequency if self._is_adaptive_sync(): self._as_analyze_gradient() self._round_id += 1 self._as_set_last_param() print("sync step is: {}".format(self._global_step))