From aea78b3912a5297cabd5f2afba924bdab00b1fad Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Wed, 25 Sep 2024 09:23:34 -0700 Subject: [PATCH] Fix the initialization function of "GradSampleModuleFastGradientClipping" (#675) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/675 ``GradSampleModuleFastGradientClipping`` does not correctly take ``strict`` and ``force_functorch`` in its initialization function. Made the fix to allow the change of the values of the two parameters. Reviewed By: iden-kalemaj Differential Revision: D62676700 fbshipit-source-id: 6df643fb5e9ea47fe91490eeb01c32bd4ed8d743 --- opacus/grad_sample/grad_sample_module.py | 9 ++++----- .../grad_sample_module_fast_gradient_clipping.py | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index f659f357..19b5ffa6 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -105,10 +105,9 @@ def __init__( ``[K, batch_size, ...]`` loss_reduction: Indicates if the loss reduction (for aggregating the gradients) is a sum or a mean operation. Can take values "sum" or "mean" - strict: If set to ``True``, the input module will be validated to check that - ``GradSampleModule`` has grad sampler functions for all submodules of - the input module (i.e. if it knows how to calculate per sample gradients) - for all model parameters. If set to ``False``, per sample gradients will + strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers, + which is not currently supported by Opacus. + If set to ``False``, per sample gradients will be computed on "best effort" basis - they will be available where possible and set to None otherwise. This is not recommended, because some unsupported modules (e.g. BatchNorm) affect other parameters and @@ -120,7 +119,7 @@ def __init__( Raises: NotImplementedError If ``strict`` is set to ``True`` and module ``m`` (or any of its - submodules) doesn't have a registered grad sampler function. + submodules) includes a buffer. """ super().__init__( m, diff --git a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py index 8e23b9b3..deaeb385 100644 --- a/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py +++ b/opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py @@ -107,13 +107,15 @@ def __init__( Raises: NotImplementedError If ``strict`` is set to ``True`` and module ``m`` (or any of its - submodules) doesn't have a registered grad sampler function. + submodules) includes a buffer. """ super().__init__( m, batch_first=batch_first, loss_reduction=loss_reduction, + strict=strict, + force_functorch=force_functorch, ) self.trainable_parameters = [p for _, p in trainable_parameters(self._module)] self.max_grad_norm = max_grad_norm