-
Notifications
You must be signed in to change notification settings - Fork 385
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d5b4a5c
commit c385433
Showing
9 changed files
with
349 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,68 @@ | ||
torchgeo.samplers | ||
================= | ||
|
||
.. automodule:: torchgeo.samplers | ||
.. module:: torchgeo.samplers | ||
|
||
Samplers | ||
-------- | ||
|
||
Samplers are used to index a dataset, retrieving a single query at a time. For :class:`~torchgeo.datasets.VisionDataset`, dataset objects can be indexed with integers, and PyTorch's builtin samplers are sufficient. For :class:`~torchgeo.datasets.GeoDataset`, dataset objects require a bounding box for indexing. For this reason, we define our own :class:`GeoSampler` implementations below. These can be used like so: | ||
|
||
.. code-block:: python | ||
from torch.utils.data import DataLoader | ||
from torchgeo.datasets import Landsat | ||
from torchgeo.samplers import RandomGeoSampler | ||
dataset = Landsat(...) | ||
sampler = RandomGeoSampler(dataset.index, size=1000, length=100) | ||
dataloader = DataLoader(dataset, sampler=sampler) | ||
Random Geo Sampler | ||
^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: RandomGeoSampler | ||
|
||
Grid Geo Sampler | ||
^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: GridGeoSampler | ||
|
||
Batch Samplers | ||
-------------- | ||
|
||
When working with large tile-based datasets, randomly sampling patches from each tile can be extremely time consuming. It's much more efficient to choose a tile, load it, warp it to the appropriate :term:`coordinate reference system (CRS)` and resolution, and then sample random patches from that tile to construct a mini-batch of data. For this reason, we define our own :class:`BatchGeoSampler` implementations below. These can be used like so: | ||
|
||
.. code-block:: python | ||
from torch.utils.data import DataLoader | ||
from torchgeo.datasets import Landsat | ||
from torchgeo.samplers import RandomBatchGeoSampler | ||
dataset = Landsat(...) | ||
sampler = RandomBatchGeoSampler(dataset.index, size=1000, batch_size=10, length=100) | ||
dataloader = DataLoader(dataset, batch_sampler=sampler) | ||
Random Batch Geo Sampler | ||
^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: RandomBatchGeoSampler | ||
|
||
Sampler Base Classes | ||
-------------------- | ||
|
||
If you want to write your own custom sampler, you can extend one of these abstract base classes. | ||
|
||
Geo Sampler | ||
^^^^^^^^^^^ | ||
|
||
.. autoclass:: GeoSampler | ||
|
||
Batch Geo Sampler | ||
^^^^^^^^^^^^^^^^^ | ||
|
||
.. autoclass:: BatchGeoSampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import math | ||
from typing import Iterator, List | ||
|
||
import pytest | ||
from _pytest.fixtures import SubRequest | ||
from rtree.index import Index, Property | ||
|
||
from torchgeo.datasets import BoundingBox | ||
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler | ||
|
||
|
||
class CustomBatchGeoSampler(BatchGeoSampler): | ||
def __init__(self) -> None: | ||
pass | ||
|
||
def __iter__(self) -> Iterator[List[BoundingBox]]: | ||
for i in range(len(self)): | ||
yield [BoundingBox(j, j, j, j, j, j) for j in range(len(self))] | ||
|
||
def __len__(self) -> int: | ||
return 2 | ||
|
||
|
||
class TestBatchGeoSampler: | ||
@pytest.fixture(scope="function") | ||
def sampler(self) -> CustomBatchGeoSampler: | ||
return CustomBatchGeoSampler() | ||
|
||
def test_iter(self, sampler: CustomBatchGeoSampler) -> None: | ||
expected = [BoundingBox(0, 0, 0, 0, 0, 0), BoundingBox(1, 1, 1, 1, 1, 1)] | ||
assert next(iter(sampler)) == expected | ||
|
||
def test_len(self, sampler: CustomBatchGeoSampler) -> None: | ||
assert len(sampler) == 2 | ||
|
||
def test_abstract(self) -> None: | ||
with pytest.raises(TypeError, match="Can't instantiate abstract class"): | ||
BatchGeoSampler(None) # type: ignore[abstract] | ||
|
||
|
||
class TestRandomBatchGeoSampler: | ||
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)]) | ||
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler: | ||
index = Index(interleaved=False, properties=Property(dimension=3)) | ||
index.insert(0, (0, 10, 20, 30, 40, 50)) | ||
index.insert(1, (0, 10, 20, 30, 40, 50)) | ||
size = request.param | ||
return RandomBatchGeoSampler(index, size, batch_size=2, length=10) | ||
|
||
def test_iter(self, sampler: RandomBatchGeoSampler) -> None: | ||
for batch in sampler: | ||
for query in batch: | ||
assert sampler.roi.minx <= query.minx <= query.maxx <= sampler.roi.maxx | ||
assert sampler.roi.miny <= query.miny <= query.miny <= sampler.roi.maxy | ||
assert sampler.roi.mint <= query.mint <= query.maxt <= sampler.roi.maxt | ||
|
||
assert math.isclose(query.maxx - query.minx, sampler.size[1]) | ||
assert math.isclose(query.maxy - query.miny, sampler.size[0]) | ||
assert math.isclose( | ||
query.maxt - query.mint, sampler.roi.maxt - sampler.roi.mint | ||
) | ||
|
||
def test_len(self, sampler: RandomBatchGeoSampler) -> None: | ||
assert len(sampler) == sampler.length // sampler.batch_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
"""TorchGeo batch samplers.""" | ||
|
||
import abc | ||
import random | ||
from typing import Iterator, List, Optional, Tuple, Union | ||
|
||
from rtree.index import Index | ||
from torch.utils.data import Sampler | ||
|
||
from torchgeo.datasets import BoundingBox | ||
|
||
from .utils import _to_tuple | ||
|
||
# https://github.com/pytorch/pytorch/issues/60979 | ||
# https://github.com/pytorch/pytorch/pull/61045 | ||
Sampler.__module__ = "torch.utils.data" | ||
|
||
|
||
class BatchGeoSampler(Sampler[List[BoundingBox]], abc.ABC): | ||
"""Abstract base class for sampling from :class:`~torchgeo.datasets.GeoDataset`. | ||
Unlike PyTorch's :class:`~torch.utils.data.BatchSampler`, :class:`BatchGeoSampler` | ||
returns enough geospatial information to uniquely index any | ||
:class:`~torchgeo.datasets.GeoDataset`. This includes things like latitude, | ||
longitude, height, width, projection, coordinate system, and time. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def __iter__(self) -> Iterator[List[BoundingBox]]: | ||
"""Return a batch of indices of a dataset. | ||
Returns: | ||
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset | ||
""" | ||
|
||
|
||
class RandomBatchGeoSampler(BatchGeoSampler): | ||
"""Samples batches of elements from a region of interest randomly. | ||
This is particularly useful during training when you want to maximize the size of | ||
the dataset and return as many random :term:`chips <chip>` as possible. | ||
When sampling from :class:`~torchgeo.datasets.ZipDataset`, the ``index`` should come | ||
from a tile-based dataset if possible. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
index: Index, | ||
size: Union[Tuple[float, float], float], | ||
batch_size: int, | ||
length: int, | ||
roi: Optional[BoundingBox] = None, | ||
) -> None: | ||
"""Initialize a new Sampler instance. | ||
The ``size`` argument can either be: | ||
* a single ``float`` - in which case the same value is used for the height and | ||
width dimension | ||
* a ``tuple`` of two floats - in which case, the first *float* is used for the | ||
height dimension, and the second *float* for the width dimension | ||
Args: | ||
index: index of a :class:`~torchgeo.datasets.GeoDataset` | ||
size: dimensions of each :term:`patch` in units of CRS | ||
batch_size: number of samples per batch | ||
length: number of samples per epoch | ||
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) | ||
(defaults to the bounds of ``index``) | ||
""" | ||
self.index = index | ||
self.size = _to_tuple(size) | ||
self.batch_size = batch_size | ||
self.length = length | ||
if roi is None: | ||
roi = BoundingBox(*index.bounds) | ||
self.roi = roi | ||
self.hits = list(index.intersection(roi, objects=True)) | ||
|
||
def __iter__(self) -> Iterator[List[BoundingBox]]: | ||
"""Return the indices of a dataset. | ||
Returns: | ||
batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset | ||
""" | ||
for _ in range(len(self)): | ||
# Choose a random tile | ||
hit = random.choice(self.hits) | ||
bounds = BoundingBox(*hit.bounds) | ||
|
||
# Choose random indices within that tile | ||
batch = [] | ||
for _ in range(self.batch_size): | ||
minx = random.uniform(bounds.minx, bounds.maxx - self.size[1]) | ||
maxx = minx + self.size[1] | ||
|
||
miny = random.uniform(bounds.miny, bounds.maxy - self.size[0]) | ||
maxy = miny + self.size[0] | ||
|
||
mint = bounds.mint | ||
maxt = bounds.maxt | ||
|
||
batch.append(BoundingBox(minx, maxx, miny, maxy, mint, maxt)) | ||
|
||
yield batch | ||
|
||
def __len__(self) -> int: | ||
"""Return the number of batches in a single epoch. | ||
Returns: | ||
number of batches in an epoch | ||
""" | ||
return self.length // self.batch_size |
Oops, something went wrong.