Skip to content

Commit

Permalink
fix normalization for fast cuda version of causal
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 7, 2020
1 parent 98532fc commit f2dff9a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def linear_attention(q, k, v):
# efficient causal linear attention, created by EPFL
def causal_linear_attention(q, k, v):
from fast_transformers.causal_product import CausalDotProduct
return CausalDotProduct.apply(q, k, v)
D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k.cumsum(dim=-2))
out = CausalDotProduct.apply(q, k, v)
out = torch.einsum('...nd,...n->...nd', out, D_inv)
return out

# inefficient causal linear attention, without cuda code, for reader's reference
# not being used
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.11.2',
version = '0.11.3',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f2dff9a

Please sign in to comment.