Skip to content

Commit

Permalink
IntersectionDataset: better error message when no overlap (#1192)
Browse files Browse the repository at this point in the history
* IntersectionDataset: better error message when no overlap

* Update split tests

* Document error
  • Loading branch information
adamjstewart authored and calebrob6 committed Apr 10, 2023
1 parent 4dfbcd5 commit 586190f
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,20 @@ def test_nongeo_dataset(self) -> None:
IntersectionDataset(ds1, ds2) # type: ignore[arg-type]

def test_different_crs(self) -> None:
ds1 = CustomGeoDataset(crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(crs=CRS.from_epsg(32616))
ds1 = CustomGeoDataset(BoundingBox(0, 1, 0, 1, 0, 1), crs=CRS.from_epsg(3005))
ds2 = CustomGeoDataset(
BoundingBox(
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
-3547229.913123814,
6360089.518213182,
),
crs=CRS.from_epsg(32616),
)
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0
assert len(ds) == 1

def test_different_res(self) -> None:
ds1 = CustomGeoDataset(res=1)
Expand All @@ -419,8 +429,9 @@ def test_different_res(self) -> None:
def test_no_overlap(self) -> None:
ds1 = CustomGeoDataset(BoundingBox(0, 1, 2, 3, 4, 5))
ds2 = CustomGeoDataset(BoundingBox(6, 7, 8, 9, 10, 11))
ds = IntersectionDataset(ds1, ds2)
assert len(ds) == 0
msg = "Datasets have no spatiotemporal intersection"
with pytest.raises(RuntimeError, match=msg):
IntersectionDataset(ds1, ds2)

def test_invalid_query(self, dataset: IntersectionDataset) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
Expand Down
323 changes: 323 additions & 0 deletions tests/datasets/test_splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from math import floor, isclose
from typing import Any, Dict, List, Sequence, Tuple, Union

import pytest
from rasterio.crs import CRS

from torchgeo.datasets import (
BoundingBox,
GeoDataset,
random_bbox_assignment,
random_bbox_splitting,
random_grid_cell_assignment,
roi_split,
time_series_split,
)


def total_area(dataset: GeoDataset) -> float:
total_area = 0.0
for hit in dataset.index.intersection(dataset.index.bounds, objects=True):
total_area += BoundingBox(*hit.bounds).area

return total_area


def no_overlap(ds1: GeoDataset, ds2: GeoDataset) -> bool:
try:
ds = ds1 & ds2
except RuntimeError:
return True
else:
return isclose(total_area(ds), 0)


class CustomGeoDataset(GeoDataset):
def __init__(
self,
items: List[Tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), "")],
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
) -> None:
super().__init__()
for box, content in items:
self.index.insert(0, tuple(box), content)
self._crs = crs
self.res = res

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
hits = self.index.intersection(tuple(query), objects=True)
hit = next(iter(hits))
return {"content": hit.object}


@pytest.mark.parametrize(
"lengths,expected_lengths",
[
# List of lengths
([2, 1, 1], [2, 1, 1]),
# List of fractions (with remainder)
([1 / 3, 1 / 3, 1 / 3], [2, 1, 1]),
],
)
def test_random_bbox_assignment(
lengths: Sequence[Union[int, float]], expected_lengths: Sequence[int]
) -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

train_ds, val_ds, test_ds = random_bbox_assignment(ds, lengths)

# Check datasets lengths
assert len(train_ds) == expected_lengths[0]
assert len(val_ds) == expected_lengths[1]
assert len(test_ds) == expected_lengths[2]

# No overlap
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds

# Test __getitem__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)


def test_random_bbox_assignment_invalid_inputs() -> None:
with pytest.raises(
ValueError,
match="Sum of input lengths must equal 1 or the length of dataset's index.",
):
random_bbox_assignment(CustomGeoDataset(), lengths=[2, 2, 1])
with pytest.raises(
ValueError, match="All items in input lengths must be greater than 0."
):
random_bbox_assignment(CustomGeoDataset(), lengths=[1 / 2, 3 / 4, -1 / 4])


def test_random_bbox_splitting() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

ds_area = total_area(ds)

train_ds, val_ds, test_ds = random_bbox_splitting(
ds, fractions=[1 / 2, 1 / 4, 1 / 4]
)
train_ds_area = total_area(train_ds)
val_ds_area = total_area(val_ds)
test_ds_area = total_area(test_ds)

# Check datasets areas
assert train_ds_area == ds_area / 2
assert val_ds_area == ds_area / 4
assert test_ds_area == ds_area / 4

# No overlap
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(total_area(train_ds | val_ds | test_ds), ds_area)

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input fractions
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."):
random_bbox_splitting(ds, fractions=[1 / 2, 1 / 3, 1 / 4])
with pytest.raises(
ValueError, match="All items in input fractions must be greater than 0."
):
random_bbox_splitting(ds, fractions=[1 / 2, 3 / 4, -1 / 4])


def test_random_grid_cell_assignment() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 12, 0, 12, 0, 0), "a"),
(BoundingBox(12, 24, 0, 12, 0, 0), "b"),
]
)

