Skip to content

Commit

Permalink
clean up encoder / decoder code
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 8, 2020
1 parent 36bcb03 commit 0c92b80
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 72 deletions.
68 changes: 12 additions & 56 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,85 +66,41 @@ x = torch.randn(1, 2048, 512)
model(x) # (1, 2048, 512)
```

Full encoder / decoder
Encoder / Decoder - Made possible by <a href="https://github.com/gulnazaki">Thomas Melistas</a>

```python
import torch
from performer_pytorch import PerformerLM

enc = PerformerLM(
num_tokens = 20000,
max_seq_len = 2048,
dim = 512,
depth = 6,
heads = 8
).cuda()

dec = PerformerLM(
num_tokens = 20000,
max_seq_len = 2048,
dim = 512,
depth = 6,
heads = 8,
causal = True,
cross_attend = True
).cuda()

src = torch.randint(0, 20000, (1, 2048)).cuda()
tgt = torch.randint(0, 20000, (1, 2048)).cuda()
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

encodings = enc(src, mask = src_mask, return_encodings = True)
logits = dec(tgt, context = encodings, mask = tgt_mask, context_mask = src_mask) # (1, 2048, 20000)
```

You can also use the Performer Encoder Decoder Architecture. Made by <a href="https://github.com/gulnazaki">Thomas Melistas</a>


```python
from performer_pytorch import PerformerEncDec
import torch

IN_SEQ_LEN = 4096
OUT_SEQ_LEN = 4096
SRC_SEQ_LEN = 4096
TGT_SEQ_LEN = 4096
GENERATE_LEN = 512

enc_dec = PerformerEncDec(
dim = 512,
tie_token_embed = True,
enc_num_tokens = 20000,
enc_depth = 6,
enc_heads = 8,
enc_max_seq_len = IN_SEQ_LEN,
enc_max_seq_len = SRC_SEQ_LEN,
dec_num_tokens = 20000,
dec_depth = 6,
dec_heads = 8,
dec_max_seq_len = OUT_SEQ_LEN
dec_max_seq_len = TGT_SEQ_LEN,
)

# if you have variable length sequences padding is done for you
train_in = [
torch.randint(0, 20000, (120,)).long(),
torch.randint(0, 20000, (253,)).long(),
torch.randint(0, 20000, (646,)).long()
]
train_out = [
torch.randint(0, 20000, (110,)).long(),
torch.randint(0, 20000, (500,)).long(),
torch.randint(0, 20000, (585,)).long()
]

# you have to use masks for variable length sequences (decoder mask should be 1 smaller than longest tensor)
in_mask = torch.arange(646).view(1, -1).expand(3, -1) < torch.tensor([120,253,646]).view(-1, 1)
out_mask = torch.arange(584).view(1, -1).expand(3, -1) < torch.tensor([110,500,585]).view(-1, 1)
src = torch.randint(0, 20000, (1, SRC_SEQ_LEN))
tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN))
src_mask = torch.ones_like(src).bool()
tgt_mask = torch.ones_like(src).bool()

# train
enc_dec.train()
loss = enc_dec(train_in, train_out, return_loss = True, enc_mask = in_mask, dec_mask = out_mask)
loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask, return_loss = True)
loss.backward()

# generate
generate_in = torch.randint(0, 20000, (1, IN_SEQ_LEN)).long()
generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long()
generate_out_prime = torch.tensor([[0.]]).long() # prime with <bos> token
samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token
print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens
Expand Down
22 changes: 9 additions & 13 deletions performer_pytorch/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,15 @@ def forward(self, x, return_loss = False, **kwargs):
x = pad(x)
return self.net(x, **kwargs)

if isinstance(x, torch.Tensor):
xi = x[:, :-1]
xo = x[:, 1:]

# help auto-solve an area of confusion around input masks in auto-regressive
# if user supplies a mask that is only off by one from the source sequence, resolve it for them
mask = kwargs.pop('mask', None)
if mask is not None and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs.update(mask = mask)
else:
xi = pad(list(map(lambda t: t[:-1], x)))
xo = pad(list(map(lambda t: t[1:], x)))
xi = x[:, :-1]
xo = x[:, 1:]

# help auto-solve an area of confusion around input masks in auto-regressive
# if user supplies a mask that is only off by one from the source sequence, resolve it for them
mask = kwargs.pop('mask', None)
if mask is not None and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs.update(mask = mask)

out = self.net(xi, **kwargs)

Expand Down
14 changes: 12 additions & 2 deletions performer_pytorch/performer_enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ def extract_and_set_enc_dec_kwargs(kwargs):
return enc_kwargs, dec_kwargs, kwargs

class PerformerEncDec(nn.Module):
def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs):
def __init__(
self,
dim,
ignore_index = 0,
pad_value = 0,
tie_token_embeds = False,
**kwargs
):
super().__init__()
enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs)

Expand All @@ -50,7 +57,10 @@ def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs):
enc = PerformerLM(**enc_kwargs)
dec = PerformerLM(**dec_kwargs)

self.enc = AutoregressiveWrapper(enc, ignore_index = ignore_index, pad_value = pad_value)
if tie_token_embeds:
enc.token_embed = dec.token_embed

self.enc = enc
self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value)

def generate(self, seq_in, seq_out_start, seq_len, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.12.0',
version = '0.12.1',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0c92b80

Please sign in to comment.