Skip to content

Commit

Permalink
Datasets: add CLI support (microsoft#2064)
Browse files Browse the repository at this point in the history
* Datasets: add CLI support

* Return completed process

* Fix return type

* More powerful azcopy mock
  • Loading branch information
adamjstewart authored May 23, 2024
1 parent 39cc9b6 commit 91bbb83
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 24 deletions.
27 changes: 27 additions & 0 deletions tests/datasets/azcopy
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Basic mock-up of the azcopy CLI."""

import argparse
import shutil

if __name__ == '__main__':
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
copy = subparsers.add_parser('copy')
copy.add_argument('source')
copy.add_argument('destination')
copy.add_argument('--recursive', default='false')
sync = subparsers.add_parser('sync')
sync.add_argument('source')
sync.add_argument('destination')
sync.add_argument('--recursive', default='true')
args, _ = parser.parse_known_args()

if args.recursive == 'true':
shutil.copytree(args.source, args.destination, dirs_exist_ok=True)
else:
shutil.copy(args.source, args.destination)
16 changes: 16 additions & 0 deletions tests/datasets/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import pytest
from pytest import MonkeyPatch

from torchgeo.datasets.utils import Executable, which


@pytest.fixture
def azcopy(monkeypatch: MonkeyPatch) -> Executable:
path = os.path.dirname(os.path.realpath(__file__))
monkeypatch.setenv('PATH', path, prepend=os.pathsep)
return which('azcopy')
7 changes: 4 additions & 3 deletions tests/datasets/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def test_paths_download(self) -> None:
raise DatasetNotFoundError(ds)


def test_missing_dependency() -> None:
with pytest.raises(DependencyNotFoundError, match='pip install foo'):
raise DependencyNotFoundError('foo')
def test_dependency_not_found() -> None:
msg = 'foo not installed'
with pytest.raises(DependencyNotFoundError, match=msg):
raise DependencyNotFoundError(msg)


def test_rgb_bands_missing() -> None:
Expand Down
13 changes: 13 additions & 0 deletions tests/datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
from torchgeo.datasets.utils import (
Executable,
array_to_tensor,
concat_samples,
disambiguate_timestamp,
Expand All @@ -33,6 +34,7 @@
percentile_normalization,
stack_samples,
unbind_samples,
which,
working_dir,
)

Expand Down Expand Up @@ -590,3 +592,14 @@ def test_lazy_import(name: str) -> None:
def test_lazy_import_missing(name: str) -> None:
with pytest.raises(DependencyNotFoundError, match='pip install foo-bar\n'):
lazy_import(name)


def test_azcopy(tmp_path: Path, azcopy: Executable) -> None:
source = os.path.join('tests', 'data', 'cyclone')
azcopy('sync', source, tmp_path, '--recursive=true')
assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels')


def test_which() -> None:
with pytest.raises(DependencyNotFoundError, match='foo is not installed'):
which('foo')
21 changes: 1 addition & 20 deletions torchgeo/datasets/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,12 @@ def __init__(self, dataset: Dataset[object]) -> None:
super().__init__(msg)


class DependencyNotFoundError(ModuleNotFoundError):
class DependencyNotFoundError(Exception):
"""Raised when an optional dataset dependency is not installed.
.. versionadded:: 0.6
"""

def __init__(self, name: str) -> None:
"""Initialize a new DependencyNotFoundError instance.
Args:
name: Name of missing dependency.
"""
msg = f"""\
{name} is not installed and is required to use this dataset. Either run:
$ pip install {name}
to install just this dependency, or:
$ pip install torchgeo[datasets]
to install all optional dataset dependencies."""

super().__init__(msg)


class RGBBandsMissingError(ValueError):
"""Raised when a dataset is missing RGB bands for plotting.
Expand Down
66 changes: 65 additions & 1 deletion torchgeo/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import importlib
import lzma
import os
import shutil
import subprocess
import sys
import tarfile
from collections.abc import Iterable, Iterator, Sequence
Expand Down Expand Up @@ -402,6 +404,34 @@ def split(
return bbox1, bbox2


class Executable:
"""Command-line executable.
.. versionadded:: 0.6
"""

def __init__(self, name: str) -> None:
"""Initialize a new Executable instance.
Args:
name: Command name.
"""
self.name = name

def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[bytes]:
"""Run the command.
Args:
args: Arguments to pass to the command.
kwargs: Keyword arguments to pass to :func:`subprocess.run`.
Returns:
The completed process.
"""
kwargs['check'] = True
return subprocess.run((self.name,) + args, **kwargs)


def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]:
"""Disambiguate partial timestamps.
Expand Down Expand Up @@ -772,6 +802,9 @@ def lazy_import(name: str) -> Any:
Args:
name: Name of module to import.
Returns:
Module import.
Raises:
DependencyNotFoundError: If *name* is not installed.
Expand All @@ -785,4 +818,35 @@ def lazy_import(name: str) -> Any:
module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name)
module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'}
name = module_to_pypi[name]
raise DependencyNotFoundError(name) from None
msg = f"""\
{name} is not installed and is required to use this dataset. Either run:
$ pip install {name}
to install just this dependency, or:
$ pip install torchgeo[datasets]
to install all optional dataset dependencies."""
raise DependencyNotFoundError(msg) from None


def which(name: str) -> Executable:
"""Search for executable *name*.
Args:
name: Name of executable to search for.
Returns:
Callable executable instance.
Raises:
DependencyNotFoundError: If *name* is not installed.
.. versionadded:: 0.6
"""
if shutil.which(name):
return Executable(name)
else:
msg = f'{name} is not installed and is required to use this dataset.'
raise DependencyNotFoundError(msg) from None

0 comments on commit 91bbb83

Please sign in to comment.