Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add existing code for kmeans #3210

Merged
merged 6 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions examples/advanced/sklearn-kmeans/sklearn_kmeans_iris.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
"id": "bd0713e2-e393-41c0-9da0-392535cf8a54",
"metadata": {},
"source": [
"## 4. Run simulated kmeans experiment\n",
"## 3. Run simulated kmeans experiment\n",
"We run the federated training using NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html):"
]
},
Expand All @@ -124,7 +124,7 @@
"id": "913e9ee2-e993-442d-a525-d2baf92af539",
"metadata": {},
"source": [
"## 5. Result visualization\n",
"## 4. Result visualization\n",
"Model accuracy is computed as the homogeneity score between the cluster formed and the ground truth label, which can be visualized in tensorboard."
]
},
Expand All @@ -140,14 +140,6 @@
"%load_ext tensorboard\n",
"%tensorboard --logdir /tmp/nvflare/workspace/works/kmeans/sklearn_kmeans_uniform_3_clients"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bea9ebcd-96f5-45c8-a490-0559fab9991f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from enum import Enum
from typing import List

import numpy as np
from src.kmeans_assembler import KMeansAssembler
from src.kmeans_learner import KMeansLearner

from nvflare import FedJob
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor
from nvflare.app_opt.sklearn.sklearn_executor import SKLearnExecutor


class SplitMethod(Enum):
UNIFORM = "uniform"
LINEAR = "linear"
SQUARE = "square"
EXPONENTIAL = "exponential"


def get_split_ratios(site_num: int, split_method: SplitMethod):
if split_method == SplitMethod.UNIFORM:
ratio_vec = np.ones(site_num)
elif split_method == SplitMethod.LINEAR:
ratio_vec = np.linspace(1, site_num, num=site_num)
elif split_method == SplitMethod.SQUARE:
ratio_vec = np.square(np.linspace(1, site_num, num=site_num))
elif split_method == SplitMethod.EXPONENTIAL:
ratio_vec = np.exp(np.linspace(1, site_num, num=site_num))
else:
raise ValueError(f"Split method {split_method.name} not implemented!")

return ratio_vec


def split_num_proportion(n, site_num, split_method: SplitMethod) -> List[int]:
split = []
ratio_vec = get_split_ratios(site_num, split_method)
total = sum(ratio_vec)
left = n
for site in range(site_num - 1):
x = int(n * ratio_vec[site] / total)
left = left - x
split.append(x)
split.append(left)
return split


def assign_data_index_to_sites(
data_size: int,
valid_fraction: float,
num_sites: int,
split_method: SplitMethod = SplitMethod.UNIFORM,
) -> dict:
if valid_fraction > 1.0:
raise ValueError("validation percent should be less than or equal to 100% of the total data")
elif valid_fraction < 1.0:
valid_size = int(round(data_size * valid_fraction, 0))
train_size = data_size - valid_size
else:
valid_size = data_size
train_size = data_size

site_sizes = split_num_proportion(train_size, num_sites, split_method)
split_data_indices = {
"valid": {"start": 0, "end": valid_size},
}
for site in range(num_sites):
site_id = site + 1
if valid_fraction < 1.0:
idx_start = valid_size + sum(site_sizes[:site])
idx_end = valid_size + sum(site_sizes[: site + 1])
else:
idx_start = sum(site_sizes[:site])
idx_end = sum(site_sizes[: site + 1])
split_data_indices[site_id] = {"start": idx_start, "end": idx_end}

return split_data_indices


def get_file_line_count(input_path: str) -> int:
count = 0
with open(input_path, "r") as fp:
for i, _ in enumerate(fp):
count += 1
return count


def split_data(
data_path: str,
num_clients: int,
valid_frac: float,
split_method: SplitMethod = SplitMethod.UNIFORM,
):
size_total_file = get_file_line_count(data_path)
site_indices = assign_data_index_to_sites(size_total_file, valid_frac, num_clients, split_method)
return site_indices


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--workspace_dir",
type=str,
default="/tmp/nvflare/workspace/works/kmeans",
help="work directory, default to '/tmp/nvflare/workspace/works/kmeans'",
)
parser.add_argument(
"--job_dir",
type=str,
default="/tmp/nvflare/workspace/jobs/kmeans",
help="directory for job export, default to '/tmp/nvflare/workspace/jobs/kmeans'",
)
parser.add_argument(
"--data_path",
type=str,
default="/tmp/nvflare/dataset/sklearn_iris.csv",
help="work directory, default to '/tmp/nvflare/dataset/sklearn_iris.csv'",
)
parser.add_argument(
"--num_clients",
type=int,
default=3,
help="number of clients to simulate, default to 3",
)
parser.add_argument(
"--num_rounds",
type=int,
default=5,
help="number of rounds, default to 5",
)
parser.add_argument(
"--split_mode",
type=str,
default="uniform",
choices=["uniform", "linear", "square", "exponential"],
help="how to split data among clients",
)
parser.add_argument(
"--valid_frac",
type=float,
default=1,
help="fraction of data to use for validation, default to perform validation on all data",
)
return parser.parse_args()


