From 52a3042365b40b1f40dc39ba6d53922d4eb4fd6c Mon Sep 17 00:00:00 2001 From: MCG <000914m@gmail.com> Date: Thu, 21 Nov 2024 19:46:59 +0900 Subject: [PATCH] feat: box detect and segmentation UNet3+ #25 --- UNet3+/Code/CropTrainRun.py | 94 ++++++++++++++++++ UNet3+/Code/DataSet/YOLO_Crop_Dataset.py | 118 +++++++++++++++++++++++ UNet3+/Code/config.py | 34 +++++-- 3 files changed, 239 insertions(+), 7 deletions(-) create mode 100644 UNet3+/Code/CropTrainRun.py create mode 100644 UNet3+/Code/DataSet/YOLO_Crop_Dataset.py diff --git a/UNet3+/Code/CropTrainRun.py b/UNet3+/Code/CropTrainRun.py new file mode 100644 index 0000000..c4233a9 --- /dev/null +++ b/UNet3+/Code/CropTrainRun.py @@ -0,0 +1,94 @@ + +# python native +import os +from ultralytics import YOLO + +from sklearn.model_selection import GroupKFold +import albumentations as A +# torch +import torch +import torch.optim as optim +from torch.utils.data import DataLoader + + +from Model.FixedModel import UNet_3Plus_DeepSup +from DataSet.DataLoder import get_image_label_paths +from config import IMAGE_ROOT,LABEL_ROOT,BATCH_SIZE,IMSIZE,CLASSES,MILESTONES, GAMMA,LR, SAVED_DIR,YOLO_MODEL_PATH +from DataSet.YOLO_Crop_Dataset import XRayDataset +from Loss.Loss import CombinedLoss +from Train import train +from Util.SetSeed import set_seed + + + + +set_seed() +# Load the YOLO model +YoloModel = YOLO(YOLO_MODEL_PATH) + +if not os.path.isdir(SAVED_DIR): + os.makedirs(SAVED_DIR) + +pngs, jsons=get_image_label_paths(IMAGE_ROOT=IMAGE_ROOT,LABEL_ROOT=LABEL_ROOT) +#print(pngs, jsons) + +# split train-valid +# 한 폴더 안에 한 인물의 양손에 대한 `.png` 파일이 존재하기 때문에 +# 폴더 이름을 그룹으로 해서 GroupKFold를 수행합니다. +# 동일 인물의 손이 train, valid에 따로 들어가는 것을 방지합니다. +groups = [os.path.dirname(fname) for fname in pngs] + +# dummy label +ys = [0 for fname in pngs] + +# 전체 데이터의 20%를 validation data로 쓰기 위해 `n_splits`를 +# 5으로 설정하여 GroupKFold를 수행합니다. +gkf = GroupKFold(n_splits=5) + +train_filenames = [] +train_labelnames = [] +valid_filenames = [] +valid_labelnames = [] +for i, (x, y) in enumerate(gkf.split(pngs, ys, groups)): + # 0번을 validation dataset으로 사용합니다. + if i == 0: + valid_filenames += list(pngs[y]) + valid_labelnames += list(jsons[y]) + + else: + train_filenames += list(pngs[y]) + train_labelnames += list(jsons[y]) + +#tf = A.Resize(IMSIZE,IMSIZE) +train_dataset = XRayDataset(train_filenames, train_labelnames, is_train=True,yolo_model=YoloModel) +valid_dataset = XRayDataset(valid_filenames, valid_labelnames, is_train=False,yolo_model=YoloModel) + + +train_loader = DataLoader( + dataset=train_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=8, + drop_last=True, +) + +valid_loader = DataLoader( + dataset=valid_dataset, + batch_size=1, + shuffle=False, + num_workers=2, + drop_last=False +) + + +model = UNet_3Plus_DeepSup(n_classes=len(CLASSES)) + +# Loss function 정의 +criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1) + +# Optimizer 정의 +optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-6) +scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=GAMMA) + + +train(model, train_loader, valid_loader, criterion, optimizer,scheduler) \ No newline at end of file diff --git a/UNet3+/Code/DataSet/YOLO_Crop_Dataset.py b/UNet3+/Code/DataSet/YOLO_Crop_Dataset.py new file mode 100644 index 0000000..37680ff --- /dev/null +++ b/UNet3+/Code/DataSet/YOLO_Crop_Dataset.py @@ -0,0 +1,118 @@ +import os +import cv2 +import numpy as np +import json +import torch +from config import CLASS2IND, CLASSES, IMAGE_ROOT, LABEL_ROOT,YOLO_NAMES,YOLO_SELECT_CLASS,IMSIZE +from Util.SetSeed import set_seed + +set_seed() + +from torch.utils.data import Dataset + +class XRayDataset(Dataset): + def __init__(self, filenames, labelnames, transforms=None, is_train=False, yolo_model=None): + self.filenames = filenames + self.labelnames = labelnames + self.is_train = is_train + self.transforms = transforms + self.yolo_model = yolo_model # YOLO 모델 추가 + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, item): + image_name = self.filenames[item] + image_path = os.path.join(IMAGE_ROOT, image_name) + + # Read and normalize image + image = cv2.imread(image_path) + image = image / 255.0 + + label_name = self.labelnames[item] + label_path = os.path.join(LABEL_ROOT, label_name) + + # Initialize label tensor + label_shape = tuple(image.shape[:2]) + (len(CLASSES), ) + label = np.zeros(label_shape, dtype=np.uint8) + + # Read label file + with open(label_path, "r") as f: + annotations = json.load(f) + annotations = annotations["annotations"] + + # Generate masks for all annotations + for ann in annotations: + c = ann["label"] + if c not in CLASSES: + continue + + class_ind = CLASS2IND[c] + points = np.array(ann["points"]) + + # Generate masks + class_label = np.zeros(image.shape[:2], dtype=np.uint8) + cv2.fillPoly(class_label, [points], 1) + label[..., class_ind] = class_label + + # YOLO 예측 결과에서 others 클래스 박스 가져오기 + if self.yolo_model: + results = self.yolo_model.predict(image_path, imgsz=2048, iou=0.3, conf=0.1, max_det=3) + result=results[0].boxes + yolo_boxes = result.xyxy.cpu().numpy() # (N, 4) 형식의 박스 좌표 + yolo_classes = result.cls.cpu().numpy() # (N,) 형식의 클래스 + yolo_confidences = result.conf.cpu().numpy() # (N,) 형식의 신뢰도 + + # others 클래스 필터링 + others_boxes = [ + (box, conf) for box, cls, conf in zip(yolo_boxes, yolo_classes, yolo_confidences) + if YOLO_NAMES[int(cls)] == YOLO_SELECT_CLASS + ] + + # 신뢰도가 가장 높은 박스 선택 + if others_boxes: + best_box, _ = max(others_boxes, key=lambda x: x[1]) # (x1, y1, x2, y2) 좌표 + crop_box = self.calculate_crop_box_from_yolo(best_box, image.shape[:2]) + image = self.crop_image(image, crop_box) + label = self.crop_label(label, crop_box) + + # Apply augmentations + if self.transforms is not None: + inputs = {"image": image, "mask": label} if self.is_train else {"image": image} + result = self.transforms(**inputs) + image = result["image"] + label = result["mask"] if self.is_train else label + + # Convert to tensor + image = image.transpose(2, 0, 1) # Channel first + label = label.transpose(2, 0, 1) + + image = torch.from_numpy(image).float() + label = torch.from_numpy(label).float() + + return image, label + + def calculate_crop_box_from_yolo(self, yolo_box, image_size, crop_size=IMSIZE): + """Calculate the crop box based on YOLO prediction.""" + x1, y1, x2, y2 = yolo_box + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + + half_size = crop_size / 2 + start_x = max(int(center_x - half_size), 0) + start_y = max(int(center_y - half_size), 0) + end_x = min(int(start_x + crop_size), image_size[1]) + end_y = min(int(start_y + crop_size), image_size[0]) + + return start_x, start_y, end_x, end_y + + def crop_image(self, image, crop_box): + """Crop the image to the specified box.""" + start_x, start_y, end_x, end_y = crop_box + cropped_image = image[start_y:end_y, start_x:end_x] + return cropped_image + + def crop_label(self, label, crop_box): + """Crop the label tensor to match the cropped image.""" + start_x, start_y, end_x, end_y = crop_box + return label[start_y:end_y, start_x:end_x, :] diff --git a/UNet3+/Code/config.py b/UNet3+/Code/config.py index d6acfd4..61428b5 100644 --- a/UNet3+/Code/config.py +++ b/UNet3+/Code/config.py @@ -3,9 +3,13 @@ WEBHOOK_URL = 'https://discord.com/api/webhooks/1306529597055041562/DUG0omhuBla0YM6SVqVgcSWgRBP2D0WZ5_xJt9aNvLL2QJpIHicb4tRupbvWYRLRtEgN' + +YOLO_MODEL_PATH="/data/ephemeral/home/MCG/YOLO_Detection_Model/best.pt" + IMAGE_ROOT = "/data/ephemeral/home/MCG/data/train/DCM" LABEL_ROOT = "/data/ephemeral/home/MCG/data/train/outputs_json" -CLASSES = [ + +'''CLASSES = [ 'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5', 'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10', 'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15', @@ -13,31 +17,47 @@ 'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate', 'Triquetrum', 'Pisiform', 'Radius', 'Ulna', ] +''' + +CLASSES = [ + 'Trapezium', + 'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate', + 'Triquetrum', 'Pisiform', +] + + CLASS2IND = {v: i for i, v in enumerate(CLASSES)} IND2CLASS = {v: k for k, v in CLASS2IND.items()} -BATCH_SIZE = 1 + +YOLO_NAMES = {0: "finger", 1: "radius_ulna", 2: "others"} # YOLO 클래스 인덱스 매핑 +YOLO_SELECT_CLASS="others" + + RANDOM_SEED = 21 # 적절하게 조절 NUM_EPOCHS =30 VAL_EVERY = 1 -IMSIZE=1024 + +BATCH_SIZE = 1 +IMSIZE=480 + LR = 0.0001 MILESTONES=[7,16,23,27] GAMMA=0.3 SAVED_DIR = "/data/ephemeral/home/MCG/UNetRefactored/Creadted_model/" -MODELNAME="best_NewModel.pt" +MODELNAME="othersCrop.pt" if not os.path.isdir(SAVED_DIR): os.mkdir(SAVED_DIR) -INFERENCE_MODEL_NAME="best_NewModel.pt" +INFERENCE_MODEL_NAME="othersCrop.pt" -TEST_IMAGE_ROOT="../../data/test/DCM" +TEST_IMAGE_ROOT="/data/ephemeral/home/MCG/data/test/DCM" CSVDIR="/data/ephemeral/home/MCG/UNetRefactored/CSV" -CSVNAME="best_NewModel" \ No newline at end of file +CSVNAME="othersCrop" \ No newline at end of file