Skip to content

Commit

Permalink
add wandb resume
Browse files Browse the repository at this point in the history
  • Loading branch information
wonbeomjang committed Jan 29, 2022
1 parent 3c08e18 commit cd32a3a
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 113 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ checkpoints
samples
image_preprocess.py
results.csv
wandb
2 changes: 1 addition & 1 deletion config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
parser.add_argument('--num_classes', type=int, default=2, help='number of model output channels')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--test_batch_size', type=int, default=1, help='test batch size')
parser.add_argument('--num_epoch', type=int, default=500, help='number of epochs to train for')
parser.add_argument('--num_epoch', type=int, default=200, help='number of epochs to train for')
parser.add_argument('--decay_epoch', type=int, default=100, help='learning rate decay start epoch num')
parser.add_argument('--lr', type=float, default=1, help='learning rate')
parser.add_argument('--rho', type=float, default=0.95, help='adadelta rho')
Expand Down
14 changes: 1 addition & 13 deletions data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#
#
import os
import cv2

from tqdm import tqdm
import torch.utils.data
Expand All @@ -28,18 +27,6 @@ def check_data(data_folder):

intersection = list(intersection)

# print('[*] Check that if mask image is single channel')
# index = 0
# for image in tqdm(intersection):
# img = cv2.imread(f'{data_folder}/masks/{image}')
#
# if img.shape[-1] == 3:
# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# cv2.imwrite(f'{data_folder}/masks/{image}', img)
# index += 1
#
# print(f"[!] {index} images are changed")

return intersection


Expand Down Expand Up @@ -80,6 +67,7 @@ def transform(image, mask, image_size=224):

# Make gray scale image
gray_image = TF.to_grayscale(image)
mask = TF.to_grayscale(mask)

# Transform to tensor
image = TF.to_tensor(image)
Expand Down
26 changes: 12 additions & 14 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os

import wandb
import random
import torch
import torch.backends.cudnn

from config.config import get_config
import data.test_loader
from src.train import Trainer
from src.test import Tester

wandb.log

def main(config):
if config.checkpoint_dir is None:
config.checkpoint_dir = 'checkpoints'
Expand All @@ -17,20 +19,18 @@ def main(config):
if not os.path.exists(config.sample_dir):
os.makedirs(config.sample_dir)

# config.manual_seed = random.randint(1, 10000)
# print("Random Seed: ", config.manual_seed)
# random.seed(config.manual_seed)
# torch.manual_seed(config.manual_seed)

# if torch.cuda.is_available():
# torch.cuda.manual_seed_all(config.manual_seed)
config.manual_seed = 100
print("Random Seed: ", config.manual_seed)
random.seed(config.manual_seed)
torch.manual_seed(config.manual_seed)

# cudnn.benchmark = True
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config.manual_seed)

run = wandb.init(project='hair_segmentation', resume=True)
torch.backends.cudnn.benchmark = True

if not config.test:
trainer = Trainer(config, wandb)
trainer = Trainer(config)
trainer.train()

if config.quantize:
Expand All @@ -41,8 +41,6 @@ def main(config):
tester = Tester(config, test_loader)
tester.test()

run.finish()


if __name__ == "__main__":
config = get_config()
Expand Down
Binary file modified requirements.txt
Binary file not shown.
16 changes: 4 additions & 12 deletions src/test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import torch
import os
import numpy as np
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from utils.custom_transfrom import UnNormalize
from utils.util import quantize_model, AverageMeter
from models.quantization.modelv2 import QuantizableMobileHairNetV2
from utils.util import AverageMeter
from loss.loss import iou_loss


class Tester:
def __init__(self, config, dataloader):
self.batch_size = config.batch_size
Expand All @@ -29,18 +26,13 @@ def __init__(self, config, dataloader):
def load_model(self):
ckpt = f'{self.checkpoint_dir}/quantized.pt' if self.quantize else f'{self.checkpoint_dir}/best.pt'
print(f'[*] Load Model from {ckpt}')

# save_info = {'model': self.net, 'state_dict': self.net.state_dict(), 'optimizer' : self.optimizer.state_dict()}



if self.quantize:
self.net = torch.jit.load(ckpt)
else:
save_info = torch.load(ckpt, map_location=self.device)
self.net = save_info['model']

self.net = save_info['model']
self.net.load_state_dict(save_info['state_dict'])


