Skip to content

Commit

Permalink
fix for black
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanyeung committed Feb 18, 2024
1 parent c0a5601 commit 911bfac
Showing 1 changed file with 6 additions and 24 deletions.
30 changes: 6 additions & 24 deletions egs/librispeech/SSL/hubert/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def __init__(
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention

assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and value to be of the same size"
)
assert (
not self.self_attention or self.qkv_same_dim
), "Self-attention requires query, key and value to be of the same size"

self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
Expand Down Expand Up @@ -224,13 +224,7 @@ def _get_reserve_head_index(self, num_heads_to_keep: int):
]
)
).tolist()
+ torch.sum(
torch.abs(
self.k_proj.bias[
start_idx:end_idx
]
)
).tolist()
+ torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist()
)
q_proj_heads_norm.append(
torch.sum(
Expand All @@ -240,13 +234,7 @@ def _get_reserve_head_index(self, num_heads_to_keep: int):
]
)
).tolist()
+ torch.sum(
torch.abs(
self.q_proj.bias[
start_idx:end_idx
]
)
).tolist()
+ torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist()
)
v_proj_heads_norm.append(
torch.sum(
Expand All @@ -256,13 +244,7 @@ def _get_reserve_head_index(self, num_heads_to_keep: int):
]
)
).tolist()
+ torch.sum(
torch.abs(
self.v_proj.bias[
start_idx:end_idx
]
)
).tolist()
+ torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist()
)

heads_norm = []
Expand Down

0 comments on commit 911bfac

Please sign in to comment.