Skip to content

Commit

Permalink
Merge pull request #206 from TeamEpochGithub/v0.3.7
Browse files Browse the repository at this point in the history
V0.3.7
  • Loading branch information
Gregoire-Andre-Dumont authored Jun 24, 2024
2 parents 7e2c7b6 + f7158f0 commit 4fc7221
Show file tree
Hide file tree
Showing 25 changed files with 797 additions and 178 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/branch-name-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Branch Name Check

on:
pull_request:
branches:
- main

jobs:
check-branch-name:
runs-on: ubuntu-latest

steps:
- name: Check if the branch name follows the version pattern
run: |
if [[ ! ${{ github.head_ref }} =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "Error: Branch name ${{ github.head_ref }} does not follow the required version pattern vX.Y.Z"
exit 1
fi
2 changes: 1 addition & 1 deletion .github/workflows/main-branch-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
venv/bin/python -m pip install --upgrade pip
venv/bin/python -m pip install pytest
venv/bin/python -m pip install -r requirements.txt
venv/bin/python -m pip install -r requirements-dev.lock
venv/bin/python -m pip install pytest-cov coverage
- name: Test with pytest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/version-branch-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
run: |
venv/bin/python -m pip install --upgrade pip
venv/bin/python -m pip install pytest
venv/bin/python -m pip install -r requirements.txt
venv/bin/python -m pip install -r requirements-dev.lock
venv/bin/python -m pip install pytest-cov coverage
- name: Test with pytest
Expand Down
67 changes: 48 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
exclude: ^(venv/|.venv/|tests/)
exclude: ^(external/|venv/|.venv/|tests/|.cache)
repos:
- repo: local # Remove this when a new version of pre-commit-hooks (>4.6.0) is released
hooks:
- id: check-illegal-windows-names
name: check illegal windows names
entry: Illegal Windows filenames detected
language: fail
files: '(?i)((^|/)(CON|PRN|AUX|NUL|COM[\d¹²³]|LPT[\d¹²³])(\.|/|$)|[<>:\"\\|?*\x00-\x1F]|/[^/]*[\.\s]/|[^/]*[\.\s]$)'
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: check-ast
- id: check-added-large-files
- id: check-builtin-literals
- id: check-executables-have-shebangs
# - id: check-illegal-windows-names # Uncomment this when a new version of pre-commit-hooks (>4.6.0) is released
- id: check-json
- id: pretty-format-json
args: [--autofix, --no-sort-keys]
args: [ "--autofix", "--no-sort-keys" ]
- id: check-merge-conflict
- id: check-shebang-scripts-are-executable
- id: check-symlinks
- id: check-toml
- id: check-xml
- id: check-yaml
- id: detect-private-key
- id: destroyed-symlinks
- id: mixed-line-ending
- id: end-of-file-fixer
- id: fix-byte-order-marker
- id: name-tests-test
args: [ "--pytest-test-first" ]
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.16
rev: v0.18
hooks:
- id: validate-pyproject
- repo: https://github.com/citation-file-format/cffconvert
rev: b6045d78aac9e02b039703b030588d54d53262ac
hooks:
- id: validate-cff
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
Expand All @@ -32,23 +53,31 @@ repos:
- id: sphinx-lint
args: [ --enable=default-role ]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
rev: v0.4.10
hooks:
- id: ruff
- id: ruff-format
- id: ruff
# Pydoclint is useful, but gives too many false positives
# - repo: https://github.com/jsh9/pydoclint
# rev: 0.4.1
# hooks:
# - id: pydoclint
# args: [ --quiet ]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.10.0
hooks:
- id: mypy
additional_dependencies:
- agogos
- annotated-types
- dask
- matplotlib
- numpy
- polars
- torch
- traitlets
- timm
- kornia
args: [ --disallow-any-generics, --disallow-untyped-defs, --disable-error-code=import-untyped]
- numpy==1.26.4
- pandas-stubs>=2.2.2.240514
- matplotlib==3.8.4
- torch==2.3.1
- dask==2024.6.2
- typing_extensions==4.9.0
- annotated-types==0.7.0
- polars==0.20.31
- kornia==0.7.2
- timm==1.0.7
args:
- "--fast-module-lookup"
- "--disallow-any-generics"
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.14
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ authors:
- family-names: "Kopar"
given-names: "Cahit Tolga"
affiliation: "TU Delft Dream Team Epoch"
email: "https://github.com/tolgakopar"
email: "cahittolgakopar@gmail.com"
- family-names: "Selm"
name-particle: "van"
given-names: "Jasper"
Expand Down
12 changes: 6 additions & 6 deletions epochalyst/_core/_caching/_cacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
try:
import polars as pl
except ImportError:
"""User doen't require these packages"""
"""User doesn't require these packages"""

from epochalyst._core._logging._logger import _Logger

if sys.version_info < (3, 11):
if sys.version_info < (3, 11): # pragma: no cover (<py311)
from typing_extensions import NotRequired
else:
else: # pragma: no cover (py311+)
from typing import NotRequired


Expand Down Expand Up @@ -197,7 +197,7 @@ def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any: #
if output_data_type == "polars_dataframe":
return pl.read_parquet(storage_path + name + ".parquet", **read_args)

self.log_to_debug(
self.log_to_debug( # type: ignore[unreachable]
f"Invalid output data type: {output_data_type}, for loading .parquet file.",
)
raise ValueError(
Expand Down Expand Up @@ -239,7 +239,7 @@ def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any: #
with open(storage_path + name + ".pkl", "rb") as file:
return pickle.load(file, **read_args) # noqa: S301

self.log_to_debug(f"Invalid storage type: {storage_type}")
self.log_to_debug(f"Invalid storage type: {storage_type}") # type: ignore[unreachable]
raise ValueError(
"storage_type must be .npy, .parquet, .csv, or .npy_stack, other types not supported yet",
)
Expand Down Expand Up @@ -348,7 +348,7 @@ def _store_cache(self, name: str, data: Any, cache_args: CacheArgs | None = None
**({"protocol": pickle.HIGHEST_PROTOCOL} | store_args),
)
else:
self.log_to_debug(f"Invalid storage type: {storage_type}")
self.log_to_debug(f"Invalid storage type: {storage_type}") # type: ignore[unreachable]
raise ValueError(
"storage_type must be .npy, .parquet, .csv or .npy_stack, other types not supported yet",
)
1 change: 1 addition & 0 deletions epochalyst/_core/_logging/_logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""_Logger add abstract logging functionality to other classes."""

from abc import abstractmethod
from typing import Any

Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""EnsemblePipeline for ensebling multiple ModelPipelines."""

from typing import Any

from agogos.training import ParallelTrainingSystem
Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""ModelPipeline connects multiple transforming and training systems for extended training functionality."""

from typing import Any

from agogos.training import Pipeline
Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/training/models/timm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Timm model for 2D image classification."""

import torch
from torch import nn

Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/training/pretrain_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""PretrainBlock to implement modules such as scalers."""

from abc import abstractmethod
from dataclasses import dataclass
from typing import Any
Expand Down
46 changes: 27 additions & 19 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""TorchTrainer is a module that allows for the training of Torch models."""

import copy
import functools
import gc
from collections.abc import Callable
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
from typing import Annotated, Any, TypeVar
from typing import Annotated, Any, Literal, TypeVar

import numpy as np
import numpy.typing as npt
import torch
from annotated_types import Gt, Interval
from annotated_types import Ge, Gt, Interval
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -163,7 +165,7 @@ def log_to_terminal(self, message: str) -> None:

# Training parameters
epochs: Annotated[int, Gt(0)] = 10
patience: Annotated[int, Gt(0)] = 5 # Early stopping
patience: Annotated[int, Gt(0)] = -1 # Early stopping
batch_size: Annotated[int, Gt(0)] = 32
collate_fn: Callable[[tuple[Tensor, ...]], tuple[Tensor, ...]] = field(default=custom_collate, init=True, repr=False, compare=False)

Expand All @@ -174,18 +176,22 @@ def log_to_terminal(self, message: str) -> None:

# Misc
model_name: str | None = None # No spaces allowed
trained_models_directory: Path = Path("tm")
to_predict: str = "validation"
trained_models_directory: PathLike[str] = field(default=Path("tm/"), repr=False, compare=False)
to_predict: Literal["validation", "all", "none"] = field(default="validation", repr=False, compare=False)

# Parameters relevant for Hashing
n_folds: float = field(default=-1, init=True, repr=False, compare=False)
n_folds: Annotated[int, Ge(0)] = field(default=-1, init=True, repr=False, compare=False)
_fold: int = field(default=-1, init=False, repr=False, compare=False)
validation_size: Annotated[float, Interval(ge=0, le=1)] = 0.2

# Types for tensors
x_tensor_type: str = "float"
y_tensor_type: str = "float"

# Prefix and postfix for logging to external
logging_prefix: str = field(default="", init=True, repr=False, compare=False)
logging_postfix: str = field(default="", init=True, repr=False, compare=False)

def __post_init__(self) -> None:
"""Post init method for the TorchTrainer class."""
# Make sure to_predict is either "validation" or "all" or "none"
Expand Down Expand Up @@ -570,8 +576,8 @@ def _training_loop(
if fold > -1:
fold_no = f"_{fold}"

self.external_define_metric(f"Training/Train Loss{fold_no}", "epoch")
self.external_define_metric(f"Validation/Validation Loss{fold_no}", "epoch")
self.external_define_metric(self.wrap_log(f"Training/Train Loss{fold_no}"), self.wrap_log("epoch"))
self.external_define_metric(self.wrap_log(f"Validation/Validation Loss{fold_no}"), self.wrap_log("epoch"))

# Set the scheduler to the correct epoch
if self.initialized_scheduler is not None:
Expand All @@ -586,8 +592,8 @@ def _training_loop(
# Log train loss
self.log_to_external(
message={
f"Training/Train Loss{fold_no}": train_losses[-1],
"epoch": epoch,
self.wrap_log(f"Training/Train Loss{fold_no}"): train_losses[-1],
self.wrap_log("epoch"): epoch,
},
)

Expand Down Expand Up @@ -616,8 +622,8 @@ def _training_loop(
# Log validation loss and plot train/val loss against each other
self.log_to_external(
message={
f"Validation/Validation Loss{fold_no}": val_losses[-1],
"epoch": epoch,
self.wrap_log(f"Validation/Validation Loss{fold_no}"): val_losses[-1],
self.wrap_log("epoch"): epoch,
},
)

Expand All @@ -631,21 +637,19 @@ def _training_loop(
), # Ensure it's a list, not a range object
"ys": [train_losses, val_losses],
"keys": [f"Train{fold_no}", f"Validation{fold_no}"],
"title": f"Training/Loss{fold_no}",
"title": self.wrap_log(f"Training/Loss{fold_no}"),
"xname": "Epoch",
},
},
)

# Early stopping
if self._early_stopping():
self.log_to_external(
message={f"Epochs{fold_no}": (epoch + 1) - self.patience},
)
self.log_to_external(message={self.wrap_log(f"Epochs{fold_no}"): (epoch + 1) - self.patience})
break

# Log the trained epochs to wandb if we finished training
self.log_to_external(message={f"Epochs{fold_no}": epoch + 1})
self.log_to_external(message={self.wrap_log(f"Epochs{fold_no}"): epoch + 1})

def _train_one_epoch(
self,
Expand Down Expand Up @@ -815,15 +819,19 @@ def get_model_path(self) -> Path:
:return: The model path.
"""
return Path(f"{self.trained_models_directory}/{self.get_hash()}.pt")
return Path(self.trained_models_directory) / f"{self.get_hash()}.pt"

def get_model_checkpoint_path(self, epoch: int) -> Path:
"""Get the checkpoint path.
:param epoch: The epoch number.
:return: The checkpoint path.
"""
return Path(f"{self.trained_models_directory}/{self.get_hash()}_checkpoint_{epoch}.pt")
return Path(self.trained_models_directory) / f"{self.get_hash()}_checkpoint_{epoch}.pt"

def wrap_log(self, text: str) -> str:
"""Add logging prefix and postfix to the message."""
return f"{self.logging_prefix}{text}{self.logging_postfix}"


class TrainValidationDataset(Dataset[T_co]):
Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/training/training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TrainingPipeline for creating a sequential pipeline of TrainType classes."""

from typing import Any

from agogos.training import TrainingSystem, TrainType
Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/training/training_block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TrainingBlock that can be inherited from to make blocks for a training pipeline."""

from abc import abstractmethod
from typing import Any

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module with tensor functions."""

import torch
from torch import Tensor

Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/transformation/transformation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TransformationPipeline that extends from TransformingSystem, _Cacher and _Logger."""

from dataclasses import dataclass
from typing import Any

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TransformationBlock module than can be extended by implementing the custom_transform method."""

from abc import abstractmethod
from typing import Any

Expand Down
Loading

0 comments on commit 4fc7221

Please sign in to comment.