diff --git a/conf/ae/cosine_annealing.yaml b/conf/ae_opt/cosine_annealing.yaml similarity index 100% rename from conf/ae/cosine_annealing.yaml rename to conf/ae_opt/cosine_annealing.yaml diff --git a/conf/experiment/nicopp/rn18/only_pred_y_loss.yaml b/conf/experiment/nicopp/rn18/only_pred_y_loss.yaml index 5c209987..10e8305e 100644 --- a/conf/experiment/nicopp/rn18/only_pred_y_loss.yaml +++ b/conf/experiment/nicopp/rn18/only_pred_y_loss.yaml @@ -1,7 +1,7 @@ # @package _global_ defaults: - - /ae: cosine_annealing + - /ae_opt: cosine_annealing - /alg: only_pred_y_loss - /eval: nicopp - override /ae_arch: resnet diff --git a/poetry.lock b/poetry.lock index 0f856ad2..75db8f40 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2120,13 +2120,13 @@ typing-extensions = "*" [[package]] name = "ranzen" -version = "2.4.2" +version = "2.5.1" description = "A toolkit facilitating machine-learning experimentation." optional = false -python-versions = ">=3.10,<3.13" +python-versions = "<3.13,>=3.10" files = [ - {file = "ranzen-2.4.2-py3-none-any.whl", hash = "sha256:ec16544b3f996f9d7d83ac4427b21653769935784541698ef689583f1a936366"}, - {file = "ranzen-2.4.2.tar.gz", hash = "sha256:295a00c3e97a7e6ec5de63b2a0fe6879ef142d516d942eb52471ed95590da67a"}, + {file = "ranzen-2.5.1-py3-none-any.whl", hash = "sha256:5ec6b78aa8c7c8fffc4a3437bdb491ef0036ae5f5af089dd664d2000007e5e8a"}, + {file = "ranzen-2.5.1.tar.gz", hash = "sha256:6a9e21a44c9795000cc08cdf579968cf79a54b97824b2d5d9aca4ded986d2bdc"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 133f4c5b..3dee9ba5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ numpy = { version = ">=1.23.2" } pandas = { version = ">=1.5.0" } pillow = "*" python = ">=3.10,<3.13" -ranzen = { version = "^2.1.2" } +ranzen = { version = "^2.5.0" } scikit-image = ">=0.14" scikit_learn = { version = ">=0.20.1" } scipy = { version = ">=1.2.1" } diff --git a/src/algs/adv/scorer.py b/src/algs/adv/scorer.py index 777af815..c5e72e29 100644 --- a/src/algs/adv/scorer.py +++ b/src/algs/adv/scorer.py @@ -66,6 +66,7 @@ def balanced_accuracy(y_pred: Tensor, *, y_true: Tensor) -> Tensor: return cdtm.subclass_balanced_accuracy(y_pred=y_pred, y_true=y_true, s=y_true) +@dataclass(eq=False) class Scorer(ABC): @abstractmethod def run( diff --git a/src/arch/predictors/base.py b/src/arch/predictors/base.py index c32cbc8d..c913091e 100644 --- a/src/arch/predictors/base.py +++ b/src/arch/predictors/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TypeVar from typing_extensions import TypeAliasType @@ -10,6 +11,7 @@ PredictorFactoryOut = TypeAliasType("PredictorFactoryOut", tuple[M, int], type_params=(M,)) +@dataclass(eq=False) class PredictorFactory(ABC): @abstractmethod def __call__( diff --git a/src/data/common.py b/src/data/common.py index 6d0dc9fa..d86d8fc5 100644 --- a/src/data/common.py +++ b/src/data/common.py @@ -75,6 +75,7 @@ def num_samples_te(self) -> int: return len(self.test) +@dataclass class DatasetFactory(ABC): @abstractmethod def __call__(self) -> Dataset: diff --git a/src/labelling/pipeline.py b/src/labelling/pipeline.py index a0f75bba..4a488713 100644 --- a/src/labelling/pipeline.py +++ b/src/labelling/pipeline.py @@ -38,6 +38,7 @@ ] +@dataclass(repr=False, eq=False) class Labeller(ABC): @abstractmethod def run(self, dm: DataModule, device: torch.device) -> Tensor | None: diff --git a/src/relay/base.py b/src/relay/base.py index 320bbdeb..bc59d561 100644 --- a/src/relay/base.py +++ b/src/relay/base.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field import os -from typing import Any, ClassVar +from typing import ClassVar from loguru import logger from ranzen.torch import random_seed @@ -18,7 +18,7 @@ @dataclass(eq=False, kw_only=True) class BaseRelay: dm: DataModuleConf = field(default_factory=DataModuleConf) - split: Any + split: DataSplitter wandb: WandbConf = field(default_factory=WandbConf) seed: int = 0 @@ -36,8 +36,6 @@ def init_dm( labeller: Labeller, device: torch.device, ) -> DataModule: - assert isinstance(self.split, DataSplitter) - logger.info(f"Current working directory: '{os.getcwd()}'") random_seed(self.seed, use_cuda=True) torch.multiprocessing.set_sharing_strategy("file_system") diff --git a/src/relay/fs.py b/src/relay/fs.py index 103f33c7..ae0639b6 100644 --- a/src/relay/fs.py +++ b/src/relay/fs.py @@ -37,11 +37,11 @@ class FsRelay(BaseRelay): ] ) - alg: Any - ds: Any - backbone: Any + alg: FsAlg + ds: DatasetFactory + backbone: BackboneFactory predictor: Fcn = field(default_factory=Fcn) - labeller: Any + labeller: Labeller options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "ds": { @@ -70,11 +70,6 @@ class FsRelay(BaseRelay): } def run(self, raw_config: dict[str, Any] | None = None) -> float | None: - assert isinstance(self.alg, FsAlg) - assert isinstance(self.backbone, BackboneFactory) - assert isinstance(self.ds, DatasetFactory) - assert isinstance(self.labeller, Labeller) - ds = self.ds() run = self.wandb.init(raw_config, (ds, self.labeller, self.backbone, self.predictor)) dm = self.init_dm(ds, self.labeller, device=self.alg.device) diff --git a/src/relay/label.py b/src/relay/label.py index 6bc13e25..dade77af 100644 --- a/src/relay/label.py +++ b/src/relay/label.py @@ -29,8 +29,8 @@ class LabelRelay(BaseRelay): ] ) - ds: Any # CdtDataset - labeller: Any # Labeller + ds: DatasetFactory + labeller: Labeller gpu: int = 0 options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { @@ -50,9 +50,6 @@ class LabelRelay(BaseRelay): } def run(self, raw_config: dict[str, Any] | None = None) -> float | None: - assert isinstance(self.ds, DatasetFactory) - assert isinstance(self.labeller, Labeller) - ds = self.ds() run = self.wandb.init(raw_config, (ds, self.labeller)) device = resolve_device(self.gpu) diff --git a/src/relay/mimin.py b/src/relay/mimin.py index 436aa138..ca4fcfc7 100644 --- a/src/relay/mimin.py +++ b/src/relay/mimin.py @@ -37,14 +37,14 @@ class MiMinRelay(BaseRelay): ) alg: MiMin = field(default_factory=MiMin) - ae_arch: Any + ae_arch: AeFactory disc_arch: Fcn = field(default_factory=Fcn) disc: OptimizerCfg = field(default_factory=OptimizerCfg) eval: Evaluator = field(default_factory=Evaluator) ae: SplitAeCfg = field(default_factory=SplitAeCfg) ae_opt: OptimizerCfg = field(default_factory=OptimizerCfg) - ds: Any - labeller: Any + ds: DatasetFactory + labeller: Labeller options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | { "ds": { @@ -70,10 +70,6 @@ class MiMinRelay(BaseRelay): } def run(self, raw_config: dict[str, Any] | None = None) -> None: - assert isinstance(self.ae_arch, AeFactory) - assert isinstance(self.ds, DatasetFactory) - assert isinstance(self.labeller, Labeller) - ds = self.ds() run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch)) dm = self.init_dm(ds, self.labeller, device=self.alg.device) diff --git a/src/relay/split.py b/src/relay/split.py index 41bc0f17..c3add4cf 100644 --- a/src/relay/split.py +++ b/src/relay/split.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import Any, ClassVar -from src.data import RandomSplitter +from src.data import DataSplitter, RandomSplitter from src.data.common import DatasetFactory from src.data.factories import NICOPPCfg from src.data.nih import NIHChestXRayDatasetCfg @@ -15,8 +15,8 @@ class SplitRelay: defaults: list[Any] = field(default_factory=lambda: [{"ds": "celeba"}, {"split": "random"}]) - ds: Any # CdtDataset - split: Any + ds: DatasetFactory + split: DataSplitter wandb: WandbConf = field(default_factory=WandbConf) options: ClassVar[dict[str, dict[str, type]]] = { @@ -30,7 +30,6 @@ class SplitRelay: } def run(self, raw_config: dict[str, Any] | None = None) -> None: - assert isinstance(self.ds, DatasetFactory) assert isinstance(self.split, RandomSplitter) ds = self.ds() diff --git a/src/relay/supmatch.py b/src/relay/supmatch.py index 38cab6e4..e9e9a787 100644 --- a/src/relay/supmatch.py +++ b/src/relay/supmatch.py @@ -57,13 +57,13 @@ class SupMatchRelay(BaseRelay): alg: SupportMatching = field(default_factory=SupportMatching) ae: SplitAeCfg = field(default_factory=SplitAeCfg) ae_opt: OptimizerCfg = field(default_factory=OptimizerCfg) - ae_arch: Any # AeFactory - ds: Any # DatasetFactory - disc_arch: Any # PredictorFactory + ae_arch: AeFactory + ds: DatasetFactory + disc_arch: PredictorFactory disc: DiscOptimizerCfg = field(default_factory=DiscOptimizerCfg) eval: Evaluator = field(default_factory=Evaluator) - labeller: Any # Labeller - scorer: Any # Scorer + labeller: Labeller + scorer: Scorer artifact_name: str | None = None """Save model weights under this name.""" @@ -96,12 +96,6 @@ class SupMatchRelay(BaseRelay): } def run(self, raw_config: dict[str, Any] | None = None) -> float | None: - assert isinstance(self.ae_arch, AeFactory) - assert isinstance(self.disc_arch, PredictorFactory) - assert isinstance(self.ds, DatasetFactory) - assert isinstance(self.labeller, Labeller) - assert isinstance(self.scorer, Scorer) - ds = self.ds() run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch)) dm = self.init_dm(ds, self.labeller, device=self.alg.device)