Skip to content

Commit

Permalink
feat: box detect and segmentation UNet3+ #25
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 21, 2024
1 parent b4b5866 commit 52a3042
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 7 deletions.
94 changes: 94 additions & 0 deletions UNet3+/Code/CropTrainRun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

# python native
import os
from ultralytics import YOLO

from sklearn.model_selection import GroupKFold
import albumentations as A
# torch
import torch
import torch.optim as optim
from torch.utils.data import DataLoader


from Model.FixedModel import UNet_3Plus_DeepSup
from DataSet.DataLoder import get_image_label_paths
from config import IMAGE_ROOT,LABEL_ROOT,BATCH_SIZE,IMSIZE,CLASSES,MILESTONES, GAMMA,LR, SAVED_DIR,YOLO_MODEL_PATH
from DataSet.YOLO_Crop_Dataset import XRayDataset
from Loss.Loss import CombinedLoss
from Train import train
from Util.SetSeed import set_seed




set_seed()
# Load the YOLO model
YoloModel = YOLO(YOLO_MODEL_PATH)

if not os.path.isdir(SAVED_DIR):
os.makedirs(SAVED_DIR)

pngs, jsons=get_image_label_paths(IMAGE_ROOT=IMAGE_ROOT,LABEL_ROOT=LABEL_ROOT)
#print(pngs, jsons)

# split train-valid
# 한 폴더 안에 한 인물의 양손에 대한 `.png` 파일이 존재하기 때문에
# 폴더 이름을 그룹으로 해서 GroupKFold를 수행합니다.
# 동일 인물의 손이 train, valid에 따로 들어가는 것을 방지합니다.
groups = [os.path.dirname(fname) for fname in pngs]

# dummy label
ys = [0 for fname in pngs]

# 전체 데이터의 20%를 validation data로 쓰기 위해 `n_splits`를
# 5으로 설정하여 GroupKFold를 수행합니다.
gkf = GroupKFold(n_splits=5)

train_filenames = []
train_labelnames = []
valid_filenames = []
valid_labelnames = []
for i, (x, y) in enumerate(gkf.split(pngs, ys, groups)):
# 0번을 validation dataset으로 사용합니다.
if i == 0:
valid_filenames += list(pngs[y])
valid_labelnames += list(jsons[y])

else:
train_filenames += list(pngs[y])
train_labelnames += list(jsons[y])

#tf = A.Resize(IMSIZE,IMSIZE)
train_dataset = XRayDataset(train_filenames, train_labelnames, is_train=True,yolo_model=YoloModel)
valid_dataset = XRayDataset(valid_filenames, valid_labelnames, is_train=False,yolo_model=YoloModel)


train_loader = DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=8,
drop_last=True,
)

valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=1,
shuffle=False,
num_workers=2,
drop_last=False
)


model = UNet_3Plus_DeepSup(n_classes=len(CLASSES))

# Loss function 정의
criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1)

# Optimizer 정의
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=GAMMA)


train(model, train_loader, valid_loader, criterion, optimizer,scheduler)
118 changes: 118 additions & 0 deletions UNet3+/Code/DataSet/YOLO_Crop_Dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import cv2
import numpy as np
import json
import torch
from config import CLASS2IND, CLASSES, IMAGE_ROOT, LABEL_ROOT,YOLO_NAMES,YOLO_SELECT_CLASS,IMSIZE
from Util.SetSeed import set_seed

set_seed()

from torch.utils.data import Dataset

class XRayDataset(Dataset):
def __init__(self, filenames, labelnames, transforms=None, is_train=False, yolo_model=None):
self.filenames = filenames
self.labelnames = labelnames
self.is_train = is_train
self.transforms = transforms
self.yolo_model = yolo_model # YOLO 모델 추가

def __len__(self):
return len(self.filenames)

def __getitem__(self, item):
image_name = self.filenames[item]
image_path = os.path.join(IMAGE_ROOT, image_name)

# Read and normalize image
image = cv2.imread(image_path)
image = image / 255.0

label_name = self.labelnames[item]
label_path = os.path.join(LABEL_ROOT, label_name)

# Initialize label tensor
label_shape = tuple(image.shape[:2]) + (len(CLASSES), )
label = np.zeros(label_shape, dtype=np.uint8)

# Read label file
with open(label_path, "r") as f:
annotations = json.load(f)
annotations = annotations["annotations"]

# Generate masks for all annotations
for ann in annotations:
c = ann["label"]
if c not in CLASSES:
continue

class_ind = CLASS2IND[c]
points = np.array(ann["points"])

# Generate masks
class_label = np.zeros(image.shape[:2], dtype=np.uint8)
cv2.fillPoly(class_label, [points], 1)
label[..., class_ind] = class_label

