Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom disaster-based train/test splits for xView2 dataset #2416

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ dmypy.json

# Pyre type checker
.pyre/
xbdood.ipynb
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ xView2
^^^^^^

.. autoclass:: XView2
.. autoclass:: XView2DistShift

ZueriCrop
^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_xview.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, XView2

from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift

Check failure on line 15 in tests/datasets/test_xview.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

tests/datasets/test_xview.py:4:1: I001 Import block is un-sorted or un-formatted

Check failure on line 15 in tests/datasets/test_xview.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

tests/datasets/test_xview.py:15:61: F401 `torchgeo.datasets.XView2DistShift` imported but unused

class TestXView2:
@pytest.fixture(params=['train', 'test'])
Expand All @@ -27,6 +26,7 @@
'md5': '373e61d55c1b294aa76b94dbbd81332b',
'directory': 'train',
},

'test': {
'filename': 'test_images_labels_targets.tar.gz',
'md5': 'bc6de81c956a3bada38b5b4e246266a1',
Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
from .vaihingen import Vaihingen2D
from .vhr10 import VHR10
from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture
from .xview import XView2
from .xview import XView2, XView2DistShift
from .zuericrop import ZueriCrop

__all__ = (
Expand Down Expand Up @@ -292,6 +292,7 @@
'VectorDataset',
'WesternUSALiveFuelMoisture',
'XView2',
'XView2DistShift',
'ZueriCrop',
'concat_samples',
'merge_samples',
Expand Down
207 changes: 206 additions & 1 deletion torchgeo/datasets/xview.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@

Args:
root: root dir of dataset
split: subset of dataset, one of [train, test]
split: subset of dataset, one of ['train', 'test']

Returns:
list of dicts containing paths for each pair of images and masks
Expand Down Expand Up @@ -282,3 +282,208 @@
plt.suptitle(suptitle)

return fig


class XView2DistShift(XView2):
"""A subclass of the XView2 dataset designed to reformat the original train/test splits.

This class allows for the selection of particular disasters to be used as the
training set (in-domain) and test set (out-of-domain). The dataset can be split
according to the disaster names specified by the user, enabling the model to train
on one disaster type and evaluate on a different, out-of-domain disaster. The goal
is to test the generalization ability of models trained on one disaster to perform
on others.
"""

classes = ['background', 'building']

Check failure on line 298 in torchgeo/datasets/xview.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF012)

torchgeo/datasets/xview.py:298:15: RUF012 Mutable class attributes should be annotated with `typing.ClassVar`

# List of disaster names
valid_disasters = [
'hurricane-harvey',
'socal-fire',
'hurricane-matthew',
'mexico-earthquake',
'guatemala-volcano',
'santa-rosa-wildfire',
'palu-tsunami',
'hurricane-florence',
'hurricane-michael',
'midwest-flooding',
]

Check failure on line 312 in torchgeo/datasets/xview.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF012)

torchgeo/datasets/xview.py:301:23: RUF012 Mutable class attributes should be annotated with `typing.ClassVar`

def __init__(
self,
root: str = 'data',
split: str = 'train',
id_ood_disaster: list[dict[str, str]] = [
{'disaster_name': 'hurricane-matthew', 'pre-post': 'post'},
{'disaster_name': 'mexico-earthquake', 'pre-post': 'post'},
],
transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None,

Check failure on line 322 in torchgeo/datasets/xview.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF013)

torchgeo/datasets/xview.py:322:21: RUF013 PEP 484 prohibits implicit `Optional`
checksum: bool = False,
) -> None:
"""Initialize the XView2DistShift dataset instance.

Args:
root: Root directory where the dataset is located.
split: One of "train" or "test".
id_ood_disaster: List containing in-distribution and out-of-distribution disaster names.
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)

Raises:
AssertionError: If *split* is invalid.
ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters.
DatasetNotFoundError: If dataset is not found.
"""
assert split in ['train', 'test'], "Split must be either 'train' or 'test'."
# Validate that the disasters are valid

