Skip to content

Commit

Permalink
Feature/SK-1356 | Added example of keyword detection using client api (
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-andersson authored Feb 7, 2025
1 parent 83041d2 commit 9e5ce66
Show file tree
Hide file tree
Showing 10 changed files with 832 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/pytorch-keyworddetection-api/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data
*.npz
*.yaml
62 changes: 62 additions & 0 deletions examples/pytorch-keyworddetection-api/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
FEDn Project: Keyword Detection (PyTorch)
-----------------------------

This is an example to showcase how to set up FEDnClient and use APIClient to setup and manage a training from python.
The machine learning project is based on the Speech Commands dataset from Google, https://huggingface.co/datasets/google/speech_commands.

The example is intented as a minimalistic quickstart to learn how to use FEDn.


**Note: It is recommended to complete the example in https://docs.scaleoutsystems.com/en/stable/quickstart.html before starting this example **
Prerequisites
-------------

- `Python >=3.9, <=3.12 <https://www.python.org/downloads>`__
- `A project in FEDn Studio <https://fedn.scaleoutsystems.com/signup>`__

Installing pre requirements and creating seed model
-------------------------------------------

There are two alternatives to install the required packages, either using conda or pip.

.. code-block::
conda env create -n <name-of-env> --file env.yaml
Or if you rather use pip to install the packages:

.. code-block::
pip install -r requirements.txt
Note that you in the case of installing with pip need to install either sox (macos or linux) or soundfile (windows) depending on your platform as this is a requirement for the torchaudio package.


Clone this repository, then locate into this directory:

.. code-block::
git clone https://github.com/scaleoutsystems/fedn.git
cd fedn/examples/pytorch-keyworddetection-api
Next we need to setup the APIClient. This link https://docs.scaleoutsystems.com/en/stable/apiclient.html helps you to get the hostname and access token. Edit the file fedn_api.py and insert your HOST and TOKEN.

Next, generate the seed model:

.. code-block::
python fedn_api.py --init-seed
This will create a model file 'seed.npz' in the root of the project and upload it to the server.


Now we need to start the clients, download at set of client configutations following the quickstart tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html#start-clients.

Start the clients with the following command:
.. code-block::
python client_sc.py --dataset-split-idx 0 --client-yaml client0.yaml
where each client is started with a different dataset split index and client yaml file.

215 changes: 215 additions & 0 deletions examples/pytorch-keyworddetection-api/client_sc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
"""Client SC Example for PyTorch Keyword Detection API.
This module contains the implementation of the client for the federated learning
network using PyTorch for keyword detection.
"""

import argparse
import io
from pathlib import Path

from data import get_dataloaders
from model import compile_model, load_parameters, model_hyperparams, save_parameters
from settings import BATCHSIZE_VALID, DATASET_PATH, DATASET_TOTAL_SPLITS, KEYWORDS
from torch import nn
from torch.optim import Adam
from torcheval.metrics import MulticlassAccuracy
from util import construct_api_url, read_settings

from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient


def main() -> None:
"""Parse arguments and start the client."""
parser = argparse.ArgumentParser(description="Client SC Example")
parser.add_argument("--client-yaml", type=str, required=False, help="Settings specfic for the client (default: client.yaml)")
parser.add_argument("--dataset-split-idx", type=int, required=True, help="Setting for which dataset split this client uses")

parser.set_defaults(client_yaml="client.yaml")
args = parser.parse_args()

start_client(args.client_yaml, args.dataset_split_idx)


def start_client(client_yaml: str, dataset_split_idx: int) -> None:
"""Start the client with the given configuration and dataset split index.
Args:
client_yaml (str): Path to the client configuration YAML file.
dataset_split_idx (int): Index of the dataset split to use.
"""
DATASET_SPLIT_IDX = dataset_split_idx

cfg = load_client_config(client_yaml)
url = construct_api_url(cfg["api_url"], cfg.get("api_port", None))

fedn_client = FednClient(
train_callback=lambda params, settings: on_train(params, settings, DATASET_SPLIT_IDX),
validate_callback=lambda params: on_validate(params, DATASET_SPLIT_IDX),
predict_callback=lambda params: on_predict(params, DATASET_SPLIT_IDX),
)

configure_fedn_client(fedn_client, cfg)

result, combiner_config = fedn_client.connect_to_api(url, cfg["token"], get_controller_config(fedn_client))

if result != ConnectToApiResult.Assigned:
print("Failed to connect to API, exiting.")
return

if not fedn_client.init_grpchandler(config=combiner_config, client_name=fedn_client.client_id, token=cfg["token"]):
return

fedn_client.run()


def load_client_config(client_yaml: str) -> dict:
"""Load the client configuration from a YAML file.
Args:
client_yaml (str): Path to the client configuration YAML file.
Returns:
dict: The client configuration as a dictionary.
"""
if Path(client_yaml).exists():
cfg = read_settings(client_yaml)
else:
raise Exception(f"Client yaml file not found: {client_yaml}")

if "discover_host" in cfg:
cfg["api_url"] = cfg["discover_host"]

return cfg


def configure_fedn_client(fedn_client: FednClient, cfg: dict) -> None:
"""Configure the FednClient with the given settings.
Args:
fedn_client (FednClient): The FednClient instance to configure.
cfg (dict): The configuration dictionary containing client settings.
"""
fedn_client.set_name(cfg["name"])
fedn_client.set_client_id(cfg["client_id"])


def get_controller_config(fedn_client: FednClient) -> dict:
"""Get the controller configuration for the FednClient.
Args:
fedn_client (FednClient): The FednClient instance.
Returns:
dict: The controller configuration dictionary.
"""
return {
"name": fedn_client.name,
"client_id": fedn_client.client_id,
"package": "local",
"preferred_combiner": "",
}


def on_train(model_params, settings, dataset_split_idx) -> tuple:
"""Train the model with the given parameters and settings.
Args:
model_params: The model parameters.
settings: The training settings.
dataset_split_idx: The index of the dataset split to use.
Returns:
tuple: The trained model parameters and metadata.
"""
training_metadata = {"batchsize_train": 64, "lr": 1e-3, "n_epochs": 1}

dataloader_train, _, _ = get_dataloaders(
DATASET_PATH, KEYWORDS, dataset_split_idx, DATASET_TOTAL_SPLITS, training_metadata["batchsize_train"], BATCHSIZE_VALID
)

sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)
optimizer = Adam(sc_model.parameters(), lr=training_metadata["lr"])
loss_fn = nn.CrossEntropyLoss()
n_epochs = training_metadata["n_epochs"]

