Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TSMixer #1375

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

TSMixer #1375

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
# Release Notes

## v1.1.0 Adding TSMixer

### Added

- New state-of-the-art model beating TFT called TSMixer based on [TSMixer: An All-MLP Architecture for Time Series Forecasting](https://arxiv.org/abs/2303.06053).

### Fixes

- Multiple small fixes

### Contributors

- jdb78
- jurgispods
- jacktang
- andre-marcos-perez
- tmxt
- bohdan-safoniuk
- maartensukel
- CahidArda
- MBelniak

## v1.0.0 Update to pytorch 2.0 (10/04/2023)


Expand Down
2 changes: 2 additions & 0 deletions pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
NHiTS,
RecurrentNetwork,
TemporalFusionTransformer,
TSMixer,
get_rnn,
)
from pytorch_forecasting.utils import (
Expand All @@ -69,6 +70,7 @@
"NBeats",
"NHiTS",
"Baseline",
"TSMixer",
"DeepAR",
"BaseModel",
"BaseModelWithCovariates",
Expand Down
2 changes: 2 additions & 0 deletions pytorch_forecasting/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn
from pytorch_forecasting.models.rnn import RecurrentNetwork
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer
from pytorch_forecasting.models.tsmixer import TSMixer

__all__ = [
"NBeats",
Expand All @@ -32,4 +33,5 @@
"GRU",
"MultiEmbedding",
"DecoderMLP",
"TSMixer",
]
190 changes: 190 additions & 0 deletions pytorch_forecasting/models/tsmixer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""
TSMixer is a fairly simple architecture shown to have beaten the likes of the Temporal Fusion Transformer.

Reference: `TSMixer: An All-MLP Architecture for Time Series Forecasting <https://arxiv.org/abs/2303.06053>`_
"""

from copy import copy
from typing import Dict, List, Tuple, Union

from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
from torchmetrics import Metric as LightningMetric

from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.data.encoders import NaNLabelEncoder
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric, MultiLoss, QuantileLoss
from pytorch_forecasting.models.base_model import BaseModelWithCovariates
from pytorch_forecasting.models.nn import LSTM, MultiEmbedding
from pytorch_forecasting.models.tsmixer.submodules import TSMixerEncoder
from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list


class TSMixer(BaseModelWithCovariates):
def __init__(
self,
hidden_size: int = 16,
lstm_layers: int = 1,
dropout: float = 0.1,
output_size: Union[int, List[int]] = 7,
loss: MultiHorizonMetric = None,
attention_head_size: int = 4,
max_encoder_length: int = 10,
static_categoricals: List[str] = [],
static_reals: List[str] = [],
time_varying_categoricals_encoder: List[str] = [],
time_varying_categoricals_decoder: List[str] = [],
categorical_groups: Dict[str, List[str]] = {},
time_varying_reals_encoder: List[str] = [],
time_varying_reals_decoder: List[str] = [],
x_reals: List[str] = [],
x_categoricals: List[str] = [],
hidden_continuous_size: int = 8,
hidden_continuous_sizes: Dict[str, int] = {},
embedding_sizes: Dict[str, Tuple[int, int]] = {},
embedding_paddings: List[str] = [],
embedding_labels: Dict[str, np.ndarray] = {},
learning_rate: float = 1e-3,
log_interval: Union[int, float] = -1,
log_val_interval: Union[int, float] = None,
log_gradient_flow: bool = False,
reduce_on_plateau_patience: int = 1000,
monotone_constaints: Dict[str, int] = {},
share_single_variable_networks: bool = False,
causal_attention: bool = True,
logging_metrics: nn.ModuleList = None,
**kwargs,
):
"""
Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible.

Implementation of the article
`TSMixer: An All-MLP Architecture for Time Series Forecasting <https://arxiv.org/abs/2303.06053>`_
Args:

hidden_size: hidden size of network which is its main hyperparameter and can range from 8 to 512
lstm_layers: number of LSTM layers (2 is mostly optimal)
dropout: dropout rate
output_size: number of outputs (e.g. number of quantiles for QuantileLoss and one target or list
of output sizes).
loss: loss function taking prediction and targets
attention_head_size: number of attention heads (4 is a good default)
max_encoder_length: length to encode (can be far longer than the decoder length but does not have to be)
static_categoricals: names of static categorical variables
static_reals: names of static continuous variables
time_varying_categoricals_encoder: names of categorical variables for encoder
time_varying_categoricals_decoder: names of categorical variables for decoder
time_varying_reals_encoder: names of continuous variables for encoder
time_varying_reals_decoder: names of continuous variables for decoder
categorical_groups: dictionary where values
are list of categorical variables that are forming together a new categorical
variable which is the key in the dictionary
x_reals: order of continuous variables in tensor passed to forward function
x_categoricals: order of categorical variables in tensor passed to forward function
hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical
embedding size)
hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection
(fallback to hidden_continuous_size if index is not in dictionary)
embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and
embedding size
embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector
embedding_labels: dictionary mapping (string) indices to list of categorical labels
learning_rate: learning rate
log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0
, will log multiple entries per batch. Defaults to -1.
log_val_interval: frequency with which to log validation set metrics, defaults to log_interval
log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training
failures
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder
variables mapping
position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive,
larger numbers add more weight to the constraint vs. the loss but are usually not necessary).
This constraint significantly slows down training. Defaults to {}.
share_single_variable_networks (bool): if to share the single variable networks between the encoder and
decoder. Defaults to False.
causal_attention (bool): If to attend only at previous timesteps in the decoder or also include future
predictions. Defaults to True.
logging_metrics (nn.ModuleList[LightningMetric]): list of metrics that are logged during training.
Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]).
**kwargs: additional arguments to :py:class:`~BaseModel`.
"""
if logging_metrics is None:
logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()])
if loss is None:
loss = QuantileLoss()
self.save_hyperparameters()
# store loss function separately as it is a module
assert isinstance(loss, LightningMetric), "Loss has to be a PyTorch Lightning `Metric`"
super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)

# processing inputs
# embeddings
self.input_embeddings = MultiEmbedding(
embedding_sizes=self.hparams.embedding_sizes,
categorical_groups=self.hparams.categorical_groups,
embedding_paddings=self.hparams.embedding_paddings,
x_categoricals=self.hparams.x_categoricals,
max_embedding_size=self.hparams.hidden_size,
)

@classmethod
def from_dataset(
cls,
dataset: TimeSeriesDataSet,
allowed_encoder_known_variable_names: List[str] = None,
**kwargs,
):
"""
Create model from dataset.

