Skip to content

Commit

Permalink
Cleaned code and added docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
carl-andersson committed Feb 7, 2025
1 parent 8fce5a8 commit d0ff7d6
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 100 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch-keyworddetection-api/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ FEDn Project: Keyword Detection (PyTorch)
-----------------------------

This is an example to showcase how to set up FEDnClient and use APIClient to setup and manage a training from python.
The machine learning project is based on the Speech Commands dataset from Google.
The machine learning project is based on the Speech Commands dataset from Google, https://huggingface.co/datasets/google/speech_commands.

The example is intented as a minimalistic quickstart to learn how to use FEDn.

Expand Down
185 changes: 108 additions & 77 deletions examples/pytorch-keyworddetection-api/data.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,42 @@
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np

from pathlib import Path
import json
import hashlib
import json
from pathlib import Path

import numpy as np
import pyloudnorm as pyln
import torch
import torchaudio
from torch.utils.data import DataLoader, Dataset

SAMPLERATE = 16000


class BackgroundNoise(Dataset):
def __init__(self, path, dataset_split_idx, dataset_total_splits):
"""Dataset for background noise samples using all *.wav in the given path."""

def __init__(self, path: str, dataset_split_idx: int, dataset_total_splits: int) -> "BackgroundNoise":
"""Initialize the dataset."""
super().__init__()
self._path = path
self._walker = sorted(str(p) for p in Path(self._path).glob("*.wav"))

self._dataset_split_idx = dataset_split_idx
self._dataset_total_splits = dataset_total_splits

self._start_idx = int(self._dataset_split_idx * len(self._walker)/self._dataset_total_splits)
self._end_idx = int((self._dataset_split_idx +1) * len(self._walker)/self._dataset_total_splits)
self._start_idx = int(self._dataset_split_idx * len(self._walker) / self._dataset_total_splits)
self._end_idx = int((self._dataset_split_idx + 1) * len(self._walker) / self._dataset_total_splits)

self._loudness_meter = pyln.Meter(SAMPLERATE)

self._loudness = [self._loudness_meter.integrated_loudness(self._load_audio(file)[0].numpy()) for file in self._walker]

def _load_audio(self, filename):
def _load_audio(self, filename: str) -> tuple[torch.Tensor, int]:
data, sr = torchaudio.load(filename)
data.squeeze_()
return data, sr

def __getitem__(self, index):
def __getitem__(self, index: int) -> torch.Tensor:
"""Get the audio sample at the given index."""
index = index + self._start_idx
filename = self._walker[index]
audio, sr = self._load_audio(filename)
Expand All @@ -45,15 +50,31 @@ def __getitem__(self, index):
raise ValueError(f"sample rate should be {SAMPLERATE}, but got {sr}")
return audio

def __len__(self):
def __len__(self) -> int:
"""Get the number of samples in the dataset."""
return self._end_idx - self._start_idx


class FedSCDataset(Dataset):
"""Dataset for the Federated Speech Commands dataset."""

NEGATIVE_KEYWORD = "<negative>"
SEED = 1

