Skip to content

Commit

Permalink
Refactor FedOpt
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Jan 10, 2025
1 parent e2a7180 commit aa11dc7
Showing 1 changed file with 100 additions and 93 deletions.
193 changes: 100 additions & 93 deletions fedn/network/combiner/aggregators/fedopt.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import math
import time
import traceback
from typing import Any, Dict, Optional, Tuple

from fedn.common.exceptions import InvalidParameterError
from fedn.common.log_config import logger
from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase
from fedn.utils.helpers.helperbase import HelperBase
from fedn.utils.parameters import Parameters


class Aggregator(AggregatorBase):
Expand All @@ -17,6 +20,9 @@ class Aggregator(AggregatorBase):
A server-side scheme is then applied, currenty supported schemes
are "adam", "yogi", "adagrad".
Limitations:
- Only supports one combiner.
- Momentum is reser for each new invokation of a training session.
:param control: A handle to the :class: `fedn.network.combiner.updatehandler.UpdateHandler`
:type control: class: `fedn.network.combiner.updatehandler.UpdateHandler`
Expand All @@ -31,43 +37,23 @@ def __init__(self, update_handler):
self.v = None
self.m = None

def combine_models(self, helper=None, delete_models=True, parameters=None):
"""Compute pseudo gradients using model updates in the queue.
:param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase`, ML framework specific helper, defaults to None
:type helper: class: `fedn.utils.helpers.helpers.HelperBase`, optional
:param time_window: The time window for model aggregation, defaults to 180
:type time_window: int, optional
:param max_nr_models: The maximum number of updates aggregated, defaults to 100
:type max_nr_models: int, optional
:param delete_models: Delete models from storage after aggregation, defaults to True
:type delete_models: bool, optional
:param parameters: Aggregator hyperparameters.
:type parameters: `fedn.utils.parmeters.Parameters`, optional
:return: The global model and metadata
:rtype: tuple
def combine_models(
self,
helper: Optional[HelperBase] = None,
delete_models: bool = True,
parameters: Optional[Parameters] = None
) -> Tuple[Optional[Any], Dict[str, float]]:
"""
data = {}
data["time_model_load"] = 0.0
data["time_model_aggregation"] = 0.0
Compute pseudo gradients using model updates in the queue.
# Define parameter schema
parameter_schema = {
"serveropt": str,
"learning_rate": float,
"beta1": float,
"beta2": float,
"tau": float,
}

