Skip to content

Commit

Permalink
add model parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
wonbeomjang committed Jan 31, 2022
1 parent cd32a3a commit 76e2896
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 51 deletions.
14 changes: 1 addition & 13 deletions data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,7 @@
from PIL import Image
import torchvision.transforms.functional as TF
import random


def check_data(data_folder):
masks = set(os.listdir(f'{data_folder}/masks/'))
image = set(os.listdir(f'{data_folder}/images/'))

intersection = masks.intersection(image)
union = masks.union(image)
print(f"[!] {len(union) - len(intersection)} of {len(union)} images doesn't have mask")

intersection = list(intersection)

return intersection
from data.utils import check_data


def transform(image, mask, image_size=224):
Expand Down
22 changes: 4 additions & 18 deletions data/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,22 @@
import torchvision.transforms as transforms
from PIL import Image
import torchvision.transforms.functional as TF
from random import random
from data.utils import check_data


def transform(image, mask, image_size=224):
resize = transforms.Resize(size=(image_size, image_size))
image = resize(image)
mask = resize(mask)

# if random() > 0.5:
# image = TF.vflip(image)
# mask = TF.vflip(mask)

# if random() > 0.5:
# image = TF.hflip(image)
# mask = TF.hflip(mask)

# angle = random() * 12 - 6
# image = TF.rotate(image, angle)
# mask = TF.rotate(mask, angle)

# pad_size = random() * image_size
# image = TF.pad(image, pad_size, padding_mode='edge')
# mask = TF.pad(mask, pad_size, padding_mode='edge')
mask = TF.to_grayscale(mask)

# Transform to tensor
image = TF.to_tensor(image)
mask = TF.to_tensor(mask)

# Normalize Data
image = TF.normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
image = TF.normalize(image, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

return image, mask

Expand All @@ -45,7 +31,7 @@ def __init__(self, data_folder, image_size):
raise Exception(f"[!] {self.data_folder} not exists.")

self.objects_path = []
self.image_name = os.listdir(os.path.join(data_folder, "images"))
self.image_name = check_data(data_folder)
if len(self.image_name) == 0:
raise Exception(f"No image found in {self.image_name}")
for p in os.listdir(data_folder):
Expand Down
14 changes: 14 additions & 0 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import os


def check_data(data_folder):
masks = set(os.listdir(f'{data_folder}/masks/'))
image = set(os.listdir(f'{data_folder}/images/'))

intersection = masks.intersection(image)
union = masks.union(image)
print(f"[!] {len(union) - len(intersection)} of {len(union)} images doesn't have mask")

intersection = list(intersection)

return intersection
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def main(config):
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config.manual_seed)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if not config.test:
trainer = Trainer(config)
Expand Down
Binary file modified param/best.pt
Binary file not shown.
Binary file modified param/last.pt
Binary file not shown.
Binary file modified param/quantized.pt
Binary file not shown.
35 changes: 18 additions & 17 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,28 @@ def load_model(self):
def test(self, net=None):
if net:
self.net = net
avg_meter = AverageMeter()

unnormal = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
pbar = tqdm(enumerate(self.data_loader), total=len(self.data_loader))
for step, (image, mask) in pbar:
image = image.to(self.device)
#image = unnormal(image.to(self.device))
result = self.net(image)
avg_meter = AverageMeter()
with torch.no_grad():
unnormal = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
pbar = tqdm(enumerate(self.data_loader), total=len(self.data_loader))
for step, (image, mask) in pbar:
image = image.to(self.device)
#image = unnormal(image.to(self.device))
result = self.net(image)

mask = mask.to(self.device)
mask = mask.to(self.device)

avg_meter.update(iou_loss(result, mask))
pbar.set_description(f'IOU: {avg_meter.avg:.4f}')
avg_meter.update(iou_loss(result, mask))
pbar.set_description(f'IOU: {avg_meter.avg:.4f}')

mask = mask.repeat_interleave(3, 1)
argmax = torch.argmax(result, dim=1).unsqueeze(dim=1)
result = result[:, 1, :, :].unsqueeze(dim=1)
result = result * argmax
result = result.repeat_interleave(3, 1)
torch.cat([image, result, mask])
mask = mask.repeat_interleave(3, 1)
argmax = torch.argmax(result, dim=1).unsqueeze(dim=1)
result = result[:, 1, :, :].unsqueeze(dim=1)
result = result * argmax
result = result.repeat_interleave(3, 1)
torch.cat([image, result, mask])


save_image(torch.cat([image, result, mask]), os.path.join(self.sample_dir, f"{step}.png"))
save_image(torch.cat([image, result, mask]), os.path.join(self.sample_dir, f"{step}.png"))

6 changes: 4 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def quantize_model(self):
return

print('Load Best Model')
ckpt = f'{self.checkpoint_dir}/best.pt'
ckpt = f'{self.checkpoint_dir}/mobilenetv2.pt'
save_info = torch.load(ckpt, map_location=self.device)
self.net = save_info['model']
self.net.load_state_dict(save_info['state_dict'])
Expand All @@ -154,6 +154,9 @@ def quantize_model(self):

temp = self.num_epoch
self.num_epoch = self.num_quantize_train
self.lr_scheduler = OneCycleLR(optimizer=self.optimizer, max_lr=self.lr, epochs=self.num_epoch,
steps_per_epoch=self.image_len, cycle_momentum=False)

for i in range(self.num_quantize_train):
self._train_one_epoch(i, image_gradient_criterion, bce_criterion, quantize=True)
self.num_epoch = temp
Expand Down Expand Up @@ -216,7 +219,6 @@ def _train_one_epoch(self, epoch, image_gradient_criterion, bce_criterion, quant
'lr_scheduler': self.lr_scheduler.state_dict(), 'run_id': self.run.id}
torch.save(save_info, f'{self.checkpoint_dir}/last.pt')
wandb.save(f'{self.checkpoint_dir}/last.pt', './')
print(f"[*] Save model Epoch: {epoch}")
if image is not None:
img = torch.cat(
[image[0], mask[0].repeat(3, 1, 1), pred[0].argmax(dim=0).unsqueeze(dim=0).repeat(3, 1, 1)], dim=2)
Expand Down

0 comments on commit 76e2896

Please sign in to comment.