Skip to content

Commit

Permalink
initial quickstart
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jan 28, 2025
1 parent f0e05c7 commit 99db989
Showing 1 changed file with 174 additions and 0 deletions.
174 changes: 174 additions & 0 deletions docs/source/quickstart.md
Original file line number Diff line number Diff line change
@@ -1 +1,175 @@
# Quickstart

## Installation

First, we need to install the `fl4health` package. The easiest and recommended
way to do this is via `pip`.

```sh
pip install fl4health
```

## A simple FL task

With federated learning, the model is trained collaboratively by a set of
distributed nodes called `clients`. This collaboration is facilitated by another
node, namely the `server` node. To setup an FL task we need to define our `Client`
as well as our `Server` in the scripts `client.py` and `server.py`, respectively.

### `client.py`

```python
from pathlib import Path

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.cnn_model import Net
from fl4health.clients.basic_client import BasicClient
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data
from fl4health.utils.metrics import Accuracy

class CifarClient(BasicClient):
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size)
return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> DataLoader | None:
batch_size = narrow_dict_type(config, "batch_size", int)
test_loader, _ = load_cifar10_test_data(self.data_path, batch_size)
return test_loader

def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def get_model(self, config: Config) -> nn.Module:
return Net().to(self.device)
```

### `server.py`

```python
from functools import partial
from typing import Any

import flwr as fl
from flwr.common.typing import Config
from flwr.server.client_manager import SimpleClientManager
from flwr.server.strategy import FedAvg

from examples.models.cnn_model import Net
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters


def fit_config(
batch_size: int,
current_server_round: int,
local_epochs: int | None = None,
local_steps: int | None = None,
) -> Config:
return {
**make_dict_with_epochs_or_steps(local_epochs, local_steps),
"batch_size": batch_size,
"current_server_round": current_server_round,
}


def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
config["batch_size"],
local_epochs=config.get("local_epochs"),
local_steps=config.get("local_steps"),
)

# Initializing the model on the server side
model = Net()
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointers = [
BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"),
LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"),
]
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
min_available_clients=config["n_clients"],
on_fit_config_fn=fit_config_fn,
# We use the same fit config function, as nothing changes for eval
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(model),
)

server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
strategy=strategy,
checkpoint_and_state_module=checkpoint_and_state_module,
accept_failures=False,
)

fl.server.start_server(
server=server,
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]),
)
```

## Running the FL task

Now that we have our server and clients defined, we can run the FL system!

### Starting Server

The next step is to start the server by running

```sh
python -m examples.basic_example.server --config_path /path/to/config.yaml
```

from the FL4Health directory. The following arguments must be present in the specified config file:
* `n_clients`: number of clients the server waits for in order to run the FL training
* `local_epochs`: number of epochs each client will train for locally
* `batch_size`: size of the batches each client will train on
* `n_server_rounds`: The number of rounds to run FL

### Starting Clients

Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the two
clients. This is done by simply running (remembering to activate your environment)

```sh
python -m examples.basic_example.client --dataset_path /path/to/data
```

**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If
the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be
automatically downloaded to the path specified and used in the run.

After both clients have been started federated learning should commence.

0 comments on commit 99db989

Please sign in to comment.