Skip to content

Commit

Permalink
refactor: 🚧 Match train code to API
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Apr 23, 2024
1 parent de3b5a5 commit 68114c7
Show file tree
Hide file tree
Showing 7 changed files with 780 additions and 347 deletions.
9 changes: 6 additions & 3 deletions src/quac/training/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@ class ClassifierWrapper(torch.nn.Module):
for how to convert a model to torchscript.
"""

def __init__(self, model_checkpoint, mean, std):
def __init__(self, model_checkpoint, mean: None, std: None):
"""Wraps a torchscript model, and applies normalization."""
super().__init__()
self.model = torch.jit.load(model_checkpoint)
self.model.eval()
self.transform = transforms.Normalize(mean, std)
if mean is None:
self.transform = lambda x: x

def forward(self, x):
def forward(self, x, assume_normalized=False):
"""Assumes that x is between -1 and 1."""
# TODO it would be even better if the range was between 0 and 1 so we wouldn't have to do the below
x = (x + 1) / 2
if not assume_normalized:
x = (x + 1) / 2
x = self.transform(x)
return self.model(x)
63 changes: 47 additions & 16 deletions src/quac/training/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel
from typing import Optional


class ModelConfig(BaseModel):
Expand All @@ -9,24 +10,54 @@ class ModelConfig(BaseModel):


class DataConfig(BaseModel):
train_img_dir: str
source: str
reference: str
validation: str
img_size: int = 128
batch_size: int = 16
randcrop_prob: float = 0.0
num_workers: int = 4
grayscale: bool = False
latent_dim: int = 16


class RunConfig(BaseModel):
resume_iter: int = 0
total_iter: int = 100000
log_every: int = 1000
save_every: int = 10000
eval_every: int = 10000


class ValConfig(BaseModel):
classifier_checkpoint: str
num_outs_per_domain: int = 10
mean: Optional[float] = 0.5
std: Optional[float] = 0.5


class LossConfig(BaseModel):
lambda_ds: float = 1.0
lambda_sty: float = 1.0
lambda_cyc: float = 1.0
lambda_reg: float = 1.0
ds_iter: int = 100000


class SolverConfig(BaseModel):
checkpoint_dir: str
f_lr: float = 1e-4
lr: float = 1e-4
beta1: float = 0.5
beta2: float = 0.99
weight_decay: float = 0.1


class TrainConfig(BaseModel):
f_lr: float = 1e-4 # Learning rate for the mapping network
lr: float = 1e-4 # Learning rate for the other networks
beta1: float = 0.5 # Beta1 for Adam optimizer
beta2: float = 0.999 # Beta2 for Adam optimizer
weight_decay: float = 1e-4 # Weight decay for Adam optimizer
latent_dim: int = 16 # Latent dimension for the mapping network
resume_iter: int = 0 # Iteration to resume training from
lamdba_ds: float = 1.0 # Weight for the diversity sensitive loss
total_iters: int = 100000 # Total number of iterations to train the model
ds_iter: int = 1000 # Number of iterations to optimize the diversity sensitive loss
log_every: int = 1000 # How often (iterations) to log training progress
save_every: int = 10000 # How often (iterations) to save the model
eval_every: int = 10000 # How often (iterations) to evaluate the model
class ExperimentConfig(BaseModel):
# Some input required
data: DataConfig
solver: SolverConfig
val: ValConfig
# No input required
model: ModelConfig = ModelConfig()
run: RunConfig = RunConfig()
loss: LossConfig = LossConfig()
200 changes: 194 additions & 6 deletions src/quac/training/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,13 @@ def get_eval_loader(
root,
img_size=256,
batch_size=32,
imagenet_normalize=True,
imagenet_normalize=False,
shuffle=True,
num_workers=4,
drop_last=False,
grayscale=False,
mean=0.5,
std=0.5,
):
print("Preparing DataLoader for the evaluation phase...")
if imagenet_normalize:
Expand All @@ -203,8 +205,11 @@ def get_eval_loader(
std = [0.229, 0.224, 0.225]
else:
height, width = img_size, img_size
mean = 0.5
std = 0.5

if mean is not None:
normalize = transforms.Normalize(mean=mean, std=std)
else:
normalize = transforms.Lambda(lambda x: x)

transform_list = []
if grayscale:
Expand All @@ -215,7 +220,7 @@ def get_eval_loader(
*transform_list,
transforms.Resize([height, width]),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
normalize,
]
)

Expand Down Expand Up @@ -303,7 +308,7 @@ def __next__(self):
z_trg2=z_trg2,
)
elif self.mode == "val":
x_ref, y_ref = self._fetch_inputs()
x_ref, y_ref = self._fetch_refs()
inputs = Munch(x_src=x, y_src=y, x_ref=x_ref, y_ref=y_ref)
elif self.mode == "test":
inputs = Munch(x=x, y=y)
Expand Down Expand Up @@ -342,11 +347,194 @@ def __next__(self):
z_trg2=z_trg2,
)
elif self.mode == "val":
x_ref, _, y_ref = self._fetch_inputs()
x_ref, _, y_ref = self._fetch_refs()
inputs = Munch(x_src=x, y_src=y, x_ref=x_ref, y_ref=y_ref)
elif self.mode == "test":
inputs = Munch(x=x, y=y)
else:
raise NotImplementedError

return Munch({k: v.to(self.device) for k, v in inputs.items()})


class TrainingData:
def __init__(
self,
source,
reference,
img_size=128,
batch_size=8,
num_workers=4,
grayscale=False,
latent_dim=16,
):
self.latent_dim = latent_dim
self.src = get_train_loader(
root=source,
which="source",
img_size=img_size,
batch_size=batch_size,
num_workers=num_workers,
grayscale=grayscale,
)
self.reference = get_train_loader(
root=reference,
which="reference",
img_size=img_size,
batch_size=batch_size,
num_workers=num_workers,
grayscale=grayscale,
)

# NOTE: Made these properties so that they are recomputed
# Not sure if this is necessary
@property
def data_fetcher(self):
return AugmentedInputFetcher(
self.src, self.reference, mode="train", latent_dim=self.latent_dim
)


class ValidationData:
"""
A data loader for validation.
"""

def __init__(
self,
source_directory,
ref_directory=None,
mode="latent",
image_size=128,
batch_size=32,
num_workers=4,
grayscale=False,
mean=None,
std=None,
):
"""
Parameters
----------
source_directory : str
The directory containing the source images.
ref_directory : str
The directory containing the reference images, defaults to source_directory if None.
mode : str
The mode of the data loader, either "latent" or "reference".
If "latent", the data loader will only load the source images.
If "reference", the data loader will load both the source and reference images.
image_size : int
The size of the images; images of a different size will be resized.
batch_size : int
The batch size for source data.
num_workers : int
The number of workers for the data loader.
grayscale : bool
Whether the images are grayscale.
"""
assert mode in ["latent", "reference"]
# parameters
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers
self.grayscale = grayscale
self.mean = mean
self.std = std
# The source and target classes
self.source = None
self.target = None
# The roots of the source and target directories
self.source_root = Path(source_directory)
if ref_directory is not None:
self.ref_root = Path(ref_directory)
else:
self.ref_root = source_directory

# Available classes
self.available_sources = [
subdir.name
for subdir in Path(source_directory).iterdir()
if subdir.is_dir()
]
self._available_targets = None
self.set_mode(mode)

def set_mode(self, mode):
assert mode in ["latent", "reference"]
self.mode = mode

@property
def available_targets(self):
if self.mode == "latent":
return self.available_sources
elif self._available_targets is None:
self._available_targets = [
subdir.name
for subdir in Path(self.ref_root).iterdir()
if subdir.is_dir()
]
return self._available_targets

def set_target(self, target):
assert (
target in self.available_targets
), f"{target} not in {self.available_targets}"
self.target = target

def set_source(self, source):
assert (
source in self.available_sources
), f"{source} not in {self.available_sources}"
self.source = source

@property
def reference_directory(self):
if self.mode == "latent":
return None
if self.target is None:
raise (ValueError("Target not set."))
return self.ref_root / self.target

@property
def source_directory(self):
if self.source is None:
raise (ValueError("Source not set."))
return self.source_root / self.source

def print_info(self):
print(f"Avaliable sources: {self.available_sources}")
print(f"Avaliable targets: {self.available_targets}")
print(f"Mode: {self.mode}")
try:
print(f"Current source directory: {self.source_directory}")
except ValueError:
print("Source not set.")
try:
print(f"Current target directory: {self.reference_directory}")
except ValueError:
print("Target not set.")

@property
def loader_src(self):
return get_eval_loader(
self.source_directory,
img_size=self.image_size,
batch_size=self.batch_size,
num_workers=self.num_workers,
grayscale=self.grayscale,
mean=self.mean,
std=self.std,
)

@property
def loader_ref(self):
return get_eval_loader(
self.reference_directory,
img_size=self.image_size,
batch_size=self.batch_size,
num_workers=self.num_workers,
grayscale=self.grayscale,
mean=self.mean,
std=self.std,
)
Loading

0 comments on commit 68114c7

Please sign in to comment.