try:
parameters.validate(parameter_schema)
except InvalidParameterError as e:
logger.error(
"Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e))
return None, data
:param helper: ML framework-specific helper, defaults to None.
:param delete_models: Delete models from storage after aggregation, defaults to True.
:param parameters: Aggregator hyperparameters, defaults to None.
:return: The global model and metadata.
"""
data = {"time_model_load": 0.0, "time_model_aggregation": 0.0}

# Default hyperparameters. Note that these may need fine tuning.
# Default hyperparameters
default_parameters = {
"serveropt": "adam",
"learning_rate": 1e-3,
Expand All @@ -76,101 +62,122 @@ def combine_models(self, helper=None, delete_models=True, parameters=None):
"tau": 1e-4,
}

# Validate parameters
if parameters:
try:
parameters.validate(parameter_schema)
except InvalidParameterError as e:
logger.error(
"Aggregator {} recieved invalid parameters. Reason {}".format(self.name, e))
return None, data
else:
logger.info("Aggregator {} using default parameteres.",
format(self.name))
parameters = self.default_parameters
# Validate and merge parameters
try:
parameters = self._validate_and_merge_parameters(
parameters, default_parameters)
except InvalidParameterError as e:
logger.error(
f"Aggregator {self.name} received invalid parameters: {e}")
return None, data

# Override missing paramters with defaults
for key, value in default_parameters.items():
if key not in parameters:
parameters[key] = value
logger.info(f"Aggregator {self.name} starting model aggregation.")

# Aggregation initialization
model, pseudo_gradient = None, None
pseudo_gradient, model_old = None, None
nr_aggregated_models, total_examples = 0, 0

logger.info(
"AGGREGATOR({}): Aggregating model updates... ".format(self.name))

while not self.update_handler.model_updates.empty():
try:
logger.info(
"AGGREGATOR({}): Getting next model update from queue.".format(self.name))
f"Aggregator {self.name}: Fetching next model update.")
model_update = self.update_handler.next_model_update()
# Load model paratmeters and metadata

tic = time.time()
model_next, metadata = self.update_handler.load_model_update(
model_update, helper)
data["time_model_load"] += time.time()-tic
data["time_model_load"] += time.time() - tic

logger.info("AGGREGATOR({}): Processing model update {}".format(
self.name, model_update.model_update_id))
logger.info(
f"Processing model update {model_update.model_update_id}")

# Increment total number of examples
# Increment total examples
total_examples += metadata["num_examples"]

tic = time.time()
if nr_aggregated_models == 0:
model_old = self.update_handler.load_model(
helper, model_update.model_id)
pseudo_gradient = helper.subtract(model_next, model_old)
else:
pseudo_gradient_next = helper.subtract(
model_next, model_old)
pseudo_gradient = helper.increment_average(
pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples)
data["time_model_aggregation"] += time.time()-tic
pseudo_gradient, model_old = self._update_pseudo_gradient(
helper, pseudo_gradient, model_next, model_old, metadata, nr_aggregated_models, total_examples
)
data["time_model_aggregation"] += time.time() - tic

nr_aggregated_models += 1
# Delete model from storage

if delete_models:
self.update_handler.delete_model(model_update)
logger.info("AGGREGATOR({}): Deleted model update {} from storage.".format(
self.name, model_update.model_update_id))
logger.info(
f"Deleted model update {model_update.model_update_id} from storage.")
except Exception as e:
logger.error(
"AGGREGATOR({}): Error encoutered while processing model update {}, skiphttps://github.com/scaleoutsystems/fedn/pull/770ping this update.".format(self.name, e))
f"Error processing model update: {e}. Skipping this update.")
logger.error(traceback.format_exc())
continue

data["nr_aggregated_models"] = nr_aggregated_models

if pseudo_gradient:
try:
if parameters["serveropt"] == "adam":
model = self.serveropt_adam(
helper, pseudo_gradient, model_old, parameters)
elif parameters["serveropt"] == "yogi":
model = self.serveropt_yogi(
helper, pseudo_gradient, model_old, parameters)
elif parameters["serveropt"] == "adagrad":
model = self.serveropt_adagrad(
helper, pseudo_gradient, model_old, parameters)
else:
logger.error(
"Unsupported server optimizer passed to FedOpt.")
return None, data
model = self._apply_server_optimizer(
helper, pseudo_gradient, model_old, parameters)
except Exception as e:
tb = traceback.format_exc()
logger.error(
"AGGREGATOR({}): Error encoutered while while aggregating: {}".format(self.name, e))
logger.error(tb)
logger.error(f"Error during model aggregation: {e}")
logger.error(traceback.format_exc())
return None, data
else:
return None, data

logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(
self.name, nr_aggregated_models))
logger.info(
f"Aggregator {self.name} completed. Aggregated {nr_aggregated_models} models.")
return model, data

def _validate_and_merge_parameters(
self, parameters: Optional[Parameters], default_parameters: Dict[str, Any]
) -> Dict[str, Any]:
"""Validate and merge default parameters."""
parameter_schema = {
"serveropt": str,
"learning_rate": float,
"beta1": float,
"beta2": float,
"tau": float,
}
if parameters:
parameters.validate(parameter_schema)
else:
logger.info(f"Aggregator {self.name} using default parameters.")
parameters = {}
return {**default_parameters, **parameters}

def _update_pseudo_gradient(
self, helper: HelperBase, pseudo_gradient: Any, model_next: Any, model_old: Any,
metadata: Dict[str, Any], nr_aggregated_models: int, total_examples: int
) -> Tuple[Any, Any]:
"""Update pseudo gradient based on the current model."""
if nr_aggregated_models == 0:
model_old = self.update_handler.load_model(
helper, metadata["model_id"])
pseudo_gradient = helper.subtract(model_next, model_old)
else:
pseudo_gradient_next = helper.subtract(model_next, model_old)
pseudo_gradient = helper.increment_average(
pseudo_gradient, pseudo_gradient_next, metadata["num_examples"], total_examples
)
return pseudo_gradient, model_old

def _apply_server_optimizer(
self, helper: HelperBase, pseudo_gradient: Any, model_old: Any, parameters: Dict[str, Any]
) -> Any:
"""Apply the selected server optimizer to compute the new model."""
optimizer_map = {
"adam": self.serveropt_adam,
"yogi": self.serveropt_yogi,
"adagrad": self.serveropt_adagrad,
}
optimizer = optimizer_map.get(parameters["serveropt"])
if not optimizer:
raise ValueError(
f"Unsupported server optimizer: {parameters['serveropt']}")
return optimizer(helper, pseudo_gradient, model_old, parameters)

def serveropt_adam(self, helper, pseudo_gradient, model_old, parameters):
"""Server side optimization, FedAdam.
Expand Down

0 comments on commit aa11dc7

Please sign in to comment.