Skip to content

Commit

Permalink
expose cross attention, in ready to use Performer in alphafold2 repo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 20, 2021
1 parent 924ea74 commit 639735c
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,24 @@ x = torch.randn(1, 1024, 512).cuda()
attn(x) # (1, 1024, 512)
```

Cross attention is similarly

```python
import torch
from performer_pytorch import CrossAttention

attn = CrossAttention(
dim = 512,
heads = 8,
causal = False,
).cuda()

x = torch.randn(1, 1024, 512).cuda()
context = torch.randn(1, 512, 512).cuda()

attn(x, context = context) # (1, 1024, 512)
```

To minimize model surgery, you could also simply rewrite the code, so that the attention step is done by the `FastAttention` module, as follows.

```python
Expand Down
2 changes: 1 addition & 1 deletion performer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention
from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from performer_pytorch.performer_enc_dec import PerformerEncDec
18 changes: 14 additions & 4 deletions performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ def causal_linear_attention(q, k, v, eps = 1e-6):

# inefficient causal linear attention, without cuda code, for reader's reference
# not being used
def causal_linear_attention_noncuda(q, k, v, chunk_size = 128):
def causal_linear_attention_noncuda(q, k, v, chunk_size = 128, eps = 1e-6):
last_k_cumsum = 0
last_context_cumsum = 0
outs = []

for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))):
k_cumsum = last_k_cumsum + k.cumsum(dim=-2)

D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))
D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q) + eps)
context = torch.einsum('...nd,...ne->...nde', k, v)
context_cumsum = last_context_cumsum + context.cumsum(dim=-3)
out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv)
Expand Down Expand Up @@ -305,7 +305,7 @@ def forward(self, x, **kwargs):
x = self.w2(x)
return x

class SelfAttention(nn.Module):
class Attention(nn.Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -375,6 +375,16 @@ def forward(self, x, pos_emb = None, context = None, mask = None, context_mask =
out = self.to_out(out)
return self.dropout(out)

class SelfAttention(Attention):
def forward(self, *args, context = None, **kwargs):
assert not exists(context), 'self attention should not receive context'
return super().forward(*args, **kwargs)

class CrossAttention(Attention):
def forward(self, *args, context = None, **kwargs):
assert exists(context), 'cross attention should receive context'
return super().forward(*args, context = context, **kwargs)

# positional embeddings

class AbsolutePositionalEmbedding(nn.Module):
Expand Down Expand Up @@ -469,7 +479,7 @@ def __init__(
continue

layers.append(nn.ModuleList([
wrapper_fn(SelfAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)),
wrapper_fn(CrossAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias, attn_out_bias = attn_out_bias)),
wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.0.7',
version = '1.0.8',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 639735c

Please sign in to comment.