From 315b579f2c0b8604802950702aecb8c9a3d78041 Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Wed, 2 Aug 2023 11:00:51 +0200 Subject: [PATCH] remove explicit device, fix return value --- i6_models/decoder/attention.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index cc526951..68726d30 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -39,7 +39,6 @@ def forward( query: torch.Tensor, weight_feedback: torch.Tensor, enc_seq_len: torch.Tensor, - device: str, ) -> Tuple[torch.Tensor, torch.Tensor]: """ :param key: encoder keys of shape [B,T,D_k] @@ -47,14 +46,13 @@ def forward( :param query: query of shape [B,D_k] :param weight_feedback: shape is [B,T,D_k] :param enc_seq_len: encoder sequence lengths [B] - :param device: device where to run the model (cpu or cuda) :return: attention context [B,D_v], attention weights [B,T,1] """ # all inputs are already projected energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] - time_arange = torch.arange(energies.size(1), device=device) # [T] + time_arange = torch.arange(energies.size(1), device=energies.device) # [T] seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T] - energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf"))) + energies = torch.where(seq_len_mask.unsqueeze(2), energies, energies.new_tensor(-float("inf"))) weights = nn.functional.softmax(energies, dim=1) # [B,T,1] weights = self.att_weights_drop(weights) context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v] @@ -76,7 +74,6 @@ class AttentionLSTMDecoderV1Config: attention_cfg: attention config output_proj_dim: output projection dimension output_dropout: output dropout - device: device where to run the model (cpu or cuda) """ encoder_dim: int @@ -89,7 +86,6 @@ class AttentionLSTMDecoderV1Config: attention_cfg: AdditiveAttentionConfig output_proj_dim: int output_dropout: float - device: str class AttentionLSTMDecoderV1(nn.Module): @@ -100,6 +96,7 @@ class AttentionLSTMDecoderV1(nn.Module): def __init__(self, cfg: AttentionLSTMDecoderV1Config): super().__init__() + print(cfg.vocab_size) self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim) self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout) @@ -130,10 +127,6 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config): self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) self.output_dropout = nn.Dropout(cfg.output_dropout) - if "cuda" in cfg.device: - assert torch.cuda.is_available(), "CUDA is not available" - self.device = cfg.device - def forward( self, encoder_outputs: torch.Tensor, @@ -148,10 +141,10 @@ def forward( :param state: decoder state """ if state is None: - zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device=self.device) + zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size)) lstm_state = (zeros, zeros) - att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device=self.device) - accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device=self.device) + att_context = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(2))) + accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) else: lstm_state, att_context, accum_att_weights = state @@ -187,7 +180,6 @@ def forward( query=s_transformed, weight_feedback=weight_feedback, enc_seq_len=enc_seq_len, - device=self.device, ) att_context_list.append(att_context) accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5