-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Crop Inference and Merge CSV UNet3+ #25
- Loading branch information
MCG
committed
Nov 24, 2024
1 parent
925e705
commit cf54ed2
Showing
3 changed files
with
213 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,128 +1,105 @@ | ||
import albumentations as A | ||
from torch.utils.data import DataLoader,Dataset | ||
from config import IND2CLASS,SAVED_DIR, INFERENCE_MODEL_NAME, IMSIZE, CSVDIR,CSVNAME,CLASSES,TEST_IMAGE_ROOT | ||
from torch.utils.data import DataLoader, Dataset | ||
from config import IND2CLASS, SAVED_DIR, INFERENCE_MODEL_NAME, IMSIZE, CSVDIR, CSVNAME, CLASSES, TEST_IMAGE_ROOT, SAVE_VISUALIZE_TRAIN_DATA_PATH | ||
import os | ||
import torch | ||
import pandas as pd | ||
import numpy as np | ||
import cv2 | ||
from tqdm.auto import tqdm | ||
|
||
from DataSet.YoloBaseCropDataset import XRayDataset | ||
|
||
model = torch.load(os.path.join(SAVED_DIR, INFERENCE_MODEL_NAME)) | ||
|
||
pngs = { | ||
os.path.relpath(os.path.join(root, fname), start=TEST_IMAGE_ROOT) | ||
for root, _dirs, files in os.walk(TEST_IMAGE_ROOT) | ||
for fname in files | ||
if os.path.splitext(fname)[1].lower() == ".png" | ||
} | ||
from ultralytics import YOLO | ||
from DataSet.YoloInferenceDataset import XRayInferenceDataset | ||
import torch.nn.functional as F | ||
import torch.multiprocessing as mp | ||
|
||
def encode_mask_to_rle(mask): | ||
''' | ||
mask: numpy array binary mask | ||
1 - mask | ||
0 - background | ||
Returns encoded run length | ||
''' | ||
pixels = mask.flatten() | ||
pixels = np.concatenate([[0], pixels, [0]]) | ||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 | ||
runs[1::2] -= runs[::2] | ||
return ' '.join(str(x) for x in runs) | ||
|
||
|
||
def decode_rle_to_mask(rle, height, width): | ||
s = rle.split() | ||
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] | ||
starts -= 1 | ||
ends = starts + lengths | ||
img = np.zeros(height * width, dtype=np.uint8) | ||
|
||
for lo, hi in zip(starts, ends): | ||
img[lo:hi] = 1 | ||
|
||
return img.reshape(height, width) | ||
|
||
class XRayInferenceDataset(Dataset): | ||
def __init__(self, transforms=None): | ||
_filenames = pngs | ||
_filenames = np.array(sorted(_filenames)) | ||
|
||
self.filenames = _filenames | ||
self.transforms = transforms | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
def __getitem__(self, item): | ||
image_name = self.filenames[item] | ||
image_path = os.path.join(TEST_IMAGE_ROOT, image_name) | ||
|
||
image = cv2.imread(image_path) | ||
image = image / 255. | ||
|
||
if self.transforms is not None: | ||
inputs = {"image": image} | ||
result = self.transforms(**inputs) | ||
image = result["image"] | ||
|
||
# to tenser will be done later | ||
image = image.transpose(2, 0, 1) # make channel first | ||
|
||
image = torch.from_numpy(image).float() | ||
|
||
return image, image_name | ||
|
||
import torch.nn.functional as F | ||
def test(model, data_loader, thr=0.5): | ||
model = model.cuda() | ||
model.eval() | ||
|
||
rles = [] | ||
filename_and_class = [] | ||
with torch.no_grad(): | ||
n_class = len(CLASSES) | ||
|
||
for step, (images, image_names) in tqdm(enumerate(data_loader), total=len(data_loader)): | ||
for step, (images, image_names, crop_boxes) in tqdm(enumerate(data_loader), total=len(data_loader)): | ||
images = images.cuda() | ||
outputs = model(images) | ||
|
||
# restore original size | ||
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear") | ||
outputs = (outputs > thr).detach().cpu().numpy() | ||
for output, image_name, crop_box in zip(outputs, image_names, crop_boxes): | ||
print(crop_box) | ||
start_x, start_y, end_x, end_y = crop_box | ||
crop_width = end_x - start_x | ||
crop_height = end_y - start_y | ||
|
||
# Interpolate to crop size | ||
output = F.interpolate(output.unsqueeze(0), size=(crop_height, crop_width), mode="nearest") | ||
output = (output > thr).squeeze(0).detach().cpu().numpy() | ||
|
||
for output, image_name in zip(outputs, image_names): | ||
for c, segm in enumerate(output): | ||
rle = encode_mask_to_rle(segm) | ||
full_size_mask = np.zeros((2048, 2048), dtype=np.uint8) | ||
full_size_mask[start_y:end_y, start_x:end_x] = segm | ||
rle = encode_mask_to_rle(full_size_mask) | ||
rles.append(rle) | ||
filename_and_class.append(f"{IND2CLASS[c]}_{image_name}") | ||
|
||
return rles, filename_and_class | ||
|
||
|
||
tf = A.Resize(IMSIZE, IMSIZE) | ||
test_dataset = XRayInferenceDataset(transforms=tf) | ||
|
||
test_loader = DataLoader( | ||
dataset=test_dataset, | ||
batch_size=2, | ||
shuffle=False, | ||
num_workers=2, | ||
drop_last=False | ||
) | ||
|
||
rles, filename_and_class = test(model, test_loader) | ||
|
||
classes, filename = zip(*[x.split("_") for x in filename_and_class]) | ||
|
||
image_name = [os.path.basename(f) for f in filename] | ||
|
||
df = pd.DataFrame({ | ||
"image_name": image_name, | ||
"class": classes, | ||
"rle": rles, | ||
}) | ||
|
||
df.to_csv(os.path.join(CSVDIR, CSVNAME),index=False) | ||
if __name__ == "__main__": | ||
mp.set_start_method("spawn", force=True) # CUDA 문제 해결을 위한 spawn 방식 설정 | ||
|
||
# 모델 로드 | ||
model = torch.load(os.path.join(SAVED_DIR, INFERENCE_MODEL_NAME)) | ||
yolo_model = YOLO("/data/ephemeral/home/MCG/YOLO_Detection_Model/best.pt") # YOLO 모델을 GPU로 이동 | ||
|
||
# PNG 파일 가져오기 | ||
pngs = { | ||
os.path.relpath(os.path.join(root, fname), start=TEST_IMAGE_ROOT) | ||
for root, _dirs, files in os.walk(TEST_IMAGE_ROOT) | ||
for fname in files | ||
if os.path.splitext(fname)[1].lower() == ".png" | ||
} | ||
|
||
# 데이터셋 생성 | ||
test_dataset = XRayInferenceDataset( | ||
filenames=pngs, | ||
yolo_model=yolo_model, # YOLO 모델 전달 | ||
save_dir=SAVE_VISUALIZE_TRAIN_DATA_PATH, | ||
draw_enabled=True | ||
) | ||
def custom_collate_fn(batch): | ||
images, image_names, crop_boxes = zip(*batch) | ||
return ( | ||
torch.stack(images), # 이미지 텐서 병합 | ||
list(image_names), # 파일명 리스트 유지 | ||
list(crop_boxes) # crop_box 리스트 유지 | ||
) | ||
|
||
# 데이터 로더 생성 | ||
test_loader = DataLoader( | ||
dataset=test_dataset, | ||
batch_size=2, | ||
shuffle=False, | ||
num_workers=0, # 멀티프로세싱 비활성화 | ||
drop_last=False, | ||
collate_fn=custom_collate_fn | ||
) | ||
|
||
# 테스트 수행 | ||
rles, filename_and_class = test(model, test_loader) | ||
|
||
# 결과 저장 | ||
classes, filename = zip(*[x.split("_") for x in filename_and_class]) | ||
image_name = [os.path.basename(f) for f in filename] | ||
|
||
df = pd.DataFrame({ | ||
"image_name": image_name, | ||
"class": classes, | ||
"rle": rles, | ||
}) | ||
|
||
df.to_csv(os.path.join(CSVDIR, CSVNAME), index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
import json | ||
import torch | ||
from config import CLASS2IND, CLASSES, IMAGE_ROOT, LABEL_ROOT,YOLO_NAMES,YOLO_SELECT_CLASS,IMSIZE,TEST_IMAGE_ROOT | ||
from Util.SetSeed import set_seed | ||
|
||
set_seed() | ||
|
||
from torch.utils.data import Dataset | ||
|
||
class XRayInferenceDataset(Dataset): | ||
def __init__(self, filenames,yolo_model, transforms=None, save_dir=None, draw_enabled=False): | ||
_filenames = filenames | ||
_filenames = np.array(sorted(_filenames)) | ||
self.yolo_model=yolo_model | ||
self.filenames = _filenames | ||
self.transforms = transforms | ||
self.save_dir = save_dir | ||
self.draw_enabled = draw_enabled | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
def __getitem__(self, item): | ||
image_name = self.filenames[item] | ||
image_path = os.path.join(TEST_IMAGE_ROOT, image_name) | ||
|
||
image = cv2.imread(image_path) | ||
|
||
|
||
if self.yolo_model: | ||
results = self.yolo_model.predict(image_path, imgsz=2048, iou=0.3, conf=0.1, max_det=3) | ||
result=results[0].boxes | ||
yolo_boxes = result.xyxy.cpu().numpy() # (N, 4) 형식의 박스 좌표 | ||
yolo_classes = result.cls.cpu().numpy() # (N,) 형식의 클래스 | ||
yolo_confidences = result.conf.cpu().numpy() # (N,) 형식의 신뢰도 | ||
|
||
# others 클래스 필터링 | ||
others_boxes = [ | ||
(box, conf) for box, cls, conf in zip(yolo_boxes, yolo_classes, yolo_confidences) | ||
if YOLO_NAMES[int(cls)] == YOLO_SELECT_CLASS | ||
] | ||
|
||
# 신뢰도가 가장 높은 박스 선택 | ||
if others_boxes: | ||
best_box, _ = max(others_boxes, key=lambda x: x[1]) # (x1, y1, x2, y2) 좌표 | ||
crop_box = self.calculate_crop_box_from_yolo(best_box, image.shape[:2]) | ||
image = self.crop_image(image, crop_box) | ||
print(crop_box,"@@@@") | ||
|
||
image = image / 255. | ||
|
||
if self.transforms is not None: | ||
inputs = {"image": image} | ||
result = self.transforms(**inputs) | ||
image = result["image"] | ||
|
||
if self.draw_enabled and self.save_dir: | ||
os.makedirs(self.save_dir, exist_ok=True) | ||
save_path = os.path.join(self.save_dir, f"cropped_{os.path.basename(self.filenames[item])}") | ||
self.save_crop(image, save_path) | ||
|
||
# to tenser will be done later | ||
image = image.transpose(2, 0, 1) # make channel first | ||
|
||
image = torch.from_numpy(image).float() | ||
|
||
return image, image_name, crop_box | ||
|
||
|
||
def calculate_crop_box_from_yolo(self, yolo_box, image_size, crop_size=IMSIZE): | ||
"""Calculate the crop box based on YOLO prediction.""" | ||
x1, y1, x2, y2 = yolo_box | ||
center_x = (x1 + x2) / 2 | ||
center_y = (y1 + y2) / 2 | ||
|
||
half_size = crop_size / 2 | ||
start_x = max(int(center_x - half_size), 0) | ||
start_y = max(int(center_y - half_size), 0) | ||
end_x = min(int(start_x + crop_size), image_size[1]) | ||
end_y = min(int(start_y + crop_size), image_size[0]) | ||
print(start_x, start_y, end_x, end_y) | ||
|
||
return start_x, start_y, end_x, end_y | ||
|
||
def crop_image(self, image, crop_box): | ||
"""Crop the image to the specified box.""" | ||
start_x, start_y, end_x, end_y = crop_box | ||
cropped_image = image[start_y:end_y, start_x:end_x] | ||
return cropped_image | ||
def save_crop(self,image,save_path): | ||
# 이미지 복사 | ||
image_to_draw = (image * 255).astype(np.uint8).copy() # 이미지 복원 (0~255) | ||
# 저장 | ||
cv2.imwrite(save_path, image_to_draw) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pandas as pd | ||
|
||
# 1번 CSV와 2번 CSV 파일 경로 | ||
csv1_path = "/data/ephemeral/home/MCG/UNetRefactored/CSV/097HR.csv" | ||
csv2_path = "/data/ephemeral/home/MCG/UNetRefactored/CSV/CropOthers_ResNet152_Hybrid.csv" | ||
|
||
# CSV 읽기 | ||
df1 = pd.read_csv(csv1_path) | ||
df2 = pd.read_csv(csv2_path) | ||
|
||
# 중복 판단 키 생성 | ||
df1["key"] = df1["image_name"] + "_" + df1["class"] | ||
df2["key"] = df2["image_name"] + "_" + df2["class"] | ||
|
||
# 2번 CSV를 딕셔너리로 변환 | ||
df2_dict = df2.set_index("key").to_dict(orient="index") | ||
|
||
# 1번 CSV 순서를 유지하며 데이터 대체 | ||
updated_rows = [] | ||
for _, row in df1.iterrows(): | ||
key = row["key"] | ||
if key in df2_dict: | ||
# df2_dict에서 대체 데이터 추출 | ||
updated_row = { | ||
"image_name": row["image_name"], | ||
"class": row["class"], | ||
"rle": df2_dict[key]["rle"] | ||
} | ||
updated_rows.append(updated_row) | ||
else: | ||
updated_rows.append(row.to_dict()) # 기존 데이터를 dict 형식으로 추가 | ||
|
||
# DataFrame 생성 | ||
updated_df = pd.DataFrame(updated_rows) | ||
|
||
# key 컬럼 삭제 | ||
updated_df.drop(columns=["key"], inplace=True) | ||
|
||
# 결과 CSV 저장 | ||
output_path = "/data/ephemeral/home/MCG/UNetRefactored/CSV/updated_csv.csv" | ||
updated_df.to_csv(output_path, index=False) | ||
|
||
print(f"결과 CSV가 '{output_path}'에 저장되었습니다.") |