Skip to content

Commit

Permalink
Datasets: improve lazy import error msg for missing deps (microsoft#2054
Browse files Browse the repository at this point in the history
)

* Datasets: improve lazy import error msg for missing deps

* Add type annotation

* Use lazy imports throughout datasets

* Fix support for older scipy

* Fix support for older scipy

* CI: test optional datasets on every commit

* Update minversion and fix tests

* Double quotes preferred over single quotes

* Undo for now

* Fast-fail during dataset initialization

* Remove extraneous space

* MissingDependencyError -> DependencyNotFoundError
  • Loading branch information
adamjstewart authored May 15, 2024
1 parent ac16f49 commit 189dabd
Show file tree
Hide file tree
Showing 33 changed files with 235 additions and 470 deletions.
28 changes: 0 additions & 28 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,6 @@ on:
branches:
- release**
jobs:
datasets:
name: datasets
runs-on: ubuntu-latest
steps:
- name: Clone repo
uses: actions/[email protected]
- name: Set up python
uses: actions/[email protected]
with:
python-version: "3.12"
- name: Cache dependencies
uses: actions/[email protected]
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install .[tests]
pip cache purge
- name: List pip dependencies
run: pip list
- name: Run pytest checks
run: |
pytest --cov=torchgeo --cov-report=xml --durations=10
python -m torchgeo --help
torchgeo --help
integration:
name: integration
runs-on: ubuntu-latest
Expand Down
33 changes: 33 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ jobs:
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}
datasets:
name: datasets
runs-on: ubuntu-latest
env:
MPLBACKEND: Agg
steps:
- name: Clone repo
uses: actions/[email protected]
- name: Set up python
uses: actions/[email protected]
with:
python-version: "3.12"
- name: Cache dependencies
uses: actions/[email protected]
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }}
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/tests.txt
pip cache purge
- name: List pip dependencies
run: pip list
- name: Run pytest checks
run: |
pytest --cov=torchgeo --cov-report=xml --durations=10
python3 -m torchgeo --help
- name: Report coverage
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
cancel-in-progress: true
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -535,4 +535,5 @@ Errors
------

.. autoclass:: DatasetNotFoundError
.. autoclass:: DependencyNotFoundError
.. autoclass:: RGBBandsMissingError
1 change: 0 additions & 1 deletion tests/datamodules/test_usavars.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
class TestUSAVarsDataModule:
@pytest.fixture
def datamodule(self, request: SubRequest) -> USAVarsDataModule:
pytest.importorskip('pandas', minversion='1.1.3')
root = os.path.join('tests', 'data', 'usavars')
batch_size = 1
num_workers = 0
Expand Down
26 changes: 2 additions & 24 deletions tests/datasets/test_advance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand All @@ -16,6 +14,8 @@
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE, DatasetNotFoundError

pytest.importorskip('scipy', minversion='1.7.2')


def download_url(url: str, root: str, *args: str) -> None:
shutil.copy(url, root)
Expand All @@ -37,19 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE:
transforms = nn.Identity()
return ADVANCE(root, transforms, download=True, checksum=True)

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'scipy.io':
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mocked_import)

def test_getitem(self, dataset: ADVANCE) -> None:
pytest.importorskip('scipy', minversion='1.6.2')
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
Expand All @@ -71,17 +59,7 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ADVANCE(str(tmp_path))

def test_mock_missing_module(
self, dataset: ADVANCE, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='scipy is not installed and is required to use this dataset',
):
dataset[0]

def test_plot(self, dataset: ADVANCE) -> None:
pytest.importorskip('scipy', minversion='1.6.2')
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()
Expand Down
24 changes: 1 addition & 23 deletions tests/datasets/test_chabud.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand All @@ -17,7 +15,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import ChaBuD, DatasetNotFoundError

pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')


def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
Expand Down Expand Up @@ -47,17 +45,6 @@ def dataset(
checksum=True,
)

@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mocked_import)

def test_getitem(self, dataset: ChaBuD) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand Down Expand Up @@ -85,15 +72,6 @@ def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChaBuD(str(tmp_path))

