Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] isolate optuna and statsmodels as soft dependencies #1629

Merged
merged 13 commits into from
Sep 4, 2024
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,11 @@ dependencies = [
"numpy<2.0.0",
"torch >=2.0.0,!=2.0.1,<3.0.0",
"lightning >=2.0.0,<3.0.0",
"optuna >=3.1.0,<3.3.0",
"scipy >=1.8,<2.0",
"pandas >=1.3.0,<3.0.0",
"scikit-learn >=1.2,<2.0",
"matplotlib",
"statsmodels",
"pytorch_optimizer >=2.5.1,<4.0.0",
"pytorch-optimizer >=2.5.1,<4.0.0",
]

[project.optional-dependencies]
Expand Down
38 changes: 25 additions & 13 deletions pytorch_forecasting/models/temporal_fusion_transformer/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,16 @@
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
import numpy as np
import optuna
from optuna.integration import PyTorchLightningPruningCallback
import optuna.logging
import statsmodels.api as sm
import torch
from torch.utils.data import DataLoader

from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.data import TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.utils._dependencies import _get_installed_packages

optuna_logger = logging.getLogger("optuna")


# need to inherit from callback for this to work
class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback):
pass


def optimize_hyperparameters(
train_dataloaders: DataLoader,
val_dataloaders: DataLoader,
Expand All @@ -47,11 +38,11 @@ def optimize_hyperparameters(
use_learning_rate_finder: bool = True,
trainer_kwargs: Dict[str, Any] = {},
log_dir: str = "lightning_logs",
study: optuna.Study = None,
study=None,
verbose: Union[int, bool] = None,
pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(),
pruner=None,
**kwargs,
) -> optuna.Study:
):
"""
Optimize Temporal Fusion Transformer hyperparameters.

Expand Down Expand Up @@ -96,6 +87,27 @@ def optimize_hyperparameters(
Returns:
optuna.Study: optuna study results
"""
pkgs = _get_installed_packages()

if "optuna" not in pkgs or "statsmodels" not in pkgs:
raise ImportError(
"optimize_hyperparameters requires optuna and statsmodels. "
"Please install these packages with `pip install optuna statsmodels`. "
"From optuna 3.3.0, optuna-integration is also required."
)

import optuna
from optuna.integration import PyTorchLightningPruningCallback
import optuna.logging
import statsmodels.api as sm

# need to inherit from callback for this to work
class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): # noqa: E501
pass

if pruner is None:
pruner = optuna.pruners.SuccessiveHalvingPruner()

assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance(
val_dataloaders.dataset, TimeSeriesDataSet
), "dataloaders must be built from timeseriesdataset"
Expand Down
13 changes: 11 additions & 2 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,18 @@ def test_prediction_with_dataframe(model, data_with_covariates):
model.predict(data_with_covariates, fast_dev_run=True)


SKIP_HYPEPARAM_TEST = (
sys.platform.startswith("win")
# Test skipped on Windows OS due to issues with ddp, see #1632"
or "optuna" not in _get_installed_packages()
or "statsmodels" not in _get_installed_packages()
# Test skipped if required package optuna or statsmodels not available
)


@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Test skipped on Windows OS due to issues with ddp, see #1632",
SKIP_HYPEPARAM_TEST,
reason="Test skipped on Win due to bug #1632, or if missing required packages",
)
@pytest.mark.parametrize("use_learning_rate_finder", [True, False])
def test_hyperparameter_optimization_integration(dataloaders_with_covariates, tmp_path, use_learning_rate_finder):
Expand Down
Loading