Skip to content

Commit

Permalink
Merge pull request #11 from lucidrains/2d-attention
Browse files Browse the repository at this point in the history
2d attention, attention across variates as well as time
  • Loading branch information
lucidrains authored Oct 29, 2023
2 parents a2a6ddf + e248cd7 commit 8c587e5
Show file tree
Hide file tree
Showing 7 changed files with 425 additions and 50 deletions.
31 changes: 29 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,38 @@ preds = model(time_series)
# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))
```

For an improvised version that does granular attention across time tokens (as well as the original per-variate tokens), just import `iTransformer2D` and set the additional `num_time_tokens`

```python
import torch
from iTransformer import iTransformer2D

# using solar energy settings

model = iTransformer2D(
num_variates = 137,
num_time_tokens = 16, # number of time tokens (patch size will be (look back length // num_time_tokens))
lookback_len = 96, # the lookback length in the paper
dim = 256, # model dimensions
depth = 6, # depth
heads = 8, # attention heads
dim_head = 64, # head dimension
pred_length = (12, 24, 36, 48), # can be one prediction, or many
use_reversible_instance_norm = True # use reversible instance normalization
)

time_series = torch.randn(2, 96, 137) # (batch, lookback len, variates)

preds = model(time_series)

# preds -> Dict[int, Tensor[batch, pred_length, variate]]
# -> (12: (2, 12, 137), 24: (2, 24, 137), 36: (2, 36, 137), 48: (2, 48, 137))
```

## Todo

- [x] beef up the transformer with latest findings

- [ ] improvise a 2d version - either global pool across time at end, or use a CLS token for attention pooling
- [x] improvise a 2d version across both variates and time

## Citation

Expand Down
2 changes: 2 additions & 0 deletions iTransformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
iTransformer,
RevIN
)

from iTransformer.iTransformer2D import iTransformer2D
30 changes: 12 additions & 18 deletions iTransformer/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,16 @@ def __init__(
heads = None,
scale = None,
flash = False,
causal = False
):
super().__init__()
self.scale = scale

self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

self.causal = causal

# flash attention

self.flash = flash
Expand Down Expand Up @@ -81,17 +84,10 @@ def __init__(

def flash_attn(
self,
q, k, v,
mask = None
q, k, v
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# expand key padding mask

if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config
Expand All @@ -101,16 +97,15 @@ def flash_attn(
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
is_causal = self.causal,
dropout_p = self.dropout if self.training else 0.
)

return out

def forward(
self,
q, k, v,
mask = None
q, k, v
):
"""
einstein notation
Expand All @@ -125,16 +120,15 @@ def forward(
scale = default(self.scale, q.shape[-1] ** -0.5)

if self.flash:
return self.flash_attn(q, k, v, mask = mask)
return self.flash_attn(q, k, v)

sim = einsum(f'b h i d, b h j d -> b h i j', q, k) * scale

i, j, dtype = *sim.shape[-2:], sim.dtype

mask_value = -torch.finfo(sim.dtype).max

if exists(mask):
sim = sim.masked_fill(~mask, mask_value)
if self.causal:
i, j, dtype = *sim.shape[-2:], sim.dtype
mask_value = -torch.finfo(sim.dtype).max
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(~causal_mask, mask_value)

attn = sim.softmax(dim = -1)
attn = attn.type(dtype)
Expand Down
31 changes: 2 additions & 29 deletions iTransformer/iTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import torch.nn.functional as F

from beartype import beartype
from beartype.typing import Optional, Union, Tuple, Callable
from beartype.typing import Optional, Union, Tuple

from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

from iTransformer.attend import Attend
from iTransformer.revin import RevIN

# helper functions

Expand All @@ -25,34 +26,6 @@ def identity(t, *args, **kwargs):
def cast_tuple(t):
return (t,) if not isinstance(t, tuple) else t

# reversible instance normalization
# proposed in https://openreview.net/forum?id=cGDAkQo1C0p

class RevIN(Module):
def __init__(self, num_variates, eps = 1e-5):
super().__init__()
self.eps = eps
self.num_variates = num_variates
self.gamma = nn.Parameter(torch.ones(num_variates, 1))
self.beta = nn.Parameter(torch.zeros(num_variates, 1))

@beartype
def forward(self, x) -> Tuple[Tensor, Callable[Tensor, Tensor]]:
assert x.shape[1] == self.num_variates

var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
var_rsqrt = var.clamp(min = self.eps).rsqrt()
instance_normalized = (x - mean) * var_rsqrt
rescaled = instance_normalized * self.gamma + self.beta

def reverse_fn(scaled_output):
clamped_gamma = torch.sign(self.gamma) * self.gamma.abs().clamp(min = self.eps)
unscaled_output = (scaled_output - self.beta) / clamped_gamma
return unscaled_output * var.sqrt() + mean

return rescaled, reverse_fn

# attention

class Attention(Module):
Expand Down
Loading

0 comments on commit 8c587e5

Please sign in to comment.