Skip to content

Commit

Permalink
fix bug with reversible networks and redrawing projections on forward
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 11, 2020
1 parent 31ac2e7 commit 79a3357
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
24 changes: 13 additions & 11 deletions performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, ep

ratio = (projection_matrix.shape[0] ** -0.5)

projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h).type_as(data)
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)

data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)

Expand Down Expand Up @@ -176,16 +176,17 @@ def __init__(self, dim_heads, nb_features = None, feature_redraw_interval = 1000
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
self.causal_linear_fn = causal_linear_attention_noncuda

def forward(self, q, k, v):
def forward(self, q, k, v, can_redraw_projection = True):
device = q.device

# It's time to redraw the projection matrix
if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
self.projection_matrix = self.create_projection(device = device)
self.calls_since_last_redraw = torch.tensor(0)
# Keep track of how many forward passes we do before we redraw again
else:
self.calls_since_last_redraw += 1
if can_redraw_projection:
# It's time to redraw the projection matrix
if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
self.projection_matrix.copy_(self.create_projection(device = device))
self.calls_since_last_redraw = torch.tensor(0)
# Keep track of how many forward passes we do before we redraw again
else:
self.calls_since_last_redraw += 1

if self.generalized_attention:
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
Expand Down Expand Up @@ -283,8 +284,9 @@ def __init__(self, dim, causal = False, heads = 8, local_heads = 0, local_window
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)

def forward(self, x, context = None, mask = None, context_mask = None):
def forward(self, x, context = None, mask = None, context_mask = None, **kwargs):
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
is_reverse = kwargs.pop('_reverse', False)

cross_attend = exists(context)
context = default(context, x)
Expand All @@ -302,7 +304,7 @@ def forward(self, x, context = None, mask = None, context_mask = None):
global_mask = context_mask[:, None, :, None]
k.masked_fill_(~global_mask, 0)

out = self.fast_attention(q, k, v)
out = self.fast_attention(q, k, v, redraw_projection = not is_reverse)
attn_outs.append(out)

if not empty(lq):
Expand Down
3 changes: 3 additions & 0 deletions performer_pytorch/reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, f, g):
def forward(self, x, f_args = {}, g_args = {}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
f_args['_reverse'] = g_args['_reverse'] = False

with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
Expand All @@ -74,6 +75,8 @@ def backward_pass(self, y, dy, f_args = {}, g_args = {}):
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy

f_args['_reverse'] = g_args['_reverse'] = True

with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.12.8',
version = '0.12.9',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 79a3357

Please sign in to comment.