From d888fd01d2d06b5e6f53fb0ea68a4603d0a76d4f Mon Sep 17 00:00:00 2001 From: Davide Testuggine Date: Mon, 23 Jan 2023 18:35:16 -0800 Subject: [PATCH] Add Clipping schedulers (#556) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/556 This diff introduces gradient clipping schedulers that can be used to vary gradient clipping throughout training. Addresses #375 in OSS. Reviewed By: karthikprasad Differential Revision: D42644261 fbshipit-source-id: 7e200d704d97d0b0f5432153af32753c1d4e6204 --- opacus/privacy_engine.py | 17 +- opacus/schedulers/__init__.py | 34 ++++ opacus/schedulers/grad_clip_scheduler.py | 174 ++++++++++++++++++ .../noise_scheduler.py} | 52 +++++- opacus/tests/privacy_engine_test.py | 105 ++++++----- opacus/tests/schedulers/__init__.py | 14 ++ .../schedulers/grad_clip_scheduler_test.py | 75 ++++++++ .../noise_scheduler_test.py} | 4 +- 8 files changed, 412 insertions(+), 63 deletions(-) create mode 100644 opacus/schedulers/__init__.py create mode 100644 opacus/schedulers/grad_clip_scheduler.py rename opacus/{scheduler.py => schedulers/noise_scheduler.py} (66%) create mode 100644 opacus/tests/schedulers/__init__.py create mode 100644 opacus/tests/schedulers/grad_clip_scheduler_test.py rename opacus/tests/{scheduler_test.py => schedulers/noise_scheduler_test.py} (95%) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index 61450e44..0bcad4a2 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -29,7 +29,7 @@ wrap_model, ) from opacus.optimizers import DPOptimizer, get_optimizer_class -from opacus.scheduler import _NoiseScheduler +from opacus.schedulers import _GradClipScheduler, _NoiseScheduler from opacus.utils.module_utils import trainable_parameters from opacus.validators.module_validator import ModuleValidator from torch import nn, optim @@ -550,6 +550,7 @@ def save_checkpoint( module: GradSampleModule, optimizer: Optional[DPOptimizer] = None, noise_scheduler: Optional[_NoiseScheduler] = None, + grad_clip_scheduler: Optional[_GradClipScheduler] = None, checkpoint_dict: Optional[Dict[str, Any]] = None, module_state_dict_kwargs: Optional[Dict[str, Any]] = None, torch_save_kwargs: Optional[Dict[str, Any]] = None, @@ -560,6 +561,9 @@ def save_checkpoint( path: Path to save the state dict objects. module: GradSampleModule to save; wrapped module's state_dict is saved. optimizer: DPOptimizer to save; wrapped optimizer's state_dict is saved. + noise_scheduler: _NoiseScheduler whose state we should save. + grad_clip_scheduler: _GradClipScheduler whose state we should save. + checkpoint_dict: Dict[str, Any]; an already-filled checkpoint dict. module_state_dict_kwargs: dict of kwargs to pass to ``module.state_dict()`` torch_save_kwargs: dict of kwargs to pass to ``torch.save()`` @@ -573,6 +577,10 @@ def save_checkpoint( checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict() if noise_scheduler is not None: checkpoint_dict["noise_scheduler_state_dict"] = noise_scheduler.state_dict() + if grad_clip_scheduler is not None: + checkpoint_dict[ + "grad_clip_scheduler_state_dict" + ] = grad_clip_scheduler.state_dict() torch.save(checkpoint_dict, path, **(torch_save_kwargs or {})) @@ -583,6 +591,7 @@ def load_checkpoint( module: GradSampleModule, optimizer: Optional[DPOptimizer] = None, noise_scheduler: Optional[_NoiseScheduler] = None, + grad_clip_scheduler: Optional[_GradClipScheduler] = None, module_load_dict_kwargs: Optional[Dict[str, Any]] = None, torch_load_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict: @@ -606,4 +615,10 @@ def load_checkpoint( if noise_scheduler is not None and len(noise_scheduler_state_dict) > 0: noise_scheduler.load_state_dict(noise_scheduler_state_dict) + grad_clip_scheduler_state_dict = checkpoint.pop( + "grad_clip_scheduler_state_dict", {} + ) + if grad_clip_scheduler is not None and len(grad_clip_scheduler_state_dict) > 0: + grad_clip_scheduler.load_state_dict(grad_clip_scheduler_state_dict) + return checkpoint diff --git a/opacus/schedulers/__init__.py b/opacus/schedulers/__init__.py new file mode 100644 index 00000000..5a59107b --- /dev/null +++ b/opacus/schedulers/__init__.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .grad_clip_scheduler import ( + ExponentialGradClip, + LambdaGradClip, + StepGradClip, + _GradClipScheduler, +) +from .noise_scheduler import ExponentialNoise, LambdaNoise, StepNoise, _NoiseScheduler + + +__all__ = [ + "_GradClipScheduler", + "ExponentialGradClip", + "LambdaGradClip", + "StepGradClip", + "_NoiseScheduler", + "ExponentialNoise", + "LambdaNoise", + "StepNoise", +] diff --git a/opacus/schedulers/grad_clip_scheduler.py b/opacus/schedulers/grad_clip_scheduler.py new file mode 100644 index 00000000..ffe5d277 --- /dev/null +++ b/opacus/schedulers/grad_clip_scheduler.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict + +from opacus.optimizers import DPOptimizer + + +class _GradClipScheduler: + """Base class for gradient clipping schedulers. We follow the same API + as the standard PyTorch LR schedulers, but apply them to Opacus's + `max_grad_norm` param instead. + + This means it only works when you pass a opacus.DPOptimizer, since that + will have a `max_grad_norm` attribute. + """ + + def __init__(self, optimizer: DPOptimizer, *, last_epoch=-1): + """ + Args: + optimizer (DPOptimizer): The DPOptimizer + *: Any other positional args (this is an abstract base class) + last_epoch(int): The index of last epoch. Default: -1. + """ + if not hasattr(optimizer, "max_grad_norm"): + raise ValueError( + "GradClipSchedulers require your optimizer to have a .max_grad_norm attr. " + "Are you sure you are using a DPOptimizer? Those have it added for you." + ) + self.optimizer = optimizer + self.last_epoch = last_epoch + + self.step() + + def state_dict(self) -> Dict: + """Returns the state of the scheduler as a :class:`dict`. + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + + """ + return { + key: value for key, value in self.__dict__.items() if key != "optimizer" + } + + def load_state_dict(self, state_dict: Dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_max_grad_norm(self): + """Implement your scheduling logic here and return the new value for `max_grad_norm`.""" + raise NotImplementedError + + def step(self): + self.last_epoch += 1 + max_grad_norm = self.get_max_grad_norm() + self.optimizer.max_grad_norm = max_grad_norm + + +class ExponentialGradClip(_GradClipScheduler): + """ + Multiplies the max_grad_norm by gamma every epoch (so the gamma factors accumulate). + This means that: + - For gamma < 1, max_grad_norm will shrink and you'll clip more + - For gamma == 1, no effect + - For gamma > 1, max_grad_norm will expand so you'll clip less + + When last_epoch=-1, sets initial max_grad_norm as max_grad_norm. + """ + + def __init__(self, optimizer: DPOptimizer, *, gamma: float, last_epoch: int = -1): + """ + Args: + optimizer: Wrapped optimizer + gamma: Multiplicative factor of learning rate decay. + last_epoch: The index of last epoch. Default: -1. + """ + self.gamma = gamma + super().__init__(optimizer, last_epoch=last_epoch) + + def get_max_grad_norm(self): + if self.last_epoch == 0: + return self.optimizer.max_grad_norm + else: + return self.optimizer.max_grad_norm * self.gamma + + +class LambdaGradClip(_GradClipScheduler): + """ + Multiplies your *base* `max_grad_norm` by the output of a `scheduler_function` given + as input. + Note: the base max_grad_norm is recorded as the max_grad_norm your optimizer had set at + the very beginning. This means that the factors from the `scheduler_function` will *not* + accumulate, unlike in ExponentialGradClip. If you want some exponential-like behavior, + accumulation logic will have to be added in your `scheduler_function`. + + When last_epoch=-1, sets initial max_grad_norm as max_grad_norm. + """ + + def __init__( + self, + optimizer: DPOptimizer, + *, + scheduler_function: Callable[[int], float], + last_epoch: int = -1, + ): + """ + + Args: + optimizer: Wrapped optimizer. + scheduler_function: A function which computes a multiplicative factor given + an integer epoch + last_epoch: The index of last epoch. Default: -1. + """ + self.scheduler_function = scheduler_function + self.base_max_grad_norm = optimizer.max_grad_norm + super().__init__(optimizer, last_epoch=last_epoch) + + def get_max_grad_norm(self): + return self.base_max_grad_norm * self.scheduler_function(self.last_epoch) + + +class StepGradClip(_GradClipScheduler): + """ + Multiplies `max_grad_norm` by `gamma` every `step_size` epochs (so the `gamma` factors accumulate). + This means that: + - For gamma < 1, max_grad_norm will shrink and you'll clip more + - For gamma == 1, no effect + - For gamma > 1, max_grad_norm will expand so you'll clip less + + When last_epoch=-1, sets initial max_grad_norm as max_grad_norm. + """ + + def __init__( + self, + optimizer: DPOptimizer, + *, + step_size: int, + gamma: float, + last_epoch: int = -1, + ): + """ + + Args: + optimizer: Wrapped optimizer. + step_size: Period of learning rate decay. + gamma: Multiplicative factor of learning rate decay. + last_epoch: The index of last epoch + """ + self.step_size = step_size + self.gamma = gamma + super().__init__(optimizer, last_epoch=last_epoch) + + def get_max_grad_norm(self): + # Only change max_grad_norm when at a 'step' + if self.last_epoch == 0 or self.last_epoch % self.step_size != 0: + return self.optimizer.max_grad_norm + else: + return self.gamma * self.optimizer.max_grad_norm diff --git a/opacus/scheduler.py b/opacus/schedulers/noise_scheduler.py similarity index 66% rename from opacus/scheduler.py rename to opacus/schedulers/noise_scheduler.py index b1284f87..674b6387 100644 --- a/opacus/scheduler.py +++ b/opacus/schedulers/noise_scheduler.py @@ -14,11 +14,30 @@ from typing import Callable, Dict -from .optimizers import DPOptimizer +from opacus.optimizers import DPOptimizer -class _NoiseScheduler(object): +class _NoiseScheduler: + """Base class for noise multiplier schedulers. We follow the same API + as the standard PyTorch LR schedulers, but apply them to Opacus's noise + multiplier param instead. + + This means it only works when you pass a opacus.DPOptimizer, since that + will have a `noise_multiplier` attribute. + """ + def __init__(self, optimizer: DPOptimizer, *, last_epoch=-1): + """ + Args: + optimizer (DPOptimizer): The DPOptimizer + *: Any other positional args (this is an abstract base class) + last_epoch(int): The index of last epoch. Default: -1. + """ + if not hasattr(optimizer, "noise_multiplier"): + raise ValueError( + "NoiseSchedulers require your optimizer to have a .noise_multiplier attr. " + "Are you sure you are using a DPOptimizer? Those have it added for you." + ) self.optimizer = optimizer self.last_epoch = last_epoch @@ -44,7 +63,7 @@ def load_state_dict(self, state_dict: Dict): self.__dict__.update(state_dict) def get_noise_multiplier(self): - # Compute learning rate using chainable form of the scheduler + """Implement your scheduling logic here and return the new value for `noise_multiplier`.""" raise NotImplementedError def step(self): @@ -55,7 +74,12 @@ def step(self): class ExponentialNoise(_NoiseScheduler): """ - Decays the noise_multiplier by gamma every epoch. + Multiplies the noise_multiplier by gamma every epoch (so the gamma factors accumulate). + This means that: + - For gamma < 1, noise_multiplier will shrink + - For gamma == 1, no effect + - For gamma > 1, noise_multiplier will expand + When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. """ @@ -66,7 +90,7 @@ def __init__(self, optimizer: DPOptimizer, *, gamma: float, last_epoch: int = -1 Args: optimizer: Wrapped optimizer gamma: Multiplicative factor of learning rate decay. - last_epoch: The index of last epoch + last_epoch: The index of last epoch. Default: -1. """ self.gamma = gamma super().__init__(optimizer, last_epoch=last_epoch) @@ -80,9 +104,15 @@ def get_noise_multiplier(self): class LambdaNoise(_NoiseScheduler): """ - Sets the noise_multiplier to the initial noise_multiplier times a given function. - When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. + Multiplies your *base* `noise_multiplier` by the output of a `scheduler_function` given + as input. + Note: the base noise_multiplier is recorded as the noise_multiplier your optimizer + had set at the very beginning. This means that the factors from the `scheduler_function` + will *not* accumulate, unlike in ExponentialGradClip. + If you want some exponential-like behavior, accumulation logic will have to be + added in your `scheduler_function`. + When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. """ def __init__( @@ -110,9 +140,13 @@ def get_noise_multiplier(self): class StepNoise(_NoiseScheduler): """ - Decays the noise_multiplier by gamma every step_size epochs. - When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. + Multiplies `noise_multiplier` by `gamma` every `step_size` epochs (so the `gamma` factors accumulate). + This means that: + - For gamma < 1, noise_multiplier will shrink + - For gamma == 1, no effect + - For gamma > 1, noise_multiplier will expand + When last_epoch=-1, sets initial noise_multiplier as noise_multiplier. """ def __init__( diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index 1b137bed..ba9e8312 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -19,7 +19,7 @@ import math import unittest from abc import ABC -from typing import Optional, OrderedDict, Type +from typing import Optional, OrderedDict from unittest.mock import MagicMock, patch import hypothesis.strategies as st @@ -30,7 +30,7 @@ from opacus import PrivacyEngine from opacus.layers.dp_multihead_attention import DPMultiheadAttention from opacus.optimizers.optimizer import _generate_noise -from opacus.scheduler import StepNoise +from opacus.schedulers import StepGradClip, StepNoise from opacus.utils.module_utils import are_state_dict_equal from opacus.validators.errors import UnsupportedModuleError from opacus.validators.module_validator import ModuleValidator @@ -550,10 +550,13 @@ def test_parameters_match(self): ) @given( - noise_scheduler=st.sampled_from([None, StepNoise]), + has_noise_scheduler=st.booleans(), + has_grad_clip_scheduler=st.booleans(), ) @settings(deadline=None) - def test_checkpoints(self, noise_scheduler: Optional[Type[StepNoise]]): + def test_checkpoints( + self, has_noise_scheduler: bool, has_grad_clip_scheduler: bool + ): # 1. Disable poisson sampling to avoid randomness in data loading caused by changing seeds. # 2. Use noise_multiplier=0.0 to avoid randomness in torch.normal() # create a set of components: set 1 @@ -563,11 +566,17 @@ def test_checkpoints(self, noise_scheduler: Optional[Type[StepNoise]]): poisson_sampling=False, grad_sample_mode=self.GRAD_SAMPLE_MODE, ) - s1 = ( - noise_scheduler(optimizer=opt1, step_size=1, gamma=1.0) - if noise_scheduler is not None + noise_scheduler1 = ( + StepNoise(optimizer=opt1, step_size=1, gamma=1.0) + if has_noise_scheduler + else None + ) + grad_clip_scheduler1 = ( + StepGradClip(optimizer=opt1, step_size=1, gamma=1.0) + if has_grad_clip_scheduler else None ) + # create a different set of components: set 2 torch.manual_seed(2) m2, opt2, _, pe2 = self._init_private_training( @@ -575,22 +584,37 @@ def test_checkpoints(self, noise_scheduler: Optional[Type[StepNoise]]): poisson_sampling=False, grad_sample_mode=self.GRAD_SAMPLE_MODE, ) - s2 = ( - noise_scheduler(optimizer=opt2, step_size=1, gamma=2.0) - if noise_scheduler is not None + noise_scheduler2 = ( + StepNoise(optimizer=opt2, step_size=1, gamma=2.0) + if has_noise_scheduler + else None + ) + grad_clip_scheduler2 = ( + StepGradClip(optimizer=opt2, step_size=1, gamma=2.0) + if has_grad_clip_scheduler else None ) # check that two sets of components are different self.assertFalse(are_state_dict_equal(m1.state_dict(), m2.state_dict())) - if noise_scheduler: - self.assertNotEqual(s1.state_dict(), s2.state_dict()) + if has_noise_scheduler: + self.assertNotEqual( + noise_scheduler1.state_dict(), noise_scheduler2.state_dict() + ) + + if has_grad_clip_scheduler: + self.assertNotEqual( + grad_clip_scheduler1.state_dict(), grad_clip_scheduler2.state_dict() + ) + self.assertNotEqual(opt1.noise_multiplier, opt2.noise_multiplier) # train set 1 for a few steps self._train_steps(m1, opt1, dl1) - if noise_scheduler: - s1.step() + if has_noise_scheduler: + noise_scheduler1.step() + if has_grad_clip_scheduler: + grad_clip_scheduler1.step() # load into set 2 checkpoint_to_save = {"foo": "bar"} @@ -599,12 +623,17 @@ def test_checkpoints(self, noise_scheduler: Optional[Type[StepNoise]]): path=bytesio, module=m1, optimizer=opt1, - noise_scheduler=s1, + noise_scheduler=noise_scheduler1, + grad_clip_scheduler=grad_clip_scheduler1, checkpoint_dict=checkpoint_to_save, ) bytesio.seek(0) loaded_checkpoint = pe2.load_checkpoint( - path=bytesio, module=m2, optimizer=opt2, noise_scheduler=s2 + path=bytesio, + module=m2, + optimizer=opt2, + noise_scheduler=noise_scheduler2, + grad_clip_scheduler=grad_clip_scheduler2, ) # check if loaded checkpoint has dummy dict @@ -614,44 +643,18 @@ def test_checkpoints(self, noise_scheduler: Optional[Type[StepNoise]]): # check the two sets of components are now the same self.assertEqual(pe1.accountant.state_dict(), pe2.accountant.state_dict()) self.assertTrue(are_state_dict_equal(m1.state_dict(), m2.state_dict())) - if noise_scheduler: - self.assertEqual(s1.state_dict(), s2.state_dict()) + if has_noise_scheduler: + self.assertEqual( + noise_scheduler1.state_dict(), noise_scheduler2.state_dict() + ) + if has_grad_clip_scheduler: + self.assertEqual( + grad_clip_scheduler1.state_dict(), grad_clip_scheduler2.state_dict() + ) + # check that non-state params are still different self.assertNotEqual(opt1.noise_multiplier, opt2.noise_multiplier) - # train the now loaded set 2 some more (change noise multiplier before doing so) - opt2.noise_multiplier = 0.0 - self._train_steps(m2, opt2, dl1) - if noise_scheduler: - s2.step() - - # recreate set 1 from scratch (set11) and check it is different from the trained set 2 - torch.manual_seed(1) - m11, opt11, dl11, _ = self._init_private_training( - noise_multiplier=0.0, - poisson_sampling=False, - grad_sample_mode=self.GRAD_SAMPLE_MODE, - ) - s11 = ( - noise_scheduler(optimizer=opt11, step_size=1, gamma=1.0) - if noise_scheduler is not None - else None - ) - self.assertFalse(are_state_dict_equal(m2.state_dict(), m11.state_dict())) - if noise_scheduler: - self.assertNotEqual(s2.state_dict(), s11.state_dict()) - # train the recreated set for the same number of steps - self._train_steps(m11, opt11, dl11) - if noise_scheduler: - s11.step() - self._train_steps(m11, opt11, dl11) - if noise_scheduler: - s11.step() - # check that recreated set is now same as the original set 1 after training - self.assertTrue(are_state_dict_equal(m2.state_dict(), m11.state_dict())) - if noise_scheduler: - self.assertEqual(s2.state_dict(), s11.state_dict()) - @given( noise_multiplier=st.floats(0.5, 5.0), max_steps=st.integers(8, 10), diff --git a/opacus/tests/schedulers/__init__.py b/opacus/tests/schedulers/__init__.py new file mode 100644 index 00000000..14d10d0a --- /dev/null +++ b/opacus/tests/schedulers/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/opacus/tests/schedulers/grad_clip_scheduler_test.py b/opacus/tests/schedulers/grad_clip_scheduler_test.py new file mode 100644 index 00000000..90dd04b4 --- /dev/null +++ b/opacus/tests/schedulers/grad_clip_scheduler_test.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from opacus import PrivacyEngine +from opacus.schedulers import ExponentialGradClip, LambdaGradClip, StepGradClip +from torch import nn, optim +from torch.utils.data import DataLoader, TensorDataset + + +class GradClipSchedulerTest(unittest.TestCase): + def setUp(self): + n_data, dim = 100, 10 + data = torch.randn(n_data, dim) + model = nn.Linear(10, 10) + optimizer = optim.SGD(model.parameters(), lr=0.1) + data_loader = DataLoader(TensorDataset(data), batch_size=10) + self.engine = PrivacyEngine() + + self.module, self.optimizer, self.data_loader = self.engine.make_private( + module=model, + optimizer=optimizer, + data_loader=data_loader, + noise_multiplier=1.0, + max_grad_norm=1.0, + ) + + def test_exponential_scheduler(self): + gamma = 0.99 + scheduler = ExponentialGradClip(self.optimizer, gamma=gamma) + + self.assertEqual(self.optimizer.max_grad_norm, 1.0) + scheduler.step() + self.assertEqual(self.optimizer.max_grad_norm, gamma) + + def test_step_scheduler(self): + gamma = 0.1 + step_size = 2 + scheduler = StepGradClip(self.optimizer, step_size=step_size, gamma=gamma) + + self.assertEqual(self.optimizer.max_grad_norm, 1.0) + scheduler.step() + self.assertEqual(self.optimizer.max_grad_norm, 1.0) + scheduler.step() + self.assertEqual(self.optimizer.max_grad_norm, gamma) + scheduler.step() + self.assertEqual(self.optimizer.max_grad_norm, gamma) + scheduler.step() + self.assertEqual(self.optimizer.max_grad_norm, gamma**2) + + def test_lambda_scheduler(self): + def scheduler_function(epoch): + return 1 - epoch / 10 + + scheduler = LambdaGradClip( + self.optimizer, scheduler_function=scheduler_function + ) + + self.assertEqual(self.optimizer.max_grad_norm, 1.0) + scheduler.step() + self.assertEqual(self.optimizer.max_grad_norm, scheduler_function(1)) diff --git a/opacus/tests/scheduler_test.py b/opacus/tests/schedulers/noise_scheduler_test.py similarity index 95% rename from opacus/tests/scheduler_test.py rename to opacus/tests/schedulers/noise_scheduler_test.py index e7a99efb..f406c387 100644 --- a/opacus/tests/scheduler_test.py +++ b/opacus/tests/schedulers/noise_scheduler_test.py @@ -17,12 +17,12 @@ import torch from opacus import PrivacyEngine -from opacus.scheduler import ExponentialNoise, LambdaNoise, StepNoise +from opacus.schedulers import ExponentialNoise, LambdaNoise, StepNoise from torch import nn, optim from torch.utils.data import DataLoader, TensorDataset -class SchedulerTest(unittest.TestCase): +class NoiseSchedulerTest(unittest.TestCase): def setUp(self): n_data, dim = 100, 10 data = torch.randn(n_data, dim)