Args:
dataset: timeseries dataset
allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all
**kwargs: additional arguments such as hyperparameters for model (see ``__init__()``)

Returns:
TemporalFusionTransformer
"""
# add maximum encoder length
# update defaults
new_kwargs = copy(kwargs)
new_kwargs["max_encoder_length"] = dataset.max_encoder_length
new_kwargs.update(cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()))

# create class and return
return super().from_dataset(
dataset, allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, **new_kwargs
)

def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
input dimensions: n_samples x time x variables
"""
encoder_lengths = x["encoder_lengths"]
decoder_lengths = x["decoder_lengths"]
x_cat = torch.cat([x["encoder_cat"], x["decoder_cat"]], dim=1) # concatenate in time dimension
x_cont = torch.cat([x["encoder_cont"], x["decoder_cont"]], dim=1) # concatenate in time dimension
# timesteps = x_cont.size(1) # encode + decode length
# max_encoder_length = int(encoder_lengths.max())
input_vectors = self.input_embeddings(x_cat)
input_vectors.update(
{
name: x_cont[..., idx].unsqueeze(-1)
for idx, name in enumerate(self.hparams.x_reals)
if name in self.reals
}
)

return self.to_network_output(
# prediction=self.transform_output(output, target_scale=x["target_scale"]),
# encoder_attention=attn_output_weights[..., :max_encoder_length],
# decoder_attention=attn_output_weights[..., max_encoder_length:],
# static_variables=static_variable_selection,
# encoder_variables=encoder_sparse_weights,
# decoder_variables=decoder_sparse_weights,
decoder_lengths=decoder_lengths,
encoder_lengths=encoder_lengths,
)
Loading
Loading