Skip to content

Commit

Permalink
also gate the agent tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 19, 2023
1 parent 1e91815 commit 9b2f782
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions agent_attention_pytorch/agent_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def forward(
mask = None
):
x = self.norm(x)
agent_tokens = self.norm(agent_tokens)
a = self.norm(agent_tokens)

x_and_agents, xa_ps = pack([agent_tokens, x], 'b * d')
x_and_agents, xa_ps = pack([a, x], 'b * d')
qkv = self.to_qkv(x_and_agents)

qkv_agent, qkv_input = unpack(qkv, xa_ps, 'qkv b h * d')
Expand All @@ -114,18 +114,19 @@ def forward(
qa_attn = self.qa_talking_heads(qa_attn)
ak_attn = self.ak_talking_heads(ak_attn)

agent_gathered_tokens = einsum('b h i j, b h j d -> b h i d', ak_attn, v)
agent_out = einsum('b h i j, b h j d -> b h i d', ak_attn, v)

out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_gathered_tokens)
out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_out)

if exists(mask):
out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.)

if exists(self.to_gates):
out = out * self.to_gates(x)
agent_out = agent_out * self.to_gates(a)

out = self.to_out(out)
agent_out = self.to_out(agent_gathered_tokens)
agent_out = self.to_out(agent_out)

return out, agent_out

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'agent-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'Agent Attention - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 9b2f782

Please sign in to comment.