From d45ef72a40324c6224ffacb890d5593a69db73de Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 31 Aug 2023 14:14:22 -0700 Subject: [PATCH] address https://github.com/lucidrains/recurrent-memory-transformer-pytorch/issues/19 again --- .../recurrent_memory_transformer.py | 16 +++++++++++++++- setup.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index 0675751..313b962 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -322,7 +322,8 @@ def forward( *, mask = None, labels = None, - xl_memories: Optional[List[Tensor]] = None + xl_memories: Optional[List[Tensor]] = None, + mask_out_read_memories = False # in the case one is passing in 0s for read memories, for onnx-able model ): has_xl_memories = exists(xl_memories) and len(xl_memories) > 0 @@ -354,6 +355,9 @@ def forward( # prepare read memories if exists(read_memories): + if read_memories.ndim == 2: + read_memories = repeat(read_memories, 'n d -> b n d', b = b) + read_mem_length = mem_length read_memories = read_memories + self.read_memory_emb elif self.always_have_read_memories: @@ -388,6 +392,16 @@ def forward( else: mask = causal_mask + # masking out read memories, either for passing in 0s for read memories on first step, or if you are doing some regularization game on the memories + + if read_mem_length > 0 and mask_out_read_memories: + read_mem_mask = torch.arange(x.shape[-2], device = device) < read_mem_length + + if exists(mask): + mask = mask & ~read_mem_mask + else: + mask = read_mem_mask + # rotary embedding - offset main positions by 10000, and keep all memories at position 0 rotary_emb = None diff --git a/setup.py b/setup.py index 81d4627..5189119 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.5.4', + version = '0.5.5', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',