Skip to content

Commit

Permalink
test: iou UNet3+ #25
Browse files Browse the repository at this point in the history
  • Loading branch information
MCG committed Nov 21, 2024
1 parent bc39d2d commit c1d0c8e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
6 changes: 3 additions & 3 deletions UNet3+/Code/Loss/Loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def dice_loss(self, logits, targets):
def forward(self, logits, targets):
# Calculate individual losses
focal = self.focal_loss(logits, targets)
#iou = self.iou_loss(logits, targets)
iou = self.iou_loss(logits, targets)
ms_ssim_loss = 1 - self.ms_ssim(torch.sigmoid(logits), targets)
dice = self.dice_loss(logits, targets)
#dice = self.dice_loss(logits, targets)

# Combine losses with respective weights
total_loss = self.alpha * focal + self.gamma * ms_ssim_loss + self.delta * dice #+ self.beta * iou
total_loss = self.alpha * focal + self.gamma * ms_ssim_loss + self.beta * iou#+ self.delta * dice
return total_loss

10 changes: 5 additions & 5 deletions UNet3+/Code/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,24 @@
RANDOM_SEED = 21

# 적절하게 조절
NUM_EPOCHS =30
VAL_EVERY = 2
NUM_EPOCHS =20
VAL_EVERY = 1
IMSIZE=512
LR = 0.0001
MILESTONES=[7,16,23,27]
GAMMA=0.3


SAVED_DIR = "/data/ephemeral/home/MCG/UNetRefactored/Creadted_model/"
MODELNAME="dice_512.pt"
MODELNAME="iou_512.pt"
if not os.path.isdir(SAVED_DIR):
os.mkdir(SAVED_DIR)



INFERENCE_MODEL_NAME="dice_512.pt"
INFERENCE_MODEL_NAME="iou_512.pt"

TEST_IMAGE_ROOT="/data/ephemeral/home/MCG/data/test/DCM"

CSVDIR="/data/ephemeral/home/MCG/UNetRefactored/CSV"
CSVNAME="dice_512.CSV"
CSVNAME="iou_512.CSV"
13 changes: 12 additions & 1 deletion image-binary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -303,6 +303,17 @@
"Cell \u001b[0;32mIn[25], line 87\u001b[0m\n\u001b[1;32m 84\u001b[0m draw_polygons_with_arrows(image, data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mannotations\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 86\u001b[0m \u001b[38;5;66;03m# Save the annotated image\u001b[39;00m\n\u001b[0;32m---> 87\u001b[0m \u001b[43mcv2\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutput_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m현재 셀 또는 이전 셀에서 코드를 실행하는 동안 Kernel이 충돌했습니다. \n",
"\u001b[1;31m셀의 코드를 검토하여 가능한 오류 원인을 식별하세요. \n",
"\u001b[1;31m자세한 내용을 보려면 <a href='https://aka.ms/vscodeJupyterKernelCrash'>여기</a>를 클릭하세요. \n",
"\u001b[1;31m자세한 내용은 Jupyter <a href='command:jupyter.viewOutput'>로그</a>를 참조하세요."
]
}
],
"source": [
Expand Down

0 comments on commit c1d0c8e

Please sign in to comment.