Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
dwahdany authored Jan 5, 2024
2 parents 255e371 + da05f77 commit 51f5111
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 15 deletions.
7 changes: 3 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ jobs:
integrationtest_py39_torch_release_cuda:
machine:
resource_class: gpu.nvidia.small.multi
image: ubuntu-2004-cuda-11.4:202110-01
image: linux-cuda-12:default
steps:
- checkout
- py_3_9_setup
Expand All @@ -363,7 +363,7 @@ jobs:
micro_benchmarks_py39_torch_release_cuda:
machine:
resource_class: gpu.nvidia.small.multi
image: ubuntu-2004-cuda-11.4:202110-01
image: linux-cuda-12:default
steps:
- checkout
- py_3_9_setup
Expand Down Expand Up @@ -447,7 +447,7 @@ jobs:
unittest_multi_gpu:
machine:
resource_class: gpu.nvidia.medium.multi
image: ubuntu-2004-cuda-11.4:202110-01
image: linux-cuda-12:default
steps:
- checkout
- py_3_9_setup
Expand Down Expand Up @@ -515,4 +515,3 @@ workflows:
filters: *exclude_ghpages
- micro_benchmarks_py39_torch_release_cuda:
filters: *exclude_ghpages

8 changes: 6 additions & 2 deletions opacus/accountants/analysis/prv/prvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,15 @@ def mean(self) -> float:
"""
Calculate the mean using numerical integration.
"""
# determine points based on t_min and t_max
lower_exponent = int(np.log10(np.abs(self.t_min)))
upper_exponent = int(np.log10(self.t_max))
points = np.concatenate(
[
[self.t_min],
-np.logspace(-5, -1, 5)[::-1],
np.logspace(-5, -1, 5),
-np.logspace(start=lower_exponent, stop=-5, num=10),
[0],
np.logspace(start=-5, stop=upper_exponent, num=10),
[self.t_max],
]
)
Expand Down
41 changes: 39 additions & 2 deletions opacus/layers/dp_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,21 @@ def __init__(
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
device=None,
dtype=None,
):
super(DPMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

# when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" and fast path calculation will be used in "nn.transformer", which should be avoided. This is why we force self._qkv_same_embed_dim = False.
self._qkv_same_embed_dim = False

self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
Expand All @@ -120,6 +124,10 @@ def __init__(

self.dropout = nn.Dropout(dropout)

# to avoid null pointers in Transformer.forward
self.in_proj_weight = None
self.in_proj_bias = None

def load_state_dict(self, state_dict):
r"""
Loads module from previously saved state.
Expand Down Expand Up @@ -178,7 +186,33 @@ def forward(
key_padding_mask=None,
need_weights=True,
attn_mask=None,
is_causal=False,
):
is_batched = query.dim() == 3

assert is_batched == True, "The query must have a dimension of 3."

r"""
As per https://github.com/pytorch/opacus/issues/596, we have to include ``is_causal`` as a dummy parameter of the function,
since it is used in the ``forward`` function of parent class ``nn.TransformerEncoderLayer``.
"""
assert (
is_causal == False
), "We currently do not support causal mask. Will fix it in the future."

r"""
Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).
"""
if self.batch_first:
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

tgt_len, bsz, embed_dim = query.size()
if embed_dim != self.embed_dim:
raise ValueError(
Expand Down Expand Up @@ -323,6 +357,9 @@ def forward(
)
attn_output = self.out_proj(attn_output)

if self.batch_first:
attn_output = attn_output.transpose(1, 0)

if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
Expand Down Expand Up @@ -361,7 +398,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
keep_vars=keep_vars,
)

if self._qkv_same_embed_dim:
if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim):
destination_alter[prefix + "in_proj_weight"] = torch.cat(
(
destination[prefix + "qlinear.weight"],
Expand Down
6 changes: 3 additions & 3 deletions opacus/tests/batch_memory_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch
import torch.nn as nn
from hypothesis import given, settings
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
Expand Down Expand Up @@ -59,7 +59,7 @@ def _init_training(self, batch_size=10, **data_loader_kwargs):
batch_size=st.sampled_from([8, 16, 64]),
max_physical_batch_size=st.sampled_from([4, 8]),
)
@settings(deadline=10000)
@settings(suppress_health_check=list(HealthCheck), deadline=10000)
def test_basic(
self,
num_workers: int,
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_basic(
num_workers=st.integers(0, 4),
pin_memory=st.booleans(),
)
@settings(deadline=10000)
@settings(suppress_health_check=list(HealthCheck), deadline=10000)
def test_empty_batch(
self,
num_workers: int,
Expand Down
8 changes: 4 additions & 4 deletions opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from hypothesis import given, settings
from hypothesis import HealthCheck, given, settings
from opacus import PrivacyEngine
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
from opacus.optimizers.optimizer import _generate_noise
Expand Down Expand Up @@ -266,7 +266,7 @@ def _compare_to_vanilla(
use_closure=st.booleans(),
max_steps=st.sampled_from([1, 4]),
)
@settings(deadline=None)
@settings(suppress_health_check=list(HealthCheck), deadline=None)
def test_compare_to_vanilla(
self,
do_clip: bool,
Expand Down Expand Up @@ -552,7 +552,7 @@ def test_parameters_match(self):
has_noise_scheduler=st.booleans(),
has_grad_clip_scheduler=st.booleans(),
)
@settings(deadline=None)
@settings(suppress_health_check=list(HealthCheck), deadline=None)
def test_checkpoints(
self, has_noise_scheduler: bool, has_grad_clip_scheduler: bool
):
Expand Down Expand Up @@ -659,7 +659,7 @@ def test_checkpoints(
max_steps=st.integers(8, 10),
secure_mode=st.just(False), # TODO: enable after fixing torchcsprng build
)
@settings(deadline=None)
@settings(suppress_health_check=list(HealthCheck), deadline=None)
def test_noise_level(
self,
noise_multiplier: float,
Expand Down
1 change: 1 addition & 0 deletions opacus/validators/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention:
add_zero_attn=module.add_zero_attn,
kdim=module.kdim,
vdim=module.vdim,
batch_first=module.batch_first,
)
dp_attn.load_state_dict(module.state_dict())
return dp_attn

0 comments on commit 51f5111

Please sign in to comment.