generated from VectorInstitute/aieng-template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,175 @@ | ||
# Quickstart | ||
|
||
## Installation | ||
|
||
First, we need to install the `fl4health` package. The easiest and recommended | ||
way to do this is via `pip`. | ||
|
||
```sh | ||
pip install fl4health | ||
``` | ||
|
||
## A simple FL task | ||
|
||
With federated learning, the model is trained collaboratively by a set of | ||
distributed nodes called `clients`. This collaboration is facilitated by another | ||
node, namely the `server` node. To setup an FL task we need to define our `Client` | ||
as well as our `Server` in the scripts `client.py` and `server.py`, respectively. | ||
|
||
### `client.py` | ||
|
||
```python | ||
from pathlib import Path | ||
|
||
import flwr as fl | ||
import torch | ||
import torch.nn as nn | ||
from flwr.common.typing import Config | ||
from torch.nn.modules.loss import _Loss | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
|
||
from examples.models.cnn_model import Net | ||
from fl4health.clients.basic_client import BasicClient | ||
from fl4health.utils.config import narrow_dict_type | ||
from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data | ||
from fl4health.utils.metrics import Accuracy | ||
|
||
class CifarClient(BasicClient): | ||
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: | ||
batch_size = narrow_dict_type(config, "batch_size", int) | ||
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) | ||
return train_loader, val_loader | ||
|
||
def get_test_data_loader(self, config: Config) -> DataLoader | None: | ||
batch_size = narrow_dict_type(config, "batch_size", int) | ||
test_loader, _ = load_cifar10_test_data(self.data_path, batch_size) | ||
return test_loader | ||
|
||
def get_criterion(self, config: Config) -> _Loss: | ||
return torch.nn.CrossEntropyLoss() | ||
|
||
def get_optimizer(self, config: Config) -> Optimizer: | ||
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) | ||
|
||
def get_model(self, config: Config) -> nn.Module: | ||
return Net().to(self.device) | ||
``` | ||
|
||
### `server.py` | ||
|
||
```python | ||
from functools import partial | ||
from typing import Any | ||
|
||
import flwr as fl | ||
from flwr.common.typing import Config | ||
from flwr.server.client_manager import SimpleClientManager | ||
from flwr.server.strategy import FedAvg | ||
|
||
from examples.models.cnn_model import Net | ||
from examples.utils.functions import make_dict_with_epochs_or_steps | ||
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer | ||
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule | ||
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger | ||
from fl4health.servers.base_server import FlServer | ||
from fl4health.utils.config import load_config | ||
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn | ||
from fl4health.utils.parameter_extraction import get_all_model_parameters | ||
|
||
|
||
def fit_config( | ||
batch_size: int, | ||
current_server_round: int, | ||
local_epochs: int | None = None, | ||
local_steps: int | None = None, | ||
) -> Config: | ||
return { | ||
**make_dict_with_epochs_or_steps(local_epochs, local_steps), | ||
"batch_size": batch_size, | ||
"current_server_round": current_server_round, | ||
} | ||
|
||
|
||
def main(config: dict[str, Any]) -> None: | ||
# This function will be used to produce a config that is sent to each client to initialize their own environment | ||
fit_config_fn = partial( | ||
fit_config, | ||
config["batch_size"], | ||
local_epochs=config.get("local_epochs"), | ||
local_steps=config.get("local_steps"), | ||
) | ||
|
||
# Initializing the model on the server side | ||
model = Net() | ||
# To facilitate checkpointing | ||
parameter_exchanger = FullParameterExchanger() | ||
checkpointers = [ | ||
BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"), | ||
LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"), | ||
] | ||
checkpoint_and_state_module = BaseServerCheckpointAndStateModule( | ||
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers | ||
) | ||
|
||
# Server performs simple FedAveraging as its server-side optimization strategy | ||
strategy = FedAvg( | ||
min_fit_clients=config["n_clients"], | ||
min_evaluate_clients=config["n_clients"], | ||
# Server waits for min_available_clients before starting FL rounds | ||
min_available_clients=config["n_clients"], | ||
on_fit_config_fn=fit_config_fn, | ||
# We use the same fit config function, as nothing changes for eval | ||
on_evaluate_config_fn=fit_config_fn, | ||
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, | ||
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, | ||
initial_parameters=get_all_model_parameters(model), | ||
) | ||
|
||
server = FlServer( | ||
client_manager=SimpleClientManager(), | ||
fl_config=config, | ||
strategy=strategy, | ||
checkpoint_and_state_module=checkpoint_and_state_module, | ||
accept_failures=False, | ||
) | ||
|
||
fl.server.start_server( | ||
server=server, | ||
server_address="0.0.0.0:8080", | ||
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), | ||
) | ||
``` | ||
|
||
## Running the FL task | ||
|
||
Now that we have our server and clients defined, we can run the FL system! | ||
|
||
### Starting Server | ||
|
||
The next step is to start the server by running | ||
|
||
```sh | ||
python -m examples.basic_example.server --config_path /path/to/config.yaml | ||
``` | ||
|
||
from the FL4Health directory. The following arguments must be present in the specified config file: | ||
* `n_clients`: number of clients the server waits for in order to run the FL training | ||
* `local_epochs`: number of epochs each client will train for locally | ||
* `batch_size`: size of the batches each client will train on | ||
* `n_server_rounds`: The number of rounds to run FL | ||
|
||
### Starting Clients | ||
|
||
Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the two | ||
clients. This is done by simply running (remembering to activate your environment) | ||
|
||
```sh | ||
python -m examples.basic_example.client --dataset_path /path/to/data | ||
``` | ||
|
||
**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If | ||
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be | ||
automatically downloaded to the path specified and used in the run. | ||
|
||
After both clients have been started federated learning should commence. |