diff --git a/performer_pytorch/performer_pytorch.py b/performer_pytorch/performer_pytorch.py index 8fac32f..bd08509 100644 --- a/performer_pytorch/performer_pytorch.py +++ b/performer_pytorch/performer_pytorch.py @@ -128,11 +128,10 @@ def causal_linear_attention(q, k, v): # inefficient causal linear attention, without cuda code, for reader's reference # not being used def causal_linear_attention_noncuda(q, k, v): - k_cumsum = k.cumsum(dim=-2) + D_inv = torch.einsum('...nd,...nd->...n', q, k.cumsum(dim=-2)) context = torch.einsum('...nd,...ne->...nde', k, v) context = context.cumsum(dim=-3) - context /= k_cumsum.unsqueeze(dim=-1) - out = torch.einsum('...nde,...nd->...ne', context, q) + out = torch.einsum('...nde,...nd,...n->...ne', context, q, D_inv) return out class FastAttention(nn.Module): diff --git a/setup.py b/setup.py index b5e8c5d..9a8872e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'performer-pytorch', packages = find_packages(exclude=['examples']), - version = '0.7.4', + version = '0.7.5', license='MIT', description = 'Performer - Pytorch', author = 'Phil Wang',