diff --git a/i6_models/parts/mha.py b/i6_models/parts/mha.py index 4bfc4296..09972dad 100644 --- a/i6_models/parts/mha.py +++ b/i6_models/parts/mha.py @@ -45,10 +45,6 @@ def forward( value: torch.Tensor, key_padding_mask: torch.Tensor): - if key_padding_mask is not None: - inv_sequence_mask = compat.logical_not(key_padding_mask) - else: - inv_sequence_mask = None assert query is value is key, "only supports self attention for now" batch_dim , num_tokens, embed_dim = query.shape @@ -70,9 +66,9 @@ def forward( dot = torch.matmul(query, key) # [B, D//H, T, T] dot = dot / self.norm - if inv_sequence_mask is not None: - inv_sequence_mask = inv_sequence_mask.view(batch_dim, 1, 1, inv_sequence_mask.size(1)) - dot = dot.masked_fill(inv_sequence_mask, -float('inf')) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view(batch_dim, 1, 1, key_padding_mask.size(1)) + dot = dot.masked_fill(key_padding_mask, -float('inf')) alpha = self.softmax(dot)# [B, D//H, T, T] alpha = self.dropout(alpha)