Skip to content

Commit

Permalink
Add tests for seed
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 20, 2024
1 parent 494fbd7 commit 5a9e107
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
16 changes: 16 additions & 0 deletions tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -144,6 +145,21 @@ def test_weighted_sampling(self) -> None:
for bbox in batch:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
generator = torch.manual_seed(0)
sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample1 = bbox
break

sampler = RandomBatchGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down
32 changes: 32 additions & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from itertools import product

import pytest
import torch
from _pytest.fixtures import SubRequest
from rasterio.crs import CRS
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -139,6 +140,21 @@ def test_weighted_sampling(self) -> None:
for bbox in sampler:
assert bbox == BoundingBox(0, 10, 0, 10, 0, 10)

def test_random_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
generator = torch.manual_seed(0)
sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample1 = bbox
break

sampler = RandomGeoSampler(ds, 1, 1, generator=generator)
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down Expand Up @@ -288,6 +304,22 @@ def test_point_data(self) -> None:
for _ in sampler:
continue

def test_shuffle_seed(self) -> None:
ds = CustomGeoDataset()
ds.index.insert(0, (0, 10, 0, 10, 0, 10))
ds.index.insert(1, (0, 11, 0, 11, 0, 11))
generator = torch.manual_seed(0)
sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler:
sample1 = bbox
break

sampler = PreChippedGeoSampler(ds, shuffle=True, generator=generator)
for bbox in sampler:
sample2 = bbox
break
assert sample1 == sample2

@pytest.mark.slow
@pytest.mark.parametrize('num_workers', [0, 1, 2])
def test_dataloader(
Expand Down

0 comments on commit 5a9e107

Please sign in to comment.