Skip to content

Commit

Permalink
Towards making the interface of ghost clipping same as that of PyTorch (
Browse files Browse the repository at this point in the history
#668)

Summary:
Pull Request resolved: #668

We define two classes DPLossFastGradientClipping and DPTensorFastGradientClipping in the utils fine, which allows us to repurpose loss.backward() to perform the two backward passes and loss scaling required to implement ghost clipping.

Reviewed By: HuanyuZhang

Differential Revision: D61162530

fbshipit-source-id: 9b832469e1645513a13e1c962a13500169a3806b
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Aug 29, 2024
1 parent 27e6a1d commit 36f7a34
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 42 deletions.
25 changes: 17 additions & 8 deletions opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,21 @@ def create_norm_sample(
"""

if param.requires_grad:
param._norm_sample = torch.zeros(
torch.Size([max_batch_len, 1]),
device=grad_sample.device,
dtype=grad_sample.dtype,
)
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1)
if (
max_batch_len == 0
): # To handle the case of empty batch that may arise from Poisson sampling
param._norm_sample = torch.tensor(
[], device=grad_sample.device, dtype=grad_sample.dtype
)
else:
param._norm_sample = torch.zeros(
torch.Size([max_batch_len, 1]),
device=grad_sample.device,
dtype=grad_sample.dtype,
)
param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(
2, dim=-1
)


class GradSampleModuleFastGradientClipping(GradSampleModule):
Expand Down Expand Up @@ -110,7 +119,7 @@ def __init__(
self.max_grad_norm = max_grad_norm
self.use_ghost_clipping = use_ghost_clipping

def get_coeff(self) -> torch.Tensor:
def get_clipping_coef(self) -> torch.Tensor:
"""Get per-example gradient scaling factor for clipping."""
norm_sample = self.get_norm_sample()
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
Expand Down Expand Up @@ -175,6 +184,7 @@ def capture_backprops_hook(
return

backprops = forward_output[0].detach()

activations, backprops = self.rearrange_grad_samples(
module=module,
backprops=backprops,
Expand Down Expand Up @@ -216,7 +226,6 @@ def capture_backprops_hook(
max_batch_len=module.max_batch_len,
)
del p.grad_sample

if len(module.activations) == 0:
if hasattr(module, "max_batch_len"):
del module.max_batch_len
15 changes: 10 additions & 5 deletions opacus/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@


def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None):
if clipping == "flat" and distributed is False:
if grad_sample_mode == "ghost":
if clipping == "flat" and distributed is False:
return DPOptimizerFastGradientClipping
elif clipping == "flat" and distributed is True:
return DistributedDPOptimizerFastGradientClipping
else:
raise ValueError(
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}"
)
elif clipping == "flat" and distributed is False:
return DPOptimizer
elif clipping == "ghost" and distributed is False:
return DPOptimizerFastGradientClipping
elif clipping == "ghost" and distributed is True:
return DistributedDPOptimizerFastGradientClipping
elif clipping == "flat" and distributed is True:
return DistributedDPOptimizer
elif clipping == "per_layer" and distributed is False:
Expand Down
9 changes: 9 additions & 0 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from opacus.optimizers import DPOptimizer, get_optimizer_class
from opacus.schedulers import _GradClipScheduler, _NoiseScheduler
from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping
from opacus.validators.module_validator import ModuleValidator
from torch import nn, optim
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -277,6 +278,7 @@ def make_private(
*,
module: nn.Module,
optimizer: optim.Optimizer,
criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility
data_loader: DataLoader,
noise_multiplier: float,
max_grad_norm: Union[float, List[float]],
Expand Down Expand Up @@ -400,6 +402,11 @@ def make_private(
optimizer.attach_step_hook(
self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate)
)
if grad_sample_mode == "ghost":
criterion = DPLossFastGradientClipping(
module, optimizer, criterion, loss_reduction
)
return module, optimizer, criterion, data_loader

return module, optimizer, data_loader

Expand All @@ -408,6 +415,7 @@ def make_private_with_epsilon(
*,
module: nn.Module,
optimizer: optim.Optimizer,
criterion=nn.CrossEntropyLoss(), # Added deafult for backward compatibility
data_loader: DataLoader,
target_epsilon: float,
target_delta: float,
Expand Down Expand Up @@ -487,6 +495,7 @@ def make_private_with_epsilon(
module=module,
optimizer=optimizer,
data_loader=data_loader,
criterion=criterion,
noise_multiplier=get_noise_multiplier(
target_epsilon=target_epsilon,
target_delta=target_delta,
Expand Down
21 changes: 13 additions & 8 deletions opacus/tests/grad_sample_module_fast_gradient_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from hypothesis import given, settings
from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping
from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping
from opacus.utils.fast_gradient_clipping_utils import double_backward
from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping
from opacus.utils.per_sample_gradients_utils import clone_module
from torch.utils.data import DataLoader, Dataset

Expand Down Expand Up @@ -146,7 +146,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
(input_data, target_data) = list(self.dl)[0]
optimizer_normal.zero_grad()
output_normal = self.model_normal(input_data)
loss_normal = torch.mean(self.criterion(output_normal, target_data))
loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0)
loss_normal.backward()
all_norms_normal = torch.stack(
[
Expand All @@ -165,7 +165,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
first_loss.backward(retain_graph=True)

optimizer_gc.zero_grad()
coeff = self.grad_sample_module.get_coeff()
coeff = self.grad_sample_module.get_clipping_coef()
second_loss_per_sample = coeff * first_loss_per_sample
second_loss = torch.sum(second_loss_per_sample)
self.grad_sample_module.disable_hooks()
Expand All @@ -190,7 +190,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
@settings(deadline=1000000)
def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
"""
Tests if gradients are the same between standard (opacus) and fast gradient clipping, using double_backward function"
Tests if gradients are the same between standard (opacus) and fast gradient clipping"
"""

noise_multiplier = 0.0
Expand All @@ -200,7 +200,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
self.dim = dim
self.setUp_data_sequantial(self.size, self.length, self.dim)
max_grad_norm = 1.0
self.criterion = torch.nn.CrossEntropyLoss(reduction="none")
self.criterion = torch.nn.CrossEntropyLoss()

sample_module = SampleModule()
self.model_normal = GradSampleModule(clone_module(sample_module))
Expand All @@ -226,10 +226,14 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
expected_batch_size=batch_size,
)

criterion_gc = DPLossFastGradientClipping(
self.grad_sample_module, optimizer_gc, self.criterion
)

(input_data, target_data) = list(self.dl)[0]
optimizer_normal.zero_grad()
output_normal = self.model_normal(input_data)
loss_normal = torch.mean(self.criterion(output_normal, target_data))
loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0)
loss_normal.backward()
optimizer_normal.step()

Expand All @@ -240,8 +244,9 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):

output_gc = self.grad_sample_module(input_data)

first_loss_per_sample = self.criterion(output_gc, target_data)
double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample)
loss_gc = criterion_gc(output_gc, target_data)
loss_gc.backward()
# double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample)

all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()]
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])
Expand Down
2 changes: 1 addition & 1 deletion opacus/tests/multigpu_gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def run_ghost_clipping_test(
loss_per_sample = loss_fn(outputs, y)
torch.mean(loss_per_sample).backward(retain_graph=True)
optimizer.zero_grad()
rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample
rescaled_loss_per_sample = ddp_model.get_clipping_coef() * loss_per_sample
rescaled_loss = torch.sum(rescaled_loss_per_sample)
ddp_model.disable_hooks()
rescaled_loss.backward()
Expand Down
111 changes: 91 additions & 20 deletions opacus/utils/fast_gradient_clipping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,99 @@
from opacus.optimizers import DPOptimizerFastGradientClipping


def double_backward(
module: GradSampleModuleFastGradientClipping,
optimizer: DPOptimizerFastGradientClipping,
loss_per_sample: torch.Tensor,
) -> None:
class DPTensorFastGradientClipping:
"""
Packages the training loop for Fast Gradient and Ghost Clipping. It does the two backward passes, as well as the loss rescaling and hook operations in between.
This function also works with DistributedDPOptimizer.
Packages the training loop for Fast Gradient and Ghost Clipping into loss.backward().
"""

def __init__(
self,
module: GradSampleModuleFastGradientClipping,
optimizer: DPOptimizerFastGradientClipping,
loss_per_sample: torch.Tensor,
loss_reduction: str = "mean",
):
"""
Args:
module: the module to train
optimizer: the optimizer used to train the module
loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1]
"""

self.module = module
self.optimizer = optimizer
self.loss_per_sample = loss_per_sample
self.loss_reduction = loss_reduction

def item(self):
if self.loss_reduction == "mean":
return torch.mean(self.loss_per_sample).detach().item()
elif self.loss_reduction == "sum":
return torch.sum(self.loss_per_sample).detach().item()

Args:
module: The DP gradient sample module to train
optimizer: The DP optimizer used to train the module
loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1]
def backward(self):
"""
Repurposes loss.backward() to perform two backward passes, as well as the loss rescaling and hook operations in between
"""

Returns:
None
if self.loss_reduction == "mean":
reduced_loss = torch.mean(self.loss_per_sample, dim=0)
elif self.loss_reduction == "sum":
reduced_loss = torch.sum(self.loss_per_sample, dim=0)
else:
raise ValueError(
f"loss_reduction = {self.loss_reduction}. Only 'sum' and 'mean' losses are supported"
)
reduced_loss.backward(retain_graph=True)
self.optimizer.zero_grad()
coeff = self.module.get_clipping_coef()
second_loss_per_sample = coeff * self.loss_per_sample
second_loss = torch.sum(second_loss_per_sample)
self.module.disable_hooks()
second_loss.backward()
self.module.enable_hooks()


class DPLossFastGradientClipping:
"""
Wrapper on the loss function to be used with Fast Gradient and Ghost Clipping. It computes the per-sample loss, and wraps it in DPTensorFastGradientClipping.
"""

torch.mean(loss_per_sample).backward(retain_graph=True)
optimizer.zero_grad()
rescaled_loss_per_sample = module.get_coeff() * loss_per_sample
rescaled_loss = torch.sum(rescaled_loss_per_sample)
module.disable_hooks()
rescaled_loss.backward()
module.enable_hooks()
def __init__(
self,
module: GradSampleModuleFastGradientClipping,
optimizer: DPOptimizerFastGradientClipping,
criterion,
loss_reduction: str = "mean",
):
assert loss_reduction in [
"mean",
"sum",
], "loss_reduction should be either 'mean' or 'sum'"
assert (
loss_reduction
== criterion.reduction
== module.loss_reduction
== optimizer.loss_reduction
), "loss_reduction should be the same across GradSampleModule, Optimizer, Criterion, and loss_reduction"

self.optimizer = optimizer
self.module = module
self.criterion = criterion
self.loss_reduction = loss_reduction
self.criterion.reduction = "none"

def __call__(self, input, target) -> DPTensorFastGradientClipping:
"""
Redefining the forward function to compute per-sample loss and wrap it in DPTensorFastGradientClipping
"""

loss_per_sample = self.criterion(
input,
target,
)
return DPTensorFastGradientClipping(
self.module, self.optimizer, loss_per_sample, self.loss_reduction
)

0 comments on commit 36f7a34

Please sign in to comment.