Skip to content

Commit

Permalink
fix: 🐛 a bug of subset forecasting
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Sep 26, 2024
1 parent f08fd8f commit a3c2847
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion basicts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .launcher import launch_training, launch_evaluation
from .runners import BaseRunner

__version__ = '0.4.2'
__version__ = '0.4.3'

__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseRunner']
9 changes: 9 additions & 0 deletions basicts/runners/base_tsf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, cfg: Dict):
self.prediction_length = cfg['TRAIN'].CL.get('PREDICTION_LENGTH')
self.cl_step_size = cfg['TRAIN'].CL.get('STEP_SIZE', 1)

self.target_time_series = cfg['MODEL'].get('TARGET_TIME_SERIES', None)

# Eealuation settings
self.if_evaluate_on_gpu = cfg.get('EVAL', EasyDict()).get('USE_GPU', True)
self.evaluation_horizons = [_ - 1 for _ in cfg.get('EVAL', EasyDict()).get('HORIZONS', [])]
Expand Down Expand Up @@ -421,10 +423,17 @@ def postprocessing(self, input_data: Dict) -> Dict:
Dict: Processed data.
"""

# rescale data
if self.scaler is not None and self.scaler.rescale:
input_data['prediction'] = self.scaler.inverse_transform(input_data['prediction'])
input_data['target'] = self.scaler.inverse_transform(input_data['target'])
input_data['inputs'] = self.scaler.inverse_transform(input_data['inputs'])

# subset forecasting
if self.target_time_series is not None:
input_data['target'] = input_data['target'][:, :, self.target_time_series, :]
input_data['prediction'] = input_data['prediction'][:, :, self.target_time_series, :]

# TODO: add more postprocessing steps as needed.
return input_data

Expand Down
7 changes: 1 addition & 6 deletions basicts/runners/runner_zoo/simple_tsf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(self, cfg: Dict):
super().__init__(cfg)
self.forward_features = cfg['MODEL'].get('FORWARD_FEATURES', None)
self.target_features = cfg['MODEL'].get('TARGET_FEATURES', None)
self.target_time_series = cfg['MODEL'].get('TARGET_TIME_SERIES', None)

def select_input_features(self, data: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -108,12 +107,8 @@ def forward(self, data: Dict, epoch: int = None, iter_num: int = None, train: bo
if 'target' not in model_return:
model_return['target'] = self.select_target_features(future_data)

if self.target_time_series is not None:
model_return['target'] = self.select_target_time_series(model_return['target'])
model_return['prediction'] = self.select_target_time_series(model_return['prediction'])

# Ensure the output shape is correct
assert list(model_return['prediction'].shape)[:3] == [batch_size, length, num_nodes if self.target_time_series is None else len(self.target_time_series)], \
assert list(model_return['prediction'].shape)[:3] == [batch_size, length, num_nodes], \
"The shape of the output is incorrect. Ensure it matches [B, L, N, C]."

return model_return

0 comments on commit a3c2847

Please sign in to comment.