Skip to content

Commit

Permalink
One backward function for Ghost Clipping (pytorch#661)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#661

Simplfied training loop for ghost clipping using only one "double backward" function.

Differential Revision: D60427371
  • Loading branch information
EnayatUllah authored and facebook-github-bot committed Jul 31, 2024
1 parent f1d0e02 commit 3e1d14c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
2 changes: 1 addition & 1 deletion opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class DistributedDPOptimizerFastGradientClipping(DPOptimizerFastGradientClipping):
"""
:class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with
:class:`opacus.optimizers.optimizer.DPOptimizer` compatible with
distributed data processing
"""

Expand Down
18 changes: 5 additions & 13 deletions opacus/tests/grad_sample_module_fast_gradient_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping
from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping
from opacus.utils.per_sample_gradients_utils import clone_module
from opacus.utils.fast_gradient_clipping_utils import double_backward
from torch.utils.data import DataLoader, Dataset

from .grad_sample_module_test import GradSampleModuleTest, SampleConvNet
Expand Down Expand Up @@ -108,7 +109,7 @@ def setUp_data_sequantial(self, size, length, dim):
@settings(deadline=1000000)
def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
"""
Tests if norm calculation is same between standard (opacus) and fast gradient clipping"
Tests if norm calculation is the same between standard (opacus) and fast gradient clipping"
"""
self.length = length
self.size = size
Expand Down Expand Up @@ -189,7 +190,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
@settings(deadline=1000000)
def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
"""
Tests if gradients are same between standard (opacus) and fast gradient clipping"
Tests if gradients are the same between standard (opacus) and fast gradient clipping, using double_backward function"
"""

noise_multiplier = 0.0
Expand Down Expand Up @@ -237,19 +238,10 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
]
flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal])

self.grad_sample_module.enable_hooks()
output_gc = self.grad_sample_module(input_data)

first_loss_per_sample = self.criterion(output_gc, target_data)
first_loss = torch.mean(first_loss_per_sample)
first_loss.backward(retain_graph=True)

optimizer_gc.zero_grad()
coeff = self.grad_sample_module.get_coeff()
second_loss_per_sample = coeff * first_loss_per_sample
second_loss = torch.sum(second_loss_per_sample)
self.grad_sample_module.disable_hooks()
second_loss.backward()
double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample)

all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()]
flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc])
Expand All @@ -261,5 +253,5 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
]
)
logging.info(f"Diff = {diff}")
msg = "FAIL: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
46 changes: 46 additions & 0 deletions opacus/utils/fast_gradient_clipping_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
GradSampleModuleFastGradientClipping,
)
from opacus.optimizers import DPOptimizerFastGradientClipping


def double_backward(
module: GradSampleModuleFastGradientClipping,
optimizer: DPOptimizerFastGradientClipping,
loss_per_sample: torch.Tensor,
) -> None:
"""
Packages the training loop for Fast Gradient and Ghost Clipping. It does the two backward passes, as well as the loss rescaling and hook operations in between.
Args:
module: The DP gradient sample module to train
optimizer: The DP optimizer used to train the module
loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1]
Returns:
None
"""

torch.mean(loss_per_sample).backward(retain_graph=True)
optimizer.zero_grad()
rescaled_loss_per_sample = module.get_coeff() * loss_per_sample
rescaled_loss = torch.sum(rescaled_loss_per_sample)
module.disable_hooks()
rescaled_loss.backward()
module.enable_hooks()

0 comments on commit 3e1d14c

Please sign in to comment.