diff --git a/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py b/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py index b604911f..dd2c1b94 100644 --- a/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py +++ b/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py @@ -24,7 +24,7 @@ class DistributedDPOptimizerFastGradientClipping(DPOptimizerFastGradientClipping): """ - :class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with + :class:`opacus.optimizers.optimizer.DPOptimizer` compatible with distributed data processing """ diff --git a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py index 819a1649..f46a7a65 100644 --- a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py +++ b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py @@ -23,6 +23,7 @@ from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping from opacus.utils.per_sample_gradients_utils import clone_module +from opacus.utils.fast_gradient_clipping_utils import double_backward from torch.utils.data import DataLoader, Dataset from .grad_sample_module_test import GradSampleModuleTest, SampleConvNet @@ -108,7 +109,7 @@ def setUp_data_sequantial(self, size, length, dim): @settings(deadline=1000000) def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): """ - Tests if norm calculation is same between standard (opacus) and fast gradient clipping" + Tests if norm calculation is the same between standard (opacus) and fast gradient clipping" """ self.length = length self.size = size @@ -189,7 +190,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): @settings(deadline=1000000) def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): """ - Tests if gradients are same between standard (opacus) and fast gradient clipping" + Tests if gradients are the same between standard (opacus) and fast gradient clipping, using double_backward function" """ noise_multiplier = 0.0 @@ -237,19 +238,10 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): ] flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal]) - self.grad_sample_module.enable_hooks() output_gc = self.grad_sample_module(input_data) first_loss_per_sample = self.criterion(output_gc, target_data) - first_loss = torch.mean(first_loss_per_sample) - first_loss.backward(retain_graph=True) - - optimizer_gc.zero_grad() - coeff = self.grad_sample_module.get_coeff() - second_loss_per_sample = coeff * first_loss_per_sample - second_loss = torch.sum(second_loss_per_sample) - self.grad_sample_module.disable_hooks() - second_loss.backward() + double_backward(self.grad_sample_module, optimizer_gc, first_loss_per_sample) all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()] flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc]) @@ -261,5 +253,5 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): ] ) logging.info(f"Diff = {diff}") - msg = "FAIL: Gradients from vanilla DP-SGD and from fast gradient clipping are different" + msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different" assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg diff --git a/opacus/utils/fast_gradient_clipping_utils.py b/opacus/utils/fast_gradient_clipping_utils.py new file mode 100644 index 00000000..61a3d2be --- /dev/null +++ b/opacus/utils/fast_gradient_clipping_utils.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import ( + GradSampleModuleFastGradientClipping, +) +from opacus.optimizers import DPOptimizerFastGradientClipping + + +def double_backward( + module: GradSampleModuleFastGradientClipping, + optimizer: DPOptimizerFastGradientClipping, + loss_per_sample: torch.Tensor, +) -> None: + """ + Packages the training loop for Fast Gradient and Ghost Clipping. It does the two backward passes, as well as the loss rescaling and hook operations in between. + + Args: + module: The DP gradient sample module to train + optimizer: The DP optimizer used to train the module + loss_per_sample: loss on each sample in the mini-batch of size [batch_size, 1] + + Returns: + None + """ + + torch.mean(loss_per_sample).backward(retain_graph=True) + optimizer.zero_grad() + rescaled_loss_per_sample = module.get_coeff() * loss_per_sample + rescaled_loss = torch.sum(rescaled_loss_per_sample) + module.disable_hooks() + rescaled_loss.backward() + module.enable_hooks()