def test_mock_missing_module(
self, dataset: ChaBuD, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
ChaBuD(dataset.root, download=True, checksum=True)

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
ChaBuD(bands=['OK', 'BK'])
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
UnionDataset,
)

pytest.importorskip('zipfile_deflate64')


def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)


class TestChesapeake13:
pytest.importorskip('zipfile_deflate64')

@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13:
monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url)
Expand Down
24 changes: 1 addition & 23 deletions tests/datasets/test_cropharvest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand All @@ -16,25 +14,14 @@
import torchgeo.datasets.utils
from torchgeo.datasets import CropHarvest, DatasetNotFoundError

pytest.importorskip('h5py', minversion='3')
pytest.importorskip('h5py', minversion='3.6')


def download_url(url: str, root: str, filename: str, md5: str) -> None:
shutil.copy(url, os.path.join(root, filename))


class TestCropHarvest:
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == 'h5py':
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mocked_import)

@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url)
Expand Down Expand Up @@ -89,12 +76,3 @@ def test_plot(self, dataset: CropHarvest) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
plt.close()

def test_mock_missing_module(
self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None
) -> None:
with pytest.raises(
ImportError,
match='h5py is not installed and is required to use this dataset',
):
CropHarvest(root=str(tmp_path), download=True)[0]
11 changes: 10 additions & 1 deletion tests/datasets/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import pytest
from torch.utils.data import Dataset

from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError
from torchgeo.datasets import (
DatasetNotFoundError,
DependencyNotFoundError,
RGBBandsMissingError,
)


class TestDatasetNotFoundError:
Expand Down Expand Up @@ -55,6 +59,11 @@ def test_paths_download(self) -> None:
raise DatasetNotFoundError(ds)


def test_missing_dependency() -> None:
with pytest.raises(DependencyNotFoundError, match='pip install foo'):
raise DependencyNotFoundError('foo')


def test_rgb_bands_missing() -> None:
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):
Expand Down
33 changes: 0 additions & 33 deletions tests/datasets/test_idtrees.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import builtins
import glob
import os
import shutil
from pathlib import Path
from typing import Any

import matplotlib.pyplot as plt
import pytest
Expand Down Expand Up @@ -50,19 +48,6 @@ def dataset(
transforms = nn.Identity()
return IDTReeS(root, split, task, transforms, download=True, checksum=True)

@pytest.fixture(params=['laspy', 'pyvista'])
def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
import_orig = builtins.__import__
package = str(request.param)

def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == package:
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, '__import__', mocked_import)
return package

def test_getitem(self, dataset: IDTReeS) -> None:
x = dataset[0]
assert isinstance(x, dict)
Expand Down Expand Up @@ -101,24 +86,6 @@ def test_not_extracted(self, tmp_path: Path) -> None:
shutil.copy(zipfile, root)
IDTReeS(root)

def test_mock_missing_module(
self, dataset: IDTReeS, mock_missing_module: str
) -> None:
package = mock_missing_module

if package == 'laspy':
with pytest.raises(
ImportError,
match=f'{package} is not installed and is required to use this dataset',
):
IDTReeS(dataset.root, dataset.split, dataset.task)
elif package == 'pyvista':
with pytest.raises(
ImportError,
match=f'{package} is not installed and is required to plot point cloud',
):
dataset.plot_las(0)

def test_plot(self, dataset: IDTReeS) -> None:
x = dataset[0].copy()
dataset.plot(x, suptitle='Test')
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def test_plot(self, dataset: LandCoverAIGeo) -> None:


class TestLandCoverAI:
pytest.importorskip('cv2', minversion='4.5.4')

@pytest.fixture(params=['train', 'val', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> LandCoverAI:
pytest.importorskip('cv2', minversion='4.4.0')
monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url)
md5 = 'ff8998857cc8511f644d3f7d0f3688d0'
monkeypatch.setattr(LandCoverAI, 'md5', md5)
Expand Down Expand Up @@ -111,7 +112,6 @@ def test_already_extracted(self, dataset: LandCoverAI) -> None:
LandCoverAI(root=dataset.root, download=True)

def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
pytest.importorskip('cv2', minversion='4.4.0')
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
Expand Down
Loading

0 comments on commit 189dabd

Please sign in to comment.