Skip to content

Commit

Permalink
fix: focal_Loss UNet3+ #25
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 22, 2024
1 parent da95f38 commit ec8a4c1
Show file tree
Hide file tree
Showing 10 changed files with 380 additions and 88 deletions.
16 changes: 8 additions & 8 deletions UNet3+/Code/CropTrainRun.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from Model.FixedModel import UNet_3Plus_DeepSup
from DataSet.DataLoder import get_image_label_paths
from config import IMAGE_ROOT, LABEL_ROOT, BATCH_SIZE, IMSIZE, CLASSES, MILESTONES, GAMMA, LR, SAVED_DIR, VISUALIZE_TRAIN_DATA, SAVE_VISUALIZE_TRAIN_DATA_PATH
from config import IMAGE_ROOT, LABEL_ROOT, BATCH_SIZE, IMSIZE, CLASSES, MILESTONES, GAMMA, LR, SAVED_DIR, VISUALIZE_TRAIN_DATA, SAVE_VISUALIZE_TRAIN_DATA_PATH,NUM_EPOCHS
from DataSet.LabelBaseCropDataset import XRayDataset
from Loss.Loss import CombinedLoss
from Train import train
Expand Down Expand Up @@ -57,15 +57,15 @@ def main():
train_filenames,
train_labelnames,
is_train=True,
save_dir=SAVE_VISUALIZE_TRAIN_DATA_PATH,
draw_enabled=VISUALIZE_TRAIN_DATA,
save_dir=None,
draw_enabled=False,
)
valid_dataset = XRayDataset(
valid_filenames,
valid_labelnames,
is_train=False,
save_dir=None,
draw_enabled=False,
save_dir=SAVE_VISUALIZE_TRAIN_DATA_PATH,
draw_enabled=VISUALIZE_TRAIN_DATA,
)

train_loader = DataLoader(
Expand All @@ -87,11 +87,11 @@ def main():
model = UNet_3Plus_DeepSup(n_classes=len(CLASSES))

# Loss function 정의
criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1)
criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=0)

# Optimizer 정의
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=GAMMA)
optimizer = optim.AdamW(params=model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

train(model, train_loader, valid_loader, criterion, optimizer, scheduler)

Expand Down
1 change: 0 additions & 1 deletion UNet3+/Code/InfetenceRun.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def test(model, data_loader, thr=0.5):

# restore original size
outputs = F.interpolate(outputs, size=(2048, 2048), mode="bilinear")
outputs = torch.sigmoid(outputs)
outputs = (outputs > thr).detach().cpu().numpy()

for output, image_name in zip(outputs, image_names):
Expand Down
83 changes: 53 additions & 30 deletions UNet3+/Code/Loss/Loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F
from math import exp

import numpy as np
import torchvision

def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
Expand Down Expand Up @@ -107,27 +108,46 @@ def __init__(self, window_size=11, size_average=True, channel=3):
def forward(self, img1, img2):
return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average, normalize=True)




class CombinedLoss(nn.Module):
def __init__(self, focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1, smooth=1e-6, channel=3):
"""
Combined Loss = alpha * Focal Loss + beta * IoU Loss + gamma * MS-SSIM Loss + delta * Dice Loss
"""
super(CombinedLoss, self).__init__()
self.alpha = focal_weight # Weight for Focal Loss
self.beta = iou_weight # Weight for IoU Loss
self.gamma = ms_ssim_weight # Weight for MS-SSIM Loss
self.delta = dice_weight # Weight for Dice Loss
self.focal_weight = focal_weight
self.iou_weight = iou_weight
self.ms_ssim_weight = ms_ssim_weight
self.dice_weight = dice_weight
self.smooth = smooth
self.ms_ssim = MSSSIM(window_size=7, size_average=True, channel=channel)
self.ms_ssim = MSSSIM(window_size=11, size_average=True, channel=channel)
self.bce_loss_fn = nn.BCEWithLogitsLoss(reduction='mean') # BCE loss with logits
def adaptive_focal_loss(self, logits, targets, alpha=1, gamma_min=1.5, gamma_max=4.0, reduce=True):
# Compute BCE loss
BCE_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')#self.bce_loss_fn(logits, targets)

