diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index d02b7238218..458269e24c7 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -7,34 +7,6 @@ on: branches: - release** jobs: - datasets: - name: datasets - runs-on: ubuntu-latest - steps: - - name: Clone repo - uses: actions/checkout@v4.1.5 - - name: Set up python - uses: actions/setup-python@v5.1.0 - with: - python-version: "3.12" - - name: Cache dependencies - uses: actions/cache@v4.0.2 - id: cache - with: - path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets - - name: Install pip dependencies - if: steps.cache.outputs.cache-hit != 'true' - run: | - pip install .[tests] - pip cache purge - - name: List pip dependencies - run: pip list - - name: Run pytest checks - run: | - pytest --cov=torchgeo --cov-report=xml --durations=10 - python -m torchgeo --help - torchgeo --help integration: name: integration runs-on: ubuntu-latest diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 58d2f8f3b01..9988d3cb5df 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -99,6 +99,39 @@ jobs: uses: codecov/codecov-action@v4.3.1 with: token: ${{ secrets.CODECOV_TOKEN }} + datasets: + name: datasets + runs-on: ubuntu-latest + env: + MPLBACKEND: Agg + steps: + - name: Clone repo + uses: actions/checkout@v4.1.4 + - name: Set up python + uses: actions/setup-python@v5.1.0 + with: + python-version: "3.12" + - name: Cache dependencies + uses: actions/cache@v4.0.2 + id: cache + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }} + - name: Install pip dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: | + pip install -r requirements/required.txt -r requirements/tests.txt + pip cache purge + - name: List pip dependencies + run: pip list + - name: Run pytest checks + run: | + pytest --cov=torchgeo --cov-report=xml --durations=10 + python3 -m torchgeo --help + - name: Report coverage + uses: codecov/codecov-action@v4.3.0 + with: + token: ${{ secrets.CODECOV_TOKEN }} concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }} cancel-in-progress: true diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 265bac96109..168e291cdb4 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -535,4 +535,5 @@ Errors ------ .. autoclass:: DatasetNotFoundError +.. autoclass:: DependencyNotFoundError .. autoclass:: RGBBandsMissingError diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 004750c2840..c6d5b7c77bf 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -14,7 +14,6 @@ class TestUSAVarsDataModule: @pytest.fixture def datamodule(self, request: SubRequest) -> USAVarsDataModule: - pytest.importorskip('pandas', minversion='1.1.3') root = os.path.join('tests', 'data', 'usavars') batch_size = 1 num_workers = 0 diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index efa3ac538f8..f2a34b89f4c 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -16,6 +14,8 @@ import torchgeo.datasets.utils from torchgeo.datasets import ADVANCE, DatasetNotFoundError +pytest.importorskip('scipy', minversion='1.7.2') + def download_url(url: str, root: str, *args: str) -> None: shutil.copy(url, root) @@ -37,19 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE: transforms = nn.Identity() return ADVANCE(root, transforms, download=True, checksum=True) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'scipy.io': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: ADVANCE) -> None: - pytest.importorskip('scipy', minversion='1.6.2') x = dataset[0] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) @@ -71,17 +59,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ADVANCE(str(tmp_path)) - def test_mock_missing_module( - self, dataset: ADVANCE, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='scipy is not installed and is required to use this dataset', - ): - dataset[0] - def test_plot(self, dataset: ADVANCE) -> None: - pytest.importorskip('scipy', minversion='1.6.2') x = dataset[0].copy() dataset.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_chabud.py b/tests/datasets/test_chabud.py index 955104a53fb..074674a1733 100644 --- a/tests/datasets/test_chabud.py +++ b/tests/datasets/test_chabud.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -17,7 +15,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import ChaBuD, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3') +pytest.importorskip('h5py', minversion='3.6') def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: @@ -47,17 +45,6 @@ def dataset( checksum=True, ) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: ChaBuD) -> None: x = dataset[0] assert isinstance(x, dict) @@ -85,15 +72,6 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ChaBuD(str(tmp_path)) - def test_mock_missing_module( - self, dataset: ChaBuD, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - ChaBuD(dataset.root, download=True, checksum=True) - def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): ChaBuD(bands=['OK', 'BK']) diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 459c6bcdbdf..e31b441529b 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -23,14 +23,14 @@ UnionDataset, ) +pytest.importorskip('zipfile_deflate64') + def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) class TestChesapeake13: - pytest.importorskip('zipfile_deflate64') - @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13: monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url) diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py index f478cdf53ad..2ad82fca137 100644 --- a/tests/datasets/test_cropharvest.py +++ b/tests/datasets/test_cropharvest.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -16,7 +14,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import CropHarvest, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3') +pytest.importorskip('h5py', minversion='3.6') def download_url(url: str, root: str, filename: str, md5: str) -> None: @@ -24,17 +22,6 @@ def download_url(url: str, root: str, filename: str, md5: str) -> None: class TestCropHarvest: - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url) @@ -89,12 +76,3 @@ def test_plot(self, dataset: CropHarvest) -> None: x = dataset[0].copy() dataset.plot(x, suptitle='Test') plt.close() - - def test_mock_missing_module( - self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - CropHarvest(root=str(tmp_path), download=True)[0] diff --git a/tests/datasets/test_errors.py b/tests/datasets/test_errors.py index 4aa6fbc1c56..f87ab6f03a3 100644 --- a/tests/datasets/test_errors.py +++ b/tests/datasets/test_errors.py @@ -6,7 +6,11 @@ import pytest from torch.utils.data import Dataset -from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError +from torchgeo.datasets import ( + DatasetNotFoundError, + DependencyNotFoundError, + RGBBandsMissingError, +) class TestDatasetNotFoundError: @@ -55,6 +59,11 @@ def test_paths_download(self) -> None: raise DatasetNotFoundError(ds) +def test_missing_dependency() -> None: + with pytest.raises(DependencyNotFoundError, match='pip install foo'): + raise DependencyNotFoundError('foo') + + def test_rgb_bands_missing() -> None: match = 'Dataset does not contain some of the RGB bands' with pytest.raises(RGBBandsMissingError, match=match): diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index e907213c29c..a4c05580b58 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import glob import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -50,19 +48,6 @@ def dataset( transforms = nn.Identity() return IDTReeS(root, split, task, transforms, download=True, checksum=True) - @pytest.fixture(params=['laspy', 'pyvista']) - def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str: - import_orig = builtins.__import__ - package = str(request.param) - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == package: - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - return package - def test_getitem(self, dataset: IDTReeS) -> None: x = dataset[0] assert isinstance(x, dict) @@ -101,24 +86,6 @@ def test_not_extracted(self, tmp_path: Path) -> None: shutil.copy(zipfile, root) IDTReeS(root) - def test_mock_missing_module( - self, dataset: IDTReeS, mock_missing_module: str - ) -> None: - package = mock_missing_module - - if package == 'laspy': - with pytest.raises( - ImportError, - match=f'{package} is not installed and is required to use this dataset', - ): - IDTReeS(dataset.root, dataset.split, dataset.task) - elif package == 'pyvista': - with pytest.raises( - ImportError, - match=f'{package} is not installed and is required to plot point cloud', - ): - dataset.plot_las(0) - def test_plot(self, dataset: IDTReeS) -> None: x = dataset[0].copy() dataset.plot(x, suptitle='Test') diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 7222ff78bbc..7c81f257250 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -76,11 +76,12 @@ def test_plot(self, dataset: LandCoverAIGeo) -> None: class TestLandCoverAI: + pytest.importorskip('cv2', minversion='4.5.4') + @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LandCoverAI: - pytest.importorskip('cv2', minversion='4.4.0') monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url) md5 = 'ff8998857cc8511f644d3f7d0f3688d0' monkeypatch.setattr(LandCoverAI, 'md5', md5) @@ -111,7 +112,6 @@ def test_already_extracted(self, dataset: LandCoverAI) -> None: LandCoverAI(root=dataset.root, download=True) def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - pytest.importorskip('cv2', minversion='4.4.0') sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' monkeypatch.setattr(LandCoverAI, 'sha256', sha256) url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index 59596f2c0c5..636e9d7a666 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -17,6 +15,8 @@ import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, QuakeSet +pytest.importorskip('h5py', minversion='3.6') + def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -39,26 +39,6 @@ def dataset( root, split, transforms=transforms, download=True, checksum=True ) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - - def test_mock_missing_module( - self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - QuakeSet(dataset.root, download=True, checksum=True) - def test_getitem(self, dataset: QuakeSet) -> None: x = dataset[0] assert isinstance(x, dict) diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index d84a9709602..d52d2d01194 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -15,6 +15,8 @@ import torchgeo.datasets.utils from torchgeo.datasets import RESISC45, DatasetNotFoundError +pytest.importorskip('rarfile', minversion='4') + def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) @@ -25,8 +27,6 @@ class TestRESISC45: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> RESISC45: - pytest.importorskip('rarfile', minversion='4') - monkeypatch.setattr(torchgeo.datasets.resisc45, 'download_url', download_url) md5 = '5895dea3757ba88707d52f5521c444d3' monkeypatch.setattr(RESISC45, 'md5', md5) diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py index 01907d22cea..d4deb975b3b 100644 --- a/tests/datasets/test_skippd.py +++ b/tests/datasets/test_skippd.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from itertools import product from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -18,7 +16,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import SKIPPD, DatasetNotFoundError -pytest.importorskip('h5py', minversion='3') +pytest.importorskip('h5py', minversion='3.6') def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -53,26 +51,6 @@ def dataset( checksum=True, ) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - - def test_mock_missing_module( - self, dataset: SKIPPD, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - SKIPPD(dataset.root, download=True, checksum=True) - def test_already_extracted(self, dataset: SKIPPD) -> None: SKIPPD(root=dataset.root, download=True) diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 6daa4057bed..1caf86b6c30 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -1,10 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -15,7 +13,7 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, So2Sat -pytest.importorskip('h5py', minversion='3') +pytest.importorskip('h5py', minversion='3.6') class TestSo2Sat: @@ -35,17 +33,6 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> So2Sat: transforms = nn.Identity() return So2Sat(root=root, split=split, transforms=transforms, checksum=True) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: So2Sat) -> None: x = dataset[0] assert isinstance(x, dict) @@ -89,12 +76,3 @@ def test_plot_rgb(self, dataset: So2Sat) -> None: RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): dataset.plot(dataset[0], suptitle='Single Band') - - def test_mock_missing_module( - self, dataset: So2Sat, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - So2Sat(dataset.root) diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index bd3e720652f..d8e44885418 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import glob import math import os @@ -20,8 +19,8 @@ from rasterio.crs import CRS import torchgeo.datasets.utils +from torchgeo.datasets import BoundingBox, DependencyNotFoundError from torchgeo.datasets.utils import ( - BoundingBox, array_to_tensor, concat_samples, disambiguate_timestamp, @@ -29,6 +28,7 @@ download_radiant_mlhub_collection, download_radiant_mlhub_dataset, extract_archive, + lazy_import, merge_samples, percentile_normalization, stack_samples, @@ -37,18 +37,6 @@ ) -@pytest.fixture -def mock_missing_module(monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name in ['radiant_mlhub', 'rarfile', 'zipfile_deflate64']: - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - - class MLHubDataset: def download(self, output_dir: str, **kwargs: str) -> None: glob_path = os.path.join( @@ -79,10 +67,6 @@ def download_url(url: str, root: str, *args: str) -> None: shutil.copy(url, root) -def test_mock_missing_module(mock_missing_module: None) -> None: - import sys # noqa: F401 - - @pytest.mark.parametrize( 'src', [ @@ -102,21 +86,6 @@ def test_extract_archive(src: str, tmp_path: Path) -> None: extract_archive(os.path.join('tests', 'data', src), str(tmp_path)) -def test_missing_rarfile(mock_missing_module: None) -> None: - with pytest.raises( - ImportError, - match='rarfile is not installed and is required to extract this dataset', - ): - extract_archive( - os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar') - ) - - -def test_missing_zipfile_deflate64(mock_missing_module: None) -> None: - # Should fallback on Python builtin zipfile - extract_archive(os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')) - - def test_unsupported_scheme() -> None: with pytest.raises( RuntimeError, match='src file has unknown archival/compression scheme' @@ -148,21 +117,6 @@ def test_download_radiant_mlhub_collection( download_radiant_mlhub_collection('', str(tmp_path)) -def test_missing_radiant_mlhub(mock_missing_module: None) -> None: - with pytest.raises( - ImportError, - match='radiant_mlhub is not installed and is required to download this dataset', - ): - download_radiant_mlhub_dataset('', '') - - with pytest.raises( - ImportError, - match='radiant_mlhub is not installed and is required to download this' - + ' collection', - ): - download_radiant_mlhub_collection('', '') - - class TestBoundingBox: def test_repr_str(self) -> None: bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4) @@ -625,3 +579,14 @@ def test_array_to_tensor(array_dtype: 'np.typing.DTypeLike') -> None: # values equal even if they differ. assert array[0].item() == tensor[0].item() assert array[1].item() == tensor[1].item() + + +@pytest.mark.parametrize('name', ['collections', 'collections.abc']) +def test_lazy_import(name: str) -> None: + lazy_import(name) + + +@pytest.mark.parametrize('name', ['foo_bar', 'foo_bar.baz']) +def test_lazy_import_missing(name: str) -> None: + with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'): + lazy_import(name) diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index de4c6c2d507..dee46c1db88 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -19,6 +17,7 @@ from torchgeo.datasets import VHR10, DatasetNotFoundError pytest.importorskip('pycocotools') +pytest.importorskip('rarfile', minversion='4') def download_url(url: str, root: str, *args: str) -> None: @@ -30,7 +29,6 @@ class TestVHR10: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> VHR10: - pytest.importorskip('rarfile', minversion='4') monkeypatch.setattr(torchgeo.datasets.vhr10, 'download_url', download_url) monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) url = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar') @@ -46,17 +44,6 @@ def dataset( transforms = nn.Identity() return VHR10(root, split, transforms, download=True, checksum=True) - @pytest.fixture - def mock_missing_modules(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name in {'pycocotools.coco', 'skimage.measure'}: - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: VHR10) -> None: for i in range(2): x = dataset[i] @@ -93,25 +80,8 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): VHR10(str(tmp_path)) - def test_mock_missing_module( - self, dataset: VHR10, mock_missing_modules: None - ) -> None: - if dataset.split == 'positive': - with pytest.raises( - ImportError, - match='pycocotools is not installed and is required to use this datase', - ): - VHR10(dataset.root, dataset.split) - - with pytest.raises( - ImportError, - match='scikit-image is not installed and is required to plot masks', - ): - x = dataset[0] - dataset.plot(x) - def test_plot(self, dataset: VHR10) -> None: - pytest.importorskip('skimage', minversion='0.18') + pytest.importorskip('skimage', minversion='0.19') x = dataset[1].copy() dataset.plot(x, suptitle='Test') plt.close() diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index 330866b36f0..bea0d9e8519 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -1,11 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import builtins import os import shutil from pathlib import Path -from typing import Any import matplotlib.pyplot as plt import pytest @@ -16,7 +14,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop -pytest.importorskip('h5py', minversion='3') +pytest.importorskip('h5py', minversion='3.6') def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: @@ -39,17 +37,6 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop: transforms = nn.Identity() return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True) - @pytest.fixture - def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: - import_orig = builtins.__import__ - - def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: - if name == 'h5py': - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, '__import__', mocked_import) - def test_getitem(self, dataset: ZueriCrop) -> None: x = dataset[0] assert isinstance(x, dict) @@ -82,15 +69,6 @@ def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): ZueriCrop(str(tmp_path)) - def test_mock_missing_module( - self, dataset: ZueriCrop, tmp_path: Path, mock_missing_module: None - ) -> None: - with pytest.raises( - ImportError, - match='h5py is not installed and is required to use this dataset', - ): - ZueriCrop(dataset.root, download=True, checksum=True) - def test_invalid_bands(self) -> None: with pytest.raises(ValueError): ZueriCrop(bands=('OK', 'BK')) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index ca4da48bb78..3402bd0f3fc 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -89,7 +89,7 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: if name.startswith('so2sat') or name == 'quakeset': - pytest.importorskip('h5py', minversion='3') + pytest.importorskip('h5py', minversion='3.6') config = os.path.join('tests', 'conf', name + '.yaml') diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index ef3c6164d98..c62c808c72f 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -71,7 +71,7 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: if name == 'skippd': - pytest.importorskip('h5py', minversion='3') + pytest.importorskip('h5py', minversion='3.6') config = os.path.join('tests', 'conf', name + '.yaml') diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 88cd58c0553..d8b207d5d2d 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -87,12 +87,16 @@ class TestSemanticSegmentationTask: def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name == 'naipchesapeake': - pytest.importorskip('zipfile_deflate64') - - if name == 'landcoverai': - sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' - monkeypatch.setattr(LandCoverAI, 'sha256', sha256) + match name: + case 'chabud': + pytest.importorskip('h5py', minversion='3.6') + case 'landcoverai': + sha256 = ( + 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' + ) + monkeypatch.setattr(LandCoverAI, 'sha256', sha256) + case 'naipchesapeake': + pytest.importorskip('zipfile_deflate64') config = os.path.join('tests', 'conf', name + '.yaml') diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index db3c9786547..56303ceaefa 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -37,7 +37,7 @@ from .dfc2022 import DFC2022 from .eddmaps import EDDMapS from .enviroatlas import EnviroAtlas -from .errors import DatasetNotFoundError, RGBBandsMissingError +from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError from .esri2020 import Esri2020 from .etci2021 import ETCI2021 from .eudem import EUDEM @@ -280,5 +280,6 @@ 'time_series_split', # Errors 'DatasetNotFoundError', + 'DependencyNotFoundError', 'RGBBandsMissingError', ) diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index f14392b24dd..0cc02544715 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import download_and_extract_archive, lazy_import class ADVANCE(NonGeoDataset): @@ -104,7 +104,10 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If scipy is not installed. """ + lazy_import('scipy.io.wavfile') + self.root = root self.transforms = transforms self.checksum = checksum @@ -191,14 +194,8 @@ def _load_target(self, path: str) -> Tensor: Returns: the target audio """ - try: - from scipy.io import wavfile - except ImportError: - raise ImportError( - 'scipy is not installed and is required to use this dataset' - ) - - array = wavfile.read(path, mmap=True)[1] + siw = lazy_import('scipy.io.wavfile') + array = siw.read(path, mmap=True)[1] tensor = torch.from_numpy(array) tensor = tensor.unsqueeze(0) return tensor diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py index 42e3fdad339..905c2d8496e 100644 --- a/torchgeo/datasets/chabud.py +++ b/torchgeo/datasets/chabud.py @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, percentile_normalization +from .utils import download_url, lazy_import, percentile_normalization class ChaBuD(NonGeoDataset): @@ -96,7 +96,10 @@ def __init__( Raises: AssertionError: If ``split`` or ``bands`` arguments are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If h5py is not installed. """ + lazy_import('h5py') + assert split in self.folds assert set(bands) <= set(self.all_bands) @@ -111,13 +114,6 @@ def __init__( self._verify() - try: - import h5py # noqa: F401 - except ImportError: - raise ImportError( - 'h5py is not installed and is required to use this dataset' - ) - self.uuids = self._load_uuids() def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -153,8 +149,7 @@ def _load_uuids(self) -> list[str]: Returns: the image uuids """ - import h5py - + h5py = lazy_import('h5py') uuids = [] with h5py.File(self.filepath, 'r') as f: for k, v in f.items(): @@ -173,8 +168,7 @@ def _load_image(self, index: int) -> Tensor: Returns: the image """ - import h5py - + h5py = lazy_import('h5py') uuid = self.uuids[index] with h5py.File(self.filepath, 'r') as f: pre_array = f[uuid]['pre_fire'][:] @@ -199,8 +193,7 @@ def _load_target(self, index: int) -> Tensor: Returns: the target mask """ - import h5py - + h5py = lazy_import('h5py') uuid = self.uuids[index] with h5py.File(self.filepath, 'r') as f: array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1) diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index 7a41e829971..400b5ceb63c 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import download_url, extract_archive, lazy_import class CropHarvest(NonGeoDataset): @@ -112,14 +112,9 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. - ImportError: If h5py is not installed + DependencyNotFoundError: If h5py is not installed. """ - try: - import h5py # noqa: F401 - except ImportError: - raise ImportError( - 'h5py is not installed and is required to use this dataset' - ) + lazy_import('h5py') self.root = root self.transforms = transforms @@ -210,8 +205,7 @@ def _load_array(self, path: str) -> Tensor: Returns: the image """ - import h5py - + h5py = lazy_import('h5py') filename = os.path.join(path) with h5py.File(filename, 'r') as f: array = f.get('array')[()] diff --git a/torchgeo/datasets/errors.py b/torchgeo/datasets/errors.py index c78a7819227..4b8d7a75982 100644 --- a/torchgeo/datasets/errors.py +++ b/torchgeo/datasets/errors.py @@ -49,6 +49,32 @@ def __init__(self, dataset: Dataset[object]) -> None: super().__init__(msg) +class DependencyNotFoundError(ModuleNotFoundError): + """Raised when an optional dataset dependency is not installed. + + .. versionadded:: 0.6 + """ + + def __init__(self, name: str) -> None: + """Initialize a new DependencyNotFoundError instance. + + Args: + name: Name of missing dependency. + """ + msg = f"""\ +{name} is not installed and is required to use this dataset. Either run: + +$ pip install {name} + +to install just this dependency, or: + +$ pip install torchgeo[datasets] + +to install all optional dataset dependencies.""" + + super().__init__(msg) + + class RGBBandsMissingError(ValueError): """Raised when a dataset is missing RGB bands for plotting. diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index b36f2a9db9c..6ddc0c9661b 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -22,7 +22,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import download_url, extract_archive, lazy_import class IDTReeS(NonGeoDataset): @@ -92,6 +92,11 @@ class IDTReeS(NonGeoDataset): * https://doi.org/10.1101/2021.08.06.453503 + This dataset requires the following additional libraries to be installed: + + * `laspy `_ to read lidar point clouds + * `pyvista `_ to plot lidar point clouds + .. versionadded:: 0.2 """ @@ -167,11 +172,14 @@ def __init__( checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - ImportError: if laspy is not installed DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If laspy is not installed. """ + lazy_import('laspy') + assert split in ['train', 'test'] assert task in ['task1', 'task2'] + self.root = root self.split = split self.task = task @@ -182,14 +190,6 @@ def __init__( self.idx2class = {i: c for i, c in enumerate(self.classes)} self.num_classes = len(self.classes) self._verify() - - try: - import laspy # noqa: F401 - except ImportError: - raise ImportError( - 'laspy is not installed and is required to use this dataset' - ) - self.images, self.geometries, self.labels = self._load(root) def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -263,8 +263,7 @@ def _load_las(self, path: str) -> Tensor: Returns: the point cloud """ - import laspy - + laspy = lazy_import('laspy') las = laspy.read(path) array: 'np.typing.NDArray[np.int_]' = np.stack([las.x, las.y, las.z], axis=0) tensor = torch.from_numpy(array) @@ -561,19 +560,13 @@ def plot_las(self, index: int) -> 'pyvista.Plotter': # type: ignore[name-define pyvista.PolyData object. Run pyvista.plot(point_cloud, ...) to display Raises: - ImportError: if pyvista is not installed + DependencyNotFoundError: If laspy or pyvista are not installed. .. versionchanged:: 0.4 Ported from Open3D to PyVista, *colormap* parameter removed. """ - try: - import pyvista # noqa: F401 - except ImportError: - raise ImportError( - 'pyvista is not installed and is required to plot point clouds' - ) - import laspy - + laspy = lazy_import('laspy') + pyvista = lazy_import('pyvista') path = self.images[index] path = path.replace('RGB', 'LAS').replace('.tif', '.las') las = laspy.read(path) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index 7ca10f02718..ce5d9a3bd2c 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, percentile_normalization +from .utils import download_url, lazy_import, percentile_normalization class QuakeSet(NonGeoDataset): @@ -85,8 +85,10 @@ def __init__( Raises: AssertionError: If ``split`` argument is invalid. DatasetNotFoundError: If dataset is not found and *download* is False. - ImportError: if h5py is not installed + DependencyNotFoundError: If h5py is not installed. """ + lazy_import('h5py') + assert split in self.splits self.root = root @@ -95,16 +97,7 @@ def __init__( self.download = download self.checksum = checksum self.filepath = os.path.join(root, self.filename) - self._verify() - - try: - import h5py # noqa: F401 - except ImportError: - raise ImportError( - 'h5py is not installed and is required to use this dataset' - ) - self.data = self._load_data() def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -141,8 +134,7 @@ def _load_data(self) -> list[dict[str, Any]]: Returns: the sample keys, patches, images, labels, and magnitudes """ - import h5py - + h5py = lazy_import('h5py') data = [] with h5py.File(self.filepath) as f: for k in sorted(f.keys()): @@ -185,7 +177,7 @@ def _load_image(self, index: int) -> Tensor: Returns: the image """ - import h5py + h5py = lazy_import('h5py') key = self.data[index]['key'] patch = self.data[index]['patch'] diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 376e4a1ce47..0d111ae15b9 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import download_url, extract_archive, lazy_import class SKIPPD(NonGeoDataset): @@ -53,6 +53,12 @@ class SKIPPD(NonGeoDataset): * https://doi.org/10.48550/arXiv.2207.00913 + .. note:: + + This dataset requires the following additional library to be installed: + + * ``_ to load the dataset + .. versionadded:: 0.5 """ @@ -94,8 +100,10 @@ def __init__( Raises: AssertionError: if ``task`` or ``split`` is invalid DatasetNotFoundError: If dataset is not found and *download* is False. - ImportError: if h5py is not installed + DependencyNotFoundError: If h5py is not installed. """ + lazy_import('h5py') + assert ( split in self.valid_splits ), f'Please choose one of these valid data splits {self.valid_splits}.' @@ -110,14 +118,6 @@ def __init__( self.transforms = transforms self.download = download self.checksum = checksum - - try: - import h5py # noqa: F401 - except ImportError: - raise ImportError( - 'h5py is not installed and is required to use this dataset' - ) - self._verify() def __len__(self) -> int: @@ -126,8 +126,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - import h5py - + h5py = lazy_import('h5py') with h5py.File( os.path.join(self.root, self.data_file_name.format(self.task)), 'r' ) as f: @@ -161,8 +160,7 @@ def _load_image(self, index: int) -> Tensor: Returns: image tensor at index """ - import h5py - + h5py = lazy_import('h5py') with h5py.File( os.path.join(self.root, self.data_file_name.format(self.task)), 'r' ) as f: @@ -187,8 +185,7 @@ def _load_features(self, index: int) -> dict[str, str | Tensor]: Returns: label tensor at index """ - import h5py - + h5py = lazy_import('h5py') with h5py.File( os.path.join(self.root, self.data_file_name.format(self.task)), 'r' ) as f: diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 77281ce3e7b..ee6f1435474 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, percentile_normalization +from .utils import check_integrity, lazy_import, percentile_normalization class So2Sat(NonGeoDataset): @@ -97,6 +97,12 @@ class So2Sat(NonGeoDataset): done or manually downloaded from https://mediatum.ub.tum.de/1613658 + + .. note:: + + This dataset requires the following additional library to be installed: + + * ``_ to load the dataset """ # noqa: E501 versions = ['2', '3_random', '3_block', '3_culture_10'] @@ -210,6 +216,7 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found. + DependencyNotFoundError: If h5py is not installed. .. versionadded:: 0.3 The *bands* parameter. @@ -217,12 +224,8 @@ def __init__( .. versionadded:: 0.5 The *version* parameter. """ - try: - import h5py # noqa: F401 - except ImportError: - raise ImportError( - 'h5py is not installed and is required to use this dataset' - ) + h5py = lazy_import('h5py') + assert version in self.versions assert split in self.filenames_by_version[version] @@ -272,8 +275,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index """ - import h5py - + h5py = lazy_import('h5py') with h5py.File(self.fn, 'r') as f: s1 = f['sen1'][index].astype(np.float64) # convert from None: self.kwargs = kwargs def __enter__(self) -> Any: - try: - import rarfile - except ImportError: - raise ImportError( - 'rarfile is not installed and is required to extract this dataset' - ) - + rarfile = lazy_import('rarfile') # TODO: catch exception for when rarfile is installed but not # unrar/unar/bsdtar return rarfile.RarFile(*self.args, **self.kwargs) @@ -157,14 +154,11 @@ def download_radiant_mlhub_dataset( api_key: the API key to use for all requests from the session. Can also be passed in via the ``MLHUB_API_KEY`` environment variable, or configured in ``~/.mlhub/profiles``. - """ - try: - import radiant_mlhub - except ImportError: - raise ImportError( - 'radiant_mlhub is not installed and is required to download this dataset' - ) + Raises: + DependencyNotFoundError: If radiant_mlhub is not installed. + """ + radiant_mlhub = lazy_import('radiant_mlhub') dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key) dataset.download(output_dir=download_root, api_key=api_key) @@ -180,14 +174,11 @@ def download_radiant_mlhub_collection( api_key: the API key to use for all requests from the session. Can also be passed in via the ``MLHUB_API_KEY`` environment variable, or configured in ``~/.mlhub/profiles``. - """ - try: - import radiant_mlhub - except ImportError: - raise ImportError( - 'radiant_mlhub is not installed and is required to download this collection' - ) + Raises: + DependencyNotFoundError: If radiant_mlhub is not installed. + """ + radiant_mlhub = lazy_import('radiant_mlhub') collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key) collection.download(output_dir=download_root, api_key=api_key) @@ -773,3 +764,25 @@ def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor: elif array.dtype == np.uint32: array = array.astype(np.int64) return torch.tensor(array) + + +def lazy_import(name: str) -> Any: + """Lazy import of *name*. + + Args: + name: Name of module to import. + + Raises: + DependencyNotFoundError: If *name* is not installed. + + .. versionadded:: 0.6 + """ + try: + return importlib.import_module(name) + except ModuleNotFoundError: + # Map from import name to package name on PyPI + name = name.split('.')[0].replace('_', '-') + module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name) + module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'} + name = module_to_pypi[name] + raise DependencyNotFoundError(name) from None diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index 13937336281..7f77b29513c 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -17,7 +17,12 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, download_url +from .utils import ( + check_integrity, + download_and_extract_archive, + download_url, + lazy_import, +) def convert_coco_poly_to_mask( @@ -32,13 +37,15 @@ def convert_coco_poly_to_mask( Returns: Tensor: Mask tensor - """ - from pycocotools import mask as coco_mask # noqa: F401 + Raises: + DependencyNotFoundError: If pycocotools is not installed. + """ + pycocotools = lazy_import('pycocotools') masks = [] for polygons in segmentations: - rles = coco_mask.frPyObjects(polygons, height, width) - mask = coco_mask.decode(rles) + rles = pycocotools.mask.frPyObjects(polygons, height, width) + mask = pycocotools.mask.decode(rles) mask = torch.as_tensor(mask, dtype=torch.uint8) mask = mask.any(dim=2) masks.append(mask) @@ -196,8 +203,9 @@ def __init__( Raises: AssertionError: if ``split`` argument is invalid - ImportError: if ``split="positive"`` and pycocotools is not installed DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: if ``split="positive"`` and pycocotools is + not installed. """ assert split in ['positive', 'negative'] @@ -213,20 +221,12 @@ def __init__( raise DatasetNotFoundError(self) if split == 'positive': - # Must be installed to parse annotations file - try: - from pycocotools.coco import COCO # noqa: F401 - except ImportError: - raise ImportError( - 'pycocotools is not installed and is required to use this dataset' - ) - - self.coco = COCO( + pc = lazy_import('pycocotools.coco') + self.coco = pc.COCO( os.path.join( self.root, 'NWPU VHR-10 dataset', self.target_meta['filename'] ) ) - self.coco_convert = ConvertCocoAnnotations() self.ids = list(sorted(self.coco.imgs.keys())) @@ -381,7 +381,7 @@ def plot( Raises: AssertionError: if ``show_feats`` argument is invalid - ImportError: if plotting masks and scikit-image is not installed + DependencyNotFoundError: If plotting masks and scikit-image is not installed. .. versionadded:: 0.4 """ @@ -397,12 +397,7 @@ def plot( return fig if show_feats != 'boxes': - try: - from skimage.measure import find_contours # noqa: F401 - except ImportError: - raise ImportError( - 'scikit-image is not installed and is required to plot masks.' - ) + skimage = lazy_import('skimage') image = sample['image'].permute(1, 2, 0).numpy() boxes = sample['boxes'].cpu().numpy() @@ -465,7 +460,7 @@ def plot( # Add masks if show_feats in {'masks', 'both'} and 'masks' in sample: mask = masks[i] - contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call] + contours = skimage.measure.find_contours(mask, 0.5) for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( @@ -517,7 +512,7 @@ def plot( # Add masks if show_pred_masks: mask = prediction_masks[i] - contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call] + contours = skimage.measure.find_contours(mask, 0.5) for verts in contours: verts = np.fliplr(verts) p = patches.Polygon( diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index aeb972898c2..e1a5f4d2870 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import download_url, percentile_normalization +from .utils import download_url, lazy_import, percentile_normalization class ZueriCrop(NonGeoDataset): @@ -82,7 +82,10 @@ def __init__( Raises: DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If h5py is not installed. """ + lazy_import('h5py') + self._validate_bands(bands) self.band_indices = torch.tensor( [self.band_names.index(b) for b in bands] @@ -97,13 +100,6 @@ def __init__( self._verify() - try: - import h5py # noqa: F401 - except ImportError: - raise ImportError( - 'h5py is not installed and is required to use this dataset' - ) - def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -129,8 +125,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - import h5py - + h5py = lazy_import('h5py') with h5py.File(self.filepath, 'r') as f: length: int = f['data'].shape[0] return length @@ -144,8 +139,7 @@ def _load_image(self, index: int) -> Tensor: Returns: the image """ - import h5py - + h5py = lazy_import('h5py') with h5py.File(self.filepath, 'r') as f: array = f['data'][index, ...]