diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index eb708984..33180d8a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macOS-latest] # add windows-2019 when poetry allows installation with `-f` flag + os: [ubuntu-latest, macos-13] # add windows-2019 when poetry allows installation with `-f` flag python-version: ["3.8", "3.9", "3.10"] defaults: run: @@ -62,7 +62,7 @@ jobs: run: poetry run python -m pip install pip -U - name: Install dependencies - run: poetry install -E "github-actions graph mqf2" + run: poetry install -E "github-actions graph" # - name: Install pytorch geometric dependencies # shell: bash diff --git a/docs/requirements.txt b/docs/requirements.txt index 09435203..c68cd39e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,6 +7,7 @@ lightning >=2.0.0 cloudpickle torch >=2.0,!=2.0.1 optuna >=3.1.0 +optuna-integration scipy pandas >=1.3 scikit-learn >1.2 diff --git a/poetry.lock b/poetry.lock index cd46f873..39d38c00 100644 --- a/poetry.lock +++ b/poetry.lock @@ -302,7 +302,7 @@ files = [ name = "backports-functools-lru-cache" version = "1.6.6" description = "Backport of functools.lru_cache" -optional = true +optional = false python-versions = ">=2.6" files = [ {file = "backports.functools_lru_cache-1.6.6-py2.py3-none-any.whl", hash = "sha256:77e27d0ffbb463904bdd5ef8b44363f6cd5ef503e664b3f599a3bf5843ed37cf"}, @@ -798,7 +798,7 @@ toml = ["tomli"] name = "cpflows" version = "0.1.2" description = "Convex Potential Flows package" -optional = true +optional = false python-versions = "*" files = [ {file = "cpflows-0.1.2.tar.gz", hash = "sha256:a88f5c8f948776d0619c78bf183ab639543ef4cf8e4d91e64e1e45b13a61bbdd"}, @@ -1212,7 +1212,7 @@ tqdm = ["tqdm"] name = "future" version = "0.18.3" description = "Clean single-source support for Python 3 and 2" -optional = true +optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" files = [ {file = "future-0.18.3.tar.gz", hash = "sha256:34a17436ed1e96697a86f9de3d15a3b0be01d8bc8de9c1dffd59fb8234ed5307"}, @@ -1399,7 +1399,7 @@ protobuf = ["grpcio-tools (>=1.58.0)"] name = "h5py" version = "3.9.0" description = "Read and write HDF5 files from Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "h5py-3.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eb7bdd5e601dd1739698af383be03f3dad0465fe67184ebd5afca770f50df9d6"}, @@ -2701,6 +2701,26 @@ document = ["ase", "cmaes (>=0.10.0)", "fvcore", "lightgbm", "matplotlib (!=3.6. optional = ["boto3", "cmaes (>=0.10.0)", "google-cloud-storage", "matplotlib (!=3.6.0)", "pandas", "plotly (>=4.9.0)", "redis", "scikit-learn (>=0.24.2)", "scipy", "torch"] test = ["coverage", "fakeredis[lua]", "kaleido", "moto", "pytest", "scipy (>=1.9.2)", "torch"] +[[package]] +name = "optuna-integration" +version = "3.6.0" +description = "Integration libraries of Optuna." +optional = false +python-versions = "*" +files = [ + {file = "optuna-integration-3.6.0.tar.gz", hash = "sha256:f261c38586b22cd95639287ca694fc0f788482cfbb7bb83803caf404ce06a55a"}, + {file = "optuna_integration-3.6.0-py3-none-any.whl", hash = "sha256:e281d4902ab728b4c86a997eb01e7bc54d921ae7cff40ed8f4e083f49d37e033"}, +] + +[package.dependencies] +optuna = "*" + +[package.extras] +all = ["botorch (<0.10.0)", "catalyst", "catboost (>=0.26)", "catboost (>=0.26,<1.2)", "cma", "distributed", "fastai", "gpytorch", "lightgbm", "lightning", "mlflow", "mxnet", "pandas", "pytorch-ignite", "scikit-learn (>=0.24.2)", "scikit-optimize", "scipy (>=1.9.2)", "shap", "skorch", "tensorboard", "tensorflow", "torch", "wandb", "xgboost"] +checking = ["black", "blackdoc", "hacking", "isort", "mypy", "types-PyYAML", "types-redis", "types-setuptools", "typing-extensions (>=3.10.0.0)"] +document = ["cma", "mlflow", "pandas", "scikit-learn (>=0.24.2)", "scipy (>=1.9.2)", "sphinx", "sphinx-rtd-theme"] +test = ["coverage", "fakeredis[lua]", "pytest"] + [[package]] name = "packaging" version = "23.1" @@ -4138,7 +4158,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo name = "seaborn" version = "0.12.2" description = "Statistical data visualization" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "seaborn-0.12.2-py3-none-any.whl", hash = "sha256:ebf15355a4dba46037dfd65b7350f014ceb1f13c05e814eda2c9f5fd731afc08"}, @@ -4460,9 +4480,34 @@ description = "Statistical computations and models for Python" optional = false python-versions = ">=3.8" files = [ + {file = "statsmodels-0.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:43af9c0b07c9d72f275cf14ea54a481a3f20911f0b443181be4769def258fdeb"}, + {file = "statsmodels-0.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16975ab6ad505d837ba9aee11f92a8c5b49c4fa1ff45b60fe23780b19e5705e"}, + {file = "statsmodels-0.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e278fe74da5ed5e06c11a30851eda1af08ef5af6be8507c2c45d2e08f7550dde"}, + {file = "statsmodels-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0564d92cb05b219b4538ed09e77d96658a924a691255e1f7dd23ee338df441b"}, + {file = "statsmodels-0.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5385e22e72159a09c099c4fb975f350a9f3afeb57c1efce273b89dcf1fe44c0f"}, {file = "statsmodels-0.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:0a8aae75a2e08ebd990e5fa394f8e32738b55785cb70798449a3f4207085e667"}, + {file = "statsmodels-0.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b69a63ad6c979a6e4cde11870ffa727c76a318c225a7e509f031fbbdfb4e416a"}, + {file = "statsmodels-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7562cb18a90a114f39fab6f1c25b9c7b39d9cd5f433d0044b430ca9d44a8b52c"}, + {file = "statsmodels-0.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3abaca4b963259a2bf349c7609cfbb0ce64ad5fb3d92d6f08e21453e4890248"}, + {file = "statsmodels-0.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0f727fe697f6406d5f677b67211abe5a55101896abdfacdb3f38410405f6ad8"}, + {file = "statsmodels-0.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b6838ac6bdb286daabb5e91af90fd4258f09d0cec9aace78cc441cb2b17df428"}, {file = "statsmodels-0.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:709bfcef2dbe66f705b17e56d1021abad02243ee1a5d1efdb90f9bad8b06a329"}, + {file = "statsmodels-0.14.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f32a7cd424cf33304a54daee39d32cccf1d0265e652c920adeaeedff6d576457"}, + {file = "statsmodels-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f8c30181c084173d662aaf0531867667be2ff1bee103b84feb64f149f792dbd2"}, + {file = "statsmodels-0.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de2b97413913d52ad6342dece2d653e77f78620013b7705fad291d4e4266ccb"}, + {file = "statsmodels-0.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3420f88289c593ba2bca33619023059c476674c160733bd7d858564787c83d3"}, + {file = "statsmodels-0.14.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c008e16096f24f0514e53907890ccac6589a16ad6c81c218f2ee6752fdada555"}, {file = "statsmodels-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:bc0351d279c4e080f0ce638a3d886d312aa29eade96042e3ba0a73771b1abdfb"}, + {file = "statsmodels-0.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bf293ada63b2859d95210165ad1dfcd97bd7b994a5266d6fbeb23659d8f0bf68"}, + {file = "statsmodels-0.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:44ca8cb88fa3d3a4ffaff1fb8eb0e98bbf83fc936fcd9b9eedee258ecc76696a"}, + {file = "statsmodels-0.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d5373d176239993c095b00d06036690a50309a4e00c2da553b65b840f956ae6"}, + {file = "statsmodels-0.14.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a532dfe899f8b6632cd8caa0b089b403415618f51e840d1817a1e4b97e200c73"}, + {file = "statsmodels-0.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:4fe0a60695952b82139ae8750952786a700292f9e0551d572d7685070944487b"}, + {file = "statsmodels-0.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:04293890f153ffe577e60a227bd43babd5f6c1fc50ea56a3ab1862ae85247a95"}, + {file = "statsmodels-0.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3e70a2e93d54d40b2cb6426072acbc04f35501b1ea2569f6786964adde6ca572"}, + {file = "statsmodels-0.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab3a73d16c0569adbba181ebb967e5baaa74935f6d2efe86ac6fc5857449b07d"}, + {file = "statsmodels-0.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eefa5bcff335440ee93e28745eab63559a20cd34eea0375c66d96b016de909b3"}, + {file = "statsmodels-0.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:bc43765710099ca6a942b5ffa1bac7668965052542ba793dd072d26c83453572"}, {file = "statsmodels-0.14.1.tar.gz", hash = "sha256:2260efdc1ef89f39c670a0bd8151b1d0843567781bcafec6cda0534eb47a94f6"}, ] @@ -4485,7 +4530,7 @@ docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "n name = "subprocess32" version = "3.5.4" description = "A backport of the subprocess module from Python 3 for use on 2.x." -optional = true +optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, <4" files = [ {file = "subprocess32-3.5.4-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:88e37c1aac5388df41cc8a8456bb49ebffd321a3ad4d70358e3518176de3a56b"}, @@ -4710,7 +4755,7 @@ visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.2.0)"] name = "torchvision" version = "0.17.1" description = "image and video datasets and models for torch deep learning" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "torchvision-0.17.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:06418880212b66e45e855dd39f536e7fd48b4e6b034a11dd9fe9e2384afb51ec"}, @@ -5053,9 +5098,8 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [extras] github-actions = ["pytest-github-actions-annotate-failures"] graph = ["networkx"] -mqf2 = ["cpflows"] [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.11" -content-hash = "acf0ee98a7ed5f9c84477905279c6eb686064b786d51ba5e5dab7501cb3252f9" +content-hash = "63e6eae67fe328b80bfa092d1ce44663d22a9bf90afc756b98b04232ba376b5b" diff --git a/pyproject.toml b/pyproject.toml index 826fcf26..c0d5064a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,15 +58,18 @@ python = ">=3.8,<3.11" torch = "^2.0.0,!=2.0.1" lightning = "^2.0.0" optuna = "^3.1.0" +optuna-integration="*" scipy = "^1.8" pandas = ">=1.3.0,<=3.0.0" +pyarrow = "*" scikit-learn = "^1.2" matplotlib = "*" +tensorboard = "^2.12.1" +cpflows = "^0.1.2" statsmodels = "*" pytest-github-actions-annotate-failures = { version = "*", optional = true } networkx = { version = "^3.0.0", optional = true } -cpflows = { version = "^0.1.2", optional = true } fastapi = ">=0.80" pytorch-optimizer = "^2.5.1" @@ -108,7 +111,6 @@ pandoc = "^2.3" [tool.poetry.extras] # extras github-actions = ["pytest-github-actions-annotate-failures"] graph = ["networkx"] -mqf2 = ["cpflows"] [tool.poetry-dynamic-versioning] enable = true diff --git a/pytorch_forecasting/data/examples.py b/pytorch_forecasting/data/examples.py index 80409170..03ae566a 100644 --- a/pytorch_forecasting/data/examples.py +++ b/pytorch_forecasting/data/examples.py @@ -3,10 +3,10 @@ """ from pathlib import Path +from urllib.request import urlretrieve import numpy as np import pandas as pd -import requests BASE_URL = "https://github.com/jdb78/pytorch-forecasting/raw/master/examples/data/" @@ -28,9 +28,7 @@ def _get_data_by_filename(fname: str) -> Path: # check if file exists - download if necessary if not full_fname.exists(): url = BASE_URL + fname - download = requests.get(url, allow_redirects=True) - with open(full_fname, "wb") as file: - file.write(download.content) + urlretrieve(url, full_fname) return full_fname diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 7ee1524e..2f7e6e5b 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -18,7 +18,7 @@ from lightning.pytorch.utilities.parsing import AttributeDict, get_init_args import matplotlib.pyplot as plt import numpy as np -from numpy.lib.function_base import iterable +from numpy import iterable import pandas as pd import pytorch_optimizer from pytorch_optimizer import Ranger21