diff --git a/.github/workflows/smoke_tests.yaml b/.github/workflows/smoke_tests.yaml index 775cce716..e832efd02 100644 --- a/.github/workflows/smoke_tests.yaml +++ b/.github/workflows/smoke_tests.yaml @@ -22,6 +22,9 @@ jobs: # Display the Python version being used - name: Display Python version run: python -c "import sys; print(sys.version)" + - name: Set up file descriptor limit + run: | + ulimit -n 4096 - name: Install and configure Poetry uses: snok/install-poetry@v1 with: diff --git a/.github/workflows/static_code_checks.yaml b/.github/workflows/static_code_checks.yaml index 7cc294f51..48ebebf4b 100644 --- a/.github/workflows/static_code_checks.yaml +++ b/.github/workflows/static_code_checks.yaml @@ -48,6 +48,7 @@ jobs: # pip-audit to ignore these warnings # Vulnerabilities from GHSA-x38x-g6gr-jqff to GHSA-7p8j-qv6x-f4g4 # originate from mlflow + # Ignore pytorch vulnerability GHSA-pg7h-5qx3-wjr3 ignore-vulns: | GHSA-x38x-g6gr-jqff GHSA-j8mg-pqc5-x9gj diff --git a/.gitignore b/.gitignore index c207f71ca..059adf536 100644 --- a/.gitignore +++ b/.gitignore @@ -150,6 +150,7 @@ settings.json **/datasets/skin_cancer/PAD-UFES-20/** **/datasets/skin_cancer/ISIC_2019/** **/datasets/skin_cancer/Derm7pt/** +**/datasets/nnunet/** # logs diff --git a/examples/nnunet_example/README.md b/examples/nnunet_example/README.md new file mode 100644 index 000000000..2955b5ad9 --- /dev/null +++ b/examples/nnunet_example/README.md @@ -0,0 +1,40 @@ +# NnUNetClient Example + +This example demonstrates how to use the NnUNetClient to train nnunet segmentation models in a federated setting. + +By default this example trains an nnunet model on the Task04_Hippocampus dataset from the Medical Segmentation Decathlon (MSD). However, any of the MSD datasets can be used by specifying them with the msd_dataset_name flag for the client. To run this example first create a config file for the server. An example config has been provided in this directory. The required keys for the config are: + +```yaml +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_epochs: 1 # Or local_steps, one or the other must be chosen + +nnunet_config: 2d +``` + +The only additional parameter required by nnunet is nnunet_config which is one of the official nnunet configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) + +One may also add the following optional keys to the config yaml file + +```yaml +# Optional config parameters +nnunet_plans: /Path/to/nnunet/plans.json +starting_checkpoint: /Path/to/starting/checkpoint.pth +``` + +To run a federated learning experiment with nnunet models, first ensure you are in the FL4Health directory and then start the nnunet server using the following command. To view a list of optional flags use the --help flag + +```bash +python -m examples.nnunet_example.server --config_path examples/nnunet_example/config.yaml +``` + +Once the server has started, start the necessary number of clients specified by the n_clients key in the config file. Each client can be started by running the following command in a seperate session. To view a list of optional flags use the --help flag. + +```bash +python -m examples.nnunet_example.client --dataset_path examples/datasets/nnunet +``` + +The MSD dataset will be downloaded and prepared automatically by the nnunet example script if it does not already exist. The dataset_path flag is used as more of a data working directory by the client. The client will create nnunet_raw, nnunet_preprocessed and nnunet_results sub directories if they do not already exist in the dataset_path folder. The dataset itself will be stored in a folder within nnunet_raw. Therefore when checking if the data already exists, the client will look for the following folder '{dataset_path}/nnunet_raw/{dataset_name}' diff --git a/examples/nnunet_example/__init__.py b/examples/nnunet_example/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/nnunet_example/client.py b/examples/nnunet_example/client.py new file mode 100644 index 000000000..4b66c8b1f --- /dev/null +++ b/examples/nnunet_example/client.py @@ -0,0 +1,172 @@ +import argparse +import os +import warnings +from logging import INFO +from os.path import exists, join +from pathlib import Path +from typing import Union + +with warnings.catch_warnings(): + # Need to import lightning utilities now in order to avoid deprecation + # warnings. Ignore flake8 warning saying that it is unused + # lightning utilities is imported by some of the dependencies + # so by importing it now and filtering the warnings + # https://github.com/Lightning-AI/utilities/issues/119 + warnings.filterwarnings("ignore", category=DeprecationWarning) + import lightning_utilities # noqa: F401 + +import torch +from flwr.client import start_client +from flwr.common.logger import log +from torchmetrics.segmentation import GeneralizedDiceScore + +from fl4health.utils.load_data import load_msd_dataset +from fl4health.utils.metrics import TorchMetric, TransformsMetric +from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_num_labels +from research.picai.fl_nnunet.transforms import get_annotations_from_probs, get_probabilities_from_logits + + +def main( + dataset_path: Path, + msd_dataset_name: str, + server_address: str, + fold: Union[int, str], + always_preprocess: bool = False, +) -> None: + + # Log device and server address + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + log(INFO, f"Using device: {DEVICE}") + log(INFO, f"Using server address: {server_address}") + + # Load the dataset if necessary + msd_dataset_enum = get_msd_dataset_enum(msd_dataset_name) + nnUNet_raw = join(dataset_path, "nnunet_raw") + if not exists(join(nnUNet_raw, msd_dataset_enum.value)): + log(INFO, f"Downloading and extracting {msd_dataset_enum.value} dataset") + load_msd_dataset(nnUNet_raw, msd_dataset_name) + + # The dataset ID will be the same as the MSD Task number + dataset_id = int(msd_dataset_enum.value[4:6]) + nnunet_dataset_name = f"Dataset{dataset_id:03d}_{msd_dataset_enum.value.split('_')[1]}" + + # Convert the msd dataset if necessary + if not exists(join(nnUNet_raw, nnunet_dataset_name)): + log(INFO, f"Converting {msd_dataset_enum.value} into nnunet dataset") + convert_msd_dataset(source_folder=join(nnUNet_raw, msd_dataset_enum.value)) + + # Create a metric + dice = TransformsMetric( + metric=TorchMetric( + name="Pseudo DICE", + metric=GeneralizedDiceScore( + num_classes=msd_num_labels[msd_dataset_enum], weight_type="square", include_background=False + ).to(DEVICE), + ), + transforms=[get_probabilities_from_logits, get_annotations_from_probs], + ) + + # Create client + client = nnUNetClient( + # Args specific to nnUNetClient + dataset_id=dataset_id, + fold=fold, + always_preprocess=always_preprocess, + # BaseClient Args + device=DEVICE, + metrics=[dice], + data_path=dataset_path, # Argument not actually used by nnUNetClient + ) + + start_client(server_address=server_address, client=client.to_client()) + + # Shutdown the client + client.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="nnunet_example/client.py", + description="""An exampled of nnUNetClient on any of the Medical + Segmentation Decathelon (MSD) datasets. Automatically generates a + nnunet segmentation model and trains it in a federated setting""", + ) + + # I have to use underscores instead of dashes because thats how they + # defined it in run smoke tests + parser.add_argument( + "--dataset_path", + type=str, + required=True, + help="""Path to the folder in which data should be stored. This script + will automatically create nnunet_raw, and nnunet_preprocessed + subfolders if they don't already exist. This script will also + attempt to download and prepare the MSD Dataset into the + nnunet_raw folder if it does not already exist.""", + ) + parser.add_argument( + "--fold", + type=str, + required=False, + default="0", + help="""[OPTIONAL] Which fold of the local client dataset to use for + validation. nnunet defaults to 5 folds (0 to 4). Can also be set + to 'all' to use all the data for both training and validation. + Defaults to fold 0""", + ) + parser.add_argument( + "--msd_dataset_name", + type=str, + required=False, + default="Task04_Hippocampus", # The smallest dataset + help="""[OPTIONAL] Name of the MSD dataset to use. The options are + defined by the values of the MsdDataset enum as returned by the + get_msd_dataset_enum function""", + ) + parser.add_argument( + "--always-preprocess", + action="store_true", + required=False, + help="""[OPTIONAL] Use this to force preprocessing the nnunet data + even if the preprocessed data is found to already exist""", + ) + parser.add_argument( + "--server_address", + type=str, + required=False, + default="0.0.0.0:8080", + help="""[OPTIONAL] The server address for the clients to communicate + to the server through. Defaults to 0.0.0.0:8080""", + ) + args = parser.parse_args() + + # Create nnunet directory structure and set environment variables + nnUNet_raw = join(args.dataset_path, "nnunet_raw") + nnUNet_preprocessed = join(args.dataset_path, "nnunet_preprocessed") + if not exists(nnUNet_raw): + os.makedirs(nnUNet_raw) + if not exists(nnUNet_preprocessed): + os.makedirs(nnUNet_preprocessed) + os.environ["nnUNet_raw"] = nnUNet_raw + os.environ["nnUNet_preprocessed"] = nnUNet_preprocessed + os.environ["nnUNet_results"] = join(args.dataset_path, "nnunet_results") + log(INFO, "Setting nnunet environment variables") + log(INFO, f"\tnnUNet_raw: {nnUNet_raw}") + log(INFO, f"\tnnUNet_preprocessed: {nnUNet_preprocessed}") + log(INFO, f"\tnnUNet_results: {join(args.dataset_path, 'nnunet_results')}") + + # Everything that uses nnunetv2 module can only be imported after + # environment variables are changed + from nnunetv2.dataset_conversion.convert_MSD_dataset import convert_msd_dataset + + from research.picai.fl_nnunet.nnunet_client import nnUNetClient + + # Check fold argument and start main method + fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold) + main( + dataset_path=Path(args.dataset_path), + msd_dataset_name=args.msd_dataset_name, + server_address=args.server_address, + fold=fold, + always_preprocess=args.always_preprocess, + ) diff --git a/examples/nnunet_example/config.yaml b/examples/nnunet_example/config.yaml new file mode 100644 index 000000000..d2371d0d6 --- /dev/null +++ b/examples/nnunet_example/config.yaml @@ -0,0 +1,8 @@ +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_epochs: 1 + +nnunet_config: 2d diff --git a/examples/nnunet_example/server.py b/examples/nnunet_example/server.py new file mode 100644 index 000000000..56a23a717 --- /dev/null +++ b/examples/nnunet_example/server.py @@ -0,0 +1,122 @@ +import argparse +import json +import pickle +import warnings +from functools import partial +from typing import Optional + +import yaml + +with warnings.catch_warnings(): + # Need to import lightning utilities now in order to avoid deprecation + # warnings. Ignore flake8 warning saying that it is unused + # lightning utilities is imported by some of the dependencies + # so by importing it now and filtering the warnings + # https://github.com/Lightning-AI/utilities/issues/119 + warnings.filterwarnings("ignore", category=DeprecationWarning) + import lightning_utilities # noqa: F401 + +import flwr as fl +import torch +from flwr.common.parameter import ndarrays_to_parameters +from flwr.common.typing import Config +from flwr.server.client_manager import SimpleClientManager +from flwr.server.strategy import FedAvg + +from examples.utils.functions import make_dict_with_epochs_or_steps +from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from research.picai.fl_nnunet.nnunet_server import NnUNetServer + + +def get_config( + current_server_round: int, + nnunet_config: str, + n_server_rounds: int, + batch_size: int, + n_clients: int, + nnunet_plans: Optional[str] = None, + local_epochs: Optional[int] = None, + local_steps: Optional[int] = None, +) -> Config: + # Create config + config: Config = { + "n_clients": n_clients, + "nnunet_config": nnunet_config, + "n_server_rounds": n_server_rounds, + "batch_size": batch_size, + **make_dict_with_epochs_or_steps(local_epochs, local_steps), + "current_server_round": current_server_round, + } + + # Check if plans were provided + if nnunet_plans is not None: + plans_bytes = pickle.dumps(json.load(open(nnunet_plans, "r"))) + config["nnunet_plans"] = plans_bytes + + return config + + +def main(config: dict, server_address: str) -> None: + # Partial function with everything set except current server round + fit_config_fn = partial( + get_config, + n_clients=config["n_clients"], + nnunet_config=config["nnunet_config"], + n_server_rounds=config["n_server_rounds"], + batch_size=0, # Set this to 0 because we're not using it + nnunet_plans=config.get("nnunet_plans"), + local_epochs=config.get("local_epochs"), + local_steps=config.get("local_steps"), + ) + + if config.get("starting_checkpoint"): + model = torch.load(config["starting_checkpoint"]) + # Of course nnunet stores their pytorch models differently. + params = ndarrays_to_parameters([val.cpu().numpy() for _, val in model["network_weights"].items()]) + else: + params = None + + strategy = FedAvg( + min_fit_clients=config["n_clients"], + min_evaluate_clients=config["n_clients"], + min_available_clients=config["n_clients"], + on_fit_config_fn=fit_config_fn, + on_evaluate_config_fn=fit_config_fn, # Nothing changes for eval + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + initial_parameters=params, + ) + + server = NnUNetServer( + client_manager=SimpleClientManager(), + strategy=strategy, + ) + + fl.server.start_server( + server=server, + server_address=server_address, + config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), + ) + + # Shutdown server + server.shutdown() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", action="store", type=str, help="Path to the configuration file") + parser.add_argument( + "--server-address", + type=str, + required=False, + default="0.0.0.0:8080", + help="""[OPTIONAL] The address to use for the server. Defaults to + 0.0.0.0:8080""", + ) + + args = parser.parse_args() + + with open(args.config_path, "r") as f: + config = yaml.safe_load(f) + + main(config, server_address=args.server_address) diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index 01db446bf..307d5e5b7 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -1,12 +1,12 @@ import datetime -from logging import DEBUG, INFO, WARNING +from logging import DEBUG, INFO, WARN, WARNING from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union import torch.nn as nn from flwr.common import EvaluateRes, Parameters from flwr.common.logger import log from flwr.common.parameter import parameters_to_ndarrays -from flwr.common.typing import Scalar +from flwr.common.typing import Code, GetParametersIns, Scalar from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.history import History @@ -350,3 +350,101 @@ def _hydrate_model_for_checkpointing(self) -> nn.Module: model_ndarrays = parameters_to_ndarrays(self.parameters) self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) return self.server_model + + +class FlServerWithInitializer(FlServer): + def __init__( + self, + client_manager: ClientManager, + strategy: Optional[Strategy] = None, + wandb_reporter: Optional[ServerWandBReporter] = None, + checkpointer: Optional[TorchCheckpointer] = None, + metrics_reporter: Optional[MetricsReporter] = None, + ) -> None: + """ + Server with an initialize hook method that is called prior to fit. + Override the self.initialize method to do server initialization prior + to training but after the clients have been created. Can be useful if + the state of the server depends on the properties of the clients. Eg. + The nnunet server requests an nnunet plans dict to be generated by a + client if one was not provided. + + Args: + client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if + they are to be sampled at all. + strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. + client updates and other information potentially sent by the participating clients. If None the + strategy is FedAvg as set by the flwr Server. + wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log + information and results to a Weights and Biases account. If None is provided, no logging occurs. + Defaults to None. + checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform + server side checkpointing based on some criteria. If none, then no server-side checkpointing is + performed. Defaults to None. + metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics + during the execution. Defaults to an instance of MetricsReporter with default init parameters. + """ + super().__init__(client_manager, strategy, wandb_reporter, checkpointer, metrics_reporter) + self.initialized = False + + def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: + """ + Get initial parameters from one of the available clients. Same as + parent function except we provide a config to the client when + requesting initial parameters + https://github.com/adap/flower/issues/3770 + + Note: + I have to use configure_fit to bypass mypy errors since + on_fit_config is not defined in the Strategy base class. The + downside is that configure fit will wait until enough clients for + training are present instead of just sampling one client. I + thought about defining a new init_config attribute but this + """ + # Server-side parameter initialization + parameters: Optional[Parameters] = self.strategy.initialize_parameters(client_manager=self._client_manager) + if parameters is not None: + log(INFO, "Using initial global parameters provided by strategy") + return parameters + + # Get initial parameters from one of the clients + log(INFO, "Requesting initial parameters from one random client") + random_client = self._client_manager.sample(1)[0] + dummy_params = Parameters([], "None") + config = self.strategy.configure_fit(server_round, dummy_params, self._client_manager)[0][1].config + ins = GetParametersIns(config=config) + get_parameters_res = random_client.get_parameters(ins=ins, timeout=timeout, group_id=server_round) + if get_parameters_res.status.code == Code.OK: + log(INFO, "Received initial parameters from one random client") + else: + log( + WARN, + "Failed to receive initial parameters from the client." " Empty initial parameters will be used.", + ) + return get_parameters_res.parameters + + def initialize(self, server_round: int, timeout: Optional[float] = None) -> None: + """ + Hook method to allow the server to do some additional initialization + prior to training. For example, NnUNetServer uses this method to ask a + client to initialize the global nnunet plans if one is not provided in + in the config + + Args: + server_round (int): The current server round. This hook method is + only called with a server_round=0 at the beginning of self.fit + timeout (Optional[float], optional): The server's timeout + parameter. Useful if one is requesting information from a + client Defaults to None. + """ + self.initialized = True + + def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + """ + Same as parent method except initialize hook method is called first + """ + # Initialize the server + if not self.initialized: + self.initialize(server_round=0, timeout=timeout) + + return super().fit(num_rounds, timeout) diff --git a/fl4health/utils/load_data.py b/fl4health/utils/load_data.py index e72f9ae5b..606f233dd 100644 --- a/fl4health/utils/load_data.py +++ b/fl4health/utils/load_data.py @@ -1,4 +1,5 @@ import random +import warnings from logging import INFO from pathlib import Path from typing import Callable, Dict, Optional, Tuple @@ -12,8 +13,14 @@ from fl4health.utils.dataset import TensorDataset from fl4health.utils.dataset_converter import DatasetConverter +from fl4health.utils.msd_dataset_sources import get_msd_dataset_enum, msd_md5_hashes, msd_urls from fl4health.utils.sampler import LabelBasedSampler +with warnings.catch_warnings(): + # ignoring some annoying scipy deprecation warnings + warnings.simplefilter("ignore", category=DeprecationWarning) + from monai.apps.utils import download_and_extract + class ToNumpy: def __call__(self, tensor: torch.Tensor) -> np.ndarray: @@ -255,3 +262,22 @@ def load_cifar10_test_data( evaluation_loader = DataLoader(evaluation_set, batch_size=batch_size, shuffle=False) num_examples = {"eval_set": len(evaluation_set)} return evaluation_loader, num_examples + + +def load_msd_dataset(data_path: str, msd_dataset_name: str) -> None: + """ + Downloads and extracts one of the 10 Medical Segmentation Decathelon (MSD) + datasets. + + Args: + data_path (str): Path to the folder in which to extract the + dataset. The data itself will be in a subfolder named after the + dataset, not in the data_path directory itself. The name of the + folder will be the name of the dataset as defined by the values of + the MsdDataset enum returned by get_msd_dataset_enum + msd_dataset_name (str): One of the 10 msd datasets + """ + msd_enum = get_msd_dataset_enum(msd_dataset_name) + msd_hash = msd_md5_hashes[msd_enum] + url = msd_urls[msd_enum] + download_and_extract(url=url, output_dir=data_path, hash_val=msd_hash, hash_type="md5", progress=True) diff --git a/fl4health/utils/msd_dataset_sources.py b/fl4health/utils/msd_dataset_sources.py new file mode 100644 index 000000000..ba07cc808 --- /dev/null +++ b/fl4health/utils/msd_dataset_sources.py @@ -0,0 +1,63 @@ +from enum import Enum + + +class MsdDataset(Enum): + TASK01_BRAINTUMOUR = "Task01_BrainTumour" + TASK02_HEART = "Task02_Heart" + TASK03_LIVER = "Task03_Liver" + TASK04_HIPPOCAMPUS = "Task04_Hippocampus" + TASK05_PROSTATE = "Task05_Prostate" + TASK06_LUNG = "Task06_Lung" + TASK07_PANCREAS = "Task07_Pancreas" + TASK08_HEPATICVESSEL = "Task08_HepaticVessel" + TASK09_SPLEEN = "Task09_Spleen" + TASK10_COLON = "Task10_Colon" + + +def get_msd_dataset_enum(dataset_name: str) -> MsdDataset: + try: + return MsdDataset(dataset_name) + except Exception as e: + raise e + + +msd_urls = { + MsdDataset.TASK01_BRAINTUMOUR: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar", + MsdDataset.TASK02_HEART: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar", + MsdDataset.TASK03_LIVER: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar", + MsdDataset.TASK04_HIPPOCAMPUS: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar", + MsdDataset.TASK05_PROSTATE: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tar", + MsdDataset.TASK06_LUNG: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar", + MsdDataset.TASK07_PANCREAS: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tar", + MsdDataset.TASK08_HEPATICVESSEL: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tar", + MsdDataset.TASK09_SPLEEN: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar", + MsdDataset.TASK10_COLON: "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar", +} + +msd_md5_hashes = { + MsdDataset.TASK01_BRAINTUMOUR: "240a19d752f0d9e9101544901065d872", + MsdDataset.TASK02_HEART: "06ee59366e1e5124267b774dbd654057", + MsdDataset.TASK03_LIVER: "a90ec6c4aa7f6a3d087205e23d4e6397", + MsdDataset.TASK04_HIPPOCAMPUS: "9d24dba78a72977dbd1d2e110310f31b", + MsdDataset.TASK05_PROSTATE: "35138f08b1efaef89d7424d2bcc928db", + MsdDataset.TASK06_LUNG: "8afd997733c7fc0432f71255ba4e52dc", + MsdDataset.TASK07_PANCREAS: "4f7080cfca169fa8066d17ce6eb061e4", + MsdDataset.TASK08_HEPATICVESSEL: "641d79e80ec66453921d997fbf12a29c", + MsdDataset.TASK09_SPLEEN: "410d4a301da4e5b2f6f86ec3ddba524e", + MsdDataset.TASK10_COLON: "bad7a188931dc2f6acf72b08eb6202d0", +} + +# The number of classes for each MSD Dataset (including background) +# I got these from the paper, didn't download all the datasets to double check +msd_num_labels = { + MsdDataset.TASK01_BRAINTUMOUR: 4, + MsdDataset.TASK02_HEART: 2, + MsdDataset.TASK03_LIVER: 3, + MsdDataset.TASK04_HIPPOCAMPUS: 3, + MsdDataset.TASK05_PROSTATE: 3, + MsdDataset.TASK06_LUNG: 2, + MsdDataset.TASK07_PANCREAS: 3, + MsdDataset.TASK08_HEPATICVESSEL: 3, + MsdDataset.TASK09_SPLEEN: 2, + MsdDataset.TASK10_COLON: 2, +} diff --git a/poetry.lock b/poetry.lock index 24ed51840..a3dcab397 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -2611,7 +2611,6 @@ description = "Clang Python Bindings, mirrored from the official LLVM repo: http optional = false python-versions = "*" files = [ - {file = "libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a"}, {file = "libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5"}, {file = "libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8"}, {file = "libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b"}, @@ -5656,18 +5655,19 @@ test = ["pytest"] [[package]] name = "setuptools" -version = "70.0.0" +version = "69.5.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"}, - {file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"}, + {file = "setuptools-69.5.1-py3-none-any.whl", hash = "sha256:c636ac361bc47580504644275c9ad802c50415c7522212252c033bd15f301f32"}, + {file = "setuptools-69.5.1.tar.gz", hash = "sha256:6c1fccdac05a97e598fb0ae3bbed5904ccb317337a51139dcd51453611bbb987"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "shellingham" @@ -7312,4 +7312,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.11" -content-hash = "3e99bd34f28aaf5adc97f53a9e991a3085749c6e203eefd546aaa0951a6ecbd3" +content-hash = "a98dd7469d0a2d873b1f4dc04ab31a89fba8642221d8d97c0fbfaf6ad2bfbf62" diff --git a/pyproject.toml b/pyproject.toml index d067fc4b7..4041a079b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,11 @@ dp-accounting = "^0.4.3" torchmetrics = "^1.3.0" aiohttp = "^3.9.3" urllib3 = "^2.2.2" +setuptools = "69.5.1" +# Documented issues with setuptools 70.0.0 +# https://stackoverflow.com/a/78606253/24046590 +# https://github.com/vllm-project/vllm/issues/4961 +# Temporary solution is to pin to 69.5.1 [tool.poetry.group.dev.dependencies] # locked the 2.13 version because of restrictions with tensorflow-io diff --git a/research/picai/fl_nnunet/config.yaml b/research/picai/fl_nnunet/config.yaml index 850008a28..c43bd2e42 100644 --- a/research/picai/fl_nnunet/config.yaml +++ b/research/picai/fl_nnunet/config.yaml @@ -1,9 +1,7 @@ # You should set these yourself n_clients: 1 nnunet_config: 2d -nnunet_plans: /home/shawn/Code/nnunet_storage/nnUNet_preprocessed/Dataset012_PICAI-debug/nnUNetPlans.json -fold: 0 +# nnunet_plans: /home/shawn/Code/nnunet_storage/nnUNet_preprocessed/Dataset012_PICAI-debug/nnUNetPlans.json n_server_rounds: 1 -local_epochs: 3 -server_address: '0.0.0.0:8080' -starting_checkpoint: /home/shawn/Code/nnunet_storage/nnUNet_results/Dataset012_PICAI-debug/nnUNetTrainer_1epoch__nnUNetPlans__2d/fold_0/checkpoint_best.pth +local_epochs: 2 +# starting_checkpoint: /home/shawn/Code/nnunet_storage/nnUNet_results/Dataset012_PICAI-debug/nnUNetTrainer_1epoch__nnUNetPlans__2d/fold_0/checkpoint_best.pth diff --git a/research/picai/fl_nnunet/nnunet_client.py b/research/picai/fl_nnunet/nnunet_client.py index 7492e2b52..a5cc19696 100644 --- a/research/picai/fl_nnunet/nnunet_client.py +++ b/research/picai/fl_nnunet/nnunet_client.py @@ -1,16 +1,16 @@ import logging +import os import pickle import signal import warnings from logging import INFO -from os import makedirs from os.path import exists, join from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from flwr.common.logger import log -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from torch import nn from torch.nn.modules.loss import _Loss from torch.optim import Optimizer @@ -37,13 +37,17 @@ # silences a bunch of deprecation warnings related to scipy.ndimage # Raised an issue with nnunet. https://github.com/MIC-DKFZ/nnUNet/issues/2370 warnings.filterwarnings("ignore", category=DeprecationWarning) + from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter + from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter from batchgenerators.utilities.file_and_folder_operations import load_json, save_json + from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, preprocess_dataset from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw from nnunetv2.training.dataloading.utils import unpack_dataset from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name + # Get the default signal handlers used by python before flwr overrides them # We need these because the nnunet dataloaders spawn child processes # and flwr throws errors when those processes end. So we set the signal handlers @@ -137,6 +141,7 @@ def __init__( self.data_identifier = data_identifier self.always_preprocess: bool = always_preprocess self.plans_name = plans_identifier + self.fingerprint_extracted = False # nnunet specific attributes to be initialized in setup_client self.nnunet_trainer: nnUNetTrainer @@ -239,7 +244,7 @@ def create_plans(self, config: Config) -> Dict[str, Any]: # Can't run nnunet preprocessing without saving plans file if not exists(join(nnUNet_preprocessed, self.dataset_name)): - makedirs(join(nnUNet_preprocessed, self.dataset_name)) + os.makedirs(join(nnUNet_preprocessed, self.dataset_name)) plans_save_path = join(nnUNet_preprocessed, self.dataset_name, self.plans_name + ".json") save_json(plans, plans_save_path, sort_keys=False) return plans @@ -269,7 +274,22 @@ def maybe_preprocess(self, nnunet_config: NnUNetConfig) -> None: configurations=[nnunet_config.value], ) else: - log(INFO, "nnunet preprocessed data seems to already exist. Skipping preprocessing") + log(INFO, "\tnnunet preprocessed data seems to already exist. Skipping preprocessing") + + def maybe_extract_fingerprint(self) -> None: + """ + Checks if nnunet dataset fingerprint already exists and if not extracts one from the dataset + """ + fp_path = join(nnUNet_preprocessed, self.dataset_name, "dataset_fingerprint.json") + if self.always_preprocess or not exists(fp_path): + log(INFO, "\tExtracting nnunet dataset fingerprint") + with nostdout(): # prevent print statements from nnunet method + extract_fingerprints(dataset_ids=[self.dataset_id]) + else: + log(INFO, "\tnnunet dataset fingerprint already exists. Skipping fingerprint extraction") + + # Avoid extracting fingerprint multiple times when always_preprocess is true + self.fingerprint_extracted = True def setup_client(self, config: Config) -> None: """ @@ -291,15 +311,11 @@ def setup_client(self, config: Config) -> None: # Get nnunet config self.nnunet_config = get_valid_nnunet_config(narrow_config_type(config, "nnunet_config", str)) - # Check if dataset fingerprint has been extracted - if self.always_preprocess or not exists( - join(nnUNet_preprocessed, self.dataset_name, "dataset_fingerprint.json") - ): - log(INFO, "Extracting nnunet dataset fingerprint") - with nostdout(): # prevent print statements from nnunet method - extract_fingerprints(dataset_ids=[self.dataset_id]) + # Check if dataset fingerprint has already been extracted + if not self.fingerprint_extracted: + self.maybe_extract_fingerprint() else: - log(INFO, "nnunet dataset fingerprint already exists. Skipping fingerprint extraction") + log(INFO, "\tDataset fingerprint has already been extracted. Skipping.") # Create the nnunet plans for the local client self.plans = self.create_plans(config=config) @@ -319,6 +335,17 @@ def setup_client(self, config: Config) -> None: # do it manually since nnunet_trainer not being used for training self.nnunet_trainer.set_deep_supervision_enabled(self.nnunet_trainer.enable_deep_supervision) + # Prevent nnunet from generating log files. And delete empty output directories + os.remove(self.nnunet_trainer.log_file) + self.nnunet_trainer.log_file = os.devnull + output_folder = Path(self.nnunet_trainer.output_folder) + while True: + if len(os.listdir(output_folder)) == 0: + os.rmdir(output_folder) + output_folder = output_folder.parent + else: + break + # Preprocess nnunet_raw data if needed self.maybe_preprocess(self.nnunet_config) unpack_dataset( # Reduces load on CPU and RAM during training @@ -522,3 +549,72 @@ def update_before_epoch(self, epoch: int) -> None: def get_client_specific_logs(self) -> Tuple[str, List[Tuple[LogLevel, str]]]: lr = self.optimizers["global"].param_groups[0]["lr"] return f" Current LR: {lr}", [] + + def get_properties(self, config: Config) -> Dict[str, Scalar]: + """ + Return properties (sample counts and nnunet plans) of client. + + If nnunet plans are not provided by the server, creates a new set of + nnunet plans from the local client dataset. These plans are intended + to be used for initializing global nnunet plans when they are not + provided. + + Args: + config (Config): The config from the server + + Returns: + Dict[str, Scalar]: A dictionary containing the train and + validation sample counts as well as the serialized nnunet plans + """ + # Check if nnunet plans have already been initialized + if "nnunet_plans" in config.keys(): + properties = super().get_properties(config) + properties["nnunet_plans"] = config["nnunet_plans"] + return properties + + # Check if local nnunet dataset fingerprint needs to be extracted + if not self.fingerprint_extracted: + self.maybe_extract_fingerprint() + + # Create experiment planner and plans + planner = ExperimentPlanner(dataset_name_or_id=self.dataset_id, plans_name="temp_plans") + with nostdout(): # Prevent print statements from experiment planner + plans = planner.plan_experiment() + plans_bytes = pickle.dumps(plans) + + # Remove plans file that was created by planner + plans_path = join(nnUNet_preprocessed, self.dataset_name, planner.plans_identifier + ".json") + if exists(plans_path): + os.remove(plans_path) + + # return properties with initialized nnunet plans. Need to provide + # plans since client needs to be initialized to get properties + config["nnunet_plans"] = plans_bytes + properties = super().get_properties(config) + properties["nnunet_plans"] = pickle.dumps(plans_bytes) + return properties + + def shutdown_dataloader(self, dataloader: Optional[DataLoader], dl_name: Optional[str] = None) -> None: + """ + Checks the dataloaders type and if it is a MultiThreadedAugmenter or + NonDetMultiThreadedAugmenter calls the _finish method to ensure they + are properly shutdown + + Args: + dataloader (DataLoader): The dataloader to shutdown + dl_name (Optional[str]): A string that identifies the dataloader + to shutdown. Used for logging purposes. Defaults to None + """ + if dataloader is not None and isinstance(dataloader, nnUNetDataLoaderWrapper): + if isinstance(dataloader.nnunet_dataloader, (NonDetMultiThreadedAugmenter, MultiThreadedAugmenter)): + if dl_name is not None: + log(INFO, f"\tShutting down nnunet dataloader: {dl_name}") + dataloader.nnunet_dataloader._finish() + + def shutdown(self) -> None: + # Not entirely sure if processes potentially opened by nnunet + # dataloaders were being ended so ensure that they are ended here + self.shutdown_dataloader(self.train_loader, "train_loader") + self.shutdown_dataloader(self.val_loader, "val_loader") + self.shutdown_dataloader(self.test_loader, "test_loader") + return super().shutdown() diff --git a/research/picai/fl_nnunet/nnunet_server.py b/research/picai/fl_nnunet/nnunet_server.py new file mode 100644 index 000000000..9f4a9a0c0 --- /dev/null +++ b/research/picai/fl_nnunet/nnunet_server.py @@ -0,0 +1,86 @@ +from logging import INFO, WARN +from typing import Any, Callable, List, Optional, Tuple, Union + +from flwr.common import Parameters +from flwr.common.logger import log +from flwr.common.typing import Code, Config, EvaluateIns, FitIns, GetPropertiesIns +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy + +from fl4health.server.base_server import FlServerWithInitializer + +FIT_CFG_FN = Callable[[int, Parameters, ClientManager], List[Tuple[ClientProxy, FitIns]]] +EVAL_CFG_FN = Callable[[int, Parameters, ClientManager], List[Tuple[ClientProxy, EvaluateIns]]] +CFG_FN = Union[FIT_CFG_FN, EVAL_CFG_FN] + + +def add_items_to_config_fn(fn: CFG_FN, items: Config) -> CFG_FN: + """ + Accepts a flwr Strategy configure function (either configure_fit or + configure_evaluate) and returns a new function that returns the same thing + except the dictionary items in the items argument have been added to the + config that is returned by the original function + + Args: + fn (CFG_FN): The Strategy configure function to wrap + items (Config): A Config containing additional items to update the + original config with + + Returns: + CFG_FN: The wrapped function. Argument and return type is the same + """ + + def new_fn(*args: Any, **kwargs: Any) -> Any: + cfg_ins = fn(*args, **kwargs) + for _, ins in cfg_ins: + ins.config.update(items) + return cfg_ins + + return new_fn + + +class NnUNetServer(FlServerWithInitializer): + """ + A Basic FlServer with added functionality to ask a client to initialize + the global nnunet plans if one was not provided in the config. Intended + for use with NnUNetClient + """ + + def initialize(self, server_round: int, timeout: Optional[float] = None) -> None: + # Get fit config + dummy_params = Parameters([], "None") + config = self.strategy.configure_fit(server_round, dummy_params, self._client_manager)[0][1].config + + # Check if plans need to be initialized + if config.get("nnunet_plans") is not None: + self.initialized = True + return + + # Sample properties from a random client to initialize plans + log(INFO, "") + log(INFO, "[PRE-INIT]") + log(INFO, "Requesting initialization of global nnunet plans from one random client via get_properties") + random_client = self._client_manager.sample(1)[0] + ins = GetPropertiesIns(config=config) + properties_res = random_client.get_properties(ins=ins, timeout=timeout, group_id=server_round) + + if properties_res.status.code == Code.OK: + log(INFO, "Recieved global nnunet plans from one random client") + else: + log(WARN, "Failed to receive properties from client to initialize nnnunet plans") + + properties = properties_res.properties + + # NnUNetClient has serialized nnunet_plans as a property + plans_bytes = properties["nnunet_plans"] + + # Wrap config functions so that nnunet_plans is included + log(INFO, "Wrapping strategy config functions to return nnunet_plans") + new_fit_cfg_fn = add_items_to_config_fn(self.strategy.configure_fit, {"nnunet_plans": plans_bytes}) + new_eval_cfg_fn = add_items_to_config_fn(self.strategy.configure_evaluate, {"nnunet_plans": plans_bytes}) + setattr(self.strategy, "configure_fit", new_fit_cfg_fn) + setattr(self.strategy, "configure_evaluate", new_eval_cfg_fn) + + # Finish + self.initialized = True + log(INFO, "") diff --git a/research/picai/fl_nnunet/start_client.py b/research/picai/fl_nnunet/start_client.py index 4b56885a6..f269fbe83 100644 --- a/research/picai/fl_nnunet/start_client.py +++ b/research/picai/fl_nnunet/start_client.py @@ -1,7 +1,7 @@ import argparse import warnings from logging import INFO -from os.path import join +from pathlib import Path from typing import Optional, Union with warnings.catch_warnings(): @@ -16,8 +16,6 @@ import flwr as fl import torch from flwr.common.logger import log -from nnunetv2.paths import nnUNet_preprocessed -from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name from torchmetrics.classification import Dice from torchmetrics.segmentation import GeneralizedDiceScore @@ -61,7 +59,6 @@ def main( metrics = [dice1, dice2] # Oddly each of these dice metrics is drastically different. # Create and start client - dataset_name = convert_id_to_dataset_name(dataset_id) client = nnUNetClient( # Args specific to nnUNetClient dataset_id=dataset_id, @@ -72,9 +69,7 @@ def main( # BaseClient Args device=DEVICE, metrics=metrics, - data_path=join( - nnUNet_preprocessed, dataset_name - ), # data_path is not actually used but is required by BaseClient + data_path=Path("dummy/path"), # Argument not used by nnUNetClient ) fl.client.start_client(server_address=server_address, client=client.to_client()) @@ -135,16 +130,7 @@ def main( args = parser.parse_args() # Convert fold to an integer if it is not 'all' - if args.fold != "all": - try: - fold = int(args.fold) - except ValueError as e: - print( - f"Unable to convert given value for fold to int: {args.fold}. Fold must be either 'all' or an integer" - ) - raise e - else: - fold = args.fold + fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold) main( dataset_id=args.dataset_id, diff --git a/research/picai/fl_nnunet/start_server.py b/research/picai/fl_nnunet/start_server.py index dbda59298..890366b53 100644 --- a/research/picai/fl_nnunet/start_server.py +++ b/research/picai/fl_nnunet/start_server.py @@ -23,41 +23,47 @@ from flwr.server.strategy import FedAvg from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.server.base_server import FlServer # This is the lightning utils deprecation warning culprit from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn +from research.picai.fl_nnunet.nnunet_server import NnUNetServer def get_config( current_server_round: int, nnunet_config: str, - nnunet_plans: str, n_server_rounds: int, batch_size: int, n_clients: int, + nnunet_plans: Optional[str] = None, local_epochs: Optional[int] = None, local_steps: Optional[int] = None, ) -> Config: - nnunet_plans_dict = pickle.dumps(json.load(open(nnunet_plans, "r"))) - return { + # Create config + config: Config = { "n_clients": n_clients, "nnunet_config": nnunet_config, - "nnunet_plans": nnunet_plans_dict, "n_server_rounds": n_server_rounds, "batch_size": batch_size, **make_dict_with_epochs_or_steps(local_epochs, local_steps), "current_server_round": current_server_round, } + # Check if plans were provided + if nnunet_plans is not None: + plans_bytes = pickle.dumps(json.load(open(nnunet_plans, "r"))) + config["nnunet_plans"] = plans_bytes -def main(config: dict) -> None: + return config + + +def main(config: dict, server_address: str) -> None: # Partial function with everything set except current server round fit_config_fn = partial( get_config, n_clients=config["n_clients"], nnunet_config=config["nnunet_config"], - nnunet_plans=config["nnunet_plans"], n_server_rounds=config["n_server_rounds"], batch_size=0, # Set this to 0 because we're not using it + nnunet_plans=config.get("nnunet_plans"), local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), ) @@ -67,12 +73,12 @@ def main(config: dict) -> None: # Of course nnunet stores their pytorch models differently. params = ndarrays_to_parameters([val.cpu().numpy() for _, val in model["network_weights"].items()]) else: - raise Exception( - "There is a bug right now where params can not be None. \ - Therefore a starting checkpoint must be provided because I don't \ - want to mess up my code. I hav raised an issue with flwr" - ) - # params = None + # raise Exception( + # "There is a bug right now where params can not be None. \ + # Therefore a starting checkpoint must be provided because I don't \ + # want to mess up my code. I hav raised an issue with flwr" + # ) + params = None strategy = FedAvg( min_fit_clients=config["n_clients"], @@ -85,25 +91,35 @@ def main(config: dict) -> None: initial_parameters=params, ) - server = FlServer(client_manager=SimpleClientManager(), strategy=strategy) + # server = FlServer(client_manager=SimpleClientManager(), strategy=strategy) + server = NnUNetServer(client_manager=SimpleClientManager(), strategy=strategy) fl.server.start_server( server=server, - server_address=config["server_address"], + server_address=server_address, config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) # Shutdown server - # server.shutdown() + server.shutdown() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config-path", action="store", type=str, help="Path to the configuration file") + parser.add_argument( + "--server-address", + type=str, + required=False, + default="0.0.0.0:8080", + help="""[OPTIONAL] The address to use for the server. Defaults to + 0.0.0.0:8080""", + ) + args = parser.parse_args() with open(args.config_path, "r") as f: config = yaml.safe_load(f) - main(config) + main(config, args.server_address) diff --git a/tests/smoke_tests/nnunet_config_2d.yaml b/tests/smoke_tests/nnunet_config_2d.yaml new file mode 100644 index 000000000..d2371d0d6 --- /dev/null +++ b/tests/smoke_tests/nnunet_config_2d.yaml @@ -0,0 +1,8 @@ +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_epochs: 1 + +nnunet_config: 2d diff --git a/tests/smoke_tests/nnunet_config_3d.yaml b/tests/smoke_tests/nnunet_config_3d.yaml new file mode 100644 index 000000000..5244fdacc --- /dev/null +++ b/tests/smoke_tests/nnunet_config_3d.yaml @@ -0,0 +1,8 @@ +# Parameters that describe the server +n_server_rounds: 1 + +# Parameters that describe the clients +n_clients: 1 +local_steps: 5 + +nnunet_config: 3d_fullres diff --git a/tests/smoke_tests/run_smoke_test.py b/tests/smoke_tests/run_smoke_test.py index 841ba9da4..0b332f4ac 100644 --- a/tests/smoke_tests/run_smoke_test.py +++ b/tests/smoke_tests/run_smoke_test.py @@ -153,6 +153,7 @@ async def run_smoke_test( full_server_output = "" startup_messages = [ # printed by fedprox, apfl, basic_example, fedbn, fedper, fedrep, and ditto, FENDA, fl_plus_local_ft and moon + # Update, this is no longer in output, examples are actually being triggered by the [ROUND 1] startup message "FL starting", # printed by scaffold "Using Warm Start Strategy. Waiting for clients to be available for polling", @@ -161,6 +162,10 @@ async def run_smoke_test( # printed by federated_eval "Federated Evaluation Starting", "[ROUND 1]", + # As far as I can tell this is printed by most servers that inherit from FlServer + "Flower ECE: gRPC server running ", + "gRPC server running", + "server running", ] output_found = False @@ -640,4 +645,20 @@ def load_metrics_from_file(file_path: str) -> Dict[str, Any]: dataset_path="examples/datasets/cifar_data/", ) ) + loop.run_until_complete( + run_smoke_test( # By default will use Task04_Hippocampus Dataset + server_python_path="examples.nnunet_example.server", + client_python_path="examples.nnunet_example.client", + config_path="tests/smoke_tests/nnunet_config_2d.yaml", + dataset_path="examples/datasets/nnunet", + ) + ) + loop.run_until_complete( + run_smoke_test( # By default will use Task04_Hippocampus Dataset + server_python_path="examples.nnunet_example.server", + client_python_path="examples.nnunet_example.client", + config_path="tests/smoke_tests/nnunet_config_3d.yaml", + dataset_path="examples/datasets/nnunet", + ) + ) loop.close()