Skip to content

Commit

Permalink
More intelligent sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Aug 16, 2021
1 parent d5b4a5c commit c385433
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 71 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@
"python": ("https://docs.python.org/3", None),
"pytorch-lightning": ("https://pytorch-lightning.readthedocs.io/en/latest/", None),
"rasterio": ("https://rasterio.readthedocs.io/en/latest/", None),
"rtree": ("https://rtree.readthedocs.io/en/latest/", None),
"torch": ("https://pytorch.org/docs/stable", None),
}
6 changes: 3 additions & 3 deletions docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ torchgeo.datasets

.. module:: torchgeo.datasets

In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Base Classes`.
In :mod:`torchgeo`, we define two types of datasets: :ref:`Geospatial Datasets` and :ref:`Non-geospatial Datasets`. These abstract base classes are documented in more detail in :ref:`Dataset Base Classes`.

Geospatial Datasets
-------------------
Expand Down Expand Up @@ -107,8 +107,8 @@ NWPU VHR-10

.. autoclass:: VHR10

Base Classes
------------
Dataset Base Classes
--------------------

If you want to write your own custom dataset, you can extend one of these abstract base classes.

Expand Down
66 changes: 65 additions & 1 deletion docs/samplers.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,68 @@
torchgeo.samplers
=================

.. automodule:: torchgeo.samplers
.. module:: torchgeo.samplers

Samplers
--------

Samplers are used to index a dataset, retrieving a single query at a time. For :class:`~torchgeo.datasets.VisionDataset`, dataset objects can be indexed with integers, and PyTorch's builtin samplers are sufficient. For :class:`~torchgeo.datasets.GeoDataset`, dataset objects require a bounding box for indexing. For this reason, we define our own :class:`GeoSampler` implementations below. These can be used like so:

.. code-block:: python
from torch.utils.data import DataLoader
from torchgeo.datasets import Landsat
from torchgeo.samplers import RandomGeoSampler
dataset = Landsat(...)
sampler = RandomGeoSampler(dataset.index, size=1000, length=100)
dataloader = DataLoader(dataset, sampler=sampler)
Random Geo Sampler
^^^^^^^^^^^^^^^^^^

.. autoclass:: RandomGeoSampler

Grid Geo Sampler
^^^^^^^^^^^^^^^^

.. autoclass:: GridGeoSampler

Batch Samplers
--------------

When working with large tile-based datasets, randomly sampling patches from each tile can be extremely time consuming. It's much more efficient to choose a tile, load it, warp it to the appropriate :term:`coordinate reference system (CRS)` and resolution, and then sample random patches from that tile to construct a mini-batch of data. For this reason, we define our own :class:`BatchGeoSampler` implementations below. These can be used like so:

.. code-block:: python
from torch.utils.data import DataLoader
from torchgeo.datasets import Landsat
from torchgeo.samplers import RandomBatchGeoSampler
dataset = Landsat(...)
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100)
dataloader = DataLoader(dataset, batch_sampler=sampler)
Random Batch Geo Sampler
^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: RandomBatchGeoSampler

Sampler Base Classes
--------------------

If you want to write your own custom sampler, you can extend one of these abstract base classes.

Geo Sampler
^^^^^^^^^^^

.. autoclass:: GeoSampler

Batch Geo Sampler
^^^^^^^^^^^^^^^^^

.. autoclass:: BatchGeoSampler
64 changes: 64 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import math
from typing import Iterator, List

import pytest
from _pytest.fixtures import SubRequest
from rtree.index import Index, Property

from torchgeo.datasets import BoundingBox
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler


class CustomBatchGeoSampler(BatchGeoSampler):
def __init__(self) -> None:
pass

def __iter__(self) -> Iterator[List[BoundingBox]]:
for i in range(len(self)):
yield [BoundingBox(j, j, j, j, j, j) for j in range(len(self))]

def __len__(self) -> int:
return 2


class TestBatchGeoSampler:
@pytest.fixture(scope="function")
def sampler(self) -> CustomBatchGeoSampler:
return CustomBatchGeoSampler()

def test_iter(self, sampler: CustomBatchGeoSampler) -> None:
expected = [BoundingBox(0, 0, 0, 0, 0, 0), BoundingBox(1, 1, 1, 1, 1, 1)]
assert next(iter(sampler)) == expected

def test_len(self, sampler: CustomBatchGeoSampler) -> None:
assert len(sampler) == 2

def test_abstract(self) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BatchGeoSampler(None) # type: ignore[abstract]


class TestRandomBatchGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomBatchGeoSampler(index, size, batch_size=2, length=10)

def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:
for query in batch:
assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx
assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt

assert math.isclose(query.maxx - query.minx, sampler.size[1])
assert math.isclose(query.maxy - query.miny, sampler.size[0])
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)

def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size
22 changes: 12 additions & 10 deletions tests/samplers/test_samplers.py → tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
from _pytest.fixtures import SubRequest
from rtree.index import Index, Property

from torchgeo.datasets import BoundingBox
from torchgeo.samplers import GeoSampler, GridGeoSampler, RandomGeoSampler
Expand All @@ -22,13 +23,13 @@ def __len__(self) -> int:

class TestGeoSampler:
@pytest.fixture(scope="function")
def sampler(self) -> GeoSampler:
def sampler(self) -> CustomGeoSampler:
return CustomGeoSampler()

def test_iter(self, sampler: GeoSampler) -> None:
def test_iter(self, sampler: CustomGeoSampler) -> None:
assert next(iter(sampler)) == BoundingBox(0, 0, 0, 0, 0, 0)

def test_len(self, sampler: GeoSampler) -> None:
def test_len(self, sampler: CustomGeoSampler) -> None:
assert len(sampler) == 2

def test_abstract(self) -> None:
Expand All @@ -39,9 +40,11 @@ def test_abstract(self) -> None:
class TestRandomGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomGeoSampler:
roi = BoundingBox(0, 10, 20, 30, 40, 50)
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomGeoSampler(roi, size, length=10)
return RandomGeoSampler(index, size, length=10)

def test_iter(self, sampler: RandomGeoSampler) -> None:
for query in sampler:
Expand Down Expand Up @@ -72,9 +75,11 @@ class TestGridGeoSampler:
],
)
def sampler(self, request: SubRequest) -> GridGeoSampler:
roi = BoundingBox(0, 10, 20, 30, 40, 50)
index = Index(interleaved=False, properties=Property(dimension=3))
index.insert(0, (0, 10, 20, 30, 40, 50))
index.insert(1, (0, 10, 20, 30, 40, 50))
size, stride = request.param
return GridGeoSampler(roi, size, stride)
return GridGeoSampler(index, size, stride)

def test_iter(self, sampler: GridGeoSampler) -> None:
for query in sampler:
Expand All @@ -87,6 +92,3 @@ def test_iter(self, sampler: GridGeoSampler) -> None:
assert math.isclose(
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint
)

def test_len(self, sampler: RandomGeoSampler) -> None:
assert len(sampler) == 9
14 changes: 12 additions & 2 deletions torchgeo/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
"""TorchGeo samplers."""

from .samplers import GeoSampler, GridGeoSampler, RandomGeoSampler
from .batch import BatchGeoSampler, RandomBatchGeoSampler
from .single import GeoSampler, GridGeoSampler, RandomGeoSampler

__all__ = ("GeoSampler", "GridGeoSampler", "RandomGeoSampler")
__all__ = (
# Samplers
"GridGeoSampler",
"RandomGeoSampler",
# Batch samplers
"RandomBatchGeoSampler",
# Base classes
"GeoSampler",
"BatchGeoSampler",
)

# https://stackoverflow.com/questions/40018681
for module in __all__:
Expand Down
114 changes: 114 additions & 0 deletions torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""TorchGeo batch samplers."""

import abc
import random
from typing import Iterator, List, Optional, Tuple, Union

from rtree.index import Index
from torch.utils.data import Sampler

from torchgeo.datasets import BoundingBox

from .utils import _to_tuple

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Sampler.__module__ = "torch.utils.data"


class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC):
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`.
Unlike PyTorch's :class:`~torch.utils.data.BatchSampler`, :class:`BatchGeoSampler`
returns enough geospatial information to uniquely index any
:class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude,
longitude, height, width, projection, coordinate system, and time.
"""

@abc.abstractmethod
def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return a batch of indices of a dataset.
Returns:
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""


class RandomBatchGeoSampler(BatchGeoSampler):
"""Samples batches of elements from a region of interest randomly.
This is particularly useful during training when you want to maximize the size of
the dataset and return as many random :term:`chips <chip>` as possible.
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come
from a tile-based dataset if possible.
"""

def __init__(
self,
index: Index,
size: Union[Tuple[float, float], float],
batch_size: int,
length: int,
roi: Optional[BoundingBox] = None,
) -> None:
"""Initialize a new Sampler instance.
The ``size`` argument can either be:
* a single ``float`` - in which case the same value is used for the height and
width dimension
* a ``tuple`` of two floats - in which case, the first *float* is used for the
height dimension, and the second *float* for the width dimension
Args:
index: index of a :class:`~torchgeo.datasets.GeoDataset`
size: dimensions of each :term:`patch` in units of CRS
batch_size: number of samples per batch
length: number of samples per epoch
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``index``)
"""
self.index = index
self.size = _to_tuple(size)
self.batch_size = batch_size
self.length = length
if roi is None:
roi = BoundingBox(*index.bounds)
self.roi = roi
self.hits = list(index.intersection(roi, objects=True))

def __iter__(self) -> Iterator[List[BoundingBox]]:
"""Return the indices of a dataset.
Returns:
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset
"""
for _ in range(len(self)):
# Choose a random tile
hit = random.choice(self.hits)
bounds = BoundingBox(*hit.bounds)

# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
minx = random.uniform(bounds.minx, bounds.maxx - self.size[1])
maxx = minx + self.size[1]

miny = random.uniform(bounds.miny, bounds.maxy - self.size[0])
maxy = miny + self.size[0]

mint = bounds.mint
maxt = bounds.maxt

batch.append(BoundingBox(minx, maxx, miny, maxy, mint, maxt))

yield batch

def __len__(self) -> int:
"""Return the number of batches in a single epoch.
Returns:
number of batches in an epoch
"""
return self.length // self.batch_size
Loading

0 comments on commit c385433

Please sign in to comment.