diff --git a/.gitignore b/.gitignore index 161756d..1f9d717 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ checkpoints samples image_preprocess.py results.csv +wandb \ No newline at end of file diff --git a/config/config.py b/config/config.py index d0526eb..5942478 100644 --- a/config/config.py +++ b/config/config.py @@ -7,7 +7,7 @@ parser.add_argument('--num_classes', type=int, default=2, help='number of model output channels') parser.add_argument('--batch_size', type=int, default=32, help='batch size') parser.add_argument('--test_batch_size', type=int, default=1, help='test batch size') -parser.add_argument('--num_epoch', type=int, default=500, help='number of epochs to train for') +parser.add_argument('--num_epoch', type=int, default=200, help='number of epochs to train for') parser.add_argument('--decay_epoch', type=int, default=100, help='learning rate decay start epoch num') parser.add_argument('--lr', type=float, default=1, help='learning rate') parser.add_argument('--rho', type=float, default=0.95, help='adadelta rho') diff --git a/data/dataloader.py b/data/dataloader.py index f2e4fc4..323c438 100644 --- a/data/dataloader.py +++ b/data/dataloader.py @@ -8,7 +8,6 @@ # # import os -import cv2 from tqdm import tqdm import torch.utils.data @@ -28,18 +27,6 @@ def check_data(data_folder): intersection = list(intersection) - # print('[*] Check that if mask image is single channel') - # index = 0 - # for image in tqdm(intersection): - # img = cv2.imread(f'{data_folder}/masks/{image}') - # - # if img.shape[-1] == 3: - # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - # cv2.imwrite(f'{data_folder}/masks/{image}', img) - # index += 1 - # - # print(f"[!] {index} images are changed") - return intersection @@ -80,6 +67,7 @@ def transform(image, mask, image_size=224): # Make gray scale image gray_image = TF.to_grayscale(image) + mask = TF.to_grayscale(mask) # Transform to tensor image = TF.to_tensor(image) diff --git a/main.py b/main.py index fa46e28..bd96efb 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,15 @@ import os -import wandb +import random +import torch +import torch.backends.cudnn from config.config import get_config import data.test_loader from src.train import Trainer from src.test import Tester -wandb.log + def main(config): if config.checkpoint_dir is None: config.checkpoint_dir = 'checkpoints' @@ -17,20 +19,18 @@ def main(config): if not os.path.exists(config.sample_dir): os.makedirs(config.sample_dir) - # config.manual_seed = random.randint(1, 10000) - # print("Random Seed: ", config.manual_seed) - # random.seed(config.manual_seed) - # torch.manual_seed(config.manual_seed) - - # if torch.cuda.is_available(): - # torch.cuda.manual_seed_all(config.manual_seed) + config.manual_seed = 100 + print("Random Seed: ", config.manual_seed) + random.seed(config.manual_seed) + torch.manual_seed(config.manual_seed) - # cudnn.benchmark = True + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(config.manual_seed) - run = wandb.init(project='hair_segmentation', resume=True) + torch.backends.cudnn.benchmark = True if not config.test: - trainer = Trainer(config, wandb) + trainer = Trainer(config) trainer.train() if config.quantize: @@ -41,8 +41,6 @@ def main(config): tester = Tester(config, test_loader) tester.test() - run.finish() - if __name__ == "__main__": config = get_config() diff --git a/requirements.txt b/requirements.txt index fd4ee8e..87c4602 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/src/test.py b/src/test.py index 3cec0ab..761b5ba 100644 --- a/src/test.py +++ b/src/test.py @@ -1,15 +1,12 @@ import torch import os -import numpy as np -from glob import glob from tqdm import tqdm -import matplotlib.pyplot as plt from torchvision.utils import save_image from utils.custom_transfrom import UnNormalize -from utils.util import quantize_model, AverageMeter -from models.quantization.modelv2 import QuantizableMobileHairNetV2 +from utils.util import AverageMeter from loss.loss import iou_loss + class Tester: def __init__(self, config, dataloader): self.batch_size = config.batch_size @@ -29,18 +26,13 @@ def __init__(self, config, dataloader): def load_model(self): ckpt = f'{self.checkpoint_dir}/quantized.pt' if self.quantize else f'{self.checkpoint_dir}/best.pt' print(f'[*] Load Model from {ckpt}') - - # save_info = {'model': self.net, 'state_dict': self.net.state_dict(), 'optimizer' : self.optimizer.state_dict()} - - + if self.quantize: self.net = torch.jit.load(ckpt) else: save_info = torch.load(ckpt, map_location=self.device) - self.net = save_info['model'] - + self.net = save_info['model'] self.net.load_state_dict(save_info['state_dict']) - def test(self, net=None): if net: diff --git a/src/train.py b/src/train.py index 75548b3..ee35b4b 100644 --- a/src/train.py +++ b/src/train.py @@ -1,13 +1,11 @@ import os import time import traceback -from shutil import move +import wandb import torch import torch.nn as nn -import numpy as np from tqdm import tqdm -import matplotlib.pyplot as plt from torch.optim.adadelta import Adadelta from torch.optim.lr_scheduler import OneCycleLR @@ -16,13 +14,17 @@ from models.quantization.modelv1 import QuantizableMobileHairNet from models.quantization.modelv2 import QuantizableMobileHairNetV2 from loss.loss import ImageGradientLoss, iou_loss -from utils.util import LambdaLR, AverageMeter +from utils.util import AverageMeter from data.dataloader import get_loader class Trainer: - def __init__(self, config, wandb): + def __init__(self, config): + self.net = None + self.run = None + self.lr_scheduler = None + self.optimizer = None self.data_loader, self.val_loader = get_loader(config.data_path, config.batch_size, config.image_size, shuffle=True, num_workers=int(config.workers)) self.batch_size = config.batch_size @@ -41,13 +43,12 @@ def __init__(self, config, wandb): self.model_version = config.model_version self.quantize = config.quantize self.resume = config.resume + self.eps = config.eps + self.rho = config.rho + self.decay = config.decay self.num_quantize_train = config.num_quantize_train - self.wandb = wandb - self.build_model() - self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=config.eps, rho=config.rho, - weight_decay=config.decay) def build_model(self): if self.model_version == 1: @@ -64,26 +65,39 @@ def build_model(self): else: raise Exception('[!] Unexpected model version') + self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=self.eps, rho=self.rho, + weight_decay=self.decay) + self.lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr, epochs=self.num_epoch, + steps_per_epoch=self.image_len, cycle_momentum=False) self.load_model() def load_model(self): if not self.model_path and not self.resume: + self.run = wandb.init(project='hair_segmentation', dir=os.getcwd()) return - ckpt = os.path.join(self.checkpoint_dir, 'last.pt') if self.resume else self.model_path + ckpt = f'{self.checkpoint_dir}/last.pt' if self.resume else self.model_path + save_info = torch.load(ckpt, map_location=self.device) + run_id = save_info['run_id'] if 'run_id' in save_info else None + self.run = wandb.init(id=run_id, project='hair_segmentation', resume="allow", dir=os.getcwd()) - try: - save_info = torch.load(self.wandb.restore(ckpt), map_location=self.device) - except ValueError: - print(traceback.format_exc()) - print("[!] Wandb load fail") - save_info = torch.load(ckpt, map_location=self.device) - # save_info = {'model': self.net, 'state_dict': self.net.state_dict(), 'optimizer' : self.optimizer.state_dict()} + # try: + # save_info = torch.load(wandb.restore(ckpt).name, map_location=self.device) + # except ValueError: + # print(traceback.format_exc()) + # print(f"[!] {ckpt} is not exist in wandb") self.epoch = save_info['epoch'] + 1 self.net = save_info['model'] - self.optimizer = save_info['optimizer'] + + self.optimizer = Adadelta(self.net.parameters(), lr=self.lr, eps=self.eps, rho=self.rho, + weight_decay=self.decay) + self.lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr, epochs=self.num_epoch, + steps_per_epoch=self.image_len, cycle_momentum=False) + + self.optimizer.load_state_dict(save_info['optimizer']) self.net.load_state_dict(save_info['state_dict']) + self.lr_scheduler.load_state_dict(save_info['lr_scheduler']) print(f"[*] Load Model from {ckpt}") @@ -91,13 +105,6 @@ def train(self): image_gradient_criterion = ImageGradientLoss().to(self.device) bce_criterion = nn.CrossEntropyLoss().to(self.device) best = 0 - best_loss = 0 - - if os.path.exists('results.csv'): - f = open('results.csv', 'a') - else: - f = open('results.csv', 'w') - f.write('epoch,iou,loss\n') for epoch in range(self.epoch, self.num_epoch): results = self._train_one_epoch(epoch, image_gradient_criterion, bce_criterion) @@ -106,33 +113,28 @@ def train(self): val_results = self.val(image_gradient_criterion, bce_criterion) results.update(val_results) iou = val_results["val/iou"] - loss = val_results["val/loss"] - f.write(f'{epoch},{iou:.4f},{loss:4f}\n') if iou > best: best = iou - best_loss = loss - save_info = {'model': self.net, 'state_dict': self.net.state_dict(), - 'optimizer': self.optimizer.state_dict(), 'epoch': epoch} + 'optimizer': self.optimizer.state_dict(), 'epoch': epoch, + 'lr_scheduler': self.lr_scheduler.state_dict(), 'run_id': self.run.id} torch.save(save_info, f'{self.checkpoint_dir}/best.pt') + wandb.save(f'{self.checkpoint_dir}/best.pt', './', 'now') - if self.wandb: - self.wandb.log(results) - - f.write(f'final,{best:.4f},{best_loss:.4f}\n') + wandb.log(results) print(f'Final IOU: {best:.4f}') - f.close() + self.run.finish() def quantize_model(self): if not self.quantize: return print('Load Best Model') - ckpt = f'{self.checkpoint_dir}/last.pt' + ckpt = f'{self.checkpoint_dir}/best.pt' save_info = torch.load(ckpt, map_location=self.device) + self.net = save_info['model'] self.net.load_state_dict(save_info['state_dict']) - # save_info = {'model': self.net, 'state_dict': self.net.state_dict(), 'optimizer' : self.optimizer.state_dict()} print('Before quantize') self.device = torch.device('cpu') @@ -198,6 +200,7 @@ def _train_one_epoch(self, epoch, image_gradient_criterion, bce_criterion, quant self.optimizer.zero_grad() loss.backward() self.optimizer.step() + self.lr_scheduler.step() iou = iou_loss(pred, mask) bce_losses.update(bce_loss.item(), self.batch_size) @@ -207,23 +210,17 @@ def _train_one_epoch(self, epoch, image_gradient_criterion, bce_criterion, quant pbar.set_description(f"Epoch: [{epoch}/{self.num_epoch}] | Bce Loss: {bce_losses.avg:.4f} | " f"Image Gradient Loss: {image_gradient_losses.avg:.4f} | IOU: {iou_avg.avg:.4f}") - # if step % self.sample_step == 0: - # self.save_sample_imgs(image[0], mask[0], torch.argmax(pred[0], 0), self.sample_dir, epoch, step) - # print('[*] Saved sample images') if not quantize: save_info = {'model': self.net, 'state_dict': self.net.state_dict(), - 'optimizer': self.optimizer.state_dict(), 'epoch': epoch} + 'optimizer': self.optimizer.state_dict(), 'epoch': epoch, + 'lr_scheduler': self.lr_scheduler.state_dict(), 'run_id': self.run.id} torch.save(save_info, f'{self.checkpoint_dir}/last.pt') - try: - self.wandb.save(f'{os.getcwd()}/{self.checkpoint_dir}/last.pt') - except OSError: - print(traceback.format_exc()) - print("[!] Save wandb fail") - - if self.wandb and image is not None: + wandb.save(f'{self.checkpoint_dir}/last.pt', './') + print(f"[*] Save model Epoch: {epoch}") + if image is not None: img = torch.cat( [image[0], mask[0].repeat(3, 1, 1), pred[0].argmax(dim=0).unsqueeze(dim=0).repeat(3, 1, 1)], dim=2) - results["prediction"] = self.wandb.Image(img) + results["prediction"] = wandb.Image(img) results["train/iou"] = iou_avg.avg results["train/loss"] = bce_losses.avg + image_gradient_losses.avg * self.gradient_loss_weight @@ -275,27 +272,3 @@ def val(self, image_gradient_criterion, bce_criterion): results["val/loss"] = bce_losses.avg + image_gradient_losses.avg * self.gradient_loss_weight return results - - def make_sample_imgs(self, real_img, real_mask, prediction, save_dir, epoch, step): - data = [real_img, real_mask, prediction] - names = ["Image", "Mask", "Prediction"] - - fig = plt.figure() - for i, d in enumerate(data): - d = d.squeeze() - im = d.data.cpu().numpy() - - if i > 0: - im = np.expand_dims(im, axis=0) - im = np.concatenate((im, im, im), axis=0) - - im = (im.transpose(1, 2, 0) + 1) / 2 - - f = fig.add_subplot(1, 3, i + 1) - f.imshow(im) - f.set_title(names[i]) - f.set_xticks([]) - f.set_yticks([]) - - p = os.path.join(save_dir, "epoch-%s_step-%s.png" % (epoch, step)) - plt.savefig(p)