diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index f465d76f976..2127a1d22b7 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,4 +1,6 @@ -# Double quote -> single quote +# Prettier: double quote -> single quote +6a5aaf4b93507072d40dcd78114893362c4eaf6e +# Ruff: double quote -> single quote b09122f3e4a9cb422f6747bf33eca02993f67549 # Prettier bd9c75798eede1a4b7d7ecd6203179d3cb5e54dd diff --git a/.github/workflows/tutorials.yaml b/.github/workflows/tutorials.yaml index 67accad5e2c..f3e746499ee 100644 --- a/.github/workflows/tutorials.yaml +++ b/.github/workflows/tutorials.yaml @@ -33,7 +33,7 @@ jobs: - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | - pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac + pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac . pip cache purge - name: List pip dependencies run: pip list diff --git a/tests/conf/landcoverai100.yaml b/tests/conf/landcoverai100.yaml index 1610bb03990..f6461851fa3 100644 --- a/tests/conf/landcoverai100.yaml +++ b/tests/conf/landcoverai100.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 5 num_filters: 1 @@ -13,4 +13,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/landcoverai" + root: 'tests/data/landcoverai' diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 59c8aaa00be..16b99e16a93 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -144,6 +145,20 @@ def test_weighted_sampling(self) -> None: for bbox in batch: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + for bbox in sampler: + sample1 = bbox + break + + sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0)) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..abbf22d2727 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -139,6 +140,21 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + generator = torch.manual_seed(0) + sampler = RandomGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample1 = bbox + break + + sampler = RandomGeoSampler(ds, 1, 1, generator=generator) + for bbox in sampler: + sample2 = bbox + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -288,6 +304,25 @@ def test_point_data(self) -> None: for _ in sampler: continue + def test_shuffle_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 11, 0, 11, 0, 11)) + generator = torch.manual_seed(2) + sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler1: + sample1 = bbox + print(sample1) + break + + generator = torch.manual_seed(2) + sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator) + for bbox in sampler2: + sample2 = bbox + print(sample2) + break + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index bed6365d4a2..c5b92b6b01a 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None: if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( - self.train_dataset, self.patch_size, self.batch_size, self.length + self.train_dataset, + self.patch_size, + self.batch_size, + self.length, + generator=generator, ) if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 5f5ae76b939..e8e3aedd194 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -286,6 +286,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=self.collate_fn, + persistent_workers=self.num_workers > 0, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: @@ -429,6 +430,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: shuffle=split == 'train', num_workers=self.num_workers, collate_fn=self.collate_fn, + persistent_workers=self.num_workers > 0, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 22726f74b2c..396ad0f0c7b 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -70,6 +70,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -97,9 +98,11 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: random number generator """ super().__init__(dataset, roi) self.size = _to_tuple(size) + self.generator = generator if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -144,7 +147,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]: # Choose random indices within that tile batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..ea943db3d53 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,6 +5,7 @@ import abc from collections.abc import Callable, Iterable, Iterator +from functools import partial import torch from rtree.index import Index, Property @@ -72,6 +73,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -98,6 +100,8 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: The random generator used for sampling. + """ super().__init__(dataset, roi) self.size = _to_tuple(size) @@ -105,6 +109,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.generator = generator self.length = 0 self.hits = [] areas = [] @@ -142,7 +147,9 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) yield bounding_box @@ -270,7 +277,11 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | None = None, + shuffle: bool = False, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -281,9 +292,12 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch + generator: The random number generator used in combination with shuffle. + """ super().__init__(dataset, roi) self.shuffle = shuffle + self.generator = generator self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -297,7 +311,7 @@ def __iter__(self) -> Iterator[BoundingBox]: """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: - generator = torch.randperm + generator = partial(torch.randperm, generator=self.generator) for idx in generator(len(self)): yield BoundingBox(*self.hits[idx].bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index a1fca673a3a..258f74a5425 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -35,7 +35,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: tuple[float, float] | float, res: float + bounds: BoundingBox, + size: tuple[float, float] | float, + res: float, + generator: torch.Generator | None = None, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -50,6 +53,7 @@ def get_random_bounding_box( bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image + generator: random number generator Returns: randomly sampled bounding box from the extent of the input @@ -64,8 +68,8 @@ def get_random_bounding_box( miny = bounds.miny # Use an integer multiple of res to avoid resampling - minx += int(torch.rand(1).item() * width) * res - miny += int(torch.rand(1).item() * height) * res + minx += int(torch.rand(1, generator=generator).item() * width) * res + miny += int(torch.rand(1, generator=generator).item() * height) * res maxx = minx + t_size[1] maxy = miny + t_size[0]