Skip to content

Commit

Permalink
Add more baselines (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
blisky-li authored Sep 26, 2024
1 parent da12828 commit 0c60982
Show file tree
Hide file tree
Showing 58 changed files with 6,891 additions and 0 deletions.
154 changes: 154 additions & 0 deletions baselines/ETSformer/Electricity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import sys
from easydict import EasyDict
sys.path.append(os.path.abspath(__file__ + '/../../..'))
from basicts.metrics import masked_mae, masked_mse
from basicts.data import TimeSeriesForecastingDataset
from basicts.runners import SimpleTimeSeriesForecastingRunner
from basicts.scaler import ZScoreScaler
from basicts.utils import get_regular_settings

from .arch import ETSformer

############################## Hot Parameters ##############################
# Dataset & Metrics configuration
DATA_NAME = 'Electricity' # Dataset name
regular_settings = get_regular_settings(DATA_NAME)
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
# Model architecture and parameters
MODEL_ARCH = ETSformer
NUM_NODES = 321
MODEL_PARAM = {
"enc_in": NUM_NODES, # num nodes
"dec_in": NUM_NODES,
"c_out": NUM_NODES,
"seq_len": INPUT_LEN,
"label_len": INPUT_LEN/2, # start token length used in decoder
"pred_len": OUTPUT_LEN, # prediction sequence length
"factor": 3, # attn factor
"d_model": 512,
"moving_avg": 25, # window size of moving average. This is a CRUCIAL hyper-parameter.
"n_heads": 8,
"e_layers": 2, # num of encoder layers
"d_layers": 2, # num of decoder layers
"d_ff": 2048,
"K": 3,
"sigma" : 0.2,
"dropout": 0.2,
"output_attention": False,
"embed": "timeF", # [timeF, fixed, learned]
"activation": "sigmoid",
"num_time_features": 4, # number of used time features
"time_of_day_size": 24,
"day_of_week_size": 7,
"day_of_month_size": 31,
"day_of_year_size": 366
}
NUM_EPOCHS = 100

############################## General Configuration ##############################
CFG = EasyDict()
# General settings
CFG.DESCRIPTION = 'An Example Config'
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
# Runner
CFG.RUNNER = SimpleTimeSeriesForecastingRunner

############################## Dataset Configuration ##############################
CFG.DATASET = EasyDict()
# Dataset settings
CFG.DATASET.NAME = DATA_NAME
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
CFG.DATASET.PARAM = EasyDict({
'dataset_name': DATA_NAME,
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
'input_len': INPUT_LEN,
'output_len': OUTPUT_LEN,
# 'mode' is automatically set by the runner
})

############################## Scaler Configuration ##############################
CFG.SCALER = EasyDict()
# Scaler settings
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
CFG.SCALER.PARAM = EasyDict({
'dataset_name': DATA_NAME,
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
'norm_each_channel': NORM_EACH_CHANNEL,
'rescale': RESCALE,
})

############################## Model Configuration ##############################
CFG.MODEL = EasyDict()
# Model settings
CFG.MODEL.NAME = MODEL_ARCH.__name__
CFG.MODEL.ARCH = MODEL_ARCH
CFG.MODEL.PARAM = MODEL_PARAM
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2, 3, 4]
CFG.MODEL.TARGET_FEATURES = [0]

############################## Metrics Configuration ##############################

CFG.METRICS = EasyDict()
# Metrics settings
CFG.METRICS.FUNCS = EasyDict({
'MAE': masked_mae,
'MSE': masked_mse
})
CFG.METRICS.TARGET = 'MAE'
CFG.METRICS.NULL_VAL = NULL_VAL

