Skip to content

Commit

Permalink
Add Clipping schedulers (#556)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #556

This diff introduces gradient clipping schedulers that can be used to vary gradient clipping throughout training.

Addresses #375 in OSS.

Reviewed By: karthikprasad

Differential Revision: D42644261

fbshipit-source-id: 7e200d704d97d0b0f5432153af32753c1d4e6204
  • Loading branch information
Darktex authored and facebook-github-bot committed Jan 24, 2023
1 parent 76f289c commit d888fd0
Show file tree
Hide file tree
Showing 8 changed files with 412 additions and 63 deletions.
17 changes: 16 additions & 1 deletion opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
wrap_model,
)
from opacus.optimizers import DPOptimizer, get_optimizer_class
from opacus.scheduler import _NoiseScheduler
from opacus.schedulers import _GradClipScheduler, _NoiseScheduler
from opacus.utils.module_utils import trainable_parameters
from opacus.validators.module_validator import ModuleValidator
from torch import nn, optim
Expand Down Expand Up @@ -550,6 +550,7 @@ def save_checkpoint(
module: GradSampleModule,
optimizer: Optional[DPOptimizer] = None,
noise_scheduler: Optional[_NoiseScheduler] = None,
grad_clip_scheduler: Optional[_GradClipScheduler] = None,
checkpoint_dict: Optional[Dict[str, Any]] = None,
module_state_dict_kwargs: Optional[Dict[str, Any]] = None,
torch_save_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -560,6 +561,9 @@ def save_checkpoint(
path: Path to save the state dict objects.
module: GradSampleModule to save; wrapped module's state_dict is saved.
optimizer: DPOptimizer to save; wrapped optimizer's state_dict is saved.
noise_scheduler: _NoiseScheduler whose state we should save.
grad_clip_scheduler: _GradClipScheduler whose state we should save.
checkpoint_dict: Dict[str, Any]; an already-filled checkpoint dict.
module_state_dict_kwargs: dict of kwargs to pass to ``module.state_dict()``
torch_save_kwargs: dict of kwargs to pass to ``torch.save()``
Expand All @@ -573,6 +577,10 @@ def save_checkpoint(
checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict()
if noise_scheduler is not None:
checkpoint_dict["noise_scheduler_state_dict"] = noise_scheduler.state_dict()
if grad_clip_scheduler is not None:
checkpoint_dict[
"grad_clip_scheduler_state_dict"
] = grad_clip_scheduler.state_dict()

torch.save(checkpoint_dict, path, **(torch_save_kwargs or {}))

Expand All @@ -583,6 +591,7 @@ def load_checkpoint(
module: GradSampleModule,
optimizer: Optional[DPOptimizer] = None,
noise_scheduler: Optional[_NoiseScheduler] = None,
grad_clip_scheduler: Optional[_GradClipScheduler] = None,
module_load_dict_kwargs: Optional[Dict[str, Any]] = None,
torch_load_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict:
Expand All @@ -606,4 +615,10 @@ def load_checkpoint(
if noise_scheduler is not None and len(noise_scheduler_state_dict) > 0:
noise_scheduler.load_state_dict(noise_scheduler_state_dict)

grad_clip_scheduler_state_dict = checkpoint.pop(
"grad_clip_scheduler_state_dict", {}
)
if grad_clip_scheduler is not None and len(grad_clip_scheduler_state_dict) > 0:
grad_clip_scheduler.load_state_dict(grad_clip_scheduler_state_dict)

return checkpoint
34 changes: 34 additions & 0 deletions opacus/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .grad_clip_scheduler import (
ExponentialGradClip,
LambdaGradClip,
StepGradClip,
_GradClipScheduler,
)
from .noise_scheduler import ExponentialNoise, LambdaNoise, StepNoise, _NoiseScheduler


__all__ = [
"_GradClipScheduler",
"ExponentialGradClip",
"LambdaGradClip",
"StepGradClip",
"_NoiseScheduler",
"ExponentialNoise",
"LambdaNoise",
"StepNoise",
]
174 changes: 174 additions & 0 deletions opacus/schedulers/grad_clip_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict

from opacus.optimizers import DPOptimizer


class _GradClipScheduler:
"""Base class for gradient clipping schedulers. We follow the same API
as the standard PyTorch LR schedulers, but apply them to Opacus's
`max_grad_norm` param instead.
This means it only works when you pass a opacus.DPOptimizer, since that
will have a `max_grad_norm` attribute.
"""

def __init__(self, optimizer: DPOptimizer, *, last_epoch=-1):
"""
Args:
optimizer (DPOptimizer): The DPOptimizer
*: Any other positional args (this is an abstract base class)
last_epoch(int): The index of last epoch. Default: -1.
"""
if not hasattr(optimizer, "max_grad_norm"):
raise ValueError(
"GradClipSchedulers require your optimizer to have a .max_grad_norm attr. "
"Are you sure you are using a DPOptimizer? Those have it added for you."
)
self.optimizer = optimizer
self.last_epoch = last_epoch

self.step()

def state_dict(self) -> Dict:
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}

def load_state_dict(self, state_dict: Dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)

def get_max_grad_norm(self):
"""Implement your scheduling logic here and return the new value for `max_grad_norm`."""
raise NotImplementedError

def step(self):
self.last_epoch += 1
max_grad_norm = self.get_max_grad_norm()
self.optimizer.max_grad_norm = max_grad_norm


class ExponentialGradClip(_GradClipScheduler):
"""
Multiplies the max_grad_norm by gamma every epoch (so the gamma factors accumulate).
This means that:
- For gamma < 1, max_grad_norm will shrink and you'll clip more
- For gamma == 1, no effect
- For gamma > 1, max_grad_norm will expand so you'll clip less
When last_epoch=-1, sets initial max_grad_norm as max_grad_norm.
"""

def __init__(self, optimizer: DPOptimizer, *, gamma: float, last_epoch: int = -1):
"""
Args:
optimizer: Wrapped optimizer
gamma: Multiplicative factor of learning rate decay.
last_epoch: The index of last epoch. Default: -1.
"""
self.gamma = gamma
super().__init__(optimizer, last_epoch=last_epoch)

def get_max_grad_norm(self):
if self.last_epoch == 0:
return self.optimizer.max_grad_norm
else:
return self.optimizer.max_grad_norm * self.gamma


class LambdaGradClip(_GradClipScheduler):
"""
Multiplies your *base* `max_grad_norm` by the output of a `scheduler_function` given
as input.
Note: the base max_grad_norm is recorded as the max_grad_norm your optimizer had set at
the very beginning. This means that the factors from the `scheduler_function` will *not*
accumulate, unlike in ExponentialGradClip. If you want some exponential-like behavior,
accumulation logic will have to be added in your `scheduler_function`.
When last_epoch=-1, sets initial max_grad_norm as max_grad_norm.
"""

def __init__(
self,
optimizer: DPOptimizer,
*,
scheduler_function: Callable[[int], float],
last_epoch: int = -1,
):
"""
Args:
optimizer: Wrapped optimizer.
scheduler_function: A function which computes a multiplicative factor given
an integer epoch
last_epoch: The index of last epoch. Default: -1.
"""
self.scheduler_function = scheduler_function
self.base_max_grad_norm = optimizer.max_grad_norm
super().__init__(optimizer, last_epoch=last_epoch)

def get_max_grad_norm(self):
return self.base_max_grad_norm * self.scheduler_function(self.last_epoch)


class StepGradClip(_GradClipScheduler):
"""
Multiplies `max_grad_norm` by `gamma` every `step_size` epochs (so the `gamma` factors accumulate).
This means that:
- For gamma < 1, max_grad_norm will shrink and you'll clip more
- For gamma == 1, no effect
- For gamma > 1, max_grad_norm will expand so you'll clip less
When last_epoch=-1, sets initial max_grad_norm as max_grad_norm.
"""

def __init__(
self,
optimizer: DPOptimizer,
*,
step_size: int,
gamma: float,
last_epoch: int = -1,
):
"""
Args:
optimizer: Wrapped optimizer.
step_size: Period of learning rate decay.
gamma: Multiplicative factor of learning rate decay.
last_epoch: The index of last epoch
"""
self.step_size = step_size
self.gamma = gamma
super().__init__(optimizer, last_epoch=last_epoch)

def get_max_grad_norm(self):
# Only change max_grad_norm when at a 'step'
if self.last_epoch == 0 or self.last_epoch % self.step_size != 0:
return self.optimizer.max_grad_norm
else:
return self.gamma * self.optimizer.max_grad_norm
52 changes: 43 additions & 9 deletions opacus/scheduler.py → opacus/schedulers/noise_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,30 @@

from typing import Callable, Dict

from .optimizers import DPOptimizer
from opacus.optimizers import DPOptimizer


class _NoiseScheduler(object):
class _NoiseScheduler:
"""Base class for noise multiplier schedulers. We follow the same API
as the standard PyTorch LR schedulers, but apply them to Opacus's noise
multiplier param instead.
This means it only works when you pass a opacus.DPOptimizer, since that
will have a `noise_multiplier` attribute.
"""

def __init__(self, optimizer: DPOptimizer, *, last_epoch=-1):
"""
Args:
optimizer (DPOptimizer): The DPOptimizer
*: Any other positional args (this is an abstract base class)
last_epoch(int): The index of last epoch. Default: -1.
"""
if not hasattr(optimizer, "noise_multiplier"):
raise ValueError(
"NoiseSchedulers require your optimizer to have a .noise_multiplier attr. "
"Are you sure you are using a DPOptimizer? Those have it added for you."
)
self.optimizer = optimizer
self.last_epoch = last_epoch

Expand All @@ -44,7 +63,7 @@ def load_state_dict(self, state_dict: Dict):
self.__dict__.update(state_dict)

def get_noise_multiplier(self):
# Compute learning rate using chainable form of the scheduler
"""Implement your scheduling logic here and return the new value for `noise_multiplier`."""
raise NotImplementedError

def step(self):
Expand All @@ -55,7 +74,12 @@ def step(self):

class ExponentialNoise(_NoiseScheduler):
"""
Decays the noise_multiplier by gamma every epoch.
Multiplies the noise_multiplier by gamma every epoch (so the gamma factors accumulate).
This means that:
- For gamma < 1, noise_multiplier will shrink
- For gamma == 1, no effect
- For gamma > 1, noise_multiplier will expand
When last_epoch=-1, sets initial noise_multiplier as noise_multiplier.
"""
Expand All @@ -66,7 +90,7 @@ def __init__(self, optimizer: DPOptimizer, *, gamma: float, last_epoch: int = -1
Args:
optimizer: Wrapped optimizer
gamma: Multiplicative factor of learning rate decay.
last_epoch: The index of last epoch
last_epoch: The index of last epoch. Default: -1.
"""
self.gamma = gamma
super().__init__(optimizer, last_epoch=last_epoch)
Expand All @@ -80,9 +104,15 @@ def get_noise_multiplier(self):

class LambdaNoise(_NoiseScheduler):
"""
Sets the noise_multiplier to the initial noise_multiplier times a given function.
When last_epoch=-1, sets initial noise_multiplier as noise_multiplier.
Multiplies your *base* `noise_multiplier` by the output of a `scheduler_function` given
as input.
Note: the base noise_multiplier is recorded as the noise_multiplier your optimizer
had set at the very beginning. This means that the factors from the `scheduler_function`
will *not* accumulate, unlike in ExponentialGradClip.
If you want some exponential-like behavior, accumulation logic will have to be
added in your `scheduler_function`.
When last_epoch=-1, sets initial noise_multiplier as noise_multiplier.
"""

def __init__(
Expand Down Expand Up @@ -110,9 +140,13 @@ def get_noise_multiplier(self):

class StepNoise(_NoiseScheduler):
"""
Decays the noise_multiplier by gamma every step_size epochs.
When last_epoch=-1, sets initial noise_multiplier as noise_multiplier.
Multiplies `noise_multiplier` by `gamma` every `step_size` epochs (so the `gamma` factors accumulate).
This means that:
- For gamma < 1, noise_multiplier will shrink
- For gamma == 1, no effect
- For gamma > 1, noise_multiplier will expand
When last_epoch=-1, sets initial noise_multiplier as noise_multiplier.
"""

def __init__(
Expand Down
Loading

0 comments on commit d888fd0

Please sign in to comment.