Skip to content

Commit

Permalink
complete a 2d variant that does attention across both variates and time
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 29, 2023
1 parent 8d8b5c9 commit e248cd7
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 44 deletions.
31 changes: 29 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,38 @@ preds = model(time_series)
# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))
```

For an improvised version that does granular attention across time tokens (as well as the original per-variate tokens), just import `iTransformer2D` and set the additional `num_time_tokens`

```python
import torch
from iTransformer import iTransformer2D

# using solar energy settings

model = iTransformer2D(
num_variates = 137,
num_time_tokens = 16, # number of time tokens (patch size will be (look back length // num_time_tokens))
lookback_len = 96, # the lookback length in the paper
dim = 256, # model dimensions
depth = 6, # depth
heads = 8, # attention heads
dim_head = 64, # head dimension
pred_length = (12, 24, 36, 48), # can be one prediction, or many
use_reversible_instance_norm = True # use reversible instance normalization
)

time_series = torch.randn(2, 96, 137) # (batch, lookback len, variates)

preds = model(time_series)

# preds -> Dict[int, Tensor[batch, pred_length, variate]]
# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))
```

## Todo

- [x] beef up the transformer with latest findings

- [ ] improvise a 2d version - either global pool across time at end, or use a CLS token for attention pooling
- [x] improvise a 2d version across both variates and time

## Citation

Expand Down
30 changes: 12 additions & 18 deletions iTransformer/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ def __init__(
heads = None,
scale = None,
flash = False,
causal = False
):
super().__init__()
self.scale = scale

self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.causal = causal

# flash attention

self.flash = flash
Expand Down Expand Up @@ -81,17 +84,10 @@ def __init__(

def flash_attn(
self,
q, k, v,
mask = None
q, k, v
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# expand key padding mask

if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config
Expand All @@ -101,16 +97,15 @@ def flash_attn(
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
is_causal = self.causal,
dropout_p = self.dropout if self.training else 0.
)

return out

def forward(
self,
q, k, v,
mask = None
q, k, v
):
"""
einstein notation
Expand All @@ -125,16 +120,15 @@ def forward(
scale = default(self.scale, q.shape[-1] ** -0.5)

if self.flash:
return self.flash_attn(q, k, v, mask = mask)
return self.flash_attn(q, k, v)

sim = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale

i, j, dtype = *sim.shape[-2:], sim.dtype

mask_value = -torch.finfo(sim.dtype).max

if exists(mask):
sim = sim.masked_fill(~mask, mask_value)
if self.causal:
i, j, dtype = *sim.shape[-2:], sim.dtype
mask_value = -torch.finfo(sim.dtype).max
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(~causal_mask, mask_value)

attn = sim.softmax(dim = -1)
attn = attn.type(dtype)
Expand Down
141 changes: 118 additions & 23 deletions iTransformer/iTransformer2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,18 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def pack_one(t, pattern):
return pack([t], pattern)

def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def identity(t, *args, **kwargs):
return t

def divisible_by(num, den):
return (num % den) == 0

def cast_tuple(t):
return (t,) if not isinstance(t, tuple) else t

Expand All @@ -37,12 +46,16 @@ def __init__(
dim_head = 32,
heads = 4,
dropout = 0.,
flash = True
causal = False,
flash = True,
rotary_emb: Optional[RotaryEmbedding] = None,
):
super().__init__()
self.scale = dim_head ** -0.5
dim_inner = dim_head * heads

self.rotary_emb = rotary_emb

self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias = False),
Rearrange('b n (qkv h d) -> qkv b h n d', qkv = 3, h = heads)
Expand All @@ -54,17 +67,21 @@ def __init__(
Rearrange('b n (h d) -> b h n d', h = heads)
)

self.attend = Attend(flash = flash, dropout = dropout)
self.attend = Attend(flash = flash, dropout = dropout, causal = causal)

self.to_out = nn.Sequential(
Rearrange('b h n d -> b n (h d)'),
nn.Linear(dim_inner, dim, bias = False),
nn.Dropout(dropout)
)

@beartype
def forward(self, x):
q, k, v = self.to_qkv(x)

if exists(self.rotary_emb):
q, k = map(lambda t: rotary_emb.rotate_queries_or_keys(t), (q, k))

out = self.attend(q, k, v)

out = out * self.to_v_gates(x)
Expand All @@ -86,6 +103,39 @@ def FeedForward(dim, mult = 4, dropout = 0.):
nn.Linear(dim_inner, dim)
)

# transformer block

class TransformerBlock(Module):
def __init__(
self,
*,
dim,
causal = False,
dim_head = 32,
heads = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
rotary_emb: Optional[RotaryEmbedding] = None,
):
super().__init__()
self.rotary_emb = rotary_emb

self.attn = Attention(rotary_emb = rotary_emb, causal = causal, dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn_norm = nn.LayerNorm(dim)
self.ff_norm = nn.LayerNorm(dim)

def forward(self, x, rotary_emb: Optional[RotaryEmbedding] = None):

x = self.attn(x) + x
x = self.attn_norm(x)

x = self.ff(x) + x
x = self.ff_norm(x)

return x

# main class

class iTransformer2D(Module):
Expand All @@ -95,9 +145,9 @@ def __init__(
*,
num_variates: int,
lookback_len: int,
num_time_tokens: int,
depth: int,
dim: int,
num_tokens_per_variate = 1,
pred_length: Union[int, Tuple[int, ...]],
dim_head = 32,
heads = 4,
Expand All @@ -109,8 +159,12 @@ def __init__(
flash_attn = True
):
super().__init__()
assert divisible_by(lookback_len, num_time_tokens)
assert num_time_tokens >= 2

self.num_variates = num_variates
self.lookback_len = lookback_len
self.num_time_tokens = num_time_tokens

self.mem_tokens = nn.Parameter(torch.randn(num_mem_tokens, dim)) if num_mem_tokens > 0 else None

Expand All @@ -119,27 +173,41 @@ def __init__(

self.reversible_instance_norm = RevIN(num_variates) if use_reversible_instance_norm else None

self.rotary_emb = RotaryEmbedding(dim_head)

self.layers = ModuleList([])

block_kwargs = dict(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),
nn.LayerNorm(dim),
FeedForward(dim, mult = ff_mult, dropout = ff_dropout),
nn.LayerNorm(dim)
TransformerBlock(causal = True, **block_kwargs),
TransformerBlock(causal = False, **block_kwargs)
]))

self.mlp_in = nn.Sequential(
nn.Linear(lookback_len, dim * num_tokens_per_variate),
Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variate),
self.to_variate_token = nn.Sequential(
nn.Linear(lookback_len, dim),
nn.LayerNorm(dim)
)

self.to_time_tokens = nn.Sequential(
Rearrange('b v (t n) -> b v t n', t = num_time_tokens),
nn.Linear(lookback_len // num_time_tokens, dim),
nn.LayerNorm(dim)
)

self.pred_heads = ModuleList([])

for one_pred_length in pred_length:
head = nn.Sequential(
Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variate),
nn.Linear(dim * num_tokens_per_variate, one_pred_length),
nn.Linear(dim, one_pred_length),
Rearrange('b v n -> b n v')
)

Expand All @@ -157,6 +225,7 @@ def forward(
b - batch
n - time
v - variate
t - number of time tokens
"""

has_mem = exists(self.mem_tokens)
Expand All @@ -170,35 +239,61 @@ def forward(
if exists(self.reversible_instance_norm):
x, reverse_fn = self.reversible_instance_norm(x)

x = self.mlp_in(x)
# derive the time tokens per variate 't'

t = self.to_time_tokens(x)

# 'v' will be the variate pool token, which is the same as the token per variate from iTransformer

v = self.to_variate_token(x)

# combine time and variate tokens into 2d feature map of variates and time

v = rearrange(v, 'b v d -> b v 1 d')
x = torch.cat((t, v), dim = -2)

# memory tokens

if has_mem:
m = repeat(self.mem_tokens, 'm d -> b m d', b = x.shape[0])
x, mem_ps = pack([m, x], 'b * d')
m = repeat(self.mem_tokens, 'm d -> b m t d', b = x.shape[0], t = x.shape[-2])
x, mem_ps = pack([m, x], 'b * t d')

# attention and feedforward layers

for attn, attn_post_norm, ff, ff_post_norm in self.layers:
x = attn(x) + x
x = attn_post_norm(x)
x = ff(x) + x
x = ff_post_norm(x)
for time_attn_block, variate_attn_block in self.layers:
x, ps = pack_one(x, '* t d')

# causal attention across time for each variate
x = time_attn_block(x)

x = unpack_one(x, ps, '* t d')

x = rearrange(x, 'b v t d -> b t v d')
x, ps = pack_one(x, '* v d')

# full attention across variates (as in inverted Transformer paper)
x = variate_attn_block(x)

x = unpack_one(x, ps, '* v d')
x = rearrange(x, 'b t v d -> b v t d')

# splice out memory tokens

if has_mem:
_, x = unpack(x, mem_ps, 'b * d')
_, x = unpack(x, mem_ps, 'b * t d')

# get back the original variate pooled tokens

v = x[..., -1, :]

# reversible instance normaization, if needed

if exists(self.reversible_instance_norm):
x = reverse_fn(x)
v = reverse_fn(v)

# predicting multiple times

pred_list = [fn(x) for fn in self.pred_heads]
pred_list = [fn(v) for fn in self.pred_heads]

# calculate loss if targets is passed in

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.3.0',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand Down

0 comments on commit e248cd7

Please sign in to comment.