############################## Training Configuration ##############################
CFG.TRAIN = EasyDict()
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
'checkpoints',
MODEL_ARCH.__name__,
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
)
CFG.TRAIN.LOSS = masked_mae
# Optimizer settings
CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = "Adam"
CFG.TRAIN.OPTIM.PARAM = {
"lr": 0.0001,
}
# Learning rate scheduler settings
CFG.TRAIN.LR_SCHEDULER = EasyDict()
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
CFG.TRAIN.LR_SCHEDULER.PARAM = {
"milestones": [1, 25, 50],
"gamma": 0.5
}
CFG.TRAIN.CLIP_GRAD_PARAM = {
'max_norm': 5.0
}
# Train data loader settings
CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 64
CFG.TRAIN.DATA.SHUFFLE = True

############################## Validation Configuration ##############################
CFG.VAL = EasyDict()
CFG.VAL.INTERVAL = 1
CFG.VAL.DATA = EasyDict()
CFG.VAL.DATA.BATCH_SIZE = 64

############################## Test Configuration ##############################
CFG.TEST = EasyDict()
CFG.TEST.INTERVAL = 1
CFG.TEST.DATA = EasyDict()
CFG.TEST.DATA.BATCH_SIZE = 64

############################## Evaluation Configuration ##############################

CFG.EVAL = EasyDict()

# Evaluation parameters
CFG.EVAL.HORIZONS = [12, 24, 48, 96, 192, 288, 336]
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True
3 changes: 3 additions & 0 deletions baselines/ETSformer/arch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .etsformer_arch import ETSformer

__all__ = ["ETSformer"]
84 changes: 84 additions & 0 deletions baselines/ETSformer/arch/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat


class DampingLayer(nn.Module):

def __init__(self, pred_len, nhead, dropout=0.1, output_attention=False):
super().__init__()
self.pred_len = pred_len
self.nhead = nhead
self.output_attention = output_attention
self._damping_factor = nn.Parameter(torch.randn(1, nhead))
self.dropout = nn.Dropout(dropout)

def forward(self, x):
x = repeat(x, 'b 1 d -> b t d', t=self.pred_len)
b, t, d = x.shape

powers = torch.arange(self.pred_len).to(self._damping_factor.device) + 1
powers = powers.view(self.pred_len, 1)
damping_factors = self.damping_factor ** powers
damping_factors = damping_factors.cumsum(dim=0)
x = x.view(b, t, self.nhead, -1)
x = self.dropout(x) * damping_factors.unsqueeze(-1)
x = x.view(b, t, d)
if self.output_attention:
return x, damping_factors
return x, None

@property
def damping_factor(self):
return torch.sigmoid(self._damping_factor)


class DecoderLayer(nn.Module):

def __init__(self, d_model, nhead, c_out, pred_len, dropout=0.1, output_attention=False):
super().__init__()
self.d_model = d_model
self.nhead = nhead
self.c_out = c_out
self.pred_len = pred_len
self.output_attention = output_attention

self.growth_damping = DampingLayer(pred_len, nhead, dropout=dropout, output_attention=output_attention)
self.dropout1 = nn.Dropout(dropout)

def forward(self, growth, season):
growth_horizon, growth_damping = self.growth_damping(growth[:, -1:])
growth_horizon = self.dropout1(growth_horizon)

seasonal_horizon = season[:, -self.pred_len:]

if self.output_attention:
return growth_horizon, seasonal_horizon, growth_damping
return growth_horizon, seasonal_horizon, None


class Decoder(nn.Module):

def __init__(self, layers):
super().__init__()
self.d_model = layers[0].d_model
self.c_out = layers[0].c_out
self.pred_len = layers[0].pred_len
self.nhead = layers[0].nhead

self.layers = nn.ModuleList(layers)
self.pred = nn.Linear(self.d_model, self.c_out)

def forward(self, growths, seasons):
growth_repr = []
season_repr = []
growth_dampings = []

for idx, layer in enumerate(self.layers):
growth_horizon, season_horizon, growth_damping = layer(growths[idx], seasons[idx])
growth_repr.append(growth_horizon)
season_repr.append(season_horizon)
growth_dampings.append(growth_damping)
growth_repr = sum(growth_repr)
season_repr = sum(season_repr)
return self.pred(growth_repr), self.pred(season_repr), growth_dampings
Loading

0 comments on commit 0c60982

Please sign in to comment.