Skip to content

Commit

Permalink
address #19 again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 31, 2023
1 parent 90de2ac commit d45ef72
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ def forward(
*,
mask = None,
labels = None,
xl_memories: Optional[List[Tensor]] = None
xl_memories: Optional[List[Tensor]] = None,
mask_out_read_memories = False # in the case one is passing in 0s for read memories, for onnx-able model
):
has_xl_memories = exists(xl_memories) and len(xl_memories) > 0

Expand Down Expand Up @@ -354,6 +355,9 @@ def forward(
# prepare read memories

if exists(read_memories):
if read_memories.ndim == 2:
read_memories = repeat(read_memories, 'n d -> b n d', b = b)

read_mem_length = mem_length
read_memories = read_memories + self.read_memory_emb
elif self.always_have_read_memories:
Expand Down Expand Up @@ -388,6 +392,16 @@ def forward(
else:
mask = causal_mask

# masking out read memories, either for passing in 0s for read memories on first step, or if you are doing some regularization game on the memories

if read_mem_length > 0 and mask_out_read_memories:
read_mem_mask = torch.arange(x.shape[-2], device = device) < read_mem_length

if exists(mask):
mask = mask & ~read_mem_mask
else:
mask = read_mem_mask

# rotary embedding - offset main positions by 10000, and keep all memories at position 0

rotary_emb = None
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.4',
version = '0.5.5',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d45ef72

Please sign in to comment.