diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_baseline.png b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_baseline.png new file mode 100644 index 0000000000..9ff1fcdb4c Binary files /dev/null and b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_baseline.png differ diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl.png b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl.png new file mode 100644 index 0000000000..df082d406a Binary files /dev/null and b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl.png differ diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl_he.png b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl_he.png new file mode 100644 index 0000000000..b1610c4183 Binary files /dev/null and b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/figs/km_curve_fl_he.png differ diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/km_job.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/km_job.py new file mode 100644 index 0000000000..6a3e79f164 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/km_job.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from src.kaplan_meier_wf import KM +from src.kaplan_meier_wf_he import KM_HE + +from nvflare import FedJob +from nvflare.job_config.script_runner import ScriptRunner + + +def main(): + args = define_parser() + # Default paths + data_root = "/tmp/nvflare/dataset/km_data" + he_context_path = "/tmp/nvflare/he_context/he_context_client.txt" + + # Set the script and config + if args.encryption: + job_name = "KM_HE" + train_script = "src/kaplan_meier_train_he.py" + script_args = f"--data_root {data_root} --he_context_path {he_context_path}" + else: + job_name = "KM" + train_script = "src/kaplan_meier_train.py" + script_args = f"--data_root {data_root}" + + # Set the number of clients and threads + num_clients = args.num_clients + if args.num_threads: + num_threads = args.num_threads + else: + num_threads = num_clients + + # Set the output workspace and job directories + workspace_dir = os.path.join(args.workspace_dir, job_name) + job_dir = args.job_dir + + # Create the FedJob + job = FedJob(name=job_name, min_clients=num_clients) + + # Define the KM controller workflow and send to server + if args.encryption: + controller = KM_HE(min_clients=num_clients, he_context_path=he_context_path) + else: + controller = KM(min_clients=num_clients) + job.to_server(controller) + + # Define the ScriptRunner and send to all clients + runner = ScriptRunner( + script=train_script, + script_args=script_args, + params_exchange_format="raw", + launch_external_process=False, + ) + job.to_clients(runner, tasks=["train"]) + + # Export the job + print("job_dir=", job_dir) + job.export_job(job_dir) + + # Run the job + print("workspace_dir=", workspace_dir) + print("num_threads=", num_threads) + job.simulator_run(workspace_dir, n_clients=num_clients, threads=num_threads) + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--workspace_dir", + type=str, + default="/tmp/nvflare/jobs/km/workdir", + help="work directory, default to '/tmp/nvflare/jobs/km/workdir'", + ) + parser.add_argument( + "--job_dir", + type=str, + default="/tmp/nvflare/jobs/km/jobdir", + help="directory for job export, default to '/tmp/nvflare/jobs/km/jobdir'", + ) + parser.add_argument( + "--encryption", + action=argparse.BooleanOptionalAction, + help="whether to enable encryption, default to False", + ) + parser.add_argument( + "--num_clients", + type=int, + default=5, + help="number of clients to simulate, default to 5", + ) + parser.add_argument( + "--num_threads", + type=int, + help="number of threads to use for FL simulation, default to the number of clients if not specified", + ) + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/requirements.txt b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/requirements.txt new file mode 100644 index 0000000000..e6d18ba9a3 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/requirements.txt @@ -0,0 +1,3 @@ +lifelines +tenseal +scikit-survival diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_train.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_train.py new file mode 100644 index 0000000000..d8d7e55d28 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_train.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from lifelines import KaplanMeierFitter +from lifelines.utils import survival_table_from_events + +# (1) import nvflare client API +import nvflare.client as flare +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType + + +# Client code +def details_save(kmf): + # Get the survival function at all observed time points + survival_function_at_all_times = kmf.survival_function_ + # Get the timeline (time points) + timeline = survival_function_at_all_times.index.values + # Get the KM estimate + km_estimate = survival_function_at_all_times["KM_estimate"].values + # Get the event count at each time point + event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events + # Get the survival rate at each time point (using the 1st column of the survival function) + survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values + # Return the results + results = { + "timeline": timeline.tolist(), + "km_estimate": km_estimate.tolist(), + "event_count": event_count.tolist(), + "survival_rate": survival_rate.tolist(), + } + file_path = os.path.join(os.getcwd(), "km_global.json") + print(f"save the details of KM analysis result to {file_path} \n") + with open(file_path, "w") as json_file: + json.dump(results, json_file, indent=4) + + +def plot_and_save(kmf): + # Plot and save the Kaplan-Meier survival curve + plt.figure() + plt.title("Federated") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + plt.tight_layout() + file_path = os.path.join(os.getcwd(), "km_curve_fl.png") + print(f"save the curve plot to {file_path} \n") + plt.savefig(file_path) + + +def main(): + parser = argparse.ArgumentParser(description="KM analysis") + parser.add_argument("--data_root", type=str, help="Root path for data files") + args = parser.parse_args() + + flare.init() + + site_name = flare.get_site_name() + print(f"Kaplan-meier analysis for {site_name}") + + # get local data + data_path = os.path.join(args.data_root, site_name + ".csv") + data = pd.read_csv(data_path) + event_local = data["event"] + time_local = data["time"] + + while flare.is_running(): + # receives global message from NVFlare + global_msg = flare.receive() + curr_round = global_msg.current_round + print(f"current_round={curr_round}") + + if curr_round == 1: + # First round: + # Empty payload from server, send local histogram + # Convert local data to histogram + event_table = survival_table_from_events(time_local, event_local) + hist_idx = event_table.index.values.astype(int) + hist_obs = {} + hist_cen = {} + for idx in range(max(hist_idx)): + hist_obs[idx] = 0 + hist_cen[idx] = 0 + # Assign values + idx = event_table.index.values.astype(int) + observed = event_table["observed"].to_numpy() + censored = event_table["censored"].to_numpy() + for i in range(len(idx)): + hist_obs[idx[i]] = observed[i] + hist_cen[idx[i]] = censored[i] + # Send histograms to server + response = FLModel(params={"hist_obs": hist_obs, "hist_cen": hist_cen}, params_type=ParamsType.FULL) + flare.send(response) + + elif curr_round == 2: + # Get global histograms + hist_obs_global = global_msg.params["hist_obs_global"] + hist_cen_global = global_msg.params["hist_cen_global"] + # Unfold histogram to event list + time_unfold = [] + event_unfold = [] + for i in hist_obs_global.keys(): + for j in range(hist_obs_global[i]): + time_unfold.append(i) + event_unfold.append(True) + for k in range(hist_cen_global[i]): + time_unfold.append(i) + event_unfold.append(False) + time_unfold = np.array(time_unfold) + event_unfold = np.array(event_unfold) + + # Perform Kaplan-Meier analysis on global aggregated information + # Create a Kaplan-Meier estimator + kmf = KaplanMeierFitter() + + # Fit the model + kmf.fit(durations=time_unfold, event_observed=event_unfold) + + # Plot and save the KM curve + plot_and_save(kmf) + + # Save details of the KM result to a json file + details_save(kmf) + + # Send a simple response to server + response = FLModel(params={}, params_type=ParamsType.FULL) + flare.send(response) + + print(f"finish send for {site_name}, complete") + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_train_he.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_train_he.py new file mode 100644 index 0000000000..1ff9c69dbb --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_train_he.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import base64 +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import tenseal as ts +from lifelines import KaplanMeierFitter +from lifelines.utils import survival_table_from_events + +# (1) import nvflare client API +import nvflare.client as flare +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType + + +# Client code +def read_data(file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + + +def details_save(kmf): + # Get the survival function at all observed time points + survival_function_at_all_times = kmf.survival_function_ + # Get the timeline (time points) + timeline = survival_function_at_all_times.index.values + # Get the KM estimate + km_estimate = survival_function_at_all_times["KM_estimate"].values + # Get the event count at each time point + event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events + # Get the survival rate at each time point (using the 1st column of the survival function) + survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values + # Return the results + results = { + "timeline": timeline.tolist(), + "km_estimate": km_estimate.tolist(), + "event_count": event_count.tolist(), + "survival_rate": survival_rate.tolist(), + } + file_path = os.path.join(os.getcwd(), "km_global.json") + print(f"save the details of KM analysis result to {file_path} \n") + with open(file_path, "w") as json_file: + json.dump(results, json_file, indent=4) + + +def plot_and_save(kmf): + # Plot and save the Kaplan-Meier survival curve + plt.figure() + plt.title("Federated HE") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + plt.tight_layout() + file_path = os.path.join(os.getcwd(), "km_curve_fl_he.png") + print(f"save the curve plot to {file_path} \n") + plt.savefig(file_path) + + +def main(): + parser = argparse.ArgumentParser(description="KM analysis") + parser.add_argument("--data_root", type=str, help="Root path for data files") + parser.add_argument("--he_context_path", type=str, help="Path for the HE context file") + args = parser.parse_args() + + flare.init() + + site_name = flare.get_site_name() + print(f"Kaplan-meier analysis for {site_name}") + + # get local data + data_path = os.path.join(args.data_root, site_name + ".csv") + data = pd.read_csv(data_path) + event_local = data["event"] + time_local = data["time"] + + # HE context + # In real-life application, HE context is prepared by secure provisioning + he_context_serial = read_data(args.he_context_path) + he_context = ts.context_from(he_context_serial) + + while flare.is_running(): + # receives global message from NVFlare + global_msg = flare.receive() + curr_round = global_msg.current_round + print(f"current_round={curr_round}") + + if curr_round == 1: + # First round: + # Empty payload from server, send max index back + # Condense local data to histogram + event_table = survival_table_from_events(time_local, event_local) + hist_idx = event_table.index.values.astype(int) + # Get the max index to be synced globally + max_hist_idx = max(hist_idx) + + # Send max to server + print(f"send max hist index for site = {flare.get_site_name()}") + model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) + flare.send(model) + + elif curr_round == 2: + # Second round, get global max index + # Organize local histogram and encrypt + max_idx_global = global_msg.params["max_idx_global"] + print("Global Max Idx") + print(max_idx_global) + # Convert local table to uniform histogram + hist_obs = {} + hist_cen = {} + for idx in range(max_idx_global): + hist_obs[idx] = 0 + hist_cen[idx] = 0 + # assign values + idx = event_table.index.values.astype(int) + observed = event_table["observed"].to_numpy() + censored = event_table["censored"].to_numpy() + for i in range(len(idx)): + hist_obs[idx[i]] = observed[i] + hist_cen[idx[i]] = censored[i] + # Encrypt with tenseal using BFV scheme since observations are integers + hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) + hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) + # Serialize for transmission + hist_obs_he_serial = hist_obs_he.serialize() + hist_cen_he_serial = hist_cen_he.serialize() + # Send encrypted histograms to server + response = FLModel( + params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL + ) + flare.send(response) + + elif curr_round == 3: + # Get global histograms + hist_obs_global_serial = global_msg.params["hist_obs_global"] + hist_cen_global_serial = global_msg.params["hist_cen_global"] + # Deserialize + hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) + hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) + # Decrypt + hist_obs_global = hist_obs_global.decrypt() + hist_cen_global = hist_cen_global.decrypt() + # Unfold histogram to event list + time_unfold = [] + event_unfold = [] + for i in range(max_idx_global): + for j in range(hist_obs_global[i]): + time_unfold.append(i) + event_unfold.append(True) + for k in range(hist_cen_global[i]): + time_unfold.append(i) + event_unfold.append(False) + time_unfold = np.array(time_unfold) + event_unfold = np.array(event_unfold) + + # Perform Kaplan-Meier analysis on global aggregated information + # Create a Kaplan-Meier estimator + kmf = KaplanMeierFitter() + + # Fit the model + kmf.fit(durations=time_unfold, event_observed=event_unfold) + + # Plot and save the KM curve + plot_and_save(kmf) + + # Save details of the KM result to a json file + details_save(kmf) + + # Send a simple response to server + response = FLModel(params={}, params_type=ParamsType.FULL) + flare.send(response) + + print(f"finish send for {site_name}, complete") + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_wf.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_wf.py new file mode 100644 index 0000000000..54fa1d384c --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_wf.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.workflows.model_controller import ModelController + + +# Controller Workflow +class KM(ModelController): + def __init__(self, min_clients: int): + super(KM, self).__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.min_clients = min_clients + self.num_rounds = 2 + + def run(self): + hist_local = self.start_fl_collect_hist() + hist_obs_global, hist_cen_global = self.aggr_hist(hist_local) + _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) + + def start_fl_collect_hist(self): + self.logger.info("send initial message to all sites to start FL \n") + model = FLModel(params={}, start_round=1, current_round=1, total_rounds=self.num_rounds) + + results = self.send_model_and_wait(data=model) + return results + + def aggr_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate histogram \n") + + if not sag_result: + raise RuntimeError("input is None or empty") + + hist_idx_max = 0 + for fl_model in sag_result: + hist = fl_model.params["hist_obs"] + if hist_idx_max < max(hist.keys()): + hist_idx_max = max(hist.keys()) + hist_idx_max += 1 + + hist_obs_global = {} + hist_cen_global = {} + for idx in range(hist_idx_max + 1): + hist_obs_global[idx] = 0 + hist_cen_global[idx] = 0 + + for fl_model in sag_result: + hist_obs = fl_model.params["hist_obs"] + hist_cen = fl_model.params["hist_cen"] + for i in hist_obs.keys(): + hist_obs_global[i] += hist_obs[i] + for i in hist_cen.keys(): + hist_cen_global[i] += hist_cen[i] + + return hist_obs_global, hist_cen_global + + def distribute_global_hist(self, hist_obs_global, hist_cen_global): + self.logger.info("send global accumulated histograms within HE to all sites \n") + + model = FLModel( + params={"hist_obs_global": hist_obs_global, "hist_cen_global": hist_cen_global}, + params_type=ParamsType.FULL, + start_round=1, + current_round=2, + total_rounds=self.num_rounds, + ) + + results = self.send_model_and_wait(data=model) + return results diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_wf_he.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_wf_he.py new file mode 100644 index 0000000000..12acf51f4b --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/src/kaplan_meier_wf_he.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import logging +from typing import Dict + +import tenseal as ts + +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.workflows.model_controller import ModelController + +# Controller Workflow + + +class KM_HE(ModelController): + def __init__(self, min_clients: int, he_context_path: str): + super(KM_HE, self).__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.min_clients = min_clients + self.he_context_path = he_context_path + self.num_rounds = 3 + + def run(self): + max_idx_results = self.start_fl_collect_max_idx() + global_res = self.aggr_max_idx(max_idx_results) + enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) + hist_obs_global, hist_cen_global = self.aggr_he_hist(enc_hist_results) + _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) + + def read_data(self, file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + + def start_fl_collect_max_idx(self): + self.logger.info("send initial message to all sites to start FL \n") + model = FLModel(params={}, start_round=1, current_round=1, total_rounds=self.num_rounds) + + results = self.send_model_and_wait(data=model) + return results + + def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate max histogram index \n") + + if not sag_result: + raise RuntimeError("input is None or empty") + + max_idx_global = [] + for fl_model in sag_result: + max_idx = fl_model.params["max_idx"] + max_idx_global.append(max_idx) + # actual time point as index, so plus 1 for storage + return max(max_idx_global) + 1 + + def distribute_max_idx_collect_enc_stats(self, result: int): + self.logger.info("send global max_index to all sites \n") + + model = FLModel( + params={"max_idx_global": result}, + params_type=ParamsType.FULL, + start_round=1, + current_round=2, + total_rounds=self.num_rounds, + ) + + results = self.send_model_and_wait(data=model) + return results + + def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate histogram within HE \n") + + # Load HE context + he_context_serial = self.read_data(self.he_context_path) + he_context = ts.context_from(he_context_serial) + + if not sag_result: + raise RuntimeError("input is None or empty") + + hist_obs_global = None + hist_cen_global = None + for fl_model in sag_result: + site = fl_model.meta.get("client_name", None) + hist_obs_he_serial = fl_model.params["hist_obs"] + hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) + hist_cen_he_serial = fl_model.params["hist_cen"] + hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) + + if not hist_obs_global: + print(f"assign global hist with result from {site}") + hist_obs_global = hist_obs_he + else: + print(f"add to global hist with result from {site}") + hist_obs_global += hist_obs_he + + if not hist_cen_global: + print(f"assign global hist with result from {site}") + hist_cen_global = hist_cen_he + else: + print(f"add to global hist with result from {site}") + hist_cen_global += hist_cen_he + + # return the two accumulated vectors, serialized for transmission + hist_obs_global_serial = hist_obs_global.serialize() + hist_cen_global_serial = hist_cen_global.serialize() + return hist_obs_global_serial, hist_cen_global_serial + + def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): + self.logger.info("send global accumulated histograms within HE to all sites \n") + + model = FLModel( + params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, + params_type=ParamsType.FULL, + start_round=1, + current_round=3, + total_rounds=self.num_rounds, + ) + + results = self.send_model_and_wait(data=model) + return results diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/baseline_kaplan_meier.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/baseline_kaplan_meier.py new file mode 100644 index 0000000000..0bd37b0bb1 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/baseline_kaplan_meier.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import matplotlib.pyplot as plt +import numpy as np +from lifelines import KaplanMeierFitter +from sksurv.datasets import load_veterans_lung_cancer + + +def args_parser(): + parser = argparse.ArgumentParser(description="Kaplan Meier Survival Analysis Baseline") + parser.add_argument( + "--output_curve_path", + type=str, + default="/tmp/nvflare/baseline/km_curve_baseline.png", + help="save path for the output curve", + ) + return parser + + +def prepare_data(bin_days: int = 7): + data_x, data_y = load_veterans_lung_cancer() + total_data_num = data_x.shape[0] + event = data_y["Status"] + time = data_y["Survival_in_days"] + # Categorize data to a bin, default is a week (7 days) + time = np.ceil(time / bin_days).astype(int) * bin_days + return event, time + + +def main(): + parser = args_parser() + args = parser.parse_args() + + # Set parameters + output_curve_path = args.output_curve_path + + # Set plot + plt.figure() + plt.title("Baseline") + + # Fit and plot Kaplan Meier curve with lifelines + + # Generate data with binning + event, time = prepare_data(bin_days=7) + kmf = KaplanMeierFitter() + # Fit the survival data + kmf.fit(time, event) + # Plot and save the Kaplan-Meier survival curve + kmf.plot_survival_function(label="Binned Weekly") + + # Generate data without binning + event, time = prepare_data(bin_days=1) + kmf = KaplanMeierFitter() + # Fit the survival data + kmf.fit(time, event) + # Plot and save the Kaplan-Meier survival curve + kmf.plot_survival_function(label="No binning - Daily") + + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.tight_layout() + plt.legend() + plt.savefig(output_curve_path) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/prepare_data.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/prepare_data.py new file mode 100644 index 0000000000..0517ad6274 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/prepare_data.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import numpy as np +import pandas as pd +from sksurv.datasets import load_veterans_lung_cancer + +np.random.seed(77) + + +def data_split_args_parser(): + parser = argparse.ArgumentParser(description="Generate data split for dataset") + parser.add_argument("--site_num", type=int, default=5, help="Total number of sites, default is 5") + parser.add_argument( + "--site_name_prefix", + type=str, + default="site-", + help="Site name prefix, default is site-", + ) + parser.add_argument("--bin_days", type=int, default=1, help="Bin days for categorizing data") + parser.add_argument("--out_path", type=str, help="Output root path for split data files") + return parser + + +def prepare_data(data, site_num, bin_days): + # Get total data count + total_data_num = data.shape[0] + print(f"Total data count: {total_data_num}") + # Get event and time + event = data["Status"] + time = data["Survival_in_days"] + # Categorize data to a bin, default is a week (7 days) + time = np.ceil(time / bin_days).astype(int) * bin_days + # Shuffle data + idx = np.random.permutation(total_data_num) + # Split data to clients + event_clients = {} + time_clients = {} + for i in range(site_num): + start = int(i * total_data_num / site_num) + end = int((i + 1) * total_data_num / site_num) + event_i = event[idx[start:end]] + time_i = time[idx[start:end]] + event_clients["site-" + str(i + 1)] = event_i + time_clients["site-" + str(i + 1)] = time_i + return event_clients, time_clients + + +def main(): + parser = data_split_args_parser() + args = parser.parse_args() + + # Load data + # For this KM analysis, we use full timeline and event label only + _, data = load_veterans_lung_cancer() + + # Prepare data + event_clients, time_clients = prepare_data(data=data, site_num=args.site_num, bin_days=args.bin_days) + + # Save data to csv files + if not os.path.exists(args.out_path): + os.makedirs(args.out_path, exist_ok=True) + for site in range(args.site_num): + output_file = os.path.join(args.out_path, f"{args.site_name_prefix}{site + 1}.csv") + df = pd.DataFrame( + { + "event": event_clients["site-" + str(site + 1)], + "time": time_clients["site-" + str(site + 1)], + } + ) + df.to_csv(output_file, index=False) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/prepare_he_context.py b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/prepare_he_context.py new file mode 100644 index 0000000000..ceedf4c9a4 --- /dev/null +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/code/utils/prepare_he_context.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import base64 +import os + +import tenseal as ts + + +def data_split_args_parser(): + parser = argparse.ArgumentParser(description="Generate HE context") + parser.add_argument("--scheme", type=str, default="BFV", help="HE scheme, default is BFV") + parser.add_argument("--poly_modulus_degree", type=int, default=4096, help="Poly modulus degree, default is 4096") + parser.add_argument("--out_path", type=str, help="Output root path for HE context files for client and server") + return parser + + +def write_data(file_name: str, data: bytes): + data = base64.b64encode(data) + with open(file_name, "wb") as f: + f.write(data) + + +def main(): + parser = data_split_args_parser() + args = parser.parse_args() + if args.scheme == "BFV": + scheme = ts.SCHEME_TYPE.BFV + # Generate HE context + context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree, plain_modulus=1032193) + elif args.scheme == "CKKS": + scheme = ts.SCHEME_TYPE.CKKS + # Generate HE context, CKKS does not need plain_modulus + context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree) + else: + raise ValueError("HE scheme not supported") + + # Save HE context to file for client + if not os.path.exists(args.out_path): + os.makedirs(args.out_path, exist_ok=True) + context_serial = context.serialize(save_secret_key=True) + write_data(os.path.join(args.out_path, "he_context_client.txt"), context_serial) + + # Save HE context to file for server + context_serial = context.serialize(save_secret_key=False) + write_data(os.path.join(args.out_path, "he_context_server.txt"), context_serial) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/convert_survival_analysis_to_fl.ipynb b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/convert_survival_analysis_to_fl.ipynb index f53ad94e0f..3626b93cc5 100644 --- a/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/convert_survival_analysis_to_fl.ipynb +++ b/examples/tutorials/self-paced-training/part-1_federated_learning_introduction/chapter-2_develop_federated_learning_applications/02.3_convert_machine_learning_to_federated_learning/02.3.3_convert_survival_analysis_to_federated_learning/convert_survival_analysis_to_fl.ipynb @@ -1,19 +1,279 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "d40828dd", + "metadata": {}, + "source": [ + "# Secure Federated Kaplan-Meier Analysis via Time-Binning and Homomorphic Encryption" + ] + }, + { + "cell_type": "markdown", + "id": "c0937cf5", + "metadata": {}, + "source": [ + "This example illustrates two features:\n", + "* How to perform Kaplan-Meier survival analysis in federated setting without and with secure features via time-binning and Homomorphic Encryption (HE).\n", + "* How to use the FLARE ModelController API to contract a workflow to facilitate HE under simulator mode." + ] + }, + { + "cell_type": "markdown", + "id": "da8644ba", + "metadata": {}, + "source": [ + "## Basics of Kaplan-Meier Analysis\n", + "Kaplan-Meier survival analysis is a non-parametric statistic used to estimate the survival function from lifetime data. It is used to analyze the time it takes for an event of interest to occur. For example, during a clinical trial, the Kaplan-Meier estimator can be used to estimate the proportion of patients who survive a certain amount of time after treatment. \n", + "\n", + "The Kaplan-Meier estimator takes into account the time of the event (e.g. \"Survival Days\") and whether the event was observed or censored. An event is observed if the event of interest (e.g. \"death\") occurred at the end of the observation process. An event is censored if the event of interest did not occur (i.e. patient is still alive) at the end of the observation process.\n", + "\n", + "One example dataset used here for Kaplan-Meier analysis is the `veterans_lung_cancer` dataset. This dataset contains information about the survival time of veterans with advanced lung cancer. Below we provide some samples of the dataset:\n", + "\n", + "| ID | Age | Celltype | Karnofsky | Diagtime | Prior | Treat | Status | Survival Days |\n", + "|----|-----|------------|------------|----------|-------|-----------|--------|---------------|\n", + "| 1 | 64 | squamous | 70 | 5 | yes | standard | TRUE | 411 |\n", + "| 20 | 55 | smallcell | 40 | 3 | no | standard | FALSE | 123 |\n", + "| 45 | 61 | adeno | 20 | 19 | yes | standard | TRUE | 8 |\n", + "| 63 | 62 | large | 90 | 2 | no | standard | FALSE | 182 |\n", + "\n", + "To perform the analysis, in this data, we have:\n", + "- Time `Survival Days`: days passed from the beginning of the observation till the end\n", + "- Event `Status`: whether event (i.e. death) happened at the end of the observation, or not\n", + "\n", + "Based on the above understanding, we can interpret the data as follows:\n", + "- Patient #1 goes through an observation period of 411 days, and passes away at Day 411\n", + "- Patient #20 goes through an observation period of 123 days, and is still alive when the observation stops at Day 123 \n", + "\n", + "The purpose of Kaplan-Meier analysis is to estimate the survival function, which is the probability that a patient survives beyond a certain time. Naturally, it will be a monotonic decreasing function, since the probability of surviving will decrease as time goes by." + ] + }, + { + "cell_type": "markdown", + "id": "06986478", + "metadata": {}, + "source": [ + "## Secure Multi-party Kaplan-Meier Analysis\n", + "As described above, Kaplan-Meier survival analysis is a one-shot (non-iterative) analysis performed on a list of events (`Status`) and their corresponding time (`Survival Days`). In this example, we use [lifelines](https://zenodo.org/records/10456828) to perform this analysis. \n", + "\n", + "Essentially, the estimator needs to get access to this event list, and under the setting of federated analysis, the aggregated event list from all participants.\n", + "\n", + "However, this poses a data security concern - the event list is equivalent to the raw data. If it gets exposed to external parties, it essentially breaks the core value of federated analysis.\n", + "\n", + "Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing the raw information from a participant, the targeted protection includes:\n", + "- Prevent clients from getting RAW data from each other;\n", + "- Prevent the aggregation server to access ANY information from participants' submissions.\n", + "\n", + "This is achieved by two techniques:\n", + "- Condense the raw event list to two histograms (one for observed events and the other for censored event) using binning at certain interval (e.g. a week)\n", + "- Perform the aggregation of the histograms using Homomorphic Encryption (HE)\n", + "\n", + "With time-binning, the above event list will be converted to histograms: if using a week as interval:\n", + "- Patient #1 will contribute 1 to the 411/7 = 58th bin of the observed event histogram\n", + "- Patient #20 will contribute 1 to the 123/7 = 17th bin of the censored event histogram\n", + "\n", + "In this way, events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms. Note that coarser binning will lead to higher protection, but also lower resolution of the final Kaplan-Meier curve.\n", + "\n", + "Local histograms will then be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE. This will not cause any information loss, while the server will not be able to access any plain-text information.\n", + "\n", + "With these two settings, the server will have no access to any knowledge regarding local submissions, and participants will only receive global aggregated histograms that will not contain distinguishable information regarding any individual participants (client number >= 3 - if only two participants, one can infer the other party's info by subtracting its own histograms).\n", + "\n", + "The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from decrypted global histograms." + ] + }, + { + "cell_type": "markdown", + "id": "f75beeb3", + "metadata": {}, + "source": [ + "## Install requirements\n", + "Make sure to install the required packages:" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "348e87c0-2e9f-4852-9d6a-1a9db5cb5dde", + "id": "56133db2", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "%pip install -r code/requirements.txt" + ] + }, + { + "cell_type": "markdown", + "id": "d4b57b15", + "metadata": {}, + "source": [ + "## Baseline Kaplan-Meier Analysis\n", + "We first illustrate the baseline centralized Kaplan-Meier analysis without any secure features. We used veterans_lung_cancer dataset by\n", + "`from sksurv.datasets import load_veterans_lung_cancer`, and used `Status` as the event type and `Survival_in_days` as the event time to construct the event list.\n", + "\n", + "To run the baseline script, simply execute:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41206a7d", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "! python3 utils/baseline_kaplan_meier.py" + ] + }, + { + "cell_type": "markdown", + "id": "31ab94be", + "metadata": {}, + "source": [ + "By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp` directory. The resutling KM curve is shown below:\n", + "\n", + "![KM survival baseline](code/figs/km_curve_baseline.png)\n", + "\n", + "Here, we show the survival curve for both daily (without binning) and weekly binning. The two curves aligns well with each other, while the weekly-binned curve has lower resolution." + ] + }, + { + "cell_type": "markdown", + "id": "a42f69c0", + "metadata": {}, + "source": [ + "## Federated Kaplan-Meier Analysis without and with Homomorphic Encryption\n", + "We make use of the FLARE ModelController API to implement the federated Kaplan-Meier analysis, both without and with HE.\n", + "\n", + "The FLARE ModelController API (`ModelController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning.\n", + "\n", + "Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) use a data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under the [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate ModelController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models):\n", + "- different HE schemes (BFV) rather than CKKS\n", + "- different content at different rounds of federated learning, and only specific payloads need to be encrypted\n", + "\n", + "With the ModelController API, such experiments become easy. In this example, the federated analysis pipeline includes 2 rounds without HE or 3 rounds with HE.\n", + "\n", + "For the federated analysis without HE, the detailed steps are as follows:\n", + "1. Server sends the simple start message without any payload.\n", + "2. Clients submit the local event histograms to server. Server aggregates the histograms with varying lengths by adding event counts of the same slot together, and sends the aggregated histograms back to clients.\n", + "\n", + "For the federated analysis with HE, we need to ensure proper HE aggregation using BFV, and the detailed steps are as follows:\n", + "1. Server sends the simple start message without any payload. \n", + "2. Clients collect the information of the local maximum bin number (for event time) and send to the server, where the server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to the clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable.\n", + "3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histogram value vectors, and send to the server. The server aggregates the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to the clients.\n", + "\n", + "After these rounds, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information." + ] + }, + { + "cell_type": "markdown", + "id": "302c4285", + "metadata": {}, + "source": [ + "## Run the job\n", + "First, we prepare data for a 5-client federated job. We split and generate the data files for each client with binning interval of 7 days." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a354d0d", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "! python3 code/utils/prepare_data.py --site_num 5 --bin_days 7 --out_path \"/tmp/nvflare/dataset/km_data\"" + ] + }, + { + "cell_type": "markdown", + "id": "40d6fa4e", + "metadata": {}, + "source": [ + "Then, we prepare the HE context for the clients and the server. Note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this script to distribute the HE context." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b12b162d", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "! python3 code/utils/prepare_he_context.py --out_path \"/tmp/nvflare/he_context\"" + ] + }, + { + "cell_type": "markdown", + "id": "7cc4d792", + "metadata": {}, + "source": [ + "Next, we run the federated training using the NVFlare Simulator via the [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html), both without and with HE:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4c91649", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "! python3 code/km_job.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c24c50a", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "! python3 code/km_job.py --encryption" + ] + }, + { + "cell_type": "markdown", + "id": "e31897b5", + "metadata": {}, + "source": [ + "By default, this will generate a KM curve image `km_curve_fl.png` and `km_curve_fl_he.png` under each client's directory." + ] + }, + { + "cell_type": "markdown", + "id": "e12cde9e", + "metadata": {}, + "source": [ + "## Display Result\n", + "\n", + "By comparing the two curves, we can observe that all curves are identical:\n", + "\n", + "![KM survival fl](code/figs/km_curve_fl.png)\n", + "![KM survival fl_he](code/figs/km_curve_fl_he.png)\n" + ] } ], "metadata": { "kernelspec": { - "display_name": "nvflare_example", + "display_name": "Python 3", "language": "python", - "name": "nvflare_example" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -25,7 +285,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.2" + "version": "3.10.12" } }, "nbformat": 4,