train_ds, val_ds, test_ds = random_grid_cell_assignment(
ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=5
)

# Check datasets lengths
assert len(train_ds) == 1 / 2 * 2 * 5**2 + 1
assert len(val_ds) == floor(1 / 4 * 2 * 5**2)
assert len(test_ds) == floor(1 / 4 * 2 * 5**2)

# No overlap
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input fractions
with pytest.raises(ValueError, match="Sum of input fractions must equal 1."):
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 3, 1 / 4])
with pytest.raises(
ValueError, match="All items in input fractions must be greater than 0."
):
random_grid_cell_assignment(ds, fractions=[1 / 2, 3 / 4, -1 / 4])
with pytest.raises(ValueError, match="Input grid_size must be greater than 1."):
random_grid_cell_assignment(ds, fractions=[1 / 2, 1 / 4, 1 / 4], grid_size=1)


def test_roi_split() -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 0), "a"),
(BoundingBox(1, 2, 0, 1, 0, 0), "b"),
(BoundingBox(2, 3, 0, 1, 0, 0), "c"),
(BoundingBox(3, 4, 0, 1, 0, 0), "d"),
]
)

train_ds, val_ds, test_ds = roi_split(
ds,
rois=[
BoundingBox(0, 2, 0, 1, 0, 0),
BoundingBox(2, 3.5, 0, 1, 0, 0),
BoundingBox(3.5, 4, 0, 1, 0, 0),
],
)

# Check datasets lengths
assert len(train_ds) == 2
assert len(val_ds) == 2
assert len(test_ds) == 1

# No overlap
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds
assert isclose(total_area(train_ds | val_ds | test_ds), total_area(ds))

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)

# Test invalid input rois
with pytest.raises(ValueError, match="ROIs in input rois can't overlap."):
roi_split(
ds, rois=[BoundingBox(0, 2, 0, 1, 0, 0), BoundingBox(1, 3, 0, 1, 0, 0)]
)


@pytest.mark.parametrize(
"lengths,expected_lengths",
[
# List of timestamps
([(0, 20), (20, 35), (35, 40)], [2, 2, 1]),
# List of lengths
([20, 15, 5], [2, 2, 1]),
# List of fractions (with remainder)
([1 / 2, 3 / 8, 1 / 8], [2, 2, 1]),
],
)
def test_time_series_split(
lengths: Sequence[Union[Tuple[int, int], int, float]],
expected_lengths: Sequence[int],
) -> None:
ds = CustomGeoDataset(
[
(BoundingBox(0, 1, 0, 1, 0, 10), "a"),
(BoundingBox(0, 1, 0, 1, 10, 20), "b"),
(BoundingBox(0, 1, 0, 1, 20, 30), "c"),
(BoundingBox(0, 1, 0, 1, 30, 40), "d"),
]
)

train_ds, val_ds, test_ds = time_series_split(ds, lengths)

# Check datasets lengths
assert len(train_ds) == expected_lengths[0]
assert len(val_ds) == expected_lengths[1]
assert len(test_ds) == expected_lengths[2]

# No overlap
assert no_overlap(train_ds, val_ds)
assert no_overlap(val_ds, test_ds)
assert no_overlap(test_ds, train_ds)

# Union equals original
assert (train_ds | val_ds | test_ds).bounds == ds.bounds

# Test __get_item__
x = train_ds[train_ds.bounds]
assert isinstance(x, dict)
assert isinstance(x["content"], str)


def test_time_series_split_invalid_input() -> None:
with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths must have end greater than start.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (35, 20), (35, 40)])

with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths must cover dataset's time bounds.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 35)])

with pytest.raises(
ValueError,
match="Pairs of timestamps in lengths can't be out of dataset's time bounds.",
):
time_series_split(CustomGeoDataset(), lengths=[(0, 20), (20, 45)])

with pytest.raises(
ValueError, match="Pairs of timestamps in lengths can't overlap."
):
time_series_split(CustomGeoDataset(), lengths=[(0, 10), (10, 20), (15, 40)])

with pytest.raises(
ValueError,
match="Sum of input lengths must equal 1 or the dataset's time length.",
):
time_series_split(CustomGeoDataset(), lengths=[1 / 2, 1 / 2, 1 / 2])

with pytest.raises(
ValueError, match="All items in input lengths must be greater than 0."
):
time_series_split(CustomGeoDataset(), lengths=[20, 25, -5])
4 changes: 4 additions & 0 deletions torchgeo/datasets/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def __init__(
entry and returns a transformed version
Raises:
RuntimeError: if datasets have no spatiotemporal intersection
ValueError: if either dataset is not a :class:`GeoDataset`
.. versionadded:: 0.4
Expand Down Expand Up @@ -855,6 +856,9 @@ def _merge_dataset_indices(self) -> None:
self.index.insert(i, tuple(box1 & box2))
i += 1

if i == 0:
raise RuntimeError("Datasets have no spatiotemporal intersection")

def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
"""Retrieve image and metadata indexed by query.
Expand Down

0 comments on commit 586190f

Please sign in to comment.