-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: box detect and segmentation UNet3+ #25
- Loading branch information
MCG
committed
Nov 21, 2024
1 parent
b4b5866
commit 52a3042
Showing
3 changed files
with
239 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
|
||
# python native | ||
import os | ||
from ultralytics import YOLO | ||
|
||
from sklearn.model_selection import GroupKFold | ||
import albumentations as A | ||
# torch | ||
import torch | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
from Model.FixedModel import UNet_3Plus_DeepSup | ||
from DataSet.DataLoder import get_image_label_paths | ||
from config import IMAGE_ROOT,LABEL_ROOT,BATCH_SIZE,IMSIZE,CLASSES,MILESTONES, GAMMA,LR, SAVED_DIR,YOLO_MODEL_PATH | ||
from DataSet.YOLO_Crop_Dataset import XRayDataset | ||
from Loss.Loss import CombinedLoss | ||
from Train import train | ||
from Util.SetSeed import set_seed | ||
|
||
|
||
|
||
|
||
set_seed() | ||
# Load the YOLO model | ||
YoloModel = YOLO(YOLO_MODEL_PATH) | ||
|
||
if not os.path.isdir(SAVED_DIR): | ||
os.makedirs(SAVED_DIR) | ||
|
||
pngs, jsons=get_image_label_paths(IMAGE_ROOT=IMAGE_ROOT,LABEL_ROOT=LABEL_ROOT) | ||
#print(pngs, jsons) | ||
|
||
# split train-valid | ||
# 한 폴더 안에 한 인물의 양손에 대한 `.png` 파일이 존재하기 때문에 | ||
# 폴더 이름을 그룹으로 해서 GroupKFold를 수행합니다. | ||
# 동일 인물의 손이 train, valid에 따로 들어가는 것을 방지합니다. | ||
groups = [os.path.dirname(fname) for fname in pngs] | ||
|
||
# dummy label | ||
ys = [0 for fname in pngs] | ||
|
||
# 전체 데이터의 20%를 validation data로 쓰기 위해 `n_splits`를 | ||
# 5으로 설정하여 GroupKFold를 수행합니다. | ||
gkf = GroupKFold(n_splits=5) | ||
|
||
train_filenames = [] | ||
train_labelnames = [] | ||
valid_filenames = [] | ||
valid_labelnames = [] | ||
for i, (x, y) in enumerate(gkf.split(pngs, ys, groups)): | ||
# 0번을 validation dataset으로 사용합니다. | ||
if i == 0: | ||
valid_filenames += list(pngs[y]) | ||
valid_labelnames += list(jsons[y]) | ||
|
||
else: | ||
train_filenames += list(pngs[y]) | ||
train_labelnames += list(jsons[y]) | ||
|
||
#tf = A.Resize(IMSIZE,IMSIZE) | ||
train_dataset = XRayDataset(train_filenames, train_labelnames, is_train=True,yolo_model=YoloModel) | ||
valid_dataset = XRayDataset(valid_filenames, valid_labelnames, is_train=False,yolo_model=YoloModel) | ||
|
||
|
||
train_loader = DataLoader( | ||
dataset=train_dataset, | ||
batch_size=BATCH_SIZE, | ||
shuffle=True, | ||
num_workers=8, | ||
drop_last=True, | ||
) | ||
|
||
valid_loader = DataLoader( | ||
dataset=valid_dataset, | ||
batch_size=1, | ||
shuffle=False, | ||
num_workers=2, | ||
drop_last=False | ||
) | ||
|
||
|
||
model = UNet_3Plus_DeepSup(n_classes=len(CLASSES)) | ||
|
||
# Loss function 정의 | ||
criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1) | ||
|
||
# Optimizer 정의 | ||
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-6) | ||
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=GAMMA) | ||
|
||
|
||
train(model, train_loader, valid_loader, criterion, optimizer,scheduler) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
import json | ||
import torch | ||
from config import CLASS2IND, CLASSES, IMAGE_ROOT, LABEL_ROOT,YOLO_NAMES,YOLO_SELECT_CLASS,IMSIZE | ||
from Util.SetSeed import set_seed | ||
|
||
set_seed() | ||
|
||
from torch.utils.data import Dataset | ||
|
||
class XRayDataset(Dataset): | ||
def __init__(self, filenames, labelnames, transforms=None, is_train=False, yolo_model=None): | ||
self.filenames = filenames | ||
self.labelnames = labelnames | ||
self.is_train = is_train | ||
self.transforms = transforms | ||
self.yolo_model = yolo_model # YOLO 모델 추가 | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
def __getitem__(self, item): | ||
image_name = self.filenames[item] | ||
image_path = os.path.join(IMAGE_ROOT, image_name) | ||
|
||
# Read and normalize image | ||
image = cv2.imread(image_path) | ||
image = image / 255.0 | ||
|
||
label_name = self.labelnames[item] | ||
label_path = os.path.join(LABEL_ROOT, label_name) | ||
|
||
# Initialize label tensor | ||
label_shape = tuple(image.shape[:2]) + (len(CLASSES), ) | ||
label = np.zeros(label_shape, dtype=np.uint8) | ||
|
||
# Read label file | ||
with open(label_path, "r") as f: | ||
annotations = json.load(f) | ||
annotations = annotations["annotations"] | ||
|
||
# Generate masks for all annotations | ||
for ann in annotations: | ||
c = ann["label"] | ||
if c not in CLASSES: | ||
continue | ||
|
||
class_ind = CLASS2IND[c] | ||
points = np.array(ann["points"]) | ||
|
||
# Generate masks | ||
class_label = np.zeros(image.shape[:2], dtype=np.uint8) | ||
cv2.fillPoly(class_label, [points], 1) | ||
label[..., class_ind] = class_label | ||
|
||
# YOLO 예측 결과에서 others 클래스 박스 가져오기 | ||
if self.yolo_model: | ||
results = self.yolo_model.predict(image_path, imgsz=2048, iou=0.3, conf=0.1, max_det=3) | ||
result=results[0].boxes | ||
yolo_boxes = result.xyxy.cpu().numpy() # (N, 4) 형식의 박스 좌표 | ||
yolo_classes = result.cls.cpu().numpy() # (N,) 형식의 클래스 | ||
yolo_confidences = result.conf.cpu().numpy() # (N,) 형식의 신뢰도 | ||
|
||
# others 클래스 필터링 | ||
others_boxes = [ | ||
(box, conf) for box, cls, conf in zip(yolo_boxes, yolo_classes, yolo_confidences) | ||
if YOLO_NAMES[int(cls)] == YOLO_SELECT_CLASS | ||
] | ||
|
||
# 신뢰도가 가장 높은 박스 선택 | ||
if others_boxes: | ||
best_box, _ = max(others_boxes, key=lambda x: x[1]) # (x1, y1, x2, y2) 좌표 | ||
crop_box = self.calculate_crop_box_from_yolo(best_box, image.shape[:2]) | ||
image = self.crop_image(image, crop_box) | ||
label = self.crop_label(label, crop_box) | ||
|
||
# Apply augmentations | ||
if self.transforms is not None: | ||
inputs = {"image": image, "mask": label} if self.is_train else {"image": image} | ||
result = self.transforms(**inputs) | ||
image = result["image"] | ||
label = result["mask"] if self.is_train else label | ||
|
||
# Convert to tensor | ||
image = image.transpose(2, 0, 1) # Channel first | ||
label = label.transpose(2, 0, 1) | ||
|
||
image = torch.from_numpy(image).float() | ||
label = torch.from_numpy(label).float() | ||
|
||
return image, label | ||
|
||
def calculate_crop_box_from_yolo(self, yolo_box, image_size, crop_size=IMSIZE): | ||
"""Calculate the crop box based on YOLO prediction.""" | ||
x1, y1, x2, y2 = yolo_box | ||
center_x = (x1 + x2) / 2 | ||
center_y = (y1 + y2) / 2 | ||
|
||
half_size = crop_size / 2 | ||
start_x = max(int(center_x - half_size), 0) | ||
start_y = max(int(center_y - half_size), 0) | ||
end_x = min(int(start_x + crop_size), image_size[1]) | ||
end_y = min(int(start_y + crop_size), image_size[0]) | ||
|
||
return start_x, start_y, end_x, end_y | ||
|
||
def crop_image(self, image, crop_box): | ||
"""Crop the image to the specified box.""" | ||
start_x, start_y, end_x, end_y = crop_box | ||
cropped_image = image[start_y:end_y, start_x:end_x] | ||
return cropped_image | ||
|
||
def crop_label(self, label, crop_box): | ||
"""Crop the label tensor to match the cropped image.""" | ||
start_x, start_y, end_x, end_y = crop_box | ||
return label[start_y:end_y, start_x:end_x, :] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters