Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Wrede committed Jan 30, 2025
1 parent 99809e9 commit 8fce5a8
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 78 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch-keyworddetection-api/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The machine learning project is based on the Speech Commands dataset from Google
The example is intented as a minimalistic quickstart to learn how to use FEDn.


**Note: It is recommended complete the example in https://docs.scaleoutsystems.com/en/stable/quickstart.html before starting this example **
**Note: It is recommended to complete the example in https://docs.scaleoutsystems.com/en/stable/quickstart.html before starting this example **
Prerequisites
-------------
Expand Down
227 changes: 150 additions & 77 deletions examples/pytorch-keyworddetection-api/client_sc.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from torch import nn
from torch.optim import Adam
"""Client SC Example for PyTorch Keyword Detection API.
This module contains the implementation of the client for the federated learning
network using PyTorch for keyword detection.
"""

import argparse
import io
from pathlib import Path

from torcheval.metrics import MulticlassAccuracy

from fedn.network.clients.fedn_client import FednClient, ConnectToApiResult

from data import get_dataloaders
from model import model_hyperparams, compile_model, load_parameters, save_parameters
from settings import KEYWORDS, BATCHSIZE_VALID, DATASET_PATH, DATASET_TOTAL_SPLITS
from model import compile_model, load_parameters, model_hyperparams, save_parameters
from settings import BATCHSIZE_VALID, DATASET_PATH, DATASET_TOTAL_SPLITS, KEYWORDS
from torch import nn
from torch.optim import Adam
from torcheval.metrics import MulticlassAccuracy
from util import construct_api_url, read_settings

from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient


def main():
def main() -> None:
"""Parse arguments and start the client."""
parser = argparse.ArgumentParser(description="Client SC Example")
parser.add_argument("--client-yaml", type=str, required=False, help="Settings specfic for the client (default: client.yaml)")
parser.add_argument("--dataset-split-idx", type=int, required=True, help="Setting for which dataset split this client uses")
Expand All @@ -27,115 +31,184 @@ def main():
start_client(args.client_yaml, args.dataset_split_idx)


def start_client(client_yaml: str, dataset_split_idx: int) -> None:
"""Start the client with the given configuration and dataset split index.
def start_client(client_yaml, dataset_split_idx):
#Constant in this context
Args:
client_yaml (str): Path to the client configuration YAML file.
dataset_split_idx (int): Index of the dataset split to use.
"""
DATASET_SPLIT_IDX = dataset_split_idx

cfg = load_client_config(client_yaml)
url = construct_api_url(cfg["api_url"], cfg.get("api_port", None))

fedn_client = FednClient(
train_callback=lambda params, settings: on_train(params, settings, DATASET_SPLIT_IDX),
validate_callback=lambda params: on_validate(params, DATASET_SPLIT_IDX),
predict_callback=lambda params: on_predict(params, DATASET_SPLIT_IDX),
)

configure_fedn_client(fedn_client, cfg)

result, combiner_config = fedn_client.connect_to_api(url, cfg["token"], get_controller_config(fedn_client))

if result != ConnectToApiResult.Assigned:
print("Failed to connect to API, exiting.")
return

if not fedn_client.init_grpchandler(config=combiner_config, client_name=fedn_client.client_id, token=cfg["token"]):
return

fedn_client.run()


def load_client_config(client_yaml: str) -> dict:
"""Load the client configuration from a YAML file.
Args:
client_yaml (str): Path to the client configuration YAML file.
Returns:
dict: The client configuration as a dictionary.
"""
if Path(client_yaml).exists():
cfg = read_settings(client_yaml)
cfg = read_settings(client_yaml)
else:
raise Exception(f"Client yaml file not found: {client_yaml}")

if "discover_host" in cfg:
cfg["api_url"] = cfg["discover_host"]

url = construct_api_url(cfg["api_url"], cfg.get("api_port", None))
return cfg

def on_train(model_params, settings):
training_metadata = {"batchsize_train":64, "lr": 1e-3, "n_epochs":1}
print(settings)

def configure_fedn_client(fedn_client: FednClient, cfg: dict) -> None:
"""Configure the FednClient with the given settings.
dataloader_train, _, _ = get_dataloaders(DATASET_PATH, KEYWORDS, DATASET_SPLIT_IDX,
DATASET_TOTAL_SPLITS, training_metadata["batchsize_train"],
BATCHSIZE_VALID)
Args:
fedn_client (FednClient): The FednClient instance to configure.
cfg (dict): The configuration dictionary containing client settings.
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)
optimizer = Adam(sc_model.parameters(), lr=training_metadata["lr"])
loss_fn = nn.CrossEntropyLoss()
n_epochs = training_metadata["n_epochs"]
"""
fedn_client.set_name(cfg["name"])
fedn_client.set_client_id(cfg["client_id"])

for epoch in range(n_epochs):
sc_model.train()
for idx, (y_labels, x_spectrograms) in enumerate(dataloader_train):
optimizer.zero_grad()
_, logits = sc_model(x_spectrograms)

loss = loss_fn(logits, y_labels)
loss.backward()
def get_controller_config(fedn_client: FednClient) -> dict:
"""Get the controller configuration for the FednClient.
optimizer.step()
Args:
fedn_client (FednClient): The FednClient instance.
if idx%100 == 0:
print(f"Epoch {epoch+1}/{n_epochs} | Batch: {idx+1}/{len(dataloader_train)} | Loss: {loss.item()}")
Returns:
dict: The controller configuration dictionary.
out_model = save_parameters(sc_model, io.BytesIO())
"""
return {
"name": fedn_client.name,
"client_id": fedn_client.client_id,
"package": "local",
"preferred_combiner": "",
}

