Skip to content

Commit

Permalink
Add ML Developments
Browse files Browse the repository at this point in the history
  • Loading branch information
gagewrye committed Jul 31, 2024
1 parent 338409b commit 2188c40
Show file tree
Hide file tree
Showing 7 changed files with 2,954 additions and 0 deletions.
20 changes: 20 additions & 0 deletions Segmentation/2024/datasets/HRLRDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch.utils.data import Dataset

class HRLRDataset(Dataset):
def __init__(self, hr_images, lr_images, transform=None):
self.hr_images = hr_images
self.lr_images = lr_images
self.transform = transform

def __len__(self):
return len(self.hr_images)

def __getitem__(self, idx):
hr_image = self.hr_images[idx]
lr_image = self.lr_images[idx]

if self.transform:
hr_image = self.transform(hr_image)
lr_image = self.transform(lr_image)

return hr_image, lr_image
32 changes: 32 additions & 0 deletions Segmentation/2024/datasets/SegDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch.utils.data import Dataset
import numpy as np
import torch


class SegmentationDataset(Dataset):
def __init__(self, images, labels, transforms=None):
# store the image and mask filepaths, and augmentation
# transforms
self.transforms = transforms
self.images = np.array(images)
self.labels = np.array(labels)

def __len__(self):
# return the number of total samples contained in the dataset
return np.array(self.images).shape[0]

def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]

# load the image from disk, swap its channels from BGR to RGB,

# and read the associated mask from disk in grayscale mode

# Apply transformations to image and label
if self.transforms is not None:
image = self.transforms(image)
label = self.transforms(label)

# return a tuple of the image and its mask
return (image, torch.Tensor(label))
Binary file not shown.
Binary file not shown.
93 changes: 93 additions & 0 deletions Segmentation/2024/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from models import ResNet18_UNet

import cv2
import numpy as np
from matplotlib import pyplot as plt



def normalize(image):
# Normalize the RGB image to the range [0, 1]
return image / 255.0

def prediction(model, image, patch_size):
# Initialize the segmented image with zeros
segm_img = np.zeros(image.shape[:2])
weights_sum = np.zeros(image.shape[:2]) # Initialize weights for normalization
patch_num = 1

# Iterate over the image in steps of patch_size
for i in range(0, image.shape[0] - patch_size + 1, patch_size):
for j in range(0, image.shape[1] - patch_size + 1, patch_size):
# Extract the patch, ensuring we handle the boundaries
single_patch = image[i:i+patch_size, j:j+patch_size]
single_patch_norm = normalize(single_patch)
single_patch_input = np.expand_dims(single_patch_norm, 0)
single_patch_input = np.transpose(single_patch_input, (0, 3, 1, 2))

# Predict and apply Sigmoid
with torch.no_grad():
single_patch_input_tensor = torch.from_numpy(single_patch_input).float()
output = model(single_patch_input_tensor)
print(output)
threshold = 0
binary_mask = (output > threshold).float()

binary_mask_np = binary_mask.squeeze().detach().numpy()


# Resize the prediction to match the patch size
single_patch_prediction_resized = cv2.resize(binary_mask_np, (patch_size, patch_size))

# Add the prediction to the segmented image and update weights for normalization
segm_img[i:i+patch_size, j:j+patch_size] += single_patch_prediction_resized
weights_sum[i:i+patch_size, j:j+patch_size] += 1


patch_num += 1

if patch_num % 100 == 0:
print("Finished processing patch number", patch_num, "at position", i, j)
if patch_num == 1000:
return np.divide(segm_img, weights_sum, where=weights_sum > 0)
# Normalize the final segmented image to handle overlaps
segm_img = np.divide(segm_img, weights_sum, where=weights_sum > 0)

return segm_img

patch_size = 256
# Load image and convert from BGR to RGB if needed
large_image = cv2.imread("/Users/gage/mangrove/data/jamaica3-31-34ortho-2-1.tif")
large_image_rgb = cv2.cvtColor(large_image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
model = ResNet18_UNet()

# Load the model weights
model.load_state_dict(torch.load('sat_resnet18_UNet_256_BCEweighted.pth', map_location=torch.device('mps')))

model.eval()

# Perform prediction
segmented_image = prediction(model, large_image_rgb, patch_size)

# Plotting the results
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.title('Large Image')
plt.imshow(large_image_rgb) # RGB image for correct color display

# Create a custom colormap that maps grayscale values to yellow
yellow_cmap = mcolors.LinearSegmentedColormap.from_list('yellow_cmap', [(0, 'white'), (1, 'yellow')])

plt.subplot(122)
plt.title('Segmented Image')
plt.imshow(segmented_image, cmap=yellow_cmap) # Use the custom colormap
plt.show()

# Save or visualize the segmented_image
cv2.imwrite('/Users/aaryanpanthi/Desktop/segmented_image.png', (segmented_image * 255).astype(np.uint8))
Loading

0 comments on commit 2188c40

Please sign in to comment.