Skip to content

Commit

Permalink
remove explicit device, fix return value
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed Aug 2, 2023
1 parent 4ddfae6 commit 315b579
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions i6_models/decoder/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,20 @@ def forward(
query: torch.Tensor,
weight_feedback: torch.Tensor,
enc_seq_len: torch.Tensor,
device: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param key: encoder keys of shape [B,T,D_k]
:param value: encoder values of shape [B,T,D_v]
:param query: query of shape [B,D_k]
:param weight_feedback: shape is [B,T,D_k]
:param enc_seq_len: encoder sequence lengths [B]
:param device: device where to run the model (cpu or cuda)
:return: attention context [B,D_v], attention weights [B,T,1]
"""
# all inputs are already projected
energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1]
time_arange = torch.arange(energies.size(1), device=device) # [T]
time_arange = torch.arange(energies.size(1), device=energies.device) # [T]
seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T]
energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf")))
energies = torch.where(seq_len_mask.unsqueeze(2), energies, energies.new_tensor(-float("inf")))
weights = nn.functional.softmax(energies, dim=1) # [B,T,1]
weights = self.att_weights_drop(weights)
context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v]
Expand All @@ -76,7 +74,6 @@ class AttentionLSTMDecoderV1Config:
attention_cfg: attention config
output_proj_dim: output projection dimension
output_dropout: output dropout
device: device where to run the model (cpu or cuda)
"""

encoder_dim: int
Expand All @@ -89,7 +86,6 @@ class AttentionLSTMDecoderV1Config:
attention_cfg: AdditiveAttentionConfig
output_proj_dim: int
output_dropout: float
device: str


class AttentionLSTMDecoderV1(nn.Module):
Expand All @@ -100,6 +96,7 @@ class AttentionLSTMDecoderV1(nn.Module):
def __init__(self, cfg: AttentionLSTMDecoderV1Config):
super().__init__()

print(cfg.vocab_size)
self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim)
self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout)

Expand Down Expand Up @@ -130,10 +127,6 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config):
self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size)
self.output_dropout = nn.Dropout(cfg.output_dropout)

if "cuda" in cfg.device:
assert torch.cuda.is_available(), "CUDA is not available"
self.device = cfg.device

def forward(
self,
encoder_outputs: torch.Tensor,
Expand All @@ -148,10 +141,10 @@ def forward(
:param state: decoder state
"""
if state is None:
zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device=self.device)
zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size))
lstm_state = (zeros, zeros)
att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device=self.device)
accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device=self.device)
att_context = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(2)))
accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1))
else:
lstm_state, att_context, accum_att_weights = state

Expand Down Expand Up @@ -187,7 +180,6 @@ def forward(
query=s_transformed,
weight_feedback=weight_feedback,
enc_seq_len=enc_seq_len,
device=self.device,
)
att_context_list.append(att_context)
accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5
Expand Down

0 comments on commit 315b579

Please sign in to comment.