diff --git a/docs/api/samplers.rst b/docs/api/samplers.rst index be7fbefb910..42fb3e0cbf5 100644 --- a/docs/api/samplers.rst +++ b/docs/api/samplers.rst @@ -32,6 +32,11 @@ Grid Geo Sampler .. autoclass:: GridGeoSampler +Pre-chipped Geo Sampler +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: PreChippedGeoSampler + Batch Samplers -------------- diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 529342b4b5f..e7545459dd1 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -11,7 +11,13 @@ from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples -from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units +from torchgeo.samplers import ( + GeoSampler, + GridGeoSampler, + PreChippedGeoSampler, + RandomGeoSampler, + Units, +) class CustomGeoSampler(GeoSampler): @@ -189,3 +195,48 @@ def test_dataloader( ) for _ in dl: continue + + +class TestPreChippedGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 20, 0, 20, 0, 20)) + ds.index.insert(1, (0, 30, 0, 30, 0, 30)) + return ds + + @pytest.fixture(scope="function") + def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler: + return PreChippedGeoSampler(dataset, shuffle=True) + + def test_iter(self, sampler: GridGeoSampler) -> None: + for _ in sampler: + continue + + def test_len(self, sampler: GridGeoSampler) -> None: + assert len(sampler) == 2 + + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(5, 15, 5, 15, 5, 15) + sampler = PreChippedGeoSampler(dataset, roi=roi) + for query in sampler: + assert query == roi + + def test_point_data(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 0, 0, 0, 0, 0)) + ds.index.insert(1, (1, 1, 1, 1, 1, 1)) + sampler = PreChippedGeoSampler(ds) + for _ in sampler: + continue + + @pytest.mark.slow + @pytest.mark.parametrize("num_workers", [0, 1, 2]) + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: PreChippedGeoSampler, num_workers: int + ) -> None: + dl = DataLoader( + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + ) + for _ in dl: + continue diff --git a/torchgeo/samplers/__init__.py b/torchgeo/samplers/__init__.py index a6f63de1917..17b63603fe9 100644 --- a/torchgeo/samplers/__init__.py +++ b/torchgeo/samplers/__init__.py @@ -5,11 +5,12 @@ from .batch import BatchGeoSampler, RandomBatchGeoSampler from .constants import Units -from .single import GeoSampler, GridGeoSampler, RandomGeoSampler +from .single import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler __all__ = ( # Samplers "GridGeoSampler", + "PreChippedGeoSampler", "RandomGeoSampler", # Batch samplers "RandomBatchGeoSampler", diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 781cb5c38b1..01846fe0e1c 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,8 +5,9 @@ import abc import random -from typing import Iterator, Optional, Tuple, Union +from typing import Callable, Iterable, Iterator, Optional, Tuple, Union +import torch from rtree.index import Index, Property from torch.utils.data import Sampler @@ -240,3 +241,62 @@ def __len__(self) -> int: number of patches that will be sampled """ return self.length + + +class PreChippedGeoSampler(GeoSampler): + """Samples entire files at a time. + + This is particularly useful for datasets that contain geospatial metadata + and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been + pre-processed into :term:`chips `. + + This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`. + You may encounter problems when using an :term:`ROI ` + that partially intersects with one of the file bounding boxes, when using an + :class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a + different CRS. These issues can be solved by adding padding. + """ + + def __init__( + self, + dataset: GeoDataset, + roi: Optional[BoundingBox] = None, + shuffle: bool = False, + ) -> None: + """Initialize a new Sampler instance. + + Args: + dataset: dataset to index from + roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) + (defaults to the bounds of ``dataset.index``) + shuffle: if True, reshuffle data at every epoch + + .. versionadded:: 0.3 + """ + super().__init__(dataset, roi) + self.shuffle = shuffle + + self.hits = [] + for hit in self.index.intersection(tuple(self.roi), objects=True): + self.hits.append(hit) + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + generator: Callable[[int], Iterable[int]] = range + if self.shuffle: + generator = torch.randperm + + for idx in generator(len(self)): + yield BoundingBox(*self.hits[idx].bounds) + + def __len__(self) -> int: + """Return the number of samples over the ROI. + + Returns: + number of patches that will be sampled + """ + return len(self.hits)