diff --git a/.gitignore b/.gitignore index 180c27c47b2..8357c23b3f8 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ dmypy.json # Pyre type checker .pyre/ +xbdood.ipynb diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index d01a91dfe70..28176dd04c4 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -527,6 +527,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index c54b597fadf..e873cd567d5 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -12,8 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2 - +from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift class TestXView2: @pytest.fixture(params=['train', 'test']) @@ -27,6 +26,7 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: 'md5': '373e61d55c1b294aa76b94dbbd81332b', 'directory': 'train', }, + 'test': { 'filename': 'test_images_labels_targets.tar.gz', 'md5': 'bc6de81c956a3bada38b5b4e246266a1', diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 8177120c2a7..f1f0e9da07e 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -148,7 +148,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -292,6 +292,7 @@ 'VectorDataset', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift', 'ZueriCrop', 'concat_samples', 'merge_samples', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index a7f6a36456a..f6ba2092792 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -138,7 +138,7 @@ def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: Args: root: root dir of dataset - split: subset of dataset, one of [train, test] + split: subset of dataset, one of ['train', 'test'] Returns: list of dicts containing paths for each pair of images and masks @@ -282,3 +282,208 @@ def plot( plt.suptitle(suptitle) return fig + + +class XView2DistShift(XView2): + """A subclass of the XView2 dataset designed to reformat the original train/test splits. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + """ + + classes = ['background', 'building'] + + # List of disaster names + valid_disasters = [ + 'hurricane-harvey', + 'socal-fire', + 'hurricane-matthew', + 'mexico-earthquake', + 'guatemala-volcano', + 'santa-rosa-wildfire', + 'palu-tsunami', + 'hurricane-florence', + 'hurricane-michael', + 'midwest-flooding', + ] + + def __init__( + self, + root: str = 'data', + split: str = 'train', + id_ood_disaster: list[dict[str, str]] = [ + {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, + {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, + ], + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + checksum: bool = False, + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ['train', 'test'], "Split must be either 'train' or 'test'." + # Validate that the disasters are valid + + if ( + id_ood_disaster[0]['disaster_name'] not in self.valid_disasters + or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters + ): + raise ValueError( + f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" + ) + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.files = self._load_split_files_by_disaster_and_type( + self.all_files, id_ood_disaster[0], id_ood_disaster[1] + ) + + train_size, test_size = self.get_id_ood_sizes() + print(f"ID sample len: {train_size}, OOD sample len: {test_size}") + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + """Get an item from the dataset at the given index.""" + file_info = ( + self.files['train'][index] + if self.split == 'train' + else self.files['test'][index] + ) + + image = self._load_image(file_info['image']) + mask = self._load_target(file_info['mask']).long() + + # Reformulate as building segmentation task + mask[mask == 2] = 1 # Map damage class 2 to 1 + mask[(mask == 3) | (mask == 4)] = 0 # Map 3 and 4 damage classes to background + + sample = {'image': image, 'mask': mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return ( + len(self.files['train']) + if self.split == 'train' + else len(self.files['test']) + ) + + + def get_id_ood_sizes(self) -> tuple[int, int]: + """Return the number of samples in the train and test splits.""" + return (len(self.files['train']), len(self.files['test'])) + + + def _initialize_files(self, root: str) -> list[dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, 'images') + mask_root = os.path.join(root, split, 'targets') + images = glob.glob(os.path.join(image_root, '*.png')) + + # Extract basenames while preserving the disaster-name and sample number + for img in images: + basename_parts = os.path.basename(img).split('_') + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = ( + f'{event_name}_{sample_number}' # e.g., mexico-earthquake_00000001 + ) + + file_info = { + 'image': img, + 'mask': os.path.join( + mask_root, f'{basename}_pre_disaster_target.png' + ), + 'basename': basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, + files: list[dict[str, str]], + id_disaster: dict[str, str], + ood_disaster: dict[str, str], + ) -> dict[str, list[dict[str, str]]]: + """Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}). + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + disaster_list = [] + + for file_info in files: + basename = file_info['basename'] + disaster_name = basename.split('_')[ + 0 + ] # Extract disaster name from basename + pre_post = ( + 'pre' if 'pre_disaster' in file_info['image'] else 'post' + ) # Identify pre/post type + + disaster_list.append(disaster_name) + + # Filter for in-distribution (ID) training set + if disaster_name == id_disaster['disaster_name']: + if ( + id_disaster.get('pre-post') == 'both' + or id_disaster['pre-post'] == pre_post + ): + image = ( + file_info['image'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['image'] + ) + mask = ( + file_info['mask'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['mask'] + ) + train_files.append(dict(image=image, mask=mask)) + + # Filter for out-of-distribution (OOD) test set + if disaster_name == ood_disaster['disaster_name']: + if ( + ood_disaster.get('pre-post') == 'both' + or ood_disaster['pre-post'] == pre_post + ): + test_files.append(file_info) + + return {'train': train_files, 'test': test_files, 'disasters': disaster_list}