def main():
args = define_parser()
# Get args
data_path = args.data_path
num_clients = args.num_clients
num_rounds = args.num_rounds
split_mode = args.split_mode
valid_frac = args.valid_frac
job_name = f"sklearn_kmeans_{split_mode}_{num_clients}_clients"

# Set the output workspace and job directories
workspace_dir = os.path.join(args.workspace_dir, job_name)
job_dir = args.job_dir

# Create the FedJob
job = FedJob(name=job_name, min_clients=num_clients)

# Define the controller workflow and send to server
controller = ScatterAndGather(
min_clients=num_clients,
num_rounds=num_rounds,
aggregator_id="aggregator",
persistor_id="persistor",
shareable_generator_id="shareable_generator",
train_task_name="train",
)
job.to_server(controller, id="scatter_and_gather")

# Define other server components
assembler = KMeansAssembler()
job.to_server(assembler, id="kmeans_assembler")
aggregator = CollectAndAssembleAggregator(assembler_id="kmeans_assembler")
job.to_server(aggregator, id="aggregator")
shareable_generator = FullModelShareableGenerator()
job.to_server(shareable_generator, id="shareable_generator")
persistor = JoblibModelParamPersistor(
initial_params={"n_clusters": 3},
)
job.to_server(persistor, id="persistor")

# Get the data split numbers and send to each client
# generate data split
site_indices = split_data(
data_path,
num_clients,
valid_frac,
SplitMethod(split_mode),
)

for i in range(1, num_clients + 1):
# Define the executor and send to clients
runner = SKLearnExecutor(learner_id="kmeans_learner")
job.to(runner, f"site-{i}", tasks=["train"])

learner = KMeansLearner(
data_path=data_path,
train_start=site_indices[i]["start"],
train_end=site_indices[i]["end"],
valid_start=site_indices["valid"]["start"],
valid_end=site_indices["valid"]["end"],
random_state=0,
)
job.to(learner, f"site-{i}", id="kmeans_learner")

# Export the job
print("job_dir=", job_dir)
job.export_job(job_dir)

# Run the job
print("workspace_dir=", workspace_dir)
job.simulator_run(workspace_dir)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pandas
scikit-learn
joblib
tensorboard
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

import numpy as np
from sklearn.cluster import KMeans

from nvflare.apis.dxo import DXO, DataKind
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.aggregators.assembler import Assembler
from nvflare.app_common.app_constant import AppConstants


class KMeansAssembler(Assembler):
def __init__(self):
super().__init__(data_kind=DataKind.WEIGHTS)
# Aggregator needs to keep record of historical
# center and count information for mini-batch kmeans
self.center = None
self.count = None
self.n_cluster = 0

def get_model_params(self, dxo: DXO):
data = dxo.data
return {"center": data["center"], "count": data["count"]}

def assemble(self, data: Dict[str, dict], fl_ctx: FLContext) -> DXO:
current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
if current_round == 0:
# First round, collect the information regarding n_feature and n_cluster
# Initialize the aggregated center and count to all zero
client_0 = list(self.collection.keys())[0]
self.n_cluster = self.collection[client_0]["center"].shape[0]
n_feature = self.collection[client_0]["center"].shape[1]
self.center = np.zeros([self.n_cluster, n_feature])
self.count = np.zeros([self.n_cluster])
# perform one round of KMeans over the submitted centers
# to be used as the original center points
# no count for this round
center_collect = []
for _, record in self.collection.items():
center_collect.append(record["center"])
centers = np.concatenate(center_collect)
kmeans_center_initial = KMeans(n_clusters=self.n_cluster)
kmeans_center_initial.fit(centers)
self.center = kmeans_center_initial.cluster_centers_
else:
# Mini-batch k-Means step to assemble the received centers
for center_idx in range(self.n_cluster):
centers_global_rescale = self.center[center_idx] * self.count[center_idx]
# Aggregate center, add new center to previous estimate, weighted by counts
for _, record in self.collection.items():
centers_global_rescale += record["center"][center_idx] * record["count"][center_idx]
self.count[center_idx] += record["count"][center_idx]
# Rescale to compute mean of all points (old and new combined)
alpha = 1 / self.count[center_idx]
centers_global_rescale *= alpha
# Update the global center
self.center[center_idx] = centers_global_rescale
params = {"center": self.center}
dxo = DXO(data_kind=self.expected_data_kind, data=params)

return dxo
Loading
Loading