Skip to content

Commit

Permalink
fix: 🎨 Debug using fictus experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Apr 28, 2024
1 parent 68114c7 commit 22c4f35
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 123 deletions.
7 changes: 6 additions & 1 deletion src/quac/training/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from torchvision import transforms


class Identity(torch.nn.Module):
def forward(self, x):
return x


class ClassifierWrapper(torch.nn.Module):
"""
This class expects a torchscript model. See [here](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format)
Expand All @@ -15,7 +20,7 @@ def __init__(self, model_checkpoint, mean: None, std: None):
self.model.eval()
self.transform = transforms.Normalize(mean, std)
if mean is None:
self.transform = lambda x: x
self.transform = Identity()

def forward(self, x, assume_normalized=False):
"""Assumes that x is between -1 and 1."""
Expand Down
21 changes: 15 additions & 6 deletions src/quac/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ class ModelConfig(BaseModel):
style_dim: int = 64
latent_dim: int = 16
num_domains: int = 5
input_dim: int = 3
final_activation: str = "tanh"


class DataConfig(BaseModel):
source: str
reference: str
validation: str
img_size: int = 128
batch_size: int = 16
batch_size: int = 1
num_workers: int = 4
grayscale: bool = False
latent_dim: int = 16


class RunConfig(BaseModel):
resume_iter: int = 0
total_iter: int = 100000
total_iters: int = 100000
log_every: int = 1000
save_every: int = 10000
eval_every: int = 10000
Expand All @@ -33,6 +33,9 @@ class ValConfig(BaseModel):
num_outs_per_domain: int = 10
mean: Optional[float] = 0.5
std: Optional[float] = 0.5
img_size: int = 128
val_batch_size: int = 16
assume_normalized: bool = False


class LossConfig(BaseModel):
Expand All @@ -44,7 +47,7 @@ class LossConfig(BaseModel):


class SolverConfig(BaseModel):
checkpoint_dir: str
root_dir: str
f_lr: float = 1e-4
lr: float = 1e-4
beta1: float = 0.5
Expand All @@ -53,10 +56,16 @@ class SolverConfig(BaseModel):


class ExperimentConfig(BaseModel):
# Metadata for keeping track of experiments
project: str = "default"
name: str = "default"
notes: str = ""
tags: list = []
# Some input required
data: DataConfig
solver: SolverConfig
val: ValConfig
validation_data: DataConfig
validation_config: ValConfig
# No input required
model: ModelConfig = ModelConfig()
run: RunConfig = RunConfig()
Expand Down
32 changes: 10 additions & 22 deletions src/quac/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, root, transform=None):
self.transform = transform

def _make_dataset(self, root):
domains = os.listdir(root)
domains = glob.glob(os.path.join(root, "*"))
fnames, fnames2, labels = [], [], []
for idx, domain in enumerate(sorted(domains)):
class_dir = os.path.join(root, domain)
Expand Down Expand Up @@ -366,9 +366,7 @@ def __init__(
batch_size=8,
num_workers=4,
grayscale=False,
latent_dim=16,
):
self.latent_dim = latent_dim
self.src = get_train_loader(
root=source,
which="source",
Expand All @@ -386,14 +384,6 @@ def __init__(
grayscale=grayscale,
)

# NOTE: Made these properties so that they are recomputed
# Not sure if this is necessary
@property
def data_fetcher(self):
return AugmentedInputFetcher(
self.src, self.reference, mode="train", latent_dim=self.latent_dim
)


class ValidationData:
"""
Expand All @@ -403,10 +393,10 @@ class ValidationData:

