Skip to content

Commit

Permalink
refactor transcribed jax code
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 4, 2020
1 parent 8e470aa commit a23464e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from einops import rearrange, repeat
from functools import partial

from local_attention import LocalAttention
Expand Down Expand Up @@ -34,17 +34,18 @@ def find_modules(nn_module, type):
# https://github.com/google-research/google-research/blob/master/performer/fast_self_attention/fast_self_attention.py

def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
b, h, *_ = data.shape

if normalize_data:
data_normalizer = 1.0 / (data.shape[-1] ** 0.25)
else:
data_normalizer = 1.0

ratio = 1.0 / (projection_matrix.shape[0] ** 0.5)

data_mod_shape = data.shape[:(len(data.shape) - 2)] + projection_matrix.shape
data_thick_random_matrix = torch.zeros(data_mod_shape, device = device) + projection_matrix
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)

data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), data_thick_random_matrix)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)

diag_data = data ** 2
diag_data = torch.sum(diag_data, dim=-1)
Expand All @@ -62,6 +63,8 @@ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, ep
return data_dash

def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None):
b, h, *_ = data.shape

if normalize_data:
data_normalizer = 1.0 / (data.shape[-1] ** 0.25)
else:
Expand All @@ -70,10 +73,9 @@ def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel
if projection_matrix is None:
return kernel_fn(data_normalizer * data) + kernel_epsilon

data_mod_shape = data.shape[0:len(data.shape) - 2] + projection_matrix.shape
data_thick_random_matrix = torch.zeros(data_mod_shape, device = device) + projection_matrix
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)

data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), data_thick_random_matrix)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)

data_prime = kernel_fn(data_dash) + kernel_epsilon
return data_prime
Expand Down

0 comments on commit a23464e

Please sign in to comment.