Skip to content

Commit

Permalink
added more tests for dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
rijuld committed Jan 19, 2025
1 parent d4bf9fb commit 9a050bd
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
44 changes: 41 additions & 3 deletions tests/datasets/test_substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def test_getitem_semantic(self, config: dict[str, Any]) -> None:

x = dataset[0]
assert isinstance(x, dict), f'Expected dict, got {type(x)}'
assert isinstance(x['image'], torch.Tensor), (
'Expected image to be a torch.Tensor'
)
assert isinstance(
x['image'], torch.Tensor
), 'Expected image to be a torch.Tensor'
assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor'

def test_len(self, dataset: Substation) -> None:
Expand Down Expand Up @@ -167,6 +167,44 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
root=tmp_path,
)

def test_not_downloaded_with_download(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
filename = 'image_stack'
maskname = 'mask'
source_image_path = os.path.join('tests', 'data', 'substation', filename)
source_mask_path = os.path.join('tests', 'data', 'substation', maskname)
target_image_path = tmp_path / filename
target_mask_path = tmp_path / maskname

def mock_download_side_effect() -> None:
shutil.copytree(source_image_path, target_image_path)
shutil.copytree(source_mask_path, target_mask_path)

mock_download = MagicMock(side_effect=mock_download_side_effect)
mock_extract = MagicMock()
monkeypatch.setattr(
'torchgeo.datasets.substation.Substation._download', mock_download
)
monkeypatch.setattr(
'torchgeo.datasets.substation.Substation._extract', mock_extract
)

# Create the Substation instance
Substation(
bands=[1, 2, 3],
use_timepoints=True,
mask_2d=True,
timepoint_aggregation='median',
num_of_timepoints=4,
root=tmp_path,
download=True,
)

# Verify the mocked methods were called
mock_download.assert_called_once()
mock_extract.assert_called_once()

def test_extract(self, tmp_path: Path) -> None:
filename = Substation.filename_images
maskname = Substation.filename_masks
Expand Down
8 changes: 3 additions & 5 deletions torchgeo/datasets/substation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive
from .utils import Path, download_url, extract_archive


class Substation(NonGeoDataset):
Expand Down Expand Up @@ -46,7 +46,7 @@ class Substation(NonGeoDataset):

def __init__(
self,
root: str,
root: Path,
bands: list[int],
mask_2d: bool,
timepoint_aggregation: str = 'concat',
Expand Down Expand Up @@ -220,9 +220,7 @@ def _verify(self) -> None:

# If dataset files are missing and download is not allowed, raise an error
if not getattr(self, 'download', True):
raise DatasetNotFoundError(
f'Dataset files not found in {self.root}. Enable downloading or provide the files.'
)
raise DatasetNotFoundError(self)

# Download and extract the dataset
self._download()
Expand Down

0 comments on commit 9a050bd

Please sign in to comment.