Skip to content

Commit

Permalink
allow for the latents to be propagated down multiple attention blocks…
Browse files Browse the repository at this point in the history
…, like in isab
  • Loading branch information
lucidrains committed Dec 18, 2023
1 parent cc9d1b4 commit 28344a9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
21 changes: 16 additions & 5 deletions agent_attention_pytorch/agent_attention_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ def __init__(
def forward(
self,
x,
mask = None
mask = None,
agent_tokens = None,
return_agent_tokens = False
):
batch = x.shape[0]

q, k, v = self.to_qkv(x)

a = repeat(self.agent_tokens, 'h m d -> b h m d', b = batch)
if exists(agent_tokens):
a = agent_tokens
else:
a = repeat(self.agent_tokens, 'h m d -> b h m d', b = batch)

a = a * self.scale

Expand All @@ -82,13 +87,19 @@ def forward(
qa_attn = self.qa_talking_heads(qa_attn)
ak_attn = self.ak_talking_heads(ak_attn)

out = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
out = einsum('b h i j, b h j d -> b h i d', qa_attn, out)
agent_gathered_tokens = einsum('b h i j, b h j d -> b h i d', ak_attn, v)

out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_gathered_tokens)

if exists(mask):
out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)

if exists(self.to_gates):
out = out * self.to_gates(x)

return self.to_out(out)
out = self.to_out(out)

if not return_agent_tokens:
return out

return out, agent_gathered_tokens
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'agent-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Agent Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 28344a9

Please sign in to comment.