diff --git a/src/learning_ff/__init__.py b/src/learning_ff/__init__.py index b0cabcc..0f64023 100644 --- a/src/learning_ff/__init__.py +++ b/src/learning_ff/__init__.py @@ -1,3 +1,4 @@ from .least_squares import dp_least_squares from .prototypes import give_private_prototypes -__all__ = ["dp_least_squares", "give_private_prototypes"] \ No newline at end of file + +__all__ = ["dp_least_squares", "give_private_prototypes"] diff --git a/src/learning_ff/least_squares.py b/src/learning_ff/least_squares.py index 2244b7f..4199e54 100644 --- a/src/learning_ff/least_squares.py +++ b/src/learning_ff/least_squares.py @@ -33,7 +33,7 @@ def dp_covariance( cov = X_clip.T @ X_clip # Add Gaussian noise to the matrix cov += rng.normal( - scale=clipping_norm**2 * noise_multiplier * np.sqrt(k_classes), size=(d,d) + scale=clipping_norm**2 * noise_multiplier * np.sqrt(k_classes), size=(d, d) ) return cov @@ -65,11 +65,11 @@ def dp_least_squares( ): """Build and solve the differentially private least squares problem. Algorithm attempts to follow the description (Algorithm 3) in: - - Mehta, H., Krichene, W., Thakurta, A., Kurakin, A., & Cutkosky, A. (2022). - Differentially private image classification from features. + + Mehta, H., Krichene, W., Thakurta, A., Kurakin, A., & Cutkosky, A. (2022). + Differentially private image classification from features. arXiv preprint arXiv:2211.13403. - + Args: A: (n, d) matrix of features y: (n,) vector of labels @@ -95,7 +95,7 @@ def dp_least_squares( A_clip, clipping_norm, noise_multiplier, rng, k_classes=1 ) # k_classes is always 1 for global G targets = np.unique(y) - if k_classes == None: + if k_classes is None: k_classes = np.ones(n, dtype=int) else: assert len(k_classes) == n # each sample has a number of positive classes @@ -116,4 +116,3 @@ def dp_least_squares( thetas.append(theta_class) return np.asarray(thetas) - diff --git a/src/learning_ff/prototypes.py b/src/learning_ff/prototypes.py index 07ef05c..6972f23 100644 --- a/src/learning_ff/prototypes.py +++ b/src/learning_ff/prototypes.py @@ -1,5 +1,3 @@ -# Pathed Imports -import sys from abc import ABC, abstractmethod from dataclasses import dataclass