Skip to content

Commit

Permalink
fix: Loss Bug UNet3+ #25
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 23, 2024
1 parent f9af7e6 commit aa320b1
Show file tree
Hide file tree
Showing 8 changed files with 722 additions and 98 deletions.
12 changes: 7 additions & 5 deletions UNet3+/Code/CropTrainRun.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import torch.optim as optim
from torch.utils.data import DataLoader

from Model.FixedModel import UNet_3Plus_DeepSup
from Model.resnetModel import UNet_3Plus_DeepSup
from DataSet.DataLoder import get_image_label_paths
from config import IMAGE_ROOT, LABEL_ROOT, BATCH_SIZE, IMSIZE, CLASSES, MILESTONES, GAMMA, LR, SAVED_DIR, VISUALIZE_TRAIN_DATA, SAVE_VISUALIZE_TRAIN_DATA_PATH,NUM_EPOCHS
from DataSet.LabelBaseCropDataset import XRayDataset
from Loss.Loss import CombinedLoss
from Train import train
from Util.SetSeed import set_seed

from sklearn.utils import shuffle

def main():
set_seed()
Expand All @@ -31,7 +31,7 @@ def main():
# 폴더 이름을 그룹으로 해서 GroupKFold를 수행합니다.
# 동일 인물의 손이 train, valid에 따로 들어가는 것을 방지합니다.
groups = [os.path.dirname(fname) for fname in pngs]

groups = shuffle(groups, random_state=21)
# dummy label
ys = [0 for fname in pngs]

Expand All @@ -51,6 +51,8 @@ def main():
else:
train_filenames += list(pngs[y])
train_labelnames += list(jsons[y])



# tf = A.Resize(IMSIZE,IMSIZE)
train_dataset = XRayDataset(
Expand Down Expand Up @@ -90,8 +92,8 @@ def main():
criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=0)

# Optimizer 정의
optimizer = optim.AdamW(params=model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)
optimizer = optim.AdamW(params=model.parameters(), lr=LR, weight_decay=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=8e-7)

train(model, train_loader, valid_loader, criterion, optimizer, scheduler)

Expand Down
123 changes: 89 additions & 34 deletions UNet3+/Code/Loss/Loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from math import exp
import numpy as np
import torchvision
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from config import CLASSES

