Skip to content

Commit

Permalink
feat: RefactoredUNet3+
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 19, 2024
1 parent 658d4a9 commit 9d9a200
Show file tree
Hide file tree
Showing 17 changed files with 3,923 additions and 0 deletions.
48 changes: 48 additions & 0 deletions UNet3+/Code/DataSet/DataLoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import numpy as np

def get_image_label_paths(IMAGE_ROOT, LABEL_ROOT=None):
# 이미지 파일 경로 수집
pngs = {
os.path.relpath(os.path.join(root, fname), start=IMAGE_ROOT)
for root, _dirs, files in os.walk(IMAGE_ROOT)
for fname in files
if os.path.splitext(fname)[1].lower() == ".png"
}

if LABEL_ROOT:
# 라벨 파일 경로 수집
jsons = {
os.path.relpath(os.path.join(root, fname), start=LABEL_ROOT)
for root, _dirs, files in os.walk(LABEL_ROOT)
for fname in files
if os.path.splitext(fname)[1].lower() == ".json"
}

# 접두어만 추출
jsons_fn_prefix = {os.path.splitext(fname)[0] for fname in jsons}
pngs_fn_prefix = {os.path.splitext(fname)[0] for fname in pngs}

# 이미지와 라벨의 접두어가 정확히 일치하는지 확인
assert len(jsons_fn_prefix - pngs_fn_prefix) == 0
assert len(pngs_fn_prefix - jsons_fn_prefix) == 0

# 경로 정렬
pngs = sorted(pngs)
jsons = sorted(jsons)
pngs = np.array(pngs)
jsons = np.array(jsons)
else:
print("NO LABEL ROOT")

return pngs, jsons



def get_test_images(TEST_IMAGE_ROOT):
pngs = {
os.path.relpath(os.path.join(root, fname), start=TEST_IMAGE_ROOT)
for root, _dirs, files in os.walk(TEST_IMAGE_ROOT)
for fname in files
if os.path.splitext(fname)[1].lower() == ".png"
}
101 changes: 101 additions & 0 deletions UNet3+/Code/DataSet/Dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import cv2
import numpy as np
import json
import torch
from config import CLASS2IND, CLASSES, IMAGE_ROOT, LABEL_ROOT,TEST_IMAGE_ROOT
from DataLoder import get_test_images
from Util.SetSeed import set_seed

set_seed()

from torch.utils.data import Dataset
class XRayDataset(Dataset):
def __init__(self, filenames, labelnames, transforms=None, is_train=False):
self.filenames = filenames
self.labelnames = labelnames
self.is_train = is_train
self.transforms = transforms

def __len__(self):
return len(self.filenames)

def __getitem__(self, item):
image_name = self.filenames[item]
image_path = os.path.join(IMAGE_ROOT, image_name)

image = cv2.imread(image_path)
image = image / 255.

label_name = self.labelnames[item]
label_path = os.path.join(LABEL_ROOT, label_name)

# process a label of shape (H, W, NC)
label_shape = tuple(image.shape[:2]) + (len(CLASSES), )
label = np.zeros(label_shape, dtype=np.uint8)

# read label file
with open(label_path, "r") as f:
annotations = json.load(f)
annotations = annotations["annotations"]

# iterate each class
for ann in annotations:
c = ann["label"]
class_ind = CLASS2IND[c]
points = np.array(ann["points"])

# polygon to mask
class_label = np.zeros(image.shape[:2], dtype=np.uint8)
cv2.fillPoly(class_label, [points], 1)
label[..., class_ind] = class_label

if self.transforms is not None:
inputs = {"image": image, "mask": label} if self.is_train else {"image": image}
result = self.transforms(**inputs)

image = result["image"]
label = result["mask"] if self.is_train else label

# to tenser will be done later
image = image.transpose(2, 0, 1) # make channel first
label = label.transpose(2, 0, 1)

image = torch.from_numpy(image).float()
label = torch.from_numpy(label).float()

return image, label




class XRayInferenceDataset(Dataset):
def __init__(self, transforms=None):
pngs=get_test_images(TEST_IMAGE_ROOT)
_filenames = pngs
_filenames = np.array(sorted(_filenames))

self.filenames = _filenames
self.transforms = transforms

def __len__(self):
return len(self.filenames)

def __getitem__(self, item):
image_name = self.filenames[item]
image_path = os.path.join(IMAGE_ROOT, image_name)

image = cv2.imread(image_path)
image = image / 255.

if self.transforms is not None:
inputs = {"image": image}
result = self.transforms(**inputs)
image = result["image"]

# to tenser will be done later
image = image.transpose(2, 0, 1) # make channel first

image = torch.from_numpy(image).float()

