Skip to content

Commit

Permalink
Add Finished UNet Model
Browse files Browse the repository at this point in the history
  • Loading branch information
gagewrye committed Aug 22, 2024
1 parent c8d134d commit 49f9634
Show file tree
Hide file tree
Showing 7 changed files with 559 additions and 321 deletions.
19 changes: 14 additions & 5 deletions Drone Classification/data/prepare_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -26,11 +26,12 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"base_path = '/Users/gage/Desktop/mangrove_data/Chunks'\n",
"base_path = \"C:/Users/gwrye/OneDrive/Desktop\"\n",
"\n",
"TILE_SIZE = 256\n",
"\n",
"combined_images_file = os.path.join(base_path, f'{TILE_SIZE}dataset_images.npy')\n",
Expand Down Expand Up @@ -179,9 +180,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Percent Shuffled: 100.00%\r"
]
}
],
"source": [
"# Shuffle data one entry at a time using Fisher-Yates shuffle\n",
"# This is necessary because the data is too large to load into memory all at once\n",
Expand Down
Binary file modified Drone Classification/models/ResNet18_UNet.pth
Binary file not shown.
3 changes: 2 additions & 1 deletion Drone Classification/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .models import JaccardLoss, ResNet18_UNet, ResNet50_UNet, ResNet_FC, SegmentModelWrapper
from .models import ResNet_UNet, ResNet_FC, SegmentModelWrapper
from .loss import JaccardLoss, FocalLoss
121 changes: 121 additions & 0 deletions Drone Classification/models/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Optional

class BCEDiceLoss(nn.Module):
def __init__(self, weight: Optional[torch.tensor] = None, size_average=True):
super(BCEDiceLoss, self).__init__()
self.bce = nn.BCEWithLogitsLoss(pos_weight=weight, reduction='mean' if size_average else 'sum')

def forward(self, inputs, targets):
bce_loss = self.bce(inputs, targets)
intersection = (inputs * targets).sum()
dice = (2. * intersection + 1) / (inputs.sum() + targets.sum() + 1)
dice_loss = 1 - dice
return bce_loss + dice_loss

class BCETverskyLoss(nn.Module):
def __init__(self, alpha=0.5, beta=0.5, smooth=1e-5, pos_weight=None):
super(BCETverskyLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.smooth = smooth
self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

def forward(self, inputs, targets):
# BCE Loss
bce_loss = self.bce(inputs, targets)

# Apply sigmoid to the inputs
inputs = torch.sigmoid(inputs)

# Flatten the tensors
inputs = inputs.view(-1)
targets = targets.view(-1)

# True positives, false positives, and false negatives
TP = (inputs * targets).sum()
FP = ((1 - targets) * inputs).sum()
FN = (targets * (1 - inputs)).sum()

# Tversky index
Tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
tversky_loss = 1 - Tversky

return bce_loss + tversky_loss

class JaccardLoss(nn.Module):
def __init__(self, smooth=1e-10):
super(JaccardLoss, self).__init__()
self.smooth = smooth

def forward(self, y_pred : torch.tensor, y_true: torch.tensor):
y_pred = torch.sigmoid(y_pred)

# Flatten the tensors to simplify the calculation
y_pred = y_pred.view(-1)
y_true = y_true.view(-1)

# Calculate intersection and union
intersection = (y_pred * y_true).sum()
union = y_pred.sum() + y_true.sum() - intersection

# Calculate the Jaccard index
jaccard_index = (intersection + self.smooth) / (union + self.smooth)

# Return the Jaccard loss (1 - Jaccard index)
return 1 - jaccard_index

class FocalLoss(nn.Module):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs (Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha (float): Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default: ``0.25``.
gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. Default: ``2``.
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.
Returns:
Loss tensor with the reduction option applied.
"""
def __init__(self, alpha=0.1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction

def forward(self, inputs, targets):
inputs = torch.sigmoid(inputs)

# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** self.gamma)

if self.alpha >= 0:
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = alpha_t * loss

# Check reduction option and return loss accordingly
if self.reduction == "none":
pass
elif self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{self.reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss
31 changes: 17 additions & 14 deletions Drone Classification/models/model_testing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -12,15 +12,16 @@
"import torch\n",
"import os\n",
"import gc\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data import DataLoader, Dataset\n",
"import torch.nn as nn\n",
"from torch.nn import BCEWithLogitsLoss\n",
"from torch.optim import Adam\n",
"from sklearn.model_selection import train_test_split\n",
"from tqdm import tqdm\n",
"from data.datasets import SegmentationDataset\n",
"from models import *\n",
"from torchvision.models import resnet18, resnet50, ResNet18_Weights, ResNet50_Weights"
"from .models import *\n",
"from .loss import *\n",
"from torchvision.models import resnet18, resnet50, ResNet18_Weights, ResNet50_Weights\n",
"from torchgeo.models import get_weight"
]
},
{
Expand Down Expand Up @@ -86,8 +87,8 @@
" plt.show()\n",
"\n",
"# NOTE: we are not using transforms, because there are too many channels for standard PIL transforms\n",
"trainDS = SegmentationDataset(images=trainImages, labels=trainMasks) \n",
"testDS = SegmentationDataset(images=testImages, labels=testMasks)\n",
"trainDS = Dataset(images=trainImages, labels=trainMasks) \n",
"testDS = Dataset(images=testImages, labels=testMasks)\n",
"print(f\"[INFO] found {len(trainDS)} examples in the training set...\")\n",
"print(f\"[INFO] found {len(testDS)} examples in the test set...\")\n",
"\n",
Expand Down Expand Up @@ -562,7 +563,7 @@
"source": [
"loss = BCEWithLogitsLoss()\n",
"\n",
"sat_resnet18_UNet_BCE = ResNet18_UNet(input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"sat_resnet18_UNet_BCE = ResNet_UNet(input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"sat_resnet18_train_BCE, sat_resnet18_metrics_BCE = train(sat_resnet18_UNet_BCE, trainLoader, testLoader, loss)\n",
"sat_resnet18_valid_BCE = [x['Loss'] for x in sat_resnet18_metrics_BCE]"
]
Expand Down Expand Up @@ -765,7 +766,7 @@
"source": [
"loss = JaccardLoss()\n",
"\n",
"sat_resnet18_UNet_jaccard = ResNet18_UNet(input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"sat_resnet18_UNet_jaccard = ResNet_UNet(input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"sat_resnet18_train_jaccard, sat_resnet18_metrics_jaccard = train(sat_resnet18_UNet_jaccard, trainLoader, testLoader, loss)\n",
"sat_resnet18_valid_jaccard = [x['Loss'] for x in sat_resnet18_metrics_jaccard]"
]
Expand Down Expand Up @@ -969,7 +970,7 @@
"from models import FocalLoss\n",
"loss = FocalLoss(alpha=0.25,reduction='mean')\n",
"\n",
"resnet18_UNet_focal = ResNet18_UNet(input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"resnet18_UNet_focal = ResNet_UNet(input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"resnet18_train_focal, resnet18_metrics_focal = train(resnet18_UNet_focal, trainLoader, testLoader, loss)\n",
"resnet18_valid_focal = [x['Loss'] for x in resnet18_metrics_focal]"
]
Expand Down Expand Up @@ -1306,7 +1307,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Imagenet vs Satellite Pretraied ResNet18 "
"### Imagenet vs Satellite Pretrained ResNet18 "
]
},
{
Expand Down Expand Up @@ -1509,7 +1510,7 @@
"imagenet_resnet18 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).to(DEVICE)\n",
"loss = JaccardLoss()\n",
"\n",
"imagenet_resnet18_unet = ResNet18_UNet(ResNet18=imagenet_resnet18, input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"imagenet_resnet18_unet = ResNet_UNet(ResNet=imagenet_resnet18, input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"imagenet_resnet18_train, imagenet_resnet18_metrics = train(imagenet_resnet18_unet, trainLoader, testLoader, loss)\n",
"imagenet_resnet18_valid = [x['Loss'] for x in imagenet_resnet18_metrics]"
]
Expand Down Expand Up @@ -1767,8 +1768,10 @@
"source": [
"# U-Net model that uses a ResNet50 from SSL4EO-12. https://github.com/zhu-xlab/SSL4EO-S12\n",
"# The ResNet is pretrained on Sentinel-2 3-channel RGB satellite imagery\n",
"resnet50 = resnet50(weights=get_weight(\"ResNet50_Weights.SENTINEL2_RGB_SECO\"))\n",
"loss = JaccardLoss()\n",
"sat_resnet50 = ResNet50_UNet(input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"\n",
"sat_resnet50 = ResNet_UNet(ResNet=resnet50, input_image_size=INPUT_IMAGE_SIZE).to(DEVICE)\n",
"sat_resnet50_train, sat_resnet50_metrics = train(sat_resnet50, trainLoader, testLoader,loss)\n",
"sat_resnet50_valid = [x['Loss'] for x in sat_resnet50_metrics]"
]
Expand Down Expand Up @@ -2021,7 +2024,7 @@
"# ImageNet ResNet50\n",
"loss = JaccardLoss()\n",
"imagenet_resnet50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)\n",
"imagenet_resnet50_unet = ResNet50_UNet(ResNet50=imagenet_resnet50, input_image_size=128).to(DEVICE)\n",
"imagenet_resnet50_unet = ResNet_UNet(ResNet=imagenet_resnet50, input_image_size=128).to(DEVICE)\n",
"imagenet_resnet50_train, imagenet_resnet50_metrics = train(imagenet_resnet50_unet, trainLoader, testLoader, loss)\n",
"imagenet_resnet50_valid = [x['Loss'] for x in imagenet_resnet50_metrics]"
]
Expand Down
Loading

0 comments on commit 49f9634

Please sign in to comment.