diff --git a/tests/datasets/test_substation.py b/tests/datasets/test_substation.py index 0ec61591c73..6022dca9ad6 100644 --- a/tests/datasets/test_substation.py +++ b/tests/datasets/test_substation.py @@ -87,9 +87,9 @@ def test_getitem_semantic(self, config: dict[str, Any]) -> None: x = dataset[0] assert isinstance(x, dict), f'Expected dict, got {type(x)}' - assert isinstance(x['image'], torch.Tensor), ( - 'Expected image to be a torch.Tensor' - ) + assert isinstance( + x['image'], torch.Tensor + ), 'Expected image to be a torch.Tensor' assert isinstance(x['mask'], torch.Tensor), 'Expected mask to be a torch.Tensor' def test_len(self, dataset: Substation) -> None: @@ -167,6 +167,44 @@ def test_not_downloaded(self, tmp_path: Path) -> None: root=tmp_path, ) + def test_not_downloaded_with_download( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + filename = 'image_stack' + maskname = 'mask' + source_image_path = os.path.join('tests', 'data', 'substation', filename) + source_mask_path = os.path.join('tests', 'data', 'substation', maskname) + target_image_path = tmp_path / filename + target_mask_path = tmp_path / maskname + + def mock_download_side_effect() -> None: + shutil.copytree(source_image_path, target_image_path) + shutil.copytree(source_mask_path, target_mask_path) + + mock_download = MagicMock(side_effect=mock_download_side_effect) + mock_extract = MagicMock() + monkeypatch.setattr( + 'torchgeo.datasets.substation.Substation._download', mock_download + ) + monkeypatch.setattr( + 'torchgeo.datasets.substation.Substation._extract', mock_extract + ) + + # Create the Substation instance + Substation( + bands=[1, 2, 3], + use_timepoints=True, + mask_2d=True, + timepoint_aggregation='median', + num_of_timepoints=4, + root=tmp_path, + download=True, + ) + + # Verify the mocked methods were called + mock_download.assert_called_once() + mock_extract.assert_called_once() + def test_extract(self, tmp_path: Path) -> None: filename = Substation.filename_images maskname = Substation.filename_masks diff --git a/torchgeo/datasets/substation.py b/torchgeo/datasets/substation.py index 89825d99af5..6df6ddcec05 100644 --- a/torchgeo/datasets/substation.py +++ b/torchgeo/datasets/substation.py @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class Substation(NonGeoDataset): @@ -46,7 +46,7 @@ class Substation(NonGeoDataset): def __init__( self, - root: str, + root: Path, bands: list[int], mask_2d: bool, timepoint_aggregation: str = 'concat', @@ -220,9 +220,7 @@ def _verify(self) -> None: # If dataset files are missing and download is not allowed, raise an error if not getattr(self, 'download', True): - raise DatasetNotFoundError( - f'Dataset files not found in {self.root}. Enable downloading or provide the files.' - ) + raise DatasetNotFoundError(self) # Download and extract the dataset self._download()