Skip to content

Commit

Permalink
adopt a new promising attention stabilizing trick out of Google Brain
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 13, 2023
1 parent 74b16bb commit e021548
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 78 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -888,3 +888,12 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
volume = {abs/2210.02303}
}
```

```bibtex
@misc{gilmer2023intriguing
title = {Intriguing Properties of Transformer Training Instabilities},
author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
year = {2023},
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
```
72 changes: 35 additions & 37 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,10 @@ def __init__(
dim,
dim_head = 64,
heads = 8,
cosine_sim_attn = False
scale = 8
):
super().__init__()
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
self.scale = scale

self.heads = heads
inner_dim = dim_head * heads
Expand All @@ -393,6 +391,9 @@ def __init__(
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
nn.LayerNorm(dim)
Expand All @@ -412,16 +413,15 @@ def forward(self, x, latents, mask = None):

q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)

q = q * self.scale

# cosine sim attention
# qk rmsnorm

if self.cosine_sim_attn:
q, k = map(l2norm, (q, k))
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale

# similarities and masking

sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.scale

if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
Expand Down Expand Up @@ -449,8 +449,7 @@ def __init__(
num_latents = 64,
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
max_seq_len = 512,
ff_mult = 4,
cosine_sim_attn = False
ff_mult = 4
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, dim)
Expand All @@ -469,7 +468,7 @@ def __init__(
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim = dim, mult = ff_mult)
]))

Expand Down Expand Up @@ -502,12 +501,10 @@ def __init__(
dim_head = 64,
heads = 8,
context_dim = None,
cosine_sim_attn = False
scale = 8
):
super().__init__()
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
self.scale = scale

self.heads = heads
inner_dim = dim_head * heads
Expand All @@ -518,6 +515,9 @@ def __init__(
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)

self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None

self.to_out = nn.Sequential(
Expand All @@ -533,7 +533,6 @@ def forward(self, x, context = None, mask = None, attn_bias = None):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))

q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale

# add null key / value for classifier free guidance in prior net

Expand All @@ -549,14 +548,15 @@ def forward(self, x, context = None, mask = None, attn_bias = None):
k = torch.cat((ck, k), dim = -2)
v = torch.cat((cv, v), dim = -2)

# cosine sim attention
# qk rmsnorm

if self.cosine_sim_attn:
q, k = map(l2norm, (q, k))
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale

# calculate query / key similarities

sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale

# relative positional encoding (T5 style)

Expand Down Expand Up @@ -761,12 +761,10 @@ def __init__(
dim_head = 64,
heads = 8,
norm_context = False,
cosine_sim_attn = False
scale = 8
):
super().__init__()
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
self.scale = scale

self.heads = heads
inner_dim = dim_head * heads
Expand All @@ -780,6 +778,9 @@ def __init__(
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

self.q_scale = nn.Parameter(torch.ones(dim_head))
self.k_scale = nn.Parameter(torch.ones(dim_head))

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
Expand All @@ -802,16 +803,15 @@ def forward(self, x, context, mask = None):
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)

q = q * self.scale

# cosine sim attention

if self.cosine_sim_attn:
q, k = map(l2norm, (q, k))
q, k = map(l2norm, (q, k))
q = q * self.q_scale
k = k * self.k_scale

# similarities

sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

# masking

Expand Down Expand Up @@ -994,15 +994,14 @@ def __init__(
heads = 8,
dim_head = 32,
ff_mult = 2,
context_dim = None,
cosine_sim_attn = False
context_dim = None
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn),
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
FeedForward(dim = dim, mult = ff_mult)
]))

Expand Down Expand Up @@ -1153,7 +1152,6 @@ def __init__(
scale_skip_connection = True,
final_resnet_block = True,
final_conv_kernel_size = 3,
cosine_sim_attn = False,
self_cond = False,
resize_mode = 'nearest',
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
Expand Down Expand Up @@ -1265,7 +1263,7 @@ def __init__(

# attention pooling

self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents) if attn_pool_text else None

# for classifier free guidance

Expand All @@ -1288,7 +1286,7 @@ def __init__(

# attention related params

attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)

num_layers = len(in_out)

Expand Down
Loading

0 comments on commit e021548

Please sign in to comment.