From 9b2f7821315f054eaa38db6acca4b820e1381467 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 19 Dec 2023 10:22:21 -0800 Subject: [PATCH] also gate the agent tokens --- agent_attention_pytorch/agent_transformer.py | 11 ++++++----- setup.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/agent_attention_pytorch/agent_transformer.py b/agent_attention_pytorch/agent_transformer.py index e3144e1..c806962 100644 --- a/agent_attention_pytorch/agent_transformer.py +++ b/agent_attention_pytorch/agent_transformer.py @@ -85,9 +85,9 @@ def forward( mask = None ): x = self.norm(x) - agent_tokens = self.norm(agent_tokens) + a = self.norm(agent_tokens) - x_and_agents, xa_ps = pack([agent_tokens, x], 'b * d') + x_and_agents, xa_ps = pack([a, x], 'b * d') qkv = self.to_qkv(x_and_agents) qkv_agent, qkv_input = unpack(qkv, xa_ps, 'qkv b h * d') @@ -114,18 +114,19 @@ def forward( qa_attn = self.qa_talking_heads(qa_attn) ak_attn = self.ak_talking_heads(ak_attn) - agent_gathered_tokens = einsum('b h i j, b h j d -> b h i d', ak_attn, v) + agent_out = einsum('b h i j, b h j d -> b h i d', ak_attn, v) - out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_gathered_tokens) + out = einsum('b h i j, b h j d -> b h i d', qa_attn, agent_out) if exists(mask): out = out.masked_fill(~rearrange(mask, 'b n -> b 1 n 1'), 0.) if exists(self.to_gates): out = out * self.to_gates(x) + agent_out = agent_out * self.to_gates(a) out = self.to_out(out) - agent_out = self.to_out(agent_gathered_tokens) + agent_out = self.to_out(agent_out) return out, agent_out diff --git a/setup.py b/setup.py index c547896..80e8426 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'agent-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.1.2', + version = '0.1.4', license='MIT', description = 'Agent Attention - Pytorch', author = 'Phil Wang',