From 076160ae9f47a33cee717b068a85f810f5386e5f Mon Sep 17 00:00:00 2001 From: vasiliyeskin Date: Thu, 5 Aug 2021 12:55:50 +0300 Subject: [PATCH 1/2] Remove the unused parameter 'pad_value' --- performer_pytorch/autoregressive_wrapper.py | 3 +-- performer_pytorch/performer_enc_dec.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/performer_pytorch/autoregressive_wrapper.py b/performer_pytorch/autoregressive_wrapper.py index 8707bf7..94f3ba8 100644 --- a/performer_pytorch/autoregressive_wrapper.py +++ b/performer_pytorch/autoregressive_wrapper.py @@ -26,9 +26,8 @@ def top_k(logits, thres = 0.9): return probs class AutoregressiveWrapper(nn.Module): - def __init__(self, net, ignore_index = 0, pad_value = 0): + def __init__(self, net, ignore_index = 0): super().__init__() - self.pad_value = pad_value self.ignore_index = ignore_index self.net = net diff --git a/performer_pytorch/performer_enc_dec.py b/performer_pytorch/performer_enc_dec.py index 1a88cb3..d670fb3 100644 --- a/performer_pytorch/performer_enc_dec.py +++ b/performer_pytorch/performer_enc_dec.py @@ -42,7 +42,6 @@ def __init__( self, dim, ignore_index = 0, - pad_value = 0, tie_token_embeds = False, no_projection = False, **kwargs @@ -65,7 +64,7 @@ def __init__( enc.token_emb = dec.token_emb self.enc = enc - self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value) + self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index) @torch.no_grad() def generate(self, seq_in, seq_out_start, seq_len, **kwargs): From 2f968ea14a6bd2bbe441e2ef98bb44ae86429c30 Mon Sep 17 00:00:00 2001 From: Vasiliy Es'kin Date: Thu, 5 Aug 2021 13:01:35 +0300 Subject: [PATCH 2/2] Remove the extra commas in the Performer and Attention definitions --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2b84fbc..ee6957d 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ enc_dec = PerformerEncDec( dec_num_tokens = 20000, dec_depth = 6, dec_heads = 8, - dec_max_seq_len = TGT_SEQ_LEN, + dec_max_seq_len = TGT_SEQ_LEN ) src = torch.randint(0, 20000, (1, SRC_SEQ_LEN)) @@ -124,7 +124,7 @@ from performer_pytorch import SelfAttention attn = SelfAttention( dim = 512, heads = 8, - causal = False, + causal = False ).cuda() x = torch.randn(1, 1024, 512).cuda()