Skip to content

Commit

Permalink
address #12
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 8, 2023
1 parent 98bf309 commit 8d560e9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions recurrent_memory_transformer_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def flash_attn(self, q, k, v, mask = None):
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

if exists(mask) and mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
if exists(mask):
if mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)

# Check if there is a compatible device for flash attention
Expand Down Expand Up @@ -123,7 +124,8 @@ def forward(self, q, k, v, mask = None):
# key padding mask

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
if mask.ndim != 4:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)

# causal mask
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.1',
version = '0.4.2',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 8d560e9

Please sign in to comment.