diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 59c8aaa00be..01dea100ca8 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -144,6 +145,15 @@ def test_weighted_sampling(self) -> None: for bbox in batch: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..743c8be70da 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -139,6 +140,15 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler1 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sampler2 = RandomGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -288,6 +298,20 @@ def test_point_data(self) -> None: for _ in sampler: continue + def test_shuffle_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 11, 0, 11, 0, 11)) + sampler1 = PreChippedGeoSampler( + ds, shuffle=True, generator=torch.manual_seed(2) + ) + sampler2 = PreChippedGeoSampler( + ds, shuffle=True, generator=torch.manual_seed(2) + ) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) + assert sample1 != sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index bed6365d4a2..c5b92b6b01a 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None: if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( - self.train_dataset, self.patch_size, self.batch_size, self.length + self.train_dataset, + self.patch_size, + self.batch_size, + self.length, + generator=generator, ) if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 22726f74b2c..686b458ce24 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -8,6 +8,7 @@ import torch from rtree.index import Index, Property +from torch import Generator from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset @@ -70,6 +71,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -86,6 +88,9 @@ def __init__( .. versionchanged:: 0.4 ``length`` parameter is now optional, a reasonable default will be used + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from size: dimensions of each :term:`patch` @@ -97,9 +102,11 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: pseudo-random number generator (PRNG). """ super().__init__(dataset, roi) self.size = _to_tuple(size) + self.generator = generator if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -144,7 +151,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]: # Choose random indices within that tile batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..6fa4331c4b7 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,9 +5,11 @@ import abc from collections.abc import Callable, Iterable, Iterator +from functools import partial import torch from rtree.index import Index, Property +from torch import Generator from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset @@ -72,6 +74,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -88,6 +91,9 @@ def __init__( .. versionchanged:: 0.4 ``length`` parameter is now optional, a reasonable default will be used + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from size: dimensions of each :term:`patch` @@ -98,6 +104,7 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: pseudo-random number generator (PRNG). """ super().__init__(dataset, roi) self.size = _to_tuple(size) @@ -105,6 +112,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.generator = generator self.length = 0 self.hits = [] areas = [] @@ -142,7 +150,9 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) yield bounding_box @@ -270,20 +280,30 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | None = None, + shuffle: bool = False, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. .. versionadded:: 0.3 + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch + generator: pseudo-random number generator (PRNG) used in combination with shuffle. + """ super().__init__(dataset, roi) self.shuffle = shuffle + self.generator = generator self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -297,7 +317,7 @@ def __iter__(self) -> Iterator[BoundingBox]: """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: - generator = torch.randperm + generator = partial(torch.randperm, generator=self.generator) for idx in generator(len(self)): yield BoundingBox(*self.hits[idx].bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index a1fca673a3a..48ad760f928 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -7,6 +7,7 @@ from typing import overload import torch +from torch import Generator from ..datasets import BoundingBox @@ -35,7 +36,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: tuple[float, float] | float, res: float + bounds: BoundingBox, + size: tuple[float, float] | float, + res: float, + generator: Generator | None = None, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -46,10 +50,14 @@ def get_random_bounding_box( * a ``tuple`` of two floats - in which case, the first *float* is used for the height dimension, and the second *float* for the width dimension + .. versionadded:: 0.7 + The *generator* parameter. + Args: bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image + generator: pseudo-random number generator (PRNG). Returns: randomly sampled bounding box from the extent of the input @@ -64,8 +72,8 @@ def get_random_bounding_box( miny = bounds.miny # Use an integer multiple of res to avoid resampling - minx += int(torch.rand(1).item() * width) * res - miny += int(torch.rand(1).item() * height) * res + minx += int(torch.rand(1, generator=generator).item() * width) * res + miny += int(torch.rand(1, generator=generator).item() * height) * res maxx = minx + t_size[1] maxy = miny + t_size[0]