def test(self, net=None):
if net:
Expand Down
119 changes: 46 additions & 73 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import time
import traceback
from shutil import move

import wandb
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.optim.adadelta import Adadelta
from torch.optim.lr_scheduler import OneCycleLR

Expand All @@ -16,13 +14,17 @@
from models.quantization.modelv1 import QuantizableMobileHairNet
from models.quantization.modelv2 import QuantizableMobileHairNetV2
from loss.loss import ImageGradientLoss, iou_loss
from utils.util import LambdaLR, AverageMeter
from utils.util import AverageMeter

from data.dataloader import get_loader


class Trainer:
def __init__(self, config, wandb):
def __init__(self, config):
self.net = None
self.run = None
self.lr_scheduler = None
self.optimizer = None
self.data_loader, self.val_loader = get_loader(config.data_path, config.batch_size, config.image_size,
shuffle=True, num_workers=int(config.workers))
self.batch_size = config.batch_size
Expand All @@ -41,13 +43,12 @@ def __init__(self, config, wandb):
self.model_version = config.model_version
self.quantize = config.quantize
self.resume = config.resume
self.eps = config.eps
self.rho = config.rho
self.decay = config.decay
self.num_quantize_train = config.num_quantize_train

self.wandb = wandb

self.build_model()
self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=config.eps, rho=config.rho,
weight_decay=config.decay)

def build_model(self):
if self.model_version == 1:
Expand All @@ -64,40 +65,46 @@ def build_model(self):
else:
raise Exception('[!] Unexpected model version')

self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=self.eps, rho=self.rho,
weight_decay=self.decay)
self.lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr, epochs=self.num_epoch,
steps_per_epoch=self.image_len, cycle_momentum=False)
self.load_model()

def load_model(self):
if not self.model_path and not self.resume:
self.run = wandb.init(project='hair_segmentation', dir=os.getcwd())
return

ckpt = os.path.join(self.checkpoint_dir, 'last.pt') if self.resume else self.model_path
ckpt = f'{self.checkpoint_dir}/last.pt' if self.resume else self.model_path
save_info = torch.load(ckpt, map_location=self.device)
run_id = save_info['run_id'] if 'run_id' in save_info else None
self.run = wandb.init(id=run_id, project='hair_segmentation', resume="allow", dir=os.getcwd())

try:
save_info = torch.load(self.wandb.restore(ckpt), map_location=self.device)
except ValueError:
print(traceback.format_exc())
print("[!] Wandb load fail")
save_info = torch.load(ckpt, map_location=self.device)
# save_info = {'model': self.net, 'state_dict': self.net.state_dict(), 'optimizer' : self.optimizer.state_dict()}
# try:
# save_info = torch.load(wandb.restore(ckpt).name, map_location=self.device)
# except ValueError:
# print(traceback.format_exc())
# print(f"[!] {ckpt} is not exist in wandb")

self.epoch = save_info['epoch'] + 1
self.net = save_info['model']
self.optimizer = save_info['optimizer']

self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=self.eps, rho=self.rho,
weight_decay=self.decay)
self.lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr, epochs=self.num_epoch,
steps_per_epoch=self.image_len, cycle_momentum=False)

self.optimizer.load_state_dict(save_info['optimizer'])
self.net.load_state_dict(save_info['state_dict'])
self.lr_scheduler.load_state_dict(save_info['lr_scheduler'])

print(f"[*] Load Model from {ckpt}")

def train(self):
image_gradient_criterion = ImageGradientLoss().to(self.device)
bce_criterion = nn.CrossEntropyLoss().to(self.device)
best = 0
best_loss = 0

if os.path.exists('results.csv'):
f = open('results.csv', 'a')
else:
f = open('results.csv', 'w')
f.write('epoch,iou,loss\n')

for epoch in range(self.epoch, self.num_epoch):
results = self._train_one_epoch(epoch, image_gradient_criterion, bce_criterion)
Expand All @@ -106,33 +113,28 @@ def train(self):
val_results = self.val(image_gradient_criterion, bce_criterion)
results.update(val_results)
iou = val_results["val/iou"]
loss = val_results["val/loss"]

f.write(f'{epoch},{iou:.4f},{loss:4f}\n')
if iou > best:
best = iou
best_loss = loss

