diff --git a/docs/conf.py b/docs/conf.py index bd6d78a7525..91426900394 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -94,5 +94,6 @@ "python": ("https://docs.python.org/3", None), "pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None), "rasterio": ("https://rasterio.readthedocs.io/en/latest/", None), + "rtree": ("https://rtree.readthedocs.io/en/latest/", None), "torch": ("https://pytorch.org/docs/stable", None), } diff --git a/docs/datasets.rst b/docs/datasets.rst index 23322bbca17..73de012f4da 100644 --- a/docs/datasets.rst +++ b/docs/datasets.rst @@ -3,7 +3,7 @@ torchgeo.datasets .. module:: torchgeo.datasets -In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Base Classes`. +In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Dataset Base Classes`. Geospatial Datasets ------------------- @@ -107,8 +107,8 @@ NWPU VHR-10 .. autoclass:: VHR10 -Base Classes ------------- +Dataset Base Classes +-------------------- If you want to write your own custom dataset, you can extend one of these abstract base classes. diff --git a/docs/samplers.rst b/docs/samplers.rst index 8675f090993..ada9c457555 100644 --- a/docs/samplers.rst +++ b/docs/samplers.rst @@ -1,4 +1,68 @@ torchgeo.samplers ================= -.. automodule:: torchgeo.samplers +.. module:: torchgeo.samplers + +Samplers +-------- + +Samplers are used to index a dataset, retrieving a single query at a time. For :class:`~torchgeo.datasets.VisionDataset`, dataset objects can be indexed with integers, and PyTorch's builtin samplers are sufficient. For :class:`~torchgeo.datasets.GeoDataset`, dataset objects require a bounding box for indexing. For this reason, we define our own :class:`GeoSampler` implementations below. These can be used like so: + +.. code-block:: python + + from torch.utils.data import DataLoader + + from torchgeo.datasets import Landsat + from torchgeo.samplers import RandomGeoSampler + + dataset = Landsat(...) + sampler = RandomGeoSampler(dataset.index, size=1000, length=100) + dataloader = DataLoader(dataset, sampler=sampler) + + +Random Geo Sampler +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: RandomGeoSampler + +Grid Geo Sampler +^^^^^^^^^^^^^^^^ + +.. autoclass:: GridGeoSampler + +Batch Samplers +-------------- + +When working with large tile-based datasets, randomly sampling patches from each tile can be extremely time consuming. It's much more efficient to choose a tile, load it, warp it to the appropriate :term:`coordinate reference system (CRS)` and resolution, and then sample random patches from that tile to construct a mini-batch of data. For this reason, we define our own :class:`BatchGeoSampler` implementations below. These can be used like so: + +.. code-block:: python + + from torch.utils.data import DataLoader + + from torchgeo.datasets import Landsat + from torchgeo.samplers import RandomBatchGeoSampler + + dataset = Landsat(...) + sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100) + dataloader = DataLoader(dataset, batch_sampler=sampler) + + +Random Batch Geo Sampler +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: RandomBatchGeoSampler + +Sampler Base Classes +-------------------- + +If you want to write your own custom sampler, you can extend one of these abstract base classes. + +Geo Sampler +^^^^^^^^^^^ + +.. autoclass:: GeoSampler + +Batch Geo Sampler +^^^^^^^^^^^^^^^^^ + +.. autoclass:: BatchGeoSampler diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py new file mode 100644 index 00000000000..dd4866f7f5b --- /dev/null +++ b/tests/samplers/test_batch.py @@ -0,0 +1,64 @@ +import math +from typing import Iterator, List + +import pytest +from _pytest.fixtures import SubRequest +from rtree.index import Index, Property + +from torchgeo.datasets import BoundingBox +from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler + + +class CustomBatchGeoSampler(BatchGeoSampler): + def __init__(self) -> None: + pass + + def __iter__(self) -> Iterator[List[BoundingBox]]: + for i in range(len(self)): + yield [BoundingBox(j, j, j, j, j, j) for j in range(len(self))] + + def __len__(self) -> int: + return 2 + + +class TestBatchGeoSampler: + @pytest.fixture(scope="function") + def sampler(self) -> CustomBatchGeoSampler: + return CustomBatchGeoSampler() + + def test_iter(self, sampler: CustomBatchGeoSampler) -> None: + expected = [BoundingBox(0, 0, 0, 0, 0, 0), BoundingBox(1, 1, 1, 1, 1, 1)] + assert next(iter(sampler)) == expected + + def test_len(self, sampler: CustomBatchGeoSampler) -> None: + assert len(sampler) == 2 + + def test_abstract(self) -> None: + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + BatchGeoSampler(None) # type: ignore[abstract] + + +class TestRandomBatchGeoSampler: + @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) + def sampler(self, request: SubRequest) -> RandomBatchGeoSampler: + index = Index(interleaved=False, properties=Property(dimension=3)) + index.insert(0, (0, 10, 20, 30, 40, 50)) + index.insert(1, (0, 10, 20, 30, 40, 50)) + size = request.param + return RandomBatchGeoSampler(index, size, batch_size=2, length=10) + + def test_iter(self, sampler: RandomBatchGeoSampler) -> None: + for batch in sampler: + for query in batch: + assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx + assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy + assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt + + assert math.isclose(query.maxx - query.minx, sampler.size[1]) + assert math.isclose(query.maxy - query.miny, sampler.size[0]) + assert math.isclose( + query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint + ) + + def test_len(self, sampler: RandomBatchGeoSampler) -> None: + assert len(sampler) == sampler.length // sampler.batch_size diff --git a/tests/samplers/test_samplers.py b/tests/samplers/test_single.py similarity index 80% rename from tests/samplers/test_samplers.py rename to tests/samplers/test_single.py index e3d492bbdd0..0fbae8dbd3d 100644 --- a/tests/samplers/test_samplers.py +++ b/tests/samplers/test_single.py @@ -3,6 +3,7 @@ import pytest from _pytest.fixtures import SubRequest +from rtree.index import Index, Property from torchgeo.datasets import BoundingBox from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler @@ -22,13 +23,13 @@ def __len__(self) -> int: class TestGeoSampler: @pytest.fixture(scope="function") - def sampler(self) -> GeoSampler: + def sampler(self) -> CustomGeoSampler: return CustomGeoSampler() - def test_iter(self, sampler: GeoSampler) -> None: + def test_iter(self, sampler: CustomGeoSampler) -> None: assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0) - def test_len(self, sampler: GeoSampler) -> None: + def test_len(self, sampler: CustomGeoSampler) -> None: assert len(sampler) == 2 def test_abstract(self) -> None: @@ -39,9 +40,11 @@ def test_abstract(self) -> None: class TestRandomGeoSampler: @pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) def sampler(self, request: SubRequest) -> RandomGeoSampler: - roi = BoundingBox(0, 10, 20, 30, 40, 50) + index = Index(interleaved=False, properties=Property(dimension=3)) + index.insert(0, (0, 10, 20, 30, 40, 50)) + index.insert(1, (0, 10, 20, 30, 40, 50)) size = request.param - return RandomGeoSampler(roi, size, length=10) + return RandomGeoSampler(index, size, length=10) def test_iter(self, sampler: RandomGeoSampler) -> None: for query in sampler: @@ -72,9 +75,11 @@ class TestGridGeoSampler: ], ) def sampler(self, request: SubRequest) -> GridGeoSampler: - roi = BoundingBox(0, 10, 20, 30, 40, 50) + index = Index(interleaved=False, properties=Property(dimension=3)) + index.insert(0, (0, 10, 20, 30, 40, 50)) + index.insert(1, (0, 10, 20, 30, 40, 50)) size, stride = request.param - return GridGeoSampler(roi, size, stride) + return GridGeoSampler(index, size, stride) def test_iter(self, sampler: GridGeoSampler) -> None: for query in sampler: @@ -87,6 +92,3 @@ def test_iter(self, sampler: GridGeoSampler) -> None: assert math.isclose( query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint ) - - def test_len(self, sampler: RandomGeoSampler) -> None: - assert len(sampler) == 9 diff --git a/torchgeo/samplers/__init__.py b/torchgeo/samplers/__init__.py index 85d3d929a30..2e67cbae19f 100644 --- a/torchgeo/samplers/__init__.py +++ b/torchgeo/samplers/__init__.py @@ -1,8 +1,18 @@ """TorchGeo samplers.""" -from .samplers import GeoSampler, GridGeoSampler, RandomGeoSampler +from .batch import BatchGeoSampler, RandomBatchGeoSampler +from .single import GeoSampler, GridGeoSampler, RandomGeoSampler -__all__ = ("GeoSampler", "GridGeoSampler", "RandomGeoSampler") +__all__ = ( + # Samplers + "GridGeoSampler", + "RandomGeoSampler", + # Batch samplers + "RandomBatchGeoSampler", + # Base classes + "GeoSampler", + "BatchGeoSampler", +) # https://stackoverflow.com/questions/40018681 for module in __all__: diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py new file mode 100644 index 00000000000..9423fb44ef5 --- /dev/null +++ b/torchgeo/samplers/batch.py @@ -0,0 +1,114 @@ +"""TorchGeo batch samplers.""" + +import abc +import random +from typing import Iterator, List, Optional, Tuple, Union + +from rtree.index import Index +from torch.utils.data import Sampler + +from torchgeo.datasets import BoundingBox + +from .utils import _to_tuple + +# https://github.com/pytorch/pytorch/issues/60979 +# https://github.com/pytorch/pytorch/pull/61045 +Sampler.__module__ = "torch.utils.data" + + +class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC): + """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. + + Unlike PyTorch's :class:`~torch.utils.data.BatchSampler`, :class:`BatchGeoSampler` + returns enough geospatial information to uniquely index any + :class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude, + longitude, height, width, projection, coordinate system, and time. + """ + + @abc.abstractmethod + def __iter__(self) -> Iterator[List[BoundingBox]]: + """Return a batch of indices of a dataset. + + Returns: + batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + + +class RandomBatchGeoSampler(BatchGeoSampler): + """Samples batches of elements from a region of interest randomly. + + This is particularly useful during training when you want to maximize the size of + the dataset and return as many random :term:`chips ` as possible. + + When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come + from a tile-based dataset if possible. + """ + + def __init__( + self, + index: Index, + size: Union[Tuple[float, float], float], + batch_size: int, + length: int, + roi: Optional[BoundingBox] = None, + ) -> None: + """Initialize a new Sampler instance. + + The ``size`` argument can either be: + + * a single ``float`` - in which case the same value is used for the height and + width dimension + * a ``tuple`` of two floats - in which case, the first *float* is used for the + height dimension, and the second *float* for the width dimension + + Args: + index: index of a :class:`~torchgeo.datasets.GeoDataset` + size: dimensions of each :term:`patch` in units of CRS + batch_size: number of samples per batch + length: number of samples per epoch + roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) + (defaults to the bounds of ``index``) + """ + self.index = index + self.size = _to_tuple(size) + self.batch_size = batch_size + self.length = length + if roi is None: + roi = BoundingBox(*index.bounds) + self.roi = roi + self.hits = list(index.intersection(roi, objects=True)) + + def __iter__(self) -> Iterator[List[BoundingBox]]: + """Return the indices of a dataset. + + Returns: + batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset + """ + for _ in range(len(self)): + # Choose a random tile + hit = random.choice(self.hits) + bounds = BoundingBox(*hit.bounds) + + # Choose random indices within that tile + batch = [] + for _ in range(self.batch_size): + minx = random.uniform(bounds.minx, bounds.maxx - self.size[1]) + maxx = minx + self.size[1] + + miny = random.uniform(bounds.miny, bounds.maxy - self.size[0]) + maxy = miny + self.size[0] + + mint = bounds.mint + maxt = bounds.maxt + + batch.append(BoundingBox(minx, maxx, miny, maxy, mint, maxt)) + + yield batch + + def __len__(self) -> int: + """Return the number of batches in a single epoch. + + Returns: + number of batches in an epoch + """ + return self.length // self.batch_size diff --git a/torchgeo/samplers/samplers.py b/torchgeo/samplers/single.py similarity index 63% rename from torchgeo/samplers/samplers.py rename to torchgeo/samplers/single.py index fd5f71ac0dd..90b6670740d 100644 --- a/torchgeo/samplers/samplers.py +++ b/torchgeo/samplers/single.py @@ -2,33 +2,21 @@ import abc import random -from typing import Any, Iterator, Tuple, Union +from typing import Iterator, Optional, Tuple, Union +from rtree.index import Index from torch.utils.data import Sampler from torchgeo.datasets import BoundingBox +from .utils import _to_tuple + # https://github.com/pytorch/pytorch/issues/60979 # https://github.com/pytorch/pytorch/pull/61045 Sampler.__module__ = "torch.utils.data" -def _to_tuple(value: Union[Tuple[Any, Any], Any]) -> Tuple[Any, Any]: - """Convert value to a tuple if it is not already a tuple. - - Args: - value: input value - - Returns: - value if value is a tuple, else (value, value) - """ - if isinstance(value, (float, int)): - return (value, value) - else: - return value - - -class GeoSampler(Sampler[Tuple[Any, ...]], abc.ABC): +class GeoSampler(Sampler[BoundingBox], abc.ABC): """Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. Unlike PyTorch's :class:`~torch.utils.data.Sampler`, :class:`GeoSampler` @@ -45,26 +33,25 @@ def __iter__(self) -> Iterator[BoundingBox]: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - @abc.abstractmethod - def __len__(self) -> int: - """Return the number of samples in a single epoch. - - Returns: - length of the epoch - """ - class RandomGeoSampler(GeoSampler): """Samples elements from a region of interest randomly. This is particularly useful during training when you want to maximize the size of the dataset and return as many random :term:`chips ` as possible. + + This sampler is not recommended for use with tile-based datasets. Use + :class:`RandomBatchGeoSampler` instead. """ def __init__( - self, roi: BoundingBox, size: Union[Tuple[float, float], float], length: int + self, + index: Index, + size: Union[Tuple[float, float], float], + length: int, + roi: Optional[BoundingBox] = None, ) -> None: - """Initialize a new RandomGeoSampler. + """Initialize a new Sampler instance. The ``size`` argument can either be: @@ -74,13 +61,19 @@ def __init__( height dimension, and the second *float* for the width dimension Args: - roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) + index: index of a :class:`~torchgeo.datasets.GeoDataset` size: dimensions of each :term:`patch` in units of CRS length: number of random samples to draw per epoch + roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) + (defaults to the bounds of ``index``) """ - self.roi = roi + self.index = index self.size = _to_tuple(size) self.length = length + if roi is None: + roi = BoundingBox(*index.bounds) + self.roi = roi + self.hits = list(index.intersection(roi, objects=True)) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -89,15 +82,19 @@ def __iter__(self) -> Iterator[BoundingBox]: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ for _ in range(len(self)): - minx = random.uniform(self.roi.minx, self.roi.maxx - self.size[1]) + # Choose a random tile + hit = random.choice(self.hits) + bounds = BoundingBox(*hit.bounds) + + # Choose a random index within that tile + minx = random.uniform(bounds.minx, bounds.maxx - self.size[1]) maxx = minx + self.size[1] - miny = random.uniform(self.roi.miny, self.roi.maxy - self.size[0]) + miny = random.uniform(bounds.miny, bounds.maxy - self.size[0]) maxy = miny + self.size[0] - # TODO: figure out how to handle time - mint = self.roi.mint - maxt = self.roi.maxt + mint = bounds.mint + maxt = bounds.maxt yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) @@ -123,15 +120,19 @@ class GridGeoSampler(GeoSampler): The overlap between each chip (``chip_size - stride``) should be approximately equal to the `receptive field `_ of the CNN. + + When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come + from a non-tile-based dataset if possible. """ def __init__( self, - roi: BoundingBox, + index: Index, size: Union[Tuple[float, float], float], stride: Union[Tuple[float, float], float], + roi: Optional[BoundingBox] = None, ) -> None: - """Initialize a new RandomGeoSampler. + """Initialize a new Sampler instance. The ``size`` and ``stride`` arguments can either be: @@ -141,15 +142,18 @@ def __init__( height dimension, and the second *float* for the width dimension Args: + index: index of a :class:`~torchgeo.datasets.GeoDataset` roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) size: dimensions of each :term:`patch` in units of CRS stride: distance to skip between each patch """ - self.roi = roi + self.index = index self.size = _to_tuple(size) self.stride = _to_tuple(stride) - self.rows = int((roi.maxy - roi.miny - self.size[0]) // self.stride[0]) + 1 - self.cols = int((roi.maxx - roi.minx - self.size[1]) // self.stride[1]) + 1 + if roi is None: + roi = BoundingBox(*index.bounds) + self.roi = roi + self.hits = index.intersection(roi, objects=True) def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -157,23 +161,24 @@ def __iter__(self) -> Iterator[BoundingBox]: Returns: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ - for i in range(self.rows): - miny = self.roi.miny + i * self.stride[0] - maxy = miny + self.size[0] - for j in range(self.cols): - minx = self.roi.minx + j * self.stride[1] - maxx = minx + self.size[1] + # For each tile... + for hit in self.hits: + bounds = BoundingBox(*hit.bounds) - # TODO: figure out how to handle time - mint = self.roi.mint - maxt = self.roi.maxt + rows = int((bounds.maxy - bounds.miny - self.size[0]) // self.stride[0]) + 1 + cols = int((bounds.maxx - bounds.minx - self.size[1]) // self.stride[1]) + 1 - yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) + mint = bounds.mint + maxt = bounds.maxt - def __len__(self) -> int: - """Return the number of samples in a single epoch. + # For each row... + for i in range(rows): + miny = bounds.miny + i * self.stride[0] + maxy = miny + self.size[0] - Returns: - length of the epoch - """ - return self.rows * self.cols + # For each column... + for j in range(cols): + minx = bounds.minx + j * self.stride[1] + maxx = minx + self.size[1] + + yield BoundingBox(minx, maxx, miny, maxy, mint, maxt) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py new file mode 100644 index 00000000000..25022e656aa --- /dev/null +++ b/torchgeo/samplers/utils.py @@ -0,0 +1,18 @@ +"""Common sampler utilities.""" + +from typing import Tuple, Union + + +def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: + """Convert value to a tuple if it is not already a tuple. + + Args: + value: input value + + Returns: + value if value is a tuple, else (value, value) + """ + if isinstance(value, (float, int)): + return (value, value) + else: + return value