# YOLO 예측 결과에서 others 클래스 박스 가져오기
if self.yolo_model:
results = self.yolo_model.predict(image_path, imgsz=2048, iou=0.3, conf=0.1, max_det=3)
result=results[0].boxes
yolo_boxes = result.xyxy.cpu().numpy() # (N, 4) 형식의 박스 좌표
yolo_classes = result.cls.cpu().numpy() # (N,) 형식의 클래스
yolo_confidences = result.conf.cpu().numpy() # (N,) 형식의 신뢰도

# others 클래스 필터링
others_boxes = [
(box, conf) for box, cls, conf in zip(yolo_boxes, yolo_classes, yolo_confidences)
if YOLO_NAMES[int(cls)] == YOLO_SELECT_CLASS
]

# 신뢰도가 가장 높은 박스 선택
if others_boxes:
best_box, _ = max(others_boxes, key=lambda x: x[1]) # (x1, y1, x2, y2) 좌표
crop_box = self.calculate_crop_box_from_yolo(best_box, image.shape[:2])
image = self.crop_image(image, crop_box)
label = self.crop_label(label, crop_box)

# Apply augmentations
if self.transforms is not None:
inputs = {"image": image, "mask": label} if self.is_train else {"image": image}
result = self.transforms(**inputs)
image = result["image"]
label = result["mask"] if self.is_train else label

# Convert to tensor
image = image.transpose(2, 0, 1) # Channel first
label = label.transpose(2, 0, 1)

image = torch.from_numpy(image).float()
label = torch.from_numpy(label).float()

return image, label

def calculate_crop_box_from_yolo(self, yolo_box, image_size, crop_size=IMSIZE):
"""Calculate the crop box based on YOLO prediction."""
x1, y1, x2, y2 = yolo_box
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2

half_size = crop_size / 2
start_x = max(int(center_x - half_size), 0)
start_y = max(int(center_y - half_size), 0)
end_x = min(int(start_x + crop_size), image_size[1])
end_y = min(int(start_y + crop_size), image_size[0])

return start_x, start_y, end_x, end_y

def crop_image(self, image, crop_box):
"""Crop the image to the specified box."""
start_x, start_y, end_x, end_y = crop_box
cropped_image = image[start_y:end_y, start_x:end_x]
return cropped_image

def crop_label(self, label, crop_box):
"""Crop the label tensor to match the cropped image."""
start_x, start_y, end_x, end_y = crop_box
return label[start_y:end_y, start_x:end_x, :]
34 changes: 27 additions & 7 deletions UNet3+/Code/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,61 @@

WEBHOOK_URL = 'https://discord.com/api/webhooks/1306529597055041562/DUG0omhuBla0YM6SVqVgcSWgRBP2D0WZ5_xJt9aNvLL2QJpIHicb4tRupbvWYRLRtEgN'


YOLO_MODEL_PATH="/data/ephemeral/home/MCG/YOLO_Detection_Model/best.pt"

IMAGE_ROOT = "/data/ephemeral/home/MCG/data/train/DCM"
LABEL_ROOT = "/data/ephemeral/home/MCG/data/train/outputs_json"
CLASSES = [

'''CLASSES = [
'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]
'''

CLASSES = [
'Trapezium',
'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
'Triquetrum', 'Pisiform',
]


CLASS2IND = {v: i for i, v in enumerate(CLASSES)}
IND2CLASS = {v: k for k, v in CLASS2IND.items()}

BATCH_SIZE = 1

YOLO_NAMES = {0: "finger", 1: "radius_ulna", 2: "others"} # YOLO 클래스 인덱스 매핑
YOLO_SELECT_CLASS="others"


RANDOM_SEED = 21

# 적절하게 조절
NUM_EPOCHS =30
VAL_EVERY = 1
IMSIZE=1024

BATCH_SIZE = 1
IMSIZE=480

LR = 0.0001
MILESTONES=[7,16,23,27]
GAMMA=0.3


SAVED_DIR = "/data/ephemeral/home/MCG/UNetRefactored/Creadted_model/"
MODELNAME="best_NewModel.pt"
MODELNAME="othersCrop.pt"
if not os.path.isdir(SAVED_DIR):
os.mkdir(SAVED_DIR)



INFERENCE_MODEL_NAME="best_NewModel.pt"
INFERENCE_MODEL_NAME="othersCrop.pt"

TEST_IMAGE_ROOT="../../data/test/DCM"
TEST_IMAGE_ROOT="/data/ephemeral/home/MCG/data/test/DCM"

CSVDIR="/data/ephemeral/home/MCG/UNetRefactored/CSV"
CSVNAME="best_NewModel"
CSVNAME="othersCrop"

0 comments on commit 52a3042

Please sign in to comment.