-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
35 changed files
with
1,525 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# 3DR-GAN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/usr/bin/env python | ||
from setuptools import setup | ||
|
||
setup(name="style_transfer", | ||
version="0.1", | ||
author="Danish, Manas, Hansal", | ||
zip_safe=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import argparse | ||
import logging | ||
import os,sys | ||
from typing import Type | ||
import random | ||
from tqdm import tqdm | ||
|
||
import torch | ||
import numpy as np | ||
from torch import nn, optim | ||
from torch.utils.data import DataLoader | ||
|
||
from style_transfer.data.datasets import ShapenetDataset | ||
from style_transfer.models.base_nn import GraphConvClf | ||
from style_transfer.config import Config | ||
from style_transfer.utils.torch_utils import train_val_split, save_checkpoint, accuracy | ||
|
||
import warnings | ||
warnings.filterwarnings("ignore") | ||
# -------------------------------------------------------------------------------------------- | ||
# Argument Parser | ||
# -------------------------------------------------------------------------------------------- | ||
parser = argparse.ArgumentParser("Run training for a particular phase.") | ||
parser.add_argument( | ||
"--config-yml", required=True, help="Path to a config file for specified phase." | ||
) | ||
parser.add_argument( | ||
"--config-override", | ||
default=[], | ||
nargs="*", | ||
help="A sequence of key-value pairs specifying certain config arguments (with dict-like " | ||
"nesting) using a dot operator. The actual config will be updated and recorded in " | ||
"the results directory.", | ||
) | ||
|
||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
# -------------------------------------------------------------------------------------------- | ||
# INPUT ARGUMENTS AND CONFIG | ||
# -------------------------------------------------------------------------------------------- | ||
_A = parser.parse_args() | ||
|
||
# Create a config with default values, then override from config file, and _A. | ||
# This config object is immutable, nothing can be changed in this anymore. | ||
_C = Config(_A.config_yml, _A.config_override) | ||
|
||
# Print configs and args. | ||
print(_C) | ||
for arg in vars(_A): | ||
print("{:<20}: {}".format(arg, getattr(_A, arg))) | ||
|
||
# Create serialization directory and save config in it. | ||
os.makedirs(_C.CKP.experiment_path, exist_ok=True) | ||
_C.dump(os.path.join(_C.CKP.experiment_path, "config.yml")) | ||
|
||
# For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html | ||
# These five lines control all the major sources of randomness. | ||
np.random.seed(_C.RANDOM_SEED) | ||
torch.manual_seed(_C.RANDOM_SEED) | ||
torch.cuda.manual_seed_all(_C.RANDOM_SEED) | ||
torch.backends.cudnn.benchmark = False | ||
torch.backends.cudnn.deterministic = True | ||
|
||
device = torch.device("cuda:0") | ||
_C.DEVICE = device | ||
|
||
# -------------------------------------------------------------------------------------------- | ||
# INSTANTIATE DATALOADER, MODEL, OPTIMIZER & CRITERION | ||
# -------------------------------------------------------------------------------------------- | ||
## Datasets | ||
trn_objs, val_objs = train_val_split(config=_C) | ||
collate_fn = ShapenetDataset.collate_fn | ||
|
||
if _C.OVERFIT: | ||
trn_objs, val_objs = trn_objs[:10], val_objs[:10] | ||
|
||
trn_dataset = ShapenetDataset(_C, trn_objs) | ||
trn_dataloader = DataLoader(trn_dataset, | ||
batch_size=_C.OPTIM.BATCH_SIZE, | ||
shuffle=True, | ||
collate_fn=collate_fn, | ||
num_workers=_C.OPTIM.WORKERS) | ||
|
||
val_dataset = ShapenetDataset(_C, val_objs) | ||
val_dataloader = DataLoader(val_dataset, | ||
batch_size=_C.OPTIM.VAL_BATCH_SIZE, | ||
shuffle=True, | ||
collate_fn=collate_fn, | ||
num_workers=_C.OPTIM.WORKERS) | ||
|
||
print("Training Samples: "+str(len(trn_dataloader))) | ||
print("Validation Samples: "+str(len(val_dataloader))) | ||
|
||
model = GraphConvClf(_C).cuda() | ||
model.load_state_dict(torch.load('results/exp_03_16_11_22_19_10classes/[email protected]')['state_dict']) | ||
|
||
# optimizer = optim.SGD( | ||
# model.parameters(), | ||
# lr=_C.OPTIM.LR, | ||
# momentum=_C.OPTIM.MOMENTUM, | ||
# weight_decay=_C.OPTIM.WEIGHT_DECAY, | ||
# ) | ||
optimizer = optim.Adam( | ||
model.parameters(), | ||
lr=_C.OPTIM.LR, | ||
) | ||
# lr_scheduler = optim.lr_scheduler.LambdaLR( # type: ignore | ||
# optimizer, lr_lambda=lambda iteration: 1 - iteration / _C.OPTIM.NUM_ITERATIONS | ||
# ) | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
args = {} | ||
args['EXPERIMENT_NAME'] = _C.EXPERIMENT_NAME | ||
args['full_experiment_name'] = _C.CKP.full_experiment_name | ||
args['experiment_path'] = _C.CKP.experiment_path | ||
args['best_loss'] = _C.CKP.best_loss | ||
args['best_acc'] = _C.CKP.best_acc | ||
# -------------------------------------------------------------------------------------------- | ||
# TRAINING LOOP | ||
# -------------------------------------------------------------------------------------------- | ||
total_step = len(trn_dataloader) | ||
print('\n ***************** Training *****************') | ||
for epoch in tqdm(range(4, _C.OPTIM.EPOCH)): | ||
# -------------------------------------------------------------------------------------------- | ||
# TRAINING | ||
# -------------------------------------------------------------------------------------------- | ||
running_loss = 0.0 | ||
print('Epoch: '+str(epoch)) | ||
model.train() | ||
|
||
for i, data in enumerate(tqdm(trn_dataloader), 0): | ||
if data[0] == None and data[1] == None: | ||
continue | ||
label = data[0].cuda() | ||
mesh = data[1].cuda() | ||
# zero the parameter gradients | ||
optimizer.zero_grad() | ||
# forward + backward + optimize | ||
outputs = model(mesh) | ||
#print(outputs, label) | ||
if outputs.size()[0] == label.size()[0]: | ||
loss = criterion(outputs, label) | ||
loss.backward() | ||
optimizer.step() | ||
#lr_scheduler.step() | ||
# print statistics | ||
running_loss += loss.item() | ||
else: | ||
print('Shape Mismatch') | ||
print(outputs.size(), label.size()) | ||
print(mesh.verts_packed_to_mesh_idx().unique(return_counts=True)[1]) | ||
running_loss /= len(trn_dataloader) | ||
print('\n\tTraining Loss: '+ str(running_loss)) | ||
|
||
# ---------------------------------------------------------------------------------------- | ||
# VALIDATION | ||
# ---------------------------------------------------------------------------------------- | ||
model.eval() | ||
val_loss = 0.0 | ||
val_acc = 0.0 | ||
print("\n\n\tEvaluating..") | ||
for i, data in enumerate(tqdm(val_dataloader), 0): | ||
if data[0] == None and data[1] == None: | ||
continue | ||
label = data[0].cuda() | ||
mesh = data[1].cuda() | ||
with torch.no_grad(): | ||
batch_prediction = model(mesh) | ||
if batch_prediction.size()[0] == label.size()[0]: | ||
loss = criterion(batch_prediction, label) | ||
acc = accuracy(batch_prediction, label) | ||
val_loss += loss.item() | ||
val_acc += np.sum(acc) | ||
else: | ||
print('Shape Mismatch') | ||
print(batch_prediction.size(), label.size()) | ||
print(mesh.verts_packed_to_mesh_idx().unique(return_counts=True)[1]) | ||
# Average out the loss | ||
val_loss /= len(val_dataloader) | ||
val_acc /= len(val_dataloader) | ||
print('\n\tValidation Loss: '+str(val_loss)) | ||
print('\tValidation Acc: '+str(val_acc.item())) | ||
# Final save of the model | ||
args = save_checkpoint(model = model, | ||
optimizer = optimizer, | ||
curr_epoch = epoch, | ||
curr_loss = val_loss, | ||
curr_step = (total_step * epoch), | ||
args = args, | ||
curr_acc = val_acc.item(), | ||
trn_loss = running_loss, | ||
filename = ('model@epoch%d.pkl' %(epoch))) | ||
|
||
print('---------------------------------------------------------------------------------------\n') | ||
print('Finished Training') | ||
print('Best Accuracy on validation',args['best_acc']) | ||
print('Best Loss on validation',args['best_loss']) |
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
r"""This module provides package-wide configuration management.""" | ||
from typing import Any, List | ||
import time | ||
import os, sys | ||
from yacs.config import CfgNode as CN | ||
|
||
class Config(object): | ||
r""" | ||
A collection of all the required configuration parameters. This class is a nested dict-like | ||
structure, with nested keys accessible as attributes. It contains sensible default values for | ||
all the parameters, which may be overriden by (first) through a YAML file and (second) through | ||
a list of attributes and values. | ||
Extended Summary | ||
---------------- | ||
This class definition contains default values corresponding to ``joint_training`` phase, as it | ||
is the final training phase and uses almost all the configuration parameters. Modification of | ||
any parameter after instantiating this class is not possible, so you must override required | ||
parameter values in either through ``config_yaml`` file or ``config_override`` list. | ||
Parameters | ||
---------- | ||
config_yaml: str | ||
Path to a YAML file containing configuration parameters to override. | ||
config_override: List[Any], optional (default= []) | ||
A list of sequential attributes and values of parameters to override. This happens after | ||
overriding from YAML file. | ||
Examples | ||
-------- | ||
Let a YAML file named "config.yaml" specify these parameters to override:: | ||
ALPHA: 1000.0 | ||
BETA: 0.5 | ||
>>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) | ||
>>> _C.ALPHA # default: 100.0 | ||
1000.0 | ||
>>> _C.BATCH_SIZE # default: 256 | ||
2048 | ||
>>> _C.BETA # default: 0.1 | ||
0.7 | ||
""" | ||
|
||
def __init__(self, config_yaml: str, config_override: List[Any] = []): | ||
self._C = CN() | ||
self._C.RANDOM_SEED = 0 | ||
self._C.PHASE = "training" | ||
self._C.EXPERIMENT_NAME = "default" | ||
self._C.RESULTS_DIR = "results" | ||
self._C.OVERFIT= False | ||
|
||
self._C.SHAPENET_DATA = CN() | ||
self._C.SHAPENET_DATA.PATH = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/ShapeNetCore.v1/' | ||
# self._C.SHAPENET_DATA.TRANSFORM = None | ||
|
||
self._C.OPTIM = CN() | ||
self._C.OPTIM.BATCH_SIZE = 4 | ||
self._C.OPTIM.VAL_BATCH_SIZE = 16 | ||
self._C.OPTIM.WORKERS = 4 | ||
self._C.OPTIM.EPOCH = 2 | ||
self._C.OPTIM.LR = 0.015 | ||
self._C.OPTIM.MOMENTUM = 0.9 | ||
self._C.OPTIM.WEIGHT_DECAY = 0.001 | ||
self._C.OPTIM.CLIP_GRADIENTS = 12.5 | ||
|
||
self._C.GCC = CN() | ||
self._C.GCC.INPUT_MESH_FEATS = 3 | ||
self._C.GCC.HIDDEN_DIMS = [32, 64, 128] | ||
self._C.GCC.CLASSES = 57 | ||
self._C.GCC.CONV_INIT = "normal" | ||
|
||
self._C.merge_from_file(config_yaml) | ||
self._C.merge_from_list(config_override) | ||
|
||
self._C.CKP = CN() | ||
self._C.CKP.full_experiment_name = ("exp_%s_%s" % ( time.strftime("%m_%d_%H_%M_%S"), self._C.EXPERIMENT_NAME) ) | ||
self._C.CKP.experiment_path = os.path.join(self._C.RESULTS_DIR, self._C.CKP.full_experiment_name) | ||
self._C.CKP.best_loss = sys.float_info.max | ||
self._C.CKP.best_acc = 0. | ||
|
||
# Make an instantiated object of this class immutable. | ||
self._C.freeze() | ||
|
||
def dump(self, file_path: str): | ||
r"""Save config at the specified file path. | ||
Parameters | ||
---------- | ||
file_path: str | ||
(YAML) path to save config at. | ||
""" | ||
self._C.dump(stream=open(file_path, "w")) | ||
|
||
def __getattr__(self, attr: str): | ||
return self._C.__getattr__(attr) | ||
|
||
def __str__(self): | ||
common_string: str = str(CN({"RANDOM_SEED": self._C.RANDOM_SEED})) + "\n" | ||
common_string += str(CN({"DATA": self._C.SHAPENET_DATA})) + "\n" | ||
common_string += str(CN({"BASE_MODEL": self._C.GCC})) + "\n" | ||
common_string += str(CN({"OPTIM": self._C.OPTIM})) + "\n" | ||
common_string += str(CN({"CHECKPOINT": self._C.CKP})) + "\n" | ||
return common_string | ||
|
||
def __repr__(self): | ||
return self._C.__repr__() | ||
|
Empty file.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.