Skip to content

Commit

Permalink
feat: support unbalanced classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dariush Wahdany committed Dec 13, 2023
1 parent 900c495 commit aaa3070
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ num_samples = 100
dimensionality = 10
num_classes = 5
train_preds = np.random.normal(0,1,(num_samples, dimensionality))
train_targets = np.resize(np.arange(num_classes), num_samples)
train_targets = np.random.randint(num_classes, size=(num_samples)) # Supports unbalanced classes
private_prototypes = dp_learning_ff.give_private_prototypes(train_preds, train_targets, Ps)
>>> private_prototypes.shape
(5, 10)
Expand Down
15 changes: 8 additions & 7 deletions src/dp_learning_ff/prototypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ def give_private_prototypes(
np.ndarray: (k, d)-array containing the private prototypes for each class.
"""
targets = np.unique(train_targets)
train_preds_sorted = np.stack(
[train_preds[train_targets == target] for target in targets]
).copy()
train_preds_sorted = [
train_preds[train_targets == target].copy() for target in targets
]
if subsampling < 1.0:
rng = np.random.default_rng(seed)
rng.shuffle(train_preds_sorted, axis=1)
train_preds_sorted = train_preds_sorted[
:, : int(subsampling * train_preds_sorted.shape[1])
]
subsampled = []
for M_x in train_preds_sorted:
rng.shuffle(M_x, axis=0)
subsampled.append(M_x[: int(subsampling * M_x.shape[0])])
train_preds_sorted = subsampled
protos = np.asarray(
[private_mean(train_preds_sorted[i], Ps) for i, target in enumerate(targets)]
)
Expand Down

0 comments on commit aaa3070

Please sign in to comment.