diff --git a/README.md b/README.md index 645792c..c40cc97 100644 --- a/README.md +++ b/README.md @@ -71,10 +71,9 @@ loss = model(seq, memory_replay_backprop = True) # memory efficient training fro ## Todo -- [ ] add an axial attention down the past memories axis as an option - [ ] add sliding windows as an option, detached or with memory propagated using MRBP -- [ ] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift +- [x] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift - [x] make memories being causally masked an option - [x] add the memory replay backprop technique from memformer paper - [x] relative positional encoding diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index e86989e..e9b4ed9 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -1,4 +1,5 @@ import math +from functools import partial from itertools import zip_longest from contextlib import nullcontext @@ -15,6 +16,9 @@ def exists(val): return val is not None +def identity(t, *args, **kwargs): + return t + def default(*vals): for val in vals: if exists(val): @@ -52,6 +56,13 @@ def top_k(logits, thres = 0.9): probs.scatter_(1, ind, val) return probs +def token_shift_fn(t, ps): + read_mem, t, write_mem = unpack(t, ps, 'b * d') + t, t_shift = t.chunk(2, dim = -1) + t_shift = F.pad(t_shift, (0, 0, 1, -1), value = 0.) + t = torch.cat((t, t_shift), dim = -1) + return torch.cat((read_mem, t, write_mem), dim = -2) + # rotary embedding class RotaryEmbedding(nn.Module): @@ -170,6 +181,9 @@ def __init__( ff_mult = 4, use_flash_attn = False, ignore_index = -1, + abs_pos_emb = True, + rotary_pos_emb = False, + token_shift = True, memory_not_causal = True # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on ): super().__init__() @@ -182,8 +196,13 @@ def __init__( # positions - self.pos_emb = nn.Embedding(seq_len, dim) - self.rotary_pos_emb = RotaryEmbedding(dim_head) + assert any([abs_pos_emb, rotary_pos_emb, token_shift]) + + self.pos_emb = nn.Embedding(seq_len, dim) if abs_pos_emb else None + + self.rotary_pos_emb = RotaryEmbedding(dim_head) if rotary_pos_emb else None + + self.maybe_token_shift = token_shift_fn if token_shift else identity # memory related @@ -236,7 +255,11 @@ def forward( pos = torch.arange(n, device = device) x = self.token_emb(x) - x = x + self.pos_emb(pos) + + # maybe absolute positional embedding + + if exists(self.pos_emb): + x = x + self.pos_emb(pos) # prepare read and write memories, as in paper @@ -265,16 +288,21 @@ def forward( # rotary embedding - offset main positions by 10000, and keep all memories at position 0 - pos = pos + 10000 - pos = F.pad(pos, (read_mem_length, mem_length), value = 0) + rotary_emb = None + + if exists(self.rotary_pos_emb): + pos = pos + 10000 + pos = F.pad(pos, (read_mem_length, mem_length), value = 0) + + rotary_emb = self.rotary_pos_emb(pos) - rotary_emb = self.rotary_pos_emb(pos) + shift_fn = partial(self.maybe_token_shift, ps = ps) # attention and feedforward for attn, ff in self.layers: - x = attn(x, mask = mask, rotary_emb = rotary_emb) + x - x = ff(x) + x + x = attn(shift_fn(x), mask = mask, rotary_emb = rotary_emb) + x + x = ff(shift_fn(x)) + x # split out memories using unpack diff --git a/setup.py b/setup.py index 76ecfcd..d315503 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.1.5', + version = '0.1.6', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',