Skip to content

Commit

Permalink
Merge pull request #6 from sede-open/random_seed
Browse files Browse the repository at this point in the history
Add random generator
  • Loading branch information
sfalkena authored Sep 23, 2024
2 parents 6a5aaf4 + 46e1f11 commit bfe635a
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 14 deletions.
4 changes: 3 additions & 1 deletion .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Double quote -> single quote
# Prettier: double quote -> single quote
6a5aaf4b93507072d40dcd78114893362c4eaf6e
# Ruff: double quote -> single quote
b09122f3e4a9cb422f6747bf33eca02993f67549
# Prettier
bd9c75798eede1a4b7d7ecd6203179d3cb5e54dd
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tutorials.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac
pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac .
pip cache purge
- name: List pip dependencies
run: pip list
Expand Down
8 changes: 4 additions & 4 deletions tests/conf/landcoverai100.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "unet"
backbone: "resnet18"
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 3
num_classes: 5
num_filters: 1
Expand All @@ -13,4 +13,4 @@ data:
init_args:
batch_size: 1
dict_kwargs:
root: "tests/data/landcoverai"
root: 'tests/data/landcoverai'
15 changes: 15 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -144,6 +145,20 @@ def test_weighted_sampling(self) -> None:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample1 = bbox
break

sampler = RandomBatchGeoSampler(ds, 1, 1, generator=torch.manual_seed(0))
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
35 changes: 35 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -139,6 +140,21 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
generator = torch.manual_seed(0)
sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample1 = bbox
break

sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -288,6 +304,25 @@ def test_point_data(self) -> None:
for _ in sampler:
continue

def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
generator = torch.manual_seed(2)
sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler1:
sample1 = bbox
print(sample1)
break

generator = torch.manual_seed(2)
sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler2:
sample2 = bbox
print(sample2)
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
6 changes: 5 additions & 1 deletion torchgeo/datamodules/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None:

if stage in ['fit']:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
self.train_dataset,
self.patch_size,
self.batch_size,
self.length,
generator=generator,
)
if stage in ['fit', 'validate']:
self.val_sampler = GridGeoSampler(
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
batch_sampler=batch_sampler,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down Expand Up @@ -429,6 +430,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]:
shuffle=split == 'train',
num_workers=self.num_workers,
collate_fn=self.collate_fn,
persistent_workers=self.num_workers > 0,
)

def train_dataloader(self) -> DataLoader[dict[str, Tensor]]:
Expand Down
7 changes: 6 additions & 1 deletion torchgeo/samplers/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand Down Expand Up @@ -97,9 +98,11 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: random number generator
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)
self.generator = generator

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)
Expand Down Expand Up @@ -144,7 +147,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]:
# Choose random indices within that tile
batch = []
for _ in range(self.batch_size):
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)
batch.append(bounding_box)

yield batch
Expand Down
20 changes: 17 additions & 3 deletions torchgeo/samplers/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import abc
from collections.abc import Callable, Iterable, Iterator
from functools import partial

import torch
from rtree.index import Index, Property
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
length: int | None = None,
roi: BoundingBox | None = None,
units: Units = Units.PIXELS,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -98,13 +100,16 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
units: defines if ``size`` is in pixel or CRS units
generator: The random generator used for sampling.
"""
super().__init__(dataset, roi)
self.size = _to_tuple(size)

if units == Units.PIXELS:
self.size = (self.size[0] * self.res, self.size[1] * self.res)

self.generator = generator
self.length = 0
self.hits = []
areas = []
Expand Down Expand Up @@ -142,7 +147,9 @@ def __iter__(self) -> Iterator[BoundingBox]:
bounds = BoundingBox(*hit.bounds)

# Choose a random index within that tile
bounding_box = get_random_bounding_box(bounds, self.size, self.res)
bounding_box = get_random_bounding_box(
bounds, self.size, self.res, self.generator
)

yield bounding_box

Expand Down Expand Up @@ -270,7 +277,11 @@ class PreChippedGeoSampler(GeoSampler):
"""

def __init__(
self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False
self,
dataset: GeoDataset,
roi: BoundingBox | None = None,
shuffle: bool = False,
generator: torch.Generator | None = None,
) -> None:
"""Initialize a new Sampler instance.
Expand All @@ -281,9 +292,12 @@ def __init__(
roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt)
(defaults to the bounds of ``dataset.index``)
shuffle: if True, reshuffle data at every epoch
generator: The random number generator used in combination with shuffle.
"""
super().__init__(dataset, roi)
self.shuffle = shuffle
self.generator = generator

self.hits = []
for hit in self.index.intersection(tuple(self.roi), objects=True):
Expand All @@ -297,7 +311,7 @@ def __iter__(self) -> Iterator[BoundingBox]:
"""
generator: Callable[[int], Iterable[int]] = range
if self.shuffle:
generator = torch.randperm
generator = partial(torch.randperm, generator=self.generator)

for idx in generator(len(self)):
yield BoundingBox(*self.hits[idx].bounds)
Expand Down
10 changes: 7 additions & 3 deletions torchgeo/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]:


def get_random_bounding_box(
bounds: BoundingBox, size: tuple[float, float] | float, res: float
bounds: BoundingBox,
size: tuple[float, float] | float,
res: float,
generator: torch.Generator | None = None,
) -> BoundingBox:
"""Returns a random bounding box within a given bounding box.
Expand All @@ -50,6 +53,7 @@ def get_random_bounding_box(
bounds: the larger bounding box to sample from
size: the size of the bounding box to sample
res: the resolution of the image
generator: random number generator
Returns:
randomly sampled bounding box from the extent of the input
Expand All @@ -64,8 +68,8 @@ def get_random_bounding_box(
miny = bounds.miny

# Use an integer multiple of res to avoid resampling
minx += int(torch.rand(1).item() * width) * res
miny += int(torch.rand(1).item() * height) * res
minx += int(torch.rand(1, generator=generator).item() * width) * res
miny += int(torch.rand(1, generator=generator).item() * height) * res

maxx = minx + t_size[1]
maxy = miny + t_size[0]
Expand Down

0 comments on commit bfe635a

Please sign in to comment.