Skip to content

Commit

Permalink
Allow pytorch 2.0 & pytorch-lightning 2.0 via Copilot Workspace
Browse files Browse the repository at this point in the history
Fixes borchero#54

Update the package to support PyTorch 2.0 and PyTorch-Lightning 2.0.

* **pyproject.toml**
  - Update `torch` version to "^2.0.0".
  - Update `pytorch-lightning` version to "^2.0.0".

* **.github/workflows/ci.yml**
  - Add Python 3.11 to the matrix of Python versions for unit tests.

* **pycave/bayes/gmm/lightning_module.py**
  - Update import for `EarlyStopping` to ensure compatibility with PyTorch-Lightning 2.0.

* **pycave/bayes/markov_chain/estimator.py**
  - Add import for `Trainer` from `pytorch_lightning`.

* **pycave/utils/lightning_module.py**
  - Update import for `pytorch_lightning` to `lightning.pytorch` to ensure compatibility with PyTorch-Lightning 2.0.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/borchero/pycave/issues/54?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
EvenStrangest committed Sep 28, 2024
1 parent 3d25c64 commit 9e0aeb5
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion pycave/bayes/gmm/lightning_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torchmetrics import MeanMetric
from pycave.bayes.core import cholesky_precision
from pycave.utils import NonparametricLightningModule
Expand Down
1 change: 1 addition & 0 deletions pycave/bayes/markov_chain/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightkit.data import DataLoader, dataset_from_tensors
from torch.nn.utils.rnn import PackedSequence
from torch.utils.data import Dataset
from pytorch_lightning import Trainer
from .lightning_module import MarkovChainLightningModule
from .model import MarkovChainModel, MarkovChainModelConfig
from .types import collate_sequences, collate_sequences_same_length, SequenceData
Expand Down
2 changes: 1 addition & 1 deletion pycave/utils/lightning_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from typing import List
import pytorch_lightning as pl
import lightning.pytorch as pl
import torch
from torch import nn

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ version = "0.0.0"
lightkit = "^0.5.0"
numpy = "^1.20.3"
python = ">=3.8,<3.11"
pytorch-lightning = "^1.6.0"
torch = "^1.8.0"
pytorch-lightning = "^2.0.0"
torch = "^2.0.0"
torchmetrics = ">=0.6,<0.12"

[tool.poetry.group.pre-commit.dependencies]
Expand Down

0 comments on commit 9e0aeb5

Please sign in to comment.