diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index d02b7238218..458269e24c7 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -7,34 +7,6 @@ on:
branches:
- release**
jobs:
- datasets:
- name: datasets
- runs-on: ubuntu-latest
- steps:
- - name: Clone repo
- uses: actions/checkout@v4.1.5
- - name: Set up python
- uses: actions/setup-python@v5.1.0
- with:
- python-version: "3.12"
- - name: Cache dependencies
- uses: actions/cache@v4.0.2
- id: cache
- with:
- path: ${{ env.pythonLocation }}
- key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets
- - name: Install pip dependencies
- if: steps.cache.outputs.cache-hit != 'true'
- run: |
- pip install .[tests]
- pip cache purge
- - name: List pip dependencies
- run: pip list
- - name: Run pytest checks
- run: |
- pytest --cov=torchgeo --cov-report=xml --durations=10
- python -m torchgeo --help
- torchgeo --help
integration:
name: integration
runs-on: ubuntu-latest
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index 58d2f8f3b01..9988d3cb5df 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -99,6 +99,39 @@ jobs:
uses: codecov/codecov-action@v4.3.1
with:
token: ${{ secrets.CODECOV_TOKEN }}
+ datasets:
+ name: datasets
+ runs-on: ubuntu-latest
+ env:
+ MPLBACKEND: Agg
+ steps:
+ - name: Clone repo
+ uses: actions/checkout@v4.1.4
+ - name: Set up python
+ uses: actions/setup-python@v5.1.0
+ with:
+ python-version: "3.12"
+ - name: Cache dependencies
+ uses: actions/cache@v4.0.2
+ id: cache
+ with:
+ path: ${{ env.pythonLocation }}
+ key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }}
+ - name: Install pip dependencies
+ if: steps.cache.outputs.cache-hit != 'true'
+ run: |
+ pip install -r requirements/required.txt -r requirements/tests.txt
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
+ - name: Run pytest checks
+ run: |
+ pytest --cov=torchgeo --cov-report=xml --durations=10
+ python3 -m torchgeo --help
+ - name: Report coverage
+ uses: codecov/codecov-action@v4.3.0
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
cancel-in-progress: true
diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 265bac96109..168e291cdb4 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -535,4 +535,5 @@ Errors
------
.. autoclass:: DatasetNotFoundError
+.. autoclass:: DependencyNotFoundError
.. autoclass:: RGBBandsMissingError
diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py
index 004750c2840..c6d5b7c77bf 100644
--- a/tests/datamodules/test_usavars.py
+++ b/tests/datamodules/test_usavars.py
@@ -14,7 +14,6 @@
class TestUSAVarsDataModule:
@pytest.fixture
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
- pytest.importorskip('pandas', minversion='1.1.3')
root = os.path.join('tests', 'data', 'usavars')
batch_size = 1
num_workers = 0
diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py
index efa3ac538f8..f2a34b89f4c 100644
--- a/tests/datasets/test_advance.py
+++ b/tests/datasets/test_advance.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -16,6 +14,8 @@
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE, DatasetNotFoundError
+pytest.importorskip('scipy', minversion='1.7.2')
+
def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
@@ -37,19 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE:
transforms = nn.Identity()
return ADVANCE(root, transforms, download=True, checksum=True)
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == 'scipy.io':
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
def test_getitem(self, dataset: ADVANCE) -> None:
- pytest.importorskip('scipy', minversion='1.6.2')
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
@@ -71,17 +59,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ADVANCE(str(tmp_path))
- def test_mock_missing_module(
- self, dataset: ADVANCE, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match='scipy is not installed and is required to use this dataset',
- ):
- dataset[0]
-
def test_plot(self, dataset: ADVANCE) -> None:
- pytest.importorskip('scipy', minversion='1.6.2')
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()
diff --git a/tests/datasets/test_chabud.py b/tests/datasets/test_chabud.py
index 955104a53fb..074674a1733 100644
--- a/tests/datasets/test_chabud.py
+++ b/tests/datasets/test_chabud.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -17,7 +15,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import ChaBuD, DatasetNotFoundError
-pytest.importorskip('h5py', minversion='3')
+pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
@@ -47,17 +45,6 @@ def dataset(
checksum=True,
)
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == 'h5py':
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
def test_getitem(self, dataset: ChaBuD) -> None:
x = dataset[0]
assert isinstance(x, dict)
@@ -85,15 +72,6 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChaBuD(str(tmp_path))
- def test_mock_missing_module(
- self, dataset: ChaBuD, tmp_path: Path, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match='h5py is not installed and is required to use this dataset',
- ):
- ChaBuD(dataset.root, download=True, checksum=True)
-
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
ChaBuD(bands=['OK', 'BK'])
diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py
index 459c6bcdbdf..e31b441529b 100644
--- a/tests/datasets/test_chesapeake.py
+++ b/tests/datasets/test_chesapeake.py
@@ -23,14 +23,14 @@
UnionDataset,
)
+pytest.importorskip('zipfile_deflate64')
+
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestChesapeake13:
- pytest.importorskip('zipfile_deflate64')
-
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13:
monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url)
diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py
index f478cdf53ad..2ad82fca137 100644
--- a/tests/datasets/test_cropharvest.py
+++ b/tests/datasets/test_cropharvest.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -16,7 +14,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import CropHarvest, DatasetNotFoundError
-pytest.importorskip('h5py', minversion='3')
+pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, filename: str, md5: str) -> None:
@@ -24,17 +22,6 @@ def download_url(url: str, root: str, filename: str, md5: str) -> None:
class TestCropHarvest:
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == 'h5py':
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url)
@@ -89,12 +76,3 @@ def test_plot(self, dataset: CropHarvest) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()
-
- def test_mock_missing_module(
- self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match='h5py is not installed and is required to use this dataset',
- ):
- CropHarvest(root=str(tmp_path), download=True)[0]
diff --git a/tests/datasets/test_errors.py b/tests/datasets/test_errors.py
index 4aa6fbc1c56..f87ab6f03a3 100644
--- a/tests/datasets/test_errors.py
+++ b/tests/datasets/test_errors.py
@@ -6,7 +6,11 @@
import pytest
from torch.utils.data import Dataset
-from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError
+from torchgeo.datasets import (
+ DatasetNotFoundError,
+ DependencyNotFoundError,
+ RGBBandsMissingError,
+)
class TestDatasetNotFoundError:
@@ -55,6 +59,11 @@ def test_paths_download(self) -> None:
raise DatasetNotFoundError(ds)
+def test_missing_dependency() -> None:
+ with pytest.raises(DependencyNotFoundError, match='pip install foo'):
+ raise DependencyNotFoundError('foo')
+
+
def test_rgb_bands_missing() -> None:
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):
diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py
index e907213c29c..a4c05580b58 100644
--- a/tests/datasets/test_idtrees.py
+++ b/tests/datasets/test_idtrees.py
@@ -1,12 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import glob
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -50,19 +48,6 @@ def dataset(
transforms = nn.Identity()
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
- @pytest.fixture(params=['laspy', 'pyvista'])
- def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
- import_orig = builtins.__import__
- package = str(request.param)
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == package:
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
- return package
-
def test_getitem(self, dataset: IDTReeS) -> None:
x = dataset[0]
assert isinstance(x, dict)
@@ -101,24 +86,6 @@ def test_not_extracted(self, tmp_path: Path) -> None:
shutil.copy(zipfile, root)
IDTReeS(root)
- def test_mock_missing_module(
- self, dataset: IDTReeS, mock_missing_module: str
- ) -> None:
- package = mock_missing_module
-
- if package == 'laspy':
- with pytest.raises(
- ImportError,
- match=f'{package} is not installed and is required to use this dataset',
- ):
- IDTReeS(dataset.root, dataset.split, dataset.task)
- elif package == 'pyvista':
- with pytest.raises(
- ImportError,
- match=f'{package} is not installed and is required to plot point cloud',
- ):
- dataset.plot_las(0)
-
def test_plot(self, dataset: IDTReeS) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py
index 7222ff78bbc..7c81f257250 100644
--- a/tests/datasets/test_landcoverai.py
+++ b/tests/datasets/test_landcoverai.py
@@ -76,11 +76,12 @@ def test_plot(self, dataset: LandCoverAIGeo) -> None:
class TestLandCoverAI:
+ pytest.importorskip('cv2', minversion='4.5.4')
+
@pytest.fixture(params=['train', 'val', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> LandCoverAI:
- pytest.importorskip('cv2', minversion='4.4.0')
monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url)
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
monkeypatch.setattr(LandCoverAI, 'md5', md5)
@@ -111,7 +112,6 @@ def test_already_extracted(self, dataset: LandCoverAI) -> None:
LandCoverAI(root=dataset.root, download=True)
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
- pytest.importorskip('cv2', minversion='4.4.0')
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py
index 59596f2c0c5..636e9d7a666 100644
--- a/tests/datasets/test_quakeset.py
+++ b/tests/datasets/test_quakeset.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -17,6 +15,8 @@
import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, QuakeSet
+pytest.importorskip('h5py', minversion='3.6')
+
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@@ -39,26 +39,6 @@ def dataset(
root, split, transforms=transforms, download=True, checksum=True
)
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == 'h5py':
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
- def test_mock_missing_module(
- self, dataset: QuakeSet, tmp_path: Path, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match='h5py is not installed and is required to use this dataset',
- ):
- QuakeSet(dataset.root, download=True, checksum=True)
-
def test_getitem(self, dataset: QuakeSet) -> None:
x = dataset[0]
assert isinstance(x, dict)
diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py
index d84a9709602..d52d2d01194 100644
--- a/tests/datasets/test_resisc45.py
+++ b/tests/datasets/test_resisc45.py
@@ -15,6 +15,8 @@
import torchgeo.datasets.utils
from torchgeo.datasets import RESISC45, DatasetNotFoundError
+pytest.importorskip('rarfile', minversion='4')
+
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@@ -25,8 +27,6 @@ class TestRESISC45:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> RESISC45:
- pytest.importorskip('rarfile', minversion='4')
-
monkeypatch.setattr(torchgeo.datasets.resisc45, 'download_url', download_url)
md5 = '5895dea3757ba88707d52f5521c444d3'
monkeypatch.setattr(RESISC45, 'md5', md5)
diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py
index 01907d22cea..d4deb975b3b 100644
--- a/tests/datasets/test_skippd.py
+++ b/tests/datasets/test_skippd.py
@@ -1,12 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from itertools import product
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -18,7 +16,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import SKIPPD, DatasetNotFoundError
-pytest.importorskip('h5py', minversion='3')
+pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -53,26 +51,6 @@ def dataset(
checksum=True,
)
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == 'h5py':
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
- def test_mock_missing_module(
- self, dataset: SKIPPD, tmp_path: Path, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match='h5py is not installed and is required to use this dataset',
- ):
- SKIPPD(dataset.root, download=True, checksum=True)
-
def test_already_extracted(self, dataset: SKIPPD) -> None:
SKIPPD(root=dataset.root, download=True)
diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py
index 6daa4057bed..1caf86b6c30 100644
--- a/tests/datasets/test_so2sat.py
+++ b/tests/datasets/test_so2sat.py
@@ -1,10 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -15,7 +13,7 @@
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, So2Sat
-pytest.importorskip('h5py', minversion='3')
+pytest.importorskip('h5py', minversion='3.6')
class TestSo2Sat:
@@ -35,17 +33,6 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> So2Sat:
transforms = nn.Identity()
return So2Sat(root=root, split=split, transforms=transforms, checksum=True)
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == 'h5py':
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
def test_getitem(self, dataset: So2Sat) -> None:
x = dataset[0]
assert isinstance(x, dict)
@@ -89,12 +76,3 @@ def test_plot_rgb(self, dataset: So2Sat) -> None:
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):
dataset.plot(dataset[0], suptitle='Single Band')
-
- def test_mock_missing_module(
- self, dataset: So2Sat, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match='h5py is not installed and is required to use this dataset',
- ):
- So2Sat(dataset.root)
diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py
index bd3e720652f..d8e44885418 100644
--- a/tests/datasets/test_utils.py
+++ b/tests/datasets/test_utils.py
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import glob
import math
import os
@@ -20,8 +19,8 @@
from rasterio.crs import CRS
import torchgeo.datasets.utils
+from torchgeo.datasets import BoundingBox, DependencyNotFoundError
from torchgeo.datasets.utils import (
- BoundingBox,
array_to_tensor,
concat_samples,
disambiguate_timestamp,
@@ -29,6 +28,7 @@
download_radiant_mlhub_collection,
download_radiant_mlhub_dataset,
extract_archive,
+ lazy_import,
merge_samples,
percentile_normalization,
stack_samples,
@@ -37,18 +37,6 @@
)
-@pytest.fixture
-def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name in ['radiant_mlhub', 'rarfile', 'zipfile_deflate64']:
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
-
class MLHubDataset:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
@@ -79,10 +67,6 @@ def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
-def test_mock_missing_module(mock_missing_module: None) -> None:
- import sys # noqa: F401
-
-
@pytest.mark.parametrize(
'src',
[
@@ -102,21 +86,6 @@ def test_extract_archive(src: str, tmp_path: Path) -> None:
extract_archive(os.path.join('tests', 'data', src), str(tmp_path))
-def test_missing_rarfile(mock_missing_module: None) -> None:
- with pytest.raises(
- ImportError,
- match='rarfile is not installed and is required to extract this dataset',
- ):
- extract_archive(
- os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar')
- )
-
-
-def test_missing_zipfile_deflate64(mock_missing_module: None) -> None:
- # Should fallback on Python builtin zipfile
- extract_archive(os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'))
-
-
def test_unsupported_scheme() -> None:
with pytest.raises(
RuntimeError, match='src file has unknown archival/compression scheme'
@@ -148,21 +117,6 @@ def test_download_radiant_mlhub_collection(
download_radiant_mlhub_collection('', str(tmp_path))
-def test_missing_radiant_mlhub(mock_missing_module: None) -> None:
- with pytest.raises(
- ImportError,
- match='radiant_mlhub is not installed and is required to download this dataset',
- ):
- download_radiant_mlhub_dataset('', '')
-
- with pytest.raises(
- ImportError,
- match='radiant_mlhub is not installed and is required to download this'
- + ' collection',
- ):
- download_radiant_mlhub_collection('', '')
-
-
class TestBoundingBox:
def test_repr_str(self) -> None:
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
@@ -625,3 +579,14 @@ def test_array_to_tensor(array_dtype: 'np.typing.DTypeLike') -> None:
# values equal even if they differ.
assert array[0].item() == tensor[0].item()
assert array[1].item() == tensor[1].item()
+
+
+@pytest.mark.parametrize('name', ['collections', 'collections.abc'])
+def test_lazy_import(name: str) -> None:
+ lazy_import(name)
+
+
+@pytest.mark.parametrize('name', ['foo_bar', 'foo_bar.baz'])
+def test_lazy_import_missing(name: str) -> None:
+ with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'):
+ lazy_import(name)
diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py
index de4c6c2d507..dee46c1db88 100644
--- a/tests/datasets/test_vhr10.py
+++ b/tests/datasets/test_vhr10.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -19,6 +17,7 @@
from torchgeo.datasets import VHR10, DatasetNotFoundError
pytest.importorskip('pycocotools')
+pytest.importorskip('rarfile', minversion='4')
def download_url(url: str, root: str, *args: str) -> None:
@@ -30,7 +29,6 @@ class TestVHR10:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> VHR10:
- pytest.importorskip('rarfile', minversion='4')
monkeypatch.setattr(torchgeo.datasets.vhr10, 'download_url', download_url)
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
url = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar')
@@ -46,17 +44,6 @@ def dataset(
transforms = nn.Identity()
return VHR10(root, split, transforms, download=True, checksum=True)
- @pytest.fixture
- def mock_missing_modules(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name in {'pycocotools.coco', 'skimage.measure'}:
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
def test_getitem(self, dataset: VHR10) -> None:
for i in range(2):
x = dataset[i]
@@ -93,25 +80,8 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
VHR10(str(tmp_path))
- def test_mock_missing_module(
- self, dataset: VHR10, mock_missing_modules: None
- ) -> None:
- if dataset.split == 'positive':
- with pytest.raises(
- ImportError,
- match='pycocotools is not installed and is required to use this datase',
- ):
- VHR10(dataset.root, dataset.split)
-
- with pytest.raises(
- ImportError,
- match='scikit-image is not installed and is required to plot masks',
- ):
- x = dataset[0]
- dataset.plot(x)
-
def test_plot(self, dataset: VHR10) -> None:
- pytest.importorskip('skimage', minversion='0.18')
+ pytest.importorskip('skimage', minversion='0.19')
x = dataset[1].copy()
dataset.plot(x, suptitle='Test')
plt.close()
diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py
index 330866b36f0..bea0d9e8519 100644
--- a/tests/datasets/test_zuericrop.py
+++ b/tests/datasets/test_zuericrop.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -16,7 +14,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop
-pytest.importorskip('h5py', minversion='3')
+pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -39,17 +37,6 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop:
transforms = nn.Identity()
return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True)
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == 'h5py':
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, '__import__', mocked_import)
-
def test_getitem(self, dataset: ZueriCrop) -> None:
x = dataset[0]
assert isinstance(x, dict)
@@ -82,15 +69,6 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ZueriCrop(str(tmp_path))
- def test_mock_missing_module(
- self, dataset: ZueriCrop, tmp_path: Path, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match='h5py is not installed and is required to use this dataset',
- ):
- ZueriCrop(dataset.root, download=True, checksum=True)
-
def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):
ZueriCrop(bands=('OK', 'BK'))
diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py
index ca4da48bb78..3402bd0f3fc 100644
--- a/tests/trainers/test_classification.py
+++ b/tests/trainers/test_classification.py
@@ -89,7 +89,7 @@ def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
if name.startswith('so2sat') or name == 'quakeset':
- pytest.importorskip('h5py', minversion='3')
+ pytest.importorskip('h5py', minversion='3.6')
config = os.path.join('tests', 'conf', name + '.yaml')
diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py
index ef3c6164d98..c62c808c72f 100644
--- a/tests/trainers/test_regression.py
+++ b/tests/trainers/test_regression.py
@@ -71,7 +71,7 @@ def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
if name == 'skippd':
- pytest.importorskip('h5py', minversion='3')
+ pytest.importorskip('h5py', minversion='3.6')
config = os.path.join('tests', 'conf', name + '.yaml')
diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py
index 88cd58c0553..d8b207d5d2d 100644
--- a/tests/trainers/test_segmentation.py
+++ b/tests/trainers/test_segmentation.py
@@ -87,12 +87,16 @@ class TestSemanticSegmentationTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
- if name == 'naipchesapeake':
- pytest.importorskip('zipfile_deflate64')
-
- if name == 'landcoverai':
- sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
- monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
+ match name:
+ case 'chabud':
+ pytest.importorskip('h5py', minversion='3.6')
+ case 'landcoverai':
+ sha256 = (
+ 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
+ )
+ monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
+ case 'naipchesapeake':
+ pytest.importorskip('zipfile_deflate64')
config = os.path.join('tests', 'conf', name + '.yaml')
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index db3c9786547..56303ceaefa 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -37,7 +37,7 @@
from .dfc2022 import DFC2022
from .eddmaps import EDDMapS
from .enviroatlas import EnviroAtlas
-from .errors import DatasetNotFoundError, RGBBandsMissingError
+from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError
from .esri2020 import Esri2020
from .etci2021 import ETCI2021
from .eudem import EUDEM
@@ -280,5 +280,6 @@
'time_series_split',
# Errors
'DatasetNotFoundError',
+ 'DependencyNotFoundError',
'RGBBandsMissingError',
)
diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py
index f14392b24dd..0cc02544715 100644
--- a/torchgeo/datasets/advance.py
+++ b/torchgeo/datasets/advance.py
@@ -17,7 +17,7 @@
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
-from .utils import download_and_extract_archive
+from .utils import download_and_extract_archive, lazy_import
class ADVANCE(NonGeoDataset):
@@ -104,7 +104,10 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
+ DependencyNotFoundError: If scipy is not installed.
"""
+ lazy_import('scipy.io.wavfile')
+
self.root = root
self.transforms = transforms
self.checksum = checksum
@@ -191,14 +194,8 @@ def _load_target(self, path: str) -> Tensor:
Returns:
the target audio
"""
- try:
- from scipy.io import wavfile
- except ImportError:
- raise ImportError(
- 'scipy is not installed and is required to use this dataset'
- )
-
- array = wavfile.read(path, mmap=True)[1]
+ siw = lazy_import('scipy.io.wavfile')
+ array = siw.read(path, mmap=True)[1]
tensor = torch.from_numpy(array)
tensor = tensor.unsqueeze(0)
return tensor
diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py
index 42e3fdad339..905c2d8496e 100644
--- a/torchgeo/datasets/chabud.py
+++ b/torchgeo/datasets/chabud.py
@@ -14,7 +14,7 @@
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
-from .utils import download_url, percentile_normalization
+from .utils import download_url, lazy_import, percentile_normalization
class ChaBuD(NonGeoDataset):
@@ -96,7 +96,10 @@ def __init__(
Raises:
AssertionError: If ``split`` or ``bands`` arguments are invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
+ DependencyNotFoundError: If h5py is not installed.
"""
+ lazy_import('h5py')
+
assert split in self.folds
assert set(bands) <= set(self.all_bands)
@@ -111,13 +114,6 @@ def __init__(
self._verify()
- try:
- import h5py # noqa: F401
- except ImportError:
- raise ImportError(
- 'h5py is not installed and is required to use this dataset'
- )
-
self.uuids = self._load_uuids()
def __getitem__(self, index: int) -> dict[str, Tensor]:
@@ -153,8 +149,7 @@ def _load_uuids(self) -> list[str]:
Returns:
the image uuids
"""
- import h5py
-
+ h5py = lazy_import('h5py')
uuids = []
with h5py.File(self.filepath, 'r') as f:
for k, v in f.items():
@@ -173,8 +168,7 @@ def _load_image(self, index: int) -> Tensor:
Returns:
the image
"""
- import h5py
-
+ h5py = lazy_import('h5py')
uuid = self.uuids[index]
with h5py.File(self.filepath, 'r') as f:
pre_array = f[uuid]['pre_fire'][:]
@@ -199,8 +193,7 @@ def _load_target(self, index: int) -> Tensor:
Returns:
the target mask
"""
- import h5py
-
+ h5py = lazy_import('h5py')
uuid = self.uuids[index]
with h5py.File(self.filepath, 'r') as f:
array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1)
diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py
index 7a41e829971..400b5ceb63c 100644
--- a/torchgeo/datasets/cropharvest.py
+++ b/torchgeo/datasets/cropharvest.py
@@ -17,7 +17,7 @@
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive
+from .utils import download_url, extract_archive, lazy_import
class CropHarvest(NonGeoDataset):
@@ -112,14 +112,9 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
- ImportError: If h5py is not installed
+ DependencyNotFoundError: If h5py is not installed.
"""
- try:
- import h5py # noqa: F401
- except ImportError:
- raise ImportError(
- 'h5py is not installed and is required to use this dataset'
- )
+ lazy_import('h5py')
self.root = root
self.transforms = transforms
@@ -210,8 +205,7 @@ def _load_array(self, path: str) -> Tensor:
Returns:
the image
"""
- import h5py
-
+ h5py = lazy_import('h5py')
filename = os.path.join(path)
with h5py.File(filename, 'r') as f:
array = f.get('array')[()]
diff --git a/torchgeo/datasets/errors.py b/torchgeo/datasets/errors.py
index c78a7819227..4b8d7a75982 100644
--- a/torchgeo/datasets/errors.py
+++ b/torchgeo/datasets/errors.py
@@ -49,6 +49,32 @@ def __init__(self, dataset: Dataset[object]) -> None:
super().__init__(msg)
+class DependencyNotFoundError(ModuleNotFoundError):
+ """Raised when an optional dataset dependency is not installed.
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(self, name: str) -> None:
+ """Initialize a new DependencyNotFoundError instance.
+
+ Args:
+ name: Name of missing dependency.
+ """
+ msg = f"""\
+{name} is not installed and is required to use this dataset. Either run:
+
+$ pip install {name}
+
+to install just this dependency, or:
+
+$ pip install torchgeo[datasets]
+
+to install all optional dataset dependencies."""
+
+ super().__init__(msg)
+
+
class RGBBandsMissingError(ValueError):
"""Raised when a dataset is missing RGB bands for plotting.
diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py
index b36f2a9db9c..6ddc0c9661b 100644
--- a/torchgeo/datasets/idtrees.py
+++ b/torchgeo/datasets/idtrees.py
@@ -22,7 +22,7 @@
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive
+from .utils import download_url, extract_archive, lazy_import
class IDTReeS(NonGeoDataset):
@@ -92,6 +92,11 @@ class IDTReeS(NonGeoDataset):
* https://doi.org/10.1101/2021.08.06.453503
+ This dataset requires the following additional libraries to be installed:
+
+ * `laspy `_ to read lidar point clouds
+ * `pyvista `_ to plot lidar point clouds
+
.. versionadded:: 0.2
"""
@@ -167,11 +172,14 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- ImportError: if laspy is not installed
DatasetNotFoundError: If dataset is not found and *download* is False.
+ DependencyNotFoundError: If laspy is not installed.
"""
+ lazy_import('laspy')
+
assert split in ['train', 'test']
assert task in ['task1', 'task2']
+
self.root = root
self.split = split
self.task = task
@@ -182,14 +190,6 @@ def __init__(
self.idx2class = {i: c for i, c in enumerate(self.classes)}
self.num_classes = len(self.classes)
self._verify()
-
- try:
- import laspy # noqa: F401
- except ImportError:
- raise ImportError(
- 'laspy is not installed and is required to use this dataset'
- )
-
self.images, self.geometries, self.labels = self._load(root)
def __getitem__(self, index: int) -> dict[str, Tensor]:
@@ -263,8 +263,7 @@ def _load_las(self, path: str) -> Tensor:
Returns:
the point cloud
"""
- import laspy
-
+ laspy = lazy_import('laspy')
las = laspy.read(path)
array: 'np.typing.NDArray[np.int_]' = np.stack([las.x, las.y, las.z], axis=0)
tensor = torch.from_numpy(array)
@@ -561,19 +560,13 @@ def plot_las(self, index: int) -> 'pyvista.Plotter': # type: ignore[name-define
pyvista.PolyData object. Run pyvista.plot(point_cloud, ...) to display
Raises:
- ImportError: if pyvista is not installed
+ DependencyNotFoundError: If laspy or pyvista are not installed.
.. versionchanged:: 0.4
Ported from Open3D to PyVista, *colormap* parameter removed.
"""
- try:
- import pyvista # noqa: F401
- except ImportError:
- raise ImportError(
- 'pyvista is not installed and is required to plot point clouds'
- )
- import laspy
-
+ laspy = lazy_import('laspy')
+ pyvista = lazy_import('pyvista')
path = self.images[index]
path = path.replace('RGB', 'LAS').replace('.tif', '.las')
las = laspy.read(path)
diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py
index 7ca10f02718..ce5d9a3bd2c 100644
--- a/torchgeo/datasets/quakeset.py
+++ b/torchgeo/datasets/quakeset.py
@@ -15,7 +15,7 @@
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
-from .utils import download_url, percentile_normalization
+from .utils import download_url, lazy_import, percentile_normalization
class QuakeSet(NonGeoDataset):
@@ -85,8 +85,10 @@ def __init__(
Raises:
AssertionError: If ``split`` argument is invalid.
DatasetNotFoundError: If dataset is not found and *download* is False.
- ImportError: if h5py is not installed
+ DependencyNotFoundError: If h5py is not installed.
"""
+ lazy_import('h5py')
+
assert split in self.splits
self.root = root
@@ -95,16 +97,7 @@ def __init__(
self.download = download
self.checksum = checksum
self.filepath = os.path.join(root, self.filename)
-
self._verify()
-
- try:
- import h5py # noqa: F401
- except ImportError:
- raise ImportError(
- 'h5py is not installed and is required to use this dataset'
- )
-
self.data = self._load_data()
def __getitem__(self, index: int) -> dict[str, Tensor]:
@@ -141,8 +134,7 @@ def _load_data(self) -> list[dict[str, Any]]:
Returns:
the sample keys, patches, images, labels, and magnitudes
"""
- import h5py
-
+ h5py = lazy_import('h5py')
data = []
with h5py.File(self.filepath) as f:
for k in sorted(f.keys()):
@@ -185,7 +177,7 @@ def _load_image(self, index: int) -> Tensor:
Returns:
the image
"""
- import h5py
+ h5py = lazy_import('h5py')
key = self.data[index]['key']
patch = self.data[index]['patch']
diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py
index 376e4a1ce47..0d111ae15b9 100644
--- a/torchgeo/datasets/skippd.py
+++ b/torchgeo/datasets/skippd.py
@@ -16,7 +16,7 @@
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive
+from .utils import download_url, extract_archive, lazy_import
class SKIPPD(NonGeoDataset):
@@ -53,6 +53,12 @@ class SKIPPD(NonGeoDataset):
* https://doi.org/10.48550/arXiv.2207.00913
+ .. note::
+
+ This dataset requires the following additional library to be installed:
+
+ * ``_ to load the dataset
+
.. versionadded:: 0.5
"""
@@ -94,8 +100,10 @@ def __init__(
Raises:
AssertionError: if ``task`` or ``split`` is invalid
DatasetNotFoundError: If dataset is not found and *download* is False.
- ImportError: if h5py is not installed
+ DependencyNotFoundError: If h5py is not installed.
"""
+ lazy_import('h5py')
+
assert (
split in self.valid_splits
), f'Please choose one of these valid data splits {self.valid_splits}.'
@@ -110,14 +118,6 @@ def __init__(
self.transforms = transforms
self.download = download
self.checksum = checksum
-
- try:
- import h5py # noqa: F401
- except ImportError:
- raise ImportError(
- 'h5py is not installed and is required to use this dataset'
- )
-
self._verify()
def __len__(self) -> int:
@@ -126,8 +126,7 @@ def __len__(self) -> int:
Returns:
length of the dataset
"""
- import h5py
-
+ h5py = lazy_import('h5py')
with h5py.File(
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
) as f:
@@ -161,8 +160,7 @@ def _load_image(self, index: int) -> Tensor:
Returns:
image tensor at index
"""
- import h5py
-
+ h5py = lazy_import('h5py')
with h5py.File(
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
) as f:
@@ -187,8 +185,7 @@ def _load_features(self, index: int) -> dict[str, str | Tensor]:
Returns:
label tensor at index
"""
- import h5py
-
+ h5py = lazy_import('h5py')
with h5py.File(
os.path.join(self.root, self.data_file_name.format(self.task)), 'r'
) as f:
diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py
index 77281ce3e7b..ee6f1435474 100644
--- a/torchgeo/datasets/so2sat.py
+++ b/torchgeo/datasets/so2sat.py
@@ -15,7 +15,7 @@
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
-from .utils import check_integrity, percentile_normalization
+from .utils import check_integrity, lazy_import, percentile_normalization
class So2Sat(NonGeoDataset):
@@ -97,6 +97,12 @@ class So2Sat(NonGeoDataset):
done
or manually downloaded from https://mediatum.ub.tum.de/1613658
+
+ .. note::
+
+ This dataset requires the following additional library to be installed:
+
+ * ``_ to load the dataset
""" # noqa: E501
versions = ['2', '3_random', '3_block', '3_culture_10']
@@ -210,6 +216,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
DatasetNotFoundError: If dataset is not found.
+ DependencyNotFoundError: If h5py is not installed.
.. versionadded:: 0.3
The *bands* parameter.
@@ -217,12 +224,8 @@ def __init__(
.. versionadded:: 0.5
The *version* parameter.
"""
- try:
- import h5py # noqa: F401
- except ImportError:
- raise ImportError(
- 'h5py is not installed and is required to use this dataset'
- )
+ h5py = lazy_import('h5py')
+
assert version in self.versions
assert split in self.filenames_by_version[version]
@@ -272,8 +275,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
Returns:
data and label at that index
"""
- import h5py
-
+ h5py = lazy_import('h5py')
with h5py.File(self.fn, 'r') as f:
s1 = f['sen1'][index].astype(np.float64) # convert from None:
self.kwargs = kwargs
def __enter__(self) -> Any:
- try:
- import rarfile
- except ImportError:
- raise ImportError(
- 'rarfile is not installed and is required to extract this dataset'
- )
-
+ rarfile = lazy_import('rarfile')
# TODO: catch exception for when rarfile is installed but not
# unrar/unar/bsdtar
return rarfile.RarFile(*self.args, **self.kwargs)
@@ -157,14 +154,11 @@ def download_radiant_mlhub_dataset(
api_key: the API key to use for all requests from the session. Can also be
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
``~/.mlhub/profiles``.
- """
- try:
- import radiant_mlhub
- except ImportError:
- raise ImportError(
- 'radiant_mlhub is not installed and is required to download this dataset'
- )
+ Raises:
+ DependencyNotFoundError: If radiant_mlhub is not installed.
+ """
+ radiant_mlhub = lazy_import('radiant_mlhub')
dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key)
dataset.download(output_dir=download_root, api_key=api_key)
@@ -180,14 +174,11 @@ def download_radiant_mlhub_collection(
api_key: the API key to use for all requests from the session. Can also be
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
``~/.mlhub/profiles``.
- """
- try:
- import radiant_mlhub
- except ImportError:
- raise ImportError(
- 'radiant_mlhub is not installed and is required to download this collection'
- )
+ Raises:
+ DependencyNotFoundError: If radiant_mlhub is not installed.
+ """
+ radiant_mlhub = lazy_import('radiant_mlhub')
collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key)
collection.download(output_dir=download_root, api_key=api_key)
@@ -773,3 +764,25 @@ def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor:
elif array.dtype == np.uint32:
array = array.astype(np.int64)
return torch.tensor(array)
+
+
+def lazy_import(name: str) -> Any:
+ """Lazy import of *name*.
+
+ Args:
+ name: Name of module to import.
+
+ Raises:
+ DependencyNotFoundError: If *name* is not installed.
+
+ .. versionadded:: 0.6
+ """
+ try:
+ return importlib.import_module(name)
+ except ModuleNotFoundError:
+ # Map from import name to package name on PyPI
+ name = name.split('.')[0].replace('_', '-')
+ module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name)
+ module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'}
+ name = module_to_pypi[name]
+ raise DependencyNotFoundError(name) from None
diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py
index 13937336281..7f77b29513c 100644
--- a/torchgeo/datasets/vhr10.py
+++ b/torchgeo/datasets/vhr10.py
@@ -17,7 +17,12 @@
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
-from .utils import check_integrity, download_and_extract_archive, download_url
+from .utils import (
+ check_integrity,
+ download_and_extract_archive,
+ download_url,
+ lazy_import,
+)
def convert_coco_poly_to_mask(
@@ -32,13 +37,15 @@ def convert_coco_poly_to_mask(
Returns:
Tensor: Mask tensor
- """
- from pycocotools import mask as coco_mask # noqa: F401
+ Raises:
+ DependencyNotFoundError: If pycocotools is not installed.
+ """
+ pycocotools = lazy_import('pycocotools')
masks = []
for polygons in segmentations:
- rles = coco_mask.frPyObjects(polygons, height, width)
- mask = coco_mask.decode(rles)
+ rles = pycocotools.mask.frPyObjects(polygons, height, width)
+ mask = pycocotools.mask.decode(rles)
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
@@ -196,8 +203,9 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- ImportError: if ``split="positive"`` and pycocotools is not installed
DatasetNotFoundError: If dataset is not found and *download* is False.
+ DependencyNotFoundError: if ``split="positive"`` and pycocotools is
+ not installed.
"""
assert split in ['positive', 'negative']
@@ -213,20 +221,12 @@ def __init__(
raise DatasetNotFoundError(self)
if split == 'positive':
- # Must be installed to parse annotations file
- try:
- from pycocotools.coco import COCO # noqa: F401
- except ImportError:
- raise ImportError(
- 'pycocotools is not installed and is required to use this dataset'
- )
-
- self.coco = COCO(
+ pc = lazy_import('pycocotools.coco')
+ self.coco = pc.COCO(
os.path.join(
self.root, 'NWPU VHR-10 dataset', self.target_meta['filename']
)
)
-
self.coco_convert = ConvertCocoAnnotations()
self.ids = list(sorted(self.coco.imgs.keys()))
@@ -381,7 +381,7 @@ def plot(
Raises:
AssertionError: if ``show_feats`` argument is invalid
- ImportError: if plotting masks and scikit-image is not installed
+ DependencyNotFoundError: If plotting masks and scikit-image is not installed.
.. versionadded:: 0.4
"""
@@ -397,12 +397,7 @@ def plot(
return fig
if show_feats != 'boxes':
- try:
- from skimage.measure import find_contours # noqa: F401
- except ImportError:
- raise ImportError(
- 'scikit-image is not installed and is required to plot masks.'
- )
+ skimage = lazy_import('skimage')
image = sample['image'].permute(1, 2, 0).numpy()
boxes = sample['boxes'].cpu().numpy()
@@ -465,7 +460,7 @@ def plot(
# Add masks
if show_feats in {'masks', 'both'} and 'masks' in sample:
mask = masks[i]
- contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
+ contours = skimage.measure.find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
@@ -517,7 +512,7 @@ def plot(
# Add masks
if show_pred_masks:
mask = prediction_masks[i]
- contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
+ contours = skimage.measure.find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py
index aeb972898c2..e1a5f4d2870 100644
--- a/torchgeo/datasets/zuericrop.py
+++ b/torchgeo/datasets/zuericrop.py
@@ -13,7 +13,7 @@
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
-from .utils import download_url, percentile_normalization
+from .utils import download_url, lazy_import, percentile_normalization
class ZueriCrop(NonGeoDataset):
@@ -82,7 +82,10 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
+ DependencyNotFoundError: If h5py is not installed.
"""
+ lazy_import('h5py')
+
self._validate_bands(bands)
self.band_indices = torch.tensor(
[self.band_names.index(b) for b in bands]
@@ -97,13 +100,6 @@ def __init__(
self._verify()
- try:
- import h5py # noqa: F401
- except ImportError:
- raise ImportError(
- 'h5py is not installed and is required to use this dataset'
- )
-
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
@@ -129,8 +125,7 @@ def __len__(self) -> int:
Returns:
length of the dataset
"""
- import h5py
-
+ h5py = lazy_import('h5py')
with h5py.File(self.filepath, 'r') as f:
length: int = f['data'].shape[0]
return length
@@ -144,8 +139,7 @@ def _load_image(self, index: int) -> Tensor:
Returns:
the image
"""
- import h5py
-
+ h5py = lazy_import('h5py')
with h5py.File(self.filepath, 'r') as f:
array = f['data'][index, ...]