metadata = { "training_metadata": {"num_examples": len(dataloader_train.dataset)} }

return out_model, metadata
def on_train(model_params, settings, dataset_split_idx) -> tuple:
"""Train the model with the given parameters and settings.
def on_validate(model_params):
dataloader_train, dataloader_valid, dataloader_test = get_dataloaders(DATASET_PATH, KEYWORDS, DATASET_SPLIT_IDX,
DATASET_TOTAL_SPLITS, BATCHSIZE_VALID, BATCHSIZE_VALID)
Args:
model_params: The model parameters.
settings: The training settings.
dataset_split_idx: The index of the dataset split to use.
n_labels = dataloader_train.dataset.n_labels
Returns:
tuple: The trained model parameters and metadata.
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)
"""
training_metadata = {"batchsize_train": 64, "lr": 1e-3, "n_epochs": 1}

def evaluate(dataloader):
sc_model.eval()
metric = MulticlassAccuracy(num_classes=n_labels)
for (y_labels, x_spectrograms) in dataloader:
dataloader_train, _, _ = get_dataloaders(
DATASET_PATH, KEYWORDS, dataset_split_idx, DATASET_TOTAL_SPLITS, training_metadata["batchsize_train"], BATCHSIZE_VALID
)

probs, _ = sc_model(x_spectrograms)
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)
optimizer = Adam(sc_model.parameters(), lr=training_metadata["lr"])
loss_fn = nn.CrossEntropyLoss()
n_epochs = training_metadata["n_epochs"]

y_pred = probs.argmax(-1)
metric.update(y_pred, y_labels)
return metric.compute().item()
for epoch in range(n_epochs):
sc_model.train()
for idx, (y_labels, x_spectrograms) in enumerate(dataloader_train):
optimizer.zero_grad()
_, logits = sc_model(x_spectrograms)

report = {"training_acc": evaluate(dataloader_train),
"validation_acc": evaluate(dataloader_valid)}
loss = loss_fn(logits, y_labels)
loss.backward()

return report
optimizer.step()

def on_predict(model_params):
dataloader_train, _, _ = get_dataloaders(DATASET_PATH, KEYWORDS, DATASET_SPLIT_IDX,
DATASET_TOTAL_SPLITS, BATCHSIZE_VALID, BATCHSIZE_VALID)
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)
if idx % 100 == 0:
print(f"Epoch {epoch + 1}/{n_epochs} | Batch: {idx + 1}/{len(dataloader_train)} | Loss: {loss.item()}")

return {}
out_model = save_parameters(sc_model, io.BytesIO())

fedn_client = FednClient(train_callback=on_train, validate_callback=on_validate, predict_callback=on_predict)
metadata = {"training_metadata": {"num_examples": len(dataloader_train.dataset)}}

return out_model, metadata

fedn_client.set_name(cfg["name"])
fedn_client.set_client_id(cfg["client_id"])

controller_config = {
"name": fedn_client.name,
"client_id": fedn_client.client_id,
"package": "local",
"preferred_combiner": "",
}
def on_validate(model_params, dataset_split_idx) -> dict:
"""Validate the model with the given parameters and dataset split index.
result, combiner_config = fedn_client.connect_to_api(url, cfg["token"], controller_config)
Args:
model_params: The model parameters.
dataset_split_idx: The index of the dataset split to use.
if result != ConnectToApiResult.Assigned:
print("Failed to connect to API, exiting.")
return
Returns:
dict: The validation report containing training and validation accuracy.
result: bool = fedn_client.init_grpchandler(config=combiner_config, client_name=fedn_client.client_id, token=cfg["token"])
"""
dataloader_train, dataloader_valid, dataloader_test = get_dataloaders(
DATASET_PATH, KEYWORDS, dataset_split_idx, DATASET_TOTAL_SPLITS, BATCHSIZE_VALID, BATCHSIZE_VALID
)

if not result:
return
n_labels = dataloader_train.dataset.n_labels

fedn_client.run()
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)

def evaluate(dataloader) -> float:
sc_model.eval()
metric = MulticlassAccuracy(num_classes=n_labels)
for y_labels, x_spectrograms in dataloader:
probs, _ = sc_model(x_spectrograms)

y_pred = probs.argmax(-1)
metric.update(y_pred, y_labels)
return metric.compute().item()

return {"training_acc": evaluate(dataloader_train), "validation_acc": evaluate(dataloader_valid)}


def on_predict(model_params, dataset_split_idx) -> dict:
"""Generate predictions using the model parameters and dataset split index.
Args:
model_params: The model parameters.
dataset_split_idx: The index of the dataset split to use.
Returns:
dict: The prediction results.
"""
dataloader_train, _, _ = get_dataloaders(DATASET_PATH, KEYWORDS, dataset_split_idx, DATASET_TOTAL_SPLITS, BATCHSIZE_VALID, BATCHSIZE_VALID)
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
load_parameters(sc_model, model_params)

return {}


if __name__ == "__main__":
Expand Down

0 comments on commit 8fce5a8

Please sign in to comment.