def __init__(
self,
source_directory,
ref_directory=None,
source,
reference=None,
mode="latent",
image_size=128,
img_size=128,
batch_size=32,
num_workers=4,
grayscale=False,
Expand Down Expand Up @@ -435,7 +425,7 @@ def __init__(
"""
assert mode in ["latent", "reference"]
# parameters
self.image_size = image_size
self.image_size = img_size
self.batch_size = batch_size
self.num_workers = num_workers
self.grayscale = grayscale
Expand All @@ -445,17 +435,15 @@ def __init__(
self.source = None
self.target = None
# The roots of the source and target directories
self.source_root = Path(source_directory)
if ref_directory is not None:
self.ref_root = Path(ref_directory)
self.source_root = Path(source)
if reference is not None:
self.ref_root = Path(reference)
else:
self.ref_root = source_directory
self.ref_root = self.source_root

# Available classes
self.available_sources = [
subdir.name
for subdir in Path(source_directory).iterdir()
if subdir.is_dir()
subdir.name for subdir in self.source_root.iterdir() if subdir.is_dir()
]
self._available_targets = None
self.set_mode(mode)
Expand Down
79 changes: 64 additions & 15 deletions src/quac/training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,28 @@

import numpy as np
from pathlib import Path
from starganv2.metrics.conversion import calculate_conversion_given_path
from starganv2.core import utils
from quac.training import utils
from quac.training.classification import ClassifierWrapper
from quac.training.data_loader import get_eval_loader
import torch
from tqdm import tqdm


@torch.no_grad()
def calculate_metrics(
eval_dir,
step,
mode,
classifier_checkpoint,
img_size,
val_batch_size,
num_outs_per_domain,
mean,
std,
step=0,
mode="latent",
classifier_checkpoint=None,
img_size=128,
val_batch_size=16,
num_outs_per_domain=10,
mean=None,
std=None,
input_dim=3,
run=None,
):
print("Calculating conversion rate for all tasks...")
print("Calculating conversion rate for all tasks...", flush=True)
translation_rate_values = (
OrderedDict()
) # How many output images are of the right class
Expand All @@ -42,17 +45,17 @@ def calculate_metrics(
domains = [subdir.name for subdir in Path(eval_dir).iterdir() if subdir.is_dir()]

for subdir in Path(eval_dir).iterdir():
if not subdir.is_dir() or subdir.startswith("."): # Skip hidden files
if not subdir.is_dir() or subdir.name.startswith("."): # Skip hidden files
continue
src_domain = subdir.name

for subdir2 in Path(subdir).iterdir():
if not subdir2.is_dir() or subdir2.startswith("."):
if not subdir2.is_dir() or subdir2.name.startswith("."):
continue
trg_domain = subdir2.name

task = "%s_to_%s" % (src_domain, trg_domain)
print("Calculating conversion rate for %s..." % task)
task = "%s/%s" % (src_domain, trg_domain)
print("Calculating conversion rate for %s..." % task, flush=True)
target_class = domains.index(trg_domain)

translation_rate, conversion_rate = calculate_conversion_given_path(
Expand Down Expand Up @@ -90,3 +93,49 @@ def calculate_metrics(
# report translation rate values
filename = os.path.join(eval_dir, "translation_rate_%.5i_%s.json" % (step, mode))
utils.save_json(translation_rate_values, filename)
if run is not None:
run.log(conversion_rate_values, step=step)
run.log(translation_rate_values, step=step)


@torch.no_grad()
def calculate_conversion_given_path(
path,
model_checkpoint,
target_class,
img_size=128,
batch_size=50,
num_outs_per_domain=10,
mean=0.5,
std=0.5,
grayscale=False,
):
print("Calculating conversion given path %s..." % path, flush=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classifier = ClassifierWrapper(model_checkpoint, mean=mean, std=std)
classifier.to(device)
classifier.eval()

loader = get_eval_loader(
path,
img_size=img_size,
batch_size=batch_size,
imagenet_normalize=False,
shuffle=False,
grayscale=grayscale,
)

predictions = []
for x in tqdm(loader, total=len(loader)):
x = x.to(device)
predictions.append(classifier(x).cpu().numpy())
predictions = np.concatenate(predictions, axis=0)
# Do it in a vectorized way, by reshaping the predictions
predictions = predictions.reshape(-1, num_outs_per_domain, predictions.shape[-1])
predictions = predictions.argmax(axis=-1)
#
at_least_one = np.any(predictions == target_class, axis=1)
#
conversion_rate = np.mean(at_least_one) # (sum(at_least_one) / len(at_least_one)
translation_rate = np.mean(predictions == target_class)
return translation_rate, conversion_rate
Loading

0 comments on commit 22c4f35

Please sign in to comment.