Skip to content

Commit

Permalink
feat: Visualization Crop Image and Label UNet3+ #25
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 21, 2024
1 parent 79b3538 commit cec78de
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
6 changes: 3 additions & 3 deletions UNet3+/Code/CropTrainRun.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from Model.FixedModel import UNet_3Plus_DeepSup
from DataSet.DataLoder import get_image_label_paths
from config import IMAGE_ROOT, LABEL_ROOT, BATCH_SIZE, IMSIZE, CLASSES, MILESTONES, GAMMA, LR, SAVED_DIR, YOLO_MODEL_PATH
from config import IMAGE_ROOT, LABEL_ROOT, BATCH_SIZE, IMSIZE, CLASSES, MILESTONES, GAMMA, LR, SAVED_DIR, YOLO_MODEL_PATH, VISUALIZE_TRAIN_DATA, SAVE_VISUALIZE_TRAIN_DATA_PATH
from DataSet.YOLO_Crop_Dataset import XRayDataset
from Loss.Loss import CombinedLoss
from Train import train
Expand Down Expand Up @@ -60,8 +60,8 @@
train_labelnames += list(jsons[y])

# tf = A.Resize(IMSIZE,IMSIZE)
train_dataset = XRayDataset(train_filenames, train_labelnames, is_train=True, yolo_model=YoloModel)
valid_dataset = XRayDataset(valid_filenames, valid_labelnames, is_train=False, yolo_model=YoloModel)
train_dataset = XRayDataset(train_filenames, train_labelnames, is_train=True, yolo_model=YoloModel, save_dir=SAVE_VISUALIZE_TRAIN_DATA_PATH, draw_enabled=VISUALIZE_TRAIN_DATA)
valid_dataset = XRayDataset(valid_filenames, valid_labelnames, is_train=False, yolo_model=YoloModel, save_dir=None, draw_enabled=False)

train_loader = DataLoader(
dataset=train_dataset,
Expand Down
46 changes: 44 additions & 2 deletions UNet3+/Code/DataSet/YOLO_Crop_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from torch.utils.data import Dataset

class XRayDataset(Dataset):
def __init__(self, filenames, labelnames, transforms=None, is_train=False, yolo_model=None):
def __init__(self, filenames, labelnames, transforms=None,
is_train=False, yolo_model=None, save_dir=None, draw_enabled=False):
self.filenames = filenames
self.labelnames = labelnames
self.is_train = is_train
self.transforms = transforms
self.yolo_model = yolo_model # YOLO 모델 추가
self.save_dir = save_dir # Crop된 이미지 저장 디렉토리
self.draw_enabled = draw_enabled # 라벨 그리기 기능 활성화 여부

def __len__(self):
return len(self.filenames)
Expand Down Expand Up @@ -83,13 +86,19 @@ def __getitem__(self, item):
image = result["image"]
label = result["mask"] if self.is_train else label

if self.draw_enabled and self.save_dir:
os.makedirs(self.save_dir, exist_ok=True)
save_path = os.path.join(self.save_dir, f"cropped_{os.path.basename(self.filenames[item])}")
draw_and_save_crop(image, label, save_path)

# Convert to tensor
image = image.transpose(2, 0, 1) # Channel first
label = label.transpose(2, 0, 1)

image = torch.from_numpy(image).float()
label = torch.from_numpy(label).float()



return image, label

def calculate_crop_box_from_yolo(self, yolo_box, image_size, crop_size=IMSIZE):
Expand All @@ -116,3 +125,36 @@ def crop_label(self, label, crop_box):
"""Crop the label tensor to match the cropped image."""
start_x, start_y, end_x, end_y = crop_box
return label[start_y:end_y, start_x:end_x, :]



import cv2
import numpy as np

def draw_and_save_crop(image, label, save_path):
"""
Crop된 이미지 위에 라벨 정보를 그려 저장합니다.
Args:
image (np.ndarray): Crop된 이미지 (H, W, C).
label (np.ndarray): Crop된 라벨 (H, W, num_classes).
save_path (str): 저장할 파일 경로.
"""
# 이미지 복사
image_to_draw = (image * 255).astype(np.uint8).copy() # 이미지 복원 (0~255)

# 클래스별 색상 설정
num_classes = label.shape[-1]
colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)] # 클래스별 색상

# 클래스별로 라벨을 이미지에 그리기
for class_idx in range(num_classes):
mask = label[..., class_idx].astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

# 컨투어 그리기
for contour in contours:
cv2.drawContours(image_to_draw, [contour], -1, colors[class_idx % len(colors)], 2)

# 저장
cv2.imwrite(save_path, image_to_draw)
4 changes: 3 additions & 1 deletion UNet3+/Code/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


YOLO_MODEL_PATH="/data/ephemeral/home/MCG/YOLO_Detection_Model/best.pt"
SAVE_VISUALIZE_TRAIN_DATA_PATH="/data/ephemeral/home/MCG/YOLO_Detection_Model/crop_train_Image"
VISUALIZE_TRAIN_DATA=True

IMAGE_ROOT = "/data/ephemeral/home/MCG/data/train/DCM"
LABEL_ROOT = "/data/ephemeral/home/MCG/data/train/outputs_json"
Expand Down Expand Up @@ -37,7 +39,7 @@
RANDOM_SEED = 21

# 적절하게 조절
NUM_EPOCHS =30
NUM_EPOCHS =20
VAL_EVERY = 1

BATCH_SIZE = 1
Expand Down

0 comments on commit cec78de

Please sign in to comment.