From 911bfacffd10d1c94cc3567127c588b947d8db98 Mon Sep 17 00:00:00 2001 From: yifanyeung Date: Sun, 18 Feb 2024 13:24:02 +0800 Subject: [PATCH] fix for black --- .../SSL/hubert/attention_module.py | 30 ++++--------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/SSL/hubert/attention_module.py b/egs/librispeech/SSL/hubert/attention_module.py index acfcbabf0c..39ef8698ee 100644 --- a/egs/librispeech/SSL/hubert/attention_module.py +++ b/egs/librispeech/SSL/hubert/attention_module.py @@ -154,9 +154,9 @@ def __init__( self.self_attention = self_attention self.encoder_decoder_attention = encoder_decoder_attention - assert not self.self_attention or self.qkv_same_dim, ( - "Self-attention requires query, key and value to be of the same size" - ) + assert ( + not self.self_attention or self.qkv_same_dim + ), "Self-attention requires query, key and value to be of the same size" self.k_proj = quant_noise( nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size @@ -224,13 +224,7 @@ def _get_reserve_head_index(self, num_heads_to_keep: int): ] ) ).tolist() - + torch.sum( - torch.abs( - self.k_proj.bias[ - start_idx:end_idx - ] - ) - ).tolist() + + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() ) q_proj_heads_norm.append( torch.sum( @@ -240,13 +234,7 @@ def _get_reserve_head_index(self, num_heads_to_keep: int): ] ) ).tolist() - + torch.sum( - torch.abs( - self.q_proj.bias[ - start_idx:end_idx - ] - ) - ).tolist() + + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() ) v_proj_heads_norm.append( torch.sum( @@ -256,13 +244,7 @@ def _get_reserve_head_index(self, num_heads_to_keep: int): ] ) ).tolist() - + torch.sum( - torch.abs( - self.v_proj.bias[ - start_idx:end_idx - ] - ) - ).tolist() + + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() ) heads_norm = []