return image, image_name
35 changes: 35 additions & 0 deletions UNet3+/Code/InfetenceRun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import albumentations as A
from DataSet.Dataset import XRayInferenceDataset
from torch.utils.data import DataLoader
from config import SAVED_DIR, INFERENCE_MODEL_NAME, IMSIZE, CSVDIR,CSVNAME
import os
import torch
from Infrence import test
import pandas as pd

model = torch.load(os.path.join(SAVED_DIR, INFERENCE_MODEL_NAME))

tf = A.Resize(IMSIZE, IMSIZE)
test_dataset = XRayInferenceDataset(transforms=tf)

test_loader = DataLoader(
dataset=test_dataset,
batch_size=2,
shuffle=False,
num_workers=2,
drop_last=False
)

rles, filename_and_class = test(model, test_loader)

classes, filename = zip(*[x.split("_") for x in filename_and_class])

image_name = [os.path.basename(f) for f in filename]

df = pd.DataFrame({
"image_name": image_name,
"class": classes,
"rle": rles,
})

df.to_csv(os.path.join(CSVDIR, CSVNAME),index=False)
59 changes: 59 additions & 0 deletions UNet3+/Code/Infrence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import os
import numpy as np
from tqdm.auto import tqdm
from config import CLASSES,IND2CLASS
import torch.nn.functional as F

def encode_mask_to_rle(mask):
'''
mask: numpy array binary mask
1 - mask
0 - background
Returns encoded run length
'''
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)


def decode_rle_to_mask(rle, height, width):
s = rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(height * width, dtype=np.uint8)

for lo, hi in zip(starts, ends):
img[lo:hi] = 1

return img.reshape(height, width)


def test(model, data_loader, thr=0.5):
model = model.cuda()
model.eval()

rles = []
filename_and_class = []
with torch.no_grad():
n_class = len(CLASSES)

for step, (images, image_names) in tqdm(enumerate(data_loader), total=len(data_loader)):
images = images.cuda()
outputs = model(images)

# restore original size
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
outputs = torch.sigmoid(outputs)
outputs = (outputs > thr).detach().cpu().numpy()

for output, image_name in zip(outputs, image_names):
for c, segm in enumerate(output):
rle = encode_mask_to_rle(segm)
rles.append(rle)
filename_and_class.append(f"{IND2CLASS[c]}_{image_name}")

return rles, filename_and_class
128 changes: 128 additions & 0 deletions UNet3+/Code/Loss/Loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import exp


def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()


def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window


def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
if val_range is None:
L = 1 # Assuming normalized images in [0, 1]

padd = 0
(_, channel, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window(real_size, channel=channel).to(img1.device)

mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2

sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2

v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity

ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)

if full:
return ret, cs
return ret


def msssim(img1, img2, window_size=11, size_average=True, normalize=False):
device = img1.device
weights = torch.FloatTensor([0.1448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
levels = weights.size()[0]
mssim = []
mcs = []
for _ in range(levels):
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True)
mssim.append(sim)
mcs.append(cs)

img1 = F.avg_pool2d(img1, (2, 2))
img2 = F.avg_pool2d(img2, (2, 2))

mssim = torch.stack(mssim)
mcs = torch.stack(mcs)

if normalize:
mssim = (mssim + 1) / 2
mcs = (mcs + 1) / 2

pow1 = mcs ** weights
pow2 = mssim ** weights
output = torch.prod(pow1[:-1] * pow2[-1])
return output


class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel

def forward(self, img1, img2):
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True)


class CombinedLoss(nn.Module):
def __init__(self, alpha=0.5, beta=0.3, gamma=0.2, smooth=1e-6):
"""
Combined Loss = alpha * Focal Loss + beta * IoU Loss + gamma * MS-SSIM Loss
"""
super(CombinedLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.smooth = smooth
self.ms_ssim = MSSSIM(window_size=11, size_average=True, channel=1)

def focal_loss(self, logits, targets, alpha=0.8, gamma=2.0):
probs = torch.sigmoid(logits)
focal_loss = -alpha * (1 - probs) ** gamma * targets * torch.log(probs + 1e-6) \
- (1 - alpha) * probs ** gamma * (1 - targets) * torch.log(1 - probs + 1e-6)
return focal_loss.mean()

def iou_loss(self, logits, targets):
probs = torch.sigmoid(logits)
intersection = (probs * targets).sum(dim=(2, 3))
union = probs.sum(dim=(2, 3)) + targets.sum(dim=(2, 3)) - intersection
iou_loss = 1 - (intersection + self.smooth) / (union + self.smooth)
return iou_loss.mean()

def forward(self, logits, targets):
focal = self.focal_loss(logits, targets)
iou = self.iou_loss(logits, targets)
ms_ssim_loss = 1 - self.ms_ssim(torch.sigmoid(logits), targets)

total_loss = self.alpha * focal + self.beta * iou + self.gamma * ms_ssim_loss
return total_loss
Loading

0 comments on commit 9d9a200

Please sign in to comment.