Skip to content

Commit

Permalink
fix datamodules failing test, better test for resampling
Browse files Browse the repository at this point in the history
  • Loading branch information
sfalkena committed Sep 18, 2024
1 parent 83411f4 commit 6fba6cc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
from geopandas import GeoDataFrame

from torchgeo.datamodules import (
GeoDataModule,
Expand Down Expand Up @@ -182,7 +183,7 @@ def test_zero_length_sampler(self) -> None:
dm = CustomGeoDataModule()
dm.dataset = CustomGeoDataset()
dm.sampler = RandomGeoSampler(dm.dataset, 1, 1)
dm.sampler.length = 0
dm.sampler.chips = GeoDataFrame()
msg = r'CustomGeoDataModule\.sampler has length 0.'
with pytest.raises(MisconfigurationException, match=msg):
dm.train_dataloader()
Expand Down
1 change: 1 addition & 0 deletions tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def test_empty(self, dataset: CustomGeoDataset) -> None:
assert len(sampler) == 0

def test_refresh_samples(self, dataset: CustomGeoDataset) -> None:
dataset.index.insert(0, (0, 100, 200, 300, 400, 500))
sampler = RandomGeoSampler(dataset, 5, length=1)
samples = list(sampler)
assert len(sampler) == 1
Expand Down

0 comments on commit 6fba6cc

Please sign in to comment.