From c5e75cd0f904b7d278f2ea21219a0dc7e36b2b91 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 25 Apr 2023 15:57:31 -0700 Subject: [PATCH] knock off a todo --- README.md | 2 +- .../attend.py | 2 +- .../recurrent_memory_transformer.py | 28 ++++++++++++++++--- setup.py | 2 +- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 2b144bb..645792c 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ loss = model(seq, memory_replay_backprop = True) # memory efficient training fro ## Todo - [ ] add an axial attention down the past memories axis as an option -- [ ] for autoregressive, run experiments between future memories being causal masked or not - [ ] add sliding windows as an option, detached or with memory propagated using MRBP - [ ] offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift +- [x] make memories being causally masked an option - [x] add the memory replay backprop technique from memformer paper - [x] relative positional encoding diff --git a/recurrent_memory_transformer_pytorch/attend.py b/recurrent_memory_transformer_pytorch/attend.py index 8abc008..7a9497e 100644 --- a/recurrent_memory_transformer_pytorch/attend.py +++ b/recurrent_memory_transformer_pytorch/attend.py @@ -80,7 +80,7 @@ def flash_attn(self, q, k, v, mask = None): # Check if mask exists and expand to compatible shape # The mask is B L, so it would have to be expanded to B H N L - if exists(mask): + if exists(mask) and mask.ndim != 4: mask = rearrange(mask, 'b j -> b 1 1 j') mask = mask.expand(-1, heads, q_len, -1) diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index de18141..e86989e 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -110,13 +110,18 @@ def __init__( dim_head = 64, heads = 8, dropout = 0., - use_flash_attn = False + use_flash_attn = False, + use_custom_causal_attn_mask = False ): super().__init__() dim_inner = dim_head * heads self.heads = heads - self.attend = Attend(causal = causal, dropout = dropout, use_flash = use_flash_attn) + self.attend = Attend( + causal = causal and not use_custom_causal_attn_mask, + dropout = dropout, + use_flash = use_flash_attn + ) self.norm = RMSNorm(dim) @@ -164,7 +169,8 @@ def __init__( heads = 8, ff_mult = 4, use_flash_attn = False, - ignore_index = -1 + ignore_index = -1, + memory_not_causal = True # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on ): super().__init__() self.causal = causal @@ -197,7 +203,8 @@ def __init__( dim_head = dim_head, causal = causal, heads = heads, - use_flash_attn = use_flash_attn + use_flash_attn = use_flash_attn, + use_custom_causal_attn_mask = memory_not_causal ), FeedForward(dim = dim, mult = ff_mult) ])) @@ -209,6 +216,10 @@ def __init__( self.ignore_index = ignore_index + # whether to use custom attention mask if causal and memory should not be causal + + self.use_custom_causal_attn_mask = causal and memory_not_causal + def init_memory(self, batch): return repeat(self.memory_tokens, 'm d -> b m d', b = batch) @@ -243,6 +254,15 @@ def forward( if exists(mask): mask = F.pad(mask, (read_mem_length, mem_length), value = True) + # custom causal mask, if needed + + if self.use_custom_causal_attn_mask: + causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).tril() + causal_mask = F.pad(causal_mask, (read_mem_length, mem_length) * 2, value = True) + + assert not exists(mask) + mask = rearrange(causal_mask, 'i j -> 1 1 i j') + # rotary embedding - offset main positions by 10000, and keep all memories at position 0 pos = pos + 10000 diff --git a/setup.py b/setup.py index 8f58b13..76ecfcd 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.1.4', + version = '0.1.5', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',