Skip to content

Commit

Permalink
allow for network to produce multiple tokens per variant, so to let i…
Browse files Browse the repository at this point in the history
…t decide how to divide up time
  • Loading branch information
lucidrains committed Oct 11, 2023
1 parent faa2594 commit 58e068b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions iTransformer/iTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ class iTransformer(Module):
def __init__(
self,
*,
num_variates,
lookback_len,
depth,
dim,
num_variates: int,
lookback_len: int,
depth: int,
dim: int,
num_tokens_per_variant = 1,
pred_length: Union[int, Tuple[int, ...]],
dim_head = 32,
heads = 4,
Expand Down Expand Up @@ -110,15 +111,17 @@ def __init__(
]))

self.mlp_in = nn.Sequential(
nn.Linear(lookback_len, dim),
nn.Linear(lookback_len, dim * num_tokens_per_variant),
Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variant),
nn.LayerNorm(dim)
)

self.pred_heads = ModuleList([])

for one_pred_length in pred_length:
head = nn.Sequential(
nn.Linear(dim, one_pred_length),
Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variant),
nn.Linear(dim * num_tokens_per_variant, one_pred_length),
Rearrange('b v n -> b n v')
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'iTransformer',
packages = find_packages(exclude=[]),
version = '0.0.4',
version = '0.0.5',
license='MIT',
description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting',
author = 'Phil Wang',
Expand Down

0 comments on commit 58e068b

Please sign in to comment.