if (
id_ood_disaster[0]['disaster_name'] not in self.valid_disasters
or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters
):
raise ValueError(
f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}"
)

self.root = root
self.split = split
self.transforms = transforms
self.checksum = checksum

self._verify()

# Load all files and compute basenames and disasters only once
self.all_files = self._initialize_files(root)

# Split logic by disaster and pre-post type
self.files = self._load_split_files_by_disaster_and_type(
self.all_files, id_ood_disaster[0], id_ood_disaster[1]
)

train_size, test_size = self.get_id_ood_sizes()
print(f"ID sample len: {train_size}, OOD sample len: {test_size}")

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
"""Get an item from the dataset at the given index."""
file_info = (
self.files['train'][index]
if self.split == 'train'
else self.files['test'][index]
)

image = self._load_image(file_info['image'])
mask = self._load_target(file_info['mask']).long()

# Reformulate as building segmentation task
mask[mask == 2] = 1 # Map damage class 2 to 1
mask[(mask == 3) | (mask == 4)] = 0 # Map 3 and 4 damage classes to background

sample = {'image': image, 'mask': mask}

if self.transforms:
sample = self.transforms(sample)

return sample

def __len__(self) -> int:
"""Return the total number of samples in the dataset."""
return (
len(self.files['train'])
if self.split == 'train'
else len(self.files['test'])
)


def get_id_ood_sizes(self) -> tuple[int, int]:
"""Return the number of samples in the train and test splits."""
return (len(self.files['train']), len(self.files['test']))


def _initialize_files(self, root: str) -> list[dict[str, str]]:
"""Initialize the dataset by loading file paths and computing basenames with sample numbers."""
all_files = []
for split in self.metadata.keys():
image_root = os.path.join(root, split, 'images')
mask_root = os.path.join(root, split, 'targets')
images = glob.glob(os.path.join(image_root, '*.png'))

# Extract basenames while preserving the disaster-name and sample number
for img in images:
basename_parts = os.path.basename(img).split('_')
event_name = basename_parts[0] # e.g., mexico-earthquake
sample_number = basename_parts[1] # e.g., 00000001
basename = (
f'{event_name}_{sample_number}' # e.g., mexico-earthquake_00000001
)

file_info = {
'image': img,
'mask': os.path.join(
mask_root, f'{basename}_pre_disaster_target.png'
),
'basename': basename,
}
all_files.append(file_info)
return all_files

def _load_split_files_by_disaster_and_type(
self,
files: list[dict[str, str]],
id_disaster: dict[str, str],
ood_disaster: dict[str, str],
) -> dict[str, list[dict[str, str]]]:
"""Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type.

Args:
files: List of file paths with their corresponding information.
id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}).
ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}).

Returns:
A dictionary containing 'train' (ID) and 'test' (OOD) file lists.
"""
train_files = []
test_files = []
disaster_list = []

for file_info in files:
basename = file_info['basename']
disaster_name = basename.split('_')[
0
] # Extract disaster name from basename
pre_post = (
'pre' if 'pre_disaster' in file_info['image'] else 'post'
) # Identify pre/post type

disaster_list.append(disaster_name)

# Filter for in-distribution (ID) training set
if disaster_name == id_disaster['disaster_name']:
if (
id_disaster.get('pre-post') == 'both'
or id_disaster['pre-post'] == pre_post
):
image = (
file_info['image'].replace('post_disaster', 'pre_disaster')
if pre_post == 'pre'
else file_info['image']
)
mask = (
file_info['mask'].replace('post_disaster', 'pre_disaster')
if pre_post == 'pre'
else file_info['mask']
)
train_files.append(dict(image=image, mask=mask))

# Filter for out-of-distribution (OOD) test set
if disaster_name == ood_disaster['disaster_name']:
if (
ood_disaster.get('pre-post') == 'both'
or ood_disaster['pre-post'] == pre_post
):
test_files.append(file_info)

return {'train': train_files, 'test': test_files, 'disasters': disaster_list}
Loading