diff --git a/pyproject.toml b/pyproject.toml index f8e2b167..4350d9b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,13 +58,11 @@ dependencies = [ "numpy<2.0.0", "torch >=2.0.0,!=2.0.1,<3.0.0", "lightning >=2.0.0,<3.0.0", - "optuna >=3.1.0,<3.3.0", "scipy >=1.8,<2.0", "pandas >=1.3.0,<3.0.0", "scikit-learn >=1.2,<2.0", "matplotlib", - "statsmodels", - "pytorch_optimizer >=2.5.1,<4.0.0", + "pytorch-optimizer >=2.5.1,<4.0.0", ] [project.optional-dependencies] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 2bc06fc8..7f2dfab8 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -12,25 +12,16 @@ from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.tuner import Tuner import numpy as np -import optuna -from optuna.integration import PyTorchLightningPruningCallback -import optuna.logging -import statsmodels.api as sm -import torch from torch.utils.data import DataLoader from pytorch_forecasting import TemporalFusionTransformer from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.metrics import QuantileLoss +from pytorch_forecasting.utils._dependencies import _get_installed_packages optuna_logger = logging.getLogger("optuna") -# need to inherit from callback for this to work -class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): - pass - - def optimize_hyperparameters( train_dataloaders: DataLoader, val_dataloaders: DataLoader, @@ -47,11 +38,11 @@ def optimize_hyperparameters( use_learning_rate_finder: bool = True, trainer_kwargs: Dict[str, Any] = {}, log_dir: str = "lightning_logs", - study: optuna.Study = None, + study=None, verbose: Union[int, bool] = None, - pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(), + pruner=None, **kwargs, -) -> optuna.Study: +): """ Optimize Temporal Fusion Transformer hyperparameters. @@ -96,6 +87,27 @@ def optimize_hyperparameters( Returns: optuna.Study: optuna study results """ + pkgs = _get_installed_packages() + + if "optuna" not in pkgs or "statsmodels" not in pkgs: + raise ImportError( + "optimize_hyperparameters requires optuna and statsmodels. " + "Please install these packages with `pip install optuna statsmodels`. " + "From optuna 3.3.0, optuna-integration is also required." + ) + + import optuna + from optuna.integration import PyTorchLightningPruningCallback + import optuna.logging + import statsmodels.api as sm + + # need to inherit from callback for this to work + class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): # noqa: E501 + pass + + if pruner is None: + pruner = optuna.pruners.SuccessiveHalvingPruner() + assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance( val_dataloaders.dataset, TimeSeriesDataSet ), "dataloaders must be built from timeseriesdataset" diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 04210806..c2aa49a7 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -381,9 +381,18 @@ def test_prediction_with_dataframe(model, data_with_covariates): model.predict(data_with_covariates, fast_dev_run=True) +SKIP_HYPEPARAM_TEST = ( + sys.platform.startswith("win") + # Test skipped on Windows OS due to issues with ddp, see #1632" + or "optuna" not in _get_installed_packages() + or "statsmodels" not in _get_installed_packages() + # Test skipped if required package optuna or statsmodels not available +) + + @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Test skipped on Windows OS due to issues with ddp, see #1632", + SKIP_HYPEPARAM_TEST, + reason="Test skipped on Win due to bug #1632, or if missing required packages", ) @pytest.mark.parametrize("use_learning_rate_finder", [True, False]) def test_hyperparameter_optimization_integration(dataloaders_with_covariates, tmp_path, use_learning_rate_finder):