Skip to content

Commit

Permalink
bet on new attention stabilization technique
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 13, 2023
1 parent 2c680bb commit 72de0b9
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,12 @@ trainer.train()
copyright = {Creative Commons Attribution 4.0 International}
}
```

```bibtex
@misc{gilmer2023intriguing
title = {Intriguing Properties of Transformer Training Instabilities},
author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
year = {2023},
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
```
19 changes: 14 additions & 5 deletions phenaki_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def default(val, d):
def leaky_relu(p = 0.1):
return nn.LeakyReLU(p)

def l2norm(t):
return F.normalize(t, dim = -1)

# bias-less layernorm, being used in more recent T5s, PaLM, also in @borisdayma 's experiments shared with me
# greater stability

Expand Down Expand Up @@ -92,12 +95,13 @@ def __init__(
causal = False,
num_null_kv = 0,
norm_context = True,
dropout = 0.
dropout = 0.,
scale = 8
):
super().__init__()
self.heads = heads
self.causal = causal
self.scale = dim_head ** -0.5
self.scale = scale
inner_dim = dim_head * heads
dim_context = default(dim_context, dim)

Expand All @@ -115,6 +119,9 @@ def __init__(
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False)

self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(
Expand All @@ -137,14 +144,16 @@ def forward(

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

q = q * self.scale

nk, nv = repeat(self.null_kv, 'h (n r) d -> b h n r d', b = batch, r = 2).unbind(dim = -2)

k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)

sim = einsum('b h i d, b h j d -> b h i j', q, k)
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

i, j = sim.shape[-2:]

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.3.0',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 72de0b9

Please sign in to comment.