diff --git a/README.md b/README.md index ecfa4eb..d519a13 100644 --- a/README.md +++ b/README.md @@ -66,85 +66,41 @@ x = torch.randn(1, 2048, 512) model(x) # (1, 2048, 512) ``` -Full encoder / decoder +Encoder / Decoder - Made possible by Thomas Melistas ```python import torch -from performer_pytorch import PerformerLM - -enc = PerformerLM( - num_tokens = 20000, - max_seq_len = 2048, - dim = 512, - depth = 6, - heads = 8 -).cuda() - -dec = PerformerLM( - num_tokens = 20000, - max_seq_len = 2048, - dim = 512, - depth = 6, - heads = 8, - causal = True, - cross_attend = True -).cuda() - -src = torch.randint(0, 20000, (1, 2048)).cuda() -tgt = torch.randint(0, 20000, (1, 2048)).cuda() -src_mask = torch.ones_like(src).bool() -tgt_mask = torch.ones_like(src).bool() - -encodings = enc(src, mask = src_mask, return_encodings = True) -logits = dec(tgt, context = encodings, mask = tgt_mask, context_mask = src_mask) # (1, 2048, 20000) -``` - -You can also use the Performer Encoder Decoder Architecture. Made by Thomas Melistas - - -```python from performer_pytorch import PerformerEncDec -import torch -IN_SEQ_LEN = 4096 -OUT_SEQ_LEN = 4096 +SRC_SEQ_LEN = 4096 +TGT_SEQ_LEN = 4096 GENERATE_LEN = 512 enc_dec = PerformerEncDec( dim = 512, + tie_token_embed = True, enc_num_tokens = 20000, enc_depth = 6, enc_heads = 8, - enc_max_seq_len = IN_SEQ_LEN, + enc_max_seq_len = SRC_SEQ_LEN, dec_num_tokens = 20000, dec_depth = 6, dec_heads = 8, - dec_max_seq_len = OUT_SEQ_LEN + dec_max_seq_len = TGT_SEQ_LEN, ) -# if you have variable length sequences padding is done for you -train_in = [ - torch.randint(0, 20000, (120,)).long(), - torch.randint(0, 20000, (253,)).long(), - torch.randint(0, 20000, (646,)).long() -] -train_out = [ - torch.randint(0, 20000, (110,)).long(), - torch.randint(0, 20000, (500,)).long(), - torch.randint(0, 20000, (585,)).long() -] - -# you have to use masks for variable length sequences (decoder mask should be 1 smaller than longest tensor) -in_mask = torch.arange(646).view(1, -1).expand(3, -1) < torch.tensor([120,253,646]).view(-1, 1) -out_mask = torch.arange(584).view(1, -1).expand(3, -1) < torch.tensor([110,500,585]).view(-1, 1) +src = torch.randint(0, 20000, (1, SRC_SEQ_LEN)) +tgt = torch.randint(0, 20000, (1, TGT_SEQ_LEN)) +src_mask = torch.ones_like(src).bool() +tgt_mask = torch.ones_like(src).bool() # train enc_dec.train() -loss = enc_dec(train_in, train_out, return_loss = True, enc_mask = in_mask, dec_mask = out_mask) +loss = enc_dec(src, tgt, enc_mask = src_mask, dec_mask = tgt_mask, return_loss = True) loss.backward() # generate -generate_in = torch.randint(0, 20000, (1, IN_SEQ_LEN)).long() +generate_in = torch.randint(0, 20000, (1, SRC_SEQ_LEN)).long() generate_out_prime = torch.tensor([[0.]]).long() # prime with token samples = enc_dec.generate(generate_in, generate_out_prime, seq_len = GENERATE_LEN, eos_token = 1) # assume 1 is id of stop token print(samples.shape) # (1, <= GENERATE_LEN) decode the tokens diff --git a/performer_pytorch/autoregressive_wrapper.py b/performer_pytorch/autoregressive_wrapper.py index cb980ab..a28d687 100644 --- a/performer_pytorch/autoregressive_wrapper.py +++ b/performer_pytorch/autoregressive_wrapper.py @@ -84,19 +84,15 @@ def forward(self, x, return_loss = False, **kwargs): x = pad(x) return self.net(x, **kwargs) - if isinstance(x, torch.Tensor): - xi = x[:, :-1] - xo = x[:, 1:] - - # help auto-solve an area of confusion around input masks in auto-regressive - # if user supplies a mask that is only off by one from the source sequence, resolve it for them - mask = kwargs.pop('mask', None) - if mask is not None and mask.shape[1] == x.shape[1]: - mask = mask[:, :-1] - kwargs.update(mask = mask) - else: - xi = pad(list(map(lambda t: t[:-1], x))) - xo = pad(list(map(lambda t: t[1:], x))) + xi = x[:, :-1] + xo = x[:, 1:] + + # help auto-solve an area of confusion around input masks in auto-regressive + # if user supplies a mask that is only off by one from the source sequence, resolve it for them + mask = kwargs.pop('mask', None) + if mask is not None and mask.shape[1] == x.shape[1]: + mask = mask[:, :-1] + kwargs.update(mask = mask) out = self.net(xi, **kwargs) diff --git a/performer_pytorch/performer_enc_dec.py b/performer_pytorch/performer_enc_dec.py index b27db21..6bb5c05 100644 --- a/performer_pytorch/performer_enc_dec.py +++ b/performer_pytorch/performer_enc_dec.py @@ -37,7 +37,14 @@ def extract_and_set_enc_dec_kwargs(kwargs): return enc_kwargs, dec_kwargs, kwargs class PerformerEncDec(nn.Module): - def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs): + def __init__( + self, + dim, + ignore_index = 0, + pad_value = 0, + tie_token_embeds = False, + **kwargs + ): super().__init__() enc_kwargs, dec_kwargs, _ = extract_enc_dec_kwargs(kwargs) @@ -50,7 +57,10 @@ def __init__(self, dim, ignore_index = 0, pad_value = 0, **kwargs): enc = PerformerLM(**enc_kwargs) dec = PerformerLM(**dec_kwargs) - self.enc = AutoregressiveWrapper(enc, ignore_index = ignore_index, pad_value = pad_value) + if tie_token_embeds: + enc.token_embed = dec.token_embed + + self.enc = enc self.dec = AutoregressiveWrapper(dec, ignore_index = ignore_index, pad_value = pad_value) def generate(self, seq_in, seq_out_start, seq_len, **kwargs): diff --git a/setup.py b/setup.py index c366128..1647404 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'performer-pytorch', packages = find_packages(exclude=['examples']), - version = '0.12.0', + version = '0.12.1', license='MIT', description = 'Performer - Pytorch', author = 'Phil Wang',