Skip to content

Commit

Permalink
feat: goodSet UNet3+ #25
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 26, 2024
1 parent f6b2179 commit 2746f1f
Show file tree
Hide file tree
Showing 15 changed files with 1,538 additions and 67 deletions.
53 changes: 37 additions & 16 deletions UNet3+/Code/CropTrainRun.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import torch.optim as optim
from torch.utils.data import DataLoader

from Model.HRnetModel import UNet3PlusHRNet
from Model.HRNetModel_Reduce_finalConv_weight import UNet3PlusHRNet
from DataSet.DataLoder import get_image_label_paths
from config import IMAGE_ROOT, LABEL_ROOT, BATCH_SIZE, IMSIZE, CLASSES, MILESTONES, GAMMA, LR, SAVED_DIR, VISUALIZE_TRAIN_DATA, SAVE_VISUALIZE_TRAIN_DATA_PATH,NUM_EPOCHS
from config import IMAGE_ROOT, LABEL_ROOT, BATCH_SIZE, IMSIZE, CLASSES, MILESTONES, GAMMA, LR, SAVED_DIR, VISUALIZE_TRAIN_DATA,NUM_EPOCHS
from DataSet.LabelBaseCropDataset import XRayDataset
from Loss.Loss import CombinedLoss
from Train import train
from NoMixedWeightTrain import train
from Util.SetSeed import set_seed
from sklearn.utils import shuffle

Expand All @@ -37,7 +37,7 @@ def main():

# 전체 데이터의 20%를 validation data로 쓰기 위해 `n_splits`를
# 5으로 설정하여 GroupKFold를 수행합니다.
gkf = GroupKFold(n_splits=5)
gkf = GroupKFold(n_splits=4)

train_filenames = []
train_labelnames = []
Expand All @@ -55,29 +55,48 @@ def main():


# tf = A.Resize(IMSIZE,IMSIZE)
train_dataset = XRayDataset(
train_dataset1 = XRayDataset(
train_filenames,
train_labelnames,
is_train=True,
save_dir=SAVE_VISUALIZE_TRAIN_DATA_PATH,
#save_dir='/data/ephemeral/home/MCG/YOLO_Detection_Model/crop_image',
draw_enabled=False,
)
valid_dataset = XRayDataset(
valid_filenames,
valid_labelnames,
is_train=False,
save_dir=SAVE_VISUALIZE_TRAIN_DATA_PATH,
tf = A.Compose([
#A.ElasticTransform(alpha=2, sigma=80, p=0.5), # Elastic Transform
A.Rotate(limit=10, p=0.8), # Random Rotation (-12 ~ 12도, 70% 확률)
A.HorizontalFlip(p=1), # Horizontal Flip (항상 적용)
A.RandomBrightnessContrast(
brightness_limit=0.24, # 밝기 조정 범위: ±20%
contrast_limit=0.24, # 대비 조정 범위: ±20%
brightness_by_max=False, # 정규화된 값 기준으로 밝기 조정
p=0.8 # 50% 확률로 적용
),
])

train_dataset2 = XRayDataset(
train_filenames,
train_labelnames,
is_train=True,
transforms=tf,
save_dir=None,
draw_enabled=False,
)

train_dataset=train_dataset1+train_dataset2
train_loader = DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=8, # 멀티프로세싱을 사용할 수 있도록 설정
drop_last=True,
)

valid_dataset = XRayDataset(
valid_filenames,
valid_labelnames,
is_train=False,
save_dir=None,
draw_enabled=False,
)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=1,
Expand All @@ -89,11 +108,13 @@ def main():
model = UNet3PlusHRNet(n_classes=len(CLASSES))

# Loss function 정의
criterion = CombinedLoss(focal_weight=1, iou_weight=1, ms_ssim_weight=1, dice_weight=0)
criterion = CombinedLoss(focal_weight=1.5, iou_weight=1, ms_ssim_weight=1, dice_weight=0)

# Optimizer 정의
optimizer = optim.AdamW(params=model.parameters(), lr=LR, weight_decay=2e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max',factor=0.3,patience=3,verbose=True)
'''optimizer = optim.AdamW(params=model.parameters(), lr=LR, weight_decay=2e-6)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='max',factor=0.5,patience=4,verbose=True)'''
optimizer = optim.AdamW(params=model.parameters(), lr=3e-4, weight_decay=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=55, T_mult=2, eta_min=5e-6)

train(model, train_loader, valid_loader, criterion, optimizer, scheduler)

Expand Down
5 changes: 3 additions & 2 deletions UNet3+/Code/DataSet/LabelBaseCropDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, filenames, labelnames, transforms=None,
self.transforms = transforms
self.save_dir = save_dir # Crop된 이미지 저장 디렉토리
self.draw_enabled = draw_enabled # 라벨 그리기 기능 활성화 여부

self.save_once=False
def __len__(self):
return len(self.filenames)

Expand Down Expand Up @@ -70,10 +70,11 @@ def __getitem__(self, item):
image = result["image"]
label = result["mask"] if self.is_train else label

if self.draw_enabled and self.save_dir:
if self.draw_enabled and self.save_dir and not self.save_once:
os.makedirs(self.save_dir, exist_ok=True)
save_path = os.path.join(self.save_dir, f"cropped_{os.path.basename(self.filenames[item])}")
draw_and_save_crop(image, label, save_path)
self.save_once = True

# Convert to tensor
image = image.transpose(2, 0, 1) # Channel first
Expand Down
6 changes: 3 additions & 3 deletions UNet3+/Code/Loss/Loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,16 @@ def boundary_loss(self, logits, targets):

def forward(self, logits, targets):
focal = self.focal_loss(logits, targets) * self.focal_weight
dice = self.dice_loss(logits, targets) * self.dice_weight
dice = 0 #self.dice_loss(logits, targets) * self.dice_weight
iou = self.iou_loss(logits, targets) * self.iou_weight
#bce=self.bce_loss_fn(logits, targets)
#gdl = self.gdl_loss(logits, targets) * self.gdl_weight
ms_ssim = self.ms_ssim(logits, targets) * self.ms_ssim_weight
# boundary = self.boundary_loss(logits, targets) * self.boundary_weight

# Combined loss
total_loss = focal + dice + iou + ms_ssim #gdl
return total_loss, focal, iou, dice, ms_ssim #gdl
total_loss = focal + iou + ms_ssim #+ dice#gdl
return total_loss, focal, iou, ms_ssim #dice#gdl


'''
Expand Down
Loading

0 comments on commit 2746f1f

Please sign in to comment.