Skip to content

Commit

Permalink
knock off a todo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 25, 2023
1 parent 6d5f4ec commit c5e75cd
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ loss = model(seq, memory_replay_backprop = True) # memory efficient training fro
## Todo

- [ ] add an axial attention down the past memories axis as an option
- [ ] for autoregressive, run experiments between future memories being causal masked or not
- [ ] add sliding windows as an option, detached or with memory propagated using MRBP
- [ ] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift

- [x] make memories being causally masked an option
- [x] add the memory replay backprop technique from memformer paper
- [x] relative positional encoding

Expand Down
2 changes: 1 addition & 1 deletion recurrent_memory_transformer_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def flash_attn(self, q, k, v, mask = None):
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

if exists(mask):
if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,18 @@ def __init__(
dim_head = 64,
heads = 8,
dropout = 0.,
use_flash_attn = False
use_flash_attn = False,
use_custom_causal_attn_mask = False
):
super().__init__()
dim_inner = dim_head * heads
self.heads = heads

self.attend = Attend(causal = causal, dropout = dropout, use_flash = use_flash_attn)
self.attend = Attend(
causal = causal and not use_custom_causal_attn_mask,
dropout = dropout,
use_flash = use_flash_attn
)

self.norm = RMSNorm(dim)

Expand Down Expand Up @@ -164,7 +169,8 @@ def __init__(
heads = 8,
ff_mult = 4,
use_flash_attn = False,
ignore_index = -1
ignore_index = -1,
memory_not_causal = True # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on
):
super().__init__()
self.causal = causal
Expand Down Expand Up @@ -197,7 +203,8 @@ def __init__(
dim_head = dim_head,
causal = causal,
heads = heads,
use_flash_attn = use_flash_attn
use_flash_attn = use_flash_attn,
use_custom_causal_attn_mask = memory_not_causal
),
FeedForward(dim = dim, mult = ff_mult)
]))
Expand All @@ -209,6 +216,10 @@ def __init__(

self.ignore_index = ignore_index

# whether to use custom attention mask if causal and memory should not be causal

self.use_custom_causal_attn_mask = causal and memory_not_causal

def init_memory(self, batch):
return repeat(self.memory_tokens, 'm d -> b m d', b = batch)

Expand Down Expand Up @@ -243,6 +254,15 @@ def forward(
if exists(mask):
mask = F.pad(mask, (read_mem_length, mem_length), value = True)

# custom causal mask, if needed

if self.use_custom_causal_attn_mask:
causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).tril()
causal_mask = F.pad(causal_mask, (read_mem_length, mem_length) * 2, value = True)

assert not exists(mask)
mask = rearrange(causal_mask, 'i j -> 1 1 i j')

# rotary embedding - offset main positions by 10000, and keep all memories at position 0

pos = pos + 10000
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.4',
version = '0.1.5',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c5e75cd

Please sign in to comment.