for epoch in range(n_epochs):
sc_model.train()
for idx, (y_labels, x_spectrograms) in enumerate(dataloader_train):
optimizer.zero_grad()
_, logits = sc_model(x_spectrograms)

loss = loss_fn(logits, y_labels)
loss.backward()

optimizer.step()

if idx % 100 == 0:
print(f"Epoch {epoch + 1}/{n_epochs} | Batch: {idx + 1}/{len(dataloader_train)} | Loss: {loss.item()}")

out_model = save_parameters(sc_model, io.BytesIO())

metadata = {"training_metadata": {"num_examples": len(dataloader_train.dataset)}}

return out_model, metadata


def on_validate(model_params, dataset_split_idx) -> dict:
"""Validate the model with the given parameters and dataset split index.
Args:
model_params: The model parameters.
dataset_split_idx: The index of the dataset split to use.
Returns:
dict: The validation report containing training and validation accuracy.
"""
dataloader_train, dataloader_valid, dataloader_test = get_dataloaders(
DATASET_PATH, KEYWORDS, dataset_split_idx, DATASET_TOTAL_SPLITS, BATCHSIZE_VALID, BATCHSIZE_VALID
)

n_labels = dataloader_train.dataset.n_labels

sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)

def evaluate(dataloader) -> float:
sc_model.eval()
metric = MulticlassAccuracy(num_classes=n_labels)
for y_labels, x_spectrograms in dataloader:
probs, _ = sc_model(x_spectrograms)

y_pred = probs.argmax(-1)
metric.update(y_pred, y_labels)
return metric.compute().item()

return {"training_acc": evaluate(dataloader_train), "validation_acc": evaluate(dataloader_valid)}


def on_predict(model_params, dataset_split_idx) -> dict:
"""Generate predictions using the model parameters and dataset split index.
Args:
model_params: The model parameters.
dataset_split_idx: The index of the dataset split to use.
Returns:
dict: The prediction results.
"""
dataloader_train, _, _ = get_dataloaders(DATASET_PATH, KEYWORDS, dataset_split_idx, DATASET_TOTAL_SPLITS, BATCHSIZE_VALID, BATCHSIZE_VALID)
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)

return {}


if __name__ == "__main__":
main()
Loading

0 comments on commit 9e5ce66

Please sign in to comment.