Skip to content

Commit

Permalink
Fix the initialization function of "GradSampleModuleFastGradientClipp…
Browse files Browse the repository at this point in the history
…ing" (pytorch#675)

Summary:
Pull Request resolved: pytorch#675

``GradSampleModuleFastGradientClipping`` does not correctly take ``strict`` and ``force_functorch`` in its initialization function. Made the fix to allow the change of the values of the two parameters.

Reviewed By: iden-kalemaj

Differential Revision: D62676700

fbshipit-source-id: 6df643fb5e9ea47fe91490eeb01c32bd4ed8d743
  • Loading branch information
HuanyuZhang authored and facebook-github-bot committed Sep 25, 2024
1 parent a246aa6 commit aea78b3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 4 additions & 5 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,9 @@ def __init__(
``[K, batch_size, ...]``
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
is a sum or a mean operation. Can take values "sum" or "mean"
strict: If set to ``True``, the input module will be validated to check that
``GradSampleModule`` has grad sampler functions for all submodules of
the input module (i.e. if it knows how to calculate per sample gradients)
for all model parameters. If set to ``False``, per sample gradients will
strict: If set to ``True``, the input module will be validated to make sure that none of its submodules includes buffers,
which is not currently supported by Opacus.
If set to ``False``, per sample gradients will
be computed on "best effort" basis - they will be available where
possible and set to None otherwise. This is not recommended, because
some unsupported modules (e.g. BatchNorm) affect other parameters and
Expand All @@ -120,7 +119,7 @@ def __init__(
Raises:
NotImplementedError
If ``strict`` is set to ``True`` and module ``m`` (or any of its
submodules) doesn't have a registered grad sampler function.
submodules) includes a buffer.
"""
super().__init__(
m,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ def __init__(
Raises:
NotImplementedError
If ``strict`` is set to ``True`` and module ``m`` (or any of its
submodules) doesn't have a registered grad sampler function.
submodules) includes a buffer.
"""

super().__init__(
m,
batch_first=batch_first,
loss_reduction=loss_reduction,
strict=strict,
force_functorch=force_functorch,
)
self.trainable_parameters = [p for _, p in trainable_parameters(self._module)]
self.max_grad_norm = max_grad_norm
Expand Down

0 comments on commit aea78b3

Please sign in to comment.