From 8dfdc2ed9af7ff4aa64cab65c4831233468fe988 Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Thu, 9 Nov 2023 10:08:09 -0800 Subject: [PATCH 1/4] Fix Opacus's failed tests (#609) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/609 Checked that the new type of (and failed) health check from hypothesis 4.57.1 (https://hypothesis.readthedocs.io/en/latest/settings.html#health-checks) is not very important, so I just disabled it. Also fixed the expired image in the "config.yml". Reviewed By: lucamelis Differential Revision: D51126461 fbshipit-source-id: 6a03dc1ea27e6ccac51b407c48c07e9b071b31bb --- .circleci/config.yml | 7 +++---- opacus/tests/batch_memory_manager_test.py | 6 +++--- opacus/tests/privacy_engine_test.py | 8 ++++---- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f149f43e..99f1f43f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -343,7 +343,7 @@ jobs: integrationtest_py39_torch_release_cuda: machine: resource_class: gpu.nvidia.small.multi - image: ubuntu-2004-cuda-11.4:202110-01 + image: linux-cuda-12:default steps: - checkout - py_3_9_setup @@ -363,7 +363,7 @@ jobs: micro_benchmarks_py39_torch_release_cuda: machine: resource_class: gpu.nvidia.small.multi - image: ubuntu-2004-cuda-11.4:202110-01 + image: linux-cuda-12:default steps: - checkout - py_3_9_setup @@ -447,7 +447,7 @@ jobs: unittest_multi_gpu: machine: resource_class: gpu.nvidia.medium.multi - image: ubuntu-2004-cuda-11.4:202110-01 + image: linux-cuda-12:default steps: - checkout - py_3_9_setup @@ -515,4 +515,3 @@ workflows: filters: *exclude_ghpages - micro_benchmarks_py39_torch_release_cuda: filters: *exclude_ghpages - diff --git a/opacus/tests/batch_memory_manager_test.py b/opacus/tests/batch_memory_manager_test.py index 26288d1e..347da1ad 100644 --- a/opacus/tests/batch_memory_manager_test.py +++ b/opacus/tests/batch_memory_manager_test.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from hypothesis import given, settings +from hypothesis import given, settings, HealthCheck from hypothesis import strategies as st from opacus import PrivacyEngine from opacus.utils.batch_memory_manager import BatchMemoryManager @@ -59,7 +59,7 @@ def _init_training(self, batch_size=10, **data_loader_kwargs): batch_size=st.sampled_from([8, 16, 64]), max_physical_batch_size=st.sampled_from([4, 8]), ) - @settings(deadline=10000) + @settings(suppress_health_check=list(HealthCheck), deadline=10000) def test_basic( self, num_workers: int, @@ -119,7 +119,7 @@ def test_basic( num_workers=st.integers(0, 4), pin_memory=st.booleans(), ) - @settings(deadline=10000) + @settings(suppress_health_check=list(HealthCheck), deadline=10000) def test_empty_batch( self, num_workers: int, diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index d89f96e4..05f3e535 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -26,7 +26,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from hypothesis import given, settings +from hypothesis import given, settings, HealthCheck from opacus import PrivacyEngine from opacus.layers.dp_multihead_attention import DPMultiheadAttention from opacus.optimizers.optimizer import _generate_noise @@ -266,7 +266,7 @@ def _compare_to_vanilla( use_closure=st.booleans(), max_steps=st.sampled_from([1, 4]), ) - @settings(deadline=None) + @settings(suppress_health_check=list(HealthCheck), deadline=None) def test_compare_to_vanilla( self, do_clip: bool, @@ -552,7 +552,7 @@ def test_parameters_match(self): has_noise_scheduler=st.booleans(), has_grad_clip_scheduler=st.booleans(), ) - @settings(deadline=None) + @settings(suppress_health_check=list(HealthCheck), deadline=None) def test_checkpoints( self, has_noise_scheduler: bool, has_grad_clip_scheduler: bool ): @@ -659,7 +659,7 @@ def test_checkpoints( max_steps=st.integers(8, 10), secure_mode=st.just(False), # TODO: enable after fixing torchcsprng build ) - @settings(deadline=None) + @settings(suppress_health_check=list(HealthCheck), deadline=None) def test_noise_level( self, noise_multiplier: float, From 3d622d0d4bafed63ab1cc2bdc6f8a88f26ab0089 Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Thu, 9 Nov 2023 12:52:16 -0800 Subject: [PATCH 2/4] Fix the import order for D51126461 (#610) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/610 In D51126461, we did not import the library according to the alphabetical order, therefore triggering some linter test failure. We therefore made this diff to fix it. Reviewed By: lucamelis Differential Revision: D51166750 fbshipit-source-id: 7dd441b49728a63502a6a3f536559e754ccae576 --- opacus/tests/batch_memory_manager_test.py | 2 +- opacus/tests/privacy_engine_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/opacus/tests/batch_memory_manager_test.py b/opacus/tests/batch_memory_manager_test.py index 347da1ad..bfe1721a 100644 --- a/opacus/tests/batch_memory_manager_test.py +++ b/opacus/tests/batch_memory_manager_test.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from hypothesis import given, settings, HealthCheck +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st from opacus import PrivacyEngine from opacus.utils.batch_memory_manager import BatchMemoryManager diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index 05f3e535..4e2b33ff 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -26,7 +26,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from hypothesis import given, settings, HealthCheck +from hypothesis import HealthCheck, given, settings from opacus import PrivacyEngine from opacus.layers.dp_multihead_attention import DPMultiheadAttention from opacus.optimizers.optimizer import _generate_noise From ad084da9e46b22d6bc341958855a04c00ffb9b1f Mon Sep 17 00:00:00 2001 From: Solosneros <24623119+Solosneros@users.noreply.github.com> Date: Tue, 28 Nov 2023 08:14:52 -0800 Subject: [PATCH 3/4] fix: make prv accountant robust to larger epsilons (#606) Summary: ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue Hi, this PR fixes https://github.com/pytorch/opacus/issues/601 and https://github.com/pytorch/opacus/issues/604. It will introduce the same fix as in https://github.com/microsoft/prv_accountant/pull/38. Lukas (author of prv accountant, wulu473) said that `In general, adding any additional points is safe and won't affect the robustness negatively.` The cause of these errors seems to be the grid for computing the `mean()` function of the `PrivacyRandomVariableTruncated` class. The grid (`points` variable) used to compute the mean is constant apart from the lowest (`self.t_min`) and highest point (`self.t_max`). This PR determines the grid (`points` variable) based on the lowest and highest point. More information is below. Best **Observation** I debugged the code and arrived at some point at the `mean()` function of the `PrivacyRandomVariableTruncated` class. The grid (`points` variable) used to compute the mean is constant apart from the lowest (`self.t_min`) and highest point (`self.t_max`). See the line of code [here](https://github.com/microsoft/prv_accountant/blob/a95c4e2d41ff4886c3e4a84925edf878a6540e0a/prv_accountant/privacy_random_variables/abstract_privacy_random_variable.py#L52). It looks like this `[self.tmin, -0.1, -0.01, -0.001, -0.0001, -1e-05, 1e-05, 0.0001, 0.001, 0.01, 0.1, self.tmax]`. It seems that the `tmin` and `tmax` are of the order of `[-12,12]` for the examples that I posted above and even up to `[-48,48]` for the example that jeandut posted in the https://github.com/pytorch/opacus/issues/604 issue whereas they are more like `[-7,7]` for the [readme example for DP-SGD](https://github.com/microsoft/prv_accountant#dp-sgd). We suspect that the integration breaks down when the gridspacing between between `tmin` / `tmax` get's too large. **Proposed solution** Determine the points grid based on `tmin` and `tmax` but determines the start and end of the logspace based on `tmin` and `tmax`. Before: (https://github.com/pytorch/opacus/blob/95df0904ae5d2b3aaa26b708e5067e9271624036/opacus/accountants/analysis/prv/prvs.py#L99-L106) After: ``` # determine points based on t_min and t_max lower_exponent = int(np.log10(np.abs(self.t_min))) upper_exponent = int(np.log10(self.t_max)) points = np.concatenate( [ [self.t_min], -np.logspace(start=lower_exponent, stop=-5, num=10), [0], np.logspace(start=-5, stop=upper_exponent, num=10), [self.t_max], ] ) ``` ## How Has This Been Tested (if it applies) I ran the examples from the issues https://github.com/pytorch/opacus/issues/601 and https://github.com/pytorch/opacus/issues/604 and they don't break anymore. ``` import opacus target_delta = 0.001 target_epsilon = 20 steps = 5000 sample_rate=0.19120458891013384 for target_epsilon in [20, 50]: noise_multiplier = opacus.privacy_engine.get_noise_multiplier(target_delta=target_delta, target_epsilon=target_epsilon, steps=steps, sample_rate=sample_rate, accountant="prv") prv_accountant = opacus.accountants.utils.create_accountant("prv") prv_accountant.history = [(noise_multiplier, sample_rate, steps)] obtained_epsilon = prv_accountant.get_epsilon(delta=target_delta) print(f"target epsilon {target_epsilon}, obtained epsilon {obtained_epsilon}") ``` > target epsilon 20, obtained epsilon 19.999332284974717 target epsilon 50, obtained epsilon 49.99460075990896 ``` target_epsilon = 4 batch_size = 50 epochs = 5 delta = 1e-05 expected_len_dataloader = 500 // batch_size sample_rate = 1/expected_len_dataloader noise_multiplier = opacus.privacy_engine.get_noise_multiplier(target_delta=target_delta, target_epsilon=target_epsilon, epochs=epochs, sample_rate=sample_rate, accountant="prv") prv_accountant = opacus.accountants.utils.create_accountant("prv") prv_accountant.history = [(noise_multiplier, sample_rate, int(epochs / sample_rate))] obtained_epsilon = prv_accountant.get_epsilon(delta=target_delta) print(f"target epsilon {target_epsilon}, obtained epsilon {obtained_epsilon}") ``` > target epsilon 4, obtained epsilon 3.9968389923130356 ## Checklist - [x] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [ ] All tests passed, and additional code has been covered with new tests. Not able to run all tests locally and unsure if new tests should be added. Pull Request resolved: https://github.com/pytorch/opacus/pull/606 Reviewed By: HuanyuZhang Differential Revision: D50111887 fbshipit-source-id: 2f77f8bc0e59837f765b87f2e107bc01015b9481 --- opacus/accountants/analysis/prv/prvs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/opacus/accountants/analysis/prv/prvs.py b/opacus/accountants/analysis/prv/prvs.py index 66df68e6..9650b759 100644 --- a/opacus/accountants/analysis/prv/prvs.py +++ b/opacus/accountants/analysis/prv/prvs.py @@ -96,11 +96,15 @@ def mean(self) -> float: """ Calculate the mean using numerical integration. """ + # determine points based on t_min and t_max + lower_exponent = int(np.log10(np.abs(self.t_min))) + upper_exponent = int(np.log10(self.t_max)) points = np.concatenate( [ [self.t_min], - -np.logspace(-5, -1, 5)[::-1], - np.logspace(-5, -1, 5), + -np.logspace(start=lower_exponent, stop=-5, num=10), + [0], + np.logspace(start=-5, stop=upper_exponent, num=10), [self.t_max], ] ) From da05f77877152cf4e7fb8f66fbdc1ed351bf4d76 Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Tue, 5 Dec 2023 17:55:01 -0800 Subject: [PATCH 4/4] Fixing bugs for DP MultiheadAttention (#598) Summary: Pull Request resolved: https://github.com/pytorch/opacus/pull/598 Fixing the null pointers in calling DP MultiheadAttention by transform.forward Reviewed By: karthikprasad Differential Revision: D47405312 fbshipit-source-id: c323503ed5ecf2e8f0fc8e5d588cee563d972a4a --- opacus/layers/dp_multihead_attention.py | 41 ++++++++++++++++++++++-- opacus/validators/multihead_attention.py | 1 + 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/opacus/layers/dp_multihead_attention.py b/opacus/layers/dp_multihead_attention.py index acbdf31e..40b5c8ed 100644 --- a/opacus/layers/dp_multihead_attention.py +++ b/opacus/layers/dp_multihead_attention.py @@ -89,6 +89,7 @@ def __init__( add_zero_attn=False, kdim=None, vdim=None, + batch_first=False, device=None, dtype=None, ): @@ -96,10 +97,13 @@ def __init__( self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + # when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False. + self._qkv_same_embed_dim = False self.num_heads = num_heads self.dropout = dropout + self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim @@ -120,6 +124,10 @@ def __init__( self.dropout = nn.Dropout(dropout) + # to avoid null pointers in Transformer.forward + self.in_proj_weight = None + self.in_proj_bias = None + def load_state_dict(self, state_dict): r""" Loads module from previously saved state. @@ -178,7 +186,33 @@ def forward( key_padding_mask=None, need_weights=True, attn_mask=None, + is_causal=False, ): + is_batched = query.dim() == 3 + + assert is_batched == True, "The query must have a dimension of 3." + + r""" + As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function, + since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``. + """ + assert ( + is_causal == False + ), "We currently do not support causal mask. Will fix it in the future." + + r""" + Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html). + """ + if self.batch_first: + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + tgt_len, bsz, embed_dim = query.size() if embed_dim != self.embed_dim: raise ValueError( @@ -323,6 +357,9 @@ def forward( ) attn_output = self.out_proj(attn_output) + if self.batch_first: + attn_output = attn_output.transpose(1, 0) + if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view( @@ -361,7 +398,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): keep_vars=keep_vars, ) - if self._qkv_same_embed_dim: + if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim): destination_alter[prefix + "in_proj_weight"] = torch.cat( ( destination[prefix + "qlinear.weight"], diff --git a/opacus/validators/multihead_attention.py b/opacus/validators/multihead_attention.py index acf80aba..1bccf541 100644 --- a/opacus/validators/multihead_attention.py +++ b/opacus/validators/multihead_attention.py @@ -45,6 +45,7 @@ def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention: add_zero_attn=module.add_zero_attn, kdim=module.kdim, vdim=module.vdim, + batch_first=module.batch_first, ) dp_attn.load_state_dict(module.state_dict()) return dp_attn