def focal_loss(self, logits, targets, alpha=0.8, gamma=2):
probs = torch.sigmoid(logits)
focal_loss = -alpha * (1 - probs) ** gamma * targets * torch.log(probs + 1e-6) \
- (1 - alpha) * probs ** gamma * (1 - targets) * torch.log(1 - probs + 1e-6)
return focal_loss.mean()
# Compute pt (predicted probability for true class)
pt = torch.exp(-BCE_loss)

# Dynamically adjust gamma based on pt
gamma = gamma_min + (1 - pt) * (gamma_max - gamma_min)
gamma = torch.clamp(gamma, gamma_min, gamma_max) # Ensure gamma stays within [gamma_min, gamma_max]

# Compute Focal Loss
F_loss = alpha * (1 - pt) ** gamma * BCE_loss

# Reduce loss if required
if reduce:
return torch.mean(F_loss)
else:
return F_loss


def focal_loss(self, logits, targets, alpha=1, gamma=1.8, reduce=True):
BCE_loss= F.binary_cross_entropy_with_logits(logits, targets, reduction='none')#self.bce_loss_fn(logits, targets)
#print("BCE:",BCE_loss)
pt = torch.exp(-BCE_loss)
F_loss = alpha * (1-pt)**gamma * BCE_loss
if reduce:
return torch.mean(F_loss)
else:
return F_loss

def iou_loss(self, logits, targets):
probs = torch.sigmoid(logits)
Expand All @@ -137,24 +157,27 @@ def iou_loss(self, logits, targets):
return iou_loss.mean()

def dice_loss(self, logits, targets):
"""
Dice Loss = 1 - (2 * intersection + smooth) / (sum_probs + sum_targets + smooth)
"""
probs = torch.sigmoid(logits)
intersection = (probs * targets).sum(dim=(2, 3))
sum_probs = probs.sum(dim=(2, 3))
sum_targets = targets.sum(dim=(2, 3))
dice = (2 * intersection + self.smooth) / (sum_probs + sum_targets + self.smooth)
return 1 - dice.mean()

def bce_loss(self, logits, targets):
# Use BCEWithLogitsLoss for numerical stability
return self.bce_loss_fn(logits, targets)
def forward(self, logits, targets):
# Calculate individual losses
focal = self.focal_loss(logits, targets)
#iou = self.iou_loss(logits, targets)
ms_ssim_loss = 1 - self.ms_ssim(torch.sigmoid(logits), targets)
dice = self.dice_loss(logits, targets)

# Combine losses with respective weights
total_loss = self.alpha * focal + self.gamma * ms_ssim_loss + self.delta * dice #self.beta * iou #+ self.delta * dice
return total_loss
focal = self.focal_loss(logits, targets)*self.focal_weight
ms_ssim_loss = 1 - self.ms_ssim(torch.sigmoid(logits), targets) * self.ms_ssim_weight
dice = self.dice_loss(logits, targets) * self.dice_weight
iou= self.iou_loss(logits,targets) * self.iou_weight
#bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='mean')


# Combined loss
total_loss = focal + ms_ssim_loss +iou + dice

# 개별 손실 값 로깅을 위해 반환
return total_loss, focal, ms_ssim_loss,iou, dice,


9 changes: 4 additions & 5 deletions UNet3+/Code/Model/FixedModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def __init__(self, in_channels=3, n_classes=1, feature_scale=4, is_deconv=True,
self.conv3 = self.convnext[3:5] # ConvNeXt Stage 2 (Output: 28x28, 384 channels)
self.conv4 = self.convnext[5:7]
self.conv5 = nn.Sequential(
nn.Conv2d(filters[4], filters[4], kernel_size=3, stride=2, padding=1), # DownSample
nn.BatchNorm2d(filters[4]),
nn.GELU(), # GELU activation function
self.convnext[7:]) # ConvNeXt Stage 4 (Output: 7x7, 1536 channels)
nn.MaxPool2d(kernel_size=2, stride=2), # DownSample using MaxPool
self.convnext[7:])


