Skip to content

Commit

Permalink
Fixing bugs for DP MultiheadAttention (#598)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #598

Fixing the null pointers in calling DP MultiheadAttention by transform.forward

Differential Revision: D47405312

fbshipit-source-id: fc0a62e790a023d2a9eff9c189b0e09e1ded68bc
  • Loading branch information
HuanyuZhang authored and facebook-github-bot committed Jul 20, 2023
1 parent e8bc932 commit 288f55c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
36 changes: 34 additions & 2 deletions opacus/layers/dp_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,22 @@ def __init__(
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first = False,
device=None,
dtype=None,
):
super(DPMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

# when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" will be used in the transformer, which should be avoided. This is why we force self._qkv_same_embed_dim = False.
# self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self._qkv_same_embed_dim = False

self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
Expand All @@ -120,6 +125,10 @@ def __init__(

self.dropout = nn.Dropout(dropout)

# to avoid null pointers in Transformer.forward
self.in_proj_weight = None
self.in_proj_bias = None

def load_state_dict(self, state_dict):
r"""
Loads module from previously saved state.
Expand Down Expand Up @@ -178,7 +187,27 @@ def forward(
key_padding_mask=None,
need_weights=True,
attn_mask=None,
is_causal = False
):
is_batched = query.dim()==3
if not is_batched:
raise ValueError(
"The query must have a dimension of 3."
)
if is_causal:
raise ValueError(
"We currently do not support causal mask."
)
if self.batch_first:
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

tgt_len, bsz, embed_dim = query.size()
if embed_dim != self.embed_dim:
raise ValueError(
Expand Down Expand Up @@ -323,6 +352,9 @@ def forward(
)
attn_output = self.out_proj(attn_output)

if self.batch_first:
attn_output = attn_output.transpose(1, 0)

if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(
Expand Down Expand Up @@ -361,7 +393,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False):
keep_vars=keep_vars,
)

if self._qkv_same_embed_dim:
if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim):
destination_alter[prefix + "in_proj_weight"] = torch.cat(
(
destination[prefix + "qlinear.weight"],
Expand Down
1 change: 1 addition & 0 deletions opacus/validators/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention:
add_zero_attn=module.add_zero_attn,
kdim=module.kdim,
vdim=module.vdim,
batch_first=module.batch_first
)
dp_attn.load_state_dict(module.state_dict())
return dp_attn

0 comments on commit 288f55c

Please sign in to comment.