diff --git a/README.md b/README.md
index ecfa4eb..d519a13 100644
--- a/README.md
+++ b/README.md
@@ -66,85 +66,41 @@ x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)
```
-Full encoder / decoder
+Encoder / Decoder - Made possible by Thomas Melistas
```python
import torch
-from performer_pytorch import PerformerLM
-
-enc = PerformerLM(
- num_tokens = 20000,
- max_seq_len = 2048,
- dim = 512,
- depth = 6,
- heads = 8
-).cuda()
-
-dec = PerformerLM(
- num_tokens = 20000,
- max_seq_len = 2048,
- dim = 512,
- depth = 6,
- heads = 8,
- causal = True,
- cross_attend = True
-).cuda()
-
-src = torch.randint(0, 20000, (1, 2048)).cuda()
-tgt = torch.randint(0, 20000, (1, 2048)).cuda()
-src_mask = torch.ones_like(src).bool()
-tgt_mask = torch.ones_like(src).bool()
-
-encodings = enc(src, mask = src_mask, return_encodings = True)
-logits = dec(tgt, context = encodings, mask = tgt_mask, context_mask = src_mask) # (1, 2048, 20000)
-```
-
-You can also use the Performer Encoder Decoder Architecture. Made by Thomas Melistas
-
-
-```python
from performer_pytorch import PerformerEncDec
-import torch
-IN_SEQ_LEN = 4096
-OUT_SEQ_LEN = 4096
+SRC_SEQ_LEN = 4096
+TGT_SEQ_LEN = 4096
GENERATE_LEN = 512
enc_dec = PerformerEncDec(
dim = 512,
+ tie_token_embed = True,
enc_num_tokens = 20000,
enc_depth = 6,
enc_heads = 8,
- enc_max_seq_len = IN_SEQ_LEN,
+ enc_max_seq_len = SRC_SEQ_LEN,
dec_num_tokens = 20000,
dec_depth = 6,
dec_heads = 8,
- dec_max_seq_len = OUT_SEQ_LEN
+ dec_max_seq_len = TGT_SEQ_LEN,
)
-# if you have variable length sequences padding is done for you
-train_in = [
- torch.randint(0, 20000, (120,)).long(),
- torch.randint(0, 20000, (253,)).long(),
- torch.randint(0, 20000, (646,)).long()
-]
-train_out = [
- torch.randint(0, 20000, (110,)).long(),
- torch.randint(0, 20000, (500,)).long(),
- torch.randint(0, 20000, (585,)).long()
-]
-
-# you have to use masks for variable length sequences (decoder mask should be 1 smaller than longest tensor)
-in_mask = torch.arange(646).view(1, -1).expand(3, -1) < torch.tensor([120,253,646]).view(-1, 1)
-out_mask = torch.arange(584).view(1, -1).expand(3, -1) < torch.tensor([110,500,585]).view(-1, 1)
+src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
+tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
+src_mask = torch.ones_like(src).bool()
+tgt_mask = torch.ones_like(src).bool()
# train
enc_dec.train()
-loss = enc_dec(train_in, train_out, return_loss = True, enc_mask = in_mask, dec_mask = out_mask)
+loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask, return_loss = True)
loss.backward()
# generate
-generate_in = torch.randint(0, 20000, (1, IN_SEQ_LEN)).long()
+generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens
diff --git a/performer_pytorch/autoregressive_wrapper.py b/performer_pytorch/autoregressive_wrapper.py
index cb980ab..a28d687 100644
--- a/performer_pytorch/autoregressive_wrapper.py
+++ b/performer_pytorch/autoregressive_wrapper.py
@@ -84,19 +84,15 @@ def forward(self, x, return_loss = False, **kwargs):
x = pad(x)
return self.net(x, **kwargs)
- if isinstance(x, torch.Tensor):
- xi = x[:, :-1]
- xo = x[:, 1:]
-
- # help auto-solve an area of confusion around input masks in auto-regressive
- # if user supplies a mask that is only off by one from the source sequence, resolve it for them
- mask = kwargs.pop('mask', None)
- if mask is not None and mask.shape[1] == x.shape[1]:
- mask = mask[:, :-1]
- kwargs.update(mask = mask)
- else:
- xi = pad(list(map(lambda t: t[:-1], x)))
- xo = pad(list(map(lambda t: t[1:], x)))
+ xi = x[:, :-1]
+ xo = x[:, 1:]
+
+ # help auto-solve an area of confusion around input masks in auto-regressive
+ # if user supplies a mask that is only off by one from the source sequence, resolve it for them
+ mask = kwargs.pop('mask', None)
+ if mask is not None and mask.shape[1] == x.shape[1]:
+ mask = mask[:, :-1]
+ kwargs.update(mask = mask)
out = self.net(xi, **kwargs)
diff --git a/performer_pytorch/performer_enc_dec.py b/performer_pytorch/performer_enc_dec.py
index b27db21..6bb5c05 100644
--- a/performer_pytorch/performer_enc_dec.py
+++ b/performer_pytorch/performer_enc_dec.py
@@ -37,7 +37,14 @@ def extract_and_set_enc_dec_kwargs(kwargs):
return enc_kwargs, dec_kwargs, kwargs
class PerformerEncDec(nn.Module):
- def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs):
+ def __init__(
+ self,
+ dim,
+ ignore_index = 0,
+ pad_value = 0,
+ tie_token_embeds = False,
+ **kwargs
+ ):
super().__init__()
enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)
@@ -50,7 +57,10 @@ def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs):
enc = PerformerLM(**enc_kwargs)
dec = PerformerLM(**dec_kwargs)
- self.enc = AutoregressiveWrapper(enc, ignore_index = ignore_index, pad_value = pad_value)
+ if tie_token_embeds:
+ enc.token_embed = dec.token_embed
+
+ self.enc = enc
self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)
def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
diff --git a/setup.py b/setup.py
index c366128..1647404 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
- version = '0.12.0',
+ version = '0.12.1',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',