'''
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
gauss = torch.tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()

def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
Expand Down Expand Up @@ -82,42 +84,75 @@ def msssim(img1, img2, window_size=11, size_average=True, normalize=False):
pow2 = mssim ** weights
output = torch.prod(pow1[:-1] * pow2[-1])
return output
'''
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
self.window = create_window(window_size, channel)
def forward(self, img1, img2):
# Ensure the images are normalized
img1 = img1 / img1.max() if img1.max() > 1 else img1
img2 = img2 / img2.max() if img2.max() > 1 else img2
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True)
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=False)
'''
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel

def forward(self, img1, img2):
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True)

class MS_SSIM_Loss(nn.Module):
def __init__(self, data_range=1.0, size_average=True, win_size=11, win_sigma=1.5, weights=None):
"""
MS-SSIM Loss for PyTorch models.
Args:
data_range (float): Value range of input images. Default is 1.0 (normalized images).
size_average (bool): If True, average the MS-SSIM values over all samples.
win_size (int): Gaussian window size. Default is 11.
win_sigma (float): Standard deviation of the Gaussian window. Default is 1.5.
weights (list): Weights for different MS-SSIM levels. Default is None (uses preset weights).
"""
super(MS_SSIM_Loss, self).__init__()
self.ms_ssim = MS_SSIM(
data_range=data_range,
size_average=size_average,
win_size=win_size,
win_sigma=win_sigma,
weights=weights,
channel=len(CLASSES)
)

def forward(self, logits, targets):
"""
Forward pass for the loss calculation.
Args:
logits (Tensor): Model outputs, typically raw scores (B, C, H, W).
targets (Tensor): Ground truth images (B, C, H, W) normalized to [0, 1].
Returns:
Tensor: MS-SSIM loss value.
"""
# Convert logits to probabilities using Sigmoid (for binary/multi-label tasks) or Softmax (multi-class tasks)
probs = torch.sigmoid(logits) # Use softmax if multi-class: torch.softmax(logits, dim=1)

# Ensure targets are of the same dtype as probs
targets = targets.type_as(probs)

# Calculate MS-SSIM (higher values indicate better similarity)
ms_ssim_val = self.ms_ssim(probs, targets)

# Return 1 - MS-SSIM as the loss (lower MS-SSIM indicates higher loss)
return 1 - ms_ssim_val


class CombinedLoss(nn.Module):
def __init__(self, focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1, smooth=1e-6, channel=3):
def __init__(self, focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1, gdl_weight=0, smooth=1e-6, channel=3):
super(CombinedLoss, self).__init__()
self.focal_weight = focal_weight
self.iou_weight = iou_weight
self.ms_ssim_weight = ms_ssim_weight
self.dice_weight = dice_weight
self.gdl_weight = gdl_weight
self.smooth = smooth
self.ms_ssim = MSSSIM(window_size=11, size_average=True, channel=channel)
self.ms_ssim = MS_SSIM_Loss()
self.bce_loss_fn = nn.BCEWithLogitsLoss(reduction='mean') # BCE loss with logits

def adaptive_focal_loss(self, logits, targets, alpha=1, gamma_min=1.5, gamma_max=4.0, reduce=True):
# Compute BCE loss
BCE_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')#self.bce_loss_fn(logits, targets)
Expand All @@ -138,16 +173,15 @@ def adaptive_focal_loss(self, logits, targets, alpha=1, gamma_min=1.5, gamma_max
else:
return F_loss


def focal_loss(self, logits, targets, alpha=1, gamma=1.8, reduce=True):
def focal_loss(self, logits, targets, alpha=1, gamma=1.5, reduce=True):
BCE_loss= F.binary_cross_entropy_with_logits(logits, targets, reduction='none')#self.bce_loss_fn(logits, targets)
#print("BCE:",BCE_loss)
pt = torch.exp(-BCE_loss)
F_loss = alpha * (1-pt)**gamma * BCE_loss
if reduce:
return torch.mean(F_loss)
else:
return F_loss
return F_loss.sum() / logits.size(0)

def iou_loss(self, logits, targets):
probs = torch.sigmoid(logits)
Expand All @@ -163,21 +197,42 @@ def dice_loss(self, logits, targets):
sum_targets = targets.sum(dim=(2, 3))
dice = (2 * intersection + self.smooth) / (sum_probs + sum_targets + self.smooth)
return 1 - dice.mean()
def bce_loss(self, logits, targets):
# Use BCEWithLogitsLoss for numerical stability
return self.bce_loss_fn(logits, targets)
def forward(self, logits, targets):
focal = self.focal_loss(logits, targets)*self.focal_weight
ms_ssim_loss = 1 - self.ms_ssim(torch.sigmoid(logits), targets) * self.ms_ssim_weight
dice = self.dice_loss(logits, targets) * self.dice_weight
iou= self.iou_loss(logits,targets) * self.iou_weight
#bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='mean')

def gdl_loss(self, logits, targets):
probs = torch.sigmoid(logits)
d_probs_dx = torch.abs(probs[:, :, :, :-1] - probs[:, :, :, 1:])
d_targets_dx = torch.abs(targets[:, :, :, :-1] - targets[:, :, :, 1:])
d_probs_dy = torch.abs(probs[:, :, :-1, :] - probs[:, :, 1:, :])
d_targets_dy = torch.abs(targets[:, :, :-1, :] - targets[:, :, 1:, :])
gdl_x = torch.abs(d_probs_dx - d_targets_dx).mean()
gdl_y = torch.abs(d_probs_dy - d_targets_dy).mean()
return gdl_x + gdl_y

def forward(self, logits, targets):
focal = self.focal_loss(logits, targets) * self.focal_weight
dice = self.dice_loss(logits, targets) * self.dice_weight
iou = self.iou_loss(logits, targets) * self.iou_weight
#bce=self.bce_loss_fn(logits, targets)
#gdl = self.gdl_loss(logits, targets) * self.gdl_weight
ms_ssim = self.ms_ssim(logits, targets) * self.ms_ssim_weight
# Combined loss
total_loss = focal + ms_ssim_loss +iou + dice
total_loss = focal + dice + iou + ms_ssim #gdl
return total_loss, focal, iou, dice, ms_ssim #gdl

# 개별 손실 값 로깅을 위해 반환
return total_loss, focal, ms_ssim_loss,iou, dice,

'''
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3):
super(MSSSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = channel
self.window = create_window(window_size, channel)
def forward(self, img1, img2):
# Ensure the images are normalized
img1 = img1 / img1.max() if img1.max() > 1 else img1
img2 = img2 / img2.max() if img2.max() > 1 else img2
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True)
'''
Loading

0 comments on commit aa320b1

Please sign in to comment.