Skip to content

Commit

Permalink
Use real type annotations in the hydra config
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 committed Mar 23, 2024
1 parent c26c763 commit 94e1b6f
Show file tree
Hide file tree
Showing 14 changed files with 30 additions and 46 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion conf/experiment/nicopp/rn18/only_pred_y_loss.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

defaults:
- /ae: cosine_annealing
- /ae_opt: cosine_annealing
- /alg: only_pred_y_loss
- /eval: nicopp
- override /ae_arch: resnet
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ numpy = { version = ">=1.23.2" }
pandas = { version = ">=1.5.0" }
pillow = "*"
python = ">=3.10,<3.13"
ranzen = { version = "^2.1.2" }
ranzen = { version = "^2.5.0" }
scikit-image = ">=0.14"
scikit_learn = { version = ">=0.20.1" }
scipy = { version = ">=1.2.1" }
Expand Down
1 change: 1 addition & 0 deletions src/algs/adv/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def balanced_accuracy(y_pred: Tensor, *, y_true: Tensor) -> Tensor:
return cdtm.subclass_balanced_accuracy(y_pred=y_pred, y_true=y_true, s=y_true)


@dataclass(eq=False)
class Scorer(ABC):
@abstractmethod
def run(
Expand Down
2 changes: 2 additions & 0 deletions src/arch/predictors/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TypeVar
from typing_extensions import TypeAliasType

Expand All @@ -10,6 +11,7 @@
PredictorFactoryOut = TypeAliasType("PredictorFactoryOut", tuple[M, int], type_params=(M,))


@dataclass(eq=False)
class PredictorFactory(ABC):
@abstractmethod
def __call__(
Expand Down
1 change: 1 addition & 0 deletions src/data/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def num_samples_te(self) -> int:
return len(self.test)


@dataclass
class DatasetFactory(ABC):
@abstractmethod
def __call__(self) -> Dataset:
Expand Down
1 change: 1 addition & 0 deletions src/labelling/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
]


@dataclass(repr=False, eq=False)
class Labeller(ABC):
@abstractmethod
def run(self, dm: DataModule, device: torch.device) -> Tensor | None:
Expand Down
6 changes: 2 additions & 4 deletions src/relay/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
import os
from typing import Any, ClassVar
from typing import ClassVar

from loguru import logger
from ranzen.torch import random_seed
Expand All @@ -18,7 +18,7 @@
@dataclass(eq=False, kw_only=True)
class BaseRelay:
dm: DataModuleConf = field(default_factory=DataModuleConf)
split: Any
split: DataSplitter
wandb: WandbConf = field(default_factory=WandbConf)
seed: int = 0

Expand All @@ -36,8 +36,6 @@ def init_dm(
labeller: Labeller,
device: torch.device,
) -> DataModule:
assert isinstance(self.split, DataSplitter)

logger.info(f"Current working directory: '{os.getcwd()}'")
random_seed(self.seed, use_cuda=True)
torch.multiprocessing.set_sharing_strategy("file_system")
Expand Down
13 changes: 4 additions & 9 deletions src/relay/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class FsRelay(BaseRelay):
]
)

alg: Any
ds: Any
backbone: Any
alg: FsAlg
ds: DatasetFactory
backbone: BackboneFactory
predictor: Fcn = field(default_factory=Fcn)
labeller: Any
labeller: Labeller

options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | {
"ds": {
Expand Down Expand Up @@ -70,11 +70,6 @@ class FsRelay(BaseRelay):
}

def run(self, raw_config: dict[str, Any] | None = None) -> float | None:
assert isinstance(self.alg, FsAlg)
assert isinstance(self.backbone, BackboneFactory)
assert isinstance(self.ds, DatasetFactory)
assert isinstance(self.labeller, Labeller)

ds = self.ds()
run = self.wandb.init(raw_config, (ds, self.labeller, self.backbone, self.predictor))
dm = self.init_dm(ds, self.labeller, device=self.alg.device)
Expand Down
7 changes: 2 additions & 5 deletions src/relay/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class LabelRelay(BaseRelay):
]
)

ds: Any # CdtDataset
labeller: Any # Labeller
ds: DatasetFactory
labeller: Labeller
gpu: int = 0