## -------------Decoder--------------
Expand Down Expand Up @@ -401,8 +399,9 @@ def forward(self, inputs):
d5 = self.dotProduct(d5, cls_branch_mask)
'''

if self.training:
return d1, d2, d3, d4, d5
return torch.cat((d1, d2, d3, d4, d5), dim=0)
else:
#print(d1)
return d1
8 changes: 6 additions & 2 deletions UNet3+/Code/Model/model_shape_check.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from torchvision.models import convnext_large
import torch

# ConvNeXt Large 모델 로드
model = convnext_large(pretrained=True)
#model = convnext_large(pretrained=True)

# 모델 구조 출력
print(model)
#print(model)

ce_loss = torch.log(torch.tensor(1e-6))
print(ce_loss)
36 changes: 21 additions & 15 deletions UNet3+/Code/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,18 @@ def save_model(model, file_name=MODELNAME):
output_path = os.path.join(SAVED_DIR, file_name)
torch.save(model, output_path)

def train(model, data_loader, val_loader, criterion, optimizer, scheduler, accumulation_steps=ACCUMULATION_STEPS):
def train(model, data_loader, val_loader, criterion, optimizer, scheduler, accumulation_steps=ACCUMULATION_STEPS, threshold=0.92):
"""
Args:
accumulation_steps (int): Number of steps to accumulate gradients before updating.
threshold (float): Dice 점수를 기준으로 손실 함수 변경.
"""
print(f'Start training with Gradient Accumulation (accumulation_steps={accumulation_steps})..')
model.cuda()

n_class = len(CLASSES)
best_dice = 0.0

# 손실 가중치 (Deep Supervision)
deep_sup_weights = [0.5, 0.3, 0.2, 0.15, 0.1] # 각 출력에 대한 가중치

# Mixed Precision Scaler 생성
scaler = GradScaler()

Expand All @@ -47,18 +45,12 @@ def train(model, data_loader, val_loader, criterion, optimizer, scheduler, accum
# Inference 및 Mixed Precision 적용
with autocast(): # Mixed Precision 모드
outputs = model(images)
batch_masks = masks.repeat(5, 1, 1, 1)

# Deep Supervision 처리: 여러 출력을 가정
if isinstance(outputs, (tuple, list)): # 출력이 리스트/튜플 형태인 경우
total_loss = 0.0
for i, output in enumerate(outputs):
loss = criterion(output, masks) # 각 출력의 손실 계산
total_loss += loss * deep_sup_weights[i] # 가중치를 곱해 합산
else: # 출력이 단일 텐서인 경우 (예외 처리)
total_loss = criterion(outputs, masks)
loss, focal, ms_ssim_loss, iou, dice = criterion(outputs, batch_masks) # 각 출력의 손실 계산

# Loss Scaling 및 Backpropagation (Gradient Accumulation)
scaler.scale(total_loss).backward()
scaler.scale(loss).backward()

# Gradient Accumulation Steps 마다 업데이트
if (step + 1) % accumulation_steps == 0:
Expand All @@ -72,7 +64,11 @@ def train(model, data_loader, val_loader, criterion, optimizer, scheduler, accum
f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} | '
f'Epoch [{epoch+1}/{NUM_EPOCHS}], '
f'Step [{step+1}/{len(data_loader)}], '
f'Loss: {round(total_loss.item(), 4)}'
f'Loss: {round(loss.item(), 4)} | '
f'Focal: {round(focal.item(), 4)}, '
f'MS-SSIM: {round(ms_ssim_loss.item(), 4)}, '
f'IoU: {round(iou.item(), 4)}, '
f'Dice: {round(dice.item(), 4)}'
)

# 마지막 미니배치 처리 후 Gradient 업데이트
Expand All @@ -83,7 +79,17 @@ def train(model, data_loader, val_loader, criterion, optimizer, scheduler, accum

# Validation 주기에 따른 Loss 출력 및 Best Model 저장
if (epoch + 1) % VAL_EVERY == 0:
dice = validation(epoch + 1, model, val_loader, criterion)
dice = validation(epoch + 1, model, val_loader)

# Validation 결과에 따른 손실 함수 선택
if dice < threshold:
print(f"Validation Dice ({dice:.4f}) < Threshold ({threshold}), using IoU Loss.")
criterion.delta = 0 # Dice Loss 비활성화
criterion.beta = 1 # IoU Loss 활성화
else:
print(f"Validation Dice ({dice:.4f}) >= Threshold ({threshold}), using Dice Loss.")
criterion.delta = 1 # Dice Loss 활성화
criterion.beta = 0 # IoU Loss 비활성화

if best_dice < dice:
print(f"Best performance at epoch: {epoch + 1}, {best_dice:.4f} -> {dice:.4f}")
Expand Down
2 changes: 1 addition & 1 deletion UNet3+/Code/TrainRun.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=1)

# Optimizer 정의
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-6)
optimizer = optim.Adam(params=model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=GAMMA)


Expand Down
27 changes: 9 additions & 18 deletions UNet3+/Code/Validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from config import CLASSES
import torch.nn.functional as F

def validation(epoch, model, data_loader, criterion, thr=0.5):
def validation(epoch, model, data_loader, thr=0.5):
print(f'Start validation #{epoch:2d}')
model.cuda()
model.eval()

dices = []
total_loss = 0
cnt = 0

with torch.no_grad():
total_steps = len(data_loader) # 데이터 로더 총 스텝 수
Expand All @@ -24,27 +22,21 @@ def validation(epoch, model, data_loader, criterion, thr=0.5):
# 출력 크기 보정 (필요한 경우만)
if outputs.shape[-2:] != masks.shape[-2:]:
outputs = F.interpolate(outputs, size=masks.shape[-2:], mode="bilinear", align_corners=False)

# 손실 계산
loss = criterion(outputs, masks)
total_loss += loss.item()
cnt += 1


# 출력 이진화 및 Dice 계산 (GPU 상에서 처리)
outputs = (torch.sigmoid(outputs) > thr).float()
outputs = (outputs > thr).float()
dice = dice_coef(outputs, masks)
dices.append(dice.detach()) # GPU에서 유지

# 진행 상황과 손실 출력
if (step + 1) % 80 == 0 or (step + 1) == total_steps: # 매 10 스텝마다 또는 마지막 스텝에서 출력
avg_loss = total_loss / cnt
print(f"Validation Progress: Step {step + 1}/{total_steps}, Avg Loss: {avg_loss:.4f}")
# 진행 상황 출력
if (step + 1) % 80 == 0 or (step + 1) == total_steps: # 매 80 스텝마다 또는 마지막 스텝에서 출력
print(f"Validation Progress: Step {step + 1}/{total_steps}")

# GPU 상에서 Dice 평균 계산
dices = torch.cat(dices, 0)
dices_per_class = dices.mean(dim=0)

# 로그 출력
# 클래스별 Dice 점수 출력
dice_str = [
f"{c:<12}: {d.item():.4f}"
for c, d in zip(CLASSES, dices_per_class)
Expand All @@ -54,8 +46,7 @@ def validation(epoch, model, data_loader, criterion, thr=0.5):

avg_dice = dices_per_class.mean().item()

# 최종 평균 손실 출력
avg_loss = total_loss / cnt
print(f"Validation Completed: Avg Loss: {avg_loss:.4f}, Avg Dice: {avg_dice:.4f}")
# 최종 평균 Dice 출력
print(f"Validation Completed: Avg Dice: {avg_dice:.4f}")

return avg_dice
16 changes: 8 additions & 8 deletions UNet3+/Code/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,28 @@
RANDOM_SEED = 21

# 적절하게 조절
NUM_EPOCHS =52
NUM_EPOCHS =75
VAL_EVERY = 1

ACCUMULATION_STEPS=32
ACCUMULATION_STEPS=16
BATCH_SIZE = 1
IMSIZE=480

LR = 0.0003
MILESTONES=[20,30,37]
GAMMA=0.2
LR = 0.0008
MILESTONES=[5,20,32,40,47]
GAMMA=0.3


SAVED_DIR = "/data/ephemeral/home/MCG/UNetRefactored/Creadted_model/"
MODELNAME="othersCrop_AddBottleNeck_ConvTrans_dice_52.pt"
MODELNAME="CropOthersChangeLoss.pt"
if not os.path.isdir(SAVED_DIR):
os.mkdir(SAVED_DIR)



INFERENCE_MODEL_NAME="othersCrop_AddBottleNeck_ConvTrans_dice_52.pt"
INFERENCE_MODEL_NAME="CropOthersChangeLoss.pt"

TEST_IMAGE_ROOT="/data/ephemeral/home/MCG/data/test/DCM"

CSVDIR="/data/ephemeral/home/MCG/UNetRefactored/CSV"
CSVNAME="othersCrop_AddBottleNeck_ConvTrans_dice_52.csv"
CSVNAME="CropOthersChangeLoss.csv"
Loading

0 comments on commit ec8a4c1

Please sign in to comment.