Skip to content

Commit

Permalink
address #19 by allowing for an option to attend to raw read memory po…
Browse files Browse the repository at this point in the history
…sitional embeddings on first step
  • Loading branch information
lucidrains committed Aug 31, 2023
1 parent 3be7d43 commit 90de2ac
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
memory_not_causal = True, # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on
add_write_to_next_write_mem = False, # add the write memories of previous step to the next write step - thanks to @IcarusWizard for pointing out this discrepancy
next_write_mem_stop_grad = True, # whether to stop gradient of previous read memory -> next write memory
always_have_read_memories = True, # whether to always have read memories, even on the first step, so to make the model onnx-able
resi_dual_scale = 1., # in the case of overflows in fp16 on the prenorm branch, set this to a value less than 1.
):
super().__init__()
Expand Down Expand Up @@ -306,6 +307,11 @@ def __init__(
self.add_write_to_next_write_mem = add_write_to_next_write_mem
self.next_write_mem_stop_grad = next_write_mem_stop_grad

# allow for attending to raw read memory positional embeddings on first step
# hack to make it onnx-able and should not hurt

self.always_have_read_memories = always_have_read_memories

def init_memory(self, batch):
return repeat(self.memory_tokens, 'm d -> b m d', b = batch)

Expand Down Expand Up @@ -350,6 +356,9 @@ def forward(
if exists(read_memories):
read_mem_length = mem_length
read_memories = read_memories + self.read_memory_emb
elif self.always_have_read_memories:
read_mem_length = mem_length
read_memories = repeat(self.read_memory_emb, 'n d -> b n d', b = b)
else:
read_mem_length = 0
read_memories = x[:, 0:0]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.3',
version = '0.5.4',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 90de2ac

Please sign in to comment.