From f262b005803e317c542cb7317c91c61b4c25fd31 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 5 Apr 2022 11:48:33 -0500 Subject: [PATCH] RandomGeoSampler: several bug fixes (#477) * RandomGeoSampler: prevent area bias * Use builtin PyTorch random Co-authored-by: Caleb Robinson --- tests/samplers/test_batch.py | 17 +++++++++++++++++ tests/samplers/test_single.py | 16 ++++++++++++++++ torchgeo/samplers/batch.py | 18 +++++++++++++----- torchgeo/samplers/single.py | 17 ++++++++++++----- torchgeo/samplers/utils.py | 18 +++++++++++++----- 5 files changed, 71 insertions(+), 15 deletions(-) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 92d4d8c47fb..9dc32394d3d 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -127,6 +127,23 @@ def test_small_area(self) -> None: for _ in sampler: continue + def test_point_data(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 0, 0, 0, 0, 0)) + ds.index.insert(1, (1, 1, 1, 1, 1, 1)) + sampler = RandomBatchGeoSampler(ds, 0, 2, 10) + for _ in sampler: + continue + + def test_weighted_sampling(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 0, 0, 0, 0, 0)) + ds.index.insert(1, (0, 10, 0, 10, 0, 10)) + sampler = RandomBatchGeoSampler(ds, 1, 2, 10) + for batch in sampler: + for bbox in batch: + assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader( diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index e7545459dd1..2380bb119fd 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -122,6 +122,22 @@ def test_small_area(self) -> None: for _ in sampler: continue + def test_point_data(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 0, 0, 0, 0, 0)) + ds.index.insert(1, (1, 1, 1, 1, 1, 1)) + sampler = RandomGeoSampler(ds, 0, 10) + for _ in sampler: + continue + + def test_weighted_sampling(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 0, 0, 0, 0, 0)) + ds.index.insert(1, (0, 10, 0, 10, 0, 10)) + sampler = RandomGeoSampler(ds, 1, 10) + for bbox in sampler: + assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + @pytest.mark.slow @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index e269a748db6..b8b06aa3fe8 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -4,9 +4,9 @@ """TorchGeo batch samplers.""" import abc -import random from typing import Iterator, List, Optional, Tuple, Union +import torch from rtree.index import Index, Property from torch.utils.data import Sampler @@ -104,13 +104,20 @@ def __init__( self.batch_size = batch_size self.length = length self.hits = [] + areas = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) if ( - bounds.maxx - bounds.minx > self.size[1] - and bounds.maxy - bounds.miny > self.size[0] + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] ): self.hits.append(hit) + areas.append(bounds.area) + + # torch.multinomial requires float probabilities > 0 + self.areas = torch.tensor(areas, dtype=torch.float) + if torch.sum(self.areas) == 0: + self.areas += 1 def __iter__(self) -> Iterator[List[BoundingBox]]: """Return the indices of a dataset. @@ -119,8 +126,9 @@ def __iter__(self) -> Iterator[List[BoundingBox]]: batch of (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ for _ in range(len(self)): - # Choose a random tile - hit = random.choice(self.hits) + # Choose a random tile, weighted by area + idx = torch.multinomial(self.areas, 1) + hit = self.hits[idx] bounds = BoundingBox(*hit.bounds) # Choose random indices within that tile diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 01846fe0e1c..e31d13bdd44 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -4,7 +4,6 @@ """TorchGeo samplers.""" import abc -import random from typing import Callable, Iterable, Iterator, Optional, Tuple, Union import torch @@ -105,13 +104,20 @@ def __init__( self.length = length self.hits = [] + areas = [] for hit in self.index.intersection(tuple(self.roi), objects=True): bounds = BoundingBox(*hit.bounds) if ( - bounds.maxx - bounds.minx > self.size[1] - and bounds.maxy - bounds.miny > self.size[0] + bounds.maxx - bounds.minx >= self.size[1] + and bounds.maxy - bounds.miny >= self.size[0] ): self.hits.append(hit) + areas.append(bounds.area) + + # torch.multinomial requires float probabilities > 0 + self.areas = torch.tensor(areas, dtype=torch.float) + if torch.sum(self.areas) == 0: + self.areas += 1 def __iter__(self) -> Iterator[BoundingBox]: """Return the index of a dataset. @@ -120,8 +126,9 @@ def __iter__(self) -> Iterator[BoundingBox]: (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset """ for _ in range(len(self)): - # Choose a random tile - hit = random.choice(self.hits) + # Choose a random tile, weighted by area + idx = torch.multinomial(self.areas, 1) + hit = self.hits[idx] bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index f8382626ee8..94a8c5622a1 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -3,9 +3,10 @@ """Common sampler utilities.""" -import random from typing import Tuple, Union +import torch + from ..datasets import BoundingBox @@ -46,11 +47,18 @@ def get_random_bounding_box( t_size = _to_tuple(size) width = (bounds.maxx - bounds.minx - t_size[1]) // res - minx = random.randrange(int(width)) * res + bounds.minx - maxx = minx + t_size[1] - height = (bounds.maxy - bounds.miny - t_size[0]) // res - miny = random.randrange(int(height)) * res + bounds.miny + + minx = bounds.minx + miny = bounds.miny + + # random.randrange crashes for inputs <= 0 + if width > 0: + minx += torch.rand(1).item() * width * res + if height > 0: + miny += torch.rand(1).item() * height * res + + maxx = minx + t_size[1] maxy = miny + t_size[0] mint = bounds.mint