Skip to content

Commit

Permalink
[MNT] Fix dependency issues and CI runners: numpy2, optuna, `requ…
Browse files Browse the repository at this point in the history
…ests`, MacOS MPS

Fixes #1594, fixes #1595, fixes #1596

Added or moved some dependencies to core dependency set.

Fixed some `numpy2` and `optuna-integrations` problems.

`requests` replaced by `urllib.request.urlretrieve`.
  • Loading branch information
fkiraly authored Aug 25, 2024
2 parents 5fb441b + f93b1e2 commit d82eaaf
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 18 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macOS-latest] # add windows-2019 when poetry allows installation with `-f` flag
os: [ubuntu-latest, macos-13] # add windows-2019 when poetry allows installation with `-f` flag
python-version: ["3.8", "3.9", "3.10"]
defaults:
run:
Expand Down Expand Up @@ -62,7 +62,7 @@ jobs:
run: poetry run python -m pip install pip -U

- name: Install dependencies
run: poetry install -E "github-actions graph mqf2"
run: poetry install -E "github-actions graph"

# - name: Install pytorch geometric dependencies
# shell: bash
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ lightning >=2.0.0
cloudpickle
torch >=2.0,!=2.0.1
optuna >=3.1.0
optuna-integration
scipy
pandas >=1.3
scikit-learn >1.2
Expand Down
62 changes: 53 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,18 @@ python = ">=3.8,<3.11"
torch = "^2.0.0,!=2.0.1"
lightning = "^2.0.0"
optuna = "^3.1.0"
optuna-integration="*"
scipy = "^1.8"
pandas = ">=1.3.0,<=3.0.0"
pyarrow = "*"
scikit-learn = "^1.2"
matplotlib = "*"
tensorboard = "^2.12.1"
cpflows = "^0.1.2"
statsmodels = "*"

pytest-github-actions-annotate-failures = { version = "*", optional = true }
networkx = { version = "^3.0.0", optional = true }
cpflows = { version = "^0.1.2", optional = true }
fastapi = ">=0.80"
pytorch-optimizer = "^2.5.1"

Expand Down Expand Up @@ -108,7 +111,6 @@ pandoc = "^2.3"
[tool.poetry.extras] # extras
github-actions = ["pytest-github-actions-annotate-failures"]
graph = ["networkx"]
mqf2 = ["cpflows"]

[tool.poetry-dynamic-versioning]
enable = true
Expand Down
6 changes: 2 additions & 4 deletions pytorch_forecasting/data/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
"""

from pathlib import Path
from urllib.request import urlretrieve

import numpy as np
import pandas as pd
import requests

BASE_URL = "https://github.com/jdb78/pytorch-forecasting/raw/master/examples/data/"

Expand All @@ -28,9 +28,7 @@ def _get_data_by_filename(fname: str) -> Path:
# check if file exists - download if necessary
if not full_fname.exists():
url = BASE_URL + fname
download = requests.get(url, allow_redirects=True)
with open(full_fname, "wb") as file:
file.write(download.content)
urlretrieve(url, full_fname)

return full_fname

Expand Down
2 changes: 1 addition & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lightning.pytorch.utilities.parsing import AttributeDict, get_init_args
import matplotlib.pyplot as plt
import numpy as np
from numpy.lib.function_base import iterable
from numpy import iterable
import pandas as pd
import pytorch_optimizer
from pytorch_optimizer import Ranger21
Expand Down

0 comments on commit d82eaaf

Please sign in to comment.