From dc162b4c120a29b733716d559d1dbbfa5d67ad30 Mon Sep 17 00:00:00 2001 From: Enayat Ullah Date: Sat, 20 Jul 2024 18:25:07 -0700 Subject: [PATCH] Fast Gradient and Ghost Clipping (#656) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/656 Itroducing Fast Gradient Clipping and Ghost Clipping to Opacus for memory-efficient training with DP SGD. Reviewed By: HuanyuZhang Differential Revision: D58210796 --- opacus/grad_sample/__init__.py | 11 +- opacus/grad_sample/grad_sample_module.py | 5 +- ...ad_sample_module_fast_gradient_clipping.py | 222 +++++++++++++++ opacus/grad_sample/linear.py | 46 ++- opacus/grad_sample/utils.py | 34 ++- opacus/optimizers/__init__.py | 10 + .../ddpoptimizer_fast_gradient_clipping.py | 81 ++++++ opacus/optimizers/optimizer.py | 6 +- .../optimizer_fast_gradient_clipping.py | 189 +++++++++++++ ...mple_module_fast_gradient_clipping_test.py | 265 ++++++++++++++++++ 10 files changed, 856 insertions(+), 13 deletions(-) create mode 100644 opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py create mode 100644 opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py create mode 100644 opacus/optimizers/optimizer_fast_gradient_clipping.py create mode 100644 opacus/tests/grad_sample_module_fast_gradient_clipping_test.py diff --git a/opacus/grad_sample/__init__.py b/opacus/grad_sample/__init__.py index 60b0403b..65b7af87 100644 --- a/opacus/grad_sample/__init__.py +++ b/opacus/grad_sample/__init__.py @@ -17,7 +17,10 @@ from .dp_multihead_attention import compute_sequence_bias_grad_sample # noqa from .dp_rnn import compute_rnn_linear_grad_sample # noqa from .embedding import compute_embedding_grad_sample # noqa -from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample +from .grad_sample_module import (GradSampleModule, + create_or_accumulate_grad_sample) +from .grad_sample_module_fast_gradient_clipping import \ + GradSampleModuleFastGradientClipping # noqa from .group_norm import compute_group_norm_grad_sample # noqa from .gsm_base import AbstractGradSampleModule from .gsm_exp_weights import GradSampleModuleExpandedWeights @@ -25,15 +28,17 @@ from .instance_norm import compute_instance_norm_grad_sample # noqa from .layer_norm import compute_layer_norm_grad_sample # noqa from .linear import compute_linear_grad_sample # noqa -from .utils import get_gsm_class, register_grad_sampler, wrap_model - +from .utils import (get_gsm_class, register_grad_sampler, + register_norm_sampler, wrap_model) __all__ = [ "GradSampleModule", + "GradSampleModuleFastGradientClipping", "GradSampleModuleExpandedWeights", "GradSampleModuleNoOp", "AbstractGradSampleModule", "register_grad_sampler", + "register_norm_sampler", "create_or_accumulate_grad_sample", "wrap_model", "get_gsm_class", diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index b7e07491..f659f357 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -34,6 +34,7 @@ logger = logging.getLogger(__name__) +logger.disabled = True def create_or_accumulate_grad_sample( @@ -465,10 +466,8 @@ def validate( errors.extend( [ NotImplementedError( - f"Model contains a trainable layer " + f"Model contains a trainable layer with buffers" f"that Opacus doesn't currently support({m_name}:{m}). " - f"Please implement and register grad sampler for this layer. " - f"(See opacus.grad_sample.utils.register_grad_sampler)" ) for m_name, m in trainable_modules(module) # With functorch, all modules are trainable diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py new file mode 100644 index 00000000..c86f28c6 --- /dev/null +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import List + +import torch +import torch.nn as nn +from opacus.grad_sample.functorch import ft_compute_per_sample_gradient +from opacus.grad_sample.grad_sample_module import ( + GradSampleModule, + create_or_accumulate_grad_sample, + promote_current_grad_sample, +) +from opacus.utils.module_utils import requires_grad, trainable_parameters + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +def create_norm_sample( + *, param: torch.Tensor, grad_sample: torch.Tensor, max_batch_len: int +) -> None: + """ + Creates a ``_norm_sample`` attribute in the given parameter + + + Args: + param: Parameter to which ``_norm_sample`` will be added + grad_sample: Per-sample gradients tensor. Must be of the same + shape as ``param`` with extra batch dimension + """ + + if param.requires_grad: + param._norm_sample = torch.zeros( + torch.Size([max_batch_len, 1]), + device=grad_sample.device, + dtype=grad_sample.dtype, + ) + param._norm_sample = grad_sample.reshape(len(grad_sample), -1).norm(2, dim=-1) + + +class GradSampleModuleFastGradientClipping(GradSampleModule): + """ + Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping + + Computes norms of gradients without gradient instantiation + """ + + NORM_SAMPLERS = {} + + def __init__( + self, + m: nn.Module, + *, + batch_first=True, + loss_reduction="mean", + strict: bool = True, + force_functorch=False, + max_grad_norm=1, + use_ghost_clipping=True, + ): + """ + + Args: + m: nn.Module to be wrapped + batch_first: Flag to indicate if the input tensor to the corresponding module + has the first dimension representing the batch. If set to True, dimensions on + input tensor are expected be ``[batch_size, ...]``, otherwise + ``[K, batch_size, ...]`` + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) + is a sum or a mean operation. Can take values "sum" or "mean" + max_grad_norm: The value at which gradients are to be clipped. + strict: If set to True, the input module will be validated to make sure that + it does not have buffers in all its submodules. + force_functorch: If set to ``True``, will use functorch to compute + all per sample gradients. Otherwise, functorch will be used only + for layers without registered grad sampler methods. + use_ghost_clipping: If set to ``True``, Ghost Clipping + will be used for clipping gradients of supported layers. If ``False``, Fast + Gradient Clipping will be used for all layers. + + Raises: + NotImplementedError + If ``strict`` is set to ``True`` and module ``m`` (or any of its + submodules) doesn't have a registered grad sampler function. + """ + + super().__init__( + m, + batch_first=batch_first, + loss_reduction=loss_reduction, + ) + self.trainable_parameters = [p for _, p in trainable_parameters(self._module)] + self.max_grad_norm = max_grad_norm + self.use_ghost_clipping = use_ghost_clipping + + def get_coeff(self) -> torch.Tensor: + """Get per-example gradient scaling factor for clipping.""" + norm_sample = self.get_norm_sample() + return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0) + + def get_norm_sample(self) -> torch.Tensor: + """Get per-example gradient norms.""" + norm_sample = torch.stack( + [param._norm_sample for param in self.trainable_parameters], dim=0 + ).norm(2, dim=0) + return norm_sample + + def capture_activations_hook( + self, + module: nn.Module, + forward_input: List[torch.Tensor], + _forward_output: torch.Tensor, + ): + if ( + not requires_grad(module) + or not module.training + or not torch.is_grad_enabled() + or not self.hooks_enabled + ): + return + + if not hasattr(module, "activations"): + module.activations = [] + module.activations.append([t.detach() for t in forward_input]) # pyre-ignore + + for _, p in trainable_parameters(module): + p._forward_counter += 1 + if ( + self.use_ghost_clipping + and p._forward_counter > 1 + and type(module) in self.NORM_SAMPLERS + ): + raise NotImplementedError( + "Parameter tying is not supported with Ghost Clipping" + ) + + def capture_backprops_hook( + self, + module: nn.Module, + _forward_input: torch.Tensor, + forward_output: torch.Tensor, + loss_reduction: str, + batch_first: bool, + ): + """ + Computes norms of per sample gradient given the current backprops and activations + stored by the associated forward hook. Computed per sample gradient norms are + stored in ``norm_sample`` field in each parameter. + + Args: + module: nn.Module, + _forward_input: torch.Tensor, + forward_output: torch.Tensor, + loss_reduction: str, + batch_first: bool, + """ + if not self.hooks_enabled: + return + + backprops = forward_output[0].detach() + activations, backprops = self.rearrange_grad_samples( + module=module, + backprops=backprops, + loss_reduction=loss_reduction, + batch_first=batch_first, + ) + + if self.use_ghost_clipping and type(module) in self.NORM_SAMPLERS: + norm_sampler_fn = self.NORM_SAMPLERS[type(module)] + norm_samples = norm_sampler_fn(module, activations, backprops) + + for param, ns in norm_samples.items(): + if param.requires_grad: + param._norm_sample = ns + param._forward_counter -= 1 + + else: + if not self.force_functorch and type(module) in self.GRAD_SAMPLERS: + grad_sampler_fn = self.GRAD_SAMPLERS[type(module)] + else: + grad_sampler_fn = ft_compute_per_sample_gradient + + grad_samples = grad_sampler_fn(module, activations, backprops) + for param, gs in grad_samples.items(): + create_or_accumulate_grad_sample( + param=param, grad_sample=gs, max_batch_len=module.max_batch_len + ) + del grad_samples + # Detect end of current batch processing and switch accumulation + # mode from sum to stacking. Used for RNNs and tied parameters + # (See #417 for details) + for _, p in trainable_parameters(module): + p._forward_counter -= 1 + if p._forward_counter == 0: + promote_current_grad_sample(p) + create_norm_sample( + param=p, + grad_sample=p.grad_sample, + max_batch_len=module.max_batch_len, + ) + del p.grad_sample + + if len(module.activations) == 0: + if hasattr(module, "max_batch_len"): + del module.max_batch_len diff --git a/opacus/grad_sample/linear.py b/opacus/grad_sample/linear.py index 5ab2739b..2cdb84cd 100644 --- a/opacus/grad_sample/linear.py +++ b/opacus/grad_sample/linear.py @@ -13,13 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import Dict, List import torch import torch.nn as nn from opt_einsum import contract -from .utils import register_grad_sampler +from .utils import register_grad_sampler, register_norm_sampler + + +logger = logging.getLogger(__name__) +logging.disabled = False @register_grad_sampler(nn.Linear) @@ -42,3 +47,42 @@ def compute_linear_grad_sample( if layer.bias is not None and layer.bias.requires_grad: ret[layer.bias] = contract("n...k->nk", backprops) return ret + + +@register_norm_sampler(nn.Linear) +def compute_linear_norm_sample( + layer: nn.Linear, activations: List[torch.Tensor], backprops: torch.Tensor +) -> Dict[nn.Parameter, torch.Tensor]: + """ + Computes per sample gradient norms for ``nn.Linear`` layer + + Args: + layer: Layer + activations: Activations + backprops: Backpropagations + """ + activations = activations[0] + ret = {} + + if backprops.dim() == 2: + if layer.weight.requires_grad: + g = contract("n...i,n...i->n", backprops, backprops) + a = contract("n...j,n...j->n", activations, activations) + ret[layer.weight] = torch.sqrt((g * a).flatten()) + if layer.bias is not None and layer.bias.requires_grad: + ret[layer.bias] = torch.sqrt( + contract("n...i,n...i->n", backprops, backprops).flatten() + ) + elif backprops.dim() == 3: + if layer.weight.requires_grad: + + ggT = contract("nik,njk->nij", backprops, backprops) # batchwise g g^T + aaT = contract("nik,njk->nij", activations, activations) # batchwise a a^T + ga = contract("n...i,n...i->n", ggT, aaT).clamp(min=0) + + ret[layer.weight] = torch.sqrt(ga) + if layer.bias is not None and layer.bias.requires_grad: + ggT = contract("nik,njk->nij", backprops, backprops) + gg = contract("n...i,n...i->n", ggT, ggT).clamp(min=0) + ret[layer.bias] = torch.sqrt(gg) + return ret diff --git a/opacus/grad_sample/utils.py b/opacus/grad_sample/utils.py index 8b5e8ff5..f4011361 100644 --- a/opacus/grad_sample/utils.py +++ b/opacus/grad_sample/utils.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +# !/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,6 +18,9 @@ import torch.nn as nn from .grad_sample_module import GradSampleModule +from .grad_sample_module_fast_gradient_clipping import ( + GradSampleModuleFastGradientClipping, +) from .gsm_base import AbstractGradSampleModule from .gsm_exp_weights import GradSampleModuleExpandedWeights from .gsm_no_op import GradSampleModuleNoOp @@ -46,6 +49,33 @@ def decorator(f): ) for target_class in target_classes: GradSampleModule.GRAD_SAMPLERS[target_class] = f + GradSampleModuleFastGradientClipping.GRAD_SAMPLERS[target_class] = f + return f + + return decorator + + +def register_norm_sampler( + target_class_or_classes: Union[Type[nn.Module], Sequence[Type[nn.Module]]] +): + """ + Registers the decorated function as the ``norm_sampler`` of ``target_class_or_classes``, which is + the function that will be invoked every time you want to compute a per-sample gradient norm + of ``target_class_or_classes``. The signature of every norm_sampler is always the same: + + >>> @register_norm_sampler(MyCustomModel) + ... def compute_grad_norm_sample(module, activations, backprops): + ... pass + """ + + def decorator(f): + target_classes = ( + target_class_or_classes + if isinstance(target_class_or_classes, Sequence) + else [target_class_or_classes] + ) + for target_class in target_classes: + GradSampleModuleFastGradientClipping.NORM_SAMPLERS[target_class] = f return f return decorator @@ -70,6 +100,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]: return GradSampleModule elif grad_sample_mode == "ew": return GradSampleModuleExpandedWeights + elif grad_sample_mode == "ghost": + return GradSampleModuleFastGradientClipping elif grad_sample_mode == "no_op": return GradSampleModuleNoOp else: diff --git a/opacus/optimizers/__init__.py b/opacus/optimizers/__init__.py index 55297c7c..5867e127 100644 --- a/opacus/optimizers/__init__.py +++ b/opacus/optimizers/__init__.py @@ -18,7 +18,11 @@ SimpleDistributedPerLayerOptimizer, ) from .ddpoptimizer import DistributedDPOptimizer +from .ddpoptimizer_fast_gradient_clipping import ( + DistributedDPOptimizerFastGradientClipping, +) from .optimizer import DPOptimizer +from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping from .perlayeroptimizer import DPPerLayerOptimizer @@ -27,6 +31,8 @@ "DistributedPerLayerOptimizer", "DistributedDPOptimizer", "DPOptimizer", + "DPOptimizerFastGradientClipping", + "DistributedDPOptimizerFastGradientlipping", "DPPerLayerOptimizer", "SimpleDistributedPerLayerOptimizer", ] @@ -35,6 +41,10 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str = None): if clipping == "flat" and distributed is False: return DPOptimizer + elif clipping == "ghost" and distributed is False: + return DPOptimizerFastGradientClipping + elif clipping == "ghost" and distributed is True: + return DistributedDPOptimizerFastGradientClipping elif clipping == "flat" and distributed is True: return DistributedDPOptimizer elif clipping == "per_layer" and distributed is False: diff --git a/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py b/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py new file mode 100644 index 00000000..b604911f --- /dev/null +++ b/opacus/optimizers/ddpoptimizer_fast_gradient_clipping.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable, Optional + +import torch +from torch.optim import Optimizer + +from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping + + +class DistributedDPOptimizerFastGradientClipping(DPOptimizerFastGradientClipping): + """ + :class:`~opacus.optimizers.optimizer.DPOptimizer` compatible with + distributed data processing + """ + + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + ): + super().__init__( + optimizer, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=expected_batch_size, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + ) + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_noise(self): + # Noise only gets added to the first worker + if self.rank == 0: + super().add_noise() + else: + for p in self.params: + p.grad = p.summed_grad.view_as(p) + + def reduce_gradients(self): + for p in self.params: + if not p.requires_grad: + continue + torch.distributed.all_reduce(p.grad, op=torch.distributed.ReduceOp.SUM) + if self.loss_reduction == "mean": + p.grad /= self.world_size + + def step( + self, closure: Optional[Callable[[], float]] = None + ) -> Optional[torch.Tensor]: + if closure is not None: + with torch.enable_grad(): + closure() + + if self.pre_step(): + self.reduce_gradients() + return self.original_optimizer.original_optimizer.step() + else: + return None diff --git a/opacus/optimizers/optimizer.py b/opacus/optimizers/optimizer.py index 53eb3e50..bbd554a6 100644 --- a/opacus/optimizers/optimizer.py +++ b/opacus/optimizers/optimizer.py @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) +logger.disabled = True def _mark_as_processed(obj: Union[torch.Tensor, List[torch.Tensor]]): @@ -497,18 +498,14 @@ def pre_step( # Essentially the DPOptimizer act as a normal optimizer if self.grad_samples is None or len(self.grad_samples) == 0: return True - self.clip_and_accumulate() if self._check_skip_next_step(): self._is_last_step_skipped = True return False - self.add_noise() self.scale_grad() - if self.step_hook: self.step_hook(self) - self._is_last_step_skipped = False return True @@ -516,7 +513,6 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] if closure is not None: with torch.enable_grad(): closure() - if self.pre_step(): return self.original_optimizer.step() else: diff --git a/opacus/optimizers/optimizer_fast_gradient_clipping.py b/opacus/optimizers/optimizer_fast_gradient_clipping.py new file mode 100644 index 00000000..aa415e33 --- /dev/null +++ b/opacus/optimizers/optimizer_fast_gradient_clipping.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Callable, Optional + +import torch +from torch.optim import Optimizer + +from .optimizer import DPOptimizer + + +logger = logging.getLogger(__name__) +logger.disabled = True + + +class DPOptimizerFastGradientClipping(DPOptimizer): + """ + ``torch.optim.Optimizer`` wrapper to implement Fast Gradient and Ghost Clipping -- modifies DPOptimizer + to only add noise to the average gradient, without clipping. + + Can be used with any ``torch.optim.Optimizer`` subclass as an underlying optimizer. + ``DPOptimizerFastGradientClipping`` assumes that parameters over which it performs optimization belong + to GradSampleModuleFastGradientClipping and therefore have the ``grad_sample`` attribute. + + On a high level ``DPOptimizerFastGradientClipping``'s step looks like this: + 1) Add Gaussian noise to ``p.grad`` calibrated to a given noise multiplier and + max grad norm limit (``std = noise_multiplier * max_grad_norm``). + 2) Call underlying optimizer to perform optimization step + + Examples: + >>> module = MyCustomModel() + >>> optimizer = torch.optim.SGD(module.parameters(), lr=0.1) + >>> dp_optimizer = DPOptimizerFastGradientClipping( + ... optimizer=optimizer, + ... noise_multiplier=1.0, + ... max_grad_norm=1.0, + ... expected_batch_size=4, + ... ) + """ + + def __init__( + self, + optimizer: Optimizer, + *, + noise_multiplier: float, + max_grad_norm: float, + expected_batch_size: Optional[int], + loss_reduction: str = "mean", + generator=None, + secure_mode: bool = False, + ): + """ + + Args: + optimizer: wrapped optimizer. + noise_multiplier: noise multiplier + max_grad_norm: max grad norm used for calculating the standard devition of noise added + expected_batch_size: batch_size used for averaging gradients. When using + Poisson sampling averaging denominator can't be inferred from the + actual batch size. Required is ``loss_reduction="mean"``, ignored if + ``loss_reduction="sum"`` + loss_reduction: Indicates if the loss reduction (for aggregating the gradients) + is a sum or a mean operation. Can take values "sum" or "mean" + generator: torch.Generator() object used as a source of randomness for + the noise + secure_mode: if ``True`` uses noise generation approach robust to floating + point arithmetic attacks. + See :meth:`~opacus.optimizers.optimizer._generate_noise` for details + """ + + super().__init__( + optimizer=optimizer, + noise_multiplier=noise_multiplier, + expected_batch_size=expected_batch_size, + max_grad_norm=max_grad_norm, + loss_reduction=loss_reduction, + generator=generator, + secure_mode=secure_mode, + ) + + @property + def accumulated_iterations(self) -> int: + """ + Returns number of batches currently accumulated and not yet processed. + + In other words ``accumulated_iterations`` tracks the number of forward/backward + passed done in between two optimizer steps. The value would typically be 1, + but there are possible exceptions. + + Used by privacy accountants to calculate real sampling rate. + """ + return 1 + + def accumulate(self): + """ + Performs gradient accumulation. + Stores aggregated gradients into `p.summed_grad``` + """ + for p in self.params: + if p.summed_grad is not None: + p.summed_grad += p.grad + else: + p.summed_grad = p.grad + + def zero_grad(self, set_to_none: bool = False): + """ + Clear gradients. + + Clears ``p.grad``, ``p.grad_sample`` and ``p.summed_grad`` for all of it's parameters + + Notes: + ``set_to_none`` argument only affects ``p.grad``. ``p.grad_sample`` and + ``p.summed_grad`` is never zeroed out and always set to None. + Normal grads can do this, because their shape is always the same. + Grad samples do not behave like this, as we accumulate gradients from different + batches in a list + + Args: + set_to_none: instead of setting to zero, set the grads to None. (only + affects regular gradients. Per sample gradients are always set to None) + """ + + if set_to_none is False: + logger.debug( + "Despite set_to_none is set to False, " + "opacus will set p.grad_sample and p.summed_grad to None due to " + "non-trivial gradient accumulation behaviour" + ) + + for p in self.params: + p.grad_sample = None + + if not self._is_last_step_skipped: + p.summed_grad = None + self.original_optimizer.zero_grad(set_to_none) + + def pre_step( + self, closure: Optional[Callable[[], float]] = None + ) -> Optional[float]: + """ + Perform actions specific to ``DPOptimizer`` before calling + underlying ``optimizer.step()`` + + Args: + closure: A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + """ + # The corner case when the optimizer has no trainable parameters. + # Essentially the DPOptimizer act as a normal optimizer + + self.accumulate() + if self._check_skip_next_step(): + self._is_last_step_skipped = True + return False + + self.add_noise() + self.scale_grad() + + if self.step_hook: + self.step_hook(self) + + self._is_last_step_skipped = False + return True + + def _get_flat_grad_sample(self, p: torch.Tensor): + """ + Redefines a parent class' function to not do anything + """ + pass + + def clip_and_accumulate(self): + """ + Redefines a parent class' function to not do anything + """ + pass diff --git a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py new file mode 100644 index 00000000..819a1649 --- /dev/null +++ b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import hypothesis.strategies as st +import torch +import torch.nn as nn +import torch.nn.functional as F +from hypothesis import given, settings +from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping +from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping +from opacus.utils.per_sample_gradients_utils import clone_module +from torch.utils.data import DataLoader, Dataset + +from .grad_sample_module_test import GradSampleModuleTest, SampleConvNet + + +class SyntheticDataset(Dataset): + def __init__(self, size, length, dim): + self.size = size + self.length = length + self.dim = dim + self.images = torch.randn(self.size, self.length, self.dim, dtype=torch.float32) + self.labels = torch.randint( + 0, 2, size=(self.size, self.length), dtype=torch.float32 + ) + + def __len__(self): + return self.size + + def __getitem__(self, index): + image = self.images[index] + label = self.labels[index] + return image, label + + +class SampleModule(nn.Module): + def __init__(self): + super(SampleModule, self).__init__() + self.fc1 = nn.Linear(2, 2) + self.fc3 = nn.Linear(2, 1024) + self.fc4 = nn.Linear(1024, 1024) + self.fc5 = nn.Linear(1024, 1) + self.layer_norm = nn.LayerNorm(2) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.layer_norm(x) + x = self.fc3(x) + x = self.fc4(x) + x = self.fc5(x).flatten(start_dim=1) + x = F.softmax(x) + return x + + +class GradSampleModuleFastGradientClippingTest(GradSampleModuleTest): + CLS = GradSampleModuleFastGradientClipping + + def setUp(self): + self.dim = 2 + self.size = 10 + self.length = 5 + # self.original_model = SampleModule() + # copy_of_original_model = SampleModule() + + self.original_model = SampleConvNet() + copy_of_original_model = SampleConvNet() + + copy_of_original_model.load_state_dict( + self.original_model.state_dict(), strict=True + ) + + self.grad_sample_module = self.CLS( + copy_of_original_model, + batch_first=True, + max_grad_norm=1, + use_ghost_clipping=True, + ) + self.DATA_SIZE = self.size + self.setUp_data() + self.criterion = nn.L1Loss() + + def setUp_data_sequantial(self, size, length, dim): + self.size = size + self.length = length + self.dim = dim + dataset = SyntheticDataset(size=size, length=length, dim=dim) + self.dl = DataLoader(dataset, batch_size=size, shuffle=True) + + @given( + size=st.sampled_from([10]), + length=st.sampled_from([1]), + dim=st.sampled_from([2]), + ) + @settings(deadline=1000000) + def test_norm_calculation_fast_gradient_clipping(self, size, length, dim): + """ + Tests if norm calculation is same between standard (opacus) and fast gradient clipping" + """ + self.length = length + self.size = size + self.dim = dim + + self.criterion = torch.nn.CrossEntropyLoss(reduction="none") + self.setUp_data_sequantial(self.size, self.length, self.dim) + noise_multiplier = 0.0 + batch_size = self.size + max_grad_norm = 1.0 + sample_module = SampleModule() + self.model_normal = GradSampleModule(clone_module(sample_module)) + optimizer_normal = torch.optim.SGD(self.model_normal.parameters(), lr=1) + optimizer_normal = DPOptimizer( + optimizer_normal, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + self.grad_sample_module = GradSampleModuleFastGradientClipping( + clone_module(sample_module), + max_grad_norm=max_grad_norm, + use_ghost_clipping=True, + ) + optimizer_gc = torch.optim.SGD(self.grad_sample_module.parameters(), lr=1) + optimizer_gc = DPOptimizerFastGradientClipping( + optimizer_gc, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + (input_data, target_data) = list(self.dl)[0] + optimizer_normal.zero_grad() + output_normal = self.model_normal(input_data) + loss_normal = torch.mean(self.criterion(output_normal, target_data)) + loss_normal.backward() + all_norms_normal = torch.stack( + [ + torch.stack([g.norm() for g in param.grad_sample], dim=0) + for param in self.model_normal.parameters() + ], + dim=0, + ) + flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal]) + + self.grad_sample_module.enable_hooks() + output_gc = self.grad_sample_module(input_data) + + first_loss_per_sample = self.criterion(output_gc, target_data) + first_loss = torch.mean(first_loss_per_sample) + first_loss.backward(retain_graph=True) + + optimizer_gc.zero_grad() + coeff = self.grad_sample_module.get_coeff() + second_loss_per_sample = coeff * first_loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + self.grad_sample_module.disable_hooks() + second_loss.backward() + + all_norms_gc = [ + param._norm_sample for param in self.grad_sample_module.parameters() + ] + flat_norms_gc = torch.cat([p.flatten() for p in all_norms_gc]) + + diff = flat_norms_normal - flat_norms_gc + + logging.info(f"Diff = {diff}"), + msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different" + assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg + + @given( + size=st.sampled_from([10]), + length=st.sampled_from([1, 5]), + dim=st.sampled_from([2]), + ) + @settings(deadline=1000000) + def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): + """ + Tests if gradients are same between standard (opacus) and fast gradient clipping" + """ + + noise_multiplier = 0.0 + batch_size = size + self.length = length + self.size = size + self.dim = dim + self.setUp_data_sequantial(self.size, self.length, self.dim) + max_grad_norm = 1.0 + self.criterion = torch.nn.CrossEntropyLoss(reduction="none") + + sample_module = SampleModule() + self.model_normal = GradSampleModule(clone_module(sample_module)) + self.grad_sample_module = GradSampleModuleFastGradientClipping( + clone_module(sample_module), + max_grad_norm=max_grad_norm, + use_ghost_clipping=True, + ) + + optimizer_normal = torch.optim.SGD(self.model_normal.parameters(), lr=1) + optimizer_normal = DPOptimizer( + optimizer_normal, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + optimizer_gc = torch.optim.SGD(self.grad_sample_module.parameters(), lr=1) + optimizer_gc = DPOptimizerFastGradientClipping( + optimizer_gc, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + (input_data, target_data) = list(self.dl)[0] + optimizer_normal.zero_grad() + output_normal = self.model_normal(input_data) + loss_normal = torch.mean(self.criterion(output_normal, target_data)) + loss_normal.backward() + optimizer_normal.step() + + all_grads_normal = [ + param.summed_grad for param in self.model_normal.parameters() + ] + flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal]) + + self.grad_sample_module.enable_hooks() + output_gc = self.grad_sample_module(input_data) + + first_loss_per_sample = self.criterion(output_gc, target_data) + first_loss = torch.mean(first_loss_per_sample) + first_loss.backward(retain_graph=True) + + optimizer_gc.zero_grad() + coeff = self.grad_sample_module.get_coeff() + second_loss_per_sample = coeff * first_loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + self.grad_sample_module.disable_hooks() + second_loss.backward() + + all_grads_gc = [param.grad for param in self.grad_sample_module.parameters()] + flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc]) + + diff = torch.tensor( + [ + (g_gc - g_normal).norm() + for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal) + ] + ) + logging.info(f"Diff = {diff}") + msg = "FAIL: Gradients from vanilla DP-SGD and from fast gradient clipping are different" + assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg