From cbc61b84f3a57f5342811c6fcaf85857c02debad Mon Sep 17 00:00:00 2001 From: Huanyu Zhang Date: Thu, 16 May 2024 10:31:33 -0700 Subject: [PATCH] Fix gradient shape error for DPMultiheadAttention (issue 650) Summary: When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step(). Details in: https://github.com/pytorch/opacus/issues/650 Differential Revision: D57446245 --- opacus/layers/dp_multihead_attention.py | 29 ++++++++++++------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/opacus/layers/dp_multihead_attention.py b/opacus/layers/dp_multihead_attention.py index 40b5c8ed..bc69a067 100644 --- a/opacus/layers/dp_multihead_attention.py +++ b/opacus/layers/dp_multihead_attention.py @@ -203,17 +203,12 @@ def forward( r""" Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html). """ - if self.batch_first: - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - tgt_len, bsz, embed_dim = query.size() + if not self.batch_first: + tgt_len, bsz, embed_dim = query.size() + else: + bsz, tgt_len, embed_dim = query.size() + if embed_dim != self.embed_dim: raise ValueError( f"query has as size of {embed_dim} while the embedding" @@ -234,6 +229,9 @@ def forward( q = q * scaling + if self.batch_first: + q, k, v = [x.transpose(1, 0) for x in (q, k, v)] + if attn_mask is not None: if attn_mask.dtype not in ( torch.float32, @@ -352,13 +350,14 @@ def forward( attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = self.out_proj(attn_output) if self.batch_first: - attn_output = attn_output.transpose(1, 0) + attn_output = attn_output.contiguous().view(bsz, tgt_len, embed_dim) + else: + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = self.out_proj(attn_output) if need_weights: # average attention weights over heads