From f5e15a41c7ff08ef0b69b9dff64ca5a5ca0b0aab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 29 Aug 2024 12:51:00 +0100 Subject: [PATCH 1/8] add deputils --- pytorch_forecasting/utils/__init__.py | 33 +++++++++++++++ pytorch_forecasting/utils/_dependencies.py | 40 +++++++++++++++++++ .../{utils.py => utils/_utils.py} | 0 3 files changed, 73 insertions(+) create mode 100644 pytorch_forecasting/utils/__init__.py create mode 100644 pytorch_forecasting/utils/_dependencies.py rename pytorch_forecasting/{utils.py => utils/_utils.py} (100%) diff --git a/pytorch_forecasting/utils/__init__.py b/pytorch_forecasting/utils/__init__.py new file mode 100644 index 00000000..10741943 --- /dev/null +++ b/pytorch_forecasting/utils/__init__.py @@ -0,0 +1,33 @@ +""" +PyTorch Forecasting package for timeseries forecasting with PyTorch. +""" + +from pytorch_forecasting.utils._utils import ( + apply_to_list, + autocorrelation, + create_mask, + detach, + get_embedding_size, + groupby_apply, + integer_histogram, + move_to_device, + profile, + to_list, + unpack_sequence, +) + +__all__ = [ + "apply_to_list", + "autocorrelation", + "get_embedding_size", + "create_mask", + "to_list", + "RecurrentNetwork", + "DecoderMLP", + "detach", + "move_to_device", + "integer_histogram", + "groupby_apply", + "profile", + "unpack_sequence", +] diff --git a/pytorch_forecasting/utils/_dependencies.py b/pytorch_forecasting/utils/_dependencies.py new file mode 100644 index 00000000..3d4f1279 --- /dev/null +++ b/pytorch_forecasting/utils/_dependencies.py @@ -0,0 +1,40 @@ +"""Utilities for managing dependencies. + +Copied from sktime/skbase. +""" + +from functools import lru_cache + + +@lru_cache +def _get_installed_packages_private(): + """Get a dictionary of installed packages and their versions. + + Same as _get_installed_packages, but internal to avoid mutating the lru_cache + by accident. + """ + from importlib.metadata import distributions, version + + dists = distributions() + package_names = {dist.metadata["Name"] for dist in dists} + package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names} + # developer note: + # we cannot just use distributions naively, + # because the same top level package name may appear *twice*, + # e.g., in a situation where a virtual env overrides a base env, + # such as in deployment environments like databricks. + # the "version" contract ensures we always get the version that corresponds + # to the importable distribution, i.e., the top one in the sys.path. + return package_versions + + +def _get_installed_packages(): + """Get a dictionary of installed packages and their versions. + + Returns + ------- + dict : dictionary of installed packages and their versions + keys are PEP 440 compatible package names, values are package versions + MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3" + """ + return _get_installed_packages_private().copy() diff --git a/pytorch_forecasting/utils.py b/pytorch_forecasting/utils/_utils.py similarity index 100% rename from pytorch_forecasting/utils.py rename to pytorch_forecasting/utils/_utils.py From bedd86e378806e8ce3df8188c5d92babd89b2801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 29 Aug 2024 12:51:24 +0100 Subject: [PATCH 2/8] isolate --- pyproject.toml | 8 +++- .../temporal_fusion_transformer/tuning.py | 38 ++++++++++++------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3c5ed511..f07d52ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,17 +58,21 @@ dependencies = [ "numpy<2.0.0", "torch >=2.0.0,!=2.0.1,<3.0.0", "lightning >=2.0.0,<3.0.0", - "optuna >=3.1.0,<3.3.0", "scipy >=1.8,<2.0", "pandas >=1.3.0,<3.0.0", "scikit-learn >=1.2,<2.0", "matplotlib", - "statsmodels", "pytorch-optimizer >=2.5.1,<3.0.0", ] [project.optional-dependencies] +all_extras = [ + "optuna >=3.1.0,<4.0.0", + "optuna-integration", + "statsmodels", +] + dev = [ "pydocstyle >=6.1.1,<7.0.0", # checks and make tools diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index 2bc06fc8..bba49803 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -12,25 +12,16 @@ from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.tuner import Tuner import numpy as np -import optuna -from optuna.integration import PyTorchLightningPruningCallback -import optuna.logging -import statsmodels.api as sm -import torch from torch.utils.data import DataLoader from pytorch_forecasting import TemporalFusionTransformer from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.metrics import QuantileLoss +from pytorch_forecastingu.utils._dependencies import _get_installed_packages optuna_logger = logging.getLogger("optuna") -# need to inherit from callback for this to work -class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): - pass - - def optimize_hyperparameters( train_dataloaders: DataLoader, val_dataloaders: DataLoader, @@ -47,11 +38,11 @@ def optimize_hyperparameters( use_learning_rate_finder: bool = True, trainer_kwargs: Dict[str, Any] = {}, log_dir: str = "lightning_logs", - study: optuna.Study = None, + study=None, verbose: Union[int, bool] = None, - pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(), + pruner=None, **kwargs, -) -> optuna.Study: +): """ Optimize Temporal Fusion Transformer hyperparameters. @@ -96,6 +87,27 @@ def optimize_hyperparameters( Returns: optuna.Study: optuna study results """ + pkgs = _get_installed_packages() + + if "optuna" not in pkgs or "statsmodels" not in pkgs: + raise ImportError( + "optimize_hyperparameters requires optuna and statsmodels. " + "Please install these packages with `pip install optuna statsmodels`. " + "From optuna 3.3.0, optuna-integration is also required." + ) + + import optuna + from optuna.integration import PyTorchLightningPruningCallback + import optuna.logging + import statsmodels.api as sm + + # need to inherit from callback for this to work + class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback): # noqa: E501 + pass + + if pruner is None: + pruner = optuna.pruners.SuccessiveHalvingPruner() + assert isinstance(train_dataloaders.dataset, TimeSeriesDataSet) and isinstance( val_dataloaders.dataset, TimeSeriesDataSet ), "dataloaders must be built from timeseriesdataset" From b99f944ae75dfb5eada5c122d052e6a09670ee7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 29 Aug 2024 13:08:21 +0100 Subject: [PATCH 3/8] exports --- pytorch_forecasting/utils/__init__.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pytorch_forecasting/utils/__init__.py b/pytorch_forecasting/utils/__init__.py index 10741943..16b39215 100644 --- a/pytorch_forecasting/utils/__init__.py +++ b/pytorch_forecasting/utils/__init__.py @@ -3,31 +3,49 @@ """ from pytorch_forecasting.utils._utils import ( + InitialParameterRepresenterMixIn, + OutputMixIn, + TupleOutputMixIn, apply_to_list, autocorrelation, + concat_sequences, create_mask, detach, get_embedding_size, groupby_apply, integer_histogram, + masked_op, move_to_device, + padded_stack, profile, + redirect_stdout, + repr_class, to_list, unpack_sequence, + unsqueeze_like, ) __all__ = [ + "InitialParameterRepresenterMixIn", + "OutputMixIn", + "TupleOutputMixIn", "apply_to_list", "autocorrelation", "get_embedding_size", + "concat_sequences", "create_mask", "to_list", "RecurrentNetwork", "DecoderMLP", "detach", + "masked_op", "move_to_device", "integer_histogram", "groupby_apply", + "padded_stack", "profile", + "redirect_stdout", + "repr_class", "unpack_sequence", + "unsqueeze_like", ] From d560e25282162130420e828bfd46f43c857586f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 29 Aug 2024 13:10:03 +0100 Subject: [PATCH 4/8] Update test.yml --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dd2cba1d..4490f904 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -64,7 +64,7 @@ jobs: - name: Install dependencies shell: bash run: | - pip install ".[dev,github-actions,graph,mqf2]" + pip install ".[dev,all_extras,github-actions,graph,mqf2]" - name: Show dependencies run: python -m pip list From 505bc4433f2862a5d4b73e912f118588ad4ba78f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 29 Aug 2024 13:24:48 +0100 Subject: [PATCH 5/8] typo fix --- .../models/temporal_fusion_transformer/tuning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py index bba49803..7f2dfab8 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tuning.py @@ -17,7 +17,7 @@ from pytorch_forecasting import TemporalFusionTransformer from pytorch_forecasting.data import TimeSeriesDataSet from pytorch_forecasting.metrics import QuantileLoss -from pytorch_forecastingu.utils._dependencies import _get_installed_packages +from pytorch_forecasting.utils._dependencies import _get_installed_packages optuna_logger = logging.getLogger("optuna") From 41db0492527208db8a0125e46217cbd7a85777ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Fri, 30 Aug 2024 21:25:19 +0100 Subject: [PATCH 6/8] iso test --- .../test_models/test_temporal_fusion_transformer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index 04210806..c2aa49a7 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -381,9 +381,18 @@ def test_prediction_with_dataframe(model, data_with_covariates): model.predict(data_with_covariates, fast_dev_run=True) +SKIP_HYPEPARAM_TEST = ( + sys.platform.startswith("win") + # Test skipped on Windows OS due to issues with ddp, see #1632" + or "optuna" not in _get_installed_packages() + or "statsmodels" not in _get_installed_packages() + # Test skipped if required package optuna or statsmodels not available +) + + @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Test skipped on Windows OS due to issues with ddp, see #1632", + SKIP_HYPEPARAM_TEST, + reason="Test skipped on Win due to bug #1632, or if missing required packages", ) @pytest.mark.parametrize("use_learning_rate_finder", [True, False]) def test_hyperparameter_optimization_integration(dataloaders_with_covariates, tmp_path, use_learning_rate_finder): From dc074f7666d9a5b40074d7789496a76e043a8f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 4 Sep 2024 18:34:36 +0100 Subject: [PATCH 7/8] Update pyproject.toml --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f07d52ef..07882a9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,12 @@ dependencies = [ [project.optional-dependencies] +tuning = [ + "optuna >=3.1.0,<4.0.0", + "optuna-integration", + "statsmodels", +] + all_extras = [ "optuna >=3.1.0,<4.0.0", "optuna-integration", From 24803446394a081b6b50f5816597876a884bb662 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 4 Sep 2024 18:40:03 +0100 Subject: [PATCH 8/8] Update pyproject.toml --- pyproject.toml | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 07882a9b..5fb6c890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,19 +66,41 @@ dependencies = [ ] [project.optional-dependencies] +# there are the following dependency sets: +# - all_extras - all soft dependencies +# - granular dependency sets: +# - tuning - dependencies for tuning hyperparameters via optuna +# - mqf2 - dependencies for multivariate quantile loss +# - graph - dependencies for graph based forecasting +# - dev - the developer dependency set, for contributors to pytorch-forecasting +# - CI related: e.g., dev, github-actions. Not for users of sktime. +# +# soft dependencies are not required for the core functionality of sktime +# but are required by popular estimators, e.g., prophet, tbats, etc. -tuning = [ +# all soft dependencies +# +# users can install via "pip install pytorch-forecasting[all_extras]" +# +all_extras = [ + "cpflows", + "networkx", "optuna >=3.1.0,<4.0.0", "optuna-integration", "statsmodels", ] -all_extras = [ +tuning = [ "optuna >=3.1.0,<4.0.0", "optuna-integration", "statsmodels", ] +graph = ["networkx"] + +mqf2 = ["cpflows"] + +# dev - the developer dependency set, for contributors to pytorch-forecasting dev = [ "pydocstyle >=6.1.1,<7.0.0", # checks and make tools @@ -111,8 +133,6 @@ dev = [ ] github-actions = ["pytest-github-actions-annotate-failures"] -graph = ["networkx"] -mqf2 = ["cpflows"] [tool.setuptools.packages.find] exclude = ["build_tools"]