diff --git a/opacus/layers/dp_multihead_attention.py b/opacus/layers/dp_multihead_attention.py index acbdf31e..081938b7 100644 --- a/opacus/layers/dp_multihead_attention.py +++ b/opacus/layers/dp_multihead_attention.py @@ -89,6 +89,7 @@ def __init__( add_zero_attn=False, kdim=None, vdim=None, + batch_first = False, device=None, dtype=None, ): @@ -96,10 +97,14 @@ def __init__( self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + # when self._qkv_same_embed_dim = True, "in_proj_weight" rather than "q,k,v_weight" will be used in the transformer, which should be avoided. This is why we force self._qkv_same_embed_dim = False. + # self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self._qkv_same_embed_dim = False self.num_heads = num_heads self.dropout = dropout + self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert ( self.head_dim * num_heads == self.embed_dim @@ -120,6 +125,10 @@ def __init__( self.dropout = nn.Dropout(dropout) + # to avoid null pointers in Transformer.forward + self.in_proj_weight = None + self.in_proj_bias = None + def load_state_dict(self, state_dict): r""" Loads module from previously saved state. @@ -178,7 +187,27 @@ def forward( key_padding_mask=None, need_weights=True, attn_mask=None, + is_causal = False ): + is_batched = query.dim()==3 + if not is_batched: + raise ValueError( + "The query must have a dimension of 3." + ) + if is_causal: + raise ValueError( + "We currently do not support causal mask." + ) + if self.batch_first: + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + tgt_len, bsz, embed_dim = query.size() if embed_dim != self.embed_dim: raise ValueError( @@ -323,6 +352,9 @@ def forward( ) attn_output = self.out_proj(attn_output) + if self.batch_first: + attn_output = attn_output.transpose(1, 0) + if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view( @@ -361,7 +393,7 @@ def state_dict(self, destination=None, prefix="", keep_vars=False): keep_vars=keep_vars, ) - if self._qkv_same_embed_dim: + if (self.kdim == self.embed_dim) and (self.vdim == self.embed_dim): destination_alter[prefix + "in_proj_weight"] = torch.cat( ( destination[prefix + "qlinear.weight"], diff --git a/opacus/validators/multihead_attention.py b/opacus/validators/multihead_attention.py index acf80aba..d603976d 100644 --- a/opacus/validators/multihead_attention.py +++ b/opacus/validators/multihead_attention.py @@ -45,6 +45,7 @@ def fix(module: nn.MultiheadAttention) -> DPMultiheadAttention: add_zero_attn=module.add_zero_attn, kdim=module.kdim, vdim=module.vdim, + batch_first=module.batch_first ) dp_attn.load_state_dict(module.state_dict()) return dp_attn