diff --git a/tests/datasets/azcopy b/tests/datasets/azcopy new file mode 100755 index 00000000000..1f74b4c4d0b --- /dev/null +++ b/tests/datasets/azcopy @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Basic mock-up of the azcopy CLI.""" + +import argparse +import shutil + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + copy = subparsers.add_parser('copy') + copy.add_argument('source') + copy.add_argument('destination') + copy.add_argument('--recursive', default='false') + sync = subparsers.add_parser('sync') + sync.add_argument('source') + sync.add_argument('destination') + sync.add_argument('--recursive', default='true') + args, _ = parser.parse_known_args() + + if args.recursive == 'true': + shutil.copytree(args.source, args.destination, dirs_exist_ok=True) + else: + shutil.copy(args.source, args.destination) diff --git a/tests/datasets/conftest.py b/tests/datasets/conftest.py new file mode 100644 index 00000000000..3f59d69581b --- /dev/null +++ b/tests/datasets/conftest.py @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import pytest +from pytest import MonkeyPatch + +from torchgeo.datasets.utils import Executable, which + + +@pytest.fixture +def azcopy(monkeypatch: MonkeyPatch) -> Executable: + path = os.path.dirname(os.path.realpath(__file__)) + monkeypatch.setenv('PATH', path, prepend=os.pathsep) + return which('azcopy') diff --git a/tests/datasets/test_errors.py b/tests/datasets/test_errors.py index f87ab6f03a3..a32c00dbf02 100644 --- a/tests/datasets/test_errors.py +++ b/tests/datasets/test_errors.py @@ -59,9 +59,10 @@ def test_paths_download(self) -> None: raise DatasetNotFoundError(ds) -def test_missing_dependency() -> None: - with pytest.raises(DependencyNotFoundError, match='pip install foo'): - raise DependencyNotFoundError('foo') +def test_dependency_not_found() -> None: + msg = 'foo not installed' + with pytest.raises(DependencyNotFoundError, match=msg): + raise DependencyNotFoundError(msg) def test_rgb_bands_missing() -> None: diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index d8e44885418..be8ddf667b8 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -21,6 +21,7 @@ import torchgeo.datasets.utils from torchgeo.datasets import BoundingBox, DependencyNotFoundError from torchgeo.datasets.utils import ( + Executable, array_to_tensor, concat_samples, disambiguate_timestamp, @@ -33,6 +34,7 @@ percentile_normalization, stack_samples, unbind_samples, + which, working_dir, ) @@ -590,3 +592,14 @@ def test_lazy_import(name: str) -> None: def test_lazy_import_missing(name: str) -> None: with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'): lazy_import(name) + + +def test_azcopy(tmp_path: Path, azcopy: Executable) -> None: + source = os.path.join('tests', 'data', 'cyclone') + azcopy('sync', source, tmp_path, '--recursive=true') + assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels') + + +def test_which() -> None: + with pytest.raises(DependencyNotFoundError, match='foo is not installed'): + which('foo') diff --git a/torchgeo/datasets/errors.py b/torchgeo/datasets/errors.py index 4b8d7a75982..e2124c1b1d2 100644 --- a/torchgeo/datasets/errors.py +++ b/torchgeo/datasets/errors.py @@ -49,31 +49,12 @@ def __init__(self, dataset: Dataset[object]) -> None: super().__init__(msg) -class DependencyNotFoundError(ModuleNotFoundError): +class DependencyNotFoundError(Exception): """Raised when an optional dataset dependency is not installed. .. versionadded:: 0.6 """ - def __init__(self, name: str) -> None: - """Initialize a new DependencyNotFoundError instance. - - Args: - name: Name of missing dependency. - """ - msg = f"""\ -{name} is not installed and is required to use this dataset. Either run: - -$ pip install {name} - -to install just this dependency, or: - -$ pip install torchgeo[datasets] - -to install all optional dataset dependencies.""" - - super().__init__(msg) - class RGBBandsMissingError(ValueError): """Raised when a dataset is missing RGB bands for plotting. diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 3e27d7fc6a7..12ed2aad40b 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -13,6 +13,8 @@ import importlib import lzma import os +import shutil +import subprocess import sys import tarfile from collections.abc import Iterable, Iterator, Sequence @@ -402,6 +404,34 @@ def split( return bbox1, bbox2 +class Executable: + """Command-line executable. + + .. versionadded:: 0.6 + """ + + def __init__(self, name: str) -> None: + """Initialize a new Executable instance. + + Args: + name: Command name. + """ + self.name = name + + def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[bytes]: + """Run the command. + + Args: + args: Arguments to pass to the command. + kwargs: Keyword arguments to pass to :func:`subprocess.run`. + + Returns: + The completed process. + """ + kwargs['check'] = True + return subprocess.run((self.name,) + args, **kwargs) + + def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: """Disambiguate partial timestamps. @@ -772,6 +802,9 @@ def lazy_import(name: str) -> Any: Args: name: Name of module to import. + Returns: + Module import. + Raises: DependencyNotFoundError: If *name* is not installed. @@ -785,4 +818,35 @@ def lazy_import(name: str) -> Any: module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name) module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'} name = module_to_pypi[name] - raise DependencyNotFoundError(name) from None + msg = f"""\ +{name} is not installed and is required to use this dataset. Either run: + +$ pip install {name} + +to install just this dependency, or: + +$ pip install torchgeo[datasets] + +to install all optional dataset dependencies.""" + raise DependencyNotFoundError(msg) from None + + +def which(name: str) -> Executable: + """Search for executable *name*. + + Args: + name: Name of executable to search for. + + Returns: + Callable executable instance. + + Raises: + DependencyNotFoundError: If *name* is not installed. + + .. versionadded:: 0.6 + """ + if shutil.which(name): + return Executable(name) + else: + msg = f'{name} is not installed and is required to use this dataset.' + raise DependencyNotFoundError(msg) from None