diff --git a/README.md b/README.md index 1e46c5f..bf0735b 100644 --- a/README.md +++ b/README.md @@ -462,3 +462,12 @@ trainer.train() copyright = {Creative Commons Attribution 4.0 International} } ``` + +```bibtex +@misc{gilmer2023intriguing + title = {Intriguing Properties of Transformer Training Instabilities}, + author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen}, + year = {2023}, + status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams} +} +``` diff --git a/phenaki_pytorch/attention.py b/phenaki_pytorch/attention.py index dfbc4cd..b780b2f 100644 --- a/phenaki_pytorch/attention.py +++ b/phenaki_pytorch/attention.py @@ -19,6 +19,9 @@ def default(val, d): def leaky_relu(p = 0.1): return nn.LeakyReLU(p) +def l2norm(t): + return F.normalize(t, dim = -1) + # bias-less layernorm, being used in more recent T5s, PaLM, also in @borisdayma 's experiments shared with me # greater stability @@ -92,12 +95,13 @@ def __init__( causal = False, num_null_kv = 0, norm_context = True, - dropout = 0. + dropout = 0., + scale = 8 ): super().__init__() self.heads = heads self.causal = causal - self.scale = dim_head ** -0.5 + self.scale = scale inner_dim = dim_head * heads dim_context = default(dim_context, dim) @@ -115,6 +119,9 @@ def __init__( self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim_context, inner_dim * 2, bias = False) + self.q_scale = nn.Parameter(torch.ones(dim_head)) + self.k_scale = nn.Parameter(torch.ones(dim_head)) + self.to_out = nn.Linear(inner_dim, dim, bias = False) def forward( @@ -137,14 +144,16 @@ def forward( q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) - q = q * self.scale - nk, nv = repeat(self.null_kv, 'h (n r) d -> b h n r d', b = batch, r = 2).unbind(dim = -2) k = torch.cat((nk, k), dim = -2) v = torch.cat((nv, v), dim = -2) - sim = einsum('b h i d, b h j d -> b h i j', q, k) + q, k = map(l2norm, (q, k)) + q = q * self.q_scale + k = k * self.k_scale + + sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale i, j = sim.shape[-2:] diff --git a/setup.py b/setup.py index 04c4024..9ae24c5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.2.0', + version = '0.3.0', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang',