diff --git a/tests/conf/inria.yaml b/tests/conf/inria_test.yaml similarity index 100% rename from tests/conf/inria.yaml rename to tests/conf/inria_test.yaml diff --git a/tests/conf/inria_train.yaml b/tests/conf/inria_train.yaml new file mode 100644 index 00000000000..99db7925f27 --- /dev/null +++ b/tests/conf/inria_train.yaml @@ -0,0 +1,20 @@ +experiment: + task: "inria" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 2 + ignore_index: null + datamodule: + root: "tests/data/inria" + batch_size: 1 + num_workers: 0 + val_split_pct: 0.0 + test_split_pct: 0.0 + patch_size: 2 + num_patches_per_tile: 2 diff --git a/tests/conf/inria_val.yaml b/tests/conf/inria_val.yaml new file mode 100644 index 00000000000..c20f8923439 --- /dev/null +++ b/tests/conf/inria_val.yaml @@ -0,0 +1,20 @@ +experiment: + task: "inria" + module: + loss: "ce" + model: "unet" + backbone: "resnet18" + weights: "imagenet" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + in_channels: 3 + num_classes: 2 + ignore_index: null + datamodule: + root: "tests/data/inria" + batch_size: 1 + num_workers: 0 + val_split_pct: 0.2 + test_split_pct: 0.0 + patch_size: 2 + num_patches_per_tile: 2 diff --git a/tests/datamodules/test_inria.py b/tests/datamodules/test_inria.py deleted file mode 100644 index e4415db96e5..00000000000 --- a/tests/datamodules/test_inria.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os - -import pytest -from _pytest.fixtures import SubRequest - -from torchgeo.datamodules import InriaAerialImageLabelingDataModule - -TEST_DATA_DIR = os.path.join("tests", "data", "inria") - - -class TestInriaAerialImageLabelingDataModule: - @pytest.fixture(params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0])) - def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule: - val_split_pct, test_split_pct = request.param - patch_size = 2 # (2,2) - num_patches_per_tile = 2 - root = TEST_DATA_DIR - batch_size = 1 - num_workers = 0 - dm = InriaAerialImageLabelingDataModule( - root=root, - batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_pct, - test_split_pct=test_split_pct, - patch_size=patch_size, - num_patches_per_tile=num_patches_per_tile, - ) - dm.prepare_data() - dm.setup() - return dm - - def test_train_dataloader( - self, datamodule: InriaAerialImageLabelingDataModule - ) -> None: - sample = next(iter(datamodule.train_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 - assert sample["image"].shape[1] == 3 - assert sample["mask"].shape[1] == 1 - - def test_val_dataloader( - self, datamodule: InriaAerialImageLabelingDataModule - ) -> None: - sample = next(iter(datamodule.val_dataloader())) - if datamodule.val_split_pct > 0.0: - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 - - def test_test_dataloader( - self, datamodule: InriaAerialImageLabelingDataModule - ) -> None: - sample = next(iter(datamodule.test_dataloader())) - if datamodule.test_split_pct > 0.0: - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 - - def test_predict_dataloader( - self, datamodule: InriaAerialImageLabelingDataModule - ) -> None: - sample = next(iter(datamodule.predict_dataloader())) - assert len(sample["image"].shape) == 5 - assert sample["image"].shape[-2:] == (2, 2) - assert sample["image"].shape[2] == 3 diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index dc6be108adf..d30df3d993d 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -40,7 +40,9 @@ class TestSemanticSegmentationTask: ("deepglobelandcover_0", DeepGlobeLandCoverDataModule), ("deepglobelandcover_5", DeepGlobeLandCoverDataModule), ("etci2021", ETCI2021DataModule), - ("inria", InriaAerialImageLabelingDataModule), + ("inria_train", InriaAerialImageLabelingDataModule), + ("inria_val", InriaAerialImageLabelingDataModule), + ("inria_test", InriaAerialImageLabelingDataModule), ("landcoverai", LandCoverAIDataModule), ("naipchesapeake", NAIPChesapeakeDataModule), ("oscd_all", OSCDDataModule), @@ -80,7 +82,9 @@ def test_trainer( trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + + if hasattr(datamodule, "predict_dataset"): + trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml")) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index df8b07f0c21..92d26dc87b2 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -98,7 +98,9 @@ def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: self.patch_size, padding=padding, ) - sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w") + # Needed for reconstruction of patches later + sample["num_patches"] = sample["image"].shape[1] + sample["image"] = rearrange(sample["image"], "b n c h w -> (b n) c h w") return sample def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: