-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
MCG
committed
Nov 19, 2024
1 parent
658d4a9
commit 9d9a200
Showing
17 changed files
with
3,923 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import numpy as np | ||
|
||
def get_image_label_paths(IMAGE_ROOT, LABEL_ROOT=None): | ||
# 이미지 파일 경로 수집 | ||
pngs = { | ||
os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT) | ||
for root, _dirs, files in os.walk(IMAGE_ROOT) | ||
for fname in files | ||
if os.path.splitext(fname)[1].lower() == ".png" | ||
} | ||
|
||
if LABEL_ROOT: | ||
# 라벨 파일 경로 수집 | ||
jsons = { | ||
os.path.relpath(os.path.join(root, fname), start=LABEL_ROOT) | ||
for root, _dirs, files in os.walk(LABEL_ROOT) | ||
for fname in files | ||
if os.path.splitext(fname)[1].lower() == ".json" | ||
} | ||
|
||
# 접두어만 추출 | ||
jsons_fn_prefix = {os.path.splitext(fname)[0] for fname in jsons} | ||
pngs_fn_prefix = {os.path.splitext(fname)[0] for fname in pngs} | ||
|
||
# 이미지와 라벨의 접두어가 정확히 일치하는지 확인 | ||
assert len(jsons_fn_prefix - pngs_fn_prefix) == 0 | ||
assert len(pngs_fn_prefix - jsons_fn_prefix) == 0 | ||
|
||
# 경로 정렬 | ||
pngs = sorted(pngs) | ||
jsons = sorted(jsons) | ||
pngs = np.array(pngs) | ||
jsons = np.array(jsons) | ||
else: | ||
print("NO LABEL ROOT") | ||
|
||
return pngs, jsons | ||
|
||
|
||
|
||
def get_test_images(TEST_IMAGE_ROOT): | ||
pngs = { | ||
os.path.relpath(os.path.join(root, fname), start=TEST_IMAGE_ROOT) | ||
for root, _dirs, files in os.walk(TEST_IMAGE_ROOT) | ||
for fname in files | ||
if os.path.splitext(fname)[1].lower() == ".png" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
import json | ||
import torch | ||
from config import CLASS2IND, CLASSES, IMAGE_ROOT, LABEL_ROOT,TEST_IMAGE_ROOT | ||
from DataLoder import get_test_images | ||
from Util.SetSeed import set_seed | ||
|
||
set_seed() | ||
|
||
from torch.utils.data import Dataset | ||
class XRayDataset(Dataset): | ||
def __init__(self, filenames, labelnames, transforms=None, is_train=False): | ||
self.filenames = filenames | ||
self.labelnames = labelnames | ||
self.is_train = is_train | ||
self.transforms = transforms | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
def __getitem__(self, item): | ||
image_name = self.filenames[item] | ||
image_path = os.path.join(IMAGE_ROOT, image_name) | ||
|
||
image = cv2.imread(image_path) | ||
image = image / 255. | ||
|
||
label_name = self.labelnames[item] | ||
label_path = os.path.join(LABEL_ROOT, label_name) | ||
|
||
# process a label of shape (H, W, NC) | ||
label_shape = tuple(image.shape[:2]) + (len(CLASSES), ) | ||
label = np.zeros(label_shape, dtype=np.uint8) | ||
|
||
# read label file | ||
with open(label_path, "r") as f: | ||
annotations = json.load(f) | ||
annotations = annotations["annotations"] | ||
|
||
# iterate each class | ||
for ann in annotations: | ||
c = ann["label"] | ||
class_ind = CLASS2IND[c] | ||
points = np.array(ann["points"]) | ||
|
||
# polygon to mask | ||
class_label = np.zeros(image.shape[:2], dtype=np.uint8) | ||
cv2.fillPoly(class_label, [points], 1) | ||
label[..., class_ind] = class_label | ||
|
||
if self.transforms is not None: | ||
inputs = {"image": image, "mask": label} if self.is_train else {"image": image} | ||
result = self.transforms(**inputs) | ||
|
||
image = result["image"] | ||
label = result["mask"] if self.is_train else label | ||
|
||
# to tenser will be done later | ||
image = image.transpose(2, 0, 1) # make channel first | ||
label = label.transpose(2, 0, 1) | ||
|
||
image = torch.from_numpy(image).float() | ||
label = torch.from_numpy(label).float() | ||
|
||
return image, label | ||
|
||
|
||
|
||
|
||
class XRayInferenceDataset(Dataset): | ||
def __init__(self, transforms=None): | ||
pngs=get_test_images(TEST_IMAGE_ROOT) | ||
_filenames = pngs | ||
_filenames = np.array(sorted(_filenames)) | ||
|
||
self.filenames = _filenames | ||
self.transforms = transforms | ||
|
||
def __len__(self): | ||
return len(self.filenames) | ||
|
||
def __getitem__(self, item): | ||
image_name = self.filenames[item] | ||
image_path = os.path.join(IMAGE_ROOT, image_name) | ||
|
||
image = cv2.imread(image_path) | ||
image = image / 255. | ||
|
||
if self.transforms is not None: | ||
inputs = {"image": image} | ||
result = self.transforms(**inputs) | ||
image = result["image"] | ||
|
||
# to tenser will be done later | ||
image = image.transpose(2, 0, 1) # make channel first | ||
|
||
image = torch.from_numpy(image).float() | ||
|
||
return image, image_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import albumentations as A | ||
from DataSet.Dataset import XRayInferenceDataset | ||
from torch.utils.data import DataLoader | ||
from config import SAVED_DIR, INFERENCE_MODEL_NAME, IMSIZE, CSVDIR,CSVNAME | ||
import os | ||
import torch | ||
from Infrence import test | ||
import pandas as pd | ||
|
||
model = torch.load(os.path.join(SAVED_DIR, INFERENCE_MODEL_NAME)) | ||
|
||
tf = A.Resize(IMSIZE, IMSIZE) | ||
test_dataset = XRayInferenceDataset(transforms=tf) | ||
|
||
test_loader = DataLoader( | ||
dataset=test_dataset, | ||
batch_size=2, | ||
shuffle=False, | ||
num_workers=2, | ||
drop_last=False | ||
) | ||
|
||
rles, filename_and_class = test(model, test_loader) | ||
|
||
classes, filename = zip(*[x.split("_") for x in filename_and_class]) | ||
|
||
image_name = [os.path.basename(f) for f in filename] | ||
|
||
df = pd.DataFrame({ | ||
"image_name": image_name, | ||
"class": classes, | ||
"rle": rles, | ||
}) | ||
|
||
df.to_csv(os.path.join(CSVDIR, CSVNAME),index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import os | ||
import numpy as np | ||
from tqdm.auto import tqdm | ||
from config import CLASSES,IND2CLASS | ||
import torch.nn.functional as F | ||
|
||
def encode_mask_to_rle(mask): | ||
''' | ||
mask: numpy array binary mask | ||
1 - mask | ||
0 - background | ||
Returns encoded run length | ||
''' | ||
pixels = mask.flatten() | ||
pixels = np.concatenate([[0], pixels, [0]]) | ||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 | ||
runs[1::2] -= runs[::2] | ||
return ' '.join(str(x) for x in runs) | ||
|
||
|
||
def decode_rle_to_mask(rle, height, width): | ||
s = rle.split() | ||
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] | ||
starts -= 1 | ||
ends = starts + lengths | ||
img = np.zeros(height * width, dtype=np.uint8) | ||
|
||
for lo, hi in zip(starts, ends): | ||
img[lo:hi] = 1 | ||
|
||
return img.reshape(height, width) | ||
|
||
|
||
def test(model, data_loader, thr=0.5): | ||
model = model.cuda() | ||
model.eval() | ||
|
||
rles = [] | ||
filename_and_class = [] | ||
with torch.no_grad(): | ||
n_class = len(CLASSES) | ||
|
||
for step, (images, image_names) in tqdm(enumerate(data_loader), total=len(data_loader)): | ||
images = images.cuda() | ||
outputs = model(images) | ||
|
||
# restore original size | ||
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear") | ||
outputs = torch.sigmoid(outputs) | ||
outputs = (outputs > thr).detach().cpu().numpy() | ||
|
||
for output, image_name in zip(outputs, image_names): | ||
for c, segm in enumerate(output): | ||
rle = encode_mask_to_rle(segm) | ||
rles.append(rle) | ||
filename_and_class.append(f"{IND2CLASS[c]}_{image_name}") | ||
|
||
return rles, filename_and_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from math import exp | ||
|
||
|
||
def gaussian(window_size, sigma): | ||
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) | ||
return gauss/gauss.sum() | ||
|
||
|
||
def create_window(window_size, channel=1): | ||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) | ||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | ||
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() | ||
return window | ||
|
||
|
||
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): | ||
if val_range is None: | ||
L = 1 # Assuming normalized images in [0, 1] | ||
|
||
padd = 0 | ||
(_, channel, height, width) = img1.size() | ||
if window is None: | ||
real_size = min(window_size, height, width) | ||
window = create_window(real_size, channel=channel).to(img1.device) | ||
|
||
mu1 = F.conv2d(img1, window, padding=padd, groups=channel) | ||
mu2 = F.conv2d(img2, window, padding=padd, groups=channel) | ||
|
||
mu1_sq = mu1.pow(2) | ||
mu2_sq = mu2.pow(2) | ||
mu1_mu2 = mu1 * mu2 | ||
|
||
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq | ||
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq | ||
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 | ||
|
||
C1 = (0.01 * L) ** 2 | ||
C2 = (0.03 * L) ** 2 | ||
|
||
v1 = 2.0 * sigma12 + C2 | ||
v2 = sigma1_sq + sigma2_sq + C2 | ||
cs = torch.mean(v1 / v2) # contrast sensitivity | ||
|
||
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) | ||
|
||
if size_average: | ||
ret = ssim_map.mean() | ||
else: | ||
ret = ssim_map.mean(1).mean(1).mean(1) | ||
|
||
if full: | ||
return ret, cs | ||
return ret | ||
|
||
|
||
def msssim(img1, img2, window_size=11, size_average=True, normalize=False): | ||
device = img1.device | ||
weights = torch.FloatTensor([0.1448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) | ||
levels = weights.size()[0] | ||
mssim = [] | ||
mcs = [] | ||
for _ in range(levels): | ||
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True) | ||
mssim.append(sim) | ||
mcs.append(cs) | ||
|
||
img1 = F.avg_pool2d(img1, (2, 2)) | ||
img2 = F.avg_pool2d(img2, (2, 2)) | ||
|
||
mssim = torch.stack(mssim) | ||
mcs = torch.stack(mcs) | ||
|
||
if normalize: | ||
mssim = (mssim + 1) / 2 | ||
mcs = (mcs + 1) / 2 | ||
|
||
pow1 = mcs ** weights | ||
pow2 = mssim ** weights | ||
output = torch.prod(pow1[:-1] * pow2[-1]) | ||
return output | ||
|
||
|
||
class MSSSIM(torch.nn.Module): | ||
def __init__(self, window_size=11, size_average=True, channel=3): | ||
super(MSSSIM, self).__init__() | ||
self.window_size = window_size | ||
self.size_average = size_average | ||
self.channel = channel | ||
|
||
def forward(self, img1, img2): | ||
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True) | ||
|
||
|
||
class CombinedLoss(nn.Module): | ||
def __init__(self, alpha=0.5, beta=0.3, gamma=0.2, smooth=1e-6): | ||
""" | ||
Combined Loss = alpha * Focal Loss + beta * IoU Loss + gamma * MS-SSIM Loss | ||
""" | ||
super(CombinedLoss, self).__init__() | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.gamma = gamma | ||
self.smooth = smooth | ||
self.ms_ssim = MSSSIM(window_size=11, size_average=True, channel=1) | ||
|
||
def focal_loss(self, logits, targets, alpha=0.8, gamma=2.0): | ||
probs = torch.sigmoid(logits) | ||
focal_loss = -alpha * (1 - probs) ** gamma * targets * torch.log(probs + 1e-6) \ | ||
- (1 - alpha) * probs ** gamma * (1 - targets) * torch.log(1 - probs + 1e-6) | ||
return focal_loss.mean() | ||
|
||
def iou_loss(self, logits, targets): | ||
probs = torch.sigmoid(logits) | ||
intersection = (probs * targets).sum(dim=(2, 3)) | ||
union = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) - intersection | ||
iou_loss = 1 - (intersection + self.smooth) / (union + self.smooth) | ||
return iou_loss.mean() | ||
|
||
def forward(self, logits, targets): | ||
focal = self.focal_loss(logits, targets) | ||
iou = self.iou_loss(logits, targets) | ||
ms_ssim_loss = 1 - self.ms_ssim(torch.sigmoid(logits), targets) | ||
|
||
total_loss = self.alpha * focal + self.beta * iou + self.gamma * ms_ssim_loss | ||
return total_loss |
Oops, something went wrong.