diff --git a/.circleci/config.yml b/.circleci/config.yml index f149f43e..99f1f43f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -343,7 +343,7 @@ jobs: integrationtest_py39_torch_release_cuda: machine: resource_class: gpu.nvidia.small.multi - image: ubuntu-2004-cuda-11.4:202110-01 + image: linux-cuda-12:default steps: - checkout - py_3_9_setup @@ -363,7 +363,7 @@ jobs: micro_benchmarks_py39_torch_release_cuda: machine: resource_class: gpu.nvidia.small.multi - image: ubuntu-2004-cuda-11.4:202110-01 + image: linux-cuda-12:default steps: - checkout - py_3_9_setup @@ -447,7 +447,7 @@ jobs: unittest_multi_gpu: machine: resource_class: gpu.nvidia.medium.multi - image: ubuntu-2004-cuda-11.4:202110-01 + image: linux-cuda-12:default steps: - checkout - py_3_9_setup @@ -515,4 +515,3 @@ workflows: filters: *exclude_ghpages - micro_benchmarks_py39_torch_release_cuda: filters: *exclude_ghpages - diff --git a/opacus/accountants/analysis/prv/prvs.py b/opacus/accountants/analysis/prv/prvs.py index 66df68e6..9650b759 100644 --- a/opacus/accountants/analysis/prv/prvs.py +++ b/opacus/accountants/analysis/prv/prvs.py @@ -96,11 +96,15 @@ def mean(self) -> float: """ Calculate the mean using numerical integration. """ + # determine points based on t_min and t_max + lower_exponent = int(np.log10(np.abs(self.t_min))) + upper_exponent = int(np.log10(self.t_max)) points = np.concatenate( [ [self.t_min], - -np.logspace(-5, -1, 5)[::-1], - np.logspace(-5, -1, 5), + -np.logspace(start=lower_exponent, stop=-5, num=10), + [0], + np.logspace(start=-5, stop=upper_exponent, num=10), [self.t_max], ] ) diff --git a/opacus/layers/dp_multihead_attention.py b/opacus/layers/dp_multihead_attention.py index acbdf31e..40b5c8ed 100644 --- a/opacus/layers/dp_multihead_attention.py +++ b/opacus/layers/dp_multihead_attention.py @@ -89,6 +89,7 @@ def __init__( add_zero_attn=False, kdim=None, vdim=None, + batch_first=False, device=None, dtype=None, ): @@ -96,10 +97,13 @@ def __init__( self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + # when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False. + self._qkv_same_embed_dim = False self.num_heads = num_heads self.dropout = dropout + self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim @@ -120,6 +124,10 @@ def __init__( self.dropout = nn.Dropout(dropout) + # to avoid null pointers in Transformer.forward + self.in_proj_weight = None + self.in_proj_bias = None + def load_state_dict(self, state_dict): r""" Loads module from previously saved state. @@ -178,7 +186,33 @@ def forward( key_padding_mask=None, need_weights=True, attn_mask=None, + is_causal=False, ): + is_batched = query.dim() == 3 + + assert is_batched == True, "The query must have a dimension of 3." + + r""" + As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function, + since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``. + """ + assert ( + is_causal == False + ), "We currently do not support causal mask. Will fix it in the future." + + r""" + Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html). + """ + if self.batch_first: + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + tgt_len, bsz, embed_dim = query.size() if embed_dim != self.embed_dim: raise ValueError( @@ -323,6 +357,9 @@ def forward( ) attn_output = self.out_proj(attn_output) + if self.batch_first: + attn_output = attn_output.transpose(1, 0) + if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view( @@ -361,7 +398,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): keep_vars=keep_vars, ) - if self._qkv_same_embed_dim: + if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim): destination_alter[prefix + "in_proj_weight"] = torch.cat( ( destination[prefix + "qlinear.weight"], diff --git a/opacus/tests/batch_memory_manager_test.py b/opacus/tests/batch_memory_manager_test.py index 26288d1e..bfe1721a 100644 --- a/opacus/tests/batch_memory_manager_test.py +++ b/opacus/tests/batch_memory_manager_test.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st from opacus import PrivacyEngine from opacus.utils.batch_memory_manager import BatchMemoryManager @@ -59,7 +59,7 @@ def _init_training(self, batch_size=10, **data_loader_kwargs): batch_size=st.sampled_from([8, 16, 64]), max_physical_batch_size=st.sampled_from([4, 8]), ) - @settings(deadline=10000) + @settings(suppress_health_check=list(HealthCheck), deadline=10000) def test_basic( self, num_workers: int, @@ -119,7 +119,7 @@ def test_basic( num_workers=st.integers(0, 4), pin_memory=st.booleans(), ) - @settings(deadline=10000) + @settings(suppress_health_check=list(HealthCheck), deadline=10000) def test_empty_batch( self, num_workers: int, diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index d89f96e4..4e2b33ff 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -26,7 +26,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from opacus import PrivacyEngine from opacus.layers.dp_multihead_attention import DPMultiheadAttention from opacus.optimizers.optimizer import _generate_noise @@ -266,7 +266,7 @@ def _compare_to_vanilla( use_closure=st.booleans(), max_steps=st.sampled_from([1, 4]), ) - @settings(deadline=None) + @settings(suppress_health_check=list(HealthCheck), deadline=None) def test_compare_to_vanilla( self, do_clip: bool, @@ -552,7 +552,7 @@ def test_parameters_match(self): has_noise_scheduler=st.booleans(), has_grad_clip_scheduler=st.booleans(), ) - @settings(deadline=None) + @settings(suppress_health_check=list(HealthCheck), deadline=None) def test_checkpoints( self, has_noise_scheduler: bool, has_grad_clip_scheduler: bool ): @@ -659,7 +659,7 @@ def test_checkpoints( max_steps=st.integers(8, 10), secure_mode=st.just(False), # TODO: enable after fixing torchcsprng build ) - @settings(deadline=None) + @settings(suppress_health_check=list(HealthCheck), deadline=None) def test_noise_level( self, noise_multiplier: float, diff --git a/opacus/validators/multihead_attention.py b/opacus/validators/multihead_attention.py index acf80aba..1bccf541 100644 --- a/opacus/validators/multihead_attention.py +++ b/opacus/validators/multihead_attention.py @@ -45,6 +45,7 @@ def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention: add_zero_attn=module.add_zero_attn, kdim=module.kdim, vdim=module.vdim, + batch_first=module.batch_first, ) dp_attn.load_state_dict(module.state_dict()) return dp_attn