Skip to content

Commit

Permalink
use absolute positional embedding and token shift as default
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 26, 2023
1 parent c5e75cd commit 969828a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ loss = model(seq, memory_replay_backprop = True) # memory efficient training fro

## Todo

- [ ] add an axial attention down the past memories axis as an option
- [ ] add sliding windows as an option, detached or with memory propagated using MRBP
- [ ] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift

- [x] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift
- [x] make memories being causally masked an option
- [x] add the memory replay backprop technique from memformer paper
- [x] relative positional encoding
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from functools import partial
from itertools import zip_longest
from contextlib import nullcontext

Expand All @@ -15,6 +16,9 @@
def exists(val):
return val is not None

def identity(t, *args, **kwargs):
return t

def default(*vals):
for val in vals:
if exists(val):
Expand Down Expand Up @@ -52,6 +56,13 @@ def top_k(logits, thres = 0.9):
probs.scatter_(1, ind, val)
return probs

def token_shift_fn(t, ps):
read_mem, t, write_mem = unpack(t, ps, 'b * d')
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.)
t = torch.cat((t, t_shift), dim = -1)
return torch.cat((read_mem, t, write_mem), dim = -2)

# rotary embedding

class RotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -170,6 +181,9 @@ def __init__(
ff_mult = 4,
use_flash_attn = False,
ignore_index = -1,
abs_pos_emb = True,
rotary_pos_emb = False,
token_shift = True,
memory_not_causal = True # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on
):
super().__init__()
Expand All @@ -182,8 +196,13 @@ def __init__(

# positions

self.pos_emb = nn.Embedding(seq_len, dim)
self.rotary_pos_emb = RotaryEmbedding(dim_head)
assert any([abs_pos_emb, rotary_pos_emb, token_shift])

self.pos_emb = nn.Embedding(seq_len, dim) if abs_pos_emb else None

self.rotary_pos_emb = RotaryEmbedding(dim_head) if rotary_pos_emb else None

self.maybe_token_shift = token_shift_fn if token_shift else identity

# memory related

Expand Down Expand Up @@ -236,7 +255,11 @@ def forward(
pos = torch.arange(n, device = device)

x = self.token_emb(x)
x = x + self.pos_emb(pos)

# maybe absolute positional embedding

if exists(self.pos_emb):
x = x + self.pos_emb(pos)

# prepare read and write memories, as in paper

Expand Down Expand Up @@ -265,16 +288,21 @@ def forward(

# rotary embedding - offset main positions by 10000, and keep all memories at position 0

pos = pos + 10000
pos = F.pad(pos, (read_mem_length, mem_length), value = 0)
rotary_emb = None

if exists(self.rotary_pos_emb):
pos = pos + 10000
pos = F.pad(pos, (read_mem_length, mem_length), value = 0)

rotary_emb = self.rotary_pos_emb(pos)

rotary_emb = self.rotary_pos_emb(pos)
shift_fn = partial(self.maybe_token_shift, ps = ps)

# attention and feedforward

for attn, ff in self.layers:
x = attn(x, mask = mask, rotary_emb = rotary_emb) + x
x = ff(x) + x
x = attn(shift_fn(x), mask = mask, rotary_emb = rotary_emb) + x
x = ff(shift_fn(x)) + x

# split out memories using unpack

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.5',
version = '0.1.6',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 969828a

Please sign in to comment.