save_info = {'model': self.net, 'state_dict': self.net.state_dict(),
'optimizer': self.optimizer.state_dict(), 'epoch': epoch}
'optimizer': self.optimizer.state_dict(), 'epoch': epoch,
'lr_scheduler': self.lr_scheduler.state_dict(), 'run_id': self.run.id}
torch.save(save_info, f'{self.checkpoint_dir}/best.pt')
wandb.save(f'{self.checkpoint_dir}/best.pt', './', 'now')

if self.wandb:
self.wandb.log(results)

f.write(f'final,{best:.4f},{best_loss:.4f}\n')
wandb.log(results)
print(f'Final IOU: {best:.4f}')
f.close()
self.run.finish()

def quantize_model(self):
if not self.quantize:
return

print('Load Best Model')
ckpt = f'{self.checkpoint_dir}/last.pt'
ckpt = f'{self.checkpoint_dir}/best.pt'
save_info = torch.load(ckpt, map_location=self.device)
self.net = save_info['model']
self.net.load_state_dict(save_info['state_dict'])
# save_info = {'model': self.net, 'state_dict': self.net.state_dict(), 'optimizer' : self.optimizer.state_dict()}

print('Before quantize')
self.device = torch.device('cpu')
Expand Down Expand Up @@ -198,6 +200,7 @@ def _train_one_epoch(self, epoch, image_gradient_criterion, bce_criterion, quant
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.lr_scheduler.step()

iou = iou_loss(pred, mask)
bce_losses.update(bce_loss.item(), self.batch_size)
Expand All @@ -207,23 +210,17 @@ def _train_one_epoch(self, epoch, image_gradient_criterion, bce_criterion, quant
pbar.set_description(f"Epoch: [{epoch}/{self.num_epoch}] | Bce Loss: {bce_losses.avg:.4f} | "
f"Image Gradient Loss: {image_gradient_losses.avg:.4f} | IOU: {iou_avg.avg:.4f}")

# if step % self.sample_step == 0:
# self.save_sample_imgs(image[0], mask[0], torch.argmax(pred[0], 0), self.sample_dir, epoch, step)
# print('[*] Saved sample images')
if not quantize:
save_info = {'model': self.net, 'state_dict': self.net.state_dict(),
'optimizer': self.optimizer.state_dict(), 'epoch': epoch}
'optimizer': self.optimizer.state_dict(), 'epoch': epoch,
'lr_scheduler': self.lr_scheduler.state_dict(), 'run_id': self.run.id}
torch.save(save_info, f'{self.checkpoint_dir}/last.pt')
try:
self.wandb.save(f'{os.getcwd()}/{self.checkpoint_dir}/last.pt')
except OSError:
print(traceback.format_exc())
print("[!] Save wandb fail")

if self.wandb and image is not None:
wandb.save(f'{self.checkpoint_dir}/last.pt', './')
print(f"[*] Save model Epoch: {epoch}")
if image is not None:
img = torch.cat(
[image[0], mask[0].repeat(3, 1, 1), pred[0].argmax(dim=0).unsqueeze(dim=0).repeat(3, 1, 1)], dim=2)
results["prediction"] = self.wandb.Image(img)
results["prediction"] = wandb.Image(img)
results["train/iou"] = iou_avg.avg
results["train/loss"] = bce_losses.avg + image_gradient_losses.avg * self.gradient_loss_weight

Expand Down Expand Up @@ -275,27 +272,3 @@ def val(self, image_gradient_criterion, bce_criterion):
results["val/loss"] = bce_losses.avg + image_gradient_losses.avg * self.gradient_loss_weight

return results

def make_sample_imgs(self, real_img, real_mask, prediction, save_dir, epoch, step):
data = [real_img, real_mask, prediction]
names = ["Image", "Mask", "Prediction"]

fig = plt.figure()
for i, d in enumerate(data):
d = d.squeeze()
im = d.data.cpu().numpy()

if i > 0:
im = np.expand_dims(im, axis=0)
im = np.concatenate((im, im, im), axis=0)

im = (im.transpose(1, 2, 0) + 1) / 2

f = fig.add_subplot(1, 3, i + 1)
f.imshow(im)
f.set_title(names[i])
f.set_xticks([])
f.set_yticks([])

p = os.path.join(save_dir, "epoch-%s_step-%s.png" % (epoch, step))
plt.savefig(p)

0 comments on commit cd32a3a

Please sign in to comment.