From c51a63bdb188217cc32343f66b88566905bac6d8 Mon Sep 17 00:00:00 2001 From: Sieger Falkena Date: Mon, 23 Sep 2024 13:21:12 +0200 Subject: [PATCH] Revert "Merge branch 'vers_working_branch' into geosampler_prechipping" This reverts commit c7f3e4cf8b3e1ae6274d6f19f7417157a0515955, reversing changes made to d8cb4b207874fe69ab295e7e341edddd3b765644. --- tests/samplers/test_batch.py | 15 ------------ tests/samplers/test_single.py | 35 ---------------------------- torchgeo/datamodules/agrifieldnet.py | 6 +---- torchgeo/samplers/batch.py | 7 +----- torchgeo/samplers/single.py | 18 +++----------- torchgeo/samplers/utils.py | 10 +++----- 6 files changed, 8 insertions(+), 83 deletions(-) diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 16b99e16a93..59c8aaa00be 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -6,7 +6,6 @@ from itertools import product import pytest -import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -145,20 +144,6 @@ def test_weighted_sampling(self) -> None: for bbox in batch: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) - def test_random_seed(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) - for bbox in sampler: - sample1 = bbox - break - - sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) - for bbox in sampler: - sample2 = bbox - break - assert sample1 == sample2 - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 77db86f2b57..6fdbf4712fc 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -7,7 +7,6 @@ import geopandas as gpd import pytest -import torch from _pytest.fixtures import SubRequest from geopandas import GeoDataFrame from rasterio.crs import CRS @@ -223,21 +222,6 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) - def test_random_seed(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - generator = torch.manual_seed(0) - sampler = RandomGeoSampler(ds, 1, 1, generator=generator) - for bbox in sampler: - sample1 = bbox - break - - sampler = RandomGeoSampler(ds, 1, 1, generator=generator) - for bbox in sampler: - sample2 = bbox - break - assert sample1 == sample2 - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -387,25 +371,6 @@ def test_point_data(self) -> None: for _ in sampler: continue - def test_shuffle_seed(self) -> None: - ds = CustomGeoDataset() - ds.index.insert(0, (0, 10, 0, 10, 0, 10)) - ds.index.insert(1, (0, 11, 0, 11, 0, 11)) - generator = torch.manual_seed(2) - sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler1: - sample1 = bbox - print(sample1) - break - - generator = torch.manual_seed(2) - sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) - for bbox in sampler2: - sample2 = bbox - print(sample2) - break - assert sample1 == sample2 - @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index c5b92b6b01a..bed6365d4a2 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -74,11 +74,7 @@ def setup(self, stage: str) -> None: if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( - self.train_dataset, - self.patch_size, - self.batch_size, - self.length, - generator=generator, + self.train_dataset, self.patch_size, self.batch_size, self.length ) if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 396ad0f0c7b..22726f74b2c 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -70,7 +70,6 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, - generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -98,11 +97,9 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units - generator: random number generator """ super().__init__(dataset, roi) self.size = _to_tuple(size) - self.generator = generator if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -147,9 +144,7 @@ def __iter__(self) -> Iterator[list[BoundingBox]]: # Choose random indices within that tile batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box( - bounds, self.size, self.res, self.generator - ) + bounding_box = get_random_bounding_box(bounds, self.size, self.res) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 452dfaaadad..fcb4ed536f8 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,7 +5,6 @@ import abc from collections.abc import Callable, Iterable, Iterator -from functools import partial import geopandas as gpd import numpy as np @@ -211,7 +210,6 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, - generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -238,8 +236,6 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units - generator: The random generator used for sampling. - """ super().__init__(dataset, roi) self.size = _to_tuple(size) @@ -247,7 +243,6 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) - self.generator = generator self.length = 0 self.hits = [] areas = [] @@ -309,7 +304,7 @@ def get_chips(self) -> GeoDataFrame: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bbox = get_random_bounding_box(bounds, self.size, self.res, self.generator) + bbox = get_random_bounding_box(bounds, self.size, self.res) minx, maxx, miny, maxy, mint, maxt = tuple(bbox) chip = { 'geometry': box(minx, miny, maxx, maxy), @@ -452,11 +447,7 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, - dataset: GeoDataset, - roi: BoundingBox | None = None, - shuffle: bool = False, - generator: torch.Generator | None = None, + self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False ) -> None: """Initialize a new Sampler instance. @@ -467,12 +458,9 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch - generator: The random number generator used in combination with shuffle. - """ super().__init__(dataset, roi) self.shuffle = shuffle - self.generator = generator self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -489,7 +477,7 @@ def get_chips(self) -> GeoDataFrame: """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: - generator = partial(torch.randperm, generator=self.generator) + generator = torch.randperm print('generating samples... ') chips = [] diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index 258f74a5425..a1fca673a3a 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -35,10 +35,7 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, - size: tuple[float, float] | float, - res: float, - generator: torch.Generator | None = None, + bounds: BoundingBox, size: tuple[float, float] | float, res: float ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -53,7 +50,6 @@ def get_random_bounding_box( bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image - generator: random number generator Returns: randomly sampled bounding box from the extent of the input @@ -68,8 +64,8 @@ def get_random_bounding_box( miny = bounds.miny # Use an integer multiple of res to avoid resampling - minx += int(torch.rand(1, generator=generator).item() * width) * res - miny += int(torch.rand(1, generator=generator).item() * height) * res + minx += int(torch.rand(1).item() * width) * res + miny += int(torch.rand(1).item() * height) * res maxx = minx + t_size[1] maxy = miny + t_size[0]