From 3476c0716006df79ff152d55184eb6e2944d4021 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Wed, 26 Jul 2023 17:28:38 +0000 Subject: [PATCH 01/19] implemented enc-dec-att model --- i6_models/decoder/attention.py | 185 +++++++++++++++++++++++++++++++++ tests/test_enc_dec_att.py | 47 +++++++++ 2 files changed, 232 insertions(+) create mode 100644 i6_models/decoder/attention.py create mode 100644 tests/test_enc_dec_att.py diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py new file mode 100644 index 00000000..4b95dbc3 --- /dev/null +++ b/i6_models/decoder/attention.py @@ -0,0 +1,185 @@ +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + + +@dataclass +class AdditiveAttentionConfig: + """ + Attributes: + attention_dim: attention dimension + att_weights_dropout: attention weights dropout + """ + + attention_dim: int + att_weights_dropout: float + + +class AdditiveAttention(nn.Module): + """ + Additive attention mechanism. This is defined as: + energies = v^T * tanh(h + s + beta) where beta is weight feedback information + weights = softmax(energies) + context = weights * h + """ + + def __init__(self, cfg: AdditiveAttentionConfig): + super().__init__() + self.linear = nn.Linear(cfg.attention_dim, 1, bias=False) + self.att_weights_drop = nn.Dropout(cfg.att_weights_dropout) + + def forward( + self, + key: torch.Tensor, + value: torch.Tensor, + query: torch.Tensor, + weight_feedback: torch.Tensor, + enc_seq_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param key: encoder keys of shape [B,T,D_k] + :param value: encoder values of shape [B,T,D_v] + :param query: query of shape [B,D_k] + :param weight_feedback: shape is [B,T,D_k] + :param enc_seq_len: [B] + :return: context [B,1,D_v], weights [B,T,1] + """ + + # all inputs are already projected + energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] + time_arange = torch.arange(energies.size(1)) # [T] + seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T] + energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf"))) + weights = nn.functional.softmax(energies, dim=1) # [B,T,1] + weights = self.att_weights_drop(weights) + context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v] + context = context.reshape(context.size(0), -1) # [B,D_v] + return context, weights + + +@dataclass +class AttentionLstmDecoderV1Config: + """ + Attributes: + encoder_dim: encoder dimension + vocab_size: vocabulary size + target_embed_dim: embedding dimension + target_embed_dropout: embedding dropout + lstm_hidden_size: LSTM hidden size + attention_cfg: attention config + output_proj_dim: output projection dimension + output_dropout: output dropout + """ + + encoder_dim: int + vocab_size: int + target_embed_dim: int + target_embed_dropout: float + lstm_hidden_size: int + attention_cfg: AdditiveAttentionConfig + output_proj_dim: int + output_dropout: float + + +class AttentionLstmDecoderV1(nn.Module): + """ + Single-headed Attention decoder with additive attention mechanism. + """ + + def __init__(self, cfg: AttentionLstmDecoderV1Config): + super().__init__() + + self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim) + self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout) + + self.s = nn.LSTMCell( + input_size=cfg.target_embed_dim + cfg.encoder_dim, + hidden_size=cfg.lstm_hidden_size, + ) + self.s_transformed = nn.Linear(cfg.lstm_hidden_size, cfg.attention_cfg.attention_dim, bias=False) # query + + # for attention + self.enc_ctx = nn.Linear(cfg.encoder_dim, cfg.attention_cfg.attention_dim) + self.attention = AdditiveAttention(cfg.attention_cfg) + + # for weight feedback + self.inv_fertility = nn.Linear(cfg.encoder_dim, 1, bias=False) # followed by sigmoid + self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False) + + self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim) + self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) + self.output_dropout = nn.Dropout(cfg.output_dropout) + + def forward( + self, + encoder_outputs: torch.Tensor, + labels: torch.Tensor, + enc_seq_len: torch.Tensor, + state: Optional[Tuple[torch.Tensor, ...]] = None, + ): + """ + :param encoder_outputs: encoder outputs of shape [B,T,D] + :param labels: labels of shape [B,T] + :param enc_seq_len: encoder sequence lengths of shape [B,T] + :param state: decoder state + """ + if state is None: + lstm_state = None + att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2))) + accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) + else: + lstm_state, att_context, accum_att_weights = state + + target_embeddings = self.target_embed(labels) # [B,N,D] + target_embeddings = self.target_embed_dropout(target_embeddings) + # pad for BOS and remove last token as this represents history and last token is not used + target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D] + + enc_ctx = self.enc_ctx(encoder_outputs) # [B,T,D] + enc_inv_fertility = nn.functional.sigmoid(self.inv_fertility(encoder_outputs)) # [B,T,1] + + num_steps = labels.size(1) # N + + # collect for computing later the decoder logits outside the loop + s_list = [] + att_context_list = [] + + # decoder loop + for step in range(num_steps): + target_embed = target_embeddings[:, step, :] # [B,D] + + lstm_state = self.s(torch.cat([target_embed, att_context], dim=-1), lstm_state) + lstm_out = lstm_state[0] + s_transformed = self.s_transformed(lstm_out) # project query + s_list.append(lstm_out) + + # attention mechanism + weight_feedback = self.weight_feedback(accum_att_weights) + att_context, att_weights = self.attention( + key=enc_ctx, + value=encoder_outputs, + query=s_transformed, + weight_feedback=weight_feedback, + enc_seq_len=enc_seq_len, + ) + att_context_list.append(att_context) + accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5 + + # output layer + s_stacked = torch.stack(s_list, dim=1) # [B,N,D] + att_context_stacked = torch.stack(att_context_list, dim=1) # [B,N,D] + readout_in = self.readout_in(torch.cat([s_stacked, target_embeddings, att_context_stacked], dim=-1)) # [B,N,D] + + # maxout layer + assert readout_in.size(-1) % 2 == 0 + readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2] + readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2] + + output = self.output(readout) + decoder_logits = self.output_dropout(output) + + state = lstm_state, att_context, accum_att_weights + + return decoder_logits, state diff --git a/tests/test_enc_dec_att.py b/tests/test_enc_dec_att.py new file mode 100644 index 00000000..07c84296 --- /dev/null +++ b/tests/test_enc_dec_att.py @@ -0,0 +1,47 @@ +import torch +from torch import nn + +from i6_models.decoder.attention import AdditiveAttention, AdditiveAttentionConfig +from i6_models.decoder.attention import AttentionLstmDecoderV1, AttentionLstmDecoderV1Config + + +def test_additive_attention(): + cfg = AdditiveAttentionConfig(attention_dim=5, att_weights_dropout=0.1) + att = AdditiveAttention(cfg) + key = torch.rand((10, 20, 5)) + value = torch.rand((10, 20, 5)) + query = torch.rand((10, 5)) + + enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] + + # pass key as weight feedback for testing + context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len) + assert context.shape == (10, 5) + assert weights.shape == (10, 20, 1) + + # Testing attention weights masking: + # for first seq, the enc seq length is 10 so half the weights should be 0 + assert torch.eq(weights[0, 10:, 0], torch.tensor(0.0)).all() + # test for other seqs + assert torch.eq(weights[5, 15:, 0], torch.tensor(0.0)).all() + + +def test_encoder_decoder_attention_model(): + encoder = torch.rand((10, 20, 5)) + encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] + decoder_cfg = AttentionLstmDecoderV1Config( + encoder_dim=5, + vocab_size=15, + target_embed_dim=3, + target_embed_dropout=0.1, + lstm_hidden_size=12, + attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1), + output_proj_dim=12, + output_dropout=0.1, + ) + decoder = AttentionLstmDecoderV1(decoder_cfg) + target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] + + decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len) + + assert decoder_logits.shape == (10, 7, 15) From b06aac0509d21c01019917af6f9de5c2258367dd Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Wed, 26 Jul 2023 17:34:23 +0000 Subject: [PATCH 02/19] fix docs --- i6_models/decoder/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 4b95dbc3..9bbe3059 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -44,7 +44,7 @@ def forward( :param query: query of shape [B,D_k] :param weight_feedback: shape is [B,T,D_k] :param enc_seq_len: [B] - :return: context [B,1,D_v], weights [B,T,1] + :return: attention context [B,D_v], attention weights [B,T,1] """ # all inputs are already projected From 3f5ada0db1a1b328651e8e10e7e1cd41300ea5f6 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Wed, 26 Jul 2023 17:39:07 +0000 Subject: [PATCH 03/19] fix docs --- tests/test_enc_dec_att.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_enc_dec_att.py b/tests/test_enc_dec_att.py index 07c84296..0f60eb4c 100644 --- a/tests/test_enc_dec_att.py +++ b/tests/test_enc_dec_att.py @@ -14,7 +14,7 @@ def test_additive_attention(): enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] - # pass key as weight feedback for testing + # pass key as weight feedback just for testing context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len) assert context.shape == (10, 5) assert weights.shape == (10, 20, 1) From 4378fb9017fa89db3d1853de0ce3c3c7989c5562 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Thu, 27 Jul 2023 09:56:22 +0000 Subject: [PATCH 04/19] better check --- i6_models/decoder/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 9bbe3059..b8865a32 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -109,6 +109,7 @@ def __init__(self, cfg: AttentionLstmDecoderV1Config): self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False) self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim) + assert cfg.output_proj_dim % 2 == 0, "output projection dimension must be even for MaxOut" self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) self.output_dropout = nn.Dropout(cfg.output_dropout) @@ -173,7 +174,6 @@ def forward( readout_in = self.readout_in(torch.cat([s_stacked, target_embeddings, att_context_stacked], dim=-1)) # [B,N,D] # maxout layer - assert readout_in.size(-1) % 2 == 0 readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2] readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2] From 085f9f318333ac7cda4ef2fca3ef8fca55ff5a4f Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Thu, 27 Jul 2023 09:57:36 +0000 Subject: [PATCH 05/19] better comment --- i6_models/decoder/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index b8865a32..639ca8b0 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -109,7 +109,7 @@ def __init__(self, cfg: AttentionLstmDecoderV1Config): self.weight_feedback = nn.Linear(1, cfg.attention_cfg.attention_dim, bias=False) self.readout_in = nn.Linear(cfg.lstm_hidden_size + cfg.target_embed_dim + cfg.encoder_dim, cfg.output_proj_dim) - assert cfg.output_proj_dim % 2 == 0, "output projection dimension must be even for MaxOut" + assert cfg.output_proj_dim % 2 == 0, "output projection dimension must be even for the MaxOut op of 2 pieces" self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) self.output_dropout = nn.Dropout(cfg.output_dropout) From 523ce6c8251d06017139cfa291ca0e050c746589 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Thu, 27 Jul 2023 09:58:25 +0000 Subject: [PATCH 06/19] better comment --- i6_models/decoder/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 639ca8b0..476cfd65 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -43,7 +43,7 @@ def forward( :param value: encoder values of shape [B,T,D_v] :param query: query of shape [B,D_k] :param weight_feedback: shape is [B,T,D_k] - :param enc_seq_len: [B] + :param enc_seq_len: encoder sequence lengths [B] :return: attention context [B,D_v], attention weights [B,T,1] """ From 396a6642a3503f58343b5e26c769bd541a99d232 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Thu, 27 Jul 2023 11:12:32 +0000 Subject: [PATCH 07/19] refactor + implement zoneout --- i6_models/decoder/attention.py | 26 ++++++++++++++++++------ tests/test_enc_dec_att.py | 37 +++++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 476cfd65..d2b6388f 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -4,6 +4,8 @@ import torch from torch import nn +from .zoneout_lstm import ZoneoutLSTMCell + @dataclass class AdditiveAttentionConfig: @@ -46,7 +48,6 @@ def forward( :param enc_seq_len: encoder sequence lengths [B] :return: attention context [B,D_v], attention weights [B,T,1] """ - # all inputs are already projected energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] time_arange = torch.arange(energies.size(1)) # [T] @@ -60,7 +61,7 @@ def forward( @dataclass -class AttentionLstmDecoderV1Config: +class AttentionLSTMDecoderV1Config: """ Attributes: encoder_dim: encoder dimension @@ -68,6 +69,8 @@ class AttentionLstmDecoderV1Config: target_embed_dim: embedding dimension target_embed_dropout: embedding dropout lstm_hidden_size: LSTM hidden size + zoneout_drop_h: zoneout drop probability for hidden state + zoneout_drop_c: zoneout drop probability for cell state attention_cfg: attention config output_proj_dim: output projection dimension output_dropout: output dropout @@ -78,26 +81,36 @@ class AttentionLstmDecoderV1Config: target_embed_dim: int target_embed_dropout: float lstm_hidden_size: int + zoneout_drop_h: float + zoneout_drop_c: float attention_cfg: AdditiveAttentionConfig output_proj_dim: int output_dropout: float -class AttentionLstmDecoderV1(nn.Module): +class AttentionLSTMDecoderV1(nn.Module): """ Single-headed Attention decoder with additive attention mechanism. """ - def __init__(self, cfg: AttentionLstmDecoderV1Config): + def __init__(self, cfg: AttentionLSTMDecoderV1Config): super().__init__() self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim) self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout) - self.s = nn.LSTMCell( + lstm_cell = nn.LSTMCell( input_size=cfg.target_embed_dim + cfg.encoder_dim, hidden_size=cfg.lstm_hidden_size, ) + self.lstm_hidden_size = cfg.lstm_hidden_size + # if zoneout drop probs are 0, then it is equivalent to normal LSTMCell + self.s = ZoneoutLSTMCell( + cell=lstm_cell, + zoneout_h=cfg.zoneout_drop_h, + zoneout_c=cfg.zoneout_drop_c, + ) + self.s_transformed = nn.Linear(cfg.lstm_hidden_size, cfg.attention_cfg.attention_dim, bias=False) # query # for attention @@ -127,7 +140,8 @@ def forward( :param state: decoder state """ if state is None: - lstm_state = None + zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size)) + lstm_state = (zeros, zeros) att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2))) accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) else: diff --git a/tests/test_enc_dec_att.py b/tests/test_enc_dec_att.py index 0f60eb4c..56d20e57 100644 --- a/tests/test_enc_dec_att.py +++ b/tests/test_enc_dec_att.py @@ -2,7 +2,7 @@ from torch import nn from i6_models.decoder.attention import AdditiveAttention, AdditiveAttentionConfig -from i6_models.decoder.attention import AttentionLstmDecoderV1, AttentionLstmDecoderV1Config +from i6_models.decoder.attention import AttentionLSTMDecoderV1, AttentionLSTMDecoderV1Config def test_additive_attention(): @@ -29,7 +29,7 @@ def test_additive_attention(): def test_encoder_decoder_attention_model(): encoder = torch.rand((10, 20, 5)) encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] - decoder_cfg = AttentionLstmDecoderV1Config( + decoder_cfg = AttentionLSTMDecoderV1Config( encoder_dim=5, vocab_size=15, target_embed_dim=3, @@ -38,10 +38,41 @@ def test_encoder_decoder_attention_model(): attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1), output_proj_dim=12, output_dropout=0.1, + zoneout_drop_c=0.0, + zoneout_drop_h=0.0, ) - decoder = AttentionLstmDecoderV1(decoder_cfg) + decoder = AttentionLSTMDecoderV1(decoder_cfg) target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len) assert decoder_logits.shape == (10, 7, 15) + + +def test_zoneout_lstm_cell(): + encoder = torch.rand((10, 20, 5)) + encoder_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] + target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] + + def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float): + decoder_cfg = AttentionLSTMDecoderV1Config( + encoder_dim=5, + vocab_size=15, + target_embed_dim=3, + target_embed_dropout=0.1, + lstm_hidden_size=12, + attention_cfg=AdditiveAttentionConfig(attention_dim=10, att_weights_dropout=0.1), + output_proj_dim=12, + output_dropout=0.1, + zoneout_drop_c=zoneout_drop_c, + zoneout_drop_h=zoneout_drop_h, + ) + decoder = AttentionLSTMDecoderV1(decoder_cfg) + decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len) + return decoder_logits + + decoder_logits = forward_decoder(zoneout_drop_c=0.15, zoneout_drop_h=0.05) + assert decoder_logits.shape == (10, 7, 15) + + decoder_logits = forward_decoder(zoneout_drop_c=0.0, zoneout_drop_h=0.0) + assert decoder_logits.shape == (10, 7, 15) From 9c22aa223913e1b80645afab0bcf404e0e84c480 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Thu, 27 Jul 2023 11:12:50 +0000 Subject: [PATCH 08/19] implement zoneout lstm cell --- i6_models/decoder/zoneout_lstm.py | 47 +++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 i6_models/decoder/zoneout_lstm.py diff --git a/i6_models/decoder/zoneout_lstm.py b/i6_models/decoder/zoneout_lstm.py new file mode 100644 index 00000000..fef76390 --- /dev/null +++ b/i6_models/decoder/zoneout_lstm.py @@ -0,0 +1,47 @@ +import torch +from torch import nn + +from typing import Tuple + + +class ZoneoutLSTMCell(nn.Module): + """ + Wrap an LSTM cell with Zoneout regularization (https://arxiv.org/abs/1606.01305) + """ + + def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float): + """ + :param cell: LSTM cell + :param zoneout_h: zoneout drop probability for hidden state + :param zoneout_c: zoneout drop probability for cell state + """ + super().__init__() + self.cell = cell + assert 0.0 <= zoneout_h <= 1.0 and 0.0 <= zoneout_c <= 1.0, "Zoneout drop probability must be in [0, 1]" + self.zoneout_h = zoneout_h + self.zoneout_c = zoneout_c + + def forward( + self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + h, c = self.cell(inputs) + prev_h, prev_c = state + h = self._zoneout(prev_h, h, self.zoneout_h) + c = self._zoneout(prev_c, c, self.zoneout_c) + return h, c + + def _zoneout(self, prev_state: torch.Tensor, curr_state: torch.Tensor, factor: float): + """ + Apply Zoneout. + + :param prev: previous state tensor + :param curr: current state tensor + :param factor: drop probability + """ + if factor == 0.0: + return curr_state + if self.training: + mask = curr_state.new_empty(size=curr_state.size()).bernoulli_(factor) + return mask * prev_state + (1 - mask) * curr_state + else: + return factor * prev_state + (1 - factor) * curr_state From c4d57107f8658b323f20079a9c2736e9eecc46fc Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Thu, 27 Jul 2023 11:17:22 +0000 Subject: [PATCH 09/19] fix --- i6_models/decoder/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index d2b6388f..f4fb3c70 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -143,7 +143,7 @@ def forward( zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size)) lstm_state = (zeros, zeros) att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2))) - accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) + accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) else: lstm_state, att_context, accum_att_weights = state From cd366abc1c36baa20d55fa404f72326201d474f8 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Fri, 28 Jul 2023 15:58:31 +0200 Subject: [PATCH 10/19] put tensors on cuda --- i6_models/decoder/attention.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index f4fb3c70..521828f2 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -24,7 +24,7 @@ class AdditiveAttention(nn.Module): Additive attention mechanism. This is defined as: energies = v^T * tanh(h + s + beta) where beta is weight feedback information weights = softmax(energies) - context = weights * h + context = sum_t weights_t * h_t """ def __init__(self, cfg: AdditiveAttentionConfig): @@ -50,7 +50,7 @@ def forward( """ # all inputs are already projected energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] - time_arange = torch.arange(energies.size(1)) # [T] + time_arange = torch.arange(energies.size(1), device="cuda") # [T] seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T] energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf"))) weights = nn.functional.softmax(energies, dim=1) # [B,T,1] @@ -140,15 +140,16 @@ def forward( :param state: decoder state """ if state is None: - zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size)) + zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device="cuda") lstm_state = (zeros, zeros) - att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2))) - accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) + att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device="cuda") + accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device="cuda") else: lstm_state, att_context, accum_att_weights = state target_embeddings = self.target_embed(labels) # [B,N,D] target_embeddings = self.target_embed_dropout(target_embeddings) + # pad for BOS and remove last token as this represents history and last token is not used target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D] From d0ed59b8525f336f0440a971b002334a0ecd5b20 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Mon, 31 Jul 2023 15:30:58 +0200 Subject: [PATCH 11/19] make device configurable --- i6_models/decoder/attention.py | 15 +++++++++++---- tests/test_enc_dec_att.py | 6 +++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 521828f2..c48b46f8 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -39,6 +39,7 @@ def forward( query: torch.Tensor, weight_feedback: torch.Tensor, enc_seq_len: torch.Tensor, + device: str, ) -> Tuple[torch.Tensor, torch.Tensor]: """ :param key: encoder keys of shape [B,T,D_k] @@ -46,11 +47,12 @@ def forward( :param query: query of shape [B,D_k] :param weight_feedback: shape is [B,T,D_k] :param enc_seq_len: encoder sequence lengths [B] + :param device: device where to run the model (cpu or cuda) :return: attention context [B,D_v], attention weights [B,T,1] """ # all inputs are already projected energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] - time_arange = torch.arange(energies.size(1), device="cuda") # [T] + time_arange = torch.arange(energies.size(1), device=device) # [T] seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T] energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf"))) weights = nn.functional.softmax(energies, dim=1) # [B,T,1] @@ -74,6 +76,7 @@ class AttentionLSTMDecoderV1Config: attention_cfg: attention config output_proj_dim: output projection dimension output_dropout: output dropout + device: device where to run the model (cpu or cuda) """ encoder_dim: int @@ -86,6 +89,7 @@ class AttentionLSTMDecoderV1Config: attention_cfg: AdditiveAttentionConfig output_proj_dim: int output_dropout: float + device: str class AttentionLSTMDecoderV1(nn.Module): @@ -126,6 +130,8 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config): self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) self.output_dropout = nn.Dropout(cfg.output_dropout) + self.device = cfg.device + def forward( self, encoder_outputs: torch.Tensor, @@ -140,10 +146,10 @@ def forward( :param state: decoder state """ if state is None: - zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device="cuda") + zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device=self.device) lstm_state = (zeros, zeros) - att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device="cuda") - accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device="cuda") + att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device=self.device) + accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device=self.device) else: lstm_state, att_context, accum_att_weights = state @@ -179,6 +185,7 @@ def forward( query=s_transformed, weight_feedback=weight_feedback, enc_seq_len=enc_seq_len, + device=self.device, ) att_context_list.append(att_context) accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5 diff --git a/tests/test_enc_dec_att.py b/tests/test_enc_dec_att.py index 56d20e57..4ca16128 100644 --- a/tests/test_enc_dec_att.py +++ b/tests/test_enc_dec_att.py @@ -15,7 +15,9 @@ def test_additive_attention(): enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] # pass key as weight feedback just for testing - context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len) + context, weights = att( + key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len, device="cpu" + ) assert context.shape == (10, 5) assert weights.shape == (10, 20, 1) @@ -40,6 +42,7 @@ def test_encoder_decoder_attention_model(): output_dropout=0.1, zoneout_drop_c=0.0, zoneout_drop_h=0.0, + device="cpu", ) decoder = AttentionLSTMDecoderV1(decoder_cfg) target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] @@ -66,6 +69,7 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float): output_dropout=0.1, zoneout_drop_c=zoneout_drop_c, zoneout_drop_h=zoneout_drop_h, + device="cpu", ) decoder = AttentionLSTMDecoderV1(decoder_cfg) decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len) From 5a78e40c1c2f3180b0cb49f9f2e6ea94b995e0f7 Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Mon, 31 Jul 2023 15:32:43 +0200 Subject: [PATCH 12/19] check for cuda availability --- i6_models/decoder/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index c48b46f8..cc526951 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -130,6 +130,8 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config): self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) self.output_dropout = nn.Dropout(cfg.output_dropout) + if "cuda" in cfg.device: + assert torch.cuda.is_available(), "CUDA is not available" self.device = cfg.device def forward( From 163315a97139e921f99d457d644972df2fd67b12 Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Wed, 2 Aug 2023 11:00:51 +0200 Subject: [PATCH 13/19] remove explicit device, fix return value --- i6_models/decoder/attention.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index cc526951..68726d30 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -39,7 +39,6 @@ def forward( query: torch.Tensor, weight_feedback: torch.Tensor, enc_seq_len: torch.Tensor, - device: str, ) -> Tuple[torch.Tensor, torch.Tensor]: """ :param key: encoder keys of shape [B,T,D_k] @@ -47,14 +46,13 @@ def forward( :param query: query of shape [B,D_k] :param weight_feedback: shape is [B,T,D_k] :param enc_seq_len: encoder sequence lengths [B] - :param device: device where to run the model (cpu or cuda) :return: attention context [B,D_v], attention weights [B,T,1] """ # all inputs are already projected energies = self.linear(nn.functional.tanh(key + query.unsqueeze(1) + weight_feedback)) # [B,T,1] - time_arange = torch.arange(energies.size(1), device=device) # [T] + time_arange = torch.arange(energies.size(1), device=energies.device) # [T] seq_len_mask = torch.less(time_arange[None, :], enc_seq_len[:, None]) # [B,T] - energies = torch.where(seq_len_mask.unsqueeze(2), energies, torch.tensor(-float("inf"))) + energies = torch.where(seq_len_mask.unsqueeze(2), energies, energies.new_tensor(-float("inf"))) weights = nn.functional.softmax(energies, dim=1) # [B,T,1] weights = self.att_weights_drop(weights) context = torch.bmm(weights.transpose(1, 2), value) # [B,1,D_v] @@ -76,7 +74,6 @@ class AttentionLSTMDecoderV1Config: attention_cfg: attention config output_proj_dim: output projection dimension output_dropout: output dropout - device: device where to run the model (cpu or cuda) """ encoder_dim: int @@ -89,7 +86,6 @@ class AttentionLSTMDecoderV1Config: attention_cfg: AdditiveAttentionConfig output_proj_dim: int output_dropout: float - device: str class AttentionLSTMDecoderV1(nn.Module): @@ -100,6 +96,7 @@ class AttentionLSTMDecoderV1(nn.Module): def __init__(self, cfg: AttentionLSTMDecoderV1Config): super().__init__() + print(cfg.vocab_size) self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim) self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout) @@ -130,10 +127,6 @@ def __init__(self, cfg: AttentionLSTMDecoderV1Config): self.output = nn.Linear(cfg.output_proj_dim // 2, cfg.vocab_size) self.output_dropout = nn.Dropout(cfg.output_dropout) - if "cuda" in cfg.device: - assert torch.cuda.is_available(), "CUDA is not available" - self.device = cfg.device - def forward( self, encoder_outputs: torch.Tensor, @@ -148,10 +141,10 @@ def forward( :param state: decoder state """ if state is None: - zeros = torch.zeros((encoder_outputs.size(0), self.lstm_hidden_size), device=self.device) + zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size)) lstm_state = (zeros, zeros) - att_context = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(2)), device=self.device) - accum_att_weights = torch.zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1), device=self.device) + att_context = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(2))) + accum_att_weights = encoder_outputs.new_zeros((encoder_outputs.size(0), encoder_outputs.size(1), 1)) else: lstm_state, att_context, accum_att_weights = state @@ -187,7 +180,6 @@ def forward( query=s_transformed, weight_feedback=weight_feedback, enc_seq_len=enc_seq_len, - device=self.device, ) att_context_list.append(att_context) accum_att_weights = accum_att_weights + att_weights * enc_inv_fertility * 0.5 From 6d200eb2237a3c361a74f54f1eedcda06b68db7e Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Wed, 2 Aug 2023 11:13:27 +0200 Subject: [PATCH 14/19] remove device from attention test --- tests/test_enc_dec_att.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/test_enc_dec_att.py b/tests/test_enc_dec_att.py index 4ca16128..56d20e57 100644 --- a/tests/test_enc_dec_att.py +++ b/tests/test_enc_dec_att.py @@ -15,9 +15,7 @@ def test_additive_attention(): enc_seq_len = torch.arange(start=10, end=20) # [10, ..., 19] # pass key as weight feedback just for testing - context, weights = att( - key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len, device="cpu" - ) + context, weights = att(key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len) assert context.shape == (10, 5) assert weights.shape == (10, 20, 1) @@ -42,7 +40,6 @@ def test_encoder_decoder_attention_model(): output_dropout=0.1, zoneout_drop_c=0.0, zoneout_drop_h=0.0, - device="cpu", ) decoder = AttentionLSTMDecoderV1(decoder_cfg) target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N] @@ -69,7 +66,6 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float): output_dropout=0.1, zoneout_drop_c=zoneout_drop_c, zoneout_drop_h=zoneout_drop_h, - device="cpu", ) decoder = AttentionLSTMDecoderV1(decoder_cfg) decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len) From 1b8dd52be339c3148279328bbaa904fce9f26a44 Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Wed, 2 Aug 2023 12:20:27 +0200 Subject: [PATCH 15/19] remove leftover print statement --- i6_models/decoder/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 68726d30..2d0ca276 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -96,7 +96,6 @@ class AttentionLSTMDecoderV1(nn.Module): def __init__(self, cfg: AttentionLSTMDecoderV1Config): super().__init__() - print(cfg.vocab_size) self.target_embed = nn.Embedding(num_embeddings=cfg.vocab_size, embedding_dim=cfg.target_embed_dim) self.target_embed_dropout = nn.Dropout(cfg.target_embed_dropout) From dcf038171eedbf4dcd66432a957ca69f9204c88a Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Thu, 31 Aug 2023 13:21:30 +0200 Subject: [PATCH 16/19] Add shift_embeddings flag for Attention Decoder Allows to pass the label unshifted for step-wise search without needing a separate function besides "forward". --- i6_models/decoder/__init__.py | 0 i6_models/decoder/attention.py | 19 ++++++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 i6_models/decoder/__init__.py diff --git a/i6_models/decoder/__init__.py b/i6_models/decoder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 2d0ca276..a294aeac 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -132,12 +132,20 @@ def forward( labels: torch.Tensor, enc_seq_len: torch.Tensor, state: Optional[Tuple[torch.Tensor, ...]] = None, + shift_embeddings=True, ): """ - :param encoder_outputs: encoder outputs of shape [B,T,D] - :param labels: labels of shape [B,T] - :param enc_seq_len: encoder sequence lengths of shape [B,T] + :param encoder_outputs: encoder outputs of shape [B,T,D], same for training and search + :param labels: + training: labels of shape [B,N] + (greedy-)search: hypotheses last label as [B,1] + :param enc_seq_len: encoder sequence lengths of shape [B,T], same for training and search :param state: decoder state + training: Usually None, unless decoding should be initialized with a certain state (e.g. for context init) + search: current state of the active hypotheses + :param shift_embeddings: shift the embeddings by one position along U, padding with zero in front and drop last + training: this should be "True", in order to start with a zero target embedding + search: use True for the first step in order to start with a zero embedding, False otherwise """ if state is None: zeros = encoder_outputs.new_zeros((encoder_outputs.size(0), self.lstm_hidden_size)) @@ -150,8 +158,9 @@ def forward( target_embeddings = self.target_embed(labels) # [B,N,D] target_embeddings = self.target_embed_dropout(target_embeddings) - # pad for BOS and remove last token as this represents history and last token is not used - target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D] + if shift_embeddings: + # pad for BOS and remove last token as this represents history and last token is not used + target_embeddings = nn.functional.pad(target_embeddings, (0, 0, 1, 0), value=0)[:, :-1, :] # [B,N,D] enc_ctx = self.enc_ctx(encoder_outputs) # [B,T,D] enc_inv_fertility = nn.functional.sigmoid(self.inv_fertility(encoder_outputs)) # [B,T,1] From 6a147b7117aa6afa3023c6a57e4339398d5e286f Mon Sep 17 00:00:00 2001 From: Mohammad Zeineldeen Date: Wed, 29 May 2024 17:10:48 +0200 Subject: [PATCH 17/19] fix logits dropout bug --- i6_models/decoder/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index a294aeac..985f5603 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -201,8 +201,8 @@ def forward( readout_in = readout_in.view(readout_in.size(0), readout_in.size(1), -1, 2) # [B,N,D/2,2] readout, _ = torch.max(readout_in, dim=-1) # [B,N,D/2] - output = self.output(readout) - decoder_logits = self.output_dropout(output) + readout_drop = self.output_dropout(readout) + decoder_logits = self.output(readout_drop) state = lstm_state, att_context, accum_att_weights From a988e85b7c973d93f2e0b29726cdc3315c7d28d8 Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Thu, 24 Oct 2024 14:58:53 +0200 Subject: [PATCH 18/19] Update i6_models/decoder/attention.py Co-authored-by: Benedikt Hilmes --- i6_models/decoder/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/i6_models/decoder/attention.py b/i6_models/decoder/attention.py index 985f5603..f8881079 100644 --- a/i6_models/decoder/attention.py +++ b/i6_models/decoder/attention.py @@ -132,7 +132,7 @@ def forward( labels: torch.Tensor, enc_seq_len: torch.Tensor, state: Optional[Tuple[torch.Tensor, ...]] = None, - shift_embeddings=True, + shift_embeddings: bool = True, ): """ :param encoder_outputs: encoder outputs of shape [B,T,D], same for training and search From bb8fa4e690117bae6ab7694b908ee3366376c54b Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Thu, 24 Oct 2024 16:29:56 +0200 Subject: [PATCH 19/19] Disable autocast for ZoneoutLSTM cell --- i6_models/decoder/zoneout_lstm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/i6_models/decoder/zoneout_lstm.py b/i6_models/decoder/zoneout_lstm.py index fef76390..d23ba8d0 100644 --- a/i6_models/decoder/zoneout_lstm.py +++ b/i6_models/decoder/zoneout_lstm.py @@ -24,7 +24,8 @@ def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float): def forward( self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: - h, c = self.cell(inputs) + with torch.autocast(device_type="cuda", enabled=False): + h, c = self.cell(inputs) prev_h, prev_c = state h = self._zoneout(prev_h, h, self.zoneout_h) c = self._zoneout(prev_c, c, self.zoneout_c)