From e248cd7d82dc94c1641d9a98b3a408be23c6e9bf Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 29 Oct 2023 09:52:12 -0700 Subject: [PATCH] complete a 2d variant that does attention across both variates and time --- README.md | 31 +++++++- iTransformer/attend.py | 30 +++---- iTransformer/iTransformer2D.py | 141 +++++++++++++++++++++++++++------ setup.py | 2 +- 4 files changed, 160 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 885517a..dd8310d 100644 --- a/README.md +++ b/README.md @@ -48,11 +48,38 @@ preds = model(time_series) # -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137)) ``` +For an improvised version that does granular attention across time tokens (as well as the original per-variate tokens), just import `iTransformer2D` and set the additional `num_time_tokens` + +```python +import torch +from iTransformer import iTransformer2D + +# using solar energy settings + +model = iTransformer2D( + num_variates = 137, + num_time_tokens = 16, # number of time tokens (patch size will be (look back length // num_time_tokens)) + lookback_len = 96, # the lookback length in the paper + dim = 256, # model dimensions + depth = 6, # depth + heads = 8, # attention heads + dim_head = 64, # head dimension + pred_length = (12, 24, 36, 48), # can be one prediction, or many + use_reversible_instance_norm = True # use reversible instance normalization +) + +time_series = torch.randn(2, 96, 137) # (batch, lookback len, variates) + +preds = model(time_series) + +# preds -> Dict[int, Tensor[batch, pred_length, variate]] +# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137)) +``` + ## Todo - [x] beef up the transformer with latest findings - -- [ ] improvise a 2d version - either global pool across time at end, or use a CLS token for attention pooling +- [x] improvise a 2d version across both variates and time ## Citation diff --git a/iTransformer/attend.py b/iTransformer/attend.py index e0b2451..abf62ab 100644 --- a/iTransformer/attend.py +++ b/iTransformer/attend.py @@ -45,6 +45,7 @@ def __init__( heads = None, scale = None, flash = False, + causal = False ): super().__init__() self.scale = scale @@ -52,6 +53,8 @@ def __init__( self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) + self.causal = causal + # flash attention self.flash = flash @@ -81,17 +84,10 @@ def __init__( def flash_attn( self, - q, k, v, - mask = None + q, k, v ): batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device - # expand key padding mask - - if exists(mask): - assert mask.ndim == 4 - mask = mask.expand(batch, heads, q_len, k_len) - # Check if there is a compatible device for flash attention config = self.cuda_config if is_cuda else self.cpu_config @@ -101,7 +97,7 @@ def flash_attn( with torch.backends.cuda.sdp_kernel(**config._asdict()): out = F.scaled_dot_product_attention( q, k, v, - attn_mask = mask, + is_causal = self.causal, dropout_p = self.dropout if self.training else 0. ) @@ -109,8 +105,7 @@ def flash_attn( def forward( self, - q, k, v, - mask = None + q, k, v ): """ einstein notation @@ -125,16 +120,15 @@ def forward( scale = default(self.scale, q.shape[-1] ** -0.5) if self.flash: - return self.flash_attn(q, k, v, mask = mask) + return self.flash_attn(q, k, v) sim = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale - i, j, dtype = *sim.shape[-2:], sim.dtype - - mask_value = -torch.finfo(sim.dtype).max - - if exists(mask): - sim = sim.masked_fill(~mask, mask_value) + if self.causal: + i, j, dtype = *sim.shape[-2:], sim.dtype + mask_value = -torch.finfo(sim.dtype).max + causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) + sim = sim.masked_fill(~causal_mask, mask_value) attn = sim.softmax(dim = -1) attn = attn.type(dtype) diff --git a/iTransformer/iTransformer2D.py b/iTransformer/iTransformer2D.py index eac686a..655132e 100644 --- a/iTransformer/iTransformer2D.py +++ b/iTransformer/iTransformer2D.py @@ -22,9 +22,18 @@ def exists(v): def default(v, d): return v if exists(v) else d +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + def identity(t, *args, **kwargs): return t +def divisible_by(num, den): + return (num % den) == 0 + def cast_tuple(t): return (t,) if not isinstance(t, tuple) else t @@ -37,12 +46,16 @@ def __init__( dim_head = 32, heads = 4, dropout = 0., - flash = True + causal = False, + flash = True, + rotary_emb: Optional[RotaryEmbedding] = None, ): super().__init__() self.scale = dim_head ** -0.5 dim_inner = dim_head * heads + self.rotary_emb = rotary_emb + self.to_qkv = nn.Sequential( nn.Linear(dim, dim_inner * 3, bias = False), Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads) @@ -54,7 +67,7 @@ def __init__( Rearrange('b n (h d) -> b h n d', h = heads) ) - self.attend = Attend(flash = flash, dropout = dropout) + self.attend = Attend(flash = flash, dropout = dropout, causal = causal) self.to_out = nn.Sequential( Rearrange('b h n d -> b n (h d)'), @@ -62,9 +75,13 @@ def __init__( nn.Dropout(dropout) ) + @beartype def forward(self, x): q, k, v = self.to_qkv(x) + if exists(self.rotary_emb): + q, k = map(lambda t: rotary_emb.rotate_queries_or_keys(t), (q, k)) + out = self.attend(q, k, v) out = out * self.to_v_gates(x) @@ -86,6 +103,39 @@ def FeedForward(dim, mult = 4, dropout = 0.): nn.Linear(dim_inner, dim) ) +# transformer block + +class TransformerBlock(Module): + def __init__( + self, + *, + dim, + causal = False, + dim_head = 32, + heads = 8, + ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., + rotary_emb: Optional[RotaryEmbedding] = None, + ): + super().__init__() + self.rotary_emb = rotary_emb + + self.attn = Attention(rotary_emb = rotary_emb, causal = causal, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) + self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) + self.attn_norm = nn.LayerNorm(dim) + self.ff_norm = nn.LayerNorm(dim) + + def forward(self, x, rotary_emb: Optional[RotaryEmbedding] = None): + + x = self.attn(x) + x + x = self.attn_norm(x) + + x = self.ff(x) + x + x = self.ff_norm(x) + + return x + # main class class iTransformer2D(Module): @@ -95,9 +145,9 @@ def __init__( *, num_variates: int, lookback_len: int, + num_time_tokens: int, depth: int, dim: int, - num_tokens_per_variate = 1, pred_length: Union[int, Tuple[int, ...]], dim_head = 32, heads = 4, @@ -109,8 +159,12 @@ def __init__( flash_attn = True ): super().__init__() + assert divisible_by(lookback_len, num_time_tokens) + assert num_time_tokens >= 2 + self.num_variates = num_variates self.lookback_len = lookback_len + self.num_time_tokens = num_time_tokens self.mem_tokens = nn.Parameter(torch.randn(num_mem_tokens, dim)) if num_mem_tokens > 0 else None @@ -119,18 +173,33 @@ def __init__( self.reversible_instance_norm = RevIN(num_variates) if use_reversible_instance_norm else None + self.rotary_emb = RotaryEmbedding(dim_head) + self.layers = ModuleList([]) + + block_kwargs = dict( + dim = dim, + dim_head = dim_head, + heads = heads, + ff_mult = ff_mult, + attn_dropout = attn_dropout, + ff_dropout = ff_dropout + ) + for _ in range(depth): self.layers.append(ModuleList([ - Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn), - nn.LayerNorm(dim), - FeedForward(dim, mult = ff_mult, dropout = ff_dropout), - nn.LayerNorm(dim) + TransformerBlock(causal = True, **block_kwargs), + TransformerBlock(causal = False, **block_kwargs) ])) - self.mlp_in = nn.Sequential( - nn.Linear(lookback_len, dim * num_tokens_per_variate), - Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variate), + self.to_variate_token = nn.Sequential( + nn.Linear(lookback_len, dim), + nn.LayerNorm(dim) + ) + + self.to_time_tokens = nn.Sequential( + Rearrange('b v (t n) -> b v t n', t = num_time_tokens), + nn.Linear(lookback_len // num_time_tokens, dim), nn.LayerNorm(dim) ) @@ -138,8 +207,7 @@ def __init__( for one_pred_length in pred_length: head = nn.Sequential( - Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variate), - nn.Linear(dim * num_tokens_per_variate, one_pred_length), + nn.Linear(dim, one_pred_length), Rearrange('b v n -> b n v') ) @@ -157,6 +225,7 @@ def forward( b - batch n - time v - variate + t - number of time tokens """ has_mem = exists(self.mem_tokens) @@ -170,35 +239,61 @@ def forward( if exists(self.reversible_instance_norm): x, reverse_fn = self.reversible_instance_norm(x) - x = self.mlp_in(x) + # derive the time tokens per variate 't' + + t = self.to_time_tokens(x) + + # 'v' will be the variate pool token, which is the same as the token per variate from iTransformer + + v = self.to_variate_token(x) + + # combine time and variate tokens into 2d feature map of variates and time + + v = rearrange(v, 'b v d -> b v 1 d') + x = torch.cat((t, v), dim = -2) # memory tokens if has_mem: - m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0]) - x, mem_ps = pack([m, x], 'b * d') + m = repeat(self.mem_tokens, 'm d -> b m t d', b = x.shape[0], t = x.shape[-2]) + x, mem_ps = pack([m, x], 'b * t d') # attention and feedforward layers - for attn, attn_post_norm, ff, ff_post_norm in self.layers: - x = attn(x) + x - x = attn_post_norm(x) - x = ff(x) + x - x = ff_post_norm(x) + for time_attn_block, variate_attn_block in self.layers: + x, ps = pack_one(x, '* t d') + + # causal attention across time for each variate + x = time_attn_block(x) + + x = unpack_one(x, ps, '* t d') + + x = rearrange(x, 'b v t d -> b t v d') + x, ps = pack_one(x, '* v d') + + # full attention across variates (as in inverted Transformer paper) + x = variate_attn_block(x) + + x = unpack_one(x, ps, '* v d') + x = rearrange(x, 'b t v d -> b v t d') # splice out memory tokens if has_mem: - _, x = unpack(x, mem_ps, 'b * d') + _, x = unpack(x, mem_ps, 'b * t d') + + # get back the original variate pooled tokens + + v = x[..., -1, :] # reversible instance normaization, if needed if exists(self.reversible_instance_norm): - x = reverse_fn(x) + v = reverse_fn(v) # predicting multiple times - pred_list = [fn(x) for fn in self.pred_heads] + pred_list = [fn(v) for fn in self.pred_heads] # calculate loss if targets is passed in diff --git a/setup.py b/setup.py index 40fa8b8..a611ce2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'iTransformer', packages = find_packages(exclude=[]), - version = '0.2.0', + version = '0.3.0', license='MIT', description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting', author = 'Phil Wang',