From aa11dc79e9ce24bbe8383135057dfb45d0bd0aa4 Mon Sep 17 00:00:00 2001 From: Andreas Hellander Date: Fri, 10 Jan 2025 12:10:09 +0100 Subject: [PATCH] Refactor FedOpt --- fedn/network/combiner/aggregators/fedopt.py | 193 ++++++++++---------- 1 file changed, 100 insertions(+), 93 deletions(-) diff --git a/fedn/network/combiner/aggregators/fedopt.py b/fedn/network/combiner/aggregators/fedopt.py index ff83b0c76..734b91dba 100644 --- a/fedn/network/combiner/aggregators/fedopt.py +++ b/fedn/network/combiner/aggregators/fedopt.py @@ -1,10 +1,13 @@ import math import time import traceback +from typing import Any, Dict, Optional, Tuple from fedn.common.exceptions import InvalidParameterError from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase +from fedn.utils.helpers.helperbase import HelperBase +from fedn.utils.parameters import Parameters class Aggregator(AggregatorBase): @@ -17,6 +20,9 @@ class Aggregator(AggregatorBase): A server-side scheme is then applied, currenty supported schemes are "adam", "yogi", "adagrad". + Limitations: + - Only supports one combiner. + - Momentum is reser for each new invokation of a training session. :param control: A handle to the :class: `fedn.network.combiner.updatehandler.UpdateHandler` :type control: class: `fedn.network.combiner.updatehandler.UpdateHandler` @@ -31,43 +37,23 @@ def __init__(self, update_handler): self.v = None self.m = None - def combine_models(self, helper=None, delete_models=True, parameters=None): - """Compute pseudo gradients using model updates in the queue. - - :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None - :type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional - :param time_window: The time window for model aggregation, defaults to 180 - :type time_window: int, optional - :param max_nr_models: The maximum number of updates aggregated, defaults to 100 - :type max_nr_models: int, optional - :param delete_models: Delete models from storage after aggregation, defaults to True - :type delete_models: bool, optional - :param parameters: Aggregator hyperparameters. - :type parameters: `fedn.utils.parmeters.Parameters`, optional - :return: The global model and metadata - :rtype: tuple + def combine_models( + self, + helper: Optional[HelperBase] = None, + delete_models: bool = True, + parameters: Optional[Parameters] = None + ) -> Tuple[Optional[Any], Dict[str, float]]: """ - data = {} - data["time_model_load"] = 0.0 - data["time_model_aggregation"] = 0.0 + Compute pseudo gradients using model updates in the queue. - # Define parameter schema - parameter_schema = { - "serveropt": str, - "learning_rate": float, - "beta1": float, - "beta2": float, - "tau": float, - } - - try: - parameters.validate(parameter_schema) - except InvalidParameterError as e: - logger.error( - "Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e)) - return None, data + :param helper: ML framework-specific helper, defaults to None. + :param delete_models: Delete models from storage after aggregation, defaults to True. + :param parameters: Aggregator hyperparameters, defaults to None. + :return: The global model and metadata. + """ + data = {"time_model_load": 0.0, "time_model_aggregation": 0.0} - # Default hyperparameters. Note that these may need fine tuning. + # Default hyperparameters default_parameters = { "serveropt": "adam", "learning_rate": 1e-3, @@ -76,101 +62,122 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): "tau": 1e-4, } - # Validate parameters - if parameters: - try: - parameters.validate(parameter_schema) - except InvalidParameterError as e: - logger.error( - "Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e)) - return None, data - else: - logger.info("Aggregator {} using default parameteres.", - format(self.name)) - parameters = self.default_parameters + # Validate and merge parameters + try: + parameters = self._validate_and_merge_parameters( + parameters, default_parameters) + except InvalidParameterError as e: + logger.error( + f"Aggregator {self.name} received invalid parameters: {e}") + return None, data - # Override missing paramters with defaults - for key, value in default_parameters.items(): - if key not in parameters: - parameters[key] = value + logger.info(f"Aggregator {self.name} starting model aggregation.") # Aggregation initialization - model, pseudo_gradient = None, None + pseudo_gradient, model_old = None, None nr_aggregated_models, total_examples = 0, 0 - logger.info( - "AGGREGATOR({}): Aggregating model updates... ".format(self.name)) - while not self.update_handler.model_updates.empty(): try: logger.info( - "AGGREGATOR({}): Getting next model update from queue.".format(self.name)) + f"Aggregator {self.name}: Fetching next model update.") model_update = self.update_handler.next_model_update() - # Load model paratmeters and metadata tic = time.time() model_next, metadata = self.update_handler.load_model_update( model_update, helper) - data["time_model_load"] += time.time()-tic + data["time_model_load"] += time.time() - tic - logger.info("AGGREGATOR({}): Processing model update {}".format( - self.name, model_update.model_update_id)) + logger.info( + f"Processing model update {model_update.model_update_id}") - # Increment total number of examples + # Increment total examples total_examples += metadata["num_examples"] tic = time.time() - if nr_aggregated_models == 0: - model_old = self.update_handler.load_model( - helper, model_update.model_id) - pseudo_gradient = helper.subtract(model_next, model_old) - else: - pseudo_gradient_next = helper.subtract( - model_next, model_old) - pseudo_gradient = helper.increment_average( - pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples) - data["time_model_aggregation"] += time.time()-tic + pseudo_gradient, model_old = self._update_pseudo_gradient( + helper, pseudo_gradient, model_next, model_old, metadata, nr_aggregated_models, total_examples + ) + data["time_model_aggregation"] += time.time() - tic nr_aggregated_models += 1 - # Delete model from storage + if delete_models: self.update_handler.delete_model(model_update) - logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format( - self.name, model_update.model_update_id)) + logger.info( + f"Deleted model update {model_update.model_update_id} from storage.") except Exception as e: logger.error( - "AGGREGATOR({}): Error encoutered while processing model update {}, skiphttps://github.com/scaleoutsystems/fedn/pull/770ping this update.".format(self.name, e)) + f"Error processing model update: {e}. Skipping this update.") + logger.error(traceback.format_exc()) + continue data["nr_aggregated_models"] = nr_aggregated_models if pseudo_gradient: try: - if parameters["serveropt"] == "adam": - model = self.serveropt_adam( - helper, pseudo_gradient, model_old, parameters) - elif parameters["serveropt"] == "yogi": - model = self.serveropt_yogi( - helper, pseudo_gradient, model_old, parameters) - elif parameters["serveropt"] == "adagrad": - model = self.serveropt_adagrad( - helper, pseudo_gradient, model_old, parameters) - else: - logger.error( - "Unsupported server optimizer passed to FedOpt.") - return None, data + model = self._apply_server_optimizer( + helper, pseudo_gradient, model_old, parameters) except Exception as e: - tb = traceback.format_exc() - logger.error( - "AGGREGATOR({}): Error encoutered while while aggregating: {}".format(self.name, e)) - logger.error(tb) + logger.error(f"Error during model aggregation: {e}") + logger.error(traceback.format_exc()) return None, data else: return None, data - logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format( - self.name, nr_aggregated_models)) + logger.info( + f"Aggregator {self.name} completed. Aggregated {nr_aggregated_models} models.") return model, data + def _validate_and_merge_parameters( + self, parameters: Optional[Parameters], default_parameters: Dict[str, Any] + ) -> Dict[str, Any]: + """Validate and merge default parameters.""" + parameter_schema = { + "serveropt": str, + "learning_rate": float, + "beta1": float, + "beta2": float, + "tau": float, + } + if parameters: + parameters.validate(parameter_schema) + else: + logger.info(f"Aggregator {self.name} using default parameters.") + parameters = {} + return {**default_parameters, **parameters} + + def _update_pseudo_gradient( + self, helper: HelperBase, pseudo_gradient: Any, model_next: Any, model_old: Any, + metadata: Dict[str, Any], nr_aggregated_models: int, total_examples: int + ) -> Tuple[Any, Any]: + """Update pseudo gradient based on the current model.""" + if nr_aggregated_models == 0: + model_old = self.update_handler.load_model( + helper, metadata["model_id"]) + pseudo_gradient = helper.subtract(model_next, model_old) + else: + pseudo_gradient_next = helper.subtract(model_next, model_old) + pseudo_gradient = helper.increment_average( + pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples + ) + return pseudo_gradient, model_old + + def _apply_server_optimizer( + self, helper: HelperBase, pseudo_gradient: Any, model_old: Any, parameters: Dict[str, Any] + ) -> Any: + """Apply the selected server optimizer to compute the new model.""" + optimizer_map = { + "adam": self.serveropt_adam, + "yogi": self.serveropt_yogi, + "adagrad": self.serveropt_adagrad, + } + optimizer = optimizer_map.get(parameters["serveropt"]) + if not optimizer: + raise ValueError( + f"Unsupported server optimizer: {parameters['serveropt']}") + return optimizer(helper, pseudo_gradient, model_old, parameters) + def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters): """Server side optimization, FedAdam.