From da2399b47997044d9956555bf229cdd453bc493d Mon Sep 17 00:00:00 2001 From: Burak Date: Tue, 9 Apr 2024 15:15:00 +0200 Subject: [PATCH 1/8] minor typo in custom_raster_dataset.ipynb --- docs/tutorials/custom_raster_dataset.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 7401e580edb..e4da8499114 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -345,7 +345,7 @@ "\n", "### `rgb_bands`\n", "\n", - "If your data is a multispectral iamge, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order.\n", + "If your data is a multispectral image, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order.\n", "\n", "Putting this all together into a single class, we get:" ] From 0f57ecf9d88fea403701123a159f3e1bd8707ebb Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:43:34 +0100 Subject: [PATCH 2/8] xview2 dist shift initial commit --- docs/api/datasets.rst | 1 + torchgeo/datasets/__init__.py | 3 +- torchgeo/datasets/xview.py | 169 ++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 20ce1bfcbac..8b5149ade43 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -462,6 +462,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6a15fabdf76..e1daabf3ed9 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -136,7 +136,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -258,6 +258,7 @@ 'VHR10', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift' 'ZueriCrop', # Base classes 'GeoDataset', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 5716c06f593..9854d18458f 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -270,3 +270,172 @@ def plot( plt.suptitle(suptitle) return fig + + +class XView2DistShift(XView2): + """ + A subclass of the XView2 dataset designed to reformat the original train/test splits + based on specific in-domain (ID) and out-of-domain (OOD) disasters. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + """ + + classes = ["background", "building"] + + # List of possible disaster names + valid_disasters = [ + 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', + 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', + 'hurricane-michael', 'midwest-flooding' + ] + + def __init__( + self, + root: str = "data", + split: str = "train", + id_ood_disaster: list[dict[str, str]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + checksum: bool = False, + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ["train", "test"], "Split must be either 'train' or 'test'." + # Validate that the disasters are valid + + if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: + raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) + print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get an item from the dataset at the given index.""" + file_info = ( + self.files["train"][index] + if self.split == "train" + else self.files["test"][index]) + + image = self._load_image(file_info["image"]).to("cuda") + mask = self._load_target(file_info["mask"]).long().to("cuda") + mask[mask == 2] = 1 + mask[(mask == 3) | (mask == 4)] = 0 + + sample = {"image": image, "mask": mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return ( + len(self.files["train"]) + if self.split == "train" + else len(self.files["test"]) + ) + + def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, "images") + mask_root = os.path.join(root, split, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) + + # Extract basenames while preserving the event-name and sample number + for img in images: + basename_parts = os.path.basename(img).split("_") + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = ( + f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + ) + + + file_info = { + "image": img, + "mask": os.path.join( + mask_root, f"{basename}_pre_disaster_target.png" + ), + "basename": basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] + ) -> Dict[str, List[Dict[str, str]]]: + """ + Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name + and pre-post disaster type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + disaster_list = [] + + for file_info in files: + basename = file_info["basename"] + disaster_name = basename.split("_")[0] # Extract disaster name from basename + pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type + + disaster_list.append(disaster_name) + + # Filter for in-domain (ID) training set + if disaster_name == id_disaster["disaster_name"]: + if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + image = ( + file_info["image"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["image"] + ) + mask = ( + file_info["mask"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["mask"] + ) + train_files.append(dict(image=image, mask=mask)) + + # Filter for out-of-domain (OOD) test set + if disaster_name == ood_disaster["disaster_name"]: + if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + test_files.append(file_info) + + return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file From 62919bfcf08f73d74ce9553d617319b57437fc49 Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:46:50 +0100 Subject: [PATCH 3/8] xview2distshift dataset --- docs/api/datasets.rst | 1 + torchgeo/datasets/__init__.py | 3 +- torchgeo/datasets/xview.py | 169 ++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 20ce1bfcbac..8b5149ade43 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -462,6 +462,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6a15fabdf76..f84760ce865 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -136,7 +136,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -258,6 +258,7 @@ 'VHR10', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift', 'ZueriCrop', # Base classes 'GeoDataset', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 5716c06f593..98fe3ceb534 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -270,3 +270,172 @@ def plot( plt.suptitle(suptitle) return fig + +class XView2DistShift(XView2): + """ + A subclass of the XView2 dataset designed to reformat the original train/test splits + based on specific in-domain (ID) and out-of-domain (OOD) disasters. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + """ + + classes = ["background", "building"] + + # List of disaster names + valid_disasters = [ + 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', + 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', + 'hurricane-michael', 'midwest-flooding' + ] + + def __init__( + self, + root: Path = "data", + split: str = "train", + id_ood_disaster: List[Dict[str, str]] = [{"disaster_name": "hurricane-matthew", "pre-post": "post"}, {"disaster_name": "mexico-earthquake", "pre-post": "post"}], + transforms: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + checksum: bool = False, + **kwargs + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ["train", "test"], "Split must be either 'train' or 'test'." + # Validate that the disasters are valid + + if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: + raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) + print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get an item from the dataset at the given index.""" + file_info = ( + self.files["train"][index] + if self.split == "train" + else self.files["test"][index]) + + image = self._load_image(file_info["image"]).to("cuda") + mask = self._load_target(file_info["mask"]).long().to("cuda") + mask[mask == 2] = 1 + mask[(mask == 3) | (mask == 4)] = 0 + + sample = {"image": image, "mask": mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return ( + len(self.files["train"]) + if self.split == "train" + else len(self.files["test"]) + ) + + def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, "images") + mask_root = os.path.join(root, split, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) + + # Extract basenames while preserving the event-name and sample number + for img in images: + basename_parts = os.path.basename(img).split("_") + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = ( + f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + ) + + + file_info = { + "image": img, + "mask": os.path.join( + mask_root, f"{basename}_pre_disaster_target.png" + ), + "basename": basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] + ) -> Dict[str, List[Dict[str, str]]]: + """ + Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name + and pre-post disaster type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + disaster_list = [] + + for file_info in files: + basename = file_info["basename"] + disaster_name = basename.split("_")[0] # Extract disaster name from basename + pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type + + disaster_list.append(disaster_name) + + # Filter for in-domain (ID) training set + if disaster_name == id_disaster["disaster_name"]: + if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + image = ( + file_info["image"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["image"] + ) + mask = ( + file_info["mask"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["mask"] + ) + train_files.append(dict(image=image, mask=mask)) + + # Filter for out-of-domain (OOD) test set + if disaster_name == ood_disaster["disaster_name"]: + if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + test_files.append(file_info) + + return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file From 5985f44c451b99329e8e62b50a90f7c2fb2f724e Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:00:56 +0100 Subject: [PATCH 4/8] test xview2 --- tests/datasets/test_xview2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 7689acf5f78..35e02e27c28 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2 +from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift class TestXView2: From a23344e69567d6f435e8f95018a6716ae0fe5aa4 Mon Sep 17 00:00:00 2001 From: Burak Date: Mon, 18 Nov 2024 17:03:23 +0100 Subject: [PATCH 5/8] formatting --- tests/datasets/test_xview2.py | 2 +- torchgeo/datasets/xview.py | 211 +++++++++++++++++++--------------- 2 files changed, 121 insertions(+), 92 deletions(-) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 35e02e27c28..dc8774d0933 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -14,7 +14,6 @@ from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift - class TestXView2: @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: @@ -27,6 +26,7 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: 'md5': '373e61d55c1b294aa76b94dbbd81332b', 'directory': 'train', }, + 'test': { 'filename': 'test_images_labels_targets.tar.gz', 'md5': 'bc6de81c956a3bada38b5b4e246266a1', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 98fe3ceb534..bc9ce8d5992 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -19,6 +19,7 @@ from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive + class XView2(NonGeoDataset): """xView2 dataset. @@ -50,24 +51,24 @@ class XView2(NonGeoDataset): """ metadata = { - 'train': { - 'filename': 'train_images_labels_targets.tar.gz', - 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', - 'directory': 'train', + "train": { + "filename": "train_images_labels_targets.tar.gz", + "md5": "a20ebbfb7eb3452785b63ad02ffd1e16", + "directory": "train", }, - 'test': { - 'filename': 'test_images_labels_targets.tar.gz', - 'md5': '1b39c47e05d1319c17cc8763cee6fe0c', - 'directory': 'test', + "test": { + "filename": "test_images_labels_targets.tar.gz", + "md5": "1b39c47e05d1319c17cc8763cee6fe0c", + "directory": "test", }, } - classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] - colormap = ['green', 'blue', 'orange', 'red'] + classes = ["background", "no-damage", "minor-damage", "major-damage", "destroyed"] + colormap = ["green", "blue", "orange", "red"] def __init__( self, - root: str = 'data', - split: str = 'train', + root: str = "data", + split: str = "train", transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -105,14 +106,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files['image1']) - image2 = self._load_image(files['image2']) - mask1 = self._load_target(files['mask1']) - mask2 = self._load_target(files['mask2']) + image1 = self._load_image(files["image1"]) + image2 = self._load_image(files["image2"]) + mask1 = self._load_target(files["mask1"]) + mask2 = self._load_target(files["mask2"]) image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {'image': image, 'mask': mask} + sample = {"image": image, "mask": mask} if self.transforms is not None: sample = self.transforms(sample) @@ -138,17 +139,17 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of images and masks """ files = [] - directory = self.metadata[split]['directory'] - image_root = os.path.join(root, directory, 'images') - mask_root = os.path.join(root, directory, 'targets') - images = glob.glob(os.path.join(image_root, '*.png')) + directory = self.metadata[split]["directory"] + image_root = os.path.join(root, directory, "images") + mask_root = os.path.join(root, directory, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) basenames = [os.path.basename(f) for f in images] - basenames = ['_'.join(f.split('_')[:-2]) for f in basenames] + basenames = ["_".join(f.split("_")[:-2]) for f in basenames] for name in sorted(set(basenames)): - image1 = os.path.join(image_root, f'{name}_pre_disaster.png') - image2 = os.path.join(image_root, f'{name}_post_disaster.png') - mask1 = os.path.join(mask_root, f'{name}_pre_disaster_target.png') - mask2 = os.path.join(mask_root, f'{name}_post_disaster_target.png') + image1 = os.path.join(image_root, f"{name}_pre_disaster.png") + image2 = os.path.join(image_root, f"{name}_post_disaster.png") + mask1 = os.path.join(mask_root, f"{name}_pre_disaster_target.png") + mask2 = os.path.join(mask_root, f"{name}_post_disaster_target.png") files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files @@ -163,7 +164,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -180,7 +181,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("L")) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -190,10 +191,10 @@ def _verify(self) -> None: # Check if the files already exist exists = [] for split_info in self.metadata.values(): - for directory in ['images', 'targets']: + for directory in ["images", "targets"]: exists.append( os.path.exists( - os.path.join(self.root, split_info['directory'], directory) + os.path.join(self.root, split_info["directory"], directory) ) ) @@ -203,10 +204,10 @@ def _verify(self) -> None: # Check if .tar.gz files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info['filename']) + filepath = os.path.join(self.root, split_info["filename"]) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info['md5']): - raise RuntimeError('Dataset found, but corrupted.') + if self.checksum and not check_integrity(filepath, split_info["md5"]): + raise RuntimeError("Dataset found, but corrupted.") exists.append(True) extract_archive(filepath) else: @@ -237,70 +238,78 @@ def plot( """ ncols = 2 image1 = draw_semantic_segmentation_masks( - sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap + sample["image"][0], sample["mask"][0], alpha=alpha, colors=self.colormap ) image2 = draw_semantic_segmentation_masks( - sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap + sample["image"][1], sample["mask"][1], alpha=alpha, colors=self.colormap ) - if 'prediction' in sample: # NOTE: this assumes predictions are made for post + if "prediction" in sample: # NOTE: this assumes predictions are made for post ncols += 1 image3 = draw_semantic_segmentation_masks( - sample['image'][1], - sample['prediction'], + sample["image"][1], + sample["prediction"], alpha=alpha, colors=self.colormap, ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis('off') + axs[0].axis("off") axs[1].imshow(image2) - axs[1].axis('off') + axs[1].axis("off") if ncols > 2: axs[2].imshow(image3) - axs[2].axis('off') + axs[2].axis("off") if show_titles: - axs[0].set_title('Pre disaster') - axs[1].set_title('Post disaster') + axs[0].set_title("Pre disaster") + axs[1].set_title("Post disaster") if ncols > 2: - axs[2].set_title('Predictions') + axs[2].set_title("Predictions") if suptitle is not None: plt.suptitle(suptitle) return fig - + + class XView2DistShift(XView2): - """ - A subclass of the XView2 dataset designed to reformat the original train/test splits - based on specific in-domain (ID) and out-of-domain (OOD) disasters. - - This class allows for the selection of particular disasters to be used as the - training set (in-domain) and test set (out-of-domain). The dataset can be split - according to the disaster names specified by the user, enabling the model to train - on one disaster type and evaluate on a different, out-of-domain disaster. The goal - is to test the generalization ability of models trained on one disaster to perform + """A subclass of the XView2 dataset designed to reformat the original train/test splits. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform on others. """ - + classes = ["background", "building"] - + # List of disaster names valid_disasters = [ - 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', - 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', - 'hurricane-michael', 'midwest-flooding' + "hurricane-harvey", + "socal-fire", + "hurricane-matthew", + "mexico-earthquake", + "guatemala-volcano", + "santa-rosa-wildfire", + "palu-tsunami", + "hurricane-florence", + "hurricane-michael", + "midwest-flooding", ] - + def __init__( self, - root: Path = "data", + root: str = "data", split: str = "train", - id_ood_disaster: List[Dict[str, str]] = [{"disaster_name": "hurricane-matthew", "pre-post": "post"}, {"disaster_name": "mexico-earthquake", "pre-post": "post"}], - transforms: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + id_ood_disaster: list[dict[str, str]] = [ + {"disaster_name": "hurricane-matthew", "pre-post": "post"}, + {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + ], + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, checksum: bool = False, - **kwargs ) -> None: """Initialize the XView2DistShift dataset instance. @@ -311,7 +320,7 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) - + Raises: AssertionError: If *split* is invalid. ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. @@ -319,9 +328,14 @@ def __init__( """ assert split in ["train", "test"], "Split must be either 'train' or 'test'." # Validate that the disasters are valid - - if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: - raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + if ( + id_ood_disaster[0]["disaster_name"] not in self.valid_disasters + or id_ood_disaster[1]["disaster_name"] not in self.valid_disasters + ): + raise ValueError( + f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" + ) self.root = root self.split = split @@ -332,23 +346,28 @@ def __init__( # Load all files and compute basenames and disasters only once self.all_files = self._initialize_files(root) - + # Split logic by disaster and pre-post type - self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) - print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + self.files = self._load_split_files_by_disaster_and_type( + self.all_files, id_ood_disaster[0], id_ood_disaster[1] + ) + print( + f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files." + ) - def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( self.files["train"][index] if self.split == "train" - else self.files["test"][index]) + else self.files["test"][index] + ) image = self._load_image(file_info["image"]).to("cuda") mask = self._load_target(file_info["mask"]).long().to("cuda") mask[mask == 2] = 1 mask[(mask == 3) | (mask == 4)] = 0 - + sample = {"image": image, "mask": mask} if self.transforms: @@ -364,14 +383,14 @@ def __len__(self) -> int: else len(self.files["test"]) ) - def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" all_files = [] for split in self.metadata.keys(): image_root = os.path.join(root, split, "images") mask_root = os.path.join(root, split, "targets") images = glob.glob(os.path.join(image_root, "*.png")) - + # Extract basenames while preserving the event-name and sample number for img in images: basename_parts = os.path.basename(img).split("_") @@ -381,7 +400,6 @@ def _initialize_files(self, root: Path) -> List[Dict[str, str]]: f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 ) - file_info = { "image": img, "mask": os.path.join( @@ -393,11 +411,12 @@ def _initialize_files(self, root: Path) -> List[Dict[str, str]]: return all_files def _load_split_files_by_disaster_and_type( - self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] - ) -> Dict[str, List[Dict[str, str]]]: - """ - Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name - and pre-post disaster type. + self, + files: list[dict[str, str]], + id_disaster: dict[str, str], + ood_disaster: dict[str, str], + ) -> dict[str, list[dict[str, str]]]: + """Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type. Args: files: List of file paths with their corresponding information. @@ -410,17 +429,24 @@ def _load_split_files_by_disaster_and_type( train_files = [] test_files = [] disaster_list = [] - + for file_info in files: basename = file_info["basename"] - disaster_name = basename.split("_")[0] # Extract disaster name from basename - pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type - - disaster_list.append(disaster_name) - + disaster_name = basename.split("_")[ + 0 + ] # Extract disaster name from basename + pre_post = ( + "pre" if "pre_disaster" in file_info["image"] else "post" + ) # Identify pre/post type + + disaster_list.append(disaster_name) + # Filter for in-domain (ID) training set if disaster_name == id_disaster["disaster_name"]: - if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + if ( + id_disaster.get("pre-post") == "both" + or id_disaster["pre-post"] == pre_post + ): image = ( file_info["image"].replace("post_disaster", "pre_disaster") if pre_post == "pre" @@ -435,7 +461,10 @@ def _load_split_files_by_disaster_and_type( # Filter for out-of-domain (OOD) test set if disaster_name == ood_disaster["disaster_name"]: - if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + if ( + ood_disaster.get("pre-post") == "both" + or ood_disaster["pre-post"] == pre_post + ): test_files.append(file_info) - return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file + return {"train": train_files, "test": test_files, "disasters": disaster_list} From 8239d32f80dd882f80f50553fd4638ef9d1c4bae Mon Sep 17 00:00:00 2001 From: Burak Date: Sat, 1 Feb 2025 16:41:18 +0100 Subject: [PATCH 6/8] " to ' --- torchgeo/datasets/xview.py | 178 ++++++++++++++++++------------------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 222790248cb..976ca219aad 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -61,10 +61,10 @@ class XView2(NonGeoDataset): 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', 'directory': 'train', }, - "test": { - "filename": "test_images_labels_targets.tar.gz", - "md5": "1b39c47e05d1319c17cc8763cee6fe0c", - "directory": "test", + 'test': { + 'filename': 'test_images_labels_targets.tar.gz', + 'md5': '1b39c47e05d1319c17cc8763cee6fe0c', + 'directory': 'test', }, } classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed') @@ -111,14 +111,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files["image1"]) - image2 = self._load_image(files["image2"]) - mask1 = self._load_target(files["mask1"]) - mask2 = self._load_target(files["mask2"]) + image1 = self._load_image(files['image1']) + image2 = self._load_image(files['image2']) + mask1 = self._load_target(files['mask1']) + mask2 = self._load_target(files['mask2']) image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -138,23 +138,23 @@ def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: Args: root: root dir of dataset - split: subset of dataset, one of [train, test] + split: subset of dataset, one of ['train', 'test'] Returns: list of dicts containing paths for each pair of images and masks """ files = [] - directory = self.metadata[split]["directory"] - image_root = os.path.join(root, directory, "images") - mask_root = os.path.join(root, directory, "targets") - images = glob.glob(os.path.join(image_root, "*.png")) + directory = self.metadata[split]['directory'] + image_root = os.path.join(root, directory, 'images') + mask_root = os.path.join(root, directory, 'targets') + images = glob.glob(os.path.join(image_root, '*.png')) basenames = [os.path.basename(f) for f in images] - basenames = ["_".join(f.split("_")[:-2]) for f in basenames] + basenames = ['_'.join(f.split('_')[:-2]) for f in basenames] for name in sorted(set(basenames)): - image1 = os.path.join(image_root, f"{name}_pre_disaster.png") - image2 = os.path.join(image_root, f"{name}_post_disaster.png") - mask1 = os.path.join(mask_root, f"{name}_pre_disaster_target.png") - mask2 = os.path.join(mask_root, f"{name}_post_disaster_target.png") + image1 = os.path.join(image_root, f'{name}_pre_disaster.png') + image2 = os.path.join(image_root, f'{name}_post_disaster.png') + mask1 = os.path.join(mask_root, f'{name}_pre_disaster_target.png') + mask2 = os.path.join(mask_root, f'{name}_post_disaster_target.png') files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files @@ -169,7 +169,7 @@ def _load_image(self, path: Path) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert("RGB")) + array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -186,7 +186,7 @@ def _load_target(self, path: Path) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert("L")) + array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -196,10 +196,10 @@ def _verify(self) -> None: # Check if the files already exist exists = [] for split_info in self.metadata.values(): - for directory in ["images", "targets"]: + for directory in ['images', 'targets']: exists.append( os.path.exists( - os.path.join(self.root, split_info["directory"], directory) + os.path.join(self.root, split_info['directory'], directory) ) ) @@ -209,10 +209,10 @@ def _verify(self) -> None: # Check if .tar.gz files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info["filename"]) + filepath = os.path.join(self.root, split_info['filename']) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info["md5"]): - raise RuntimeError("Dataset found, but corrupted.") + if self.checksum and not check_integrity(filepath, split_info['md5']): + raise RuntimeError('Dataset found, but corrupted.') exists.append(True) extract_archive(filepath) else: @@ -254,29 +254,29 @@ def plot( alpha=alpha, colors=list(self.colormap), ) - if "prediction" in sample: # NOTE: this assumes predictions are made for post + if 'prediction' in sample: # NOTE: this assumes predictions are made for post ncols += 1 image3 = draw_semantic_segmentation_masks( - sample["image"][1], - sample["prediction"], + sample['image'][1], + sample['prediction'], alpha=alpha, colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(image2) - axs[1].axis("off") + axs[1].axis('off') if ncols > 2: axs[2].imshow(image3) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[0].set_title("Pre disaster") - axs[1].set_title("Post disaster") + axs[0].set_title('Pre disaster') + axs[1].set_title('Post disaster') if ncols > 2: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) @@ -295,29 +295,29 @@ class XView2DistShift(XView2): on others. """ - classes = ["background", "building"] + classes = ['background', 'building'] # List of disaster names valid_disasters = [ - "hurricane-harvey", - "socal-fire", - "hurricane-matthew", - "mexico-earthquake", - "guatemala-volcano", - "santa-rosa-wildfire", - "palu-tsunami", - "hurricane-florence", - "hurricane-michael", - "midwest-flooding", + 'hurricane-harvey', + 'socal-fire', + 'hurricane-matthew', + 'mexico-earthquake', + 'guatemala-volcano', + 'santa-rosa-wildfire', + 'palu-tsunami', + 'hurricane-florence', + 'hurricane-michael', + 'midwest-flooding', ] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', id_ood_disaster: list[dict[str, str]] = [ - {"disaster_name": "hurricane-matthew", "pre-post": "post"}, - {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, + {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, checksum: bool = False, @@ -337,12 +337,12 @@ def __init__( ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. DatasetNotFoundError: If dataset is not found. """ - assert split in ["train", "test"], "Split must be either 'train' or 'test'." + assert split in ['train', 'test'], "Split must be either 'train' or 'test'." # Validate that the disasters are valid if ( - id_ood_disaster[0]["disaster_name"] not in self.valid_disasters - or id_ood_disaster[1]["disaster_name"] not in self.valid_disasters + id_ood_disaster[0]['disaster_name'] not in self.valid_disasters + or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters ): raise ValueError( f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" @@ -369,17 +369,17 @@ def __init__( def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( - self.files["train"][index] - if self.split == "train" - else self.files["test"][index] + self.files['train'][index] + if self.split == 'train' + else self.files['test'][index] ) - image = self._load_image(file_info["image"]).to("cuda") - mask = self._load_target(file_info["mask"]).long().to("cuda") + image = self._load_image(file_info['image']).to('cuda') + mask = self._load_target(file_info['mask']).long().to('cuda') mask[mask == 2] = 1 mask[(mask == 3) | (mask == 4)] = 0 - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms: sample = self.transforms(sample) @@ -389,34 +389,34 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: def __len__(self) -> int: """Return the total number of samples in the dataset.""" return ( - len(self.files["train"]) - if self.split == "train" - else len(self.files["test"]) + len(self.files['train']) + if self.split == 'train' + else len(self.files['test']) ) def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" all_files = [] for split in self.metadata.keys(): - image_root = os.path.join(root, split, "images") - mask_root = os.path.join(root, split, "targets") - images = glob.glob(os.path.join(image_root, "*.png")) + image_root = os.path.join(root, split, 'images') + mask_root = os.path.join(root, split, 'targets') + images = glob.glob(os.path.join(image_root, '*.png')) # Extract basenames while preserving the event-name and sample number for img in images: - basename_parts = os.path.basename(img).split("_") + basename_parts = os.path.basename(img).split('_') event_name = basename_parts[0] # e.g., mexico-earthquake sample_number = basename_parts[1] # e.g., 00000001 basename = ( - f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + f'{event_name}_{sample_number}' # e.g., mexico-earthquake_00000001 ) file_info = { - "image": img, - "mask": os.path.join( - mask_root, f"{basename}_pre_disaster_target.png" + 'image': img, + 'mask': os.path.join( + mask_root, f'{basename}_pre_disaster_target.png' ), - "basename": basename, + 'basename': basename, } all_files.append(file_info) return all_files @@ -431,8 +431,8 @@ def _load_split_files_by_disaster_and_type( Args: files: List of file paths with their corresponding information. - id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). - ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}). Returns: A dictionary containing 'train' (ID) and 'test' (OOD) file lists. @@ -442,40 +442,40 @@ def _load_split_files_by_disaster_and_type( disaster_list = [] for file_info in files: - basename = file_info["basename"] - disaster_name = basename.split("_")[ + basename = file_info['basename'] + disaster_name = basename.split('_')[ 0 ] # Extract disaster name from basename pre_post = ( - "pre" if "pre_disaster" in file_info["image"] else "post" + 'pre' if 'pre_disaster' in file_info['image'] else 'post' ) # Identify pre/post type disaster_list.append(disaster_name) # Filter for in-domain (ID) training set - if disaster_name == id_disaster["disaster_name"]: + if disaster_name == id_disaster['disaster_name']: if ( - id_disaster.get("pre-post") == "both" - or id_disaster["pre-post"] == pre_post + id_disaster.get('pre-post') == 'both' + or id_disaster['pre-post'] == pre_post ): image = ( - file_info["image"].replace("post_disaster", "pre_disaster") - if pre_post == "pre" - else file_info["image"] + file_info['image'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['image'] ) mask = ( - file_info["mask"].replace("post_disaster", "pre_disaster") - if pre_post == "pre" - else file_info["mask"] + file_info['mask'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['mask'] ) train_files.append(dict(image=image, mask=mask)) # Filter for out-of-domain (OOD) test set - if disaster_name == ood_disaster["disaster_name"]: + if disaster_name == ood_disaster['disaster_name']: if ( - ood_disaster.get("pre-post") == "both" - or ood_disaster["pre-post"] == pre_post + ood_disaster.get('pre-post') == 'both' + or ood_disaster['pre-post'] == pre_post ): test_files.append(file_info) - return {"train": train_files, "test": test_files, "disasters": disaster_list} + return {'train': train_files, 'test': test_files, 'disasters': disaster_list} From 0f3ceb7931601dd6e2af9824a7f7bbdeb68217e6 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sat, 1 Feb 2025 19:41:22 +0000 Subject: [PATCH 7/8] no cuda yes docstring --- .gitignore | 1 + torchgeo/datasets/xview.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 180c27c47b2..8357c23b3f8 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ dmypy.json # Pyre type checker .pyre/ +xbdood.ipynb diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 976ca219aad..12a30d16991 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -374,10 +374,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: else self.files['test'][index] ) - image = self._load_image(file_info['image']).to('cuda') - mask = self._load_target(file_info['mask']).long().to('cuda') - mask[mask == 2] = 1 - mask[(mask == 3) | (mask == 4)] = 0 + image = self._load_image(file_info['image']) + mask = self._load_target(file_info['mask']).long() + + # Reformulate as building segmentation task + mask[mask == 2] = 1 # Map damage class 2 to 1 + mask[(mask == 3) | (mask == 4)] = 0 # Map 3 and 4 damage classes to background sample = {'image': image, 'mask': mask} @@ -402,7 +404,7 @@ def _initialize_files(self, root: str) -> list[dict[str, str]]: mask_root = os.path.join(root, split, 'targets') images = glob.glob(os.path.join(image_root, '*.png')) - # Extract basenames while preserving the event-name and sample number + # Extract basenames while preserving the disaster-name and sample number for img in images: basename_parts = os.path.basename(img).split('_') event_name = basename_parts[0] # e.g., mexico-earthquake @@ -470,7 +472,7 @@ def _load_split_files_by_disaster_and_type( ) train_files.append(dict(image=image, mask=mask)) - # Filter for out-of-domain (OOD) test set + # Filter for out-of-distribution (OOD) test set if disaster_name == ood_disaster['disaster_name']: if ( ood_disaster.get('pre-post') == 'both' From a74c99fa2d211e07c647124ae54679a88ffaafc9 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sat, 1 Feb 2025 19:49:35 +0000 Subject: [PATCH 8/8] id ood length method and polishing --- torchgeo/datasets/xview.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 12a30d16991..f6ba2092792 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -362,10 +362,10 @@ def __init__( self.files = self._load_split_files_by_disaster_and_type( self.all_files, id_ood_disaster[0], id_ood_disaster[1] ) - print( - f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files." - ) + train_size, test_size = self.get_id_ood_sizes() + print(f"ID sample len: {train_size}, OOD sample len: {test_size}") + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( @@ -396,6 +396,12 @@ def __len__(self) -> int: else len(self.files['test']) ) + + def get_id_ood_sizes(self) -> tuple[int, int]: + """Return the number of samples in the train and test splits.""" + return (len(self.files['train']), len(self.files['test'])) + + def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" all_files = [] @@ -454,7 +460,7 @@ def _load_split_files_by_disaster_and_type( disaster_list.append(disaster_name) - # Filter for in-domain (ID) training set + # Filter for in-distribution (ID) training set if disaster_name == id_disaster['disaster_name']: if ( id_disaster.get('pre-post') == 'both'