diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index 8a8dcdf..c7ca4b9 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -288,7 +288,9 @@ def forward( if self.use_custom_causal_attn_mask: causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).tril() - causal_mask = F.pad(causal_mask, (read_mem_length, mem_length) * 2, value = True) + + causal_mask = F.pad(causal_mask, (0, mem_length, read_mem_length, 0), value = False) + causal_mask = F.pad(causal_mask, (read_mem_length, 0, 0, mem_length), value = True) assert not exists(mask) mask = rearrange(causal_mask, 'i j -> 1 1 i j') diff --git a/setup.py b/setup.py index da95df3..78f697a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.1.7', + version = '0.1.8', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',