-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Full pipeline with model architectures, model training/testing, and data processing
- Loading branch information
Showing
20 changed files
with
4,254 additions
and
2,638 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from torch.utils.data import Dataset, Sampler, BatchSampler, DataLoader | ||
import numpy as np | ||
from typing import Optional | ||
import torch | ||
|
||
class MemmapDataset(Dataset): | ||
def __init__(self, images: np.ndarray, labels: np.ndarray, validation_indices: Optional[tuple] = None,): | ||
""" | ||
Inputs are expected to be memory mapped numpy arrays (.npy) | ||
Args:() | ||
images (np.ndarray): Memory mapped numpy array of images | ||
labels (np.ndarray): Memory mapped numpy array of labels | ||
transform (Optional[torch.nn.Module], optional): Torchvision transform to apply to images. | ||
validation_indices (Optional[np.ndarray], optional): Indices to use for validation set during cross validation. | ||
""" | ||
self.images = images | ||
self.labels = labels | ||
self.indices = validation_indices | ||
|
||
# Standard mean and std values for ResNet | ||
mean = [0.485, 0.456, 0.406] | ||
std = [0.229, 0.224, 0.225] | ||
|
||
# Convert mean and std to tensors with shape [C, 1, 1] | ||
self.mean_tensor = torch.tensor(mean, dtype=torch.float32).view(3, 1, 1) | ||
self.std_tensor = torch.tensor(std, dtype=torch.float32).view(3, 1, 1) | ||
|
||
|
||
def __len__(self) -> int: | ||
return self.images.shape[0] | ||
|
||
def __getitem__(self, idx) -> tuple: | ||
image = self.images[idx] | ||
label = self.labels[idx] | ||
|
||
# Normalize the image | ||
image = torch.tensor(image, dtype=torch.float32) | ||
image.div_(255.0) | ||
if len(image.shape) == 4: | ||
image = (image - self.mean_tensor.unsqueeze(0)) / self.std_tensor.unsqueeze(0) | ||
else: | ||
image = (image - self.mean_tensor) / self.std_tensor | ||
|
||
return image, torch.tensor(label, dtype=torch.long) | ||
|
||
def split(self, split_ratio: float): | ||
split_index = int(self.images.shape[0] * split_ratio) | ||
# Create views for training and validation sets | ||
train_images = self.images[:split_index] | ||
val_images = self.images[split_index:] | ||
|
||
train_labels = self.labels[:split_index] | ||
val_labels = self.labels[split_index:] | ||
|
||
train_dataset = MemmapDataset(train_images, train_labels) | ||
val_dataset = MemmapDataset(val_images, val_labels) | ||
|
||
return train_dataset, val_dataset | ||
|
||
def split_into_folds(self, num_folds: int) -> list[Dataset]: | ||
""" | ||
Creates a list of validation datasets for cross validation. | ||
The original dataset will be used as the training dataset. | ||
When training, make sure the indices from the validation dataset are not included in | ||
the training batch. | ||
""" | ||
fold_size = self.images.shape[0] // num_folds | ||
validation_datasets = [] | ||
|
||
for i in range(num_folds): | ||
begin = i * fold_size | ||
end = (i + 1) * fold_size | ||
|
||
val_images = self.images[begin:end] | ||
val_labels = self.labels[begin:end] | ||
|
||
validation_datasets.append(MemmapDataset(val_images, val_labels, validation_indices=(begin, end))) | ||
|
||
return validation_datasets | ||
|
||
|
||
def slice_collate_fn(batch): | ||
""" | ||
Returns the slice as the batch. | ||
""" | ||
return batch[0] | ||
|
||
class SliceSampler(Sampler): | ||
""" | ||
Takes slices of the dataset to minimize overhead of accessing a memory mapped array. | ||
Can optionally skip indices to allow for cross validation with memory mapping. | ||
""" | ||
def __init__(self, dataset_len, batch_size, skip_indices: Optional[tuple] = None): | ||
self.dataset_len = dataset_len | ||
self.batch_size = batch_size | ||
self.start_skip = None | ||
self.end_skip = None | ||
if skip_indices: | ||
self.start_skip = skip_indices[0] | ||
self.end_skip = skip_indices[1] | ||
|
||
|
||
def __iter__(self): | ||
for start_idx in range(0, self.dataset_len, self.batch_size): | ||
end_idx = min(start_idx + self.batch_size, self.dataset_len) | ||
|
||
if self.start_skip is None: | ||
yield slice(start_idx, end_idx) | ||
continue | ||
|
||
# Check for any indices we want to skip | ||
if start_idx >= self.start_skip and start_idx <= self.end_skip or end_idx >= self.start_skip and end_idx <= self.end_skip: | ||
continue # Skip this slice | ||
|
||
yield slice(start_idx, end_idx) | ||
|
||
def __len__(self): | ||
return (self.dataset_len + self.batch_size - 1) // self.batch_size # number of batches | ||
|
||
class SliceBatchSampler(BatchSampler): | ||
""" | ||
Passes along the batch untouched. | ||
""" | ||
def __init__(self, sampler, batch_size, drop_last): | ||
super().__init__(sampler, batch_size, drop_last) | ||
def __iter__(self): | ||
for batch in super().__iter__(): | ||
yield batch | ||
def __len__(self): | ||
return super().__len__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .MemoryMapDataset import MemmapDataset | ||
from .MemoryMapDataset import slice_collate_fn | ||
from .MemoryMapDataset import SliceBatchSampler | ||
from .MemoryMapDataset import SliceSampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from preprocessing import tile_tiff_pair, rasterize_shapefile\n", | ||
"from MemoryMapDataset import MemmapDataset\n", | ||
"import numpy as np\n", | ||
"import psutil\n", | ||
"import os\n", | ||
"import gc" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The processing pipeline assumes that the data in the Chunks folder is in the following format:\n", | ||
"- Each chunk is in it's own folder and named 'Chunk x' or 'Chunk x x-x'\n", | ||
"- The RGB tif should be named 'Chunkx.tif' or 'Chunkx_x-x.tif'\n", | ||
"- label shape file and corresponding label files should be in a folder called 'labels' inside of the matching 'Chunk x' / 'Chunk x x-x' folder, the names of the files do not need to be formatted." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"base_path = '/Users/gage/Desktop/mangrove_data/Chunks'\n", | ||
"TILE_SIZE = 256\n", | ||
"\n", | ||
"combined_images_file = os.path.join(base_path, f'{TILE_SIZE}dataset_images.npy')\n", | ||
"combined_labels_file = os.path.join(base_path, f'{TILE_SIZE}dataset_labels.npy')\n", | ||
"\n", | ||
"# RAM thresholds\n", | ||
"TOTAL_RAM_MB = psutil.virtual_memory().total / (1024 ** 2)\n", | ||
"SAFE_RAM_USAGE_MB = TOTAL_RAM_MB - 16 * 1024 # 16GB below total RAM\n", | ||
"CHUNK_BUFFER_SIZE = 15 # Number of chunks to keep in memory at a time" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# convert all label shape files to tif\n", | ||
"for entry in os.listdir(base_path):\n", | ||
" if 'Chunk' in entry:\n", | ||
" chunk_path = os.path.join(base_path, entry)\n", | ||
" rasterized_shape = rasterize_shapefile(chunk_path)\n", | ||
"print('\\nDone rasterizing shapefiles')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Functions to monitor memory usage\n", | ||
"def print_memory_usage():\n", | ||
" process = psutil.Process(os.getpid())\n", | ||
" mem_info = process.memory_info()\n", | ||
" print(f\"Memory Usage: {mem_info.rss / (1024 ** 2):.2f} MB\")\n", | ||
"\n", | ||
"def get_memory_usage():\n", | ||
" process = psutil.Process(os.getpid())\n", | ||
" mem_info = process.memory_info()\n", | ||
" return mem_info.rss / (1024 ** 2) # Return memory usage in MB" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\"\"\"\n", | ||
"Convert all tif pairs into tiled datasets\n", | ||
"\n", | ||
"NOTE: This will take a lot of time, memory, and storage space.\n", | ||
"You should have at least 32GB of RAM and triple the chunk folder size of storage. If you don't have enough RAM,\n", | ||
"you can run this script in smaller chunks by lowering the CHUNK_BUFFER_SIZE variable.\n", | ||
"\"\"\"\n", | ||
"\n", | ||
"# Function to append data to memory-mapped file\n", | ||
"def append_to_memmap(file_path, data, dtype):\n", | ||
" if not os.path.exists(file_path):\n", | ||
" print(f\"Creating new memmap file at {file_path}\")\n", | ||
" new_memmap = np.lib.format.open_memmap(file_path, mode='w+', dtype=dtype, shape=data.shape)\n", | ||
" new_memmap[:] = data\n", | ||
" else:\n", | ||
" # Load the existing memmap\n", | ||
" memmap = np.load(file_path, mmap_mode='r+')\n", | ||
" new_shape = (memmap.shape[0] + data.shape[0],) + memmap.shape[1:]\n", | ||
" \n", | ||
" # Create a temporary memmap with the expanded size\n", | ||
" temp_file_path = file_path + '.tmp'\n", | ||
" new_memmap = np.lib.format.open_memmap(temp_file_path, mode='w+', dtype=dtype, shape=new_shape)\n", | ||
" \n", | ||
" # Copy old data into the temporary memmap\n", | ||
" new_memmap[:memmap.shape[0]] = memmap[:]\n", | ||
" \n", | ||
" # Append new data\n", | ||
" new_memmap[memmap.shape[0]:] = data\n", | ||
" \n", | ||
" # Flush and delete the old memmap\n", | ||
" del memmap\n", | ||
" new_memmap.flush()\n", | ||
" \n", | ||
" # Replace the original file with the temporary file\n", | ||
" os.replace(temp_file_path, file_path)\n", | ||
"\n", | ||
"# Buffer for storing data before appending to memmap\n", | ||
"image_buffer = []\n", | ||
"label_buffer = []\n", | ||
"\n", | ||
"num_chunks = len([entry for entry in os.listdir(base_path) if 'Chunk' in entry])\n", | ||
"print(f\"Processing {num_chunks} chunk directories\")\n", | ||
"\n", | ||
"# Iterate over each chunk directory and process TIFF pairs\n", | ||
"current_chunk = 0\n", | ||
"for entry in os.listdir(base_path):\n", | ||
" if 'Chunk' in entry:\n", | ||
" current_chunk += 1\n", | ||
" print(f\"\\nChunk {current_chunk}/{num_chunks}\")\n", | ||
" chunk_path = os.path.join(base_path, entry)\n", | ||
" \n", | ||
" # Generate tiled images and labels\n", | ||
" images, labels = tile_tiff_pair(chunk_path, image_size=TILE_SIZE)\n", | ||
" if images[0].size == 0:\n", | ||
" print(f\"No valid tiles found at {entry}\")\n", | ||
" continue\n", | ||
" \n", | ||
" # Add to buffer\n", | ||
" image_buffer.append(images)\n", | ||
" label_buffer.append(labels)\n", | ||
"\n", | ||
" # Check memory usage and append to memmap if within threshold\n", | ||
" current_memory_usage = get_memory_usage()\n", | ||
" if current_memory_usage > SAFE_RAM_USAGE_MB or current_chunk % CHUNK_BUFFER_SIZE == 0:\n", | ||
" if current_memory_usage > SAFE_RAM_USAGE_MB:\n", | ||
" print(f\"Memory usage {current_memory_usage:.2f} MB exceeds {SAFE_RAM_USAGE_MB} threshold. Appending to memmap.\")\n", | ||
" else:\n", | ||
" print(\"Appending to memmap...\")\n", | ||
" images_to_append = np.concatenate(image_buffer, axis=0)\n", | ||
" labels_to_append = np.concatenate(label_buffer, axis=0)\n", | ||
" append_to_memmap(combined_images_file, images_to_append, np.uint8)\n", | ||
" append_to_memmap(combined_labels_file, labels_to_append, np.uint8)\n", | ||
" \n", | ||
" # Clear buffer\n", | ||
" image_buffer = []\n", | ||
" label_buffer = []\n", | ||
" \n", | ||
" # Memory management\n", | ||
" print_memory_usage()\n", | ||
" del images_to_append, labels_to_append\n", | ||
" gc.collect()\n", | ||
"\n", | ||
"# Final append if buffer is not empty\n", | ||
"if image_buffer:\n", | ||
" print(\"Appending remaining buffered data to memmap.\")\n", | ||
" images_buffer = np.concatenate(image_buffer, axis=0)\n", | ||
" labels_buffer = np.concatenate(label_buffer, axis=0)\n", | ||
" append_to_memmap(combined_images_file, images_buffer, np.uint8)\n", | ||
" append_to_memmap(combined_labels_file, labels_buffer, np.uint8)\n", | ||
" \n", | ||
" # Clear buffer\n", | ||
" image_buffer = []\n", | ||
" label_buffer = []\n", | ||
"\n", | ||
"print('\\nDone tiling tif pairs')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Shuffle data one entry at a time using Fisher-Yates shuffle\n", | ||
"# This is necessary because the data is too large to load into memory all at once\n", | ||
"def shuffle_data(images, labels):\n", | ||
" dataset_size = images.shape[0]\n", | ||
"\n", | ||
" for i in range(dataset_size-1, 0, -1):\n", | ||
" print(f\"Percent Shuffled: {100*(dataset_size-i)/dataset_size:.2f}%\", end='\\r')\n", | ||
" j = np.random.randint(0, i+1)\n", | ||
" images[i], images[j] = images[j], images[i]\n", | ||
" labels[i], labels[j] = labels[j], labels[i]\n", | ||
"\n", | ||
"images = np.load(combined_images_file, mmap_mode='r+')\n", | ||
"labels = np.load(combined_labels_file, mmap_mode='r+')\n", | ||
"\n", | ||
"shuffle_data(images, labels)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Example dataset\n", | ||
"dataset = MemmapDataset(images, labels)\n", | ||
"print(f\"Dataset length: {len(dataset)}\")\n", | ||
"print(f\"Dataset image shape: {dataset.images[0].shape}\")\n", | ||
"print(f\"Dataset label shape: {dataset.labels[0].shape}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# If labels.tif files are no longer needed\n", | ||
"\n", | ||
"for entry in os.listdir(base_path):\n", | ||
" if 'Chunk' in entry:\n", | ||
" chunk_path = os.path.join(base_path, entry)\n", | ||
" os.remove(os.path.join(chunk_path, 'labels.tif'))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "mangrove", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .tiff_processing_tools import tile_tiff_pair, rasterize_shapefile |
Oops, something went wrong.