Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pixel sampling mode #294

Merged
merged 17 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ def set_up_parser() -> argparse.ArgumentParser:
"--patch-size",
default=224,
type=int,
help="height/width of each patch",
metavar="SIZE",
help="height/width of each patch in pixels",
metavar="PIXELS",
)
parser.add_argument(
"-s",
"--stride",
default=112,
type=int,
help="sampling stride for GridGeoSampler",
help="sampling stride for GridGeoSampler in pixels",
metavar="PIXELS",
)
parser.add_argument(
"-w",
Expand Down Expand Up @@ -139,15 +140,11 @@ def main(args: argparse.Namespace) -> None:
length = args.num_batches * args.batch_size
num_batches = args.num_batches

# Convert from pixel coords to CRS coords
size = args.patch_size * cdl.res
stride = args.stride * cdl.res

samplers = [
RandomGeoSampler(landsat, size=size, length=length),
GridGeoSampler(landsat, size=size, stride=stride),
RandomGeoSampler(landsat, size=args.patch_size, length=length),
GridGeoSampler(landsat, size=args.patch_size, stride=args.stride),
RandomBatchGeoSampler(
landsat, size=size, batch_size=args.batch_size, length=length
landsat, size=args.patch_size, batch_size=args.batch_size, length=length
),
]

Expand Down
29 changes: 27 additions & 2 deletions docs/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ Samplers are used to index a dataset, retrieving a single query at a time. For :
from torchgeo.samplers import RandomGeoSampler

dataset = Landsat(...)
sampler = RandomGeoSampler(dataset, size=1000, length=100)
sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, sampler=sampler)


This data loader will return 256x256 px images, and has an epoch length of 10,000.

Random Geo Sampler
^^^^^^^^^^^^^^^^^^

Expand All @@ -43,10 +45,12 @@ When working with large tile-based datasets, randomly sampling patches from each
from torchgeo.samplers import RandomBatchGeoSampler

dataset = Landsat(...)
sampler = RandomBatchGeoSampler(dataset, size=1000, batch_size=10, length=100)
sampler = RandomBatchGeoSampler(dataset, size=256, batch_size=128, length=10000)
dataloader = DataLoader(dataset, batch_sampler=sampler)


This data loader will return 256x256 px images, and has a batch size of 128 and an epoch length of 10,000.

Random Batch Geo Sampler
^^^^^^^^^^^^^^^^^^^^^^^^

Expand All @@ -66,3 +70,24 @@ Batch Geo Sampler
^^^^^^^^^^^^^^^^^

.. autoclass:: BatchGeoSampler

Units
-----

By default, the ``size`` parameter specifies the size of the image in *pixel* units. If you would instead like to specify the size in *CRS* units, you can change the ``units`` parameter like so:

.. code-block:: python

from torch.utils.data import DataLoader

from torchgeo.datasets import Landsat
from torchgeo.samplers import RandomGeoSampler, Units

dataset = Landsat(...)
sampler = RandomGeoSampler(dataset, size=256 * 30, length=10000, units=Units.CRS)
dataloader = DataLoader(dataset, sampler=sampler)


Assuming that each pixel in the CRS is 30 m, this data loader will return 256x256 px images, and has an epoch length of 10,000.

.. autoclass:: Units
74 changes: 50 additions & 24 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import math
from itertools import product
from typing import Dict, Iterator, List

import pytest
Expand All @@ -10,7 +11,7 @@
from torch.utils.data import DataLoader

from torchgeo.datasets import BoundingBox, GeoDataset, stack_samples
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler
from torchgeo.samplers import BatchGeoSampler, RandomBatchGeoSampler, Units


class CustomBatchGeoSampler(BatchGeoSampler):
Expand All @@ -26,7 +27,7 @@ def __len__(self) -> int:


class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 1) -> None:
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
super().__init__()
self._crs = crs
self.res = res
Expand All @@ -36,6 +37,10 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, BoundingBox]:


class TestBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
return CustomGeoDataset()

@pytest.fixture(scope="function")
def sampler(self) -> CustomBatchGeoSampler:
return CustomBatchGeoSampler()
Expand All @@ -49,28 +54,45 @@ def test_len(self, sampler: CustomBatchGeoSampler) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: CustomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: CustomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue

def test_abstract(self) -> None:
ds = CustomGeoDataset()
def test_abstract(self, dataset: CustomGeoDataset) -> None:
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
BatchGeoSampler(ds) # type: ignore[abstract]
BatchGeoSampler(dataset) # type: ignore[abstract]


class TestRandomBatchGeoSampler:
@pytest.fixture(scope="function", params=[3, 4.5, (2, 2), (3, 4.5), (4.5, 3)])
def sampler(self, request: SubRequest) -> RandomBatchGeoSampler:
@pytest.fixture(scope="class")
def dataset(self) -> CustomGeoDataset:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 20, 30, 40, 50))
ds.index.insert(1, (0, 10, 20, 30, 40, 50))
size = request.param
return RandomBatchGeoSampler(ds, size, batch_size=2, length=10)
ds.index.insert(0, (0, 100, 200, 300, 400, 500))
ds.index.insert(1, (0, 100, 200, 300, 400, 500))
return ds

@pytest.fixture(
scope="function",
params=product([3, 4.5, (2, 2), (3, 4.5), (4.5, 3)], [Units.PIXELS, Units.CRS]),
)
def sampler(
self, dataset: CustomGeoDataset, request: SubRequest
) -> RandomBatchGeoSampler:
size, units = request.param
return RandomBatchGeoSampler(
dataset, size, batch_size=2, length=10, units=units
)

def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
for batch in sampler:
Expand All @@ -88,18 +110,15 @@ def test_iter(self, sampler: RandomBatchGeoSampler) -> None:
def test_len(self, sampler: RandomBatchGeoSampler) -> None:
assert len(sampler) == sampler.length // sampler.batch_size

def test_roi(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (5, 15, 5, 15, 5, 15))
roi = BoundingBox(0, 10, 0, 10, 0, 10)
sampler = RandomBatchGeoSampler(ds, 2, 2, 10, roi=roi)
def test_roi(self, dataset: CustomGeoDataset) -> None:
roi = BoundingBox(0, 50, 200, 250, 400, 450)
sampler = RandomBatchGeoSampler(dataset, 2, 2, 10, roi=roi)
for batch in sampler:
for query in batch:
assert query in roi

def test_small_area(self) -> None:
ds = CustomGeoDataset()
ds = CustomGeoDataset(res=1)
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (20, 21, 20, 21, 20, 21))
sampler = RandomBatchGeoSampler(ds, 2, 2, 10)
Expand All @@ -108,10 +127,17 @@ def test_small_area(self) -> None:

@pytest.mark.slow
@pytest.mark.parametrize("num_workers", [0, 1, 2])
def test_dataloader(self, sampler: RandomBatchGeoSampler, num_workers: int) -> None:
ds = CustomGeoDataset()
def test_dataloader(
self,
dataset: CustomGeoDataset,
sampler: RandomBatchGeoSampler,
num_workers: int,
) -> None:
dl = DataLoader(
ds, batch_sampler=sampler, num_workers=num_workers, collate_fn=stack_samples
dataset,
batch_sampler=sampler,
num_workers=num_workers,
collate_fn=stack_samples,
)
for _ in dl:
continue
Loading