options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | {
Expand All @@ -50,9 +50,6 @@ class LabelRelay(BaseRelay):
}

def run(self, raw_config: dict[str, Any] | None = None) -> float | None:
assert isinstance(self.ds, DatasetFactory)
assert isinstance(self.labeller, Labeller)

ds = self.ds()
run = self.wandb.init(raw_config, (ds, self.labeller))
device = resolve_device(self.gpu)
Expand Down
10 changes: 3 additions & 7 deletions src/relay/mimin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ class MiMinRelay(BaseRelay):
)

alg: MiMin = field(default_factory=MiMin)
ae_arch: Any
ae_arch: AeFactory
disc_arch: Fcn = field(default_factory=Fcn)
disc: OptimizerCfg = field(default_factory=OptimizerCfg)
eval: Evaluator = field(default_factory=Evaluator)
ae: SplitAeCfg = field(default_factory=SplitAeCfg)
ae_opt: OptimizerCfg = field(default_factory=OptimizerCfg)
ds: Any
labeller: Any
ds: DatasetFactory
labeller: Labeller

options: ClassVar[dict[str, dict[str, type]]] = BaseRelay.options | {
"ds": {
Expand All @@ -70,10 +70,6 @@ class MiMinRelay(BaseRelay):
}

def run(self, raw_config: dict[str, Any] | None = None) -> None:
assert isinstance(self.ae_arch, AeFactory)
assert isinstance(self.ds, DatasetFactory)
assert isinstance(self.labeller, Labeller)

ds = self.ds()
run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch))
dm = self.init_dm(ds, self.labeller, device=self.alg.device)
Expand Down
7 changes: 3 additions & 4 deletions src/relay/split.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from typing import Any, ClassVar

from src.data import RandomSplitter
from src.data import DataSplitter, RandomSplitter
from src.data.common import DatasetFactory
from src.data.factories import NICOPPCfg
from src.data.nih import NIHChestXRayDatasetCfg
Expand All @@ -15,8 +15,8 @@
class SplitRelay:
defaults: list[Any] = field(default_factory=lambda: [{"ds": "celeba"}, {"split": "random"}])

ds: Any # CdtDataset
split: Any
ds: DatasetFactory
split: DataSplitter
wandb: WandbConf = field(default_factory=WandbConf)

options: ClassVar[dict[str, dict[str, type]]] = {
Expand All @@ -30,7 +30,6 @@ class SplitRelay:
}

def run(self, raw_config: dict[str, Any] | None = None) -> None:
assert isinstance(self.ds, DatasetFactory)
assert isinstance(self.split, RandomSplitter)

ds = self.ds()
Expand Down
16 changes: 5 additions & 11 deletions src/relay/supmatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ class SupMatchRelay(BaseRelay):
alg: SupportMatching = field(default_factory=SupportMatching)
ae: SplitAeCfg = field(default_factory=SplitAeCfg)
ae_opt: OptimizerCfg = field(default_factory=OptimizerCfg)
ae_arch: Any # AeFactory
ds: Any # DatasetFactory
disc_arch: Any # PredictorFactory
ae_arch: AeFactory
ds: DatasetFactory
disc_arch: PredictorFactory
disc: DiscOptimizerCfg = field(default_factory=DiscOptimizerCfg)
eval: Evaluator = field(default_factory=Evaluator)
labeller: Any # Labeller
scorer: Any # Scorer
labeller: Labeller
scorer: Scorer
artifact_name: str | None = None
"""Save model weights under this name."""

Expand Down Expand Up @@ -96,12 +96,6 @@ class SupMatchRelay(BaseRelay):
}

def run(self, raw_config: dict[str, Any] | None = None) -> float | None:
assert isinstance(self.ae_arch, AeFactory)
assert isinstance(self.disc_arch, PredictorFactory)
assert isinstance(self.ds, DatasetFactory)
assert isinstance(self.labeller, Labeller)
assert isinstance(self.scorer, Scorer)

ds = self.ds()
run = self.wandb.init(raw_config, (ds, self.labeller, self.ae_arch, self.disc_arch))
dm = self.init_dm(ds, self.labeller, device=self.alg.device)
Expand Down

0 comments on commit 94e1b6f

Please sign in to comment.