Skip to content

Commit

Permalink
ACS experiments (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 authored Mar 19, 2024
2 parents e1b77e7 + 1dbb67e commit 431b9a3
Show file tree
Hide file tree
Showing 37 changed files with 188 additions and 149 deletions.
1 change: 1 addition & 0 deletions analysis/wandb_downloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Download logged runs from W&B."""

from collections.abc import Sequence
from functools import lru_cache
from typing import Union
Expand Down
9 changes: 6 additions & 3 deletions conf/dm/acs.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
stratified_sampler: approx_group
num_workers: 1
batch_size_tr: 128
# stratified_sampler: approx_group
stratified_sampler: exact
num_workers: 4
# batch_size_tr: 128
batch_size_tr: 32
batch_size_te: 100000
num_samples_per_group_per_bag: 1
5 changes: 3 additions & 2 deletions conf/eval/nicopp.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
batch_size: 1
balanced_sampling: true
hidden_dim: null
num_hidden: 1
model:
hidden_dim: null
num_hidden: 1
steps: 10000
opt:
lr: 1.e-4
Expand Down
41 changes: 33 additions & 8 deletions conf/experiment/acs/fcn.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

defaults:
- /alg: supmatch_no_disc
# - /alg: supmatch_no_disc
- override /dm: acs
- override /ds: acs/employment_dis_fl
- override /split: acs/employment_dis
Expand All @@ -11,12 +11,18 @@ defaults:
alg:
use_amp: False
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
optimizer_cls: ${ ae_opt.optimizer_cls }
# weight_decay: ${ ae_opt.weight_decay }
scheduler_cls: ${ ae_opt.scheduler_cls }
scheduler_kwargs:
T_max: ${ ae_opt.scheduler_kwargs.T_max }
eta_min: ${ ae_opt.scheduler_kwargs.eta_min }
steps: 10000
val_freq: 1000
log_freq: ${ alg.steps }
# num_disc_updates: 3
# disc_loss_w: 0.03
num_disc_updates: 3
disc_loss_w: 0.03
# ga_steps: 1
# max_grad_norm: null

Expand All @@ -27,12 +33,31 @@ ae:
ae_opt:
lr: 1.e-4
optimizer_cls: ADAM
weight_decay: 0
weight_decay: 1.e-2
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
scheduler_kwargs:
T_max: ${ alg.steps }
eta_min: 1.e-6

ae_arch:
hidden_dim: 64
hidden_dim: 95
latent_dim: 64
num_hidden: 2
num_hidden: 3
dropout_prob: 0.1

split:
seed: ${ seed }

eval:
batch_size: 128
steps: 5000
batch_size: 32
balanced_sampling: true
model:
num_hidden: 0
opt:
lr: 5.e-4
optimizer_cls: ADAM
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
scheduler_kwargs:
T_max: ${ eval.steps }
eta_min: 5e-7
7 changes: 4 additions & 3 deletions conf/experiment/celeba/rn50/pretrained_enc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ ae_opt:
alg:
use_amp: true
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
log_freq: ${ alg.steps }
val_freq: 200
num_disc_updates: 5
Expand Down Expand Up @@ -49,8 +49,9 @@ dm:
eval:
batch_size: 10
balanced_sampling: true
hidden_dim: null
num_hidden: 1
model:
hidden_dim: null
num_hidden: 1
steps: 10000
opt:
lr: 1.e-4
Expand Down
5 changes: 3 additions & 2 deletions conf/experiment/celeba/sm/northern_resonance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ eval:
batch_size: 12
balanced_sampling: true
steps: 10000
num_hidden: 1
hidden_dim: null
model:
num_hidden: 1
hidden_dim: null
opt:
lr: 1.e-4
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
Expand Down
5 changes: 3 additions & 2 deletions conf/experiment/celeba/sm/pt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ eval:
batch_size: 12
balanced_sampling: true
steps: 10000
num_hidden: 1
hidden_dim: null
model:
num_hidden: 1
hidden_dim: null
opt:
lr: 1.e-4
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
Expand Down
2 changes: 1 addition & 1 deletion conf/experiment/nicopp/rn18/only_pred_y_loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ alg:
steps: 30000
use_amp: true
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
log_freq: 100000000000 # never
val_freq: 1000

Expand Down
2 changes: 1 addition & 1 deletion conf/experiment/nicopp/rn18/pretrained_enc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ae_opt:
alg:
use_amp: true
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
log_freq: 100000000000 # never
val_freq: 200
num_disc_updates: 5
Expand Down
2 changes: 1 addition & 1 deletion conf/experiment/nicopp/rn50/only_pred_y_loss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ alg:
steps: 30000
use_amp: true
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
log_freq: 100000000000 # never
val_freq: 1000

Expand Down
2 changes: 1 addition & 1 deletion conf/experiment/nicopp/rn50/pretrained_enc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ae_opt:
alg:
use_amp: true
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
log_freq: ${ alg.steps }
val_freq: 200
num_disc_updates: 5
Expand Down
2 changes: 1 addition & 1 deletion conf/experiment/nicopp/vqgan/from_pt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ae_opt:
alg:
use_amp: true
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
steps: 30000
val_freq: 1000
# val_freq: ${ alg.steps }
Expand Down
2 changes: 1 addition & 1 deletion conf/experiment/nicopp/vqgan/from_pt_dropout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ ae_opt:
alg:
use_amp: true
pred:
lr: ${ ae.lr }
lr: ${ ae_opt.lr }
steps: 30000
val_freq: 1000
# log_freq: ${ alg.steps }
Expand Down
3 changes: 2 additions & 1 deletion conf/experiment/nih/sb10_mimin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ eval:
batch_size: 12
balanced_sampling: true
steps: 10000
num_hidden: 1
model:
num_hidden: 1
opt:
lr: 1.e-4
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
Expand Down
5 changes: 3 additions & 2 deletions conf/experiment/nih/sm/sb10_gender_infiltration_pt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ eval:
batch_size: 12
balanced_sampling: true
steps: 10000
num_hidden: 1
hidden_dim: null
model:
num_hidden: 1
hidden_dim: null
opt:
lr: 1.e-4
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
Expand Down
5 changes: 3 additions & 2 deletions conf/experiment/nih/sm/sb10_sm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ eval:
batch_size: 12
balanced_sampling: true
steps: 10000
num_hidden: 1
hidden_dim: null
model:
num_hidden: 1
hidden_dim: null
opt:
lr: 1.e-4
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
Expand Down
5 changes: 3 additions & 2 deletions conf/experiment/nih/sm/sb10_sm_winter_pine.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ eval:
batch_size: 12
balanced_sampling: true
steps: 10000
num_hidden: 1
hidden_dim: null
model:
num_hidden: 1
hidden_dim: null
opt:
lr: 1.e-4
scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
Expand Down
2 changes: 1 addition & 1 deletion conf/hydra/sweeper/sm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ params:
alg.num_disc_updates: range(3, 6)
alg.disc_loss_w: interval(1.e-1, 1.0)
alg.twoway_disc_loss: choice(true, false)
ae.lr: tag(log, interval(1.e-6, 1.e-4))
ae_opt.lr: tag(log, interval(1.e-6, 1.e-4))
disc.lr: tag(log, interval(1.e-6, 1.e-4))
disc.criterion: choice(LOGISTIC_NS)
# disc_arch.num_hidden_pre: choice(1, 2)
Expand Down
40 changes: 20 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pandas-stubs = "*"
python-type-stubs = { git = "https://github.com/wearepal/python-type-stubs.git", rev = "8d5f608" }

[tool.poetry.group.lint.dependencies]
ruff = "*"
ruff = ">=0.3.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
21 changes: 6 additions & 15 deletions src/algs/adv/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import umap
import wandb

from src.arch.common import Activation
from src.arch.predictors import Fcn
from src.data import DataModule, Dataset, group_id_to_label, labels_to_group_id, resolve_device
from src.evaluation.metrics import EmEvalPair, compute_metrics
Expand Down Expand Up @@ -75,8 +74,7 @@ def encode_dataset(
device: str | torch.device,
segment: Literal["zs"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], None]:
...
) -> InvariantDatasets[Dataset[Tensor], None]: ...


@overload
Expand All @@ -87,8 +85,7 @@ def encode_dataset(
device: str | torch.device,
segment: Literal["zy"] = ...,
use_amp: bool = False,
) -> InvariantDatasets[None, Dataset[Tensor]]:
...
) -> InvariantDatasets[None, Dataset[Tensor]]: ...


@overload
Expand All @@ -99,8 +96,7 @@ def encode_dataset(
device: str | torch.device,
segment: Literal["both"],
use_amp: bool = False,
) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]:
...
) -> InvariantDatasets[Dataset[Tensor], Dataset[Tensor]]: ...


def encode_dataset(
Expand Down Expand Up @@ -222,24 +218,18 @@ def _flip(items: Sequence[Any], ncol: int) -> Sequence[Any]:
class Evaluator:
steps: int = 10_000
batch_size: int = 128
hidden_dim: int | None = None
num_hidden: int = 0
eval_s_from_zs: EvalTrainData | None = None
balanced_sampling: bool = True
umap_viz: bool = False
save_summary: bool = True

activation: Activation = Activation.GELU
model: Fcn = field(default_factory=Fcn)
opt: OptimizerCfg = field(default_factory=OptimizerCfg)
"""Optimization parameters."""

def _fit_classifier(
self, dm: DataModule, *, pred_s: bool, input_dim: int, device: torch.device
) -> Classifier:
model_fn = Fcn(
hidden_dim=self.hidden_dim, num_hidden=self.num_hidden, activation=self.activation
)
model, _ = model_fn(input_dim, target_dim=dm.card_y)
model, _ = self.model(input_dim, target_dim=dm.card_y)

clf = Classifier(model, opt=self.opt)

Expand All @@ -253,6 +243,7 @@ def _fit_classifier(
device=torch.device(device),
pred_s=pred_s,
use_wandb=False,
val_interval=1.1, # never
)

return clf
Expand Down
1 change: 1 addition & 0 deletions src/arch/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Modules that aggregate over a batch."""

from dataclasses import dataclass, field
from typing_extensions import override

Expand Down
Loading

0 comments on commit 431b9a3

Please sign in to comment.