diff --git a/egs/librispeech/SSL/hubert/attention_module.py b/egs/librispeech/SSL/hubert/attention_module.py index 8e47ed7ab6..acfcbabf0c 100644 --- a/egs/librispeech/SSL/hubert/attention_module.py +++ b/egs/librispeech/SSL/hubert/attention_module.py @@ -155,8 +155,7 @@ def __init__( self.encoder_decoder_attention = encoder_decoder_attention assert not self.self_attention or self.qkv_same_dim, ( - "Self-attention requires query, key and " - "value to be of the same size" + "Self-attention requires query, key and value to be of the same size" ) self.k_proj = quant_noise( @@ -219,35 +218,57 @@ def _get_reserve_head_index(self, num_heads_to_keep: int): end_idx = (i + 1) * self.head_dim k_proj_heads_norm.append( torch.sum( - torch.abs(self.k_proj.weight[start_idx:end_idx,]) + torch.abs( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) ).tolist() + torch.sum( - torch.abs(self.k_proj.bias[start_idx:end_idx]) + torch.abs( + self.k_proj.bias[ + start_idx:end_idx + ] + ) ).tolist() ) q_proj_heads_norm.append( torch.sum( - torch.abs(self.q_proj.weight[start_idx:end_idx,]) + torch.abs( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) ).tolist() + torch.sum( - torch.abs(self.q_proj.bias[start_idx:end_idx]) + torch.abs( + self.q_proj.bias[ + start_idx:end_idx + ] + ) ).tolist() ) v_proj_heads_norm.append( torch.sum( - torch.abs(self.v_proj.weight[start_idx:end_idx,]) + torch.abs( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) ).tolist() + torch.sum( - torch.abs(self.v_proj.bias[start_idx:end_idx]) + torch.abs( + self.v_proj.bias[ + start_idx:end_idx + ] + ) ).tolist() ) heads_norm = [] for i in range(self.num_heads): heads_norm.append( - k_proj_heads_norm[i] - + q_proj_heads_norm[i] - + v_proj_heads_norm[i] + k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] ) sorted_head_index = sorted( @@ -271,19 +292,29 @@ def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): for ele in reserve_head_index: start_idx, end_idx = ele - new_q_weight.append(self.q_proj.weight[start_idx:end_idx,]) + new_q_weight.append( + self.q_proj.weight[ + start_idx:end_idx, + ] + ) new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) - new_k_weight.append(self.k_proj.weight[start_idx:end_idx,]) + new_k_weight.append( + self.k_proj.weight[ + start_idx:end_idx, + ] + ) new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) - new_v_weight.append(self.v_proj.weight[start_idx:end_idx,]) + new_v_weight.append( + self.v_proj.weight[ + start_idx:end_idx, + ] + ) new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) - new_out_proj_weight.append( - self.out_proj.weight[:, start_idx:end_idx] - ) + new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) new_q_weight = torch.cat(new_q_weight).detach() new_k_weight = torch.cat(new_k_weight).detach() @@ -330,9 +361,7 @@ def _pad_masks( ) -> Tuple[Optional[Tensor], Optional[Tensor]]: if attn_mask is not None: shape = attn_mask.size()[:-1] + torch.Size([1]) - attn_mask = torch.cat( - [attn_mask, attn_mask.new_zeros(shape)], dim=-1 - ) + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) if key_padding_mask is not None: shape = key_padding_mask.size()[:-1] + torch.Size([1]) key_padding_mask = torch.cat( @@ -388,9 +417,7 @@ def forward( key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, - incremental_state: Optional[ - Dict[str, Dict[str, Optional[Tensor]]] - ] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, @@ -455,9 +482,7 @@ def forward( self.embed_dim, self.num_heads, torch.empty([0]), - torch.cat( - (self.q_proj.bias, self.k_proj.bias, self.v_proj.bias) - ), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, @@ -465,9 +490,7 @@ def forward( self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, - key_padding_mask.bool() - if key_padding_mask is not None - else None, + key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, @@ -482,10 +505,7 @@ def forward( # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: - assert ( - self.encoder_decoder_attention - and not self.self_attention - ) + assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None @@ -503,9 +523,9 @@ def forward( else: if self.beam_size > 1 and bsz == key.size(1): # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] - key = key.view( - key.size(0), -1, self.beam_size, key.size(2) - )[:, :, 0, :] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] if key_padding_mask is not None: key_padding_mask = key_padding_mask.view( -1, self.beam_size, key_padding_mask.size(1) @@ -552,9 +572,7 @@ def forward( _prev_key = saved_state["prev_key"] assert _prev_key is not None kv_bsz = _prev_key.size(0) - prev_key = _prev_key.view( - kv_bsz * self.num_heads, -1, self.head_dim - ) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: @@ -585,18 +603,14 @@ def forward( static_kv=static_kv, ) - saved_state["prev_key"] = k.view( - kv_bsz, self.num_heads, -1, self.head_dim - ) + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view( kv_bsz, self.num_heads, -1, self.head_dim ) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None - incremental_state = self._set_input_buffer( - incremental_state, saved_state - ) + incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None assert k.size(1) == src_len @@ -622,14 +636,10 @@ def forward( q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]), ) - attn_weights = attn_weights.reshape( - (-1,) + attn_weights.size()[-2:] - ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) else: attn_weights = torch.bmm(q, k.transpose(1, 2)) - attn_weights = self.apply_sparse_mask( - attn_weights, tgt_len, src_len, bsz - ) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [ bsz * self.num_heads, @@ -645,9 +655,7 @@ def forward( if key_padding_mask is not None: # don't attend to padding symbols - attn_weights = attn_weights.view( - bsz, self.num_heads, tgt_len, src_len - ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) if not is_tpu: attn_weights = attn_weights.view( kv_bsz, -1, self.num_heads, tgt_len, src_len @@ -661,13 +669,9 @@ def forward( ) else: attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.masked_fill( - key_padding_mask, float("-inf") - ) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) - attn_weights = attn_weights.view( - bsz * self.num_heads, tgt_len, src_len - ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v @@ -712,11 +716,7 @@ def forward( # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) else: - attn = ( - attn.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, self.embed_dim) - ) + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: @@ -786,9 +786,7 @@ def reorder_incremental_state( input_buffer_k = input_buffer[k] if input_buffer_k is not None: if self.encoder_decoder_attention: - if input_buffer_k.size( - 0 - ) * self.beam_size == new_order.size(0): + if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select( @@ -797,16 +795,10 @@ def reorder_incremental_state( // self.beam_size, ) else: - input_buffer[k] = input_buffer_k.index_select( - 0, new_order - ) + input_buffer[k] = input_buffer_k.index_select(0, new_order) else: - input_buffer[k] = input_buffer_k.index_select( - 0, new_order - ) - incremental_state = self._set_input_buffer( - incremental_state, input_buffer - ) + input_buffer[k] = input_buffer_k.index_select(0, new_order) + incremental_state = self._set_input_buffer(incremental_state, input_buffer) return incremental_state def set_beam_size(self, beam_size): @@ -829,13 +821,9 @@ def _set_input_buffer( incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]], ): - return self.set_incremental_state( - incremental_state, "attn_state", buffer - ) + return self.set_incremental_state(incremental_state, "attn_state", buffer) - def apply_sparse_mask( - self, attn_weights, tgt_len: int, src_len: int, bsz: int - ): + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): return attn_weights def upgrade_state_dict_named(self, state_dict, name): @@ -847,27 +835,19 @@ def upgrade_state_dict_named(self, state_dict, name): # in_proj_weight used to be q + k + v with same dimensions dim = int(state_dict[k].shape[0] / 3) items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] - items_to_add[prefix + "k_proj.weight"] = state_dict[k][ - dim : 2 * dim - ] - items_to_add[prefix + "v_proj.weight"] = state_dict[k][ - 2 * dim : - ] + items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] + items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] keys_to_remove.append(k) k_bias = prefix + "in_proj_bias" if k_bias in state_dict.keys(): dim = int(state_dict[k].shape[0] / 3) - items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][ - :dim - ] + items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ dim : 2 * dim ] - items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][ - 2 * dim : - ] + items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] keys_to_remove.append(prefix + "in_proj_bias")