diff --git a/predict.py b/predict.py new file mode 100644 index 00000000000..a1edd6df95e --- /dev/null +++ b/predict.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""torchgeo model inference script.""" + +import argparse +import os +from typing import Dict, Tuple, Type, cast + +import pytorch_lightning as pl +import rasterio as rio +import torch +from kornia.contrib import CombineTensorPatches +from omegaconf import OmegaConf + +from torchgeo.datamodules import ( + BigEarthNetDataModule, + ChesapeakeCVPRDataModule, + COWCCountingDataModule, + CycloneDataModule, + ETCI2021DataModule, + EuroSATDataModule, + InriaAerialImageLabelingDataModule, + LandCoverAIDataModule, + NAIPChesapeakeDataModule, + OSCDDataModule, + RESISC45DataModule, + SEN12MSDataModule, + So2SatDataModule, + UCMercedDataModule, +) +from torchgeo.trainers import ( + BYOLTask, + ClassificationTask, + MultiLabelClassificationTask, + RegressionTask, + SemanticSegmentationTask, +) + +TASK_TO_MODULES_MAPPING: Dict[ + str, Tuple[Type[pl.LightningModule], Type[pl.LightningDataModule]] +] = { + "bigearthnet": (MultiLabelClassificationTask, BigEarthNetDataModule), + "byol": (BYOLTask, ChesapeakeCVPRDataModule), + "chesapeake_cvpr": (SemanticSegmentationTask, ChesapeakeCVPRDataModule), + "cowc_counting": (RegressionTask, COWCCountingDataModule), + "cyclone": (RegressionTask, CycloneDataModule), + "eurosat": (ClassificationTask, EuroSATDataModule), + "etci2021": (SemanticSegmentationTask, ETCI2021DataModule), + "inria": (SemanticSegmentationTask, InriaAerialImageLabelingDataModule), + "landcoverai": (SemanticSegmentationTask, LandCoverAIDataModule), + "naipchesapeake": (SemanticSegmentationTask, NAIPChesapeakeDataModule), + "oscd": (SemanticSegmentationTask, OSCDDataModule), + "resisc45": (ClassificationTask, RESISC45DataModule), + "sen12ms": (SemanticSegmentationTask, SEN12MSDataModule), + "so2sat": (ClassificationTask, So2SatDataModule), + "ucmerced": (ClassificationTask, UCMercedDataModule), +} + + +def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None: + """Write mask to specified output directory with same filename as input raster. + + Args: + mask (torch.Tensor): mask tensor + output_dir (str): output directory + input_filename (str): path to input raster + """ + output_path = os.path.join(output_dir, os.path.basename(input_filename)) + with rio.open(input_filename) as src: + profile = src.profile + profile["count"] = 1 + profile["dtype"] = "uint8" + mask = mask.cpu().numpy() + with rio.open(output_path, "w", **profile) as ds: + ds.write(mask) + + +def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None: + """Main inference loop. + + Args: + config_dir (str): Path to config-dir to load config and ckpt + predict_on (str): Directory/Dataset to run inference on + output_dir (str): Path to output_directory to save predicted masks + device (str): Choice of device. Must be in [cuda, cpu] + + Raises: + ValueError: Raised if task name is not in TASK_TO_MODULES_MAPPING + FileExistsError: Raised if specified output directory contains + files and overwrite=False. + """ + os.makedirs(output_dir, exist_ok=True) + + # Load checkpoint and config + conf = OmegaConf.load(os.path.join(config_dir, "experiment_config.yaml")) + ckpt = os.path.join(config_dir, "last.ckpt") + + # Load model + task_name = conf.experiment.task + datamodule: pl.LightningDataModule + task: pl.LightningModule + if task_name not in TASK_TO_MODULES_MAPPING: + raise ValueError( + f"experiment.task={task_name} is not recognized as a valid task" + ) + task_class, datamodule_class = TASK_TO_MODULES_MAPPING[task_name] + task = task_class.load_from_checkpoint(ckpt) + task = task.to(device) + task.eval() + + # Load datamodule and dataloader + conf.experiment.datamodule["predict_on"] = predict_on + datamodule = datamodule_class(**conf.experiment.datamodule) + datamodule.setup() + dataloader = datamodule.predict_dataloader() + + if len(os.listdir(output_dir)) > 0: + if conf.program.overwrite: + print( + f"WARNING! The output directory, {output_dir}, already exists, " + + "we will overwrite data in it!" + ) + else: + raise FileExistsError( + f"The predictions directory, {output_dir}, already exists and isn't " + + "empty. We don't want to overwrite any existing results, exiting..." + ) + + for i, batch in enumerate(dataloader): + x = batch["image"].to(device) # (N, B, C, H, W) + assert len(x.shape) in {4, 5} + if len(x.shape) == 5: + masks = [] + + def tensor_to_int( + tensor_tuple: Tuple[torch.Tensor, ...] + ) -> Tuple[int, ...]: + """Convert tuple of tensors to tuple of ints.""" + return tuple(int(i.item()) for i in tensor_tuple) + + original_shape = cast( + Tuple[int, int], tensor_to_int(batch["original_shape"]) + ) + patch_shape = cast(Tuple[int, int], tensor_to_int(batch["patch_shape"])) + padding = cast(Tuple[int, int], tensor_to_int(batch["padding"])) + patch_combine = CombineTensorPatches( + original_size=original_shape, window_size=patch_shape, unpadding=padding + ) + + for tile in x: + mask = task(tile) + mask = mask.argmax(dim=1) + masks.append(mask) + + masks_arr = torch.stack(masks, dim=0) + masks_arr = masks_arr.unsqueeze(0) + masks_combined = patch_combine(masks_arr)[0] + filename = datamodule.predict_dataset.files[i]["image"] + write_mask(masks_combined, output_dir, filename) + else: + mask = task(x) + mask = mask.argmax(dim=1) + filename = datamodule.predict_dataset.files[i]["image"] + write_mask(mask, output_dir, filename) + + +if __name__ == "__main__": + # Taken from https://github.com/pangeo-data/cog-best-practices + _rasterio_best_practices = { + "GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR", + "AWS_NO_SIGN_REQUEST": "YES", + "GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000", + "GDAL_SWATH_SIZE": "200000000", + "VSI_CURL_CACHE_SIZE": "200000000", + } + os.environ.update(_rasterio_best_practices) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-dir", + type=str, + required=True, + help="Path to config-dir to load config and ckpt", + ) + + parser.add_argument( + "--predict_on", + type=str, + required=True, + help="Directory/Dataset to run inference on", + ) + + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Path to output_directory to save predicted mask geotiffs", + ) + + parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) + args = parser.parse_args() + main(args.config_dir, args.predict_on, args.output_dir, args.device) diff --git a/tests/datamodules/test_inria.py b/tests/datamodules/test_inria.py index 8ad1d33d2d6..ca982dfea65 100644 --- a/tests/datamodules/test_inria.py +++ b/tests/datamodules/test_inria.py @@ -9,11 +9,12 @@ from torchgeo.datamodules import InriaAerialImageLabelingDataModule TEST_DATA_DIR = os.path.join("tests", "data", "inria") +PREDICT_DATA_DIR = os.path.join(TEST_DATA_DIR, "AerialImageDataset/test/images") class TestInriaAerialImageLabelingDataModule: @pytest.fixture( - params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", "test", "test"]) + params=zip([0.2, 0.2, 0.0], [0.2, 0.0, 0.0], ["test", "test", PREDICT_DATA_DIR]) ) def datamodule(self, request: SubRequest) -> InriaAerialImageLabelingDataModule: val_split_pct, test_split_pct, predict_on = request.param diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index 43b4dff641a..e2cda64dc01 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -16,12 +16,14 @@ import numpy as np import pytest import torch +from _pytest.fixtures import SubRequest from _pytest.monkeypatch import MonkeyPatch from rasterio.crs import CRS import torchgeo.datasets.utils from torchgeo.datasets.utils import ( BoundingBox, + PredictDataset, concat_samples, disambiguate_timestamp, download_and_extract_archive, @@ -582,3 +584,26 @@ def test_percentile_normalization() -> None: img = percentile_normalization(img, 2, 98) assert img.min() == 0 assert img.max() == 1 + + +class TestPredictDataset: + @pytest.fixture( + params=zip([None, torch.nn.Identity(), None], [(2, 2), (8, 8), (16, 16)]) + ) + def dataset(self, request: SubRequest) -> PredictDataset: + root = os.path.join( + "tests", "data", "inria", "AerialImageDataset", "test", "images" + ) + transforms, patch_size = request.param + return PredictDataset(root, patch_size=patch_size, transforms=transforms) + + def test_getitem(self, dataset: PredictDataset) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert x["image"].ndim == 5 + assert len(x["original_shape"]) == len(x["patch_shape"]) == 2 + assert len(x["padding"]) == 4 + + def test_len(self, dataset: PredictDataset) -> None: + assert len(dataset) == 5 diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 9d5c177f39a..5e54b963076 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,6 +3,7 @@ """InriaAerialImageLabeling datamodule.""" +import os from typing import Any, Dict, List, Optional, Tuple, Union, cast import kornia.augmentation as K @@ -14,7 +15,7 @@ from torch.utils.data import DataLoader, Dataset from torch.utils.data._utils.collate import default_collate -from ..datasets import InriaAerialImageLabeling +from ..datasets import InriaAerialImageLabeling, PredictDataset from ..samplers.utils import _to_tuple from .utils import dataset_split @@ -167,10 +168,15 @@ def setup(self, stage: Optional[str] = None) -> None: self.val_dataset = train_dataset self.test_dataset = train_dataset - assert self.predict_on == "test" - self.predict_dataset = InriaAerialImageLabeling( - self.root_dir, self.predict_on, transforms=test_transforms - ) + if os.path.isdir(self.predict_on): + self.predict_dataset = PredictDataset( + self.predict_on, patch_size=self.patch_size, transforms=self.preprocess + ) + else: + assert self.predict_on == "test" + self.predict_dataset = InriaAerialImageLabeling( # type: ignore[assignment] + self.root_dir, self.predict_on, transforms=test_transforms + ) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training.""" diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 5b39365a2c5..25438e32b23 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -97,6 +97,7 @@ from .usavars import USAVars from .utils import ( BoundingBox, + PredictDataset, concat_samples, merge_samples, stack_samples, @@ -202,6 +203,7 @@ "VisionDataset", # Utilities "BoundingBox", + "PredictDataset", "concat_samples", "merge_samples", "stack_samples", diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 658220d4903..2c40c4931c5 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -15,6 +15,7 @@ from datetime import datetime, timedelta from typing import ( Any, + Callable, Dict, Iterable, Iterator, @@ -30,7 +31,11 @@ import numpy as np import rasterio import torch +import torchvision.transforms as T +from einops import rearrange +from kornia.contrib import compute_padding, extract_tensor_patches from torch import Tensor +from torch.utils.data import Dataset from torchvision.datasets.utils import check_integrity, download_url from torchvision.utils import draw_segmentation_masks @@ -51,9 +56,108 @@ "draw_semantic_segmentation_masks", "rgb_to_mask", "percentile_normalization", + "PredictDataset", ) +class PredictDataset(Dataset[Any]): + """Prediction dataset for VisionDatasets.""" + + def __init__( + self, + root: str, + patch_size: Tuple[int, int], + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + bands: Tuple[int, ...] = (1, 2, 3), + ) -> None: + """Initialize a new PredictDataset instance. + + Args: + root: root directory where dataset can be found + patch_size: Size of patch used as input for the model. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version. + bands: bands to be used. + + """ + self.root = root + self.patch_size = patch_size + self.bands = bands + # patch_sample must not be passed to PredictDataset + if transforms: + self.transforms = T.Compose([transforms, self.patch_sample]) + else: + self.transforms = self.patch_sample + self.files = self._load_files(root) + + def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Extract patches from single sample.""" + assert sample["image"].ndim == 3 + _, h, w = sample["image"].shape + + padding = compute_padding((h, w), self.patch_size) + sample["original_shape"] = (h, w) + sample["patch_shape"] = self.patch_size + sample["padding"] = padding + sample["image"] = extract_tensor_patches( + sample["image"].unsqueeze(0), + self.patch_size, + self.patch_size, + padding=padding, + ) + sample["image"] = rearrange(sample["image"], "() t c h w -> t () c h w") + return sample + + def _load_files(self, root: str) -> List[Dict[str, str]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each pair of image and label + """ + images = [os.path.join(root, i) for i in os.listdir(root)] + return [{"image": img} for img in images] + + def __len__(self) -> int: + """Return the number of samples in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with rasterio.open(path) as img: + array = img.read(self.bands).astype(np.int32) + tensor: Tensor = torch.from_numpy(array) + return tensor + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + file = self.files[index] + img = self._load_image(file["image"]) + sample = {"image": img} + sample = self.transforms(sample) + return sample + + class _rarfile: class RarFile: def __init__(self, *args: Any, **kwargs: Any) -> None: