Skip to content

Commit

Permalink
remove device from attention test
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed Aug 2, 2023
1 parent 315b579 commit 44b9cdc
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions tests/test_enc_dec_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_additive_attention():

# pass key as weight feedback just for testing
context, weights = att(
key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len, device="cpu"
key=key, value=value, query=query, weight_feedback=key, enc_seq_len=enc_seq_len
)
assert context.shape == (10, 5)
assert weights.shape == (10, 20, 1)
Expand All @@ -42,7 +42,6 @@ def test_encoder_decoder_attention_model():
output_dropout=0.1,
zoneout_drop_c=0.0,
zoneout_drop_h=0.0,
device="cpu",
)
decoder = AttentionLSTMDecoderV1(decoder_cfg)
target_labels = torch.randint(low=0, high=15, size=(10, 7)) # [B,N]
Expand All @@ -69,7 +68,6 @@ def forward_decoder(zoneout_drop_c: float, zoneout_drop_h: float):
output_dropout=0.1,
zoneout_drop_c=zoneout_drop_c,
zoneout_drop_h=zoneout_drop_h,
device="cpu",
)
decoder = AttentionLSTMDecoderV1(decoder_cfg)
decoder_logits, _ = decoder(encoder_outputs=encoder, labels=target_labels, enc_seq_len=encoder_seq_len)
Expand Down

0 comments on commit 44b9cdc

Please sign in to comment.