Skip to content

Commit

Permalink
give a null key / value to protect against entirely masked out row, a…
Browse files Browse the repository at this point in the history
…s well as to give attention the ability to attend to nothing
  • Loading branch information
lucidrains committed Aug 9, 2023
1 parent 14ecab4 commit 976af1e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(
use_flash = use_flash_attn
)

self.null_kv = nn.Parameter(torch.randn(2, heads, dim_head))

self.to_q = Linear(dim, dim_inner)
self.to_kv = Linear(dim, dim_inner * 2)
self.to_out = Linear(dim_inner, dim)
Expand All @@ -164,14 +166,29 @@ def forward(

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

# add a null key / value
# to protect against an entirely masked out sequence
# as well as giving attention ability to attend to nothing

nk, nv = map(lambda t: repeat(t, 'h d -> b h 1 d', b = x.shape[0]), self.null_kv)

k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)

if exists(mask):
mask = F.pad(mask, (1, 0), value = True)

# manage memories

next_xl_memories = torch.stack((k, v))

if exists(xl_memories):
kx, vx = xl_memories
k = torch.cat((kx, k), dim = -2)
v = torch.cat((vx, v), dim = -2)

mask = F.pad(mask, (xl_memories.shape[-2], 0), value = True)
if exists(mask):
mask = F.pad(mask, (xl_memories.shape[-2], 0), value = True)

if exists(rotary_emb):
q_rotary_emb, k_rotary_emb = rotary_emb
Expand Down Expand Up @@ -372,9 +389,14 @@ def forward(
if has_xl_memories:
k_pos = torch.arange(xl_mem_length, device = device) + mem_rel_dist
k_pos = torch.cat((k_pos, q_pos), dim = -1)
k_rotary_emb = self.rotary_pos_emb(k_pos)
else:
k_rotary_emb = q_rotary_emb
k_pos = q_pos

# account for null key / value

k_pos = F.pad(k_pos, (1, 0), value = mem_rel_dist - 1) # give a null memory token, to allow for attending to nothing

k_rotary_emb = self.rotary_pos_emb(k_pos)

rotary_emb = (q_rotary_emb, k_rotary_emb)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.3',
version = '0.5.0',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 976af1e

Please sign in to comment.