From 899c3d9c5146576c372fcb03326308be182d3808 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 19 Nov 2023 14:50:06 -0800 Subject: [PATCH] finish last thought with the repository. will revisit attention for time series at the next sota --- README.md | 32 ++- iTransformer/__init__.py | 4 +- iTransformer/iTransformerNormConditioned.py | 256 ++++++++++++++++++++ setup.py | 2 +- 4 files changed, 290 insertions(+), 4 deletions(-) create mode 100644 iTransformer/iTransformerNormConditioned.py diff --git a/README.md b/README.md index d02e810..ab37bf5 100644 --- a/README.md +++ b/README.md @@ -112,13 +112,41 @@ preds = model(time_series) # -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137)) ``` +### iTransformer with Normalization Statistics Conditining + +Reversible instance normalization, but all statistics across variates are concatted and projected into a conditioning vector for FiLM conditioning after each layernorm in the transformer. + +```python +import torch +from iTransformer import iTransformerNormConditioned + +# using solar energy settings + +model = iTransformerNormConditioned( + num_variates = 137, + lookback_len = 96, # or the lookback length in the paper + dim = 256, # model dimensions + depth = 6, # depth + heads = 8, # attention heads + dim_head = 64, # head dimension + pred_length = (12, 24, 36, 48), # can be one prediction, or many + num_tokens_per_variate = 1, # experimental setting that projects each variate to more than one token. the idea is that the network can learn to divide up into time tokens for more granular attention across time. thanks to flash attention, you should be able to accommodate long sequence lengths just fine +) + +time_series = torch.randn(2, 96, 137) # (batch, lookback len, variates) + +preds = model(time_series) + +# preds -> Dict[int, Tensor[batch, pred_length, variate]] +# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137)) +``` + ## Todo - [x] beef up the transformer with latest findings - [x] improvise a 2d version across both variates and time - [x] improvise a version that includes fft tokens - -- [ ] improvise a variant that uses adaptive normalization conditioned on statistics across all variates +- [x] improvise a variant that uses adaptive normalization conditioned on statistics across all variates ## Citation diff --git a/iTransformer/__init__.py b/iTransformer/__init__.py index b0b19d2..deaecb3 100644 --- a/iTransformer/__init__.py +++ b/iTransformer/__init__.py @@ -5,4 +5,6 @@ from iTransformer.iTransformer2D import iTransformer2D -from iTransformer.iTransformerFFT import iTransformerFFT \ No newline at end of file +from iTransformer.iTransformerFFT import iTransformerFFT + +from iTransformer.iTransformerNormConditioned import iTransformerNormConditioned diff --git a/iTransformer/iTransformerNormConditioned.py b/iTransformer/iTransformerNormConditioned.py new file mode 100644 index 0000000..8838209 --- /dev/null +++ b/iTransformer/iTransformerNormConditioned.py @@ -0,0 +1,256 @@ +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from beartype import beartype +from beartype.typing import Optional, Union, Tuple + +from einops import rearrange, reduce, repeat, pack, unpack +from einops.layers.torch import Rearrange + +from iTransformer.attend import Attend + +# helper functions + +def exists(v): + return v is not None + +def default(v, d): + return v if exists(v) else d + +def identity(t, *args, **kwargs): + return t + +def cast_tuple(t): + return (t,) if not isinstance(t, tuple) else t + +# attention + +class Attention(Module): + def __init__( + self, + dim, + dim_head = 32, + heads = 4, + dropout = 0., + flash = True + ): + super().__init__() + self.scale = dim_head ** -0.5 + dim_inner = dim_head * heads + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias = False), + Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads) + ) + + self.to_v_gates = nn.Sequential( + nn.Linear(dim, dim_inner, bias = False), + nn.SiLU(), + Rearrange('b n (h d) -> b h n d', h = heads) + ) + + self.attend = Attend(flash = flash, dropout = dropout) + + self.to_out = nn.Sequential( + Rearrange('b h n d -> b n (h d)'), + nn.Linear(dim_inner, dim, bias = False), + nn.Dropout(dropout) + ) + + def forward(self, x): + q, k, v = self.to_qkv(x) + + out = self.attend(q, k, v) + + out = out * self.to_v_gates(x) + return self.to_out(out) + +# feedforward + +class GEGLU(Module): + def forward(self, x): + x, gate = rearrange(x, '... (r d) -> r ... d', r = 2) + return x * F.gelu(gate) + +def FeedForward(dim, mult = 4, dropout = 0.): + dim_inner = int(dim * mult * 2 / 3) + return nn.Sequential( + nn.Linear(dim, dim_inner * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim) + ) + +# film conditioning + +class FiLM(Module): + def __init__(self, dim_in, dim_cond): + super().__init__() + self.to_film_gamma_beta = nn.Sequential( + nn.Linear(dim_in, dim_cond * 2), + Rearrange('... (r d) -> r ... d', r = 2) + ) + + def forward(self, x, cond): + gamma, beta = self.to_film_gamma_beta(cond) + return x * gamma + beta + +# main class + +class iTransformerNormConditioned(Module): + @beartype + def __init__( + self, + *, + num_variates: int, + lookback_len: int, + depth: int, + dim: int, + num_tokens_per_variate = 1, + pred_length: Union[int, Tuple[int, ...]], + dim_head = 32, + heads = 4, + attn_dropout = 0., + ff_mult = 4, + ff_dropout = 0., + num_mem_tokens = 4, + flash_attn = True + ): + super().__init__() + self.num_variates = num_variates + self.lookback_len = lookback_len + + self.mem_tokens = nn.Parameter(torch.randn(num_mem_tokens, dim)) if num_mem_tokens > 0 else None + + pred_length = cast_tuple(pred_length) + self.pred_length = pred_length + + dim_cond = dim * 4 + + self.to_norm_condition = nn.Sequential( + nn.Linear(num_variates * 2, dim_cond), + nn.SiLU() + ) + + self.layers = ModuleList([]) + for _ in range(depth): + self.layers.append(ModuleList([ + Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn), + nn.LayerNorm(dim, elementwise_affine = False), + FiLM(dim_cond, dim), + FeedForward(dim, mult = ff_mult, dropout = ff_dropout), + nn.LayerNorm(dim, elementwise_affine = False), + FiLM(dim_cond, dim), + ])) + + self.mlp_in = nn.Sequential( + nn.Linear(lookback_len, dim * num_tokens_per_variate), + Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variate), + nn.LayerNorm(dim) + ) + + self.pred_heads = ModuleList([]) + + for one_pred_length in pred_length: + head = nn.Sequential( + Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variate), + nn.Linear(dim * num_tokens_per_variate, one_pred_length), + Rearrange('b v n -> b n v') + ) + + self.pred_heads.append(head) + + @beartype + def forward( + self, + x: Tensor, + targets: Optional[Union[Tensor, Tuple[Tensor, ...]]] = None, + eps = 1e-5 + + ): + """ + einstein notation + + b - batch + n - time + v - variate + s - norm statistics + """ + has_mem = exists(self.mem_tokens) + assert x.shape[1:] == (self.lookback_len, self.num_variates) + + # the crux of the paper is basically treating variates as the spatial dimension in attention + # there is a lot of opportunity to improve on this, if the paper is successfully replicated + + x = rearrange(x, 'b n v -> b v n') + + # normalize + + mean = x.mean(dim = -1, keepdim = True) + var = x.var(dim = -1, unbiased = False, keepdim = True) + + x = (x - mean) * var.clamp(min = eps).rsqrt() + + # concat statistics for adaptive layernorm + + norm_stats = torch.cat((mean, var), dim = -1) + norm_stats = rearrange(norm_stats, 'b v s -> b 1 (v s)') + + cond = self.to_norm_condition(norm_stats) + + # mlp to tokens + + x = self.mlp_in(x) + + # memory tokens + + if has_mem: + m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0]) + x, mem_ps = pack([m, x], 'b * d') + + # attention and feedforward layers + + for attn, attn_post_norm, attn_film, ff, ff_post_norm, ff_film in self.layers: + x = attn(x) + x + x = attn_post_norm(x) + x = attn_film(x, cond) + + x = ff(x) + x + x = ff_post_norm(x) + x = ff_film(x, cond) + + # splice out memory tokens + + if has_mem: + _, x = unpack(x, mem_ps, 'b * d') + + # denormalize + + x = (x * var.sqrt()) + mean + + # predicting multiple times + + pred_list = [fn(x) for fn in self.pred_heads] + + # calculate loss if targets is passed in + + if exists(targets): + targets = cast_tuple(targets) + assert len(targets) == len(pred_list) + + assert self.training + mse_loss = 0. + for target, pred in zip(targets, pred_list): + assert targets.shape == pred_list.shape + + mse_loss = mse_loss + F.mse_loss(target, pred) + + return mse_loss + + if len(pred_list) == 0: + return pred_list[0] + + pred_dict = dict(zip(self.pred_length, pred_list)) + return pred_dict diff --git a/setup.py b/setup.py index 96a47cb..e0100d7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'iTransformer', packages = find_packages(exclude=[]), - version = '0.4.0', + version = '0.4.1', license='MIT', description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting', author = 'Phil Wang',