Skip to content

Commit

Permalink
Removed the inversion of the mask
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepyeldrazi committed Jun 3, 2024
1 parent d928b2e commit 8c2efcd
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions i6_models/parts/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ def forward(
value: torch.Tensor,
key_padding_mask: torch.Tensor):

if key_padding_mask is not None:
inv_sequence_mask = compat.logical_not(key_padding_mask)
else:
inv_sequence_mask = None
assert query is value is key, "only supports self attention for now"

batch_dim , num_tokens, embed_dim = query.shape
Expand All @@ -70,9 +66,9 @@ def forward(
dot = torch.matmul(query, key) # [B, D//H, T, T]
dot = dot / self.norm

if inv_sequence_mask is not None:
inv_sequence_mask = inv_sequence_mask.view(batch_dim, 1, 1, inv_sequence_mask.size(1))
dot = dot.masked_fill(inv_sequence_mask, -float('inf'))
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(batch_dim, 1, 1, key_padding_mask.size(1))
dot = dot.masked_fill(key_padding_mask, -float('inf'))

alpha = self.softmax(dot)# [B, D//H, T, T]
alpha = self.dropout(alpha)
Expand Down

0 comments on commit 8c2efcd

Please sign in to comment.