def __init__(self, path, keywords, subset, dataset_split_idx, dataset_total_splits, data_augmentation = False):
def __init__( # noqa: PLR0913
self, path: str, keywords: list[str], subset: str, dataset_split_idx: int, dataset_total_splits: int, data_augmentation: bool = False
) -> "FedSCDataset":
"""Initialize the dataset.
Args:
path (str): Path to the dataset.
keywords (list[str]): List of keywords to detect.
subset (str): Subset of the dataset to use. One of "training", "validation", or "testing".
dataset_split_idx (int): Index of the dataset split to use.
dataset_total_splits (int): Total number of dataset splits.
data_augmentation (bool): Whether to apply data augmentation.
"""
super(FedSCDataset, self).__init__()
self._path = path
self._subset = subset
Expand All @@ -62,14 +83,17 @@ def __init__(self, path, keywords, subset, dataset_split_idx, dataset_total_spli
self._dataset_split_idx = dataset_split_idx
self._dataset_total_splits = dataset_total_splits
self._dataset = torchaudio.datasets.SPEECHCOMMANDS(path, subset=subset, download=True)
self._start_idx = int(dataset_split_idx * len(self._dataset)/self._dataset_total_splits)
self._end_idx = int((dataset_split_idx+1) * len(self._dataset)/self._dataset_total_splits)
self._start_idx = int(dataset_split_idx * len(self._dataset) / self._dataset_total_splits)
self._end_idx = int((dataset_split_idx + 1) * len(self._dataset) / self._dataset_total_splits)

if data_augmentation:
self._noise_prob = 0.5
self._noise_mag = 0.9
self._noise_dataset = BackgroundNoise(Path(self._dataset._path).joinpath("_background_noise_").as_posix(),
dataset_split_idx=self._dataset_split_idx, dataset_total_splits=self._dataset_total_splits)
self._noise_dataset = BackgroundNoise(
Path(self._dataset._path).joinpath("_background_noise_").as_posix(),
dataset_split_idx=self._dataset_split_idx,
dataset_total_splits=self._dataset_total_splits,
)
else:
self._noise_prob = 0.0
self._noise_mag = 0.0
Expand All @@ -84,28 +108,28 @@ def __init__(self, path, keywords, subset, dataset_split_idx, dataset_total_spli
self._white_noise_mag = 0.0015
self._transform = self._get_spectogram_transform(self._n_mels, self._hop_length, SAMPLERATE, data_augmentation)

self._spectrogram_size = (self._n_mels, SAMPLERATE//self._hop_length)
self._spectrogram_size = (self._n_mels, SAMPLERATE // self._hop_length)

# Reinitialize rng with different seeds fot the different splits
self._rng = np.random.RandomState(self.SEED + self._dataset_split_idx)

@property
def labels(self):
def labels(self) -> list[str]:
return self._labels

@property
def n_mels(self):
def n_mels(self) -> int:
return self._n_mels

@property
def n_labels(self):
def n_labels(self) -> int:
return len(self._labels)

@property
def spectrogram_size(self):
def spectrogram_size(self) -> tuple[int, int]:
return (self.n_mels, 100)

def __getitem__(self, index):
def __getitem__(self, index: int) -> tuple[int, str, torch.Tensor, torch.Tensor]:
shuffled_index = self._shuffle_order[index]
sample, sr, label, _, _ = self._dataset[shuffled_index]
sample.squeeze_()
Expand All @@ -121,9 +145,8 @@ def __getitem__(self, index):
noise_idx = self._rng.randint(len(self._noise_dataset))
waveform = self._noise_dataset[noise_idx]
sub_start_idx = self._rng.randint(waveform.shape[-1] - SAMPLERATE)
noise = waveform[sub_start_idx: sub_start_idx + SAMPLERATE]
sample += self._noise_mag*noise*self._rng.rand()

noise = waveform[sub_start_idx : sub_start_idx + SAMPLERATE]
sample += self._noise_mag * noise * self._rng.rand()

sample += self._rng.normal(scale=self._white_noise_mag, size=sample.shape).astype(np.float32)

Expand All @@ -132,23 +155,24 @@ def __getitem__(self, index):

return y, label, spectrogram, sample

def __len__(self):
return self._end_idx-self._start_idx
def __len__(self) -> int:
return self._end_idx - self._start_idx

def get_label_from_text(self, text_label):
def get_label_from_text(self, text_label: str) -> int:
if text_label in self._labels:
y = self._labels.index(text_label)
else:
y = len(self._labels)-1
y = len(self._labels) - 1
return y

def get_spectrogram(self, sample):
def get_spectrogram(self, sample: int) -> torch.Tensor:
start_idx = self._rng.randint(0, self._hop_length)
length = sample.shape[0] - self . _hop_length
return self._transform(sample[start_idx: start_idx+length])
length = sample.shape[0] - self._hop_length
return self._transform(sample[start_idx : start_idx + length])

def get_stats(self):
sha1 = hashlib.sha1() # noqa:S324
def get_stats(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get the mean and std of the training data. If the stats are already calculated, they are loaded from disk."""
sha1 = hashlib.sha1() # noqa:S324
for word in self.labels:
sha1.update(str.encode(word))
sha1.update(str.encode(str(self._n_mels)))
Expand All @@ -162,75 +186,83 @@ def get_stats(self):
if filepath.exists():
with open(filepath, "r") as file:
data = json.load(file)
if data["n_mels"] == self._n_mels and data["labels"] == self.labels \
and data["split_index"] == self._dataset_split_idx \
and data["dataset_total_splits"] == self._dataset_total_splits \
and data["hop_length"] == self._hop_length \
and data["white_noise_mag"] == self._white_noise_mag \
and data["SEED"] == self.SEED:
if (
data["n_mels"] == self._n_mels
and data["labels"] == self.labels
and data["split_index"] == self._dataset_split_idx
and data["dataset_total_splits"] == self._dataset_total_splits
and data["hop_length"] == self._hop_length
and data["white_noise_mag"] == self._white_noise_mag
and data["SEED"] == self.SEED
):
return torch.tensor(data["label_mean"]), torch.tensor(data["spectrogram_mean"])[:, None], torch.tensor(data["spectrogram_std"])[:, None]

dataset = FedSCDataset(self._path, [], subset="training", dataset_split_idx=self._dataset_split_idx, dataset_total_splits=self._dataset_total_splits)
label_count = np.zeros(len(self.labels))
spectrogram_sum = torch.zeros(self._n_mels)
spectrogram_square_sum = torch.zeros(self._n_mels)

print("Calculating training data statistics...")
N = len(dataset)
N_spectrogram_cols = 0
for i in range(N):
print("Calculating training data statistics...") # noqa:T201
n_samples = len(dataset)
n_spectrogram_cols = 0
for i in range(n_samples):
_, label, spectrogram, _ = dataset[i]
spectrogram_sum += spectrogram.sum(-1)
spectrogram_square_sum += spectrogram.square().sum(-1)

N_spectrogram_cols += spectrogram.shape[-1]
n_spectrogram_cols += spectrogram.shape[-1]

if label in self.labels:
idx = self.labels.index(label)
label_count[idx] += 1
else:
label_count[-1] += 1

label_mean = label_count/N
spectrogram_mean = spectrogram_sum/N_spectrogram_cols
spectrogram_std = (spectrogram_square_sum - spectrogram_mean.square())/(N_spectrogram_cols-1)
label_mean = label_count / n_samples
spectrogram_mean = spectrogram_sum / n_spectrogram_cols
spectrogram_std = (spectrogram_square_sum - spectrogram_mean.square()) / (n_spectrogram_cols - 1)
spectrogram_std.sqrt_()
with open(filepath, "w") as file:
d = {"labels":self.labels,"n_mels":self._n_mels,
"white_noise_mag": self._white_noise_mag,
"SEED": self.SEED,
"split_index":self._dataset_split_idx,
"dataset_total_splits":self._dataset_total_splits,
"hop_length": self._hop_length,
"label_mean": label_mean.tolist(),
"spectrogram_mean":spectrogram_mean.numpy().tolist(),
"spectrogram_std":spectrogram_std.numpy().tolist()}
d = {
"labels": self.labels,
"n_mels": self._n_mels,
"white_noise_mag": self._white_noise_mag,
"SEED": self.SEED,
"split_index": self._dataset_split_idx,
"dataset_total_splits": self._dataset_total_splits,
"hop_length": self._hop_length,
"label_mean": label_mean.tolist(),
"spectrogram_mean": spectrogram_mean.numpy().tolist(),
"spectrogram_std": spectrogram_std.numpy().tolist(),
}
json.dump(d, file)
return torch.tensor(label_mean), spectrogram_mean[:, None], spectrogram_std[:, None]

def get_collate_fn(self):
def collate_fn(batch):
def get_collate_fn(self) -> callable:
def collate_fn(batch: tuple[int, str, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
ys, _, spectrogram, _ = zip(*batch)
return torch.tensor(ys, dtype=torch.long), torch.stack(spectrogram)
return collate_fn


return collate_fn

def _get_spectogram_transform(self, n_mels, hop_length, sr, data_augmentation=False):
def _get_spectogram_transform(self, n_mels: int, hop_length: int, sr: int, data_augmentation: bool = False) -> torch.nn.Sequential:
if data_augmentation:
return torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=320, hop_length=hop_length, n_mels=n_mels),
torchaudio.transforms.FrequencyMasking(freq_mask_param=int(n_mels*0.2)),
torchaudio.transforms.TimeMasking(time_mask_param=int(0.2 * 16000/160)),
torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80))
else:
return torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=320, hop_length=hop_length, n_mels=n_mels),
torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80))


def get_dataloaders(path, keywords, dataset_split_idx, dataset_total_splits, batchsize_train, batchsize_valid):

torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=320, hop_length=hop_length, n_mels=n_mels),
torchaudio.transforms.FrequencyMasking(freq_mask_param=int(n_mels * 0.2)),
torchaudio.transforms.TimeMasking(time_mask_param=int(0.2 * 16000 / 160)),
torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80),
)
return torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=320, hop_length=hop_length, n_mels=n_mels),
torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80),
)


def get_dataloaders(
path: str, keywords: list[str], dataset_split_idx: int, dataset_total_splits: int, batchsize_train: int, batchsize_valid: int
) -> tuple[DataLoader, DataLoader, DataLoader]:
"""Get the dataloaders for the training, validation, and testing datasets."""
dataset_train = FedSCDataset(path, keywords, "training", dataset_split_idx, dataset_total_splits, data_augmentation=True)
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize_train, collate_fn=dataset_train.get_collate_fn(), shuffle=True, drop_last=True)

Expand All @@ -241,4 +273,3 @@ def get_dataloaders(path, keywords, dataset_split_idx, dataset_total_splits, bat
dataloader_test = DataLoader(dataset=dataset_test, batch_size=batchsize_valid, collate_fn=dataset_test.get_collate_fn(), shuffle=False, drop_last=False)

return dataloader_train, dataloader_valid, dataloader_test

20 changes: 8 additions & 12 deletions examples/pytorch-keyworddetection-api/fedn_api.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
import argparse
from fedn import APIClient

from data import get_dataloaders
from model import compile_model, model_hyperparams, save_parameters
from settings import DATASET_PATH, KEYWORDS, BATCHSIZE_VALID
from settings import BATCHSIZE_VALID, DATASET_PATH, KEYWORDS

from fedn import APIClient

HOST = "" ## INSERT HOST
HOST = "" ## INSERT HOST
TOKEN = "" ## INSERT TOKEN

def init_seedmodel(api_client: APIClient):

# Use all data to compute seed model normalizers for now
dataloader_train, _, _ = get_dataloaders(DATASET_PATH, KEYWORDS, 0,
1, BATCHSIZE_VALID, BATCHSIZE_VALID)
def init_seedmodel(api_client: APIClient) -> dict:
"""Used to send a seed model to the server. The seed model is normalized with all training data."""
dataloader_train, _, _ = get_dataloaders(DATASET_PATH, KEYWORDS, 0, 1, BATCHSIZE_VALID, BATCHSIZE_VALID)
sc_model = compile_model(**model_hyperparams(dataloader_train.dataset))
seed_path = "seed.npz"
save_parameters(sc_model, seed_path)

response = api_client.set_active_model(seed_path)

return response
return api_client.set_active_model(seed_path)


def main():
Expand All @@ -39,13 +36,12 @@ def main():
response = init_seedmodel(api_client)
print(response)
elif args.start_session:
#Depending on the computer hosting the clients this round_timeout might need to increase
# Depending on the computer hosting the clients this round_timeout might need to increase
response = api_client.start_session(round_timeout=600)
print(response)
else:
print("No flag passed")



if __name__ == "__main__":
main()
Loading

0 comments on commit d0ff7d6

Please sign in to comment.