From 552be5df61e199ad147275fb1ebac299eb19692c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 23 Mar 2022 14:31:48 -0500 Subject: [PATCH 1/4] Add PreChippedGeoSampler for pre-chipped geospatial datasets --- docs/api/samplers.rst | 5 ++++ tests/samplers/test_single.py | 47 +++++++++++++++++++++++++++++++- torchgeo/samplers/__init__.py | 3 ++- torchgeo/samplers/single.py | 50 +++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 2 deletions(-) diff --git a/docs/api/samplers.rst b/docs/api/samplers.rst index be7fbefb910..42fb3e0cbf5 100644 --- a/docs/api/samplers.rst +++ b/docs/api/samplers.rst @@ -32,6 +32,11 @@ Grid Geo Sampler .. autoclass:: GridGeoSampler +Pre-chipped Geo Sampler +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: PreChippedGeoSampler + Batch Samplers -------------- diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 529342b4b5f..ae748ae5c6a 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples -from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler, Units +from torchgeo.samplers import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler, Units class CustomGeoSampler(GeoSampler): @@ -189,3 +189,48 @@ def test_dataloader( ) for _ in dl: continue + + +class TestPreChippedGeoSampler: + @pytest.fixture(scope="class") + def dataset(self) -> CustomGeoDataset: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 20, 0, 20, 0, 20)) + ds.index.insert(1, (0, 30, 0, 30, 0, 30)) + return ds + + @pytest.fixture(scope="function") + def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler: + return PreChippedGeoSampler(dataset) + + def test_iter(self, sampler: GridGeoSampler) -> None: + for _ in sampler: + continue + + def test_len(self, sampler: GridGeoSampler) -> None: + assert len(sampler) == 2 + + def test_roi(self, dataset: CustomGeoDataset) -> None: + roi = BoundingBox(5, 15, 5, 15, 5, 15) + sampler = PreChippedGeoSampler(dataset, roi=roi) + for query in sampler: + assert query == roi + + def test_point_data(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 0, 0, 0, 0, 0)) + ds.index.insert(1, (1, 1, 1, 1, 1, 1)) + sampler = PreChippedGeoSampler(ds) + for _ in sampler: + continue + + @pytest.mark.slow + @pytest.mark.parametrize("num_workers", [0, 1, 2]) + def test_dataloader( + self, dataset: CustomGeoDataset, sampler: PreChippedGeoSampler, num_workers: int + ) -> None: + dl = DataLoader( + dataset, sampler=sampler, num_workers=num_workers, collate_fn=stack_samples + ) + for _ in dl: + continue diff --git a/torchgeo/samplers/__init__.py b/torchgeo/samplers/__init__.py index a6f63de1917..17b63603fe9 100644 --- a/torchgeo/samplers/__init__.py +++ b/torchgeo/samplers/__init__.py @@ -5,11 +5,12 @@ from .batch import BatchGeoSampler, RandomBatchGeoSampler from .constants import Units -from .single import GeoSampler, GridGeoSampler, RandomGeoSampler +from .single import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler __all__ = ( # Samplers "GridGeoSampler", + "PreChippedGeoSampler", "RandomGeoSampler", # Batch samplers "RandomBatchGeoSampler", diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 781cb5c38b1..603b27cebc7 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -240,3 +240,53 @@ def __len__(self) -> int: number of patches that will be sampled """ return self.length + + +class PreChippedGeoSampler(GeoSampler): + """Samples entire files at a time. + + This is particularly useful for datasets that contain geospatial metadata + and subclass :class:`~torchgeo.datasets.GeoDataset` but have already been + pre-processed into :term:`chips `. + + This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`. + You may encounter problems when using an :term:`ROI ` + that partially intersects with one of the file bounding boxes, or when using an + :class:`~torchgeo.datasets.IntersectionDataset`. These issues can be solved by + adding padding. + """ + + def __init__( + self, + dataset: GeoDataset, + roi: Optional[BoundingBox] = None, + ) -> None: + """Initialize a new Sampler instance. + + Args: + dataset: dataset to index from + roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) + (defaults to the bounds of ``dataset.index``) + + .. versionadded:: 0.3 + """ + super().__init__(dataset, roi) + + def __iter__(self) -> Iterator[BoundingBox]: + """Return the index of a dataset. + + Returns: + (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + # For each tile... + for hit in self.index.intersection(tuple(self.roi), objects=True): + yield BoundingBox(*hit.bounds) + + def __len__(self) -> int: + """Return the number of samples over the ROI. + + Returns: + number of patches that will be sampled + """ + count: int = self.index.get_size() + return count From 1ee5239ef1746b400c64c6e68df3e80d07df7da7 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 23 Mar 2022 16:59:55 -0500 Subject: [PATCH 2/4] Add shuffle parameter --- tests/samplers/test_single.py | 8 +++++++- torchgeo/samplers/single.py | 20 +++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index ae748ae5c6a..10bba9a4ab5 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -11,7 +11,13 @@ from torch.utils.data import DataLoader from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples -from torchgeo.samplers import GeoSampler, GridGeoSampler, PreChippedGeoSampler, RandomGeoSampler, Units +from torchgeo.samplers import ( + GeoSampler, + GridGeoSampler, + PreChippedGeoSampler, + RandomGeoSampler, + Units, +) class CustomGeoSampler(GeoSampler): diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 603b27cebc7..0039fa495e7 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -7,6 +7,7 @@ import random from typing import Iterator, Optional, Tuple, Union +import torch from rtree.index import Index, Property from torch.utils.data import Sampler @@ -260,6 +261,7 @@ def __init__( self, dataset: GeoDataset, roi: Optional[BoundingBox] = None, + shuffle: bool = False, ) -> None: """Initialize a new Sampler instance. @@ -267,10 +269,16 @@ def __init__( dataset: dataset to index from roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) + shuffle: if True, reshuffle data at every epoch .. versionadded:: 0.3 """ super().__init__(dataset, roi) + self.shuffle = shuffle + + self.hits = [] + for hit in self.index.intersection(tuple(self.roi), objects=True): + self.hits.append(hit) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -278,9 +286,12 @@ def __iter__(self) -> Iterator[BoundingBox]: Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - # For each tile... - for hit in self.index.intersection(tuple(self.roi), objects=True): - yield BoundingBox(*hit.bounds) + generator = range + if self.shuffle: + generator = torch.randperm + + for idx in generator(len(self)): + yield BoundingBox(*self.hits[idx].bounds) def __len__(self) -> int: """Return the number of samples over the ROI. @@ -288,5 +299,4 @@ def __len__(self) -> int: Returns: number of patches that will be sampled """ - count: int = self.index.get_size() - return count + return len(self.hits) From 5cca5f51b93d004b0dbc9322f2265de772067d9d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 23 Mar 2022 17:28:47 -0500 Subject: [PATCH 3/4] Add tests, fix type hints --- tests/samplers/test_single.py | 2 +- torchgeo/samplers/single.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 10bba9a4ab5..e7545459dd1 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -207,7 +207,7 @@ def dataset(self) -> CustomGeoDataset: @pytest.fixture(scope="function") def sampler(self, dataset: CustomGeoDataset) -> PreChippedGeoSampler: - return PreChippedGeoSampler(dataset) + return PreChippedGeoSampler(dataset, shuffle=True) def test_iter(self, sampler: GridGeoSampler) -> None: for _ in sampler: diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 0039fa495e7..2001d7d4c68 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,7 +5,7 @@ import abc import random -from typing import Iterator, Optional, Tuple, Union +from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch from rtree.index import Index, Property @@ -286,7 +286,7 @@ def __iter__(self) -> Iterator[BoundingBox]: Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - generator = range + generator: Callable[[int], Iterable[int]] = range if self.shuffle: generator = torch.randperm From 73a311668ba08fc1cfd1cde02e0f67639bd756cc Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 30 Mar 2022 13:54:13 -0500 Subject: [PATCH 4/4] Warn about multi-CRS datasets --- torchgeo/samplers/single.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 2001d7d4c68..01846fe0e1c 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -252,9 +252,9 @@ class PreChippedGeoSampler(GeoSampler): This sampler should not be used with :class:`~torchgeo.datasets.VisionDataset`. You may encounter problems when using an :term:`ROI ` - that partially intersects with one of the file bounding boxes, or when using an - :class:`~torchgeo.datasets.IntersectionDataset`. These issues can be solved by - adding padding. + that partially intersects with one of the file bounding boxes, when using an + :class:`~torchgeo.datasets.IntersectionDataset`, or when each file is in a + different CRS. These issues can be solved by adding padding. """ def __init__(