diff --git a/iTransformer/iTransformer.py b/iTransformer/iTransformer.py index adb195d..e3e5749 100644 --- a/iTransformer/iTransformer.py +++ b/iTransformer/iTransformer.py @@ -79,10 +79,11 @@ class iTransformer(Module): def __init__( self, *, - num_variates, - lookback_len, - depth, - dim, + num_variates: int, + lookback_len: int, + depth: int, + dim: int, + num_tokens_per_variant = 1, pred_length: Union[int, Tuple[int, ...]], dim_head = 32, heads = 4, @@ -110,7 +111,8 @@ def __init__( ])) self.mlp_in = nn.Sequential( - nn.Linear(lookback_len, dim), + nn.Linear(lookback_len, dim * num_tokens_per_variant), + Rearrange('b v (n d) -> b (v n) d', n = num_tokens_per_variant), nn.LayerNorm(dim) ) @@ -118,7 +120,8 @@ def __init__( for one_pred_length in pred_length: head = nn.Sequential( - nn.Linear(dim, one_pred_length), + Rearrange('b (v n) d -> b v (n d)', n = num_tokens_per_variant), + nn.Linear(dim * num_tokens_per_variant, one_pred_length), Rearrange('b v n -> b n v') ) diff --git a/setup.py b/setup.py index 5d07ac6..ab0c842 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'iTransformer', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'iTransformer - Inverted Transformer Are Effective for Time Series Forecasting', author = 'Phil Wang',