From 494fbd76672b4a29e0a6f8736b604369346d8bdd Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Fri, 20 Sep 2024 19:43:38 +0200 Subject: [PATCH 1/3] Add random generator --- torchgeo/datamodules/agrifieldnet.py | 6 +++++- torchgeo/samplers/batch.py | 7 ++++++- torchgeo/samplers/single.py | 20 +++++++++++++++++--- torchgeo/samplers/utils.py | 10 +++++++--- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index bed6365d4a2..c5b92b6b01a 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None: if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( - self.train_dataset, self.patch_size, self.batch_size, self.length + self.train_dataset, + self.patch_size, + self.batch_size, + self.length, + generator=generator, ) if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 22726f74b2c..396ad0f0c7b 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -70,6 +70,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -97,9 +98,11 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: random number generator """ super().__init__(dataset, roi) self.size = _to_tuple(size) + self.generator = generator if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -144,7 +147,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]: # Choose random indices within that tile batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..ea943db3d53 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,6 +5,7 @@ import abc from collections.abc import Callable, Iterable, Iterator +from functools import partial import torch from rtree.index import Index, Property @@ -72,6 +73,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -98,6 +100,8 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: The random generator used for sampling. + """ super().__init__(dataset, roi) self.size = _to_tuple(size) @@ -105,6 +109,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.generator = generator self.length = 0 self.hits = [] areas = [] @@ -142,7 +147,9 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) yield bounding_box @@ -270,7 +277,11 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | None = None, + shuffle: bool = False, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -281,9 +292,12 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch + generator: The random number generator used in combination with shuffle. + """ super().__init__(dataset, roi) self.shuffle = shuffle + self.generator = generator self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -297,7 +311,7 @@ def __iter__(self) -> Iterator[BoundingBox]: """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: - generator = torch.randperm + generator = partial(torch.randperm, generator=self.generator) for idx in generator(len(self)): yield BoundingBox(*self.hits[idx].bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index a1fca673a3a..258f74a5425 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -35,7 +35,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: tuple[float, float] | float, res: float + bounds: BoundingBox, + size: tuple[float, float] | float, + res: float, + generator: torch.Generator | None = None, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -50,6 +53,7 @@ def get_random_bounding_box( bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image + generator: random number generator Returns: randomly sampled bounding box from the extent of the input @@ -64,8 +68,8 @@ def get_random_bounding_box( miny = bounds.miny # Use an integer multiple of res to avoid resampling - minx += int(torch.rand(1).item() * width) * res - miny += int(torch.rand(1).item() * height) * res + minx += int(torch.rand(1, generator=generator).item() * width) * res + miny += int(torch.rand(1, generator=generator).item() * height) * res maxx = minx + t_size[1] maxy = miny + t_size[0] From 5a9e107fd1b5556f177d9607e426e021ab57a75a Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Fri, 20 Sep 2024 20:21:19 +0200 Subject: [PATCH 2/3] Add tests for seed --- tests/samplers/test_batch.py | 16 ++++++++++++++++ tests/samplers/test_single.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 59c8aaa00be..20ad33a58c9 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -144,6 +145,21 @@ def test_weighted_sampling(self) -> None: for bbox in batch: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + generator = torch.manual_seed(0) + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample1 = bbox + break + + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..15f1025f672 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -139,6 +140,21 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + generator = torch.manual_seed(0) + sampler = RandomGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample1 = bbox + break + + sampler = RandomGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -288,6 +304,22 @@ def test_point_data(self) -> None: for _ in sampler: continue + def test_shuffle_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 11, 0, 11, 0, 11)) + generator = torch.manual_seed(0) + sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler: + sample1 = bbox + break + + sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( From 46e1f11d440ecf1363393d7e616666cbd7f3e9f5 Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Fri, 20 Sep 2024 18:49:44 +0000 Subject: [PATCH 3/3] pass generator every sampler --- tests/samplers/test_batch.py | 5 ++--- tests/samplers/test_single.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 20ad33a58c9..16b99e16a93 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -148,13 +148,12 @@ def test_weighted_sampling(self) -> None: def test_random_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - generator = torch.manual_seed(0) - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) for bbox in sampler: sample1 = bbox break - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator) + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) for bbox in sampler: sample2 = bbox break diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 15f1025f672..abbf22d2727 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -308,15 +308,18 @@ def test_shuffle_seed(self) -> None: ds = CustomGeoDataset() ds.index.insert(0, (0, 10, 0, 10, 0, 10)) ds.index.insert(1, (0, 11, 0, 11, 0, 11)) - generator = torch.manual_seed(0) - sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler: + generator = torch.manual_seed(2) + sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler1: sample1 = bbox + print(sample1) break - sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler: + generator = torch.manual_seed(2) + sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler2: sample2 = bbox + print(sample2) break assert sample1 == sample2