Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Application] cuda support for example of pytorch yolo v2 @open sesame 05/10 19:54 #2575

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions Applications/YOLOv2/PyTorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# @author Seungbaek Hong <[email protected]>

import glob
import re
import numpy as np
import torch
from torch.utils.data import Dataset
Expand All @@ -21,30 +22,36 @@
class YOLODataset(Dataset):
def __init__(self, img_dir, ann_dir):
super().__init__()
img_list = glob.glob(img_dir)
ann_list = glob.glob(ann_dir)
img_list.sort()
ann_list.sort()
pattern = re.compile("\/(\d+)\.")
img_list = glob.glob(img_dir + "*")
ann_list = glob.glob(ann_dir + "*")

img_ids = list(map(lambda x: pattern.search(x).group(1), img_list))
ann_ids = list(map(lambda x: pattern.search(x).group(1), ann_list))
ids_list = list(set(img_ids) & set(ann_ids))

self.length = len(img_list)
self.input_images = []
self.bbox_gt = []
self.cls_gt = []

for i in range(len(img_list)):
img = np.array(Image.open(img_list[i]).resize((416, 416))) / 255
for ids in ids_list:
img = np.array(Image.open(img_dir + ids + ".jpg").resize((416, 416))) / 255
label_bbox = []
label_cls = []
with open(ann_list[i], "rt", encoding="utf-8") as f:
with open(ann_dir + ids + ".txt", "rt", encoding="utf-8") as f:
for line in f.readlines():
line = [float(i) for i in line.split()]
label_bbox.append(np.array(line[1:], dtype=np.float32) / 416)
label_cls.append(int(line[0]))

if len(label_cls) == 0:
continue

self.input_images.append(img)
self.bbox_gt.append(label_bbox)
self.cls_gt.append(label_cls)

self.length = len(self.input_images)
self.input_images = np.array(self.input_images)
self.input_images = torch.FloatTensor(self.input_images).permute((0, 3, 1, 2))

Expand Down
50 changes: 21 additions & 29 deletions Applications/YOLOv2/PyTorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from yolo_loss import YoloV2_LOSS
from dataset import YOLODataset, collate_db

device = "cuda" if torch.cuda.is_available() else "cpu"


# get pyutils path using relative path
def get_util_path():
Expand All @@ -39,10 +41,10 @@ def get_util_path():
epochs = 3
batch_size = 4

train_img_dir = "/home/user/TRAIN_DIR/images/*"
train_ann_dir = "/home/user/TRAIN_DIR/annotations/*"
valid_img_dir = "/home/user/VALID_DIR/images/*"
valid_ann_dir = "/home/user/VALID_DIR/annotations/*"
train_img_dir = "/home/user/TRAIN_DIR/images/"
train_ann_dir = "/home/user/TRAIN_DIR/annotations/"
valid_img_dir = "/home/user/VALID_DIR/images/"
valid_ann_dir = "/home/user/VALID_DIR/annotations/"

# load data
train_dataset = YOLODataset(train_img_dir, train_ann_dir)
Expand All @@ -63,10 +65,12 @@ def get_util_path():
)

# set model, loss and optimizer
model = YoloV2(num_classes=num_classes)
criterion = YoloV2_LOSS(num_classes=num_classes)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
model = YoloV2(num_classes=num_classes).to(device)
criterion = YoloV2_LOSS(
num_classes=num_classes, img_shape=(416, 416), device=device
).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)

# save init model
save_bin(model, "init_model")
Expand All @@ -77,11 +81,11 @@ def get_util_path():
for epoch in range(epochs):
epoch_train_loss = 0
epoch_valid_loss = 0
model.train()
for idx, (img, bbox, cls) in enumerate(train_loader):
model.train()
optimizer.zero_grad()
# model prediction
hypothesis = model(img).permute((0, 2, 3, 1))
hypothesis = model(img.to(device)).permute((0, 2, 3, 1))
hypothesis = hypothesis.reshape(
(batch_size, out_size**2, num_anchors, 5 + num_classes)
)
Expand All @@ -95,24 +99,18 @@ def get_util_path():
score_pred.shape
)
# calc loss
loss = criterion(
torch.FloatTensor(bbox_pred),
torch.FloatTensor(iou_pred),
torch.FloatTensor(prob_pred),
bbox,
cls,
)
loss = criterion(bbox_pred, iou_pred, prob_pred, bbox, cls)
# back prop
loss.backward()
optimizer.step()
# scheduler.step()
scheduler.step()
epoch_train_loss += loss.item()

model.eval()
for idx, (img, bbox, cls) in enumerate(valid_loader):
model.eval()
with torch.no_grad():
# model prediction
hypothesis = model(img).permute((0, 2, 3, 1))
hypothesis = model(img.to(device)).permute((0, 2, 3, 1))
hypothesis = hypothesis.reshape(
(hypothesis.shape[0], out_size**2, num_anchors, 5 + num_classes)
)
Expand All @@ -126,13 +124,7 @@ def get_util_path():
score_pred.shape
)
# calc loss
loss = criterion(
torch.FloatTensor(bbox_pred),
torch.FloatTensor(iou_pred),
torch.FloatTensor(prob_pred),
bbox,
cls,
)
loss = criterion(bbox_pred, iou_pred, prob_pred, bbox, cls)
epoch_valid_loss += loss.item()

if epoch_valid_loss < best_loss:
Expand Down Expand Up @@ -175,8 +167,8 @@ def post_process_for_bbox(bbox_p):
bbox_p[:, :, :, :2] /= 13

# apply anchors to w, h
anchor_w = anchors[:, 0].contiguous().view(-1, 1)
anchor_h = anchors[:, 1].contiguous().view(-1, 1)
anchor_w = anchors[:, 0].contiguous().view(-1, 1).to(device)
anchor_h = anchors[:, 1].contiguous().view(-1, 1).to(device)
bbox_p[:, :, :, 2:3] *= anchor_w
bbox_p[:, :, :, 3:4] *= anchor_h

Expand Down
20 changes: 14 additions & 6 deletions Applications/YOLOv2/PyTorch/yolo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ def find_best_ratio(anchors, bbox):
class YoloV2_LOSS(nn.Module):
"""Yolo v2 loss"""

def __init__(self, num_classes, img_shape=(416, 416), outsize=(13, 13)):
def __init__(self, num_classes, img_shape, device="cpu", outsize=(13, 13)):
super().__init__()
self.device = device
self.num_classes = num_classes
self.img_shape = img_shape
self.outsize = outsize
Expand Down Expand Up @@ -136,8 +137,8 @@ def apply_anchors_to_bbox(self, bbox_pred):
@param bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
@return bbox_pred shape(batch_size, cell_h x cell_w, num_anchors, 4)
"""
anchor_w = self.anchors[:, 0].contiguous().view(-1, 1)
anchor_h = self.anchors[:, 1].contiguous().view(-1, 1)
anchor_w = self.anchors[:, 0].contiguous().view(-1, 1).to(self.device)
anchor_h = self.anchors[:, 1].contiguous().view(-1, 1).to(self.device)
bbox_pred_tmp = bbox_pred.clone()
bbox_pred_tmp[:, :, :, 2:3] = torch.sqrt(bbox_pred[:, :, :, 2:3] * anchor_w)
bbox_pred_tmp[:, :, :, 3:4] = torch.sqrt(bbox_pred[:, :, :, 3:4] * anchor_h)
Expand All @@ -159,7 +160,7 @@ def _build_target(self, bbox_pred, bbox_gt, cls_gt):
for i in range(batch_size):
_bbox_built, _iou_built, _cls_built, _bbox_mask, _iou_mask, _cls_mask = (
self._make_target_per_sample(
torch.FloatTensor(bbox_pred[i]),
bbox_pred[i],
torch.FloatTensor(np.array(bbox_gt[i])),
torch.LongTensor(cls_gt[i]),
)
Expand All @@ -179,7 +180,14 @@ def _build_target(self, bbox_pred, bbox_gt, cls_gt):
cls_built = torch.stack(cls_built)
cls_mask = torch.stack(cls_mask)

return bbox_built, iou_built, cls_built, bbox_mask, iou_mask, cls_mask
return (
bbox_built.to(self.device),
iou_built.to(self.device),
cls_built.to(self.device),
bbox_mask.to(self.device),
iou_mask.to(self.device),
cls_mask.to(self.device),
)

def _make_target_per_sample(self, _bbox_pred, _bbox_gt, _cls_gt):
"""
Expand Down Expand Up @@ -226,7 +234,7 @@ def _make_target_per_sample(self, _bbox_pred, _bbox_gt, _cls_gt):

# set confidence score of gt
_iou_built = calculate_iou(
_bbox_pred.reshape(-1, 4), _bbox_built.view(-1, 4)
_bbox_pred.reshape(-1, 4), _bbox_built.view(-1, 4).to(self.device)
).detach()
_iou_built = _iou_built.view(hw, num_anchors, 1)
_iou_mask[cell_idx, best_anchors, :] = 1
Expand Down