diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
index 78d364d9dde..ceb43f081f1 100644
--- a/.github/ISSUE_TEMPLATE/config.yml
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -3,3 +3,6 @@ contact_links:
- name: ❓ Questions
url: https://github.com/microsoft/torchgeo/discussions
about: Ask questions or discuss ideas with other TorchGeo users
+ - name: 💬 Chat
+ url: https://join.slack.com/t/torchgeo/shared_invite/zt-22rse667m-eqtCeNW0yI000Tl4B~2PIw
+ about: Chat with fellow TorchGeo users and developers on Slack
diff --git a/.github/labeler.yml b/.github/labeler.yml
index 683f7a9047b..a82ac823431 100644
--- a/.github/labeler.yml
+++ b/.github/labeler.yml
@@ -16,13 +16,13 @@ transforms:
# Other
dependencies:
-- "environment.yml"
-- "setup.cfg"
+- "pyproject.toml"
- "requirements/**"
documentation:
- "docs/**"
scripts:
-- "*.py"
+- "torchgeo/__main__.py"
+- "torchgeo/main.py"
- "experiments/**"
testing:
- "tests/**"
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index 2de22e69dca..25e630508b1 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -12,70 +12,82 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
- key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}
+ key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-datasets
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install .[tests]
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run pytest checks
- run: pytest --cov=torchgeo --cov-report=xml --durations=10
+ run: |
+ pytest --cov=torchgeo --cov-report=xml --durations=10
+ python -m torchgeo --help
+ torchgeo --help
integration:
name: integration
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
- key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}
+ key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-integration
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install .[datasets,tests]
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run integration checks
- run: pytest -m slow --durations=10
+ run: |
+ pytest -m slow --durations=10
+ python -m torchgeo --help
+ torchgeo --help
notebooks:
name: notebooks
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
- key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}
+ key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-tutorials
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
- pip install .[docs,tests] planetary_computer pystac pytest-rerunfailures
- pip list
+ pip install .[docs,tests] planetary_computer pystac
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run notebook checks
- run: pytest --nbmake --durations=10 --reruns=10 docs/tutorials
+ run: pytest --nbmake --durations=10 docs/tutorials
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
cancel-in-progress: true
diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml
index 022a5752c9e..5cccfd96032 100644
--- a/.github/workflows/style.yaml
+++ b/.github/workflows/style.yaml
@@ -14,13 +14,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -29,7 +29,9 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/style.txt
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run black checks
run: black . --check --diff
flake8:
@@ -37,13 +39,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -52,7 +54,9 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/style.txt
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run flake8 checks
run: flake8
isort:
@@ -60,13 +64,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -75,7 +79,9 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/style.txt
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run isort checks
run: isort . --check --diff
pydocstyle:
@@ -83,13 +89,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -98,7 +104,9 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/style.txt
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run pydocstyle checks
run: pydocstyle
pyupgrade:
@@ -106,13 +114,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -121,7 +129,9 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/style.txt
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run pyupgrade checks
run: pyupgrade --py39-plus $(find . -path ./docs/src -prune -o -name "*.py" -print)
concurrency:
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index bd898e57253..a078180b0c1 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -14,13 +14,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -29,7 +29,9 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/datasets.txt -r requirements/tests.txt
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run mypy checks
run: mypy .
pytest:
@@ -43,13 +45,13 @@ jobs:
python-version: ['3.9', '3.10', '3.11']
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: ${{ matrix.python-version }}
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -66,15 +68,19 @@ jobs:
run: brew install rar
if: ${{ runner.os == 'macOS' }}
- name: Install choco dependencies (Windows)
- run: choco install unrar
+ run: choco install 7zip
if: ${{ runner.os == 'Windows' }}
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/required.txt -r requirements/datasets.txt -r requirements/tests.txt
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run pytest checks
- run: pytest --cov=torchgeo --cov-report=xml --durations=10
+ run: |
+ pytest --cov=torchgeo --cov-report=xml --durations=10
+ python3 -m torchgeo --help
- name: Report coverage
uses: codecov/codecov-action@v3.1.4
with:
@@ -86,13 +92,13 @@ jobs:
MPLBACKEND: Agg
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.9'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
@@ -107,9 +113,13 @@ jobs:
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install -r requirements/min-reqs.old -c requirements/min-cons.old
- pip list
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run pytest checks
- run: pytest --cov=torchgeo --cov-report=xml --durations=10
+ run: |
+ pytest --cov=torchgeo --cov-report=xml --durations=10
+ python3 -m torchgeo --help
- name: Report coverage
uses: codecov/codecov-action@v3.1.4
with:
diff --git a/.github/workflows/tutorials.yaml b/.github/workflows/tutorials.yaml
index cdd94735bb2..e0f28b64147 100644
--- a/.github/workflows/tutorials.yaml
+++ b/.github/workflows/tutorials.yaml
@@ -16,24 +16,26 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Clone repo
- uses: actions/checkout@v4.0.0
+ uses: actions/checkout@v4.1.1
- name: Set up python
- uses: actions/setup-python@v4.7.0
+ uses: actions/setup-python@v4.7.1
with:
python-version: '3.11'
- name: Cache dependencies
- uses: actions/cache@v3.3.1
+ uses: actions/cache@v3.3.2
id: cache
with:
path: ${{ env.pythonLocation }}
- key: ${{ env.pythonLocation }}-${{ hashFiles('setup.cfg') }}
+ key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-tutorials
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
- pip install .[docs,tests] planetary_computer pystac pytest-rerunfailures
- pip list
+ pip install .[docs,tests] planetary_computer pystac
+ pip cache purge
+ - name: List pip dependencies
+ run: pip list
- name: Run notebook checks
- run: pytest --nbmake --durations=10 --reruns=10 docs/tutorials
+ run: pytest --nbmake --durations=10 docs/tutorials
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }}
cancel-in-progress: true
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index ba1822abf78..044afa0738c 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/asottile/pyupgrade
- rev: v3.3.1
+ rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py39-plus]
@@ -12,13 +12,13 @@ repos:
additional_dependencies: ['.[colors]']
- repo: https://github.com/psf/black
- rev: 23.1.0
+ rev: 23.10.1
hooks:
- id: black
args: [--skip-magic-trailing-comma]
- repo: https://github.com/pycqa/flake8.git
- rev: 6.0.0
+ rev: 6.1.0
hooks:
- id: flake8
@@ -30,9 +30,9 @@ repos:
additional_dependencies: ['.[toml]']
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.0.1
+ rev: v1.6.1
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports, --show-error-codes]
- additional_dependencies: [torch>=2, torchmetrics>=0.10, lightning>=2.0.3, pytest>=6.1.2, pyvista>=0.29, omegaconf>=2.0.1, hydra-core>=1, kornia>=0.6.5, numpy>=1.22]
+ additional_dependencies: [kornia>=0.6.5, lightning>=2.0.9, matplotlib>=3.8.1, numpy>=1.22, pytest>=6.1.2, pyvista>=0.34.2, torch>=2, torchmetrics>=0.10]
exclude: (build|data|dist|logo|logs|output)/
diff --git a/CITATION.cff b/CITATION.cff
index b2e4ad18ad3..e1d0db6a606 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -1,4 +1,18 @@
# https://github.com/citation-file-format/citation-file-format/blob/main/schema-guide.md
+# Can be validated using `cffconvert --validate`
+authors:
+- family-names: "Stewart"
+ given-names: "Adam J."
+- family-names: "Robinson"
+ given-names: "Caleb"
+- family-names: "Corley"
+ given-names: "Isaac A."
+- family-names: "Ortiz"
+ given-names: "Anthony"
+- family-names: "Lavista Ferres"
+ given-names: "Juan M."
+- family-names: "Banerjee"
+ given-names: "Arindam"
cff-version: "1.2.0"
message: "If you use this software, please cite it using the metadata from this file."
preferred-citation:
@@ -26,9 +40,11 @@ preferred-citation:
isbn: "9781450395298"
month: 11
number: 19
- publisher: "Association for Computing Machinery"
+ publisher:
+ name: "Association for Computing Machinery"
start: 1
title: "TorchGeo: Deep Learning With Geospatial Data"
type: "conference-paper"
url: "https://dl.acm.org/doi/10.1145/3557915.3560953"
year: 2022
+title: "TorchGeo: Deep Learning With Geospatial Data"
diff --git a/README.md b/README.md
index 8a5e59d114e..279a72fbf5f 100644
--- a/README.md
+++ b/README.md
@@ -7,17 +7,20 @@ The goal of this library is to make it simple:
1. for machine learning experts to work with geospatial data, and
2. for remote sensing experts to explore machine learning solutions.
-Testing:
-[![docs](https://readthedocs.org/projects/torchgeo/badge/?version=latest)](https://torchgeo.readthedocs.io/en/stable/)
-[![style](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml)
-[![tests](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml)
-[![codecov](https://codecov.io/gh/microsoft/torchgeo/branch/main/graph/badge.svg?token=oa3Z3PMVOg)](https://codecov.io/gh/microsoft/torchgeo)
+Community:
+[![slack](https://img.shields.io/badge/slack-join-purple?logo=slack)](https://join.slack.com/t/torchgeo/shared_invite/zt-22rse667m-eqtCeNW0yI000Tl4B~2PIw)
Packaging:
[![pypi](https://badge.fury.io/py/torchgeo.svg)](https://pypi.org/project/torchgeo/)
[![conda](https://anaconda.org/conda-forge/torchgeo/badges/version.svg)](https://anaconda.org/conda-forge/torchgeo)
[![spack](https://img.shields.io/spack/v/py-torchgeo)](https://spack.readthedocs.io/en/latest/package_list.html#py-torchgeo)
+Testing:
+[![docs](https://readthedocs.org/projects/torchgeo/badge/?version=latest)](https://torchgeo.readthedocs.io/en/stable/)
+[![style](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/style.yaml)
+[![tests](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml/badge.svg)](https://github.com/microsoft/torchgeo/actions/workflows/tests.yaml)
+[![codecov](https://codecov.io/gh/microsoft/torchgeo/branch/main/graph/badge.svg?token=oa3Z3PMVOg)](https://codecov.io/gh/microsoft/torchgeo)
+
## Installation
The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/):
@@ -119,6 +122,21 @@ for batch in dataloader:
All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch `Tensor`.
+### Pre-trained Weights
+
+Pre-trained weights have proven to be tremendously beneficial for transfer learning tasks in computer vision. Practitioners usually utilize models pre-trained on the ImageNet dataset, containing RGB images. However, remote sensing data often goes beyond RGB with additional multispectral channels that can vary across sensors. TorchGeo is the first library to support models pre-trained on different multispectral sensors, and adopts torchvision's [multi-weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). A summary of currently available weights can be seen in the [docs](https://torchgeo.readthedocs.io/en/stable/api/models.html#pretrained-weights). To create a [timm](https://github.com/huggingface/pytorch-image-models) Resnet-18 model with weights that have been pretrained on Sentinel-2 imagery, you can do the following:
+
+```python
+import timm
+from torchgeo.models import ResNet18_Weights
+
+weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
+model = timm.create_model("resnet18", in_chans=weights.meta["in_chans"], num_classes=10)
+model = model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
+```
+
+These weights can also directly be used in TorchGeo Lightning modules that are shown in the following section via the `weights` argument. For a notebook example, see this [tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html).
+
### Reproducibility with Lightning
In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created Lightning [*datamodules*](https://torchgeo.readthedocs.io/en/stable/api/datamodules.html) with well-defined train-val-test splits and [*trainers*](https://torchgeo.readthedocs.io/en/stable/api/trainers.html) for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a semantic segmentation model on the [Inria Aerial Image Labeling](https://project.inria.fr/aerialimagelabeling/) dataset is as easy as a few imports and four lines of code.
@@ -133,8 +151,8 @@ task = SemanticSegmentationTask(
num_classes=2,
loss="ce",
ignore_index=None,
- learning_rate=0.1,
- learning_rate_schedule_patience=6,
+ lr=0.1,
+ patience=6,
)
trainer = Trainer(default_root_dir="...")
@@ -143,12 +161,66 @@ trainer.fit(model=task, datamodule=datamodule)
-In our GitHub repo, we provide `train.py` and `evaluate.py` scripts to train and evaluate the performance of models using these datamodules and trainers. These scripts are configurable via the command line and/or via YAML configuration files. See the [conf](https://github.com/microsoft/torchgeo/blob/main/conf) directory for example configuration files that can be customized for different training runs.
+TorchGeo also supports command-line interface training using [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html). It can be invoked in two ways:
+
+```console
+# If torchgeo has been installed
+torchgeo
+# If torchgeo has been installed, or if it has been cloned to the current directory
+python3 -m torchgeo
+```
+
+It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:
+
+```console
+# See valid stages
+torchgeo --help
+# See valid trainer options
+torchgeo fit --help
+# See valid model options
+torchgeo fit --model.help ClassificationTask
+# See valid data options
+torchgeo fit --data.help EuroSAT100DataModule
+```
+
+Using the following config file:
+```yaml
+trainer:
+ max_epochs: 20
+model:
+ class_path: ClassificationTask
+ init_args:
+ model: "resnet18"
+ in_channels: 13
+ num_classes: 10
+data:
+ class_path: EuroSAT100DataModule
+ init_args:
+ batch_size: 8
+ dict_kwargs:
+ download: true
+```
+we can see the script in action:
```console
-$ python train.py config_file=conf/landcoverai.yaml
+# Train and validate a model
+torchgeo fit --config config.yaml
+# Validate-only
+torchgeo validate --config config.yaml
+# Calculate and report test accuracy
+torchgeo test --config config.yaml --trainer.ckpt_path=...
+```
+
+It can also be imported and used in a Python script if you need to extend it to add new features:
+
+```python
+from torchgeo.main import main
+
+main(["fit", "--config", "config.yaml"])
```
+See the [Lightning documentation](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) for more details.
+
## Citation
If you use this software in your work, please cite our [paper](https://dl.acm.org/doi/10.1145/3557915.3560953):
@@ -158,7 +230,7 @@ If you use this software in your work, please cite our [paper](https://dl.acm.or
author = {Stewart, Adam J. and Robinson, Caleb and Corley, Isaac A. and Ortiz, Anthony and Lavista Ferres, Juan M. and Banerjee, Arindam},
booktitle = {Proceedings of the 30th International Conference on Advances in Geographic Information Systems},
doi = {10.1145/3557915.3560953},
- month = {11},
+ month = nov,
pages = {1--12},
publisher = {Association for Computing Machinery},
series = {SIGSPATIAL '22},
diff --git a/conf/bigearthnet.yaml b/conf/bigearthnet.yaml
deleted file mode 100644
index 3f159efa4b1..00000000000
--- a/conf/bigearthnet.yaml
+++ /dev/null
@@ -1,24 +0,0 @@
-module:
- _target_: torchgeo.trainers.MultiLabelClassificationTask
- loss: "bce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 14
- num_classes: 19
-
-datamodule:
- _target_: torchgeo.datamodules.BigEarthNetDataModule
- root: "data/bigearthnet"
- bands: "all"
- num_classes: ${module.num_classes}
- batch_size: 128
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/chesapeake_cvpr.yaml b/conf/chesapeake_cvpr.yaml
deleted file mode 100644
index 81af245a35a..00000000000
--- a/conf/chesapeake_cvpr.yaml
+++ /dev/null
@@ -1,34 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 4
- num_classes: 7
- num_filters: 256
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
- root: "data/chesapeake/cvpr"
- train_splits:
- - "de-train"
- val_splits:
- - "de-val"
- test_splits:
- - "de-test"
- batch_size: 200
- patch_size: 256
- num_workers: 4
- class_set: ${module.num_classes}
- use_prior_labels: False
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/cowc_counting.yaml b/conf/cowc_counting.yaml
deleted file mode 100644
index 3b5d36779aa..00000000000
--- a/conf/cowc_counting.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-module:
- _target_: torchgeo.trainers.RegressionTask
- model: resnet18
- weights: null
- num_outputs: 1
- in_channels: 3
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
-
-datamodule:
- _target_: torchgeo.datamodules.COWCCountingDataModule
- root: "data/cowc_counting"
- batch_size: 64
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/cyclone.yaml b/conf/cyclone.yaml
deleted file mode 100644
index 2bb689ed4bf..00000000000
--- a/conf/cyclone.yaml
+++ /dev/null
@@ -1,21 +0,0 @@
-module:
- _target_: torchgeo.trainers.RegressionTask
- model: "resnet18"
- weights: null
- num_outputs: 1
- in_channels: 3
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
-
-datamodule:
- _target_: torchgeo.datamodules.TropicalCycloneDataModule
- root: "data/cyclone"
- batch_size: 32
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/deepglobelandcover.yaml b/conf/deepglobelandcover.yaml
deleted file mode 100644
index 0260e0ac0f3..00000000000
--- a/conf/deepglobelandcover.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 7
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.DeepGlobeLandCoverDataModule
- root: "data/deepglobelandcover"
- batch_size: 1
- patch_size: 64
- val_split_pct: 0.5
- num_workers: 0
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/defaults.yaml b/conf/defaults.yaml
deleted file mode 100644
index 15d58be2656..00000000000
--- a/conf/defaults.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-config_file: null # This lets the user pass a config filename to load other arguments from
-
-program: # These are the arguments that define how the train.py script works
- seed: 0
- output_dir: output
- data_dir: data
- log_dir: logs
- overwrite: False
diff --git a/conf/etci2021.yaml b/conf/etci2021.yaml
deleted file mode 100644
index e993b8ac628..00000000000
--- a/conf/etci2021.yaml
+++ /dev/null
@@ -1,24 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: true
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 6
- num_classes: 2
- ignore_index: 0
-
-datamodule:
- _target_: torchgeo.datamodules.ETCI2021DataModule
- root: "data/etci2021"
- batch_size: 32
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/eurosat.yaml b/conf/eurosat.yaml
deleted file mode 100644
index b90f7823e01..00000000000
--- a/conf/eurosat.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 13
- num_classes: 10
-
-datamodule:
- _target_: torchgeo.datamodules.EuroSATDataModule
- root: "data/eurosat"
- batch_size: 128
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/gid15.yaml b/conf/gid15.yaml
deleted file mode 100644
index f46672da6ce..00000000000
--- a/conf/gid15.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 16
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.GID15DataModule
- root: "data/gid15"
- batch_size: 1
- patch_size: 64
- val_split_pct: 0.5
- num_workers: 0
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/inria.yaml b/conf/inria.yaml
deleted file mode 100644
index bbf73669a6a..00000000000
--- a/conf/inria.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: true
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 3
- num_classes: 2
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.InriaAerialImageLabelingDataModule
- root: "data/inria"
- batch_size: 1
- patch_size: 512
- num_workers: 32
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/l7irish.yaml b/conf/l7irish.yaml
deleted file mode 100644
index 5f221aa9ef0..00000000000
--- a/conf/l7irish.yaml
+++ /dev/null
@@ -1,23 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 9
- num_classes: 5
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.L7IrishDataModule
- root: "data/l7irish"
- batch_size: 64
- patch_size: 224
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/l8biome.yaml b/conf/l8biome.yaml
deleted file mode 100644
index b5bf7b552de..00000000000
--- a/conf/l8biome.yaml
+++ /dev/null
@@ -1,23 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 11
- num_classes: 5
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.L8BiomeDataModule
- root: "data/l8biome"
- batch_size: 64
- patch_size: 224
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/landcoverai.yaml b/conf/landcoverai.yaml
deleted file mode 100644
index f70667fe056..00000000000
--- a/conf/landcoverai.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: true
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 3
- num_classes: 5
- num_filters: 256
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.LandCoverAIDataModule
- root: "data/landcoverai"
- batch_size: 32
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/naipchesapeake.yaml b/conf/naipchesapeake.yaml
deleted file mode 100644
index 94f6cafcab6..00000000000
--- a/conf/naipchesapeake.yaml
+++ /dev/null
@@ -1,27 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "deeplabv3+"
- backbone: "resnet34"
- weights: true
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- in_channels: 4
- num_classes: 14
- num_filters: 64
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.NAIPChesapeakeDataModule
- naip_root: "data/naip"
- chesapeake_root: "data/chesapeake/BAYWIDE"
- batch_size: 32
- num_workers: 4
- patch_size: 32
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml
deleted file mode 100644
index d176e95c0e1..00000000000
--- a/conf/nasa_marine_debris.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-module:
- _target_: torchgeo.trainers.ObjectDetectionTask
- model: "faster-rcnn"
- backbone: "resnet50"
- num_classes: 2
- learning_rate: 1.2e-4
- learning_rate_schedule_patience: 6
- verbose: false
-
-datamodule:
- _target_: torchgeo.datamodules.NASAMarineDebrisDataModule
- root: "data/nasamr/nasa_marine_debris"
- batch_size: 4
- num_workers: 6
- val_split_pct: 0.2
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/potsdam2d.yaml b/conf/potsdam2d.yaml
deleted file mode 100644
index 747e99c2047..00000000000
--- a/conf/potsdam2d.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 4
- num_classes: 6
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.Potsdam2DDataModule
- root: "data/potsdam"
- batch_size: 1
- patch_size: 64
- val_split_pct: 0.5
- num_workers: 0
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/resisc45.yaml b/conf/resisc45.yaml
deleted file mode 100644
index fc22c9ca9e3..00000000000
--- a/conf/resisc45.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 3
- num_classes: 45
-
-datamodule:
- _target_: torchgeo.datamodules.RESISC45DataModule
- root: "data/resisc45"
- batch_size: 128
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/seco_100k.yaml b/conf/seco_100k.yaml
deleted file mode 100644
index 41c6338bc02..00000000000
--- a/conf/seco_100k.yaml
+++ /dev/null
@@ -1,24 +0,0 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 12
- backbone: "resnet18"
- weights: True
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- optimizer: "Adam"
-
-datamodule:
- _target_: torchgeo.datamodules.SeasonalContrastS2DataModule
- root: "data/seco"
- version: "100k"
- seasons: 2
- bands: ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/sen12ms.yaml b/conf/sen12ms.yaml
deleted file mode 100644
index f1b4643c426..00000000000
--- a/conf/sen12ms.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- in_channels: 15
- num_classes: 11
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.SEN12MSDataModule
- root: "data/sen12ms"
- band_set: "all"
- batch_size: 32
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml
deleted file mode 100644
index 4a785a50e00..00000000000
--- a/conf/so2sat.yaml
+++ /dev/null
@@ -1,23 +0,0 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 18
- num_classes: 17
-
-datamodule:
- _target_: torchgeo.datamodules.So2SatDataModule
- root: "data/so2sat"
- batch_size: 128
- num_workers: 4
- band_set: "all"
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/spacenet1.yaml b/conf/spacenet1.yaml
deleted file mode 100644
index 82955319a57..00000000000
--- a/conf/spacenet1.yaml
+++ /dev/null
@@ -1,24 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: true
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 3
- num_classes: 3
- ignore_index: 0
-
-datamodule:
- _target_: torchgeo.datamodules.SpaceNet1DataModule
- root: "data/spacenet"
- batch_size: 32
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/ssl4eo_benchmark_etm_sr_cdl.yaml b/conf/ssl4eo_benchmark_etm_sr_cdl.yaml
deleted file mode 100644
index ed64e22b701..00000000000
--- a/conf/ssl4eo_benchmark_etm_sr_cdl.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 6
- num_classes: 18
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "etm_sr"
- product: "cdl"
- classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml b/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml
deleted file mode 100644
index ba6a6dd8dfc..00000000000
--- a/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 6
- num_classes: 14
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "etm_sr"
- product: "nlcd"
- classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_etm_toa_cdl.yaml b/conf/ssl4eo_benchmark_etm_toa_cdl.yaml
deleted file mode 100644
index da11cf9f42c..00000000000
--- a/conf/ssl4eo_benchmark_etm_toa_cdl.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 9
- num_classes: 18
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "etm_toa"
- product: "cdl"
- classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml b/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml
deleted file mode 100644
index 8e7e701416b..00000000000
--- a/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 9
- num_classes: 14
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "etm_toa"
- product: "nlcd"
- classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_oli_sr_cdl.yaml b/conf/ssl4eo_benchmark_oli_sr_cdl.yaml
deleted file mode 100644
index 292390cc25c..00000000000
--- a/conf/ssl4eo_benchmark_oli_sr_cdl.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 7
- num_classes: 18
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "oli_sr"
- product: "cdl"
- classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml b/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml
deleted file mode 100644
index 982f8cd5b02..00000000000
--- a/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 7
- num_classes: 14
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "oli_sr"
- product: "nlcd"
- classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml b/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml
deleted file mode 100644
index 7ab684024b0..00000000000
--- a/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 11
- num_classes: 18
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "oli_tirs_toa"
- product: "cdl"
- classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml b/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml
deleted file mode 100644
index 050801e6964..00000000000
--- a/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 11
- num_classes: 14
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "oli_tirs_toa"
- product: "nlcd"
- classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_tm_toa_cdl.yaml b/conf/ssl4eo_benchmark_tm_toa_cdl.yaml
deleted file mode 100644
index bc3ccdc4396..00000000000
--- a/conf/ssl4eo_benchmark_tm_toa_cdl.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 7
- num_classes: 18
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "tm_toa"
- product: "cdl"
- classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml b/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml
deleted file mode 100644
index d81cfaff6f5..00000000000
--- a/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml
+++ /dev/null
@@ -1,25 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- model: "unet"
- backbone: "resnet18"
- weights: null
- in_channels: 7
- num_classes: 14
- loss: "ce"
- ignore_index: 0
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "data/ssl4eo_benchmark"
- sensor: "tm_toa"
- product: "nlcd"
- classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
- batch_size: 64
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- min_epochs: 20
- max_epochs: 100
diff --git a/conf/ssl4eo_l_oli_tirs_toa_mocov2_resnet50.yaml b/conf/ssl4eo_l_oli_tirs_toa_mocov2_resnet50.yaml
deleted file mode 100644
index e65dc56dee8..00000000000
--- a/conf/ssl4eo_l_oli_tirs_toa_mocov2_resnet50.yaml
+++ /dev/null
@@ -1,40 +0,0 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: resnet50
- weights: True
- in_channels: 11
- version: 2
- layers: 2
- hidden_dim: 2048
- output_dim: 128
- lr: 0.12
- weight_decay: 1e-4
- momentum: 0.9
- schedule: [120, 160]
- temperature: 0.07
- memory_bank_size: 65536
- moco_momentum: 0.999
- gather_distributed: True
- size: 224
- grayscale_weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: /path/to/data/
- split: oli_tirs_toa
- seasons: 2
- batch_size: 256
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 4
- limit_val_batches: 0.0
- max_epochs: 200
- log_every_n_steps: 5
-
-program:
- overwrite: True
- output_dir: output/ssl4eo-l-oli-tirs-toa-mocov2-resnet50
- log_dir: logs/ssl4eo-l-oli-tirs-toa-mocov2-resnet50
diff --git a/conf/ssl4eo_l_oli_tirs_toa_mocov2_vits16.yaml b/conf/ssl4eo_l_oli_tirs_toa_mocov2_vits16.yaml
deleted file mode 100644
index c538f68b574..00000000000
--- a/conf/ssl4eo_l_oli_tirs_toa_mocov2_vits16.yaml
+++ /dev/null
@@ -1,40 +0,0 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: vit_small_patch16_224
- weights: True
- in_channels: 11
- version: 2
- layers: 2
- hidden_dim: 2048
- output_dim: 128
- lr: 0.012
- weight_decay: 1e-4
- momentum: 0.9
- schedule: [120, 160]
- temperature: 0.07
- memory_bank_size: 65536
- moco_momentum: 0.999
- gather_distributed: True
- size: 224
- grayscale_weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: /path/to/data/
- split: oli_tirs_toa
- seasons: 2
- batch_size: 256
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 4
- limit_val_batches: 0.0
- max_epochs: 200
- log_every_n_steps: 5
-
-program:
- overwrite: True
- output_dir: output/ssl4eo-l-oli-tirs-toa-mocov2-vits16
- log_dir: logs/ssl4eo-l-oli-tirs-toa-mocov2-vits16
diff --git a/conf/ssl4eo_l_oli_tirs_toa_simclr_resnet50.yaml b/conf/ssl4eo_l_oli_tirs_toa_simclr_resnet50.yaml
deleted file mode 100644
index 619e2bdf955..00000000000
--- a/conf/ssl4eo_l_oli_tirs_toa_simclr_resnet50.yaml
+++ /dev/null
@@ -1,35 +0,0 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: resnet50
- weights: True
- in_channels: 11
- version: 1
- layers: 2
- hidden_dim: 2048
- output_dim: 128
- lr: 0.12
- memory_bank_size: 0
- gather_distributed: True
- size: 224
- grayscale_weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: /path/to/data/
- split: oli_tirs_toa
- seasons: 2
- batch_size: 256
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 4
- limit_val_batches: 0.0
- max_epochs: 200
- log_every_n_steps: 5
-
-program:
- overwrite: True
- output_dir: output/ssl4eo-l-oli-tirs-toa-simclr-resnet50
- log_dir: logs/ssl4eo-l-oli-tirs-toa-simclr-resnet50
diff --git a/conf/ssl4eo_l_oli_tirs_toa_simclr_vits16.yaml b/conf/ssl4eo_l_oli_tirs_toa_simclr_vits16.yaml
deleted file mode 100644
index 8d89cb06ff4..00000000000
--- a/conf/ssl4eo_l_oli_tirs_toa_simclr_vits16.yaml
+++ /dev/null
@@ -1,35 +0,0 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: vit_small_patch16_224
- weights: True
- in_channels: 11
- version: 1
- layers: 2
- hidden_dim: 2048
- output_dim: 128
- lr: 0.012
- memory_bank_size: 0
- gather_distributed: True
- size: 224
- grayscale_weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: /path/to/data/
- split: oli_tirs_toa
- seasons: 2
- batch_size: 256
- num_workers: 16
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 4
- limit_val_batches: 0.0
- max_epochs: 200
- log_every_n_steps: 5
-
-program:
- overwrite: True
- output_dir: output/ssl4eo-l-oli-tirs-toa-simclr-vits16
- log_dir: logs/ssl4eo-l-oli-tirs-toa-simclr-vits16
diff --git a/conf/ucmerced.yaml b/conf/ucmerced.yaml
deleted file mode 100644
index 95fbe6fb87c..00000000000
--- a/conf/ucmerced.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 3
- num_classes: 21
-
-datamodule:
- _target_: torchgeo.datamodules.UCMercedDataModule
- root: "data/ucmerced"
- batch_size: 128
- num_workers: 4
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/conf/vaihingen2d.yaml b/conf/vaihingen2d.yaml
deleted file mode 100644
index 4c5cf3b139a..00000000000
--- a/conf/vaihingen2d.yaml
+++ /dev/null
@@ -1,28 +0,0 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 7
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.Vaihingen2DDataModule
- root: "data/vaihingen"
- batch_size: 1
- patch_size: 64
- val_split_pct: 0.5
- num_workers: 0
-
-trainer:
- _target_: lightning.pytorch.Trainer
- accelerator: gpu
- devices: 1
- min_epochs: 15
- max_epochs: 40
diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst
index 9676c7c0337..eca7faac028 100644
--- a/docs/api/datamodules.rst
+++ b/docs/api/datamodules.rst
@@ -131,6 +131,11 @@ SSL4EO
.. autoclass:: SSL4EOLDataModule
.. autoclass:: SSL4EOS12DataModule
+SSL4EO-L Benchmark
+^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: SSL4EOLBenchmarkDataModule
+
SustainBench Crop Yield
^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 3d417c8a7c4..b9fb8ce4eb9 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -174,6 +174,11 @@ BigEarthNet
.. autoclass:: BigEarthNet
+BioMassters
+^^^^^^^^^^^
+
+.. autoclass:: BioMassters
+
Cloud Cover Detection
^^^^^^^^^^^^^^^^^^^^^
@@ -257,6 +262,11 @@ LoveDA
.. autoclass:: LoveDA
+MapInWild
+^^^^^^^^^
+
+.. autoclass:: MapInWild
+
Million-AID
^^^^^^^^^^^
@@ -297,11 +307,21 @@ RESISC45
.. autoclass:: RESISC45
+Rwanda Field Boundary
+^^^^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: RwandaFieldBoundary
+
Seasonal Contrast
^^^^^^^^^^^^^^^^^
.. autoclass:: SeasonalContrastS2
+SeasoNet
+^^^^^^^^
+
+.. autoclass:: SeasoNet
+
SEN12MS
^^^^^^^
@@ -336,6 +356,11 @@ SSL4EO
.. autoclass:: SSL4EOL
.. autoclass:: SSL4EOS12
+SSL4EO-L Benchmark
+^^^^^^^^^^^^^^^^^^
+
+.. autoclass:: SSL4EOLBenchmark
+
SustainBench Crop Yield
^^^^^^^^^^^^^^^^^^^^^^^
@@ -444,3 +469,8 @@ Splitting Functions
.. autofunction:: random_grid_cell_assignment
.. autofunction:: roi_split
.. autofunction:: time_series_split
+
+Errors
+------
+
+.. autoclass:: DatasetNotFoundError
diff --git a/docs/api/landsat_pretrained_weights.csv b/docs/api/landsat_pretrained_weights.csv
new file mode 100644
index 00000000000..b5b9fcba74f
--- /dev/null
+++ b/docs/api/landsat_pretrained_weights.csv
@@ -0,0 +1,31 @@
+Weight,Landsat,Channels,Source,Citation,NLCD (Acc),NLCD (mIoU),CDL (Acc),CDL (mIoU)
+ResNet18_Weights.LANDSAT_TM_TOA_MOCO,4--5,5,`link `__,`link `__,67.65,51.11,68.70,52.32
+ResNet18_Weights.LANDSAT_TM_TOA_SIMCLR,4--5,5,`link `__,`link `__,60.86,43.74,61.94,44.86
+ResNet50_Weights.LANDSAT_TM_TOA_MOCO,4--5,5,`link `__,`link `__,68.75,53.28,69.45,53.20
+ResNet50_Weights.LANDSAT_TM_TOA_SIMCLR,4--5,5,`link `__,`link `__,62.05,44.98,62.80,45.77
+ViTSmall16_Weights.LANDSAT_TM_TOA_MOCO,4--5,5,`link `__,`link `__,67.17,50.57,67.60,51.07
+ViTSmall16_Weights.LANDSAT_TM_TOA_SIMCLR,4--5,5,`link `__,`link `__,66.82,50.17,66.92,50.28
+ResNet18_Weights.LANDSAT_ETM_TOA_MOCO,7,9,`link `__,`link `__,65.22,48.39,62.84,45.81
+ResNet18_Weights.LANDSAT_ETM_TOA_SIMCLR,7,9,`link `__,`link `__,58.76,41.60,56.47,39.34
+ResNet50_Weights.LANDSAT_ETM_TOA_MOCO,7,9,`link `__,`link `__,66.60,49.92,64.12,47.19
+ResNet50_Weights.LANDSAT_ETM_TOA_SIMCLR,7,9,`link `__,`link `__,57.17,40.02,54.95,37.88
+ViTSmall16_Weights.LANDSAT_ETM_TOA_MOCO,7,9,`link `__,`link `__,63.75,46.79,60.88,43.70
+ViTSmall16_Weights.LANDSAT_ETM_TOA_SIMCLR,7,9,`link `__,`link `__,63.33,46.34,59.06,41.91
+ResNet18_Weights.LANDSAT_ETM_SR_MOCO,7,6,`link `__,`link `__,64.18,47.25,67.30,50.71
+ResNet18_Weights.LANDSAT_ETM_SR_SIMCLR,7,6,`link `__,`link `__,57.26,40.11,54.42,37.48
+ResNet50_Weights.LANDSAT_ETM_SR_MOCO,7,6,`link `__,`link `__,64.37,47.46,62.35,45.30
+ResNet50_Weights.LANDSAT_ETM_SR_SIMCLR,7,6,`link `__,`link `__,57.79,40.64,55.69,38.59
+ViTSmall16_Weights.LANDSAT_ETM_SR_MOCO,7,6,`link `__,`link `__,64.09,47.21,52.37,35.48
+ViTSmall16_Weights.LANDSAT_ETM_SR_SIMCLR,7,6,`link `__,`link `__,63.99,47.05,53.17,36.21
+ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_MOCO,8--9,11,`link `__,`link `__,67.82,51.30,65.74,48.96
+ResNet18_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR,8--9,11,`link `__,`link `__,62.14,45.08,60.01,42.86
+ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_MOCO,8--9,11,`link `__,`link `__,69.17,52.87,67.29,50.70
+ResNet50_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR,8--9,11,`link `__,`link `__,64.66,47.78,62.08,45.01
+ViTSmall16_Weights.LANDSAT_OLI_TIRS_TOA_MOCO,8--9,11,`link `__,`link `__,67.11,50.49,64.62,47.73
+ViTSmall16_Weights.LANDSAT_OLI_TIRS_TOA_SIMCLR,8--9,11,`link `__,`link `__,66.12,49.39,63.88,46.94
+ResNet18_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,67.01,50.39,68.05,51.57
+ResNet18_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link `__,`link `__,59.93,42.79,57.44,40.30
+ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,67.44,50.88,65.96,49.21
+ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link `__,`link `__,63.65,46.68,60.01,43.17
+ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,66.81,50.16,64.17,47.24
+ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link `__,`link `__,65.04,48.20,62.61,45.46
diff --git a/docs/api/misc_pretrained_weights.csv b/docs/api/misc_pretrained_weights.csv
new file mode 100644
index 00000000000..43dd4c5405b
--- /dev/null
+++ b/docs/api/misc_pretrained_weights.csv
@@ -0,0 +1,2 @@
+Weight,Channels,Source,Citation
+ResNet50_Weights.FMOW_RGB_GASSL, 3,`link `__,`link `__
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 4e9889c1f1c..53d8bf0092a 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -39,24 +39,12 @@ ResNet
.. autoclass:: ResNet18_Weights
.. autoclass:: ResNet50_Weights
-.. csv-table::
- :widths: 45 10 10 10 15 10 10 10
- :header-rows: 1
- :align: center
- :file: resnet_pretrained_weights.csv
-
Vision Transformer
^^^^^^^^^^^^^^^^^^
.. autofunction:: vit_small_patch16_224
.. autoclass:: ViTSmall16_Weights
-.. csv-table::
- :widths: 45 10 10 10 15 10 10 10
- :header-rows: 1
- :align: center
- :file: vit_pretrained_weights.csv
-
Utility Functions
^^^^^^^^^^^^^^^^^
@@ -64,3 +52,45 @@ Utility Functions
.. autofunction:: get_model_weights
.. autofunction:: get_weight
.. autofunction:: list_models
+
+
+Pretrained Weights
+^^^^^^^^^^^^^^^^^^
+
+Landsat
+-------
+
+.. csv-table::
+ :widths: 65 10 10 10 10 10 10 10 10
+ :header-rows: 1
+ :align: center
+ :file: landsat_pretrained_weights.csv
+
+
+Sentinel-1
+----------
+
+.. csv-table::
+ :widths: 45 10 10 10
+ :header-rows: 1
+ :align: center
+ :file: sentinel1_pretrained_weights.csv
+
+
+Sentinel-2
+----------
+
+.. csv-table::
+ :widths: 45 10 10 10 15 10 10 10
+ :header-rows: 1
+ :align: center
+ :file: sentinel2_pretrained_weights.csv
+
+Other Data Sources
+------------------
+
+.. csv-table::
+ :widths: 45 10 10 10
+ :header-rows: 1
+ :align: center
+ :file: misc_pretrained_weights.csv
diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv
index 064257f8a08..903e0fbf6f4 100644
--- a/docs/api/non_geo_datasets.csv
+++ b/docs/api/non_geo_datasets.csv
@@ -2,6 +2,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`ADVANCE`_,C,"Google Earth, Freesound","5,075",13,512x512,0.5,RGB
`Benin Cashew Plantations`_,S,Airbus Pléiades,70,6,"1,122x1,186",10,MSI
`BigEarthNet`_,C,Sentinel-1/2,"590,326",19--43,120x120,10,"SAR, MSI"
+`BioMassters`_,R,Sentinel-1/2 and Lidar,,,256x256, 10, "SAR, MSI"
`Cloud Cover Detection`_,S,Sentinel-2,"22,728",2,512x512,10,MSI
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB
`Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI
@@ -18,6 +19,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`LandCover.ai`_,S,Aerial,"10,674",5,512x512,0.25--0.5,RGB
`LEVIR-CD+`_,CD,Google Earth,985,2,"1,024x1,024",0.5,RGB
`LoveDA`_,S,Google Earth,"5,987",7,"1,024x1,024",0.3,RGB
+`MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad"
`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB
`NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB
`OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI
@@ -26,13 +28,17 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`Potsdam`_,S,Aerial,38,6,"6,000x6,000",0.05,MSI
`ReforesTree`_,"OD, R",Aerial,100,6,"4,000x4,000",0.02,RGB
`RESISC45`_,C,Google Earth,"31,500",45,256x256,0.2--30,RGB
+`Rwanda Field Boundary`_,S,Planetscope,70,2,256x256,4.7,RGB + NIR
`Seasonal Contrast`_,T,Sentinel-2,100K--1M,-,264x264,10,MSI
+`SeasoNet`_,S,Sentinel-2,"1,759,830",33,120x120,10,MSI
`SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI"
`SKIPP'D`_,R,"Fish-eye","363,375",-,64x64,-,RGB
`So2Sat`_,C,Sentinel-1/2,"400,673",17,32x32,10,"SAR, MSI"
`SpaceNet`_,I,WorldView-2/3 Planet Lab Dove,"1,889--28,728",2,102--900,0.5--4,MSI
`SSL4EO`_-L,T,Landsat,1M,-,264x264,30,MSI
`SSL4EO`_-S12,T,Sentinel-1/2,1M,-,264x264,10,"SAR, MSI"
+`SSL4EO-L Benchmark`_,S,Lansat & CDL,25K,134,264x264,30,MSI
+`SSL4EO-L Benchmark`_,S,Lansat & NLCD,25K,17,264x264,30,MSI
`SustainBench Crop Yield`_,R,MODIS,11k,-,32x32,-,MSI
`Tropical Cyclone`_,R,GOES 8--16,"108,110",-,256x256,4K--8K,MSI
`UC Merced`_,C,USGS National Map,"2,100",21,256x256,0.3,RGB
diff --git a/docs/api/sentinel1_pretrained_weights.csv b/docs/api/sentinel1_pretrained_weights.csv
new file mode 100644
index 00000000000..75e1224d30b
--- /dev/null
+++ b/docs/api/sentinel1_pretrained_weights.csv
@@ -0,0 +1,2 @@
+Weight,Channels,Source,Citation
+ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__
diff --git a/docs/api/resnet_pretrained_weights.csv b/docs/api/sentinel2_pretrained_weights.csv
similarity index 78%
rename from docs/api/resnet_pretrained_weights.csv
rename to docs/api/sentinel2_pretrained_weights.csv
index 8eef92e4e46..8952b74c42f 100644
--- a/docs/api/resnet_pretrained_weights.csv
+++ b/docs/api/sentinel2_pretrained_weights.csv
@@ -2,9 +2,9 @@ Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,,,,
ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,,,,
ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,87.27,93.14,,46.94
-ResNet50_Weights.FMOW_RGB_GASSL, 3,`link `__,`link `__,,,,
-ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__,,,,
ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.7,99.1,63.6,
ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,91.8,99.1,60.9,
ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,,,,
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,87.81,,,
+ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.5,99.0,62.2,
+ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,89.9,98.6,61.6,
diff --git a/docs/api/vit_pretrained_weights.csv b/docs/api/vit_pretrained_weights.csv
deleted file mode 100644
index 9f6899a7eb5..00000000000
--- a/docs/api/vit_pretrained_weights.csv
+++ /dev/null
@@ -1,3 +0,0 @@
-Weight,Channels,Source,Citation,BigEarthNet,EuroSAT,So2Sat,OSCD
-ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,90.5,99.0,62.2,
-ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,89.9,98.6,61.6,
diff --git a/docs/conf.py b/docs/conf.py
index 44514ff9419..e62e91172d8 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -59,6 +59,7 @@
# Undocumented classes
("py:class", "kornia.augmentation._2d.intensity.base.IntensityAugmentationBase2D"),
("py:class", "kornia.augmentation.base._AugmentationBase"),
+ ("py:class", "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig"),
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
("py:class", "timm.models.resnet.ResNet"),
("py:class", "timm.models.vision_transformer.VisionTransformer"),
diff --git a/docs/index.rst b/docs/index.rst
index d30b0c0eb73..aa3e1ce3fa2 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -32,7 +32,6 @@ torchgeo
tutorials/transforms
tutorials/indices
tutorials/trainers
- tutorials/benchmarking
tutorials/pretrained_weights
.. toctree::
diff --git a/docs/tutorials/benchmarking.ipynb b/docs/tutorials/benchmarking.ipynb
deleted file mode 100644
index eb07bdb192b..00000000000
--- a/docs/tutorials/benchmarking.ipynb
+++ /dev/null
@@ -1,364 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Copyright (c) Microsoft Corporation. All rights reserved.\n",
- "\n",
- "Licensed under the MIT License."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "OFXtoHmJClRf"
- },
- "source": [
- "# Benchmarking\n",
- "\n",
- "This tutorial benchmarks the performance of various sampling strategies, with and without caching.\n",
- "\n",
- "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Setup\n",
- "\n",
- "First, we install TorchGeo."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%pip install torchgeo"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hC3pauOLChi4",
- "nteract": {
- "transient": {
- "deleting": false
- }
- }
- },
- "source": [
- "## Imports\n",
- "\n",
- "Next, we import TorchGeo and any other libraries we need."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "gather": {
- "logged": 1629238744113
- },
- "id": "gjFiws-PChi8"
- },
- "outputs": [],
- "source": [
- "import os\n",
- "import tempfile\n",
- "import time\n",
- "from typing import Tuple\n",
- "\n",
- "from torch.utils.data import DataLoader\n",
- "\n",
- "from torchgeo.datasets import NAIP, ChesapeakeDE\n",
- "from torchgeo.datasets.utils import download_url, stack_samples\n",
- "from torchgeo.samplers import RandomGeoSampler, GridGeoSampler, RandomBatchGeoSampler"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Datasets\n",
- "\n",
- "For this tutorial, we'll be using imagery from the [National Agriculture Imagery Program (NAIP)](https://catalog.data.gov/dataset/national-agriculture-imagery-program-naip) and labels from the [Chesapeake Bay High-Resolution Land Cover Project](https://www.chesapeakeconservancy.org/conservation-innovation-center/high-resolution-data/land-cover-data-project/). First, we manually download a few NAIP tiles."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "naip_root = os.path.join(tempfile.gettempdir(), \"naip\")\n",
- "naip_url = (\n",
- " \"https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/\"\n",
- ")\n",
- "tiles = [\n",
- " \"m_3807511_ne_18_060_20181104.tif\",\n",
- " \"m_3807511_se_18_060_20181104.tif\",\n",
- " \"m_3807512_nw_18_060_20180815.tif\",\n",
- " \"m_3807512_sw_18_060_20180815.tif\",\n",
- "]\n",
- "for tile in tiles:\n",
- " download_url(naip_url + tile, naip_root)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Next, we tell TorchGeo to automatically download the corresponding Chesapeake labels."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "chesapeake_root = os.path.join(tempfile.gettempdir(), \"chesapeake\")\n",
- "chesapeake = ChesapeakeDE(chesapeake_root, download=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "n6HwpMz7Chi-",
- "nteract": {
- "transient": {
- "deleting": false
- }
- }
- },
- "source": [
- "## Timing function"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "gather": {
- "logged": 1629238744228
- },
- "id": "8-z6_y2xChi-",
- "nteract": {
- "transient": {
- "deleting": false
- }
- },
- "tags": []
- },
- "outputs": [],
- "source": [
- "def time_epoch(dataloader: DataLoader) -> Tuple[float, int]:\n",
- " tic = time.time()\n",
- " i = 0\n",
- " for _ in dataloader:\n",
- " i += 1\n",
- " toc = time.time()\n",
- " return toc - tic, i"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The following variables can be modified to control the number of samples drawn per epoch."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "nbmake": {
- "mock": {
- "batch_size": 1,
- "length": 1,
- "size": 1,
- "stride": 1000000
- }
- }
- },
- "outputs": [],
- "source": [
- "size = 1000\n",
- "length = 888\n",
- "batch_size = 12\n",
- "stride = 500"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "I3pkKYoeChi_",
- "nteract": {
- "transient": {
- "deleting": false
- }
- }
- },
- "source": [
- "## RandomGeoSampler"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "gather": {
- "logged": 1629248963725
- },
- "id": "jPjIZLF7Chi_",
- "nteract": {
- "transient": {
- "deleting": false
- }
- },
- "outputId": "edcc8199-bd09-4832-e50c-7be8ac78995b",
- "tags": []
- },
- "outputs": [],
- "source": [
- "for cache in [False, True]:\n",
- " chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
- " naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
- " dataset = chesapeake & naip\n",
- " sampler = RandomGeoSampler(dataset, size=size, length=length)\n",
- " dataloader = DataLoader(\n",
- " dataset, batch_size=batch_size, sampler=sampler, collate_fn=stack_samples\n",
- " )\n",
- " duration, count = time_epoch(dataloader)\n",
- " print(duration, count)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pHqLRDA_ChjB",
- "nteract": {
- "transient": {
- "deleting": false
- }
- }
- },
- "source": [
- "## GridGeoSampler"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "gather": {
- "logged": 1629239313388
- },
- "id": "K67vnCK4ChjC",
- "nteract": {
- "transient": {
- "deleting": false
- }
- },
- "outputId": "159ce99f-a438-4ecc-d218-9b9e28d02055",
- "tags": []
- },
- "outputs": [],
- "source": [
- "for cache in [False, True]:\n",
- " chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
- " naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
- " dataset = chesapeake & naip\n",
- " sampler = GridGeoSampler(dataset, size=size, stride=stride)\n",
- " dataloader = DataLoader(\n",
- " dataset, batch_size=batch_size, sampler=sampler, collate_fn=stack_samples\n",
- " )\n",
- " duration, count = time_epoch(dataloader)\n",
- " print(duration, count)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8rwjrOD1ChjD",
- "nteract": {
- "transient": {
- "deleting": false
- }
- }
- },
- "source": [
- "## RandomBatchGeoSampler"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "gather": {
- "logged": 1629249843438
- },
- "id": "v-N2fo6UChjE",
- "nteract": {
- "transient": {
- "deleting": false
- }
- },
- "outputId": "497f6869-1ab7-4db7-bbce-e943b493ca41",
- "tags": []
- },
- "outputs": [],
- "source": [
- "for cache in [False, True]:\n",
- " chesapeake = ChesapeakeDE(chesapeake_root, cache=cache)\n",
- " naip = NAIP(naip_root, crs=chesapeake.crs, res=chesapeake.res, cache=cache)\n",
- " dataset = chesapeake & naip\n",
- " sampler = RandomBatchGeoSampler(\n",
- " dataset, size=size, batch_size=batch_size, length=length\n",
- " )\n",
- " dataloader = DataLoader(dataset, batch_sampler=sampler, collate_fn=stack_samples)\n",
- " duration, count = time_epoch(dataloader)\n",
- " print(duration, count)"
- ]
- }
- ],
- "metadata": {
- "colab": {
- "collapsed_sections": [],
- "name": "benchmarking.ipynb",
- "provenance": []
- },
- "execution": {
- "timeout": 1200
- },
- "kernel_info": {
- "name": "python38-azureml"
- },
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.8"
- },
- "nteract": {
- "version": "nteract-front-end@1.0.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb
index 826bcd0abc6..bfae62e29a7 100644
--- a/docs/tutorials/getting_started.ipynb
+++ b/docs/tutorials/getting_started.ipynb
@@ -189,6 +189,7 @@
"outputs": [],
"source": [
"chesapeake_root = os.path.join(tempfile.gettempdir(), \"chesapeake\")\n",
+ "os.makedirs(chesapeake_root, exist_ok=True)\n",
"chesapeake = ChesapeakeDE(chesapeake_root, crs=naip.crs, res=naip.res, download=True)"
]
},
diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb
index fe9660707d3..26d97fcbc6c 100644
--- a/docs/tutorials/pretrained_weights.ipynb
+++ b/docs/tutorials/pretrained_weights.ipynb
@@ -204,8 +204,8 @@
" weights=weights,\n",
" in_channels=13,\n",
" num_classes=10,\n",
- " learning_rate=0.001,\n",
- " learning_rate_schedule_patience=5,\n",
+ " lr=0.001,\n",
+ " patience=5,\n",
")"
]
},
diff --git a/docs/tutorials/trainers.ipynb b/docs/tutorials/trainers.ipynb
index 0e6b6786479..3e05930ffd2 100644
--- a/docs/tutorials/trainers.ipynb
+++ b/docs/tutorials/trainers.ipynb
@@ -174,8 +174,8 @@
" weights=ResNet18_Weights.SENTINEL2_ALL_MOCO,\n",
" in_channels=13,\n",
" num_classes=10,\n",
- " learning_rate=0.1,\n",
- " learning_rate_schedule_patience=5,\n",
+ " lr=0.1,\n",
+ " patience=5,\n",
")"
]
},
diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb
index 71e2224a3da..80a53f52643 100644
--- a/docs/tutorials/transforms.ipynb
+++ b/docs/tutorials/transforms.ipynb
@@ -201,16 +201,16 @@
" ]\n",
")\n",
"bands = {\n",
- " \"B1\": \"Coastal Aerosol\",\n",
- " \"B2\": \"Blue\",\n",
- " \"B3\": \"Green\",\n",
- " \"B4\": \"Red\",\n",
- " \"B5\": \"Vegetation Red Edge 1\",\n",
- " \"B6\": \"Vegetation Red Edge 2\",\n",
- " \"B7\": \"Vegetation Red Edge 3\",\n",
- " \"B8\": \"NIR 1\",\n",
+ " \"B01\": \"Coastal Aerosol\",\n",
+ " \"B02\": \"Blue\",\n",
+ " \"B03\": \"Green\",\n",
+ " \"B04\": \"Red\",\n",
+ " \"B05\": \"Vegetation Red Edge 1\",\n",
+ " \"B06\": \"Vegetation Red Edge 2\",\n",
+ " \"B07\": \"Vegetation Red Edge 3\",\n",
+ " \"B08\": \"NIR 1\",\n",
" \"B8A\": \"NIR 2\",\n",
- " \"B9\": \"Water Vapour\",\n",
+ " \"B09\": \"Water Vapour\",\n",
" \"B10\": \"SWIR 1\",\n",
" \"B11\": \"SWIR 2\",\n",
" \"B12\": \"SWIR 3\",\n",
diff --git a/experiments/ssl4eo/flops.py b/experiments/ssl4eo/flops.py
index 985a1e72fc1..6bffb1835f9 100755
--- a/experiments/ssl4eo/flops.py
+++ b/experiments/ssl4eo/flops.py
@@ -22,7 +22,7 @@
# Calculate memory requirements of model
mem_params = sum([p.nelement() * p.element_size() for p in m.parameters()])
mem_bufs = sum([b.nelement() * b.element_size() for b in m.buffers()])
- mem = (mem_params + mem_bufs) / 2**20
+ mem = (mem_params + mem_bufs) / 1000000
print(f"Memory: {mem:.2f} MB")
with get_accelerator().device(0):
diff --git a/experiments/ssl4eo/landsat/README.md b/experiments/ssl4eo/landsat/README.md
index ef56db35bb7..d681ea28986 100644
--- a/experiments/ssl4eo/landsat/README.md
+++ b/experiments/ssl4eo/landsat/README.md
@@ -89,10 +89,10 @@ This will create patches of NLCD and CDL data with the same locations and dimens
Using either the newly created datasets or after downloading the datasets from Hugging Face, you can run each experiment using:
```console
-$ python3 ../../../train.py config_file=...
+$ torchgeo fit --config *.yaml
```
-The config files to be passed can be found in the `../../../conf/` directory. Feel free to tweak any hyperparameters you see in these files. The default values are the optimal hyperparameters we found.
+The config files to be passed can be found in the `conf/` directory. Feel free to tweak any hyperparameters you see in these files. The default values are the optimal hyperparameters we found.
## Plotting
diff --git a/experiments/ssl4eo/landsat/conf/l7irish.yaml b/experiments/ssl4eo/landsat/conf/l7irish.yaml
new file mode 100644
index 00000000000..91b1cbea15d
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/l7irish.yaml
@@ -0,0 +1,23 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 9
+ num_classes: 5
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: L7IrishDataModule
+ init_args:
+ batch_size: 64
+ patch_size: 224
+ num_workers: 16
+ dict_kwargs:
+ paths: "data/l7irish"
diff --git a/experiments/ssl4eo/landsat/conf/l8biome.yaml b/experiments/ssl4eo/landsat/conf/l8biome.yaml
new file mode 100644
index 00000000000..728073a56fa
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/l8biome.yaml
@@ -0,0 +1,23 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 11
+ num_classes: 5
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: L8BiomeDataModule
+ init_args:
+ batch_size: 64
+ patch_size: 224
+ num_workers: 16
+ dict_kwargs:
+ paths: "data/l8biome"
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_cdl.yaml
new file mode 100644
index 00000000000..30b41c1cb0e
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_cdl.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 6
+ num_classes: 18
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "etm_sr"
+ product: "cdl"
+ classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml
new file mode 100644
index 00000000000..51e1732e91d
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 6
+ num_classes: 14
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "etm_sr"
+ product: "nlcd"
+ classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_cdl.yaml
new file mode 100644
index 00000000000..7205bc6e20d
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_cdl.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 9
+ num_classes: 18
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "etm_toa"
+ product: "cdl"
+ classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml
new file mode 100644
index 00000000000..10b9ea0a24c
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 9
+ num_classes: 14
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "etm_toa"
+ product: "nlcd"
+ classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_cdl.yaml
new file mode 100644
index 00000000000..f32a344cd0a
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_cdl.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 7
+ num_classes: 18
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "oli_sr"
+ product: "cdl"
+ classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml
new file mode 100644
index 00000000000..4fa8059656d
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 7
+ num_classes: 14
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "oli_sr"
+ product: "nlcd"
+ classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml
new file mode 100644
index 00000000000..2beacfd672f
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 11
+ num_classes: 18
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "oli_tirs_toa"
+ product: "cdl"
+ classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml
new file mode 100644
index 00000000000..d14dbb6c6e2
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 11
+ num_classes: 14
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "oli_tirs_toa"
+ product: "nlcd"
+ classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_cdl.yaml
new file mode 100644
index 00000000000..d64d0ab65ac
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_cdl.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 7
+ num_classes: 18
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "tm_toa"
+ product: "cdl"
+ classes: [0, 1, 5, 24, 36, 37, 61, 111, 121, 122, 131, 141, 142, 143, 152, 176, 190, 195]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml
new file mode 100644
index 00000000000..c1f9d0bde17
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml
@@ -0,0 +1,25 @@
+trainer:
+ min_epochs: 20
+ max_epochs: 100
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ in_channels: 7
+ num_classes: 14
+ loss: "ce"
+ ignore_index: 0
+ lr: 1e-3
+ patience: 6
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 16
+ dict_kwargs:
+ root: "data/ssl4eo_benchmark"
+ sensor: "tm_toa"
+ product: "nlcd"
+ classes: [0, 11, 21, 22, 31, 41, 42, 43, 52, 71, 81, 82, 90, 95]
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_mocov2_resnet50.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_mocov2_resnet50.yaml
new file mode 100644
index 00000000000..440c02ed8ff
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_mocov2_resnet50.yaml
@@ -0,0 +1,33 @@
+trainer:
+ limit_val_batches: 0.0
+ max_epochs: 200
+ log_every_n_steps: 5
+model:
+ class_path: MoCoTask
+ init_args:
+ model: resnet50
+ weights: True
+ in_channels: 11
+ version: 2
+ layers: 2
+ hidden_dim: 2048
+ output_dim: 128
+ lr: 0.12
+ weight_decay: 1e-4
+ momentum: 0.9
+ schedule: [120, 160]
+ temperature: 0.07
+ memory_bank_size: 65536
+ moco_momentum: 0.999
+ gather_distributed: True
+ size: 224
+ grayscale_weights: null
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 256
+ num_workers: 16
+ dict_kwargs:
+ root: /path/to/data/
+ split: oli_tirs_toa
+ seasons: 2
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_mocov2_vits16.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_mocov2_vits16.yaml
new file mode 100644
index 00000000000..5345d21287b
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_mocov2_vits16.yaml
@@ -0,0 +1,33 @@
+trainer:
+ limit_val_batches: 0.0
+ max_epochs: 200
+ log_every_n_steps: 5
+model:
+ class_path: MoCoTask
+ init_args:
+ model: vit_small_patch16_224
+ weights: True
+ in_channels: 11
+ version: 2
+ layers: 2
+ hidden_dim: 2048
+ output_dim: 128
+ lr: 0.012
+ weight_decay: 1e-4
+ momentum: 0.9
+ schedule: [120, 160]
+ temperature: 0.07
+ memory_bank_size: 65536
+ moco_momentum: 0.999
+ gather_distributed: True
+ size: 224
+ grayscale_weights: null
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 256
+ num_workers: 16
+ dict_kwargs:
+ root: /path/to/data/
+ split: oli_tirs_toa
+ seasons: 2
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_simclr_resnet50.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_simclr_resnet50.yaml
new file mode 100644
index 00000000000..ca9c7a10aab
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_simclr_resnet50.yaml
@@ -0,0 +1,28 @@
+trainer:
+ limit_val_batches: 0.0
+ max_epochs: 200
+ log_every_n_steps: 5
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: resnet50
+ weights: True
+ in_channels: 11
+ version: 1
+ layers: 2
+ hidden_dim: 2048
+ output_dim: 128
+ lr: 0.12
+ memory_bank_size: 0
+ gather_distributed: True
+ size: 224
+ grayscale_weights: null
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 256
+ num_workers: 16
+ dict_kwargs:
+ root: /path/to/data/
+ split: oli_tirs_toa
+ seasons: 2
diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_simclr_vits16.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_simclr_vits16.yaml
new file mode 100644
index 00000000000..8383bb800b0
--- /dev/null
+++ b/experiments/ssl4eo/landsat/conf/ssl4eo_l_oli_tirs_toa_simclr_vits16.yaml
@@ -0,0 +1,28 @@
+trainer:
+ limit_val_batches: 0.0
+ max_epochs: 200
+ log_every_n_steps: 5
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: vit_small_patch16_224
+ weights: True
+ in_channels: 11
+ version: 1
+ layers: 2
+ hidden_dim: 2048
+ output_dim: 128
+ lr: 0.012
+ memory_bank_size: 0
+ gather_distributed: True
+ size: 224
+ grayscale_weights: null
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 256
+ num_workers: 16
+ dict_kwargs:
+ root: /path/to/data/
+ split: oli_tirs_toa
+ seasons: 2
diff --git a/experiments/ssl4eo/landsat/plot_landsat_bands.py b/experiments/ssl4eo/landsat/plot_landsat_bands.py
index 0de8e66396c..edfd6c86ec7 100755
--- a/experiments/ssl4eo/landsat/plot_landsat_bands.py
+++ b/experiments/ssl4eo/landsat/plot_landsat_bands.py
@@ -46,7 +46,9 @@
df = df.iloc[::-1]
fig, ax = plt.subplots(figsize=(5.5, args.fig_height))
-ax1, ax2 = fig.subplots(nrows=1, ncols=2, gridspec_kw={"width_ratios": [3, 1]})
+ax1, ax2 = fig.subplots(
+ nrows=1, ncols=2, gridspec_kw={"width_ratios": [3, 1]}
+) # type: ignore[misc]
sensor_names: list[str] = []
sensor_ylocs: list[float] = []
diff --git a/experiments/ssl4eo/landsat/plot_landsat_timeline.py b/experiments/ssl4eo/landsat/plot_landsat_timeline.py
index 48634d97006..94eb40dbe0e 100755
--- a/experiments/ssl4eo/landsat/plot_landsat_timeline.py
+++ b/experiments/ssl4eo/landsat/plot_landsat_timeline.py
@@ -74,7 +74,7 @@
fig, ax = plt.subplots(figsize=(5.5, 3))
-cmap = iter(plt.cm.tab10(range(9, 0, -1)))
+cmap = iter(plt.cm.tab10(range(9, 0, -1))) # type: ignore[attr-defined]
ymin = args.bar_start
yticks = []
for satellite in range(9, 0, -1):
diff --git a/experiments/torchgeo/conf/chesapeake_cvpr.yaml b/experiments/torchgeo/conf/chesapeake_cvpr.yaml
new file mode 100644
index 00000000000..da2e012ed05
--- /dev/null
+++ b/experiments/torchgeo/conf/chesapeake_cvpr.yaml
@@ -0,0 +1,32 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ weights: null
+ lr: 1e-3
+ patience: 6
+ in_channels: 4
+ num_classes: 7
+ num_filters: 256
+ ignore_index: null
+data:
+ class_path: ChesapeakeCVPRDataModule
+ init_args:
+ train_splits:
+ - "de-train"
+ val_splits:
+ - "de-val"
+ test_splits:
+ - "de-test"
+ batch_size: 200
+ patch_size: 256
+ num_workers: 4
+ class_set: ${model.init_args.num_classes}
+ use_prior_labels: False
+ dict_kwargs:
+ root: "data/chesapeake/cvpr"
diff --git a/experiments/torchgeo/conf/cowc_counting.yaml b/experiments/torchgeo/conf/cowc_counting.yaml
new file mode 100644
index 00000000000..481ba40cd97
--- /dev/null
+++ b/experiments/torchgeo/conf/cowc_counting.yaml
@@ -0,0 +1,19 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: RegressionTask
+ init_args:
+ model: resnet18
+ weights: null
+ num_outputs: 1
+ in_channels: 3
+ lr: 1e-3
+ patience: 2
+data:
+ class_path: COWCCountingDataModule
+ init_args:
+ batch_size: 64
+ num_workers: 4
+ dict_kwargs:
+ root: "data/cowc_counting"
diff --git a/experiments/torchgeo/conf/etci2021.yaml b/experiments/torchgeo/conf/etci2021.yaml
new file mode 100644
index 00000000000..c3f0ae487ca
--- /dev/null
+++ b/experiments/torchgeo/conf/etci2021.yaml
@@ -0,0 +1,22 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ weights: true
+ lr: 1e-3
+ patience: 6
+ in_channels: 6
+ num_classes: 2
+ ignore_index: 0
+data:
+ class_path: ETCI2021DataModule
+ init_args:
+ batch_size: 32
+ num_workers: 4
+ dict_kwargs:
+ root: "data/etci2021"
diff --git a/experiments/torchgeo/conf/eurosat.yaml b/experiments/torchgeo/conf/eurosat.yaml
new file mode 100644
index 00000000000..6e788273aa6
--- /dev/null
+++ b/experiments/torchgeo/conf/eurosat.yaml
@@ -0,0 +1,20 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ lr: 1e-3
+ patience: 6
+ weights: null
+ in_channels: 13
+ num_classes: 10
+data:
+ class_path: EuroSATDataModule
+ init_args:
+ batch_size: 128
+ num_workers: 4
+ dict_kwargs:
+ root: "data/eurosat"
diff --git a/experiments/torchgeo/conf/landcoverai.yaml b/experiments/torchgeo/conf/landcoverai.yaml
new file mode 100644
index 00000000000..e9ef4df66cf
--- /dev/null
+++ b/experiments/torchgeo/conf/landcoverai.yaml
@@ -0,0 +1,23 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ weights: true
+ lr: 1e-3
+ patience: 6
+ in_channels: 3
+ num_classes: 5
+ num_filters: 256
+ ignore_index: null
+data:
+ class_path: LandCoverAIDataModule
+ init_args:
+ batch_size: 32
+ num_workers: 4
+ dict_kwargs:
+ root: "data/landcoverai"
diff --git a/experiments/torchgeo/conf/resisc45.yaml b/experiments/torchgeo/conf/resisc45.yaml
new file mode 100644
index 00000000000..8a9d34c4ede
--- /dev/null
+++ b/experiments/torchgeo/conf/resisc45.yaml
@@ -0,0 +1,20 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ lr: 1e-3
+ patience: 6
+ weights: null
+ in_channels: 3
+ num_classes: 45
+data:
+ class_path: RESISC45DataModule
+ init_args:
+ batch_size: 128
+ num_workers: 4
+ dict_kwargs:
+ root: "data/resisc45"
diff --git a/experiments/torchgeo/conf/so2sat.yaml b/experiments/torchgeo/conf/so2sat.yaml
new file mode 100644
index 00000000000..1b9e7144263
--- /dev/null
+++ b/experiments/torchgeo/conf/so2sat.yaml
@@ -0,0 +1,21 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ lr: 1e-3
+ patience: 6
+ weights: null
+ in_channels: 18
+ num_classes: 17
+data:
+ class_path: So2SatDataModule
+ init_args:
+ batch_size: 128
+ num_workers: 4
+ band_set: "all"
+ dict_kwargs:
+ root: "data/so2sat"
diff --git a/experiments/torchgeo/conf/ucmerced.yaml b/experiments/torchgeo/conf/ucmerced.yaml
new file mode 100644
index 00000000000..2a4d8786422
--- /dev/null
+++ b/experiments/torchgeo/conf/ucmerced.yaml
@@ -0,0 +1,20 @@
+trainer:
+ min_epochs: 15
+ max_epochs: 40
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ weights: null
+ lr: 1e-3
+ patience: 6
+ in_channels: 3
+ num_classes: 21
+data:
+ class_path: UCMercedDataModule
+ init_args:
+ batch_size: 128
+ num_workers: 4
+ dict_kwargs:
+ root: "data/ucmerced"
diff --git a/experiments/torchgeo/plot_dataloader_benchmark.py b/experiments/torchgeo/plot_dataloader_benchmark.py
index c0cd22185e1..4c313ca174f 100755
--- a/experiments/torchgeo/plot_dataloader_benchmark.py
+++ b/experiments/torchgeo/plot_dataloader_benchmark.py
@@ -32,7 +32,7 @@
ax.set_xscale("log")
ax.set_xticks([16, 32, 64, 128, 256])
-ax.set_xticklabels([16, 32, 64, 128, 256], fontsize=12)
+ax.set_xticklabels(["16", "32", "64", "128", "256"], fontsize=12)
ax.set_xlabel("batch size", fontsize=12)
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
ax.legend(loc="center right", fontsize="large")
diff --git a/experiments/torchgeo/plot_percentage_benchmark.py b/experiments/torchgeo/plot_percentage_benchmark.py
index 57e59e3621a..e0f2aa3b0e8 100755
--- a/experiments/torchgeo/plot_percentage_benchmark.py
+++ b/experiments/torchgeo/plot_percentage_benchmark.py
@@ -49,7 +49,7 @@
ax.set_xscale("log")
ax.set_xticks([16, 32, 64, 128, 256])
-ax.set_xticklabels([16, 32, 64, 128, 256])
+ax.set_xticklabels(["16", "32", "64", "128", "256"])
ax.set_xlabel("batch size")
ax.set_ylabel("% sampling rate (patches/sec)")
ax.legend()
diff --git a/pyproject.toml b/pyproject.toml
index 0d3e23a20ae..6c5b892b468 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,12 +45,14 @@ dependencies = [
"kornia>=0.6.9",
# lightly 1.4.4+ required for MoCo v3 support
"lightly>=1.4.4",
- # lightning 1.8+ is first release
- "lightning>=1.8",
+ # lightning 2+ required for LightningCLI args + sys.argv support
+ "lightning[pytorch-extra]>=2",
# matplotlib 3.3.3+ required for Python 3.9 wheels
"matplotlib>=3.3.3",
# numpy 1.19.3+ required by Python 3.9 wheels
"numpy>=1.19.3",
+ # pandas 1.1.3+ required for Python 3.9 wheels
+ "pandas>=1.1.3",
# pillow 8+ required for Python 3.9 wheels
"pillow>=8",
# pyproj 3+ required for Python 3.9 wheels
@@ -82,12 +84,10 @@ datasets = [
"laspy>=2",
# opencv-python 4.4.0.46+ required for Python 3.9 wheels
"opencv-python>=4.4.0.46",
- # pandas 1.1.3+ required for Python 3.9 wheels
- "pandas>=1.1.3",
# pycocotools 2.0.5+ required for cython 3+ support
"pycocotools>=2.0.5",
- # pyvista 0.29+ required for to avoid segfault during testing
- "pyvista>=0.29",
+ # pyvista 0.34.2+ required to avoid ImportError in CI
+ "pyvista>=0.34.2",
# radiant-mlhub 0.3+ required for newer tqdm support required by lightning
"radiant-mlhub>=0.3",
# rarfile 4+ required for wheels
@@ -127,25 +127,22 @@ style = [
"pyupgrade>=2.8",
]
tests = [
- # hydra-core 1+ required for omegaconf 2 support
- "hydra-core>=1",
# mypy 0.900+ required for pyproject.toml support
"mypy>=0.900",
# nbmake 1.3.3+ required for variable mocking
"nbmake>=1.3.3",
- # omegaconf 2+ required by lightning, 2.0.1+ required by hydra-core
- "omegaconf>=2.0.1",
- # pytest 6.2+ required for pytest.MonkeyPatch
- "pytest>=6.2",
- # pytest-cov 2.4+ required for pytest --cov flags
- "pytest-cov>=2.4",
- # tensorboard 2.9.1+ required by lightning
- "tensorboard>=2.9.1",
+ # pytest 7.3+ required for tmp_path_retention_policy
+ "pytest>=7.3",
+ # pytest-cov 4+ required for pytest 7.2+ compatibility
+ "pytest-cov>=4",
]
all = [
"torchgeo[datasets,docs,style,tests]",
]
+[project.scripts]
+torchgeo = "torchgeo.main:main"
+
[project.urls]
Homepage = "https://github.com/microsoft/torchgeo"
Documentation = "https://torchgeo.readthedocs.io"
@@ -215,6 +212,9 @@ filterwarnings = [
# https://github.com/pytorch/pytorch/pull/69823
"ignore:distutils Version classes are deprecated. Use packaging.version instead:DeprecationWarning",
"ignore:The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:torch.utils.tensorboard",
+ # https://github.com/Lightning-AI/torchmetrics/issues/2121
+ # https://github.com/Lightning-AI/torchmetrics/pull/2137
+ "ignore:The distutils package is deprecated and slated for removal in Python 3.12:DeprecationWarning:torchmetrics.utilities.imports",
# https://github.com/Lightning-AI/lightning/issues/13256
# https://github.com/Lightning-AI/lightning/pull/13261
"ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning:torch.distributed._sharded_tensor",
@@ -245,15 +245,22 @@ filterwarnings = [
"ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning",
# https://github.com/pydata/xarray/issues/7259
"ignore: numpy.ndarray size changed, may indicate binary incompatibility. Expected 16 from C header, got 96 from PyObject",
+ "ignore:pkg_resources is deprecated as an API.:DeprecationWarning:lightning_utilities.core.imports",
+ "ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning:jsonargparse",
+ # https://github.com/pytorch/pytorch/issues/110549
+ "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning:einops",
# Expected warnings
# Lightning warns us about using num_workers=0, but it's faster on macOS
- "ignore:The dataloader, .*, does not have many workers which may be a bottleneck:UserWarning",
+ "ignore:The .*dataloader.* does not have many workers which may be a bottleneck:UserWarning:lightning",
+ "ignore:The .*dataloader.* does not have many workers which may be a bottleneck:lightning.fabric.utilities.warnings.PossibleUserWarning:lightning",
# Lightning warns us about using the CPU when GPU/MPS is available
"ignore:GPU available but not used.:UserWarning",
"ignore:MPS available but not used.:UserWarning",
# Lightning warns us if TensorBoard is not installed
"ignore:Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package:UserWarning",
+ # https://github.com/Lightning-AI/lightning/issues/18545
+ "ignore:LightningCLI's args parameter is intended to run from within Python like if it were from the command line.:UserWarning",
# https://github.com/kornia/kornia/pull/1611
"ignore:`ColorJitter` is now following Torchvision implementation.:DeprecationWarning:kornia.augmentation._2d.intensity.color_jitter",
# https://github.com/kornia/kornia/pull/1663
@@ -262,6 +269,8 @@ filterwarnings = [
# Unexpected warnings, worth investigating
# Lightning is having trouble inferring the batch size for ChesapeakeCVPRDataModule and CycloneDataModule for some reason
"ignore:Trying to infer the `batch_size` from an ambiguous collection:UserWarning",
+ # https://github.com/pytest-dev/pytest/issues/11461
+ "ignore::pytest.PytestUnraisableExceptionWarning",
]
markers = [
"slow: marks tests as slow",
@@ -275,6 +284,7 @@ testpaths = [
"tests",
"docs/tutorials",
]
+tmp_path_retention_policy = "failed"
# https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html
[tool.setuptools.dynamic]
diff --git a/requirements/datasets.txt b/requirements/datasets.txt
index fcfdfe55719..35c9a8f5158 100644
--- a/requirements/datasets.txt
+++ b/requirements/datasets.txt
@@ -1,16 +1,15 @@
# datasets
-h5py==3.9.0
+h5py==3.10.0
laspy==2.5.1
-opencv-python==4.8.0.76
-pandas==2.1.0
+opencv-python==4.8.1.78
pycocotools==2.0.7
-pyvista==0.42.1
+pyvista==0.42.3
radiant-mlhub==0.4.1
-rarfile==4.0
-scikit-image==0.21.0
+rarfile==4.1
+scikit-image==0.22.0
xarray==2023.7.0
rioxarray==0.14.1
xarray
netCDF4
-scipy==1.11.2
+scipy==1.11.3
zipfile-deflate64==0.2.0
diff --git a/requirements/docs.txt b/requirements/docs.txt
index 61344db5686..898ddd75ed4 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -1,4 +1,4 @@
# docs
-ipywidgets==8.1.0
+ipywidgets==8.1.1
nbsphinx==0.9.3
sphinx==5.3.0
diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old
index 05422bf65e3..b7e79a4801a 100644
--- a/requirements/min-reqs.old
+++ b/requirements/min-reqs.old
@@ -6,9 +6,10 @@ einops==0.3.0
fiona==1.8.19
kornia==0.6.9
lightly==1.4.4
-lightning==1.8.0
+lightning[pytorch-extra]==2.0.0
matplotlib==3.3.3
numpy==1.19.3
+pandas==1.1.3
pillow==8.0.0
pyproj==3.0.0
rasterio==1.2.0
@@ -24,9 +25,8 @@ torchvision==0.13.0
h5py==3.0.0
laspy==2.0.0
opencv-python==4.4.0.46
-pandas==1.1.3
pycocotools==2.0.5
-pyvista==0.29.0
+pyvista==0.34.2
radiant-mlhub==0.3.0
rarfile==4.0
scikit-image==0.18.0
@@ -47,14 +47,7 @@ pydocstyle[toml]==6.1.0
pyupgrade==2.8.0
# tests
-hydra-core==1.0.0
mypy==0.900
nbmake==1.3.3
-omegaconf==2.0.1
-pytest==6.2.0
-pytest-cov==2.4.0
-tensorboard==2.9.1
-
-# Required dependency of lightning, wasn't properly listed until 1.9
-# https://github.com/Lightning-AI/lightning/pull/16302
-websockets
+pytest==7.3.0
+pytest-cov==4.0.0
diff --git a/requirements/required.txt b/requirements/required.txt
index cccd3c80fe9..ee5212a35ce 100644
--- a/requirements/required.txt
+++ b/requirements/required.txt
@@ -2,20 +2,21 @@
setuptools==68.2.0
# install
-einops==0.6.1
-fiona==1.9.4.post1
+einops==0.7.0
+fiona==1.9.5
kornia==0.7.0
-lightly==1.4.17
-lightning==2.0.8
-matplotlib==3.7.2
-numpy==1.25.2
-pillow==10.0.0
-pyproj==3.6.0
-rasterio==1.3.8
-rtree==1.0.1
+lightly==1.4.21
+lightning[pytorch-extra]==2.1.1
+matplotlib==3.8.1
+numpy==1.26.1
+pandas==2.1.3
+pillow==10.1.0
+pyproj==3.6.1
+rasterio==1.3.9
+rtree==1.1.0
segmentation-models-pytorch==0.3.3
-shapely==2.0.1
+shapely==2.0.2
timm==0.9.2
-torch==2.0.1
-torchmetrics==1.1.1
-torchvision==0.15.2
+torch==2.1.0
+torchmetrics==1.2.0
+torchvision==0.16.0
diff --git a/requirements/style.txt b/requirements/style.txt
index 98c251e8d33..3f507285103 100644
--- a/requirements/style.txt
+++ b/requirements/style.txt
@@ -1,6 +1,6 @@
# style
-black[jupyter]==23.7.0
+black[jupyter]==23.11.0
flake8==6.1.0
isort[colors]==5.12.0
pydocstyle[toml]==6.3.0
-pyupgrade==3.10.1
+pyupgrade==3.15.0
diff --git a/requirements/tests.txt b/requirements/tests.txt
index b0966be2969..1cf4888beec 100644
--- a/requirements/tests.txt
+++ b/requirements/tests.txt
@@ -1,8 +1,5 @@
# tests
-hydra-core==1.3.2
-mypy==1.5.1
-nbmake==1.4.3
-omegaconf==2.3.0
-pytest==7.4.2
+mypy==1.6.1
+nbmake==1.4.6
+pytest==7.4.3
pytest-cov==4.1.0
-tensorboard==2.14.0
diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml
index 3babdc7fd8b..2eba9c471d2 100644
--- a/tests/conf/bigearthnet_all.yaml
+++ b/tests/conf/bigearthnet_all.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.MultiLabelClassificationTask
- loss: "bce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 14
- num_classes: 19
-
-datamodule:
- _target_: torchgeo.datamodules.BigEarthNetDataModule
- root: "tests/data/bigearthnet"
- bands: "all"
- num_classes: ${module.num_classes}
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: MultiLabelClassificationTask
+ init_args:
+ loss: "bce"
+ model: "resnet18"
+ in_channels: 14
+ num_classes: 19
+data:
+ class_path: BigEarthNetDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/bigearthnet"
+ bands: "all"
+ num_classes: 19
+ download: true
diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml
index 8c07950cb5f..b93d54ce1eb 100644
--- a/tests/conf/bigearthnet_s1.yaml
+++ b/tests/conf/bigearthnet_s1.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.MultiLabelClassificationTask
- loss: "bce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 2
- num_classes: 19
-
-datamodule:
- _target_: torchgeo.datamodules.BigEarthNetDataModule
- root: "tests/data/bigearthnet"
- bands: "s1"
- num_classes: ${module.num_classes}
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: MultiLabelClassificationTask
+ init_args:
+ loss: "bce"
+ model: "resnet18"
+ in_channels: 2
+ num_classes: 19
+data:
+ class_path: BigEarthNetDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/bigearthnet"
+ bands: "s1"
+ num_classes: 19
+ download: true
diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml
index 9408e20b633..d00085a4879 100644
--- a/tests/conf/bigearthnet_s2.yaml
+++ b/tests/conf/bigearthnet_s2.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.MultiLabelClassificationTask
- loss: "bce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 12
- num_classes: 19
-
-datamodule:
- _target_: torchgeo.datamodules.BigEarthNetDataModule
- root: "tests/data/bigearthnet"
- bands: "s2"
- num_classes: ${module.num_classes}
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: MultiLabelClassificationTask
+ init_args:
+ loss: "bce"
+ model: "resnet18"
+ in_channels: 12
+ num_classes: 19
+data:
+ class_path: BigEarthNetDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/bigearthnet"
+ bands: "s2"
+ num_classes: 19
+ download: true
diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml
index a3f8e08b48d..4bab68f756e 100644
--- a/tests/conf/chesapeake_cvpr_5.yaml
+++ b/tests/conf/chesapeake_cvpr_5.yaml
@@ -1,28 +1,26 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet50"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 4
- num_classes: 5
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
- root: "tests/data/chesapeake/cvpr"
- download: true
- train_splits:
- - "de-test"
- val_splits:
- - "de-test"
- test_splits:
- - "de-test"
- batch_size: 2
- patch_size: 64
- num_workers: 0
- class_set: ${module.num_classes}
- use_prior_labels: False
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet50"
+ in_channels: 4
+ num_classes: 5
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: ChesapeakeCVPRDataModule
+ init_args:
+ train_splits:
+ - "de-test"
+ val_splits:
+ - "de-test"
+ test_splits:
+ - "de-test"
+ batch_size: 2
+ patch_size: 64
+ class_set: 5
+ use_prior_labels: False
+ dict_kwargs:
+ root: "tests/data/chesapeake/cvpr"
+ download: true
diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml
index 5b1f0669423..d4c11c86864 100644
--- a/tests/conf/chesapeake_cvpr_7.yaml
+++ b/tests/conf/chesapeake_cvpr_7.yaml
@@ -1,28 +1,26 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 4
- num_classes: 7
- num_filters: 1
- ignore_index: null
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
- root: "tests/data/chesapeake/cvpr"
- download: true
- train_splits:
- - "de-test"
- val_splits:
- - "de-test"
- test_splits:
- - "de-test"
- batch_size: 2
- patch_size: 64
- num_workers: 0
- class_set: ${module.num_classes}
- use_prior_labels: False
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 4
+ num_classes: 7
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: ChesapeakeCVPRDataModule
+ init_args:
+ train_splits:
+ - "de-test"
+ val_splits:
+ - "de-test"
+ test_splits:
+ - "de-test"
+ batch_size: 2
+ patch_size: 64
+ class_set: 7
+ use_prior_labels: False
+ dict_kwargs:
+ root: "tests/data/chesapeake/cvpr"
+ download: true
diff --git a/tests/conf/chesapeake_cvpr_prior_byol.yaml b/tests/conf/chesapeake_cvpr_prior_byol.yaml
index 3ccf939feff..1819c87a074 100644
--- a/tests/conf/chesapeake_cvpr_prior_byol.yaml
+++ b/tests/conf/chesapeake_cvpr_prior_byol.yaml
@@ -1,23 +1,21 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 4
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
- root: "tests/data/chesapeake/cvpr"
- download: true
- train_splits:
- - "de-test"
- val_splits:
- - "de-test"
- test_splits:
- - "de-test"
- batch_size: 2
- patch_size: 64
- num_workers: 0
- class_set: 5
- use_prior_labels: True
+model:
+ class_path: BYOLTask
+ init_args:
+ in_channels: 4
+ model: "resnet18"
+data:
+ class_path: ChesapeakeCVPRDataModule
+ init_args:
+ train_splits:
+ - "de-test"
+ val_splits:
+ - "de-test"
+ test_splits:
+ - "de-test"
+ batch_size: 2
+ patch_size: 64
+ class_set: 5
+ use_prior_labels: True
+ dict_kwargs:
+ root: "tests/data/chesapeake/cvpr"
+ download: true
diff --git a/tests/conf/chesapeake_cvpr_prior_moco.yaml b/tests/conf/chesapeake_cvpr_prior_moco.yaml
index aa3728f1b02..87179a28e69 100644
--- a/tests/conf/chesapeake_cvpr_prior_moco.yaml
+++ b/tests/conf/chesapeake_cvpr_prior_moco.yaml
@@ -1,20 +1,21 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: "resnet18"
- in_channels: 4
-
-datamodule:
- _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
- root: "tests/data/chesapeake/cvpr"
- download: false
- train_splits:
- - "de-test"
- val_splits:
- - "de-test"
- test_splits:
- - "de-test"
- batch_size: 2
- patch_size: 64
- num_workers: 0
- class_set: 5
- use_prior_labels: True
+model:
+ class_path: MoCoTask
+ init_args:
+ model: "resnet18"
+ in_channels: 4
+data:
+ class_path: ChesapeakeCVPRDataModule
+ init_args:
+ train_splits:
+ - "de-test"
+ val_splits:
+ - "de-test"
+ test_splits:
+ - "de-test"
+ batch_size: 2
+ patch_size: 64
+ class_set: 5
+ use_prior_labels: True
+ dict_kwargs:
+ root: "tests/data/chesapeake/cvpr"
+ download: false
diff --git a/tests/conf/chesapeake_cvpr_prior_simclr.yaml b/tests/conf/chesapeake_cvpr_prior_simclr.yaml
index 731e9bf8bf2..96acc852815 100644
--- a/tests/conf/chesapeake_cvpr_prior_simclr.yaml
+++ b/tests/conf/chesapeake_cvpr_prior_simclr.yaml
@@ -1,23 +1,24 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: "resnet18"
- in_channels: 4
- version: 1
- layers: 2
- memory_bank_size: 0
-
-datamodule:
- _target_: torchgeo.datamodules.ChesapeakeCVPRDataModule
- root: "tests/data/chesapeake/cvpr"
- download: false
- train_splits:
- - "de-test"
- val_splits:
- - "de-test"
- test_splits:
- - "de-test"
- batch_size: 2
- patch_size: 64
- num_workers: 0
- class_set: 5
- use_prior_labels: True
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: "resnet18"
+ in_channels: 4
+ version: 1
+ layers: 2
+ memory_bank_size: 0
+data:
+ class_path: ChesapeakeCVPRDataModule
+ init_args:
+ train_splits:
+ - "de-test"
+ val_splits:
+ - "de-test"
+ test_splits:
+ - "de-test"
+ batch_size: 2
+ patch_size: 64
+ class_set: 5
+ use_prior_labels: True
+ dict_kwargs:
+ root: "tests/data/chesapeake/cvpr"
+ download: false
diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml
index f67b1b6a1be..b247b20cdd9 100644
--- a/tests/conf/cowc_counting.yaml
+++ b/tests/conf/cowc_counting.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.RegressionTask
- model: resnet18
- weights: null
- num_outputs: 1
- in_channels: 3
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- loss: "mse"
-
-datamodule:
- _target_: torchgeo.datamodules.COWCCountingDataModule
- root: "tests/data/cowc_counting"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: RegressionTask
+ init_args:
+ model: resnet18
+ num_outputs: 1
+ in_channels: 3
+ loss: "mse"
+data:
+ class_path: COWCCountingDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/cowc_counting"
+ download: true
diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml
index f7ecff850ba..a0c435e9549 100644
--- a/tests/conf/cyclone.yaml
+++ b/tests/conf/cyclone.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.RegressionTask
- model: "resnet18"
- weights: null
- num_outputs: 1
- in_channels: 3
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- loss: "mse"
-
-datamodule:
- _target_: torchgeo.datamodules.TropicalCycloneDataModule
- root: "tests/data/cyclone"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: RegressionTask
+ init_args:
+ model: "resnet18"
+ num_outputs: 1
+ in_channels: 3
+ loss: "mse"
+data:
+ class_path: TropicalCycloneDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/cyclone"
+ download: true
diff --git a/tests/conf/deepglobelandcover.yaml b/tests/conf/deepglobelandcover.yaml
index 392fe3ce7b7..08a29843fdc 100644
--- a/tests/conf/deepglobelandcover.yaml
+++ b/tests/conf/deepglobelandcover.yaml
@@ -1,21 +1,18 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 7
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.DeepGlobeLandCoverDataModule
- root: "tests/data/deepglobelandcover"
- batch_size: 1
- patch_size: 2
- val_split_pct: 0.5
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ num_classes: 7
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: DeepGlobeLandCoverDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ val_split_pct: 0.5
+ dict_kwargs:
+ root: "tests/data/deepglobelandcover"
diff --git a/tests/conf/etci2021.yaml b/tests/conf/etci2021.yaml
index 9af839e92e3..bdd08948433 100644
--- a/tests/conf/etci2021.yaml
+++ b/tests/conf/etci2021.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 6
- num_classes: 2
- ignore_index: 0
-
-datamodule:
- _target_: torchgeo.datamodules.ETCI2021DataModule
- root: "tests/data/etci2021"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 6
+ num_classes: 2
+ ignore_index: 0
+data:
+ class_path: ETCI2021DataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/etci2021"
+ download: true
diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml
index 7066f7f66ce..365b46aa776 100644
--- a/tests/conf/eurosat.yaml
+++ b/tests/conf/eurosat.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 13
- num_classes: 2
-
-datamodule:
- _target_: torchgeo.datamodules.EuroSATDataModule
- root: "tests/data/eurosat"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ in_channels: 13
+ num_classes: 2
+data:
+ class_path: EuroSATDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/eurosat"
+ download: true
diff --git a/tests/conf/eurosat100.yaml b/tests/conf/eurosat100.yaml
index 65e4be957f2..0981e380548 100644
--- a/tests/conf/eurosat100.yaml
+++ b/tests/conf/eurosat100.yaml
@@ -1,16 +1,17 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 13
- num_classes: 2
-
-datamodule:
- _target_: torchgeo.datamodules.EuroSAT100DataModule
- root: "tests/data/eurosat"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ lr: 1e-3
+ patience: 6
+ weights: null
+ in_channels: 13
+ num_classes: 2
+data:
+ class_path: EuroSAT100DataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/eurosat"
+ download: true
diff --git a/tests/conf/fire_risk.yaml b/tests/conf/fire_risk.yaml
index 8971ee6839a..b4ff3467c04 100644
--- a/tests/conf/fire_risk.yaml
+++ b/tests/conf/fire_risk.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 3
- num_classes: 5
-
-datamodule:
- _target_: torchgeo.datamodules.FireRiskDataModule
- root: "tests/data/fire_risk"
- download: false
- batch_size: 2
- num_workers: 0
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ in_channels: 3
+ num_classes: 5
+data:
+ class_path: FireRiskDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/fire_risk"
+ download: false
diff --git a/tests/conf/gid15.yaml b/tests/conf/gid15.yaml
index c9af542d037..057a56696b2 100644
--- a/tests/conf/gid15.yaml
+++ b/tests/conf/gid15.yaml
@@ -1,22 +1,19 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 16
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.GID15DataModule
- root: "tests/data/gid15"
- download: true
- batch_size: 1
- patch_size: 2
- val_split_pct: 0.5
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ num_classes: 16
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: GID15DataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ val_split_pct: 0.5
+ dict_kwargs:
+ root: "tests/data/gid15"
+ download: true
diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml
index df4f4043fc4..4fbd3ded072 100644
--- a/tests/conf/inria.yaml
+++ b/tests/conf/inria.yaml
@@ -1,20 +1,16 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 3
- num_classes: 2
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.InriaAerialImageLabelingDataModule
- root: "tests/data/inria"
- batch_size: 1
- patch_size: 2
- num_workers: 0
- val_split_pct: 0.2
- test_split_pct: 0.2
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ num_classes: 2
+ ignore_index: null
+data:
+ class_path: InriaAerialImageLabelingDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ dict_kwargs:
+ root: "tests/data/inria"
diff --git a/tests/conf/inria_deeplab.yaml b/tests/conf/inria_deeplab.yaml
new file mode 100644
index 00000000000..e16ba15abe3
--- /dev/null
+++ b/tests/conf/inria_deeplab.yaml
@@ -0,0 +1,14 @@
+model:
+ class_path: PixelwiseRegressionTask
+ init_args:
+ model: "deeplabv3+"
+ backbone: "resnet18"
+ in_channels: 3
+ loss: "mae"
+data:
+ class_path: InriaAerialImageLabelingDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ dict_kwargs:
+ root: "tests/data/inria"
diff --git a/tests/conf/inria_fcn.yaml b/tests/conf/inria_fcn.yaml
new file mode 100644
index 00000000000..692db059dbf
--- /dev/null
+++ b/tests/conf/inria_fcn.yaml
@@ -0,0 +1,14 @@
+model:
+ class_path: PixelwiseRegressionTask
+ init_args:
+ model: "fcn"
+ backbone: "resnet18"
+ in_channels: 3
+ loss: "mae"
+data:
+ class_path: InriaAerialImageLabelingDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ dict_kwargs:
+ root: "tests/data/inria"
diff --git a/tests/conf/inria_unet.yaml b/tests/conf/inria_unet.yaml
new file mode 100644
index 00000000000..ded50ffe79c
--- /dev/null
+++ b/tests/conf/inria_unet.yaml
@@ -0,0 +1,14 @@
+model:
+ class_path: PixelwiseRegressionTask
+ init_args:
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ loss: "mae"
+data:
+ class_path: InriaAerialImageLabelingDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ dict_kwargs:
+ root: "tests/data/inria"
diff --git a/tests/conf/l7irish.yaml b/tests/conf/l7irish.yaml
index d5147b0032d..fc67fb8e1cc 100644
--- a/tests/conf/l7irish.yaml
+++ b/tests/conf/l7irish.yaml
@@ -1,22 +1,19 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 9
- num_classes: 5
- num_filters: 1
- ignore_index: 0
-
-datamodule:
- _target_: torchgeo.datamodules.L7IrishDataModule
- root: "tests/data/l7irish"
- download: true
- batch_size: 1
- patch_size: 32
- length: 5
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 9
+ num_classes: 5
+ num_filters: 1
+ ignore_index: 0
+data:
+ class_path: L7IrishDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 32
+ length: 5
+ dict_kwargs:
+ paths: "tests/data/l7irish"
+ download: true
diff --git a/tests/conf/l8biome.yaml b/tests/conf/l8biome.yaml
index ae42f6efff3..f33b4b36464 100644
--- a/tests/conf/l8biome.yaml
+++ b/tests/conf/l8biome.yaml
@@ -1,22 +1,19 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 11
- num_classes: 5
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.L8BiomeDataModule
- root: "tests/data/l8biome"
- download: true
- batch_size: 1
- patch_size: 32
- length: 5
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 11
+ num_classes: 5
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: L8BiomeDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 32
+ length: 5
+ dict_kwargs:
+ paths: "tests/data/l8biome"
+ download: true
diff --git a/tests/conf/landcoverai.yaml b/tests/conf/landcoverai.yaml
index 691d19bb9be..90978ef0141 100644
--- a/tests/conf/landcoverai.yaml
+++ b/tests/conf/landcoverai.yaml
@@ -1,20 +1,17 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 6
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.LandCoverAIDataModule
- root: "tests/data/landcoverai"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ num_classes: 6
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: LandCoverAIDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/landcoverai"
+ download: true
diff --git a/tests/conf/loveda.yaml b/tests/conf/loveda.yaml
index 7a558ea2207..44745a6d929 100644
--- a/tests/conf/loveda.yaml
+++ b/tests/conf/loveda.yaml
@@ -1,20 +1,17 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 8
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.LoveDADataModule
- root: "tests/data/loveda"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ num_classes: 8
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: LoveDADataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/loveda"
+ download: true
diff --git a/tests/conf/naipchesapeake.yaml b/tests/conf/naipchesapeake.yaml
index f9c0e4880fa..4b13865f1bd 100644
--- a/tests/conf/naipchesapeake.yaml
+++ b/tests/conf/naipchesapeake.yaml
@@ -1,21 +1,19 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "deeplabv3+"
- backbone: "resnet34"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- in_channels: 4
- num_classes: 14
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.NAIPChesapeakeDataModule
- naip_root: "tests/data/naip"
- chesapeake_root: "tests/data/chesapeake/BAYWIDE"
- chesapeake_download: true
- batch_size: 2
- num_workers: 0
- patch_size: 32
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "deeplabv3+"
+ backbone: "resnet34"
+ in_channels: 4
+ num_classes: 14
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: NAIPChesapeakeDataModule
+ init_args:
+ batch_size: 2
+ patch_size: 32
+ dict_kwargs:
+ naip_paths: "tests/data/naip"
+ chesapeake_paths: "tests/data/chesapeake/BAYWIDE"
+ chesapeake_download: true
diff --git a/tests/conf/nasa_marine_debris.yaml b/tests/conf/nasa_marine_debris.yaml
index 7103560c5f3..a0f30127414 100644
--- a/tests/conf/nasa_marine_debris.yaml
+++ b/tests/conf/nasa_marine_debris.yaml
@@ -1,15 +1,13 @@
-module:
- _target_: torchgeo.trainers.ObjectDetectionTask
- model: "faster-rcnn"
- backbone: "resnet18"
- num_classes: 2
- learning_rate: 1.2e-4
- learning_rate_schedule_patience: 6
- verbose: false
-
-datamodule:
- _target_: torchgeo.datamodules.NASAMarineDebrisDataModule
- root: "tests/data/nasa_marine_debris"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: ObjectDetectionTask
+ init_args:
+ model: "faster-rcnn"
+ backbone: "resnet18"
+ num_classes: 2
+data:
+ class_path: NASAMarineDebrisDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/nasa_marine_debris"
+ download: true
diff --git a/tests/conf/potsdam2d.yaml b/tests/conf/potsdam2d.yaml
index bd5f8f6c0ca..362ec81815d 100644
--- a/tests/conf/potsdam2d.yaml
+++ b/tests/conf/potsdam2d.yaml
@@ -1,21 +1,18 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 4
- num_classes: 6
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.Potsdam2DDataModule
- root: "tests/data/potsdam"
- batch_size: 1
- patch_size: 2
- val_split_pct: 0.5
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 4
+ num_classes: 6
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: Potsdam2DDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ val_split_pct: 0.5
+ dict_kwargs:
+ root: "tests/data/potsdam"
diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml
index f8d1729572e..86deb432f65 100644
--- a/tests/conf/resisc45.yaml
+++ b/tests/conf/resisc45.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 3
- num_classes: 3
-
-datamodule:
- _target_: torchgeo.datamodules.RESISC45DataModule
- root: "tests/data/resisc45"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ in_channels: 3
+ num_classes: 3
+data:
+ class_path: RESISC45DataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/resisc45"
+ download: true
diff --git a/tests/conf/seco_byol_1.yaml b/tests/conf/seco_byol_1.yaml
index 5f7e0b91b20..9d2680fdec1 100644
--- a/tests/conf/seco_byol_1.yaml
+++ b/tests/conf/seco_byol_1.yaml
@@ -1,14 +1,17 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 3
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SeasonalContrastS2DataModule
- root: "tests/data/seco"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: BYOLTask
+ init_args:
+ in_channels: 3
+ model: "resnet18"
+data:
+ class_path: SeasonalContrastS2DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/seco"
+ seasons: 1
+ # https://github.com/Lightning-AI/lightning/issues/18616
+ bands:
+ - "B4"
+ - "B3"
+ - "B2"
diff --git a/tests/conf/seco_byol_2.yaml b/tests/conf/seco_byol_2.yaml
index 07ff81c0132..f3b51c00272 100644
--- a/tests/conf/seco_byol_2.yaml
+++ b/tests/conf/seco_byol_2.yaml
@@ -1,14 +1,17 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 3
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SeasonalContrastS2DataModule
- root: "tests/data/seco"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: BYOLTask
+ init_args:
+ in_channels: 3
+ model: "resnet18"
+data:
+ class_path: SeasonalContrastS2DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/seco"
+ seasons: 2
+ # https://github.com/Lightning-AI/lightning/issues/18616
+ bands:
+ - "B4"
+ - "B3"
+ - "B2"
diff --git a/tests/conf/seco_moco_1.yaml b/tests/conf/seco_moco_1.yaml
index 21825535bca..164f13a7a23 100644
--- a/tests/conf/seco_moco_1.yaml
+++ b/tests/conf/seco_moco_1.yaml
@@ -1,16 +1,22 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: "resnet18"
- in_channels: 3
- version: 1
- weight_decay: 1e-4
- temperature: 0.07
- memory_bank_size: 10
- moco_momentum: 0.999
-
-datamodule:
- _target_: torchgeo.datamodules.SeasonalContrastS2DataModule
- root: "tests/data/seco"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: MoCoTask
+ init_args:
+ model: "resnet18"
+ in_channels: 3
+ version: 1
+ weight_decay: 1e-4
+ temperature: 0.07
+ memory_bank_size: 10
+ moco_momentum: 0.999
+data:
+ class_path: SeasonalContrastS2DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/seco"
+ seasons: 1
+ # https://github.com/Lightning-AI/lightning/issues/18616
+ bands:
+ - "B4"
+ - "B3"
+ - "B2"
diff --git a/tests/conf/seco_moco_2.yaml b/tests/conf/seco_moco_2.yaml
index b4e1a5dcfd1..fee827289f2 100644
--- a/tests/conf/seco_moco_2.yaml
+++ b/tests/conf/seco_moco_2.yaml
@@ -1,19 +1,25 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: "resnet18"
- in_channels: 3
- version: 2
- layers: 2
- hidden_dim: 10
- output_dim: 5
- weight_decay: 1e-4
- temperature: 0.07
- memory_bank_size: 10
- moco_momentum: 0.999
-
-datamodule:
- _target_: torchgeo.datamodules.SeasonalContrastS2DataModule
- root: "tests/data/seco"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: MoCoTask
+ init_args:
+ model: "resnet18"
+ in_channels: 3
+ version: 2
+ layers: 2
+ hidden_dim: 10
+ output_dim: 5
+ weight_decay: 1e-4
+ temperature: 0.07
+ memory_bank_size: 10
+ moco_momentum: 0.999
+data:
+ class_path: SeasonalContrastS2DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/seco"
+ seasons: 2
+ # https://github.com/Lightning-AI/lightning/issues/18616
+ bands:
+ - "B4"
+ - "B3"
+ - "B2"
diff --git a/tests/conf/seco_simclr_1.yaml b/tests/conf/seco_simclr_1.yaml
index ec0fa60d002..b23653e0cf0 100644
--- a/tests/conf/seco_simclr_1.yaml
+++ b/tests/conf/seco_simclr_1.yaml
@@ -1,17 +1,23 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: "resnet18"
- in_channels: 3
- version: 1
- layers: 2
- hidden_dim: 8
- output_dim: 8
- weight_decay: 1e-6
- memory_bank_size: 0
-
-datamodule:
- _target_: torchgeo.datamodules.SeasonalContrastS2DataModule
- root: "tests/data/seco"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: "resnet18"
+ in_channels: 3
+ version: 1
+ layers: 2
+ hidden_dim: 8
+ output_dim: 8
+ weight_decay: 1e-6
+ memory_bank_size: 0
+data:
+ class_path: SeasonalContrastS2DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/seco"
+ seasons: 1
+ # https://github.com/Lightning-AI/lightning/issues/18616
+ bands:
+ - "B4"
+ - "B3"
+ - "B2"
diff --git a/tests/conf/seco_simclr_2.yaml b/tests/conf/seco_simclr_2.yaml
index 22e00585c20..1b06c5d5c3c 100644
--- a/tests/conf/seco_simclr_2.yaml
+++ b/tests/conf/seco_simclr_2.yaml
@@ -1,17 +1,23 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: "resnet18"
- in_channels: 3
- version: 2
- layers: 4
- hidden_dim: 8
- output_dim: 8
- weight_decay: 1e-4
- memory_bank_size: 10
-
-datamodule:
- _target_: torchgeo.datamodules.SeasonalContrastS2DataModule
- root: "tests/data/seco"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: "resnet18"
+ in_channels: 3
+ version: 2
+ layers: 4
+ hidden_dim: 8
+ output_dim: 8
+ weight_decay: 1e-4
+ memory_bank_size: 10
+data:
+ class_path: SeasonalContrastS2DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/seco"
+ seasons: 2
+ # https://github.com/Lightning-AI/lightning/issues/18616
+ bands:
+ - "B4"
+ - "B3"
+ - "B2"
diff --git a/tests/conf/sen12ms_all.yaml b/tests/conf/sen12ms_all.yaml
index fe3d592a356..3f83fa55085 100644
--- a/tests/conf/sen12ms_all.yaml
+++ b/tests/conf/sen12ms_all.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- in_channels: 15
- num_classes: 11
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.SEN12MSDataModule
- root: "tests/data/sen12ms"
- band_set: "all"
- batch_size: 1
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 15
+ num_classes: 11
+ ignore_index: null
+data:
+ class_path: SEN12MSDataModule
+ init_args:
+ batch_size: 1
+ band_set: "all"
+ dict_kwargs:
+ root: "tests/data/sen12ms"
diff --git a/tests/conf/sen12ms_s1.yaml b/tests/conf/sen12ms_s1.yaml
index b0b9d553931..7e536d9e35a 100644
--- a/tests/conf/sen12ms_s1.yaml
+++ b/tests/conf/sen12ms_s1.yaml
@@ -1,19 +1,17 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "focal"
- model: "fcn"
- num_filters: 1
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- in_channels: 2
- num_classes: 11
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.SEN12MSDataModule
- root: "tests/data/sen12ms"
- band_set: "s1"
- batch_size: 1
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "focal"
+ model: "fcn"
+ num_filters: 1
+ backbone: "resnet18"
+ in_channels: 2
+ num_classes: 11
+ ignore_index: null
+data:
+ class_path: SEN12MSDataModule
+ init_args:
+ batch_size: 1
+ band_set: "s1"
+ dict_kwargs:
+ root: "tests/data/sen12ms"
diff --git a/tests/conf/sen12ms_s2_all.yaml b/tests/conf/sen12ms_s2_all.yaml
index e80b74896e0..b98d59d0c7f 100644
--- a/tests/conf/sen12ms_s2_all.yaml
+++ b/tests/conf/sen12ms_s2_all.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- in_channels: 13
- num_classes: 11
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.SEN12MSDataModule
- root: "tests/data/sen12ms"
- band_set: "s2-all"
- batch_size: 1
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 13
+ num_classes: 11
+ ignore_index: null
+data:
+ class_path: SEN12MSDataModule
+ init_args:
+ batch_size: 1
+ band_set: "s2-all"
+ dict_kwargs:
+ root: "tests/data/sen12ms"
diff --git a/tests/conf/sen12ms_s2_reduced.yaml b/tests/conf/sen12ms_s2_reduced.yaml
index 15758690e03..770efaa6549 100644
--- a/tests/conf/sen12ms_s2_reduced.yaml
+++ b/tests/conf/sen12ms_s2_reduced.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- in_channels: 6
- num_classes: 11
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.SEN12MSDataModule
- root: "tests/data/sen12ms"
- band_set: "s2-reduced"
- batch_size: 1
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 6
+ num_classes: 11
+ ignore_index: null
+data:
+ class_path: SEN12MSDataModule
+ init_args:
+ batch_size: 1
+ band_set: "s2-reduced"
+ dict_kwargs:
+ root: "tests/data/sen12ms"
diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml
index 14dd1bcaabe..a18a05abfe6 100644
--- a/tests/conf/skippd.yaml
+++ b/tests/conf/skippd.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.RegressionTask
- model: "resnet18"
- weights: null
- num_outputs: 1
- in_channels: 3
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- loss: "mse"
-
-datamodule:
- _target_: torchgeo.datamodules.SKIPPDDataModule
- root: "tests/data/skippd"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: RegressionTask
+ init_args:
+ model: "resnet18"
+ num_outputs: 1
+ in_channels: 3
+ loss: "mse"
+data:
+ class_path: SKIPPDDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/skippd"
+ download: true
diff --git a/tests/conf/so2sat_all.yaml b/tests/conf/so2sat_all.yaml
index 22919afe697..c728c9d7179 100644
--- a/tests/conf/so2sat_all.yaml
+++ b/tests/conf/so2sat_all.yaml
@@ -1,17 +1,15 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 18
- num_classes: 17
-
-datamodule:
- _target_: torchgeo.datamodules.So2SatDataModule
- root: "tests/data/so2sat"
- batch_size: 1
- num_workers: 0
- version: "2"
- band_set: "all"
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ in_channels: 18
+ num_classes: 17
+data:
+ class_path: So2SatDataModule
+ init_args:
+ batch_size: 1
+ band_set: "all"
+ dict_kwargs:
+ root: "tests/data/so2sat"
+ version: "2"
diff --git a/tests/conf/so2sat_rgb.yaml b/tests/conf/so2sat_rgb.yaml
index 75f7490ce22..66e1e223561 100644
--- a/tests/conf/so2sat_rgb.yaml
+++ b/tests/conf/so2sat_rgb.yaml
@@ -1,18 +1,16 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 3
- num_classes: 17
-
-datamodule:
- _target_: torchgeo.datamodules.So2SatDataModule
- root: "tests/data/so2sat"
- batch_size: 1
- num_workers: 0
- version: "3_random"
- band_set: "rgb"
- val_split_pct: 0.5
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ in_channels: 3
+ num_classes: 17
+data:
+ class_path: So2SatDataModule
+ init_args:
+ batch_size: 1
+ band_set: "rgb"
+ val_split_pct: 0.5
+ dict_kwargs:
+ root: "tests/data/so2sat"
+ version: "3_random"
diff --git a/tests/conf/so2sat_s1.yaml b/tests/conf/so2sat_s1.yaml
index c81e79742b8..df7a9cb1ea9 100644
--- a/tests/conf/so2sat_s1.yaml
+++ b/tests/conf/so2sat_s1.yaml
@@ -1,17 +1,15 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "focal"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 8
- num_classes: 17
-
-datamodule:
- _target_: torchgeo.datamodules.So2SatDataModule
- root: "tests/data/so2sat"
- batch_size: 1
- num_workers: 0
- version: "2"
- band_set: "s1"
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "focal"
+ model: "resnet18"
+ in_channels: 8
+ num_classes: 17
+data:
+ class_path: So2SatDataModule
+ init_args:
+ batch_size: 1
+ band_set: "s1"
+ dict_kwargs:
+ root: "tests/data/so2sat"
+ version: "2"
diff --git a/tests/conf/so2sat_s2.yaml b/tests/conf/so2sat_s2.yaml
index d7ba063efac..fb41099e60e 100644
--- a/tests/conf/so2sat_s2.yaml
+++ b/tests/conf/so2sat_s2.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "jaccard"
- model: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
- in_channels: 10
- num_classes: 17
-
-datamodule:
- _target_: torchgeo.datamodules.So2SatDataModule
- root: "tests/data/so2sat"
- batch_size: 1
- num_workers: 0
- band_set: "s2"
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "jaccard"
+ model: "resnet18"
+ in_channels: 10
+ num_classes: 17
+data:
+ class_path: So2SatDataModule
+ init_args:
+ batch_size: 1
+ band_set: "s2"
+ dict_kwargs:
+ root: "tests/data/so2sat"
diff --git a/tests/conf/spacenet1.yaml b/tests/conf/spacenet1.yaml
index dc88c2504d1..0da6cd24c4c 100644
--- a/tests/conf/spacenet1.yaml
+++ b/tests/conf/spacenet1.yaml
@@ -1,22 +1,19 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 3
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.SpaceNet1DataModule
- root: "tests/data/spacenet"
- download: true
- batch_size: 1
- num_workers: 0
- val_split_pct: 0.33
- test_split_pct: 0.33
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ num_classes: 3
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: SpaceNet1DataModule
+ init_args:
+ batch_size: 1
+ val_split_pct: 0.33
+ test_split_pct: 0.33
+ dict_kwargs:
+ root: "tests/data/spacenet"
+ download: true
diff --git a/tests/conf/ssl4eo_l_benchmark_cdl.yaml b/tests/conf/ssl4eo_l_benchmark_cdl.yaml
index f44abedb3a7..a4a4a7b9203 100644
--- a/tests/conf/ssl4eo_l_benchmark_cdl.yaml
+++ b/tests/conf/ssl4eo_l_benchmark_cdl.yaml
@@ -1,20 +1,18 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 7
- num_classes: 134
- num_filters: 1
- ignore_index: 0
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "tests/data/ssl4eo_benchmark_landsat"
- sensor: "tm_toa"
- product: "cdl"
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 7
+ num_classes: 134
+ num_filters: 1
+ ignore_index: 0
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo_benchmark_landsat"
+ sensor: "tm_toa"
+ product: "cdl"
diff --git a/tests/conf/ssl4eo_l_benchmark_nlcd.yaml b/tests/conf/ssl4eo_l_benchmark_nlcd.yaml
index 6dd85d935b7..89475a091b0 100644
--- a/tests/conf/ssl4eo_l_benchmark_nlcd.yaml
+++ b/tests/conf/ssl4eo_l_benchmark_nlcd.yaml
@@ -1,20 +1,18 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 6
- num_classes: 17
- num_filters: 1
- ignore_index: 0
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLBenchmarkDataModule
- root: "tests/data/ssl4eo_benchmark_landsat"
- sensor: "etm_sr"
- product: "nlcd"
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 6
+ num_classes: 17
+ num_filters: 1
+ ignore_index: 0
+data:
+ class_path: SSL4EOLBenchmarkDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo_benchmark_landsat"
+ sensor: "etm_sr"
+ product: "nlcd"
diff --git a/tests/conf/ssl4eo_l_byol_1.yaml b/tests/conf/ssl4eo_l_byol_1.yaml
index a8e3dc0cd79..ed78b7fae37 100644
--- a/tests/conf/ssl4eo_l_byol_1.yaml
+++ b/tests/conf/ssl4eo_l_byol_1.yaml
@@ -1,15 +1,13 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 7
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: "tests/data/ssl4eo/l"
- split: "tm_toa"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: BYOLTask
+ init_args:
+ in_channels: 7
+ model: "resnet18"
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/l"
+ split: "tm_toa"
+ seasons: 1
diff --git a/tests/conf/ssl4eo_l_byol_2.yaml b/tests/conf/ssl4eo_l_byol_2.yaml
index 2f1d87d83ff..6e1c6ab060d 100644
--- a/tests/conf/ssl4eo_l_byol_2.yaml
+++ b/tests/conf/ssl4eo_l_byol_2.yaml
@@ -1,15 +1,13 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 6
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: "tests/data/ssl4eo/l"
- split: "etm_sr"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: BYOLTask
+ init_args:
+ in_channels: 6
+ model: "resnet18"
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/l"
+ split: "etm_sr"
+ seasons: 2
diff --git a/tests/conf/ssl4eo_l_moco_1.yaml b/tests/conf/ssl4eo_l_moco_1.yaml
index 56233c5e747..1486d29bf01 100644
--- a/tests/conf/ssl4eo_l_moco_1.yaml
+++ b/tests/conf/ssl4eo_l_moco_1.yaml
@@ -1,17 +1,24 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: "resnet18"
- in_channels: 9
- version: 1
- weight_decay: 1e-4
- temperature: 0.07
- memory_bank_size: 10
- moco_momentum: 0.999
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: "tests/data/ssl4eo/l"
- split: "etm_toa"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: MoCoTask
+ init_args:
+ model: "resnet18"
+ in_channels: 9
+ version: 1
+ weight_decay: 1e-4
+ temperature: 0.07
+ memory_bank_size: 10
+ moco_momentum: 0.999
+ augmentation1:
+ class_path: kornia.augmentation.RandomResizedCrop
+ init_args:
+ size:
+ - 224
+ - 224
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/l"
+ split: "etm_toa"
+ seasons: 1
diff --git a/tests/conf/ssl4eo_l_moco_2.yaml b/tests/conf/ssl4eo_l_moco_2.yaml
index 91cd9a1ca93..3edf6a52487 100644
--- a/tests/conf/ssl4eo_l_moco_2.yaml
+++ b/tests/conf/ssl4eo_l_moco_2.yaml
@@ -1,20 +1,21 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: "resnet18"
- in_channels: 11
- version: 2
- layers: 2
- hidden_dim: 10
- output_dim: 5
- weight_decay: 1e-4
- temperature: 0.07
- memory_bank_size: 10
- moco_momentum: 0.999
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: "tests/data/ssl4eo/l"
- split: "oli_tirs_toa"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: MoCoTask
+ init_args:
+ model: "resnet18"
+ in_channels: 11
+ version: 2
+ layers: 2
+ hidden_dim: 10
+ output_dim: 5
+ weight_decay: 1e-4
+ temperature: 0.07
+ memory_bank_size: 10
+ moco_momentum: 0.999
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/l"
+ split: "oli_tirs_toa"
+ seasons: 2
diff --git a/tests/conf/ssl4eo_l_simclr_1.yaml b/tests/conf/ssl4eo_l_simclr_1.yaml
index 8148ab2c5de..b705579173f 100644
--- a/tests/conf/ssl4eo_l_simclr_1.yaml
+++ b/tests/conf/ssl4eo_l_simclr_1.yaml
@@ -1,18 +1,19 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: "resnet18"
- in_channels: 7
- version: 1
- layers: 2
- hidden_dim: 8
- output_dim: 8
- weight_decay: 1e-6
- memory_bank_size: 0
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: "tests/data/ssl4eo/l"
- split: "oli_sr"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: "resnet18"
+ in_channels: 7
+ version: 1
+ layers: 2
+ hidden_dim: 8
+ output_dim: 8
+ weight_decay: 1e-6
+ memory_bank_size: 0
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/l"
+ split: "oli_sr"
+ seasons: 1
diff --git a/tests/conf/ssl4eo_l_simclr_2.yaml b/tests/conf/ssl4eo_l_simclr_2.yaml
index 8d80bab9068..7310bba9e95 100644
--- a/tests/conf/ssl4eo_l_simclr_2.yaml
+++ b/tests/conf/ssl4eo_l_simclr_2.yaml
@@ -1,18 +1,19 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: "resnet18"
- in_channels: 7
- version: 2
- layers: 3
- hidden_dim: 8
- output_dim: 8
- weight_decay: 1e-4
- memory_bank_size: 10
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOLDataModule
- root: "tests/data/ssl4eo/l"
- split: "tm_toa"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: "resnet18"
+ in_channels: 7
+ version: 2
+ layers: 3
+ hidden_dim: 8
+ output_dim: 8
+ weight_decay: 1e-4
+ memory_bank_size: 10
+data:
+ class_path: SSL4EOLDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/l"
+ split: "tm_toa"
+ seasons: 2
diff --git a/tests/conf/ssl4eo_s12_byol_1.yaml b/tests/conf/ssl4eo_s12_byol_1.yaml
index 8d261d9de27..ccdf4b5736d 100644
--- a/tests/conf/ssl4eo_s12_byol_1.yaml
+++ b/tests/conf/ssl4eo_s12_byol_1.yaml
@@ -1,15 +1,13 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 2
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOS12DataModule
- root: "tests/data/ssl4eo/s12"
- split: "s1"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: BYOLTask
+ init_args:
+ in_channels: 2
+ model: "resnet18"
+data:
+ class_path: SSL4EOS12DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/s12"
+ split: "s1"
+ seasons: 1
diff --git a/tests/conf/ssl4eo_s12_byol_2.yaml b/tests/conf/ssl4eo_s12_byol_2.yaml
index 0bf2164b0b5..6368e8fdefe 100644
--- a/tests/conf/ssl4eo_s12_byol_2.yaml
+++ b/tests/conf/ssl4eo_s12_byol_2.yaml
@@ -1,15 +1,13 @@
-module:
- _target_: torchgeo.trainers.BYOLTask
- in_channels: 13
- backbone: "resnet18"
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- weights: null
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOS12DataModule
- root: "tests/data/ssl4eo/s12"
- split: "s2c"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: BYOLTask
+ init_args:
+ in_channels: 13
+ model: "resnet18"
+data:
+ class_path: SSL4EOS12DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/s12"
+ split: "s2c"
+ seasons: 2
diff --git a/tests/conf/ssl4eo_s12_moco_1.yaml b/tests/conf/ssl4eo_s12_moco_1.yaml
index f3ac6fbea84..513d5ae0842 100644
--- a/tests/conf/ssl4eo_s12_moco_1.yaml
+++ b/tests/conf/ssl4eo_s12_moco_1.yaml
@@ -1,17 +1,18 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: "resnet18"
- in_channels: 12
- version: 1
- weight_decay: 1e-4
- temperature: 0.07
- memory_bank_size: 10
- moco_momentum: 0.999
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOS12DataModule
- root: "tests/data/ssl4eo/s12"
- split: "s2a"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: MoCoTask
+ init_args:
+ model: "resnet18"
+ in_channels: 12
+ version: 1
+ weight_decay: 1e-4
+ temperature: 0.07
+ memory_bank_size: 10
+ moco_momentum: 0.999
+data:
+ class_path: SSL4EOS12DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/s12"
+ split: "s2a"
+ seasons: 1
diff --git a/tests/conf/ssl4eo_s12_moco_2.yaml b/tests/conf/ssl4eo_s12_moco_2.yaml
index f574e11e7ae..71d8ee43dc7 100644
--- a/tests/conf/ssl4eo_s12_moco_2.yaml
+++ b/tests/conf/ssl4eo_s12_moco_2.yaml
@@ -1,20 +1,21 @@
-module:
- _target_: torchgeo.trainers.MoCoTask
- model: "resnet18"
- in_channels: 2
- version: 2
- layers: 2
- hidden_dim: 10
- output_dim: 5
- weight_decay: 1e-4
- temperature: 0.07
- memory_bank_size: 10
- moco_momentum: 0.999
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOS12DataModule
- root: "tests/data/ssl4eo/s12"
- split: "s1"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: MoCoTask
+ init_args:
+ model: "resnet18"
+ in_channels: 2
+ version: 2
+ layers: 2
+ hidden_dim: 10
+ output_dim: 5
+ weight_decay: 1e-4
+ temperature: 0.07
+ memory_bank_size: 10
+ moco_momentum: 0.999
+data:
+ class_path: SSL4EOS12DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/s12"
+ split: "s1"
+ seasons: 2
diff --git a/tests/conf/ssl4eo_s12_simclr_1.yaml b/tests/conf/ssl4eo_s12_simclr_1.yaml
index 7d32f84b1d1..94444be5cc9 100644
--- a/tests/conf/ssl4eo_s12_simclr_1.yaml
+++ b/tests/conf/ssl4eo_s12_simclr_1.yaml
@@ -1,18 +1,19 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: "resnet18"
- in_channels: 13
- version: 1
- layers: 2
- hidden_dim: 8
- output_dim: 8
- weight_decay: 1e-6
- memory_bank_size: 0
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOS12DataModule
- root: "tests/data/ssl4eo/s12"
- split: "s2c"
- seasons: 1
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: "resnet18"
+ in_channels: 13
+ version: 1
+ layers: 2
+ hidden_dim: 8
+ output_dim: 8
+ weight_decay: 1e-6
+ memory_bank_size: 0
+data:
+ class_path: SSL4EOS12DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/s12"
+ split: "s2c"
+ seasons: 1
diff --git a/tests/conf/ssl4eo_s12_simclr_2.yaml b/tests/conf/ssl4eo_s12_simclr_2.yaml
index d97e6b21d46..7d88a3713ba 100644
--- a/tests/conf/ssl4eo_s12_simclr_2.yaml
+++ b/tests/conf/ssl4eo_s12_simclr_2.yaml
@@ -1,18 +1,19 @@
-module:
- _target_: torchgeo.trainers.SimCLRTask
- model: "resnet18"
- in_channels: 12
- version: 2
- layers: 3
- hidden_dim: 8
- output_dim: 8
- weight_decay: 1e-4
- memory_bank_size: 10
-
-datamodule:
- _target_: torchgeo.datamodules.SSL4EOS12DataModule
- root: "tests/data/ssl4eo/s12"
- split: "s2a"
- seasons: 2
- batch_size: 2
- num_workers: 0
+model:
+ class_path: SimCLRTask
+ init_args:
+ model: "resnet18"
+ in_channels: 12
+ version: 2
+ layers: 3
+ hidden_dim: 8
+ output_dim: 8
+ weight_decay: 1e-4
+ memory_bank_size: 10
+data:
+ class_path: SSL4EOS12DataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ssl4eo/s12"
+ split: "s2a"
+ seasons: 2
diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml
index 9b092aab674..ba6e65af105 100644
--- a/tests/conf/sustainbench_crop_yield.yaml
+++ b/tests/conf/sustainbench_crop_yield.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.RegressionTask
- model: "resnet18"
- weights: null
- num_outputs: 1
- in_channels: 9
- learning_rate: 1e-3
- learning_rate_schedule_patience: 2
- loss: "mse"
-
-datamodule:
- _target_: torchgeo.datamodules.SustainBenchCropYieldDataModule
- root: "tests/data/sustainbench_crop_yield"
- download: true
- batch_size: 1
- num_workers: 0
+model:
+ class_path: RegressionTask
+ init_args:
+ model: "resnet18"
+ num_outputs: 1
+ in_channels: 9
+ loss: "mse"
+data:
+ class_path: SustainBenchCropYieldDataModule
+ init_args:
+ batch_size: 1
+ dict_kwargs:
+ root: "tests/data/sustainbench_crop_yield"
+ download: true
diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml
index 93e37db6059..d9c8752f1ec 100644
--- a/tests/conf/ucmerced.yaml
+++ b/tests/conf/ucmerced.yaml
@@ -1,16 +1,14 @@
-module:
- _target_: torchgeo.trainers.ClassificationTask
- loss: "ce"
- model: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- in_channels: 3
- num_classes: 2
-
-datamodule:
- _target_: torchgeo.datamodules.UCMercedDataModule
- root: "tests/data/ucmerced"
- download: true
- batch_size: 2
- num_workers: 0
+model:
+ class_path: ClassificationTask
+ init_args:
+ loss: "ce"
+ model: "resnet18"
+ in_channels: 3
+ num_classes: 2
+data:
+ class_path: UCMercedDataModule
+ init_args:
+ batch_size: 2
+ dict_kwargs:
+ root: "tests/data/ucmerced"
+ download: true
diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml
index ebdc8613ad2..00404756ace 100644
--- a/tests/conf/vaihingen2d.yaml
+++ b/tests/conf/vaihingen2d.yaml
@@ -1,21 +1,18 @@
-module:
- _target_: torchgeo.trainers.SemanticSegmentationTask
- loss: "ce"
- model: "unet"
- backbone: "resnet18"
- weights: null
- learning_rate: 1e-3
- learning_rate_schedule_patience: 6
- verbose: false
- in_channels: 3
- num_classes: 7
- num_filters: 1
- ignore_index: null
-
-datamodule:
- _target_: torchgeo.datamodules.Vaihingen2DDataModule
- root: "tests/data/vaihingen"
- batch_size: 1
- patch_size: 2
- val_split_pct: 0.5
- num_workers: 0
+model:
+ class_path: SemanticSegmentationTask
+ init_args:
+ loss: "ce"
+ model: "unet"
+ backbone: "resnet18"
+ in_channels: 3
+ num_classes: 7
+ num_filters: 1
+ ignore_index: null
+data:
+ class_path: Vaihingen2DDataModule
+ init_args:
+ batch_size: 1
+ patch_size: 2
+ val_split_pct: 0.5
+ dict_kwargs:
+ root: "tests/data/vaihingen"
diff --git a/tests/data/agb_live_woody_density/Aboveground_Live_Woody_Biomass_Density.geojson b/tests/data/agb_live_woody_density/Aboveground_Live_Woody_Biomass_Density.geojson
index 169191641e1..7beec3e9db3 100644
--- a/tests/data/agb_live_woody_density/Aboveground_Live_Woody_Biomass_Density.geojson
+++ b/tests/data/agb_live_woody_density/Aboveground_Live_Woody_Biomass_Density.geojson
@@ -1 +1 @@
-{"type": "FeatureCollection", "name": "Aboveground_Live_Woody_Biomass_Density", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {"tile_id": "00N_000E", "download": "tests/data/agb_live_woody_density/00N_000E.tif", "ObjectId": 1, "Shape__Area": 1245542622548.87, "Shape__Length": 4464169.76558139}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [10.0, 0.0], [10.0, -10.0], [0.0, -10.0], [0.0, 0.0]]]}}]}
\ No newline at end of file
+{"type": "FeatureCollection", "name": "Aboveground_Live_Woody_Biomass_Density", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "properties": {"tile_id": "00N_000E", "Mg_px_1_download": "tests/data/agb_live_woody_density/00N_000E.tif", "ObjectId": 1, "Shape__Area": 1245542622548.87, "Shape__Length": 4464169.76558139}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [10.0, 0.0], [10.0, -10.0], [0.0, -10.0], [0.0, 0.0]]]}}]}
\ No newline at end of file
diff --git a/tests/data/agb_live_woody_density/data.py b/tests/data/agb_live_woody_density/data.py
index 5bbb9b1476d..115a9772fba 100755
--- a/tests/data/agb_live_woody_density/data.py
+++ b/tests/data/agb_live_woody_density/data.py
@@ -23,7 +23,7 @@
"type": "Feature",
"properties": {
"tile_id": "00N_000E",
- "download": os.path.join(
+ "Mg_px_1_download": os.path.join(
"tests", "data", "agb_live_woody_density", "00N_000E.tif"
),
"ObjectId": 1,
@@ -74,5 +74,5 @@ def create_file(path: str, dtype: str, num_channels: int) -> None:
json.dump(base_file, f)
for i in base_file["features"]:
- filepath = os.path.basename(i["properties"]["download"])
+ filepath = os.path.basename(i["properties"]["Mg_px_1_download"])
create_file(path=filepath, dtype="int32", num_channels=1)
diff --git a/tests/data/biomassters/The_BioMassters_-_features_metadata.csv.csv b/tests/data/biomassters/The_BioMassters_-_features_metadata.csv.csv
new file mode 100644
index 00000000000..f7bd35a4e89
--- /dev/null
+++ b/tests/data/biomassters/The_BioMassters_-_features_metadata.csv.csv
@@ -0,0 +1,21 @@
+filename,chip_id,satellite,split,month,size,cksum,s3path_us,s3path_eu,s3path_as,corresponding_agbm
+0003d2eb_S1_00.tif,0003d2eb,S1,train,September,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S1_01.tif,0003d2eb,S1,train,October,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S1_02.tif,0003d2eb,S1,train,November,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S2_00.tif,0003d2eb,S2,train,September,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S2_02.tif,0003d2eb,S2,train,November,0,0,path,path,path,0003d2eb_agbm.tif
+000aa810_S1_00.tif,000aa810,S1,train,September,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S1_01.tif,000aa810,S1,train,October,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S1_02.tif,000aa810,S1,train,November,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S2_00.tif,000aa810,S2,train,September,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S2_02.tif,000aa810,S2,train,November,0,0,path,path,path,000aa810_agbm.tif
+0003d2eb_S1_00.tif,0003d2eb,S1,test,September,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S1_01.tif,0003d2eb,S1,test,October,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S1_02.tif,0003d2eb,S1,test,November,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S2_00.tif,0003d2eb,S2,test,September,0,0,path,path,path,0003d2eb_agbm.tif
+0003d2eb_S2_02.tif,0003d2eb,S2,test,November,0,0,path,path,path,0003d2eb_agbm.tif
+000aa810_S1_00.tif,000aa810,S1,test,September,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S1_01.tif,000aa810,S1,test,October,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S1_02.tif,000aa810,S1,test,November,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S2_00.tif,000aa810,S2,test,September,0,0,path,path,path,000aa810_agbm.tif
+000aa810_S2_02.tif,000aa810,S2,test,November,0,0,path,path,path,000aa810_agbm.tif
diff --git a/tests/data/biomassters/data.py b/tests/data/biomassters/data.py
new file mode 100644
index 00000000000..648ae9b94f4
--- /dev/null
+++ b/tests/data/biomassters/data.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import csv
+import hashlib
+import os
+import shutil
+
+import numpy as np
+import rasterio
+
+metadata_train = "The_BioMassters_-_features_metadata.csv.csv"
+
+csv_columns = [
+ "filename",
+ "chip_id",
+ "satellite",
+ "split",
+ "month",
+ "size",
+ "cksum",
+ "s3path_us",
+ "s3path_eu",
+ "s3path_as",
+ "corresponding_agbm",
+]
+
+targets = "train_agbm.zip"
+
+splits = ["train", "test"]
+
+sample_ids = ["0003d2eb", "000aa810"]
+
+months = ["September", "October", "November"]
+
+satellite = ["S1", "S2"]
+
+SIZE = 32
+
+
+def create_tif_file(path: str, num_channels: int, dtype: str) -> None:
+ """Create S1 or S2 data with num channels.
+
+ Args:
+ path: path where to save tif
+ num_channels: number of channels (4 for S1, 11 for S2)
+ dtype: uint16 for image data and float 32 for target
+ """
+ profile = {}
+ profile["driver"] = "GTiff"
+ profile["dtype"] = dtype
+ profile["count"] = num_channels
+ profile["crs"] = "epsg:4326"
+ profile["transform"] = rasterio.transform.from_bounds(0, 0, 1, 1, 1, 1)
+ profile["height"] = SIZE
+ profile["width"] = SIZE
+ profile["compress"] = "lzw"
+ profile["predictor"] = 2
+
+ if "float" in profile["dtype"]:
+ Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])
+ else:
+ Z = np.random.randint(
+ np.iinfo(profile["dtype"]).max, size=(SIZE, SIZE), dtype=profile["dtype"]
+ )
+
+ with rasterio.open(path, "w", **profile) as src:
+ for i in range(1, profile["count"] + 1):
+ src.write(Z, i)
+
+
+# filename,chip_id,satellite,split,month,size,cksum,s3path_us,s3path_eu,s3path_as,corresponding_agbm
+if __name__ == "__main__":
+ csv_rows = []
+ for split in splits:
+ os.makedirs(f"{split}_features", exist_ok=True)
+ if split == "train":
+ os.makedirs("train_agbm", exist_ok=True)
+ for id in sample_ids:
+ for sat in satellite:
+ path = id + "_" + str(sat)
+ for idx, month in enumerate(months):
+ # S2 data is not present for every month
+ if sat == "S2" and idx == 1:
+ continue
+ file_path = path + "_" + f"{idx:02d}" + ".tif"
+
+ csv_rows.append(
+ [
+ file_path,
+ id,
+ sat,
+ split,
+ month,
+ "0",
+ "0",
+ "path",
+ "path",
+ "path",
+ id + "_agbm.tif",
+ ]
+ )
+
+ # file path to save
+ file_path = os.path.join(f"{split}_features", file_path)
+
+ if sat == "S1":
+ create_tif_file(file_path, num_channels=4, dtype="uint16")
+ else:
+ create_tif_file(file_path, num_channels=11, dtype="uint16")
+
+ # create target data one per id
+ if split == "train":
+ create_tif_file(
+ os.path.join(f"{split}_agbm", id + "_agbm.tif"),
+ num_channels=1,
+ dtype="float32",
+ )
+
+ # write out metadata
+
+ with open(metadata_train, "w") as csv_file:
+ wr = csv.writer(csv_file)
+ wr.writerow(csv_columns)
+ for row in csv_rows:
+ wr.writerow(row)
+
+ # zip up feature and target folders
+ zip_dirs = ["train_features", "test_features", "train_agbm"]
+ for dir in zip_dirs:
+ shutil.make_archive(dir, "zip", dir)
+ # Compute checksums
+ with open(dir + ".zip", "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{dir}: {md5}")
diff --git a/tests/data/biomassters/test_features.zip b/tests/data/biomassters/test_features.zip
new file mode 100644
index 00000000000..381f3c3b6b5
Binary files /dev/null and b/tests/data/biomassters/test_features.zip differ
diff --git a/tests/data/biomassters/test_features/0003d2eb_S1_00.tif b/tests/data/biomassters/test_features/0003d2eb_S1_00.tif
new file mode 100644
index 00000000000..9cb4379b6fa
Binary files /dev/null and b/tests/data/biomassters/test_features/0003d2eb_S1_00.tif differ
diff --git a/tests/data/biomassters/test_features/0003d2eb_S1_01.tif b/tests/data/biomassters/test_features/0003d2eb_S1_01.tif
new file mode 100644
index 00000000000..c2de35a26ec
Binary files /dev/null and b/tests/data/biomassters/test_features/0003d2eb_S1_01.tif differ
diff --git a/tests/data/biomassters/test_features/0003d2eb_S1_02.tif b/tests/data/biomassters/test_features/0003d2eb_S1_02.tif
new file mode 100644
index 00000000000..8a641b34091
Binary files /dev/null and b/tests/data/biomassters/test_features/0003d2eb_S1_02.tif differ
diff --git a/tests/data/biomassters/test_features/0003d2eb_S2_00.tif b/tests/data/biomassters/test_features/0003d2eb_S2_00.tif
new file mode 100644
index 00000000000..0c46590e1f5
Binary files /dev/null and b/tests/data/biomassters/test_features/0003d2eb_S2_00.tif differ
diff --git a/tests/data/biomassters/test_features/0003d2eb_S2_01.tif b/tests/data/biomassters/test_features/0003d2eb_S2_01.tif
new file mode 100644
index 00000000000..aabb39f5213
Binary files /dev/null and b/tests/data/biomassters/test_features/0003d2eb_S2_01.tif differ
diff --git a/tests/data/biomassters/test_features/0003d2eb_S2_02.tif b/tests/data/biomassters/test_features/0003d2eb_S2_02.tif
new file mode 100644
index 00000000000..30aac1463df
Binary files /dev/null and b/tests/data/biomassters/test_features/0003d2eb_S2_02.tif differ
diff --git a/tests/data/biomassters/test_features/000aa810_S1_00.tif b/tests/data/biomassters/test_features/000aa810_S1_00.tif
new file mode 100644
index 00000000000..b1a3d20ad12
Binary files /dev/null and b/tests/data/biomassters/test_features/000aa810_S1_00.tif differ
diff --git a/tests/data/biomassters/test_features/000aa810_S1_01.tif b/tests/data/biomassters/test_features/000aa810_S1_01.tif
new file mode 100644
index 00000000000..67d0bd41449
Binary files /dev/null and b/tests/data/biomassters/test_features/000aa810_S1_01.tif differ
diff --git a/tests/data/biomassters/test_features/000aa810_S1_02.tif b/tests/data/biomassters/test_features/000aa810_S1_02.tif
new file mode 100644
index 00000000000..0574ec89162
Binary files /dev/null and b/tests/data/biomassters/test_features/000aa810_S1_02.tif differ
diff --git a/tests/data/biomassters/test_features/000aa810_S2_00.tif b/tests/data/biomassters/test_features/000aa810_S2_00.tif
new file mode 100644
index 00000000000..c281f0ee2ee
Binary files /dev/null and b/tests/data/biomassters/test_features/000aa810_S2_00.tif differ
diff --git a/tests/data/biomassters/test_features/000aa810_S2_01.tif b/tests/data/biomassters/test_features/000aa810_S2_01.tif
new file mode 100644
index 00000000000..4639d161491
Binary files /dev/null and b/tests/data/biomassters/test_features/000aa810_S2_01.tif differ
diff --git a/tests/data/biomassters/test_features/000aa810_S2_02.tif b/tests/data/biomassters/test_features/000aa810_S2_02.tif
new file mode 100644
index 00000000000..55155c4f5af
Binary files /dev/null and b/tests/data/biomassters/test_features/000aa810_S2_02.tif differ
diff --git a/tests/data/biomassters/train_agbm.zip b/tests/data/biomassters/train_agbm.zip
new file mode 100644
index 00000000000..29ffc826101
Binary files /dev/null and b/tests/data/biomassters/train_agbm.zip differ
diff --git a/tests/data/biomassters/train_agbm/0003d2eb_agbm.tif b/tests/data/biomassters/train_agbm/0003d2eb_agbm.tif
new file mode 100644
index 00000000000..f6f64a088e5
Binary files /dev/null and b/tests/data/biomassters/train_agbm/0003d2eb_agbm.tif differ
diff --git a/tests/data/biomassters/train_agbm/000aa810_agbm.tif b/tests/data/biomassters/train_agbm/000aa810_agbm.tif
new file mode 100644
index 00000000000..283dcffdad2
Binary files /dev/null and b/tests/data/biomassters/train_agbm/000aa810_agbm.tif differ
diff --git a/tests/data/biomassters/train_features.zip b/tests/data/biomassters/train_features.zip
new file mode 100644
index 00000000000..c22b3f9cadb
Binary files /dev/null and b/tests/data/biomassters/train_features.zip differ
diff --git a/tests/data/biomassters/train_features/0003d2eb_S1_00.tif b/tests/data/biomassters/train_features/0003d2eb_S1_00.tif
new file mode 100644
index 00000000000..fcfac72920a
Binary files /dev/null and b/tests/data/biomassters/train_features/0003d2eb_S1_00.tif differ
diff --git a/tests/data/biomassters/train_features/0003d2eb_S1_01.tif b/tests/data/biomassters/train_features/0003d2eb_S1_01.tif
new file mode 100644
index 00000000000..93aa9244f20
Binary files /dev/null and b/tests/data/biomassters/train_features/0003d2eb_S1_01.tif differ
diff --git a/tests/data/biomassters/train_features/0003d2eb_S1_02.tif b/tests/data/biomassters/train_features/0003d2eb_S1_02.tif
new file mode 100644
index 00000000000..672c393ec5a
Binary files /dev/null and b/tests/data/biomassters/train_features/0003d2eb_S1_02.tif differ
diff --git a/tests/data/biomassters/train_features/0003d2eb_S2_00.tif b/tests/data/biomassters/train_features/0003d2eb_S2_00.tif
new file mode 100644
index 00000000000..22f5f4f22d7
Binary files /dev/null and b/tests/data/biomassters/train_features/0003d2eb_S2_00.tif differ
diff --git a/tests/data/biomassters/train_features/0003d2eb_S2_01.tif b/tests/data/biomassters/train_features/0003d2eb_S2_01.tif
new file mode 100644
index 00000000000..e7f9fd34bd8
Binary files /dev/null and b/tests/data/biomassters/train_features/0003d2eb_S2_01.tif differ
diff --git a/tests/data/biomassters/train_features/0003d2eb_S2_02.tif b/tests/data/biomassters/train_features/0003d2eb_S2_02.tif
new file mode 100644
index 00000000000..d811ccee380
Binary files /dev/null and b/tests/data/biomassters/train_features/0003d2eb_S2_02.tif differ
diff --git a/tests/data/biomassters/train_features/000aa810_S1_00.tif b/tests/data/biomassters/train_features/000aa810_S1_00.tif
new file mode 100644
index 00000000000..3fa69bea3c0
Binary files /dev/null and b/tests/data/biomassters/train_features/000aa810_S1_00.tif differ
diff --git a/tests/data/biomassters/train_features/000aa810_S1_01.tif b/tests/data/biomassters/train_features/000aa810_S1_01.tif
new file mode 100644
index 00000000000..be09998809e
Binary files /dev/null and b/tests/data/biomassters/train_features/000aa810_S1_01.tif differ
diff --git a/tests/data/biomassters/train_features/000aa810_S1_02.tif b/tests/data/biomassters/train_features/000aa810_S1_02.tif
new file mode 100644
index 00000000000..160c07adb70
Binary files /dev/null and b/tests/data/biomassters/train_features/000aa810_S1_02.tif differ
diff --git a/tests/data/biomassters/train_features/000aa810_S2_00.tif b/tests/data/biomassters/train_features/000aa810_S2_00.tif
new file mode 100644
index 00000000000..0d40a3dd0da
Binary files /dev/null and b/tests/data/biomassters/train_features/000aa810_S2_00.tif differ
diff --git a/tests/data/biomassters/train_features/000aa810_S2_01.tif b/tests/data/biomassters/train_features/000aa810_S2_01.tif
new file mode 100644
index 00000000000..190e4de4f3d
Binary files /dev/null and b/tests/data/biomassters/train_features/000aa810_S2_01.tif differ
diff --git a/tests/data/biomassters/train_features/000aa810_S2_02.tif b/tests/data/biomassters/train_features/000aa810_S2_02.tif
new file mode 100644
index 00000000000..bc7a3689000
Binary files /dev/null and b/tests/data/biomassters/train_features/000aa810_S2_02.tif differ
diff --git a/tests/data/inria/AerialImageDataset/test/images/austin10.tif b/tests/data/inria/AerialImageDataset/test/images/austin10.tif
index d77ca4e7fa4..b615ffd68c7 100644
Binary files a/tests/data/inria/AerialImageDataset/test/images/austin10.tif and b/tests/data/inria/AerialImageDataset/test/images/austin10.tif differ
diff --git a/tests/data/inria/AerialImageDataset/test/images/austin11.tif b/tests/data/inria/AerialImageDataset/test/images/austin11.tif
index 0042958d9b7..9f613041e63 100644
Binary files a/tests/data/inria/AerialImageDataset/test/images/austin11.tif and b/tests/data/inria/AerialImageDataset/test/images/austin11.tif differ
diff --git a/tests/data/inria/AerialImageDataset/test/images/austin12.tif b/tests/data/inria/AerialImageDataset/test/images/austin12.tif
index c7c12752406..e3fef860af2 100644
Binary files a/tests/data/inria/AerialImageDataset/test/images/austin12.tif and b/tests/data/inria/AerialImageDataset/test/images/austin12.tif differ
diff --git a/tests/data/inria/AerialImageDataset/test/images/austin13.tif b/tests/data/inria/AerialImageDataset/test/images/austin13.tif
index 029444bc99f..e6b7763d438 100644
Binary files a/tests/data/inria/AerialImageDataset/test/images/austin13.tif and b/tests/data/inria/AerialImageDataset/test/images/austin13.tif differ
diff --git a/tests/data/inria/AerialImageDataset/test/images/austin14.tif b/tests/data/inria/AerialImageDataset/test/images/austin14.tif
index 6a84ce1c5c9..17499763838 100644
Binary files a/tests/data/inria/AerialImageDataset/test/images/austin14.tif and b/tests/data/inria/AerialImageDataset/test/images/austin14.tif differ
diff --git a/tests/data/inria/AerialImageDataset/test/images/austin15.tif b/tests/data/inria/AerialImageDataset/test/images/austin15.tif
new file mode 100644
index 00000000000..24e8699ab33
Binary files /dev/null and b/tests/data/inria/AerialImageDataset/test/images/austin15.tif differ
diff --git a/tests/data/inria/AerialImageDataset/test/images/austin16.tif b/tests/data/inria/AerialImageDataset/test/images/austin16.tif
new file mode 100644
index 00000000000..fed02e2bf58
Binary files /dev/null and b/tests/data/inria/AerialImageDataset/test/images/austin16.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin1.tif b/tests/data/inria/AerialImageDataset/train/gt/austin1.tif
index 9bf2873fff4..a5c37b51b74 100644
Binary files a/tests/data/inria/AerialImageDataset/train/gt/austin1.tif and b/tests/data/inria/AerialImageDataset/train/gt/austin1.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin2.tif b/tests/data/inria/AerialImageDataset/train/gt/austin2.tif
index b06a9363da4..a174dd76490 100644
Binary files a/tests/data/inria/AerialImageDataset/train/gt/austin2.tif and b/tests/data/inria/AerialImageDataset/train/gt/austin2.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin3.tif b/tests/data/inria/AerialImageDataset/train/gt/austin3.tif
index 2d134842907..33f6c57829a 100644
Binary files a/tests/data/inria/AerialImageDataset/train/gt/austin3.tif and b/tests/data/inria/AerialImageDataset/train/gt/austin3.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin4.tif b/tests/data/inria/AerialImageDataset/train/gt/austin4.tif
index 21cb217cef3..22763a3b2a2 100644
Binary files a/tests/data/inria/AerialImageDataset/train/gt/austin4.tif and b/tests/data/inria/AerialImageDataset/train/gt/austin4.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin5.tif b/tests/data/inria/AerialImageDataset/train/gt/austin5.tif
index 3a819af9b82..001e166dfe3 100644
Binary files a/tests/data/inria/AerialImageDataset/train/gt/austin5.tif and b/tests/data/inria/AerialImageDataset/train/gt/austin5.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin6.tif b/tests/data/inria/AerialImageDataset/train/gt/austin6.tif
new file mode 100644
index 00000000000..c1ac7192361
Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/gt/austin6.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/gt/austin7.tif b/tests/data/inria/AerialImageDataset/train/gt/austin7.tif
new file mode 100644
index 00000000000..3d7c349a475
Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/gt/austin7.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/images/austin1.tif b/tests/data/inria/AerialImageDataset/train/images/austin1.tif
index cabf7459159..fcbe57d16ed 100644
Binary files a/tests/data/inria/AerialImageDataset/train/images/austin1.tif and b/tests/data/inria/AerialImageDataset/train/images/austin1.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/images/austin2.tif b/tests/data/inria/AerialImageDataset/train/images/austin2.tif
index df55cf5b7bc..c0e881b86f0 100644
Binary files a/tests/data/inria/AerialImageDataset/train/images/austin2.tif and b/tests/data/inria/AerialImageDataset/train/images/austin2.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/images/austin3.tif b/tests/data/inria/AerialImageDataset/train/images/austin3.tif
index c99ea6c2637..116c31899e2 100644
Binary files a/tests/data/inria/AerialImageDataset/train/images/austin3.tif and b/tests/data/inria/AerialImageDataset/train/images/austin3.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/images/austin4.tif b/tests/data/inria/AerialImageDataset/train/images/austin4.tif
index 33dc4eefa32..3f01c634d53 100644
Binary files a/tests/data/inria/AerialImageDataset/train/images/austin4.tif and b/tests/data/inria/AerialImageDataset/train/images/austin4.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/images/austin5.tif b/tests/data/inria/AerialImageDataset/train/images/austin5.tif
index 2e973747a15..5fe89e7a317 100644
Binary files a/tests/data/inria/AerialImageDataset/train/images/austin5.tif and b/tests/data/inria/AerialImageDataset/train/images/austin5.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/images/austin6.tif b/tests/data/inria/AerialImageDataset/train/images/austin6.tif
new file mode 100644
index 00000000000..c4fc765b2f7
Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/images/austin6.tif differ
diff --git a/tests/data/inria/AerialImageDataset/train/images/austin7.tif b/tests/data/inria/AerialImageDataset/train/images/austin7.tif
new file mode 100644
index 00000000000..89967c769c3
Binary files /dev/null and b/tests/data/inria/AerialImageDataset/train/images/austin7.tif differ
diff --git a/tests/data/inria/NEW2-AerialImageDataset.zip b/tests/data/inria/NEW2-AerialImageDataset.zip
index 153d3b39959..a26f8e55fe8 100644
Binary files a/tests/data/inria/NEW2-AerialImageDataset.zip and b/tests/data/inria/NEW2-AerialImageDataset.zip differ
diff --git a/tests/data/inria/data.py b/tests/data/inria/data.py
index be240a43560..4e304947adc 100755
--- a/tests/data/inria/data.py
+++ b/tests/data/inria/data.py
@@ -81,10 +81,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
shutil.make_archive(
archive_path, "zip", root_dir=root, base_dir="AerialImageDataset"
)
- shutil.rmtree(folder_path)
return calculate_md5(f"{archive_path}.zip")
if __name__ == "__main__":
- md5_hash = generate_test_data(os.getcwd(), 5)
+ md5_hash = generate_test_data(os.getcwd(), 7)
print(md5_hash)
diff --git a/tests/data/mapinwild/data.py b/tests/data/mapinwild/data.py
new file mode 100644
index 00000000000..f6d089add93
--- /dev/null
+++ b/tests/data/mapinwild/data.py
@@ -0,0 +1,166 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import hashlib
+import os
+import shutil
+
+import numpy as np
+import pandas as pd
+import rasterio
+from rasterio.crs import CRS
+from rasterio.transform import Affine
+
+SIZE = 32
+
+np.random.seed(0)
+
+meta = {
+ "driver": "GTiff",
+ "nodata": None,
+ "width": SIZE,
+ "height": SIZE,
+ "crs": CRS.from_epsg(32720),
+ "transform": Affine(10.0, 0.0, 612190.0, 0.0, -10.0, 7324250.0),
+}
+
+count = {
+ "ESA_WC": 1,
+ "VIIRS": 1,
+ "mask": 1,
+ "s1_part1": 2,
+ "s1_part2": 2,
+ "s2_temporal_subset_part1": 10,
+ "s2_temporal_subset_part2": 10,
+ "s2_autumn_part1": 10,
+ "s2_autumn_part2": 10,
+ "s2_spring_part1": 10,
+ "s2_spring_part2": 10,
+ "s2_summer_part1": 10,
+ "s2_summer_part2": 10,
+ "s2_winter_part1": 10,
+ "s2_winter_part2": 10,
+}
+dtype = {
+ "ESA_WC": np.uint8,
+ "VIIRS": np.float32,
+ "mask": np.byte,
+ "s1_part1": np.float64,
+ "s1_part2": np.float64,
+ "s2_temporal_subset_part1": np.uint16,
+ "s2_temporal_subset_part2": np.uint16,
+ "s2_autumn_part1": np.uint16,
+ "s2_autumn_part2": np.uint16,
+ "s2_spring_part1": np.uint16,
+ "s2_spring_part2": np.uint16,
+ "s2_summer_part1": np.uint16,
+ "s2_summer_part2": np.uint16,
+ "s2_winter_part1": np.uint16,
+ "s2_winter_part2": np.uint16,
+}
+stop = {
+ "ESA_WC": np.iinfo(np.uint8).max,
+ "VIIRS": np.finfo(np.float32).max,
+ "mask": np.iinfo(np.byte).max,
+ "s1_part1": np.finfo(np.float64).max,
+ "s1_part2": np.finfo(np.float64).max,
+ "s2_temporal_subset_part1": np.iinfo(np.uint16).max,
+ "s2_temporal_subset_part2": np.iinfo(np.uint16).max,
+ "s2_autumn_part1": np.iinfo(np.uint16).max,
+ "s2_autumn_part2": np.iinfo(np.uint16).max,
+ "s2_spring_part1": np.iinfo(np.uint16).max,
+ "s2_spring_part2": np.iinfo(np.uint16).max,
+ "s2_summer_part1": np.iinfo(np.uint16).max,
+ "s2_summer_part2": np.iinfo(np.uint16).max,
+ "s2_winter_part1": np.iinfo(np.uint16).max,
+ "s2_winter_part2": np.iinfo(np.uint16).max,
+}
+
+folder_path = os.path.join(os.getcwd(), "tests", "data", "mapinwild")
+
+dict_all = {
+ "s2_sum": ["s2_summer_part1", "s2_summer_part2"],
+ "s2_spr": ["s2_spring_part1", "s2_spring_part2"],
+ "s2_win": ["s2_winter_part1", "s2_winter_part2"],
+ "s2_aut": ["s2_autumn_part1", "s2_autumn_part2"],
+ "s1": ["s1_part1", "s1_part2"],
+ "s2_temp": ["s2_temporal_subset_part1", "s2_temporal_subset_part2"],
+}
+
+md5s = {}
+keys = count.keys()
+modality_download_list = list(count.keys())
+
+for source in modality_download_list:
+ directory = os.path.join(folder_path, source)
+
+ # Remove old data
+ if os.path.exists(directory):
+ shutil.rmtree(directory)
+ os.makedirs(directory, exist_ok=True)
+
+ # Random images
+ for i in range(1, 3):
+ filename = f"{i}.tif"
+ filepath = os.path.join(directory, filename)
+
+ meta["count"] = count[source]
+ meta["dtype"] = dtype[source]
+ with rasterio.open(filepath, "w", **meta) as f:
+ for j in range(1, count[source] + 1):
+ if meta["dtype"] is np.float32 or meta["dtype"] is np.float64:
+ data = np.random.randn(SIZE, SIZE).astype(dtype[source])
+
+ else:
+ data = np.random.randint(stop[source], size=(SIZE, SIZE)).astype(
+ dtype[source]
+ )
+ f.write(data, j)
+
+# Mimic the two-part structure of the dataset
+for key in dict_all.keys():
+ path_list = dict_all[key]
+ path_list_dir_p1 = os.path.join(folder_path, path_list[0])
+ path_list_dir_p2 = os.path.join(folder_path, path_list[1])
+ n_ims = len(os.listdir(path_list_dir_p1))
+
+ p1_list = os.listdir(path_list_dir_p1)
+ p2_list = os.listdir(path_list_dir_p2)
+
+ fh_idx = np.arange(0, n_ims / 2, dtype=int)
+ sh_idx = np.arange(n_ims / 2, n_ims, dtype=int)
+
+ for idx in sh_idx:
+ sh_del = os.path.join(path_list_dir_p1, p1_list[idx])
+ os.remove(sh_del)
+
+ for idx in fh_idx:
+ fh_del = os.path.join(path_list_dir_p2, p2_list[idx])
+ os.remove(fh_del)
+
+for i, source in zip(keys, modality_download_list):
+ directory = os.path.join(folder_path, source)
+ root = os.path.dirname(directory)
+
+ # Compress data
+ shutil.make_archive(directory, "zip", root_dir=root, base_dir=source)
+
+ # Compute checksums
+ with open(directory + ".zip", "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{directory}: {md5}")
+ name = i + ".zip"
+ md5s[name] = md5
+
+tvt_split = pd.DataFrame(
+ [["1", "2", "3"], [np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]],
+ index=["0", "1", "2"],
+ columns=["train", "validation", "test"],
+)
+tvt_split.dropna()
+tvt_split.to_csv(os.path.join(folder_path, "split_IDs.csv"))
+
+with open(os.path.join(folder_path, "split_IDs.csv"), "rb") as f:
+ csv_md5 = hashlib.md5(f.read()).hexdigest()
diff --git a/tests/data/mapinwild/esa_wc/ESA_WC.zip b/tests/data/mapinwild/esa_wc/ESA_WC.zip
new file mode 100644
index 00000000000..e6223dccb13
Binary files /dev/null and b/tests/data/mapinwild/esa_wc/ESA_WC.zip differ
diff --git a/tests/data/mapinwild/mask/mask.zip b/tests/data/mapinwild/mask/mask.zip
new file mode 100644
index 00000000000..1a5300f24eb
Binary files /dev/null and b/tests/data/mapinwild/mask/mask.zip differ
diff --git a/tests/data/mapinwild/s1/s1_part1.zip b/tests/data/mapinwild/s1/s1_part1.zip
new file mode 100644
index 00000000000..7c77c5a130b
Binary files /dev/null and b/tests/data/mapinwild/s1/s1_part1.zip differ
diff --git a/tests/data/mapinwild/s1/s1_part2.zip b/tests/data/mapinwild/s1/s1_part2.zip
new file mode 100644
index 00000000000..aa98f262ede
Binary files /dev/null and b/tests/data/mapinwild/s1/s1_part2.zip differ
diff --git a/tests/data/mapinwild/s2_autumn/s2_autumn_part1.zip b/tests/data/mapinwild/s2_autumn/s2_autumn_part1.zip
new file mode 100644
index 00000000000..fcc0865d821
Binary files /dev/null and b/tests/data/mapinwild/s2_autumn/s2_autumn_part1.zip differ
diff --git a/tests/data/mapinwild/s2_autumn/s2_autumn_part2.zip b/tests/data/mapinwild/s2_autumn/s2_autumn_part2.zip
new file mode 100644
index 00000000000..fe877eff0a9
Binary files /dev/null and b/tests/data/mapinwild/s2_autumn/s2_autumn_part2.zip differ
diff --git a/tests/data/mapinwild/s2_spring/s2_spring_part1.zip b/tests/data/mapinwild/s2_spring/s2_spring_part1.zip
new file mode 100644
index 00000000000..e5c0d6634ee
Binary files /dev/null and b/tests/data/mapinwild/s2_spring/s2_spring_part1.zip differ
diff --git a/tests/data/mapinwild/s2_spring/s2_spring_part2.zip b/tests/data/mapinwild/s2_spring/s2_spring_part2.zip
new file mode 100644
index 00000000000..a5029983a0a
Binary files /dev/null and b/tests/data/mapinwild/s2_spring/s2_spring_part2.zip differ
diff --git a/tests/data/mapinwild/s2_summer/s2_summer_part1.zip b/tests/data/mapinwild/s2_summer/s2_summer_part1.zip
new file mode 100644
index 00000000000..2f74bd849d7
Binary files /dev/null and b/tests/data/mapinwild/s2_summer/s2_summer_part1.zip differ
diff --git a/tests/data/mapinwild/s2_summer/s2_summer_part2.zip b/tests/data/mapinwild/s2_summer/s2_summer_part2.zip
new file mode 100644
index 00000000000..6a7793cf4c6
Binary files /dev/null and b/tests/data/mapinwild/s2_summer/s2_summer_part2.zip differ
diff --git a/tests/data/mapinwild/s2_temporal_subset/s2_temporal_subset_part1.zip b/tests/data/mapinwild/s2_temporal_subset/s2_temporal_subset_part1.zip
new file mode 100644
index 00000000000..dd8dc02bfb0
Binary files /dev/null and b/tests/data/mapinwild/s2_temporal_subset/s2_temporal_subset_part1.zip differ
diff --git a/tests/data/mapinwild/s2_temporal_subset/s2_temporal_subset_part2.zip b/tests/data/mapinwild/s2_temporal_subset/s2_temporal_subset_part2.zip
new file mode 100644
index 00000000000..32741c37829
Binary files /dev/null and b/tests/data/mapinwild/s2_temporal_subset/s2_temporal_subset_part2.zip differ
diff --git a/tests/data/mapinwild/s2_winter/s2_winter_part1.zip b/tests/data/mapinwild/s2_winter/s2_winter_part1.zip
new file mode 100644
index 00000000000..0ad23a847d6
Binary files /dev/null and b/tests/data/mapinwild/s2_winter/s2_winter_part1.zip differ
diff --git a/tests/data/mapinwild/s2_winter/s2_winter_part2.zip b/tests/data/mapinwild/s2_winter/s2_winter_part2.zip
new file mode 100644
index 00000000000..95554417a0f
Binary files /dev/null and b/tests/data/mapinwild/s2_winter/s2_winter_part2.zip differ
diff --git a/tests/data/mapinwild/split_IDs/split_IDs.csv b/tests/data/mapinwild/split_IDs/split_IDs.csv
new file mode 100644
index 00000000000..b239b8ca9e8
--- /dev/null
+++ b/tests/data/mapinwild/split_IDs/split_IDs.csv
@@ -0,0 +1,2 @@
+,train,validation,test
+0,1,1,1
diff --git a/tests/data/mapinwild/viirs/VIIRS.zip b/tests/data/mapinwild/viirs/VIIRS.zip
new file mode 100644
index 00000000000..ecbfb609cff
Binary files /dev/null and b/tests/data/mapinwild/viirs/VIIRS.zip differ
diff --git a/tests/data/rwanda_field_boundary/data.py b/tests/data/rwanda_field_boundary/data.py
new file mode 100644
index 00000000000..7a23b385cf6
--- /dev/null
+++ b/tests/data/rwanda_field_boundary/data.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import hashlib
+import os
+import shutil
+
+import numpy as np
+import rasterio
+
+dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12")
+all_bands = ("B01", "B02", "B03", "B04")
+
+SIZE = 32
+NUM_SAMPLES = 5
+np.random.seed(0)
+
+
+def create_mask(fn: str) -> None:
+ profile = {
+ "driver": "GTiff",
+ "dtype": "uint8",
+ "nodata": 0.0,
+ "width": SIZE,
+ "height": SIZE,
+ "count": 1,
+ "crs": "epsg:3857",
+ "compress": "lzw",
+ "predictor": 2,
+ "transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
+ "blockysize": 32,
+ "tiled": False,
+ "interleave": "band",
+ }
+ with rasterio.open(fn, "w", **profile) as f:
+ f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1)
+
+
+def create_img(fn: str) -> None:
+ profile = {
+ "driver": "GTiff",
+ "dtype": "uint16",
+ "nodata": 0.0,
+ "width": SIZE,
+ "height": SIZE,
+ "count": 1,
+ "crs": "epsg:3857",
+ "compress": "lzw",
+ "predictor": 2,
+ "blockysize": 16,
+ "transform": rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0),
+ "tiled": False,
+ "interleave": "band",
+ }
+ with rasterio.open(fn, "w", **profile) as f:
+ f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1)
+
+
+if __name__ == "__main__":
+ # Train and test images
+ for split in ("train", "test"):
+ for i in range(NUM_SAMPLES):
+ for date in dates:
+ directory = os.path.join(
+ f"nasa_rwanda_field_boundary_competition_source_{split}",
+ f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501
+ )
+ os.makedirs(directory, exist_ok=True)
+ for band in all_bands:
+ create_img(os.path.join(directory, f"{band}.tif"))
+
+ # Create collections.json, this isn't used by the dataset but is checked to
+ # exist
+ with open(
+ f"nasa_rwanda_field_boundary_competition_source_{split}/collections.json",
+ "w",
+ ) as f:
+ f.write("Not used")
+
+ # Train labels
+ for i in range(NUM_SAMPLES):
+ directory = os.path.join(
+ "nasa_rwanda_field_boundary_competition_labels_train",
+ f"nasa_rwanda_field_boundary_competition_labels_train_{i:02d}",
+ )
+ os.makedirs(directory, exist_ok=True)
+ create_mask(os.path.join(directory, "raster_labels.tif"))
+
+ # Create directories and compute checksums
+ for filename in [
+ "nasa_rwanda_field_boundary_competition_source_train",
+ "nasa_rwanda_field_boundary_competition_source_test",
+ "nasa_rwanda_field_boundary_competition_labels_train",
+ ]:
+ shutil.make_archive(filename, "gztar", ".", filename)
+ # Compute checksums
+ with open(f"{filename}.tar.gz", "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{filename}: {md5}")
diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz
new file mode 100644
index 00000000000..ffa98bb53d6
Binary files /dev/null and b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz differ
diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz
new file mode 100644
index 00000000000..a834f66bf38
Binary files /dev/null and b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz differ
diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz
new file mode 100644
index 00000000000..8239f70c200
Binary files /dev/null and b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz differ
diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py
new file mode 100644
index 00000000000..6befc1fd4ca
--- /dev/null
+++ b/tests/data/seasonet/data.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import hashlib
+import os
+import shutil
+
+import numpy as np
+import rasterio
+from rasterio.crs import CRS
+from rasterio.transform import Affine
+
+np.random.seed(0)
+
+meta = {
+ "driver": "GTiff",
+ "nodata": None,
+ "crs": CRS.from_epsg(32632),
+ "transform": Affine(10.0, 0.0, 664800.0, 0.0, -10.0, 5342400.0),
+ "compress": "zstd",
+}
+bands = ["10m_RGB", "10m_IR", "20m", "60m", "labels"]
+count = {"10m_RGB": 3, "10m_IR": 1, "20m": 6, "60m": 2, "labels": 1}
+dtype = {
+ "10m_RGB": np.uint16,
+ "10m_IR": np.uint16,
+ "20m": np.uint16,
+ "60m": np.uint16,
+ "labels": np.uint8,
+}
+size = {"10m_RGB": 120, "10m_IR": 120, "20m": 60, "60m": 20, "labels": 120}
+start = {"10m_RGB": 0, "10m_IR": 0, "20m": 0, "60m": 0, "labels": 1}
+stop = {
+ "10m_RGB": np.iinfo(np.uint16).max,
+ "10m_IR": np.iinfo(np.uint16).max,
+ "20m": np.iinfo(np.uint16).max,
+ "60m": np.iinfo(np.uint16).max,
+ "labels": 34,
+}
+
+meta_lines = [
+ "Index,Season,Grid,Latitude,Longitude,Satellite,Year,Month,Day,"
+ "Hour,Minute,Second,Clouds,Snow,Classes,SLRAUM,RTYP3,KTYP4,Path\n"
+]
+seasons = ["spring", "summer", "fall", "winter", "snow"]
+grids = [1, 2]
+name_comps = [
+ ["32UME", "2018", "04", "18", "T", "10", "40", "21", "53", "928425", "7", "503876"],
+ ["32TMT", "2019", "02", "14", "T", "10", "31", "29", "47", "793488", "7", "808487"],
+]
+index = 0
+for season in seasons:
+ # Remove old data
+ if os.path.exists(season):
+ shutil.rmtree(season)
+
+ archive = f"{season}.zip"
+
+ # Remove old data
+ if os.path.exists(archive):
+ os.remove(archive)
+
+ for grid, comp in zip(grids, name_comps):
+ file_name = f"{comp[0]}_{''.join(comp[1:8])}_{'_'.join(comp[8:])}"
+ dir = os.path.join(season, f"grid{grid}", file_name)
+ os.makedirs(dir)
+
+ # Random images
+ for band in bands:
+ meta["count"] = count[band]
+ meta["dtype"] = dtype[band]
+ meta["width"] = meta["height"] = size[band]
+ with rasterio.open(
+ os.path.join(dir, f"{file_name}_{band}.tif"), "w", **meta
+ ) as f:
+ for j in range(1, count[band] + 1):
+ data = np.random.randint(
+ start[band], stop[band], size=(size[band], size[band])
+ ).astype(dtype[band])
+ f.write(data, j)
+
+ # Generate meta.csv lines
+ meta_entries = [
+ index,
+ season.capitalize(),
+ grid,
+ f"{comp[8]}.{comp[9]}",
+ f"{comp[10]}.{comp[11]}",
+ "A",
+ comp[1],
+ comp[2],
+ comp[3],
+ comp[5],
+ comp[6],
+ comp[7],
+ 0.0,
+ 0.0,
+ "'2,3,12,15,17'",
+ 1,
+ 1,
+ 1,
+ dir,
+ ]
+ meta_lines.append(",".join(map(str, meta_entries)) + "\n")
+ index += 1
+
+ # Create archives
+ shutil.make_archive(season, "zip", ".", season)
+
+ # Compute checksums
+ with open(archive, "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{season}: {repr(md5)}")
+
+# Write meta.csv
+with open("meta.csv", "w") as f:
+ f.writelines(meta_lines)
+
+# Compute checksums
+with open("meta.csv", "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"meta.csv: {repr(md5)}")
+
+os.makedirs("splits", exist_ok=True)
+
+for split in ["train", "val", "test"]:
+ filename = f"{split}.csv"
+
+ # Create file list
+ with open(os.path.join("splits", filename), "w") as f:
+ for i in range(index):
+ f.write(str(i) + "\n")
+
+shutil.make_archive("splits", "zip", ".", "splits")
+
+# Compute checksums
+with open("splits.zip", "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"splits: {repr(md5)}")
diff --git a/tests/data/seasonet/fall.zip b/tests/data/seasonet/fall.zip
new file mode 100644
index 00000000000..b30752fca77
Binary files /dev/null and b/tests/data/seasonet/fall.zip differ
diff --git a/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif
new file mode 100644
index 00000000000..4454f9c51ed
Binary files /dev/null and b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif differ
diff --git a/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif
new file mode 100644
index 00000000000..2add2712b59
Binary files /dev/null and b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif differ
diff --git a/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif
new file mode 100644
index 00000000000..ddf22ac829d
Binary files /dev/null and b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif differ
diff --git a/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif
new file mode 100644
index 00000000000..d9cda50cb0a
Binary files /dev/null and b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif differ
diff --git a/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif
new file mode 100644
index 00000000000..99647b1da03
Binary files /dev/null and b/tests/data/seasonet/fall/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif differ
diff --git a/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif
new file mode 100644
index 00000000000..03a4446f239
Binary files /dev/null and b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif differ
diff --git a/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif
new file mode 100644
index 00000000000..f2f828cf285
Binary files /dev/null and b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif differ
diff --git a/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif
new file mode 100644
index 00000000000..8485cb97d81
Binary files /dev/null and b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif differ
diff --git a/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif
new file mode 100644
index 00000000000..743ab48844f
Binary files /dev/null and b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif differ
diff --git a/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif
new file mode 100644
index 00000000000..382429ee46e
Binary files /dev/null and b/tests/data/seasonet/fall/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif differ
diff --git a/tests/data/seasonet/meta.csv b/tests/data/seasonet/meta.csv
new file mode 100644
index 00000000000..fa4d6d0a21e
--- /dev/null
+++ b/tests/data/seasonet/meta.csv
@@ -0,0 +1,11 @@
+Index,Season,Grid,Latitude,Longitude,Satellite,Year,Month,Day,Hour,Minute,Second,Clouds,Snow,Classes,SLRAUM,RTYP3,KTYP4,Path
+0,Spring,1,53.928425,7.503876,A,2018,04,18,10,40,21,0.0,0.0,"2,3,12,15,17",1,1,1,spring/grid1/32UME_20180418T104021_53_928425_7_503876
+1,Spring,2,47.793488,7.808487,A,2019,02,14,10,31,29,0.0,0.0,"2,3,12,15,17",1,1,1,spring/grid2/32TMT_20190214T103129_47_793488_7_808487
+2,Summer,1,53.928425,7.503876,A,2018,04,18,10,40,21,0.0,0.0,"2,3,12,15,17",1,1,1,summer/grid1/32UME_20180418T104021_53_928425_7_503876
+3,Summer,2,47.793488,7.808487,A,2019,02,14,10,31,29,0.0,0.0,"2,3,12,15,17",1,1,1,summer/grid2/32TMT_20190214T103129_47_793488_7_808487
+4,Fall,1,53.928425,7.503876,A,2018,04,18,10,40,21,0.0,0.0,"2,3,12,15,17",1,1,1,fall/grid1/32UME_20180418T104021_53_928425_7_503876
+5,Fall,2,47.793488,7.808487,A,2019,02,14,10,31,29,0.0,0.0,"2,3,12,15,17",1,1,1,fall/grid2/32TMT_20190214T103129_47_793488_7_808487
+6,Winter,1,53.928425,7.503876,A,2018,04,18,10,40,21,0.0,0.0,"2,3,12,15,17",1,1,1,winter/grid1/32UME_20180418T104021_53_928425_7_503876
+7,Winter,2,47.793488,7.808487,A,2019,02,14,10,31,29,0.0,0.0,"2,3,12,15,17",1,1,1,winter/grid2/32TMT_20190214T103129_47_793488_7_808487
+8,Snow,1,53.928425,7.503876,A,2018,04,18,10,40,21,0.0,0.0,"2,3,12,15,17",1,1,1,snow/grid1/32UME_20180418T104021_53_928425_7_503876
+9,Snow,2,47.793488,7.808487,A,2019,02,14,10,31,29,0.0,0.0,"2,3,12,15,17",1,1,1,snow/grid2/32TMT_20190214T103129_47_793488_7_808487
diff --git a/tests/data/seasonet/snow.zip b/tests/data/seasonet/snow.zip
new file mode 100644
index 00000000000..4642bbdf2a9
Binary files /dev/null and b/tests/data/seasonet/snow.zip differ
diff --git a/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif
new file mode 100644
index 00000000000..3db41931a7e
Binary files /dev/null and b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif differ
diff --git a/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif
new file mode 100644
index 00000000000..f0d65716c85
Binary files /dev/null and b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif differ
diff --git a/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif
new file mode 100644
index 00000000000..d69f2d74f21
Binary files /dev/null and b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif differ
diff --git a/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif
new file mode 100644
index 00000000000..7a851cbfe07
Binary files /dev/null and b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif differ
diff --git a/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif
new file mode 100644
index 00000000000..ba46245e72b
Binary files /dev/null and b/tests/data/seasonet/snow/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif differ
diff --git a/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif
new file mode 100644
index 00000000000..a67b83e2059
Binary files /dev/null and b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif differ
diff --git a/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif
new file mode 100644
index 00000000000..64d2981ea2d
Binary files /dev/null and b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif differ
diff --git a/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif
new file mode 100644
index 00000000000..6e8a841b682
Binary files /dev/null and b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif differ
diff --git a/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif
new file mode 100644
index 00000000000..c38d4a7a32e
Binary files /dev/null and b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif differ
diff --git a/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif
new file mode 100644
index 00000000000..5da046a16f2
Binary files /dev/null and b/tests/data/seasonet/snow/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif differ
diff --git a/tests/data/seasonet/splits.zip b/tests/data/seasonet/splits.zip
new file mode 100644
index 00000000000..dbd10acd67d
Binary files /dev/null and b/tests/data/seasonet/splits.zip differ
diff --git a/tests/data/seasonet/splits/test.csv b/tests/data/seasonet/splits/test.csv
new file mode 100644
index 00000000000..8b1acc12b63
--- /dev/null
+++ b/tests/data/seasonet/splits/test.csv
@@ -0,0 +1,10 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
diff --git a/tests/data/seasonet/splits/train.csv b/tests/data/seasonet/splits/train.csv
new file mode 100644
index 00000000000..8b1acc12b63
--- /dev/null
+++ b/tests/data/seasonet/splits/train.csv
@@ -0,0 +1,10 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
diff --git a/tests/data/seasonet/splits/val.csv b/tests/data/seasonet/splits/val.csv
new file mode 100644
index 00000000000..8b1acc12b63
--- /dev/null
+++ b/tests/data/seasonet/splits/val.csv
@@ -0,0 +1,10 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
diff --git a/tests/data/seasonet/spring.zip b/tests/data/seasonet/spring.zip
new file mode 100644
index 00000000000..6a1cd0535c2
Binary files /dev/null and b/tests/data/seasonet/spring.zip differ
diff --git a/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif
new file mode 100644
index 00000000000..62903e11e0a
Binary files /dev/null and b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif differ
diff --git a/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif
new file mode 100644
index 00000000000..8d6faf37776
Binary files /dev/null and b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif differ
diff --git a/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif
new file mode 100644
index 00000000000..981afae0807
Binary files /dev/null and b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif differ
diff --git a/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif
new file mode 100644
index 00000000000..f31756fec57
Binary files /dev/null and b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif differ
diff --git a/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif
new file mode 100644
index 00000000000..dc16e540721
Binary files /dev/null and b/tests/data/seasonet/spring/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif differ
diff --git a/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif
new file mode 100644
index 00000000000..1814a8e956c
Binary files /dev/null and b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif differ
diff --git a/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif
new file mode 100644
index 00000000000..47deecd61f7
Binary files /dev/null and b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif differ
diff --git a/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif
new file mode 100644
index 00000000000..9c45a93a96a
Binary files /dev/null and b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif differ
diff --git a/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif
new file mode 100644
index 00000000000..12b39d63a34
Binary files /dev/null and b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif differ
diff --git a/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif
new file mode 100644
index 00000000000..6f9b131ad90
Binary files /dev/null and b/tests/data/seasonet/spring/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif differ
diff --git a/tests/data/seasonet/summer.zip b/tests/data/seasonet/summer.zip
new file mode 100644
index 00000000000..180ba2bf063
Binary files /dev/null and b/tests/data/seasonet/summer.zip differ
diff --git a/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif
new file mode 100644
index 00000000000..a384ba122c2
Binary files /dev/null and b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif differ
diff --git a/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif
new file mode 100644
index 00000000000..8b24ff29416
Binary files /dev/null and b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif differ
diff --git a/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif
new file mode 100644
index 00000000000..f779ef1615f
Binary files /dev/null and b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif differ
diff --git a/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif
new file mode 100644
index 00000000000..198fa95c52d
Binary files /dev/null and b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif differ
diff --git a/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif
new file mode 100644
index 00000000000..62ef110e306
Binary files /dev/null and b/tests/data/seasonet/summer/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif differ
diff --git a/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif
new file mode 100644
index 00000000000..ee77e161fe6
Binary files /dev/null and b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif differ
diff --git a/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif
new file mode 100644
index 00000000000..4ef2245f489
Binary files /dev/null and b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif differ
diff --git a/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif
new file mode 100644
index 00000000000..41f2e355c1f
Binary files /dev/null and b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif differ
diff --git a/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif
new file mode 100644
index 00000000000..a8fad9b0f38
Binary files /dev/null and b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif differ
diff --git a/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif
new file mode 100644
index 00000000000..66ea885b1af
Binary files /dev/null and b/tests/data/seasonet/summer/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif differ
diff --git a/tests/data/seasonet/winter.zip b/tests/data/seasonet/winter.zip
new file mode 100644
index 00000000000..17b9df95bca
Binary files /dev/null and b/tests/data/seasonet/winter.zip differ
diff --git a/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif
new file mode 100644
index 00000000000..d4aad39666c
Binary files /dev/null and b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_IR.tif differ
diff --git a/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif
new file mode 100644
index 00000000000..8672dcdc9e6
Binary files /dev/null and b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_10m_RGB.tif differ
diff --git a/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif
new file mode 100644
index 00000000000..de24812de6b
Binary files /dev/null and b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_20m.tif differ
diff --git a/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif
new file mode 100644
index 00000000000..6d249979336
Binary files /dev/null and b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_60m.tif differ
diff --git a/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif
new file mode 100644
index 00000000000..41584a2f7ba
Binary files /dev/null and b/tests/data/seasonet/winter/grid1/32UME_20180418T104021_53_928425_7_503876/32UME_20180418T104021_53_928425_7_503876_labels.tif differ
diff --git a/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif
new file mode 100644
index 00000000000..54e90c702cf
Binary files /dev/null and b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_IR.tif differ
diff --git a/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif
new file mode 100644
index 00000000000..0d28c3e2fea
Binary files /dev/null and b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_10m_RGB.tif differ
diff --git a/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif
new file mode 100644
index 00000000000..24b48e6fd34
Binary files /dev/null and b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_20m.tif differ
diff --git a/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif
new file mode 100644
index 00000000000..ce355168067
Binary files /dev/null and b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_60m.tif differ
diff --git a/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif
new file mode 100644
index 00000000000..fa908a81e7c
Binary files /dev/null and b/tests/data/seasonet/winter/grid2/32TMT_20190214T103129_47_793488_7_808487/32TMT_20190214T103129_47_793488_7_808487_labels.tif differ
diff --git a/tests/data/skippd/2017_2019_images_pv_processed_forecast.hdf5 b/tests/data/skippd/2017_2019_images_pv_processed_forecast.hdf5
new file mode 100644
index 00000000000..6b47611e450
Binary files /dev/null and b/tests/data/skippd/2017_2019_images_pv_processed_forecast.hdf5 differ
diff --git a/tests/data/skippd/2017_2019_images_pv_processed_forecast.zip b/tests/data/skippd/2017_2019_images_pv_processed_forecast.zip
new file mode 100644
index 00000000000..d47d251da01
Binary files /dev/null and b/tests/data/skippd/2017_2019_images_pv_processed_forecast.zip differ
diff --git a/tests/data/skippd/dj417rh1007/2017_2019_images_pv_processed.hdf5 b/tests/data/skippd/2017_2019_images_pv_processed_nowcast.hdf5
similarity index 99%
rename from tests/data/skippd/dj417rh1007/2017_2019_images_pv_processed.hdf5
rename to tests/data/skippd/2017_2019_images_pv_processed_nowcast.hdf5
index 4caafc43891..e13e29638ed 100644
Binary files a/tests/data/skippd/dj417rh1007/2017_2019_images_pv_processed.hdf5 and b/tests/data/skippd/2017_2019_images_pv_processed_nowcast.hdf5 differ
diff --git a/tests/data/skippd/2017_2019_images_pv_processed_nowcast.zip b/tests/data/skippd/2017_2019_images_pv_processed_nowcast.zip
new file mode 100644
index 00000000000..2bf3d5d1ac0
Binary files /dev/null and b/tests/data/skippd/2017_2019_images_pv_processed_nowcast.zip differ
diff --git a/tests/data/skippd/data.py b/tests/data/skippd/data.py
index 86032939f70..e717c2ce025 100755
--- a/tests/data/skippd/data.py
+++ b/tests/data/skippd/data.py
@@ -4,8 +4,7 @@
# Licensed under the MIT License.
import hashlib
-import os
-import shutil
+import zipfile
from datetime import datetime, timedelta
import h5py
@@ -17,43 +16,62 @@
NUM_SAMPLES = 3
NUM_CHANNELS = 3
SIZE = 64
+TIME_STEPS = 16
np.random.seed(0)
-data_dir = "dj417rh1007"
-data_file = "2017_2019_images_pv_processed.hdf5"
+tasks = ["nowcast", "forecast"]
+data_file = "2017_2019_images_pv_processed_{}.hdf5"
splits = ["trainval", "test"]
+
# Create dataset file
-data = np.random.randint(
- RGB_MAX, size=(NUM_SAMPLES, SIZE, SIZE, NUM_CHANNELS), dtype=np.int16
-)
-labels = np.random.random(size=(NUM_SAMPLES))
-if __name__ == "__main__":
- # Remove old data
- if os.path.exists(data_dir):
- shutil.rmtree(data_dir)
+data = {
+ "nowcast": np.random.randint(
+ RGB_MAX, size=(NUM_SAMPLES, SIZE, SIZE, NUM_CHANNELS), dtype=np.int16
+ ),
+ "forecast": np.random.randint(
+ RGB_MAX,
+ size=(NUM_SAMPLES, TIME_STEPS, SIZE, SIZE, NUM_CHANNELS),
+ dtype=np.int16,
+ ),
+}
+
+
+labels = {
+ "nowcast": np.random.random(size=(NUM_SAMPLES)),
+ "forecast": np.random.random(size=(NUM_SAMPLES, TIME_STEPS)),
+}
- os.makedirs(data_dir)
- with h5py.File(os.path.join(data_dir, data_file), "w") as f:
+if __name__ == "__main__":
+ for task in tasks:
+ with h5py.File(data_file.format(task), "w") as f:
+ for split in splits:
+ grp = f.create_group(split)
+ grp.create_dataset("images_log", data=data[task])
+ grp.create_dataset("pv_log", data=labels[task])
+
+ # create time stamps
for split in splits:
- grp = f.create_group(split)
- grp.create_dataset("images_log", data=data)
- grp.create_dataset("pv_log", data=labels)
-
- # create time stamps
- for split in splits:
- time_stamps = np.array(
- [datetime.now() - timedelta(days=i) for i in range(NUM_SAMPLES)]
- )
- np.save(os.path.join(data_dir, f"times_{split}.npy"), time_stamps)
-
- # Compress data
- shutil.make_archive(data_dir, "zip", ".", data_dir)
-
- # Compute checksums
- with open(data_dir + ".zip", "rb") as f:
- md5 = hashlib.md5(f.read()).hexdigest()
- print(f"{data_dir}.zip: {md5}")
+ time_stamps = np.array(
+ [datetime.now() - timedelta(days=i) for i in range(NUM_SAMPLES)]
+ )
+ np.save(f"times_{split}_{task}.npy", time_stamps)
+
+ # Compress data
+ with zipfile.ZipFile(
+ data_file.format(task).replace(".hdf5", ".zip"), "w"
+ ) as zip:
+ for file in [
+ data_file.format(task),
+ f"times_trainval_{task}.npy",
+ f"times_test_{task}.npy",
+ ]:
+ zip.write(file, arcname=file)
+
+ # Compute checksums
+ with open(data_file.format(task).replace(".hdf5", ".zip"), "rb") as f:
+ md5 = hashlib.md5(f.read()).hexdigest()
+ print(f"{task}: {md5}")
diff --git a/tests/data/skippd/dj417rh1007.zip b/tests/data/skippd/dj417rh1007.zip
deleted file mode 100644
index 1bfcca8ca63..00000000000
Binary files a/tests/data/skippd/dj417rh1007.zip and /dev/null differ
diff --git a/tests/data/skippd/dj417rh1007/times_test.npy b/tests/data/skippd/times_test_forecast.npy
similarity index 81%
rename from tests/data/skippd/dj417rh1007/times_test.npy
rename to tests/data/skippd/times_test_forecast.npy
index 30119e9a2a4..296f9bfb65b 100644
Binary files a/tests/data/skippd/dj417rh1007/times_test.npy and b/tests/data/skippd/times_test_forecast.npy differ
diff --git a/tests/data/skippd/dj417rh1007/times_trainval.npy b/tests/data/skippd/times_test_nowcast.npy
similarity index 81%
rename from tests/data/skippd/dj417rh1007/times_trainval.npy
rename to tests/data/skippd/times_test_nowcast.npy
index 34676ac40d7..61b2ea91122 100644
Binary files a/tests/data/skippd/dj417rh1007/times_trainval.npy and b/tests/data/skippd/times_test_nowcast.npy differ
diff --git a/tests/data/skippd/times_trainval_forecast.npy b/tests/data/skippd/times_trainval_forecast.npy
new file mode 100644
index 00000000000..f0d680cef9c
Binary files /dev/null and b/tests/data/skippd/times_trainval_forecast.npy differ
diff --git a/tests/data/skippd/times_trainval_nowcast.npy b/tests/data/skippd/times_trainval_nowcast.npy
new file mode 100644
index 00000000000..d2fb5e94a00
Binary files /dev/null and b/tests/data/skippd/times_trainval_nowcast.npy differ
diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py
index 1fd6eb8983d..ef856eddee3 100644
--- a/tests/datamodules/test_geo.py
+++ b/tests/datamodules/test_geo.py
@@ -8,6 +8,7 @@
import torch
from _pytest.fixtures import SubRequest
from lightning.pytorch import Trainer
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
@@ -33,7 +34,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query}
- def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
+ def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()
@@ -72,7 +73,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
def __len__(self) -> int:
return self.length
- def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
+ def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()
diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py
new file mode 100644
index 00000000000..8e67152346a
--- /dev/null
+++ b/tests/datamodules/test_levircd.py
@@ -0,0 +1,78 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import shutil
+from pathlib import Path
+
+import pytest
+from lightning.pytorch import Trainer
+from pytest import MonkeyPatch
+
+import torchgeo.datasets.utils
+from torchgeo.datamodules import LEVIRCDPlusDataModule
+from torchgeo.datasets import LEVIRCDPlus
+
+
+def download_url(url: str, root: str, *args: str) -> None:
+ shutil.copy(url, root)
+
+
+class TestLEVIRCDPlusDataModule:
+ @pytest.fixture
+ def datamodule(
+ self, monkeypatch: MonkeyPatch, tmp_path: Path
+ ) -> LEVIRCDPlusDataModule:
+ monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
+ md5 = "1adf156f628aa32fb2e8fe6cada16c04"
+ monkeypatch.setattr(LEVIRCDPlus, "md5", md5)
+ url = os.path.join("tests", "data", "levircd", "LEVIR-CD+.zip")
+ monkeypatch.setattr(LEVIRCDPlus, "url", url)
+
+ root = str(tmp_path)
+ dm = LEVIRCDPlusDataModule(
+ root=root, download=True, num_workers=0, checksum=True, val_split_pct=0.5
+ )
+ dm.prepare_data()
+ dm.trainer = Trainer(accelerator="cpu", max_epochs=1)
+ return dm
+
+ def test_train_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None:
+ datamodule.setup("fit")
+ if datamodule.trainer:
+ datamodule.trainer.training = True
+ batch = next(iter(datamodule.train_dataloader()))
+ batch = datamodule.on_after_batch_transfer(batch, 0)
+ assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
+ assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8
+ assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
+ assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8
+ assert batch["image1"].shape[1] == 3
+ assert batch["image2"].shape[1] == 3
+
+ def test_val_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None:
+ datamodule.setup("validate")
+ if datamodule.trainer:
+ datamodule.trainer.validating = True
+ batch = next(iter(datamodule.val_dataloader()))
+ batch = datamodule.on_after_batch_transfer(batch, 0)
+ if datamodule.val_split_pct > 0.0:
+ assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
+ assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8
+ assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
+ assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8
+ assert batch["image1"].shape[1] == 3
+ assert batch["image2"].shape[1] == 3
+
+ def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None:
+ datamodule.setup("test")
+ if datamodule.trainer:
+ datamodule.trainer.testing = True
+ batch = next(iter(datamodule.test_dataloader()))
+ batch = datamodule.on_after_batch_transfer(batch, 0)
+ assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
+ assert batch["image1"].shape[0] == batch["mask"].shape[0] == 8
+ assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (256, 256)
+ assert batch["image2"].shape[0] == batch["mask"].shape[0] == 8
+ assert batch["image1"].shape[1] == 3
+ assert batch["image2"].shape[1] == 3
diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py
index 10e890d044b..0c009a8ec2f 100644
--- a/tests/datamodules/test_oscd.py
+++ b/tests/datamodules/test_oscd.py
@@ -8,10 +8,11 @@
from lightning.pytorch import Trainer
from torchgeo.datamodules import OSCDDataModule
+from torchgeo.datasets import OSCD
class TestOSCDDataModule:
- @pytest.fixture(params=["all", "rgb"])
+ @pytest.fixture(params=[OSCD.all_bands, OSCD.rgb_bands])
def datamodule(self, request: SubRequest) -> OSCDDataModule:
bands = request.param
root = os.path.join("tests", "data", "oscd")
@@ -34,12 +35,16 @@ def test_train_dataloader(self, datamodule: OSCDDataModule) -> None:
datamodule.trainer.training = True
batch = next(iter(datamodule.train_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
- assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
- assert batch["image"].shape[0] == batch["mask"].shape[0] == 1
- if datamodule.bands == "all":
- assert batch["image"].shape[1] == 26
+ assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
+ assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1
+ assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
+ assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1
+ if datamodule.bands == OSCD.all_bands:
+ assert batch["image1"].shape[1] == 13
+ assert batch["image2"].shape[1] == 13
else:
- assert batch["image"].shape[1] == 6
+ assert batch["image1"].shape[1] == 3
+ assert batch["image2"].shape[1] == 3
def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
datamodule.setup("validate")
@@ -48,12 +53,16 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None:
batch = next(iter(datamodule.val_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
if datamodule.val_split_pct > 0.0:
- assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
- assert batch["image"].shape[0] == batch["mask"].shape[0] == 1
- if datamodule.bands == "all":
- assert batch["image"].shape[1] == 26
+ assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
+ assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1
+ assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
+ assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1
+ if datamodule.bands == OSCD.all_bands:
+ assert batch["image1"].shape[1] == 13
+ assert batch["image2"].shape[1] == 13
else:
- assert batch["image"].shape[1] == 6
+ assert batch["image1"].shape[1] == 3
+ assert batch["image2"].shape[1] == 3
def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
datamodule.setup("test")
@@ -61,9 +70,13 @@ def test_test_dataloader(self, datamodule: OSCDDataModule) -> None:
datamodule.trainer.testing = True
batch = next(iter(datamodule.test_dataloader()))
batch = datamodule.on_after_batch_transfer(batch, 0)
- assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
- assert batch["image"].shape[0] == batch["mask"].shape[0] == 1
- if datamodule.bands == "all":
- assert batch["image"].shape[1] == 26
+ assert batch["image1"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
+ assert batch["image1"].shape[0] == batch["mask"].shape[0] == 1
+ assert batch["image2"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2)
+ assert batch["image2"].shape[0] == batch["mask"].shape[0] == 1
+ if datamodule.bands == OSCD.all_bands:
+ assert batch["image1"].shape[1] == 13
+ assert batch["image2"].shape[1] == 13
else:
- assert batch["image"].shape[1] == 6
+ assert batch["image1"].shape[1] == 3
+ assert batch["image2"].shape[1] == 3
diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py
index abbffbf47c5..bcbaba2500f 100644
--- a/tests/datasets/test_advance.py
+++ b/tests/datasets/test_advance.py
@@ -14,7 +14,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import ADVANCE
+from torchgeo.datasets import ADVANCE, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
@@ -68,7 +68,7 @@ def test_already_downloaded(self, dataset: ADVANCE) -> None:
ADVANCE(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ADVANCE(str(tmp_path))
def test_mock_missing_module(
@@ -81,6 +81,7 @@ def test_mock_missing_module(
dataset[0]
def test_plot(self, dataset: ADVANCE) -> None:
+ pytest.importorskip("scipy", minversion="1.6.2")
x = dataset[0].copy()
dataset.plot(x, suptitle="Test")
plt.close()
diff --git a/tests/datasets/test_agb_live_woody_density.py b/tests/datasets/test_agb_live_woody_density.py
index 7800aecc1b1..3e0bbbc2dc7 100644
--- a/tests/datasets/test_agb_live_woody_density.py
+++ b/tests/datasets/test_agb_live_woody_density.py
@@ -15,6 +15,7 @@
import torchgeo
from torchgeo.datasets import (
AbovegroundLiveWoodyBiomassDensity,
+ DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
@@ -52,14 +53,14 @@ def test_getitem(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
- def test_no_dataset(self) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in."):
- AbovegroundLiveWoodyBiomassDensity(root="/test")
+ def test_no_dataset(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ AbovegroundLiveWoodyBiomassDensity(str(tmp_path))
def test_already_downloaded(
self, dataset: AbovegroundLiveWoodyBiomassDensity
) -> None:
- AbovegroundLiveWoodyBiomassDensity(dataset.root)
+ AbovegroundLiveWoodyBiomassDensity(dataset.paths)
def test_and(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None:
ds = dataset & dataset
diff --git a/tests/datasets/test_astergdem.py b/tests/datasets/test_astergdem.py
index 25d0940b30d..dfd41e40409 100644
--- a/tests/datasets/test_astergdem.py
+++ b/tests/datasets/test_astergdem.py
@@ -11,7 +11,13 @@
import torch.nn as nn
from rasterio.crs import CRS
-from torchgeo.datasets import AsterGDEM, BoundingBox, IntersectionDataset, UnionDataset
+from torchgeo.datasets import (
+ AsterGDEM,
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ UnionDataset,
+)
class TestAsterGDEM:
@@ -26,8 +32,8 @@ def dataset(self, tmp_path: Path) -> AsterGDEM:
def test_datasetmissing(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
- with pytest.raises(RuntimeError, match="Dataset not found in"):
- AsterGDEM(root=str(tmp_path))
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ AsterGDEM(str(tmp_path))
def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]
diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py
index 7255e491b8a..6d81d6876be 100644
--- a/tests/datasets/test_benin_cashews.py
+++ b/tests/datasets/test_benin_cashews.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
-from torchgeo.datasets import BeninSmallHolderCashews
+from torchgeo.datasets import BeninSmallHolderCashews, DatasetNotFoundError
class Collection:
@@ -73,7 +73,7 @@ def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None:
BeninSmallHolderCashews(root=dataset.root, download=True, api_key="")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
BeninSmallHolderCashews(str(tmp_path))
def test_invalid_bands(self) -> None:
diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py
index 7dfea2548cb..a0e93952244 100644
--- a/tests/datasets/test_bigearthnet.py
+++ b/tests/datasets/test_bigearthnet.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import BigEarthNet
+from torchgeo.datasets import BigEarthNet, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -134,10 +134,7 @@ def test_already_downloaded_not_extracted(
)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
BigEarthNet(str(tmp_path))
def test_plot(self, dataset: BigEarthNet) -> None:
diff --git a/tests/datasets/test_biomassters.py b/tests/datasets/test_biomassters.py
new file mode 100644
index 00000000000..17dab2df03c
--- /dev/null
+++ b/tests/datasets/test_biomassters.py
@@ -0,0 +1,50 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+
+import os
+from itertools import product
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+from _pytest.fixtures import SubRequest
+
+from torchgeo.datasets import BioMassters, DatasetNotFoundError
+
+
+class TestBioMassters:
+ @pytest.fixture(
+ params=product(["train", "test"], [["S1"], ["S2"], ["S1", "S2"]], [True, False])
+ )
+ def dataset(self, request: SubRequest) -> BioMassters:
+ root = os.path.join("tests", "data", "biomassters")
+ split, sensors, as_time_series = request.param
+ return BioMassters(
+ root, split=split, sensors=sensors, as_time_series=as_time_series
+ )
+
+ def test_len_of_ds(self, dataset: BioMassters) -> None:
+ assert len(dataset) > 0
+
+ def test_invalid_split(self, dataset: BioMassters) -> None:
+ with pytest.raises(AssertionError):
+ BioMassters(dataset.root, split="foo")
+
+ def test_invalid_bands(self, dataset: BioMassters) -> None:
+ with pytest.raises(AssertionError):
+ BioMassters(dataset.root, sensors=["S3"])
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ BioMassters(str(tmp_path))
+
+ def test_plot(self, dataset: BioMassters) -> None:
+ dataset.plot(dataset[0], suptitle="Test")
+ plt.close()
+
+ sample = dataset[0]
+ if dataset.split == "train":
+ sample["prediction"] = sample["label"]
+ dataset.plot(sample)
+ plt.close()
diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py
index eed08281697..f53023925b5 100644
--- a/tests/datasets/test_cbf.py
+++ b/tests/datasets/test_cbf.py
@@ -16,6 +16,7 @@
from torchgeo.datasets import (
BoundingBox,
CanadianBuildingFootprints,
+ DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
@@ -61,7 +62,7 @@ def test_or(self, dataset: CanadianBuildingFootprints) -> None:
assert isinstance(ds, UnionDataset)
def test_already_downloaded(self, dataset: CanadianBuildingFootprints) -> None:
- CanadianBuildingFootprints(root=dataset.root, download=True)
+ CanadianBuildingFootprints(dataset.paths, download=True)
def test_plot(self, dataset: CanadianBuildingFootprints) -> None:
query = dataset.bounds
@@ -75,7 +76,7 @@ def test_plot_prediction(self, dataset: CanadianBuildingFootprints) -> None:
dataset.plot(x, suptitle="Prediction")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CanadianBuildingFootprints(str(tmp_path))
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py
index e5badeeb0fc..47d0beb8d6a 100644
--- a/tests/datasets/test_cdl.py
+++ b/tests/datasets/test_cdl.py
@@ -15,7 +15,13 @@
from rasterio.crs import CRS
import torchgeo.datasets.utils
-from torchgeo.datasets import CDL, BoundingBox, IntersectionDataset, UnionDataset
+from torchgeo.datasets import (
+ CDL,
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ UnionDataset,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -74,7 +80,7 @@ def test_full_year(self, dataset: CDL) -> None:
next(dataset.index.intersection(tuple(query)))
def test_already_extracted(self, dataset: CDL) -> None:
- CDL(root=dataset.root, years=[2020, 2021])
+ CDL(dataset.paths, years=[2020, 2021])
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "cdl", "*_30m_cdls.zip")
@@ -111,7 +117,7 @@ def test_plot_prediction(self, dataset: CDL) -> None:
plt.close()
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CDL(str(tmp_path))
def test_invalid_query(self, dataset: CDL) -> None:
diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py
index 3348c1f6c95..0692e0be00f 100644
--- a/tests/datasets/test_chesapeake.py
+++ b/tests/datasets/test_chesapeake.py
@@ -18,6 +18,7 @@
BoundingBox,
Chesapeake13,
ChesapeakeCVPR,
+ DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
@@ -59,7 +60,7 @@ def test_or(self, dataset: Chesapeake13) -> None:
assert isinstance(ds, UnionDataset)
def test_already_extracted(self, dataset: Chesapeake13) -> None:
- Chesapeake13(root=dataset.root, download=True)
+ Chesapeake13(dataset.paths, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join(
@@ -70,7 +71,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
Chesapeake13(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Chesapeake13(str(tmp_path), checksum=True)
def test_plot(self, dataset: Chesapeake13) -> None:
@@ -141,7 +142,7 @@ def dataset(
)
monkeypatch.setattr(
ChesapeakeCVPR,
- "files",
+ "_files",
["de_1m_2013_extended-debuffered-test_tiles", "spatial_index.geojson"],
)
root = str(tmp_path)
@@ -193,7 +194,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
ChesapeakeCVPR(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ChesapeakeCVPR(str(tmp_path), checksum=True)
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
diff --git a/tests/datasets/test_cloud_cover.py b/tests/datasets/test_cloud_cover.py
index 15fb078d49a..68e8511b3a9 100644
--- a/tests/datasets/test_cloud_cover.py
+++ b/tests/datasets/test_cloud_cover.py
@@ -12,7 +12,7 @@
import torch.nn as nn
from pytest import MonkeyPatch
-from torchgeo.datasets import CloudCoverDetection
+from torchgeo.datasets import CloudCoverDetection, DatasetNotFoundError
class Collection:
@@ -83,7 +83,7 @@ def test_already_downloaded(self, dataset: CloudCoverDetection) -> None:
CloudCoverDetection(root=dataset.root, split="test", download=True, api_key="")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CloudCoverDetection(str(tmp_path))
def test_plot(self, dataset: CloudCoverDetection) -> None:
diff --git a/tests/datasets/test_cms_mangrove_canopy.py b/tests/datasets/test_cms_mangrove_canopy.py
index 9a30a2a0283..1ebd33a1095 100644
--- a/tests/datasets/test_cms_mangrove_canopy.py
+++ b/tests/datasets/test_cms_mangrove_canopy.py
@@ -12,7 +12,12 @@
from pytest import MonkeyPatch
from rasterio.crs import CRS
-from torchgeo.datasets import CMSGlobalMangroveCanopy, IntersectionDataset, UnionDataset
+from torchgeo.datasets import (
+ CMSGlobalMangroveCanopy,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ UnionDataset,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -44,9 +49,9 @@ def test_getitem(self, dataset: CMSGlobalMangroveCanopy) -> None:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
- def test_no_dataset(self) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in."):
- CMSGlobalMangroveCanopy(root="/test")
+ def test_no_dataset(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ CMSGlobalMangroveCanopy(str(tmp_path))
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
@@ -65,7 +70,7 @@ def test_corrupted(self, tmp_path: Path) -> None:
) as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
- CMSGlobalMangroveCanopy(root=str(tmp_path), country="Angola", checksum=True)
+ CMSGlobalMangroveCanopy(str(tmp_path), country="Angola", checksum=True)
def test_invalid_country(self) -> None:
with pytest.raises(AssertionError):
diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py
index 6742ecfb211..19f448f5a27 100644
--- a/tests/datasets/test_cowc.py
+++ b/tests/datasets/test_cowc.py
@@ -14,8 +14,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import COWCCounting, COWCDetection
-from torchgeo.datasets.cowc import COWC
+from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -78,7 +77,7 @@ def test_invalid_split(self) -> None:
COWCCounting(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
COWCCounting(str(tmp_path))
def test_plot(self, dataset: COWCCounting) -> None:
@@ -142,7 +141,7 @@ def test_invalid_split(self) -> None:
COWCDetection(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
COWCDetection(str(tmp_path))
def test_plot(self, dataset: COWCDetection) -> None:
diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py
index 638c63128c3..22667efbfd4 100644
--- a/tests/datasets/test_cv4a_kenya_crop_type.py
+++ b/tests/datasets/test_cv4a_kenya_crop_type.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
-from torchgeo.datasets import CV4AKenyaCropType
+from torchgeo.datasets import CV4AKenyaCropType, DatasetNotFoundError
class Collection:
@@ -84,7 +84,7 @@ def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None:
CV4AKenyaCropType(root=dataset.root, download=True, api_key="")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CV4AKenyaCropType(str(tmp_path))
def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None:
diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py
index 56788bab37f..6ab894c1fb7 100644
--- a/tests/datasets/test_cyclone.py
+++ b/tests/datasets/test_cyclone.py
@@ -14,7 +14,7 @@
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
-from torchgeo.datasets import TropicalCyclone
+from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone
class Collection:
@@ -80,7 +80,7 @@ def test_invalid_split(self) -> None:
TropicalCyclone(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
TropicalCyclone(str(tmp_path))
def test_plot(self, dataset: TropicalCyclone) -> None:
diff --git a/tests/datasets/test_deepglobelandcover.py b/tests/datasets/test_deepglobelandcover.py
index da243efc944..1ab9b70b2d1 100644
--- a/tests/datasets/test_deepglobelandcover.py
+++ b/tests/datasets/test_deepglobelandcover.py
@@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import DeepGlobeLandCover
+from torchgeo.datasets import DatasetNotFoundError, DeepGlobeLandCover
class TestDeepGlobeLandCover:
@@ -55,12 +55,7 @@ def test_invalid_split(self) -> None:
DeepGlobeLandCover(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(
- RuntimeError,
- match="Dataset not found in `root`, either"
- + " specify a different `root` directory or manually download"
- + " the dataset to this directory.",
- ):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
DeepGlobeLandCover(str(tmp_path))
def test_plot(self, dataset: DeepGlobeLandCover) -> None:
diff --git a/tests/datasets/test_dfc2022.py b/tests/datasets/test_dfc2022.py
index 4b1a5221506..22caebcfb7b 100644
--- a/tests/datasets/test_dfc2022.py
+++ b/tests/datasets/test_dfc2022.py
@@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import DFC2022
+from torchgeo.datasets import DFC2022, DatasetNotFoundError
class TestDFC2022:
@@ -74,7 +74,7 @@ def test_invalid_split(self) -> None:
DFC2022(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
DFC2022(str(tmp_path))
def test_plot(self, dataset: DFC2022) -> None:
diff --git a/tests/datasets/test_eddmaps.py b/tests/datasets/test_eddmaps.py
index 59b6c8f5721..a15adbeecaf 100644
--- a/tests/datasets/test_eddmaps.py
+++ b/tests/datasets/test_eddmaps.py
@@ -1,17 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
from pathlib import Path
-from typing import Any
import pytest
-from pytest import MonkeyPatch
-from torchgeo.datasets import BoundingBox, EDDMapS, IntersectionDataset, UnionDataset
-
-pytest.importorskip("pandas", minversion="1.1.3")
+from torchgeo.datasets import (
+ BoundingBox,
+ DatasetNotFoundError,
+ EDDMapS,
+ IntersectionDataset,
+ UnionDataset,
+)
class TestEDDMapS:
@@ -36,29 +37,9 @@ def test_or(self, dataset: EDDMapS) -> None:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
EDDMapS(str(tmp_path))
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == "pandas":
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, "__import__", mocked_import)
-
- def test_mock_missing_module(
- self, dataset: EDDMapS, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match="pandas is not installed and is required to use this dataset",
- ):
- EDDMapS(dataset.root)
-
def test_invalid_query(self, dataset: EDDMapS) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
diff --git a/tests/datasets/test_enviroatlas.py b/tests/datasets/test_enviroatlas.py
index 53760f3a59f..da7641b47a8 100644
--- a/tests/datasets/test_enviroatlas.py
+++ b/tests/datasets/test_enviroatlas.py
@@ -16,6 +16,7 @@
import torchgeo.datasets.utils
from torchgeo.datasets import (
BoundingBox,
+ DatasetNotFoundError,
EnviroAtlas,
IntersectionDataset,
UnionDataset,
@@ -47,7 +48,7 @@ def dataset(
)
monkeypatch.setattr(
EnviroAtlas,
- "files",
+ "_files",
["pittsburgh_pa-2010_1m-train_tiles-debuffered", "spatial_index.geojson"],
)
root = str(tmp_path)
@@ -88,7 +89,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
EnviroAtlas(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
EnviroAtlas(str(tmp_path), checksum=True)
def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None:
diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py
index 74f9200cb5a..1e01e0ac11d 100644
--- a/tests/datasets/test_esri2020.py
+++ b/tests/datasets/test_esri2020.py
@@ -13,7 +13,13 @@
from rasterio.crs import CRS
import torchgeo.datasets.utils
-from torchgeo.datasets import BoundingBox, Esri2020, IntersectionDataset, UnionDataset
+from torchgeo.datasets import (
+ BoundingBox,
+ DatasetNotFoundError,
+ Esri2020,
+ IntersectionDataset,
+ UnionDataset,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -47,7 +53,7 @@ def test_getitem(self, dataset: Esri2020) -> None:
assert isinstance(x["mask"], torch.Tensor)
def test_already_extracted(self, dataset: Esri2020) -> None:
- Esri2020(root=dataset.root, download=True)
+ Esri2020(dataset.paths, download=True)
def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join(
@@ -57,10 +63,10 @@ def test_not_extracted(self, tmp_path: Path) -> None:
"io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip",
)
shutil.copy(url, tmp_path)
- Esri2020(root=str(tmp_path))
+ Esri2020(str(tmp_path))
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Esri2020(str(tmp_path), checksum=True)
def test_and(self, dataset: Esri2020) -> None:
diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py
index c386005f182..8ee695bbcab 100644
--- a/tests/datasets/test_etci2021.py
+++ b/tests/datasets/test_etci2021.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import ETCI2021
+from torchgeo.datasets import ETCI2021, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
@@ -77,7 +77,7 @@ def test_invalid_split(self) -> None:
ETCI2021(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ETCI2021(str(tmp_path))
def test_plot(self, dataset: ETCI2021) -> None:
diff --git a/tests/datasets/test_eudem.py b/tests/datasets/test_eudem.py
index fd04f541e9a..9816ea3c84d 100644
--- a/tests/datasets/test_eudem.py
+++ b/tests/datasets/test_eudem.py
@@ -12,7 +12,13 @@
from pytest import MonkeyPatch
from rasterio.crs import CRS
-from torchgeo.datasets import EUDEM, BoundingBox, IntersectionDataset, UnionDataset
+from torchgeo.datasets import (
+ EUDEM,
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ UnionDataset,
+)
class TestEUDEM:
@@ -33,21 +39,22 @@ def test_getitem(self, dataset: EUDEM) -> None:
assert isinstance(x["mask"], torch.Tensor)
def test_extracted_already(self, dataset: EUDEM) -> None:
- zipfile = os.path.join(dataset.root, "eu_dem_v11_E30N10.zip")
- shutil.unpack_archive(zipfile, dataset.root, "zip")
- EUDEM(dataset.root)
+ assert isinstance(dataset.paths, str)
+ zipfile = os.path.join(dataset.paths, "eu_dem_v11_E30N10.zip")
+ shutil.unpack_archive(zipfile, dataset.paths, "zip")
+ EUDEM(dataset.paths)
def test_no_dataset(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
- with pytest.raises(RuntimeError, match="Dataset not found in"):
- EUDEM(root=str(tmp_path))
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ EUDEM(str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "eu_dem_v11_E30N10.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
- EUDEM(root=str(tmp_path), checksum=True)
+ EUDEM(str(tmp_path), checksum=True)
def test_and(self, dataset: EUDEM) -> None:
ds = dataset & dataset
diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py
index c79b1f80bc6..5d92498b222 100644
--- a/tests/datasets/test_eurosat.py
+++ b/tests/datasets/test_eurosat.py
@@ -15,7 +15,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import EuroSAT, EuroSAT100
+from torchgeo.datasets import DatasetNotFoundError, EuroSAT, EuroSAT100
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -92,10 +92,7 @@ def test_already_downloaded_not_extracted(
EuroSAT(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
EuroSAT(str(tmp_path))
def test_plot(self, dataset: EuroSAT) -> None:
diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py
index 0983444ad82..fcb7d4f7711 100644
--- a/tests/datasets/test_fair1m.py
+++ b/tests/datasets/test_fair1m.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import FAIR1M
+from torchgeo.datasets import FAIR1M, DatasetNotFoundError
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
@@ -120,7 +120,7 @@ def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None:
def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None:
shutil.rmtree(str(tmp_path))
- with pytest.raises(RuntimeError, match="Dataset not found in"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
FAIR1M(root=str(tmp_path), split=dataset.split)
def test_plot(self, dataset: FAIR1M) -> None:
diff --git a/tests/datasets/test_fire_risk.py b/tests/datasets/test_fire_risk.py
index 8e42cfe9742..76689bf9e82 100644
--- a/tests/datasets/test_fire_risk.py
+++ b/tests/datasets/test_fire_risk.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import FireRisk
+from torchgeo.datasets import DatasetNotFoundError, FireRisk
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -56,7 +56,7 @@ def test_already_downloaded_not_extracted(
FireRisk(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
FireRisk(str(tmp_path))
def test_plot(self, dataset: FireRisk) -> None:
diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py
index 8e333879f37..47caaebe5e3 100644
--- a/tests/datasets/test_forestdamage.py
+++ b/tests/datasets/test_forestdamage.py
@@ -12,7 +12,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import ForestDamage
+from torchgeo.datasets import DatasetNotFoundError, ForestDamage
def download_url(url: str, root: str, *args: str) -> None:
@@ -66,7 +66,7 @@ def test_corrupted(self, tmp_path: Path) -> None:
ForestDamage(root=str(tmp_path), checksum=True)
def test_not_found(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ForestDamage(str(tmp_path))
def test_plot(self, dataset: ForestDamage) -> None:
diff --git a/tests/datasets/test_gbif.py b/tests/datasets/test_gbif.py
index fd3a383fd18..bf6923a6bc2 100644
--- a/tests/datasets/test_gbif.py
+++ b/tests/datasets/test_gbif.py
@@ -1,17 +1,18 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
from pathlib import Path
-from typing import Any
import pytest
-from pytest import MonkeyPatch
-from torchgeo.datasets import GBIF, BoundingBox, IntersectionDataset, UnionDataset
-
-pytest.importorskip("pandas", minversion="1.1.3")
+from torchgeo.datasets import (
+ GBIF,
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ UnionDataset,
+)
class TestGBIF:
@@ -36,29 +37,9 @@ def test_or(self, dataset: GBIF) -> None:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
GBIF(str(tmp_path))
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == "pandas":
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, "__import__", mocked_import)
-
- def test_mock_missing_module(
- self, dataset: GBIF, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match="pandas is not installed and is required to use this dataset",
- ):
- GBIF(dataset.root)
-
def test_invalid_query(self, dataset: GBIF) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py
index cf4ef25d880..31b140e91f2 100644
--- a/tests/datasets/test_geo.py
+++ b/tests/datasets/test_geo.py
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-
import os
import pickle
+from collections.abc import Iterable
from pathlib import Path
+from typing import Optional, Union
import pytest
import torch
@@ -15,6 +16,7 @@
from torchgeo.datasets import (
NAIP,
BoundingBox,
+ DatasetNotFoundError,
GeoDataset,
IntersectionDataset,
NonGeoClassificationDataset,
@@ -32,11 +34,13 @@ def __init__(
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
+ paths: Optional[Union[str, Iterable[str]]] = None,
) -> None:
super().__init__()
self.index.insert(0, tuple(bounds))
self._crs = crs
self.res = res
+ self.paths = paths or []
def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:
hits = self.index.intersection(tuple(query), objects=True)
@@ -151,6 +155,23 @@ def test_and_nongeo(self, dataset: GeoDataset) -> None:
):
dataset & ds2 # type: ignore[operator]
+ def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None:
+ paths = [str(tmp_path), str(tmp_path / "non_existing_file.tif")]
+ with pytest.warns(UserWarning, match="Path was ignored."):
+ assert len(CustomGeoDataset(paths=paths).files) == 0
+
+ def test_files_property_for_virtual_files(self) -> None:
+ # Tests only a subset of schemes and combinations.
+ paths = [
+ "file://directory/file.tif",
+ "zip://archive.zip!folder/file.tif",
+ "az://azure_bucket/prefix/file.tif",
+ "/vsiaz/azure_bucket/prefix/file.tif",
+ "zip+az://azure_bucket/prefix/archive.zip!folder_in_archive/file.tif",
+ "/vsizip//vsiaz/azure_bucket/prefix/archive.zip/folder_in_archive/file.tif",
+ ]
+ assert len(CustomGeoDataset(paths=paths).files) == len(paths)
+
class TestRasterDataset:
@pytest.fixture(params=zip([["R", "G", "B"], None], [True, False]))
@@ -178,6 +199,39 @@ def sentinel(self, request: SubRequest) -> Sentinel2:
cache = request.param[1]
return Sentinel2(root, bands=bands, transforms=transforms, cache=cache)
+ @pytest.mark.parametrize(
+ "paths",
+ [
+ # Single directory
+ os.path.join("tests", "data", "naip"),
+ # Multiple directories
+ [
+ os.path.join("tests", "data", "naip"),
+ os.path.join("tests", "data", "naip"),
+ ],
+ # Single file
+ os.path.join("tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"),
+ # Multiple files
+ (
+ os.path.join(
+ "tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"
+ ),
+ os.path.join(
+ "tests", "data", "naip", "m_3807511_ne_18_060_20190605.tif"
+ ),
+ ),
+ # Combination
+ {
+ os.path.join("tests", "data", "naip"),
+ os.path.join(
+ "tests", "data", "naip", "m_3807511_ne_18_060_20181104.tif"
+ ),
+ },
+ ],
+ )
+ def test_files(self, paths: Union[str, Iterable[str]]) -> None:
+ assert 1 <= len(NAIP(paths).files) <= 2
+
def test_getitem_single_file(self, naip: NAIP) -> None:
x = naip[naip.bounds]
assert isinstance(x, dict)
@@ -209,7 +263,7 @@ def test_invalid_query(self, sentinel: Sentinel2) -> None:
sentinel[query]
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
RasterDataset(str(tmp_path))
def test_no_all_bands(self) -> None:
@@ -274,7 +328,7 @@ def test_invalid_query(self, dataset: CustomVectorDataset) -> None:
dataset[query]
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="No VectorDataset data was found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
VectorDataset(str(tmp_path))
diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py
index 26d269c1354..e39619d8313 100644
--- a/tests/datasets/test_gid15.py
+++ b/tests/datasets/test_gid15.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import GID15
+from torchgeo.datasets import GID15, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
@@ -58,7 +58,7 @@ def test_invalid_split(self) -> None:
GID15(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
GID15(str(tmp_path))
def test_plot(self, dataset: GID15) -> None:
diff --git a/tests/datasets/test_globbiomass.py b/tests/datasets/test_globbiomass.py
index 2bc94bdc00c..5bffc3ff0a4 100644
--- a/tests/datasets/test_globbiomass.py
+++ b/tests/datasets/test_globbiomass.py
@@ -14,6 +14,7 @@
from torchgeo.datasets import (
BoundingBox,
+ DatasetNotFoundError,
GlobBiomass,
IntersectionDataset,
UnionDataset,
@@ -47,17 +48,17 @@ def test_getitem(self, dataset: GlobBiomass) -> None:
assert isinstance(x["mask"], torch.Tensor)
def test_already_extracted(self, dataset: GlobBiomass) -> None:
- GlobBiomass(root=dataset.root)
+ GlobBiomass(dataset.paths)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
GlobBiomass(str(tmp_path), checksum=True)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "N00E020_agb.zip"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
- GlobBiomass(root=str(tmp_path), checksum=True)
+ GlobBiomass(str(tmp_path), checksum=True)
def test_and(self, dataset: GlobBiomass) -> None:
ds = dataset & dataset
diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py
index addc383bc8f..a6192ab012e 100644
--- a/tests/datasets/test_idtrees.py
+++ b/tests/datasets/test_idtrees.py
@@ -16,9 +16,8 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import IDTReeS
+from torchgeo.datasets import DatasetNotFoundError, IDTReeS
-pytest.importorskip("pandas", minversion="1.1.3")
pytest.importorskip("laspy", minversion="2")
@@ -51,7 +50,7 @@ def dataset(
transforms = nn.Identity()
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
- @pytest.fixture(params=["pandas", "laspy", "pyvista"])
+ @pytest.fixture(params=["laspy", "pyvista"])
def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
import_orig = builtins.__import__
package = str(request.param)
@@ -92,10 +91,7 @@ def test_already_downloaded(self, dataset: IDTReeS) -> None:
IDTReeS(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
IDTReeS(str(tmp_path))
def test_not_extracted(self, tmp_path: Path) -> None:
@@ -110,13 +106,13 @@ def test_mock_missing_module(
) -> None:
package = mock_missing_module
- if package in ["pandas", "laspy"]:
+ if package == "laspy":
with pytest.raises(
ImportError,
match=f"{package} is not installed and is required to use this dataset",
):
IDTReeS(dataset.root, dataset.split, dataset.task)
- elif package in ["pyvista"]:
+ elif package == "pyvista":
with pytest.raises(
ImportError,
match=f"{package} is not installed and is required to plot point cloud",
@@ -140,7 +136,8 @@ def test_plot(self, dataset: IDTReeS) -> None:
plt.close()
def test_plot_las(self, dataset: IDTReeS) -> None:
- pyvista = pytest.importorskip("pyvista", minversion="0.29")
+ pyvista = pytest.importorskip("pyvista", minversion="0.34.2")
+ pyvista.OFF_SCREEN = True
# Test point cloud without colors
point_cloud = dataset.plot_las(index=0)
diff --git a/tests/datasets/test_inaturalist.py b/tests/datasets/test_inaturalist.py
index 5bac9c4a5e7..49c87d83f77 100644
--- a/tests/datasets/test_inaturalist.py
+++ b/tests/datasets/test_inaturalist.py
@@ -1,23 +1,19 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
from pathlib import Path
-from typing import Any
import pytest
-from pytest import MonkeyPatch
from torchgeo.datasets import (
BoundingBox,
+ DatasetNotFoundError,
INaturalist,
IntersectionDataset,
UnionDataset,
)
-pytest.importorskip("pandas", minversion="1.1.3")
-
class TestINaturalist:
@pytest.fixture(scope="class")
@@ -41,29 +37,9 @@ def test_or(self, dataset: INaturalist) -> None:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
INaturalist(str(tmp_path))
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == "pandas":
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, "__import__", mocked_import)
-
- def test_mock_missing_module(
- self, dataset: INaturalist, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match="pandas is not installed and is required to use this dataset",
- ):
- INaturalist(dataset.root)
-
def test_invalid_query(self, dataset: INaturalist) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
diff --git a/tests/datasets/test_inria.py b/tests/datasets/test_inria.py
index 1ccbd22af27..71739a0ec8b 100644
--- a/tests/datasets/test_inria.py
+++ b/tests/datasets/test_inria.py
@@ -11,16 +11,16 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import InriaAerialImageLabeling
+from torchgeo.datasets import DatasetNotFoundError, InriaAerialImageLabeling
class TestInriaAerialImageLabeling:
- @pytest.fixture(params=["train", "test"])
+ @pytest.fixture(params=["train", "val", "test"])
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch
) -> InriaAerialImageLabeling:
root = os.path.join("tests", "data", "inria")
- test_md5 = "478688944e4797c097d9387fd0b3f038"
+ test_md5 = "3ecbe95eb84aea064e455c4321546be1"
monkeypatch.setattr(InriaAerialImageLabeling, "md5", test_md5)
transforms = nn.Identity()
return InriaAerialImageLabeling(
@@ -38,13 +38,18 @@ def test_getitem(self, dataset: InriaAerialImageLabeling) -> None:
assert x["image"].ndim == 3
def test_len(self, dataset: InriaAerialImageLabeling) -> None:
- assert len(dataset) == 5
+ if dataset.split == "train":
+ assert len(dataset) == 2
+ elif dataset.split == "val":
+ assert len(dataset) == 5
+ elif dataset.split == "test":
+ assert len(dataset) == 7
def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None:
InriaAerialImageLabeling(root=dataset.root)
def test_not_downloaded(self, tmp_path: str) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
InriaAerialImageLabeling(str(tmp_path))
def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None:
diff --git a/tests/datasets/test_l7irish.py b/tests/datasets/test_l7irish.py
index 59610d78f62..8b0e9a0c64f 100644
--- a/tests/datasets/test_l7irish.py
+++ b/tests/datasets/test_l7irish.py
@@ -14,7 +14,13 @@
from rasterio.crs import CRS
import torchgeo.datasets.utils
-from torchgeo.datasets import BoundingBox, IntersectionDataset, L7Irish, UnionDataset
+from torchgeo.datasets import (
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ L7Irish,
+ UnionDataset,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -58,7 +64,7 @@ def test_plot(self, dataset: L7Irish) -> None:
plt.close()
def test_already_extracted(self, dataset: L7Irish) -> None:
- L7Irish(root=dataset.root, download=True)
+ L7Irish(dataset.paths, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "l7irish", "*.tar.gz")
@@ -68,7 +74,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
L7Irish(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
L7Irish(str(tmp_path))
def test_plot_prediction(self, dataset: L7Irish) -> None:
@@ -88,7 +94,7 @@ def test_rgb_bands_absent_plot(self, dataset: L7Irish) -> None:
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
- ds = L7Irish(root=dataset.root, bands=["B10", "B20", "B50"])
+ ds = L7Irish(dataset.paths, bands=["B10", "B20", "B50"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
diff --git a/tests/datasets/test_l8biome.py b/tests/datasets/test_l8biome.py
index 2b653c87061..ca57b70fa4e 100644
--- a/tests/datasets/test_l8biome.py
+++ b/tests/datasets/test_l8biome.py
@@ -14,7 +14,13 @@
from rasterio.crs import CRS
import torchgeo.datasets.utils
-from torchgeo.datasets import BoundingBox, IntersectionDataset, L8Biome, UnionDataset
+from torchgeo.datasets import (
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ L8Biome,
+ UnionDataset,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -58,7 +64,7 @@ def test_plot(self, dataset: L8Biome) -> None:
plt.close()
def test_already_extracted(self, dataset: L8Biome) -> None:
- L8Biome(root=dataset.root, download=True)
+ L8Biome(dataset.paths, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join("tests", "data", "l8biome", "*.tar.gz")
@@ -68,7 +74,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
L8Biome(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
L8Biome(str(tmp_path))
def test_plot_prediction(self, dataset: L8Biome) -> None:
@@ -88,7 +94,7 @@ def test_rgb_bands_absent_plot(self, dataset: L8Biome) -> None:
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
):
- ds = L8Biome(root=dataset.root, bands=["B1", "B2", "B5"])
+ ds = L8Biome(dataset.paths, bands=["B1", "B2", "B5"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py
index aba9da8641c..e5bb366e5a2 100644
--- a/tests/datasets/test_landcoverai.py
+++ b/tests/datasets/test_landcoverai.py
@@ -14,7 +14,12 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import BoundingBox, LandCoverAI, LandCoverAIGeo
+from torchgeo.datasets import (
+ BoundingBox,
+ DatasetNotFoundError,
+ LandCoverAI,
+ LandCoverAIGeo,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -40,7 +45,7 @@ def test_getitem(self, dataset: LandCoverAIGeo) -> None:
assert isinstance(x["mask"], torch.Tensor)
def test_already_extracted(self, dataset: LandCoverAIGeo) -> None:
- LandCoverAIGeo(root=dataset.root, download=True)
+ LandCoverAIGeo(dataset.root, download=True)
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip")
@@ -49,7 +54,7 @@ def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> N
LandCoverAIGeo(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LandCoverAIGeo(str(tmp_path))
def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None:
@@ -75,7 +80,7 @@ class TestLandCoverAI:
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> LandCoverAI:
- pytest.importorskip("cv2", minversion="4.4.0.46")
+ pytest.importorskip("cv2", minversion="4.4.0")
monkeypatch.setattr(torchgeo.datasets.landcoverai, "download_url", download_url)
md5 = "ff8998857cc8511f644d3f7d0f3688d0"
monkeypatch.setattr(LandCoverAI, "md5", md5)
@@ -106,7 +111,7 @@ def test_already_extracted(self, dataset: LandCoverAI) -> None:
LandCoverAI(root=dataset.root, download=True)
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
- pytest.importorskip("cv2", minversion="4.4.0.46")
+ pytest.importorskip("cv2", minversion="4.4.0")
sha256 = "ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b"
monkeypatch.setattr(LandCoverAI, "sha256", sha256)
url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip")
@@ -115,7 +120,7 @@ def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> N
LandCoverAI(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LandCoverAI(str(tmp_path))
def test_invalid_split(self) -> None:
diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py
index f4ce259f6ee..f85de17e2a1 100644
--- a/tests/datasets/test_landsat.py
+++ b/tests/datasets/test_landsat.py
@@ -12,7 +12,13 @@
from pytest import MonkeyPatch
from rasterio.crs import CRS
-from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset
+from torchgeo.datasets import (
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ Landsat8,
+ UnionDataset,
+)
class TestLandsat8:
@@ -52,7 +58,7 @@ def test_plot(self, dataset: Landsat8) -> None:
def test_plot_wrong_bands(self, dataset: Landsat8) -> None:
bands = ("SR_B1",)
- ds = Landsat8(root=dataset.root, bands=bands)
+ ds = Landsat8(dataset.paths, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
@@ -60,7 +66,7 @@ def test_plot_wrong_bands(self, dataset: Landsat8) -> None:
ds.plot(x)
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Landsat8(str(tmp_path))
def test_invalid_query(self, dataset: Landsat8) -> None:
diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py
index b4a46d43e62..cafe1ed8206 100644
--- a/tests/datasets/test_levircd.py
+++ b/tests/datasets/test_levircd.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import LEVIRCDPlus
+from torchgeo.datasets import DatasetNotFoundError, LEVIRCDPlus
def download_url(url: str, root: str, *args: str) -> None:
@@ -38,9 +38,11 @@ def dataset(
def test_getitem(self, dataset: LEVIRCDPlus) -> None:
x = dataset[0]
assert isinstance(x, dict)
- assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["image1"], torch.Tensor)
+ assert isinstance(x["image2"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)
- assert x["image"].shape[0] == 2
+ assert x["image1"].shape[0] == 3
+ assert x["image2"].shape[0] == 3
def test_len(self, dataset: LEVIRCDPlus) -> None:
assert len(dataset) == 2
@@ -53,7 +55,7 @@ def test_invalid_split(self) -> None:
LEVIRCDPlus(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LEVIRCDPlus(str(tmp_path))
def test_plot(self, dataset: LEVIRCDPlus) -> None:
diff --git a/tests/datasets/test_loveda.py b/tests/datasets/test_loveda.py
index 666afce52ad..a368d711034 100644
--- a/tests/datasets/test_loveda.py
+++ b/tests/datasets/test_loveda.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import LoveDA
+from torchgeo.datasets import DatasetNotFoundError, LoveDA
def download_url(url: str, root: str, *args: str) -> None:
@@ -83,9 +83,7 @@ def test_invalid_scene(self) -> None:
LoveDA(scene=["garden"])
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(
- RuntimeError, match="Dataset not found at root directory or corrupted."
- ):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LoveDA(str(tmp_path))
def test_plot(self, dataset: LoveDA) -> None:
diff --git a/tests/datasets/test_mapinwild.py b/tests/datasets/test_mapinwild.py
new file mode 100644
index 00000000000..90aa35b6aa2
--- /dev/null
+++ b/tests/datasets/test_mapinwild.py
@@ -0,0 +1,137 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import glob
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.fixtures import SubRequest
+from pytest import MonkeyPatch
+from torch.utils.data import ConcatDataset
+
+import torchgeo.datasets.utils
+from torchgeo.datasets import DatasetNotFoundError, MapInWild
+
+
+def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
+ shutil.copy(url, root)
+
+
+class TestMapInWild:
+ @pytest.fixture(params=["train", "validation", "test"])
+ def dataset(
+ self, tmp_path: Path, monkeypatch: MonkeyPatch, request: SubRequest
+ ) -> MapInWild:
+ monkeypatch.setattr(torchgeo.datasets.mapinwild, "download_url", download_url)
+
+ md5s = {
+ "ESA_WC.zip": "3a1e696353d238c50996958855da02fc",
+ "VIIRS.zip": "e8b0e230edb1183c02092357af83bd52",
+ "mask.zip": "15245bb6368d27dbb4bd16310f4604fa",
+ "s1_part1.zip": "e660da4175518af993b63644e44a9d03",
+ "s1_part2.zip": "620cf0a7d598a2893bc7642ad7ee6087",
+ "s2_autumn_part1.zip": "624b6cf0191c5e0bc0d51f92b568e676",
+ "s2_autumn_part2.zip": "f848c62b8de36f06f12fb6b1b065c7b6",
+ "s2_spring_part1.zip": "3296f3a7da7e485708dd16be91deb111",
+ "s2_spring_part2.zip": "d27e94387a59f0558fe142a791682861",
+ "s2_summer_part1.zip": "41d783706c3c1e4238556a772d3232fb",
+ "s2_summer_part2.zip": "3495c87b67a771cfac5153d1958daa0c",
+ "s2_temporal_subset_part1.zip": "06fa463888cb033011a06cf69f82273e",
+ "s2_temporal_subset_part2.zip": "93e5383adeeea27f00051ecf110fcef8",
+ "s2_winter_part1.zip": "617abe1c6ad8d38725aa27c9dcc38ceb",
+ "s2_winter_part2.zip": "4e40d7bb0eec4ddea0b7b00314239a49",
+ "split_IDs.csv": "ca22c3d30d0b62e001ed0c327c147127",
+ }
+
+ monkeypatch.setattr(MapInWild, "md5s", md5s)
+
+ urls = os.path.join("tests", "data", "mapinwild")
+ monkeypatch.setattr(MapInWild, "url", urls)
+
+ root = str(tmp_path)
+ split = request.param
+
+ transforms = nn.Identity()
+ modality = [
+ "mask",
+ "viirs",
+ "esa_wc",
+ "s2_winter",
+ "s1",
+ "s2_summer",
+ "s2_spring",
+ "s2_autumn",
+ "s2_temporal_subset",
+ ]
+ return MapInWild(
+ root,
+ modality=modality,
+ split=split,
+ transforms=transforms,
+ download=True,
+ checksum=True,
+ )
+
+ def test_getitem(self, dataset: MapInWild) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["mask"], torch.Tensor)
+ assert x["image"].ndim == 3
+
+ def test_len(self, dataset: MapInWild) -> None:
+ assert len(dataset) == 1
+
+ def test_add(self, dataset: MapInWild) -> None:
+ ds = dataset + dataset
+ assert isinstance(ds, ConcatDataset)
+ assert len(ds) == 2
+
+ def test_invalid_split(self) -> None:
+ with pytest.raises(AssertionError):
+ MapInWild(split="foo")
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ MapInWild(root=str(tmp_path))
+
+ def test_downloaded_not_extracted(self, tmp_path: Path) -> None:
+ pathname = os.path.join("tests", "data", "mapinwild", "*", "*")
+ pathname_glob = glob.glob(pathname)
+ root = str(tmp_path)
+ for zipfile in pathname_glob:
+ shutil.copy(zipfile, root)
+ MapInWild(root, download=False)
+
+ def test_corrupted(self, tmp_path: Path) -> None:
+ pathname = os.path.join("tests", "data", "mapinwild", "**", "*.zip")
+ pathname_glob = glob.glob(pathname, recursive=True)
+ root = str(tmp_path)
+ for zipfile in pathname_glob:
+ shutil.copy(zipfile, root)
+ splitfile = os.path.join(
+ "tests", "data", "mapinwild", "split_IDs", "split_IDs.csv"
+ )
+ shutil.copy(splitfile, root)
+ with open(os.path.join(tmp_path, "mask.zip"), "w") as f:
+ f.write("bad")
+ with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
+ MapInWild(root=str(tmp_path), download=True, checksum=True)
+
+ def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None:
+ MapInWild(root=str(tmp_path), modality=dataset.modality, download=True)
+
+ def test_plot(self, dataset: MapInWild) -> None:
+ x = dataset[0].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+ x["prediction"] = x["mask"].clone()
+ dataset.plot(x)
+ plt.close()
diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py
index 751567e28a8..1e94fd003d0 100644
--- a/tests/datasets/test_millionaid.py
+++ b/tests/datasets/test_millionaid.py
@@ -11,7 +11,7 @@
import torch.nn as nn
from _pytest.fixtures import SubRequest
-from torchgeo.datasets import MillionAID
+from torchgeo.datasets import DatasetNotFoundError, MillionAID
class TestMillionAID:
@@ -38,7 +38,7 @@ def test_len(self, dataset: MillionAID) -> None:
assert len(dataset) == 2
def test_not_found(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
MillionAID(str(tmp_path))
def test_not_extracted(self, tmp_path: Path) -> None:
diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py
index 11e72938883..fe257ae2b78 100644
--- a/tests/datasets/test_naip.py
+++ b/tests/datasets/test_naip.py
@@ -10,7 +10,13 @@
import torch.nn as nn
from rasterio.crs import CRS
-from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset
+from torchgeo.datasets import (
+ NAIP,
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ UnionDataset,
+)
class TestNAIP:
@@ -41,7 +47,7 @@ def test_plot(self, dataset: NAIP) -> None:
plt.close()
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="No NAIP data was found in "):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
NAIP(str(tmp_path))
def test_invalid_query(self, dataset: NAIP) -> None:
diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py
index 05c96c8dcb1..f475234ffe6 100644
--- a/tests/datasets/test_nasa_marine_debris.py
+++ b/tests/datasets/test_nasa_marine_debris.py
@@ -12,7 +12,7 @@
import torch.nn as nn
from pytest import MonkeyPatch
-from torchgeo.datasets import NASAMarineDebris
+from torchgeo.datasets import DatasetNotFoundError, NASAMarineDebris
class Collection:
@@ -90,10 +90,7 @@ def test_corrupted_new_download(
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
NASAMarineDebris(str(tmp_path))
def test_plot(self, dataset: NASAMarineDebris) -> None:
diff --git a/tests/datasets/test_nlcd.py b/tests/datasets/test_nlcd.py
index 0f0f134e384..67dde52648d 100644
--- a/tests/datasets/test_nlcd.py
+++ b/tests/datasets/test_nlcd.py
@@ -13,7 +13,13 @@
from rasterio.crs import CRS
import torchgeo.datasets.utils
-from torchgeo.datasets import NLCD, BoundingBox, IntersectionDataset, UnionDataset
+from torchgeo.datasets import (
+ NLCD,
+ BoundingBox,
+ DatasetNotFoundError,
+ IntersectionDataset,
+ UnionDataset,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -69,7 +75,7 @@ def test_or(self, dataset: NLCD) -> None:
assert isinstance(ds, UnionDataset)
def test_already_extracted(self, dataset: NLCD) -> None:
- NLCD(root=dataset.root, download=True, years=[2019])
+ NLCD(dataset.paths, download=True, years=[2019])
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
@@ -107,7 +113,7 @@ def test_plot_prediction(self, dataset: NLCD) -> None:
plt.close()
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
NLCD(str(tmp_path))
def test_invalid_query(self, dataset: NLCD) -> None:
diff --git a/tests/datasets/test_openbuildings.py b/tests/datasets/test_openbuildings.py
index d8df3d8dc34..65244962553 100644
--- a/tests/datasets/test_openbuildings.py
+++ b/tests/datasets/test_openbuildings.py
@@ -1,30 +1,27 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import json
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
+import pandas as pd
import pytest
import torch
import torch.nn as nn
-from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import (
BoundingBox,
+ DatasetNotFoundError,
IntersectionDataset,
OpenBuildings,
UnionDataset,
)
-pd = pytest.importorskip("pandas", minversion="1.1.3")
-
class TestOpenBuildings:
@pytest.fixture
@@ -41,31 +38,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings:
monkeypatch.setattr(OpenBuildings, "md5s", md5s)
transforms = nn.Identity()
- return OpenBuildings(root=root, transforms=transforms)
-
- @pytest.fixture(params=["pandas"])
- def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
- import_orig = builtins.__import__
- package = str(request.param)
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == package:
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, "__import__", mocked_import)
- return package
-
- def test_mock_missing_module(
- self, dataset: OpenBuildings, mock_missing_module: str
- ) -> None:
- package = mock_missing_module
-
- with pytest.raises(
- ImportError,
- match=f"{package} is not installed and is required to use this dataset",
- ):
- OpenBuildings(root=dataset.root)
+ return OpenBuildings(root, transforms=transforms)
def test_no_shapes_to_rasterize(
self, dataset: OpenBuildings, tmp_path: Path
@@ -80,28 +53,15 @@ def test_no_shapes_to_rasterize(
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
- def test_no_building_data_found(self, tmp_path: Path) -> None:
- false_root = os.path.join(tmp_path, "empty")
- os.makedirs(false_root)
- shutil.copy(
- os.path.join("tests", "data", "openbuildings", "tiles.geojson"), false_root
- )
- with pytest.raises(
- RuntimeError, match="have manually downloaded the dataset as suggested "
- ):
- OpenBuildings(root=false_root)
+ def test_not_download(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ OpenBuildings(str(tmp_path))
def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f:
f.write("bad")
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
- OpenBuildings(dataset.root, checksum=True)
-
- def test_no_meta_data_found(self, tmp_path: Path) -> None:
- false_root = os.path.join(tmp_path, "empty")
- os.makedirs(false_root)
- with pytest.raises(FileNotFoundError, match="Meta data file"):
- OpenBuildings(root=false_root)
+ OpenBuildings(dataset.paths, checksum=True)
def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
# change meta data to another 'title_url' so that there is no match found
@@ -112,8 +72,8 @@ def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "tiles.geojson"), "w") as f:
json.dump(content, f)
- with pytest.raises(FileNotFoundError, match="data was found in"):
- OpenBuildings(dataset.root)
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ OpenBuildings(dataset.paths)
def test_getitem(self, dataset: OpenBuildings) -> None:
x = dataset[dataset.bounds]
diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py
index 78cbd30b8f7..82501f016d3 100644
--- a/tests/datasets/test_oscd.py
+++ b/tests/datasets/test_oscd.py
@@ -15,7 +15,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import OSCD
+from torchgeo.datasets import OSCD, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -23,7 +23,7 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
class TestOSCD:
- @pytest.fixture(params=zip(["all", "rgb"], ["train", "test"]))
+ @pytest.fixture(params=zip([OSCD.all_bands, OSCD.rgb_bands], ["train", "test"]))
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> OSCD:
@@ -72,15 +72,19 @@ def dataset(
def test_getitem(self, dataset: OSCD) -> None:
x = dataset[0]
assert isinstance(x, dict)
- assert isinstance(x["image"], torch.Tensor)
- assert x["image"].ndim == 3
+ assert isinstance(x["image1"], torch.Tensor)
+ assert x["image1"].ndim == 3
+ assert isinstance(x["image2"], torch.Tensor)
+ assert x["image2"].ndim == 3
assert isinstance(x["mask"], torch.Tensor)
assert x["mask"].ndim == 2
- if dataset.bands == "rgb":
- assert x["image"].shape[0] == 6
+ if dataset.bands == OSCD.rgb_bands:
+ assert x["image1"].shape[0] == 3
+ assert x["image2"].shape[0] == 3
else:
- assert x["image"].shape[0] == 26
+ assert x["image1"].shape[0] == 13
+ assert x["image2"].shape[0] == 13
def test_len(self, dataset: OSCD) -> None:
if dataset.split == "train":
@@ -103,9 +107,15 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
OSCD(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
OSCD(str(tmp_path))
def test_plot(self, dataset: OSCD) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
+
+ def test_failed_plot(self, dataset: OSCD) -> None:
+ single_band_dataset = OSCD(root=dataset.root, bands=("B01",))
+ with pytest.raises(ValueError, match="RGB bands must be present"):
+ x = single_band_dataset[0].copy()
+ single_band_dataset.plot(x, suptitle="Test")
diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py
index 698d12487b5..1decc20e0c8 100644
--- a/tests/datasets/test_pastis.py
+++ b/tests/datasets/test_pastis.py
@@ -14,7 +14,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import PASTIS
+from torchgeo.datasets import PASTIS, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -80,7 +80,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
PASTIS(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
PASTIS(str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:
diff --git a/tests/datasets/test_patternnet.py b/tests/datasets/test_patternnet.py
index 7e06264bf92..efab8bd7b31 100644
--- a/tests/datasets/test_patternnet.py
+++ b/tests/datasets/test_patternnet.py
@@ -12,7 +12,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import PatternNet
+from torchgeo.datasets import DatasetNotFoundError, PatternNet
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -52,10 +52,7 @@ def test_already_downloaded_not_extracted(
PatternNet(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
PatternNet(str(tmp_path))
def test_plot(self, dataset: PatternNet) -> None:
diff --git a/tests/datasets/test_potsdam.py b/tests/datasets/test_potsdam.py
index 7502a3db63b..b803b15ea95 100644
--- a/tests/datasets/test_potsdam.py
+++ b/tests/datasets/test_potsdam.py
@@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import Potsdam2D
+from torchgeo.datasets import DatasetNotFoundError, Potsdam2D
class TestPotsdam2D:
@@ -60,7 +60,7 @@ def test_invalid_split(self) -> None:
Potsdam2D(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Potsdam2D(str(tmp_path))
def test_plot(self, dataset: Potsdam2D) -> None:
diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py
index 212a7d645fc..b2f3a16eaef 100644
--- a/tests/datasets/test_reforestree.py
+++ b/tests/datasets/test_reforestree.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import matplotlib.pyplot as plt
import pytest
@@ -14,7 +12,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import ReforesTree
+from torchgeo.datasets import DatasetNotFoundError, ReforesTree
def download_url(url: str, root: str, *args: str) -> None:
@@ -24,7 +22,6 @@ def download_url(url: str, root: str, *args: str) -> None:
class TestReforesTree:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree:
- pytest.importorskip("pandas", minversion="1.1.3")
monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url)
data_dir = os.path.join("tests", "data", "reforestree")
@@ -54,32 +51,10 @@ def test_getitem(self, dataset: ReforesTree) -> None:
assert x["image"].ndim == 3
assert len(x["boxes"]) == 2
- @pytest.fixture
- def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
- import_orig = builtins.__import__
- package = "pandas"
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == package:
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, "__import__", mocked_import)
-
- def test_mock_missing_module(
- self, dataset: ReforesTree, mock_missing_module: None
- ) -> None:
- with pytest.raises(
- ImportError,
- match="pandas is not installed and is required to use this dataset",
- ):
- ReforesTree(root=dataset.root)
-
def test_len(self, dataset: ReforesTree) -> None:
assert len(dataset) == 2
def test_not_extracted(self, tmp_path: Path) -> None:
- pytest.importorskip("pandas", minversion="1.1.3")
url = os.path.join("tests", "data", "reforestree", "reforesTree.zip")
shutil.copy(url, tmp_path)
ReforesTree(root=str(tmp_path))
@@ -91,7 +66,7 @@ def test_corrupted(self, tmp_path: Path) -> None:
ReforesTree(root=str(tmp_path), checksum=True)
def test_not_found(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ReforesTree(str(tmp_path))
def test_plot(self, dataset: ReforesTree) -> None:
diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py
index b20a19ddbb4..099885deac8 100644
--- a/tests/datasets/test_resisc45.py
+++ b/tests/datasets/test_resisc45.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import RESISC45
+from torchgeo.datasets import RESISC45, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -78,10 +78,7 @@ def test_already_downloaded_not_extracted(
RESISC45(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
RESISC45(str(tmp_path))
def test_plot(self, dataset: RESISC45) -> None:
diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py
new file mode 100644
index 00000000000..c0bfd71e452
--- /dev/null
+++ b/tests/datasets/test_rwanda_field_boundary.py
@@ -0,0 +1,140 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import glob
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.fixtures import SubRequest
+from pytest import MonkeyPatch
+from torch.utils.data import ConcatDataset
+
+from torchgeo.datasets import DatasetNotFoundError, RwandaFieldBoundary
+
+
+class Collection:
+ def download(self, output_dir: str, **kwargs: str) -> None:
+ glob_path = os.path.join("tests", "data", "rwanda_field_boundary", "*.tar.gz")
+ for tarball in glob.iglob(glob_path):
+ shutil.copy(tarball, output_dir)
+
+
+def fetch(dataset_id: str, **kwargs: str) -> Collection:
+ return Collection()
+
+
+class TestRwandaFieldBoundary:
+ @pytest.fixture(params=["train", "test"])
+ def dataset(
+ self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
+ ) -> RwandaFieldBoundary:
+ radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3")
+ monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
+ monkeypatch.setattr(
+ RwandaFieldBoundary, "number_of_patches_per_split", {"train": 5, "test": 5}
+ )
+ monkeypatch.setattr(
+ RwandaFieldBoundary,
+ "md5s",
+ {
+ "train_images": "af9395e2e49deefebb35fa65fa378ba3",
+ "test_images": "d104bb82323a39e7c3b3b7dd0156f550",
+ "train_labels": "6cceaf16a141cf73179253a783e7d51b",
+ },
+ )
+
+ root = str(tmp_path)
+ split = request.param
+ transforms = nn.Identity()
+ return RwandaFieldBoundary(
+ root, split, transforms=transforms, api_key="", download=True, checksum=True
+ )
+
+ def test_getitem(self, dataset: RwandaFieldBoundary) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ if dataset.split == "train":
+ assert isinstance(x["mask"], torch.Tensor)
+ else:
+ assert "mask" not in x
+
+ def test_len(self, dataset: RwandaFieldBoundary) -> None:
+ assert len(dataset) == 5
+
+ def test_add(self, dataset: RwandaFieldBoundary) -> None:
+ ds = dataset + dataset
+ assert isinstance(ds, ConcatDataset)
+ assert len(ds) == 10
+
+ def test_needs_extraction(self, tmp_path: Path) -> None:
+ root = str(tmp_path)
+ for fn in [
+ "nasa_rwanda_field_boundary_competition_source_train.tar.gz",
+ "nasa_rwanda_field_boundary_competition_source_test.tar.gz",
+ "nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
+ ]:
+ url = os.path.join("tests", "data", "rwanda_field_boundary", fn)
+ shutil.copy(url, root)
+ RwandaFieldBoundary(root, checksum=False)
+
+ def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None:
+ RwandaFieldBoundary(root=dataset.root)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ RwandaFieldBoundary(str(tmp_path))
+
+ def test_corrupted(self, tmp_path: Path) -> None:
+ for fn in [
+ "nasa_rwanda_field_boundary_competition_source_train.tar.gz",
+ "nasa_rwanda_field_boundary_competition_source_test.tar.gz",
+ "nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
+ ]:
+ with open(os.path.join(tmp_path, fn), "w") as f:
+ f.write("bad")
+ with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
+ RwandaFieldBoundary(root=str(tmp_path), checksum=True)
+
+ def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
+ radiant_mlhub = pytest.importorskip("radiant_mlhub", minversion="0.3")
+ monkeypatch.setattr(radiant_mlhub.Collection, "fetch", fetch)
+ monkeypatch.setattr(
+ RwandaFieldBoundary,
+ "md5s",
+ {"train_images": "bad", "test_images": "bad", "train_labels": "bad"},
+ )
+ root = str(tmp_path)
+ with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ RwandaFieldBoundary(root, "train", api_key="", download=True, checksum=True)
+
+ def test_no_api_key(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Must provide an API key to download"):
+ RwandaFieldBoundary(str(tmp_path), api_key=None, download=True)
+
+ def test_invalid_bands(self) -> None:
+ with pytest.raises(ValueError, match="is an invalid band name."):
+ RwandaFieldBoundary(bands=("foo", "bar"))
+
+ def test_plot(self, dataset: RwandaFieldBoundary) -> None:
+ x = dataset[0].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+
+ if dataset.split == "train":
+ x["prediction"] = x["mask"].clone()
+ dataset.plot(x)
+ plt.close()
+
+ def test_failed_plot(self, dataset: RwandaFieldBoundary) -> None:
+ single_band_dataset = RwandaFieldBoundary(root=dataset.root, bands=("B01",))
+ with pytest.raises(ValueError, match="Dataset doesn't contain"):
+ x = single_band_dataset[0].copy()
+ single_band_dataset.plot(x, suptitle="Test")
diff --git a/tests/datasets/test_seasonet.py b/tests/datasets/test_seasonet.py
new file mode 100644
index 00000000000..6d7280537ab
--- /dev/null
+++ b/tests/datasets/test_seasonet.py
@@ -0,0 +1,194 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import glob
+import os
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pytest
+import torch
+import torch.nn as nn
+from _pytest.fixtures import SubRequest
+from pytest import MonkeyPatch
+from torch.utils.data import ConcatDataset
+
+import torchgeo.datasets.utils
+from torchgeo.datasets import DatasetNotFoundError, SeasoNet
+
+
+def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None:
+ shutil.copy(url, root)
+ torchgeo.datasets.utils.check_integrity(
+ os.path.join(root, os.path.basename(url)), md5
+ )
+
+
+class TestSeasoNet:
+ @pytest.fixture(
+ params=zip(
+ ["train", "val", "test"],
+ [{"Spring"}, {"Summer", "Fall", "Winter", "Snow"}, SeasoNet.all_seasons],
+ [SeasoNet.all_bands, ["10m_IR", "10m_RGB", "60m"], ["10m_RGB"]],
+ [[1], [2], [1, 2]],
+ [1, 3, 5],
+ )
+ )
+ def dataset(
+ self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
+ ) -> SeasoNet:
+ monkeypatch.setattr(torchgeo.datasets.seasonet, "download_url", download_url)
+ monkeypatch.setitem(
+ SeasoNet.metadata[0], "md5", "836a0896eba0e3005208f3fd180e429d"
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[1], "md5", "405656c8c19d822620bbb9f92e687337"
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[2], "md5", "dc0dda18de019a9f50a794b8b4060a3b"
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[3], "md5", "a70abca62e78eb1591555809dc81d91d"
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[4], "md5", "67651cc9095207e07ea4db1a71f0ebc2"
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[5], "md5", "576324ba1c32a7e9ba858f1c2577cf2a"
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[6], "md5", "48ff6e9e01fdd92379e5712e4f336ea8"
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[0],
+ "url",
+ os.path.join("tests", "data", "seasonet", "spring.zip"),
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[1],
+ "url",
+ os.path.join("tests", "data", "seasonet", "summer.zip"),
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[2],
+ "url",
+ os.path.join("tests", "data", "seasonet", "fall.zip"),
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[3],
+ "url",
+ os.path.join("tests", "data", "seasonet", "winter.zip"),
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[4],
+ "url",
+ os.path.join("tests", "data", "seasonet", "snow.zip"),
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[5],
+ "url",
+ os.path.join("tests", "data", "seasonet", "splits.zip"),
+ )
+ monkeypatch.setitem(
+ SeasoNet.metadata[6],
+ "url",
+ os.path.join("tests", "data", "seasonet", "meta.csv"),
+ )
+ root = str(tmp_path)
+ split, seasons, bands, grids, concat_seasons = request.param
+ transforms = nn.Identity()
+ return SeasoNet(
+ root=root,
+ split=split,
+ seasons=seasons,
+ bands=bands,
+ grids=grids,
+ concat_seasons=concat_seasons,
+ transforms=transforms,
+ download=True,
+ checksum=True,
+ )
+
+ def test_getitem(self, dataset: SeasoNet) -> None:
+ x = dataset[0]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["mask"], torch.Tensor)
+ assert x["image"].shape == (dataset.concat_seasons * dataset.channels, 120, 120)
+ assert x["mask"].shape == (120, 120)
+
+ def test_len(self, dataset: SeasoNet, request: SubRequest) -> None:
+ num_seasons = len(request.node.callspec.params["dataset"][1])
+ num_grids = len(request.node.callspec.params["dataset"][3])
+ if dataset.concat_seasons == 1:
+ assert len(dataset) == num_grids * num_seasons
+ else:
+ assert len(dataset) == num_grids
+
+ def test_add(self, dataset: SeasoNet, request: SubRequest) -> None:
+ ds = dataset + dataset
+ assert isinstance(ds, ConcatDataset)
+ num_seasons = len(request.node.callspec.params["dataset"][1])
+ num_grids = len(request.node.callspec.params["dataset"][3])
+ if dataset.concat_seasons == 1:
+ assert len(ds) == num_grids * num_seasons * 2
+ else:
+ assert len(ds) == num_grids * 2
+
+ def test_already_extracted(self, dataset: SeasoNet) -> None:
+ SeasoNet(root=dataset.root)
+
+ def test_already_downloaded(self, tmp_path: Path) -> None:
+ paths = os.path.join("tests", "data", "seasonet", "*.*")
+ root = str(tmp_path)
+ for path in glob.iglob(paths):
+ shutil.copy(path, root)
+ SeasoNet(root)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
+ SeasoNet(str(tmp_path), download=False)
+
+ def test_out_of_bounds(self, dataset: SeasoNet) -> None:
+ with pytest.raises(IndexError):
+ dataset[5]
+
+ def test_invalid_seasons(self) -> None:
+ with pytest.raises(AssertionError):
+ SeasoNet(seasons=("Salt", "Pepper"))
+
+ def test_invalid_bands(self) -> None:
+ with pytest.raises(AssertionError):
+ SeasoNet(bands=["30s_TOMARS", "9in_NAILS"])
+
+ def test_invalid_grids(self) -> None:
+ with pytest.raises(AssertionError):
+ SeasoNet(grids=[42])
+
+ def test_invalid_split(self) -> None:
+ with pytest.raises(AssertionError):
+ SeasoNet(split="banana")
+
+ def test_invalid_concat(self) -> None:
+ with pytest.raises(AssertionError):
+ SeasoNet(seasons={"Spring", "Winter", "Snow"}, concat_seasons=4)
+
+ def test_plot(self, dataset: SeasoNet) -> None:
+ x = dataset[0]
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+ dataset.plot(x, show_legend=False)
+ plt.close()
+ x["prediction"] = x["mask"].clone()
+ dataset.plot(x)
+ plt.close()
+
+ def test_plot_no_rgb(self) -> None:
+ with pytest.raises(ValueError, match="Dataset does not contain"):
+ root = os.path.join("tests", "data", "seasonet")
+ dataset = SeasoNet(root, bands=["10m_IR"])
+ x = dataset[0]
+ dataset.plot(x)
diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py
index 743a90c2b7c..f89efec4b24 100644
--- a/tests/datasets/test_seco.py
+++ b/tests/datasets/test_seco.py
@@ -15,7 +15,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import SeasonalContrastS2
+from torchgeo.datasets import DatasetNotFoundError, SeasonalContrastS2
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -98,7 +98,7 @@ def test_invalid_band(self) -> None:
SeasonalContrastS2(bands=["A1steaksauce"])
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SeasonalContrastS2(str(tmp_path))
def test_plot(self, dataset: SeasonalContrastS2) -> None:
diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py
index 55ecb406bf2..f802105e0c6 100644
--- a/tests/datasets/test_sen12ms.py
+++ b/tests/datasets/test_sen12ms.py
@@ -12,7 +12,7 @@
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
-from torchgeo.datasets import SEN12MS
+from torchgeo.datasets import SEN12MS, DatasetNotFoundError
class TestSEN12MS:
@@ -65,10 +65,10 @@ def test_invalid_split(self) -> None:
SEN12MS(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SEN12MS(str(tmp_path), checksum=True)
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SEN12MS(str(tmp_path), checksum=False)
def test_check_integrity_light(self) -> None:
diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py
index 2d6c42aa89e..f22a1c5fcc3 100644
--- a/tests/datasets/test_sentinel.py
+++ b/tests/datasets/test_sentinel.py
@@ -13,6 +13,7 @@
from torchgeo.datasets import (
BoundingBox,
+ DatasetNotFoundError,
IntersectionDataset,
Sentinel1,
Sentinel2,
@@ -64,7 +65,7 @@ def test_plot(self, dataset: Sentinel2) -> None:
plt.close()
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="No Sentinel1 data was found in "):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Sentinel1(str(tmp_path))
def test_empty_bands(self) -> None:
@@ -123,7 +124,7 @@ def test_or(self, dataset: Sentinel2) -> None:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
- with pytest.raises(FileNotFoundError, match="No Sentinel2 data was found in "):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Sentinel2(str(tmp_path))
def test_plot(self, dataset: Sentinel2) -> None:
@@ -133,7 +134,7 @@ def test_plot(self, dataset: Sentinel2) -> None:
def test_plot_wrong_bands(self, dataset: Sentinel2) -> None:
bands = ["B02"]
- ds = Sentinel2(root=dataset.root, res=dataset.res, bands=bands)
+ ds = Sentinel2(dataset.paths, res=dataset.res, bands=bands)
x = dataset[dataset.bounds]
with pytest.raises(
ValueError, match="Dataset doesn't contain some of the RGB bands"
diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py
index a35d6fb9e59..392c3255eda 100644
--- a/tests/datasets/test_skippd.py
+++ b/tests/datasets/test_skippd.py
@@ -4,6 +4,7 @@
import builtins
import os
import shutil
+from itertools import product
from pathlib import Path
from typing import Any
@@ -15,7 +16,9 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import SKIPPD
+from torchgeo.datasets import SKIPPD, DatasetNotFoundError
+
+pytest.importorskip("h5py", minversion="3")
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -23,21 +26,32 @@ def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
class TestSKIPPD:
- @pytest.fixture(params=["trainval", "test"])
+ @pytest.fixture(params=product(["nowcast", "forecast"], ["trainval", "test"]))
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> SKIPPD:
+ task, split = request.param
+
monkeypatch.setattr(torchgeo.datasets.skippd, "download_url", download_url)
- md5 = "1133b2de453a9674776abd7519af5051"
+ md5 = {
+ "nowcast": "6f5e54906927278b189f9281a2f54f39",
+ "forecast": "f3b5d7d5c28ba238144fa1e726c46969",
+ }
monkeypatch.setattr(SKIPPD, "md5", md5)
- url = os.path.join("tests", "data", "skippd", "dj417rh1007.zip")
+ url = os.path.join("tests", "data", "skippd", "{}")
monkeypatch.setattr(SKIPPD, "url", url)
monkeypatch.setattr(plt, "show", lambda *args: None)
root = str(tmp_path)
- split = request.param
transforms = nn.Identity()
- return SKIPPD(root, split, transforms, download=True, checksum=True)
+ return SKIPPD(
+ root=root,
+ task=task,
+ split=split,
+ transforms=transforms,
+ download=True,
+ checksum=True,
+ )
@pytest.fixture
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
@@ -62,11 +76,14 @@ def test_mock_missing_module(
def test_already_extracted(self, dataset: SKIPPD) -> None:
SKIPPD(root=dataset.root, download=True)
- def test_already_downloaded(self, tmp_path: Path) -> None:
- pathname = os.path.join("tests", "data", "skippd", "dj417rh1007.zip")
+ @pytest.mark.parametrize("task", ["nowcast", "forecast"])
+ def test_already_downloaded(self, tmp_path: Path, task: str) -> None:
+ pathname = os.path.join(
+ "tests", "data", "skippd", f"2017_2019_images_pv_processed_{task}.zip"
+ )
root = str(tmp_path)
shutil.copy(pathname, root)
- SKIPPD(root)
+ SKIPPD(root=root, task=task)
@pytest.mark.parametrize("index", [0, 1, 2])
def test_getitem(self, dataset: SKIPPD, index: int) -> None:
@@ -75,7 +92,10 @@ def test_getitem(self, dataset: SKIPPD, index: int) -> None:
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["label"], torch.Tensor)
assert isinstance(x["date"], str)
- assert x["image"].shape == (3, 64, 64)
+ if dataset.task == "nowcast":
+ assert x["image"].shape == (3, 64, 64)
+ else:
+ assert x["image"].shape == (48, 64, 64)
def test_len(self, dataset: SKIPPD) -> None:
assert len(dataset) == 3
@@ -85,7 +105,7 @@ def test_invalid_split(self) -> None:
SKIPPD(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SKIPPD(str(tmp_path))
def test_plot(self, dataset: SKIPPD) -> None:
@@ -93,6 +113,9 @@ def test_plot(self, dataset: SKIPPD) -> None:
plt.close()
sample = dataset[0]
- sample["prediction"] = sample["label"]
+ if dataset.task == "nowcast":
+ sample["prediction"] = sample["label"]
+ else:
+ sample["prediction"] = sample["label"][-1]
dataset.plot(sample)
plt.close()
diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py
index 5802d81b537..2e093f288c6 100644
--- a/tests/datasets/test_so2sat.py
+++ b/tests/datasets/test_so2sat.py
@@ -13,7 +13,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import So2Sat
+from torchgeo.datasets import DatasetNotFoundError, So2Sat
pytest.importorskip("h5py", minversion="3")
@@ -70,7 +70,7 @@ def test_invalid_bands(self) -> None:
So2Sat(bands=("OK", "BK"))
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
So2Sat(str(tmp_path))
def test_plot(self, dataset: So2Sat) -> None:
diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py
index 4b79ca3fa35..046b83cfba1 100644
--- a/tests/datasets/test_spacenet.py
+++ b/tests/datasets/test_spacenet.py
@@ -14,6 +14,7 @@
from pytest import MonkeyPatch
from torchgeo.datasets import (
+ DatasetNotFoundError,
SpaceNet1,
SpaceNet2,
SpaceNet3,
@@ -91,7 +92,7 @@ def test_already_downloaded(self, dataset: SpaceNet1) -> None:
SpaceNet1(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet1(str(tmp_path))
def test_plot(self, dataset: SpaceNet1) -> None:
@@ -147,7 +148,7 @@ def test_already_downloaded(self, dataset: SpaceNet2) -> None:
SpaceNet2(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet2(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet2) -> None:
@@ -207,7 +208,7 @@ def test_already_downloaded(self, dataset: SpaceNet3) -> None:
SpaceNet3(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet3(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
@@ -271,7 +272,7 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None:
SpaceNet4(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet4(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
@@ -333,7 +334,7 @@ def test_already_downloaded(self, dataset: SpaceNet5) -> None:
SpaceNet5(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet5(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet5) -> None:
@@ -427,7 +428,7 @@ def test_already_downloaded(self, dataset: SpaceNet4) -> None:
SpaceNet7(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet7(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
diff --git a/tests/datasets/test_ssl4eo.py b/tests/datasets/test_ssl4eo.py
index e2b9b36feff..68b6df002b4 100644
--- a/tests/datasets/test_ssl4eo.py
+++ b/tests/datasets/test_ssl4eo.py
@@ -15,7 +15,7 @@
from torch.utils.data import ConcatDataset
import torchgeo
-from torchgeo.datasets import SSL4EOL, SSL4EOS12
+from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -94,7 +94,7 @@ def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None:
SSL4EOL(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SSL4EOL(str(tmp_path))
def test_invalid_split(self) -> None:
@@ -155,7 +155,7 @@ def test_invalid_split(self) -> None:
SSL4EOS12(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SSL4EOS12(str(tmp_path))
def test_plot(self, dataset: SSL4EOS12) -> None:
diff --git a/tests/datasets/test_ssl4eo_benchmark.py b/tests/datasets/test_ssl4eo_benchmark.py
index 1cc1809f80d..0d5b3f94030 100644
--- a/tests/datasets/test_ssl4eo_benchmark.py
+++ b/tests/datasets/test_ssl4eo_benchmark.py
@@ -16,7 +16,13 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import CDL, NLCD, RasterDataset, SSL4EOLBenchmark
+from torchgeo.datasets import (
+ CDL,
+ NLCD,
+ DatasetNotFoundError,
+ RasterDataset,
+ SSL4EOLBenchmark,
+)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -137,7 +143,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
SSL4EOLBenchmark(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SSL4EOLBenchmark(str(tmp_path))
def test_plot(self, dataset: SSL4EOLBenchmark) -> None:
diff --git a/tests/datasets/test_sustainbench_crop_yield.py b/tests/datasets/test_sustainbench_crop_yield.py
index 04c056ed505..071f0c81a8f 100644
--- a/tests/datasets/test_sustainbench_crop_yield.py
+++ b/tests/datasets/test_sustainbench_crop_yield.py
@@ -13,7 +13,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import SustainBenchCropYield
+from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -71,7 +71,7 @@ def test_invalid_split(self) -> None:
SustainBenchCropYield(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SustainBenchCropYield(str(tmp_path))
def test_plot(self, dataset: SustainBenchCropYield) -> None:
diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py
index c4096276725..61c76f9cecd 100644
--- a/tests/datasets/test_ucmerced.py
+++ b/tests/datasets/test_ucmerced.py
@@ -14,7 +14,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import UCMerced
+from torchgeo.datasets import DatasetNotFoundError, UCMerced
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -81,10 +81,7 @@ def test_already_downloaded_not_extracted(
UCMerced(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
UCMerced(str(tmp_path))
def test_plot(self, dataset: UCMerced) -> None:
diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py
index 45ff2dc7244..4c256ad5c25 100644
--- a/tests/datasets/test_usavars.py
+++ b/tests/datasets/test_usavars.py
@@ -1,11 +1,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import pytest
import torch
@@ -16,9 +14,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import USAVars
-
-pytest.importorskip("pandas", minversion="1.1.3")
+from torchgeo.datasets import DatasetNotFoundError, USAVars
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@@ -133,33 +129,9 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
USAVars(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
USAVars(str(tmp_path))
- @pytest.fixture(params=["pandas"])
- def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
- import_orig = builtins.__import__
- package = str(request.param)
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == package:
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, "__import__", mocked_import)
- return package
-
- def test_mock_missing_module(
- self, dataset: USAVars, mock_missing_module: str
- ) -> None:
- package = mock_missing_module
- if package == "pandas":
- with pytest.raises(
- ImportError,
- match=f"{package} is not installed and is required to use this dataset",
- ):
- USAVars(dataset.root)
-
def test_plot(self, dataset: USAVars) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py
index 0346d13a2b0..be5950fb071 100644
--- a/tests/datasets/test_utils.py
+++ b/tests/datasets/test_utils.py
@@ -18,10 +18,12 @@
import torch
from pytest import MonkeyPatch
from rasterio.crs import CRS
+from torch.utils.data import Dataset
import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
+ DatasetNotFoundError,
concat_samples,
disambiguate_timestamp,
download_and_extract_archive,
@@ -36,6 +38,52 @@
)
+class TestDatasetNotFoundError:
+ def test_none(self) -> None:
+ ds: Dataset[Any] = Dataset()
+ match = "Dataset not found."
+ with pytest.raises(DatasetNotFoundError, match=match):
+ raise DatasetNotFoundError(ds)
+
+ def test_root(self) -> None:
+ ds: Dataset[Any] = Dataset()
+ ds.root = "foo" # type: ignore[attr-defined]
+ match = "Dataset not found in `root='foo'` and cannot be automatically "
+ match += "downloaded, either specify a different `root` or manually "
+ match += "download the dataset."
+ with pytest.raises(DatasetNotFoundError, match=match):
+ raise DatasetNotFoundError(ds)
+
+ def test_paths(self) -> None:
+ ds: Dataset[Any] = Dataset()
+ ds.paths = "foo" # type: ignore[attr-defined]
+ match = "Dataset not found in `paths='foo'` and cannot be automatically "
+ match += "downloaded, either specify a different `paths` or manually "
+ match += "download the dataset."
+ with pytest.raises(DatasetNotFoundError, match=match):
+ raise DatasetNotFoundError(ds)
+
+ def test_root_download(self) -> None:
+ ds: Dataset[Any] = Dataset()
+ ds.root = "foo" # type: ignore[attr-defined]
+ ds.download = False # type: ignore[attr-defined]
+ match = "Dataset not found in `root='foo'` and `download=False`, either "
+ match += "specify a different `root` or use `download=True` to automatically "
+ match += "download the dataset."
+ with pytest.raises(DatasetNotFoundError, match=match):
+ raise DatasetNotFoundError(ds)
+
+ def test_paths_download(self) -> None:
+ ds: Dataset[Any] = Dataset()
+ ds.paths = "foo" # type: ignore[attr-defined]
+ ds.download = False # type: ignore[attr-defined]
+ match = "Dataset not found in `paths='foo'` and `download=False`, either "
+ match += "specify a different `paths` or use `download=True` to automatically "
+ match += "download the dataset."
+ with pytest.raises(DatasetNotFoundError, match=match):
+ raise DatasetNotFoundError(ds)
+
+
@pytest.fixture
def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
@@ -48,7 +96,7 @@ def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
monkeypatch.setattr(builtins, "__import__", mocked_import)
-class Dataset:
+class MLHubDataset:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
@@ -66,8 +114,8 @@ def download(self, output_dir: str, **kwargs: str) -> None:
shutil.copy(tarball, output_dir)
-def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
- return Dataset()
+def fetch_dataset(dataset_id: str, **kwargs: str) -> MLHubDataset:
+ return MLHubDataset()
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
diff --git a/tests/datasets/test_vaihingen.py b/tests/datasets/test_vaihingen.py
index 56240b2aca6..fe34bccea08 100644
--- a/tests/datasets/test_vaihingen.py
+++ b/tests/datasets/test_vaihingen.py
@@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import Vaihingen2D
+from torchgeo.datasets import DatasetNotFoundError, Vaihingen2D
class TestVaihingen2D:
@@ -69,7 +69,7 @@ def test_invalid_split(self) -> None:
Vaihingen2D(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Vaihingen2D(str(tmp_path))
def test_plot(self, dataset: Vaihingen2D) -> None:
diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py
index ce69db7ef81..805b84a3117 100644
--- a/tests/datasets/test_vhr10.py
+++ b/tests/datasets/test_vhr10.py
@@ -16,7 +16,7 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import VHR10
+from torchgeo.datasets import VHR10, DatasetNotFoundError
pytest.importorskip("pycocotools")
@@ -90,7 +90,7 @@ def test_invalid_split(self) -> None:
VHR10(split="train")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
VHR10(str(tmp_path))
def test_mock_missing_module(
diff --git a/tests/datasets/test_western_usa_live_fuel_moisture.py b/tests/datasets/test_western_usa_live_fuel_moisture.py
index 0015081f35f..3337965228e 100644
--- a/tests/datasets/test_western_usa_live_fuel_moisture.py
+++ b/tests/datasets/test_western_usa_live_fuel_moisture.py
@@ -1,19 +1,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-import builtins
import os
import shutil
from pathlib import Path
-from typing import Any
import pytest
import torch
import torch.nn as nn
-from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import WesternUSALiveFuelMoisture
+from torchgeo.datasets import DatasetNotFoundError, WesternUSALiveFuelMoisture
class Collection:
@@ -68,33 +65,9 @@ def test_already_downloaded(self, tmp_path: Path) -> None:
WesternUSALiveFuelMoisture(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
WesternUSALiveFuelMoisture(str(tmp_path))
def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None:
with pytest.raises(AssertionError, match="Invalid input variable name."):
WesternUSALiveFuelMoisture(dataset.root, input_features=["foo"])
-
- @pytest.fixture(params=["pandas"])
- def mock_missing_module(self, monkeypatch: MonkeyPatch, request: SubRequest) -> str:
- import_orig = builtins.__import__
- package = str(request.param)
-
- def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
- if name == package:
- raise ImportError()
- return import_orig(name, *args, **kwargs)
-
- monkeypatch.setattr(builtins, "__import__", mocked_import)
- return package
-
- def test_mock_missing_module(
- self, dataset: WesternUSALiveFuelMoisture, mock_missing_module: str
- ) -> None:
- package = mock_missing_module
- if package == "pandas":
- with pytest.raises(
- ImportError,
- match=f"{package} is not installed and is required to use this dataset",
- ):
- WesternUSALiveFuelMoisture(dataset.root)
diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py
index 957de5a36d7..28292775a46 100644
--- a/tests/datasets/test_xview2.py
+++ b/tests/datasets/test_xview2.py
@@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
-from torchgeo.datasets import XView2
+from torchgeo.datasets import DatasetNotFoundError, XView2
class TestXView2:
@@ -80,7 +80,7 @@ def test_invalid_split(self) -> None:
XView2(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
- with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
XView2(str(tmp_path))
def test_plot(self, dataset: XView2) -> None:
diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py
index 18b5a87eb65..27325672869 100644
--- a/tests/datasets/test_zuericrop.py
+++ b/tests/datasets/test_zuericrop.py
@@ -14,7 +14,7 @@
from pytest import MonkeyPatch
import torchgeo.datasets.utils
-from torchgeo.datasets import ZueriCrop
+from torchgeo.datasets import DatasetNotFoundError, ZueriCrop
pytest.importorskip("h5py", minversion="3")
@@ -79,10 +79,7 @@ def test_already_downloaded(self, dataset: ZueriCrop) -> None:
ZueriCrop(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
- err = "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- with pytest.raises(RuntimeError, match=err):
+ with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ZueriCrop(str(tmp_path))
def test_mock_missing_module(
diff --git a/tests/models/test_rcf.py b/tests/models/test_rcf.py
index 870b66bf495..f6d8091bc0d 100644
--- a/tests/models/test_rcf.py
+++ b/tests/models/test_rcf.py
@@ -1,25 +1,28 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
+import os
+
import pytest
import torch
+from torchgeo.datasets import EuroSAT
from torchgeo.models import RCF
class TestRCF:
def test_in_channels(self) -> None:
- model = RCF(in_channels=5, features=4, kernel_size=3)
+ model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
x = torch.randn(2, 5, 64, 64)
model(x)
- model = RCF(in_channels=3, features=4, kernel_size=3)
+ model = RCF(in_channels=3, features=4, kernel_size=3, mode="gaussian")
match = "to have 3 channels, but got 5 channels instead"
with pytest.raises(RuntimeError, match=match):
model(x)
def test_num_features(self) -> None:
- model = RCF(in_channels=5, features=4, kernel_size=3)
+ model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
x = torch.randn(2, 5, 64, 64)
y = model(x)
assert y.shape[1] == 4
@@ -29,14 +32,27 @@ def test_num_features(self) -> None:
assert y.shape[0] == 4
def test_untrainable(self) -> None:
- model = RCF(in_channels=5, features=4, kernel_size=3)
+ model = RCF(in_channels=5, features=4, kernel_size=3, mode="gaussian")
assert len(list(model.parameters())) == 0
def test_biases(self) -> None:
- model = RCF(features=24, bias=10)
+ model = RCF(features=24, bias=10, mode="gaussian")
assert torch.all(model.biases == 10)
def test_seed(self) -> None:
- weights1 = RCF(seed=1).weights
- weights2 = RCF(seed=1).weights
+ weights1 = RCF(seed=1, mode="gaussian").weights
+ weights2 = RCF(seed=1, mode="gaussian").weights
assert torch.allclose(weights1, weights2)
+
+ def test_empirical(self) -> None:
+ root = os.path.join("tests", "data", "eurosat")
+ ds = EuroSAT(root=root, bands=EuroSAT.rgb_bands, split="train")
+ model = RCF(
+ in_channels=3, features=4, kernel_size=3, mode="empirical", dataset=ds
+ )
+ model(torch.randn(2, 3, 8, 8))
+
+ def test_empirical_no_dataset(self) -> None:
+ match = "dataset must be provided when mode is 'empirical'"
+ with pytest.raises(ValueError, match=match):
+ RCF(mode="empirical", dataset=None)
diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py
index f6b6eba0293..90318bb4b0c 100644
--- a/tests/models/test_resnet.py
+++ b/tests/models/test_resnet.py
@@ -47,7 +47,9 @@ def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta["in_chans"]
- sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)}
+ sample = {
+ "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
+ }
mocked_weights.transforms(sample)
@pytest.mark.slow
@@ -82,7 +84,9 @@ def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None:
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta["in_chans"]
- sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)}
+ sample = {
+ "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
+ }
mocked_weights.transforms(sample)
@pytest.mark.slow
diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py
index 37baf9d9c2f..ea4b509ca95 100644
--- a/tests/models/test_vit.py
+++ b/tests/models/test_vit.py
@@ -49,7 +49,9 @@ def test_vit_weights(self, mocked_weights: WeightsEnum) -> None:
def test_transforms(self, mocked_weights: WeightsEnum) -> None:
c = mocked_weights.meta["in_chans"]
- sample = {"image": torch.arange(c * 4 * 4, dtype=torch.float).view(c, 4, 4)}
+ sample = {
+ "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224)
+ }
mocked_weights.transforms(sample)
@pytest.mark.slow
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 00000000000..ab811e1cd6a
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,9 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import subprocess
+import sys
+
+
+def test_help() -> None:
+ subprocess.run([sys.executable, "-m", "torchgeo", "--help"], check=True)
diff --git a/tests/test_train.py b/tests/test_train.py
deleted file mode 100644
index c51e293c374..00000000000
--- a/tests/test_train.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-
-import os
-import re
-import subprocess
-import sys
-from pathlib import Path
-
-import pytest
-
-pytestmark = pytest.mark.slow
-
-
-def test_required_args() -> None:
- args = [sys.executable, "train.py"]
- ps = subprocess.run(args, capture_output=True)
- assert ps.returncode != 0
- assert b"ConfigKeyError" in ps.stderr
-
-
-def test_output_file(tmp_path: Path) -> None:
- output_file = tmp_path / "output"
- output_file.touch()
- args = [
- sys.executable,
- "train.py",
- "experiment.name=test",
- "program.output_dir=" + str(output_file),
- "experiment.task=test",
- ]
- ps = subprocess.run(args, capture_output=True)
- assert ps.returncode != 0
- assert b"NotADirectoryError" in ps.stderr
-
-
-def test_experiment_dir_not_empty(tmp_path: Path) -> None:
- output_dir = tmp_path / "output"
- experiment_dir = output_dir / "test"
- experiment_dir.mkdir(parents=True)
- experiment_file = experiment_dir / "foo"
- experiment_file.touch()
- args = [
- sys.executable,
- "train.py",
- "experiment.name=test",
- "program.output_dir=" + str(output_dir),
- "experiment.task=test",
- ]
- ps = subprocess.run(args, capture_output=True)
- assert ps.returncode != 0
- assert b"FileExistsError" in ps.stderr
-
-
-def test_overwrite_experiment_dir(tmp_path: Path) -> None:
- experiment_name = "test"
- output_dir = tmp_path / "output"
- data_dir = os.path.join("tests", "data", "cyclone")
- log_dir = tmp_path / "logs"
- experiment_dir = output_dir / experiment_name
- experiment_dir.mkdir(parents=True)
- experiment_file = experiment_dir / "foo"
- experiment_file.touch()
- args = [
- sys.executable,
- "train.py",
- "experiment.name=test",
- "program.output_dir=" + str(output_dir),
- "program.data_dir=" + data_dir,
- "program.log_dir=" + str(log_dir),
- "experiment.task=cyclone",
- "experiment.datamodule.root=" + data_dir,
- "program.overwrite=True",
- "trainer.accelerator=cpu",
- "trainer.max_epochs=1",
- ]
- ps = subprocess.run(args, capture_output=True, check=True)
- assert re.search(
- b"The experiment directory, .*, already exists, we might overwrite data in it!",
- ps.stdout,
- )
-
-
-def test_invalid_task(tmp_path: Path) -> None:
- output_dir = tmp_path / "output"
- args = [
- sys.executable,
- "train.py",
- "experiment.name=foo",
- "program.output_dir=" + str(output_dir),
- "experiment.task=foo",
- ]
- ps = subprocess.run(args, capture_output=True)
- assert ps.returncode != 0
- assert b"ValueError" in ps.stderr
-
-
-def test_missing_config_file(tmp_path: Path) -> None:
- output_dir = tmp_path / "output"
- config_file = tmp_path / "config.yaml"
- args = [
- sys.executable,
- "train.py",
- "experiment.name=test",
- "program.output_dir=" + str(output_dir),
- "experiment.task=test",
- "config_file=" + str(config_file),
- ]
- ps = subprocess.run(args, capture_output=True)
- assert ps.returncode != 0
- assert b"FileNotFoundError" in ps.stderr
-
-
-def test_config_file(tmp_path: Path) -> None:
- output_dir = tmp_path / "output"
- data_dir = os.path.join("tests", "data", "cyclone")
- log_dir = tmp_path / "logs"
- config_file = tmp_path / "config.yaml"
- config_file.write_text(
- f"""
-program:
- output_dir: {output_dir}
- data_dir: {data_dir}
- log_dir: {log_dir}
-experiment:
- name: test
- task: cyclone
- datamodule:
- root: {data_dir}
-trainer:
- accelerator: cpu
- max_epochs: 1
-"""
- )
- args = [sys.executable, "train.py", "config_file=" + str(config_file)]
- subprocess.run(args, check=True)
diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py
index 83ad3099ae6..235a1681a70 100644
--- a/tests/trainers/test_byol.py
+++ b/tests/trainers/test_byol.py
@@ -10,20 +10,16 @@
import torch
import torch.nn as nn
import torchvision
-from hydra.utils import instantiate
-from lightning.pytorch import Trainer
-from omegaconf import OmegaConf
from pytest import MonkeyPatch
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
+from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
-from .test_segmentation import SegmentationTestModel
-
def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
state_dict: dict[str, Any] = torch.load(url)
@@ -32,8 +28,8 @@ def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]:
class TestBYOL:
def test_custom_augment_fn(self) -> None:
- backbone = resnet18()
- layer = backbone.conv1
+ model = resnet18()
+ layer = model.conv1
new_layer = nn.Conv2d(
in_channels=4,
out_channels=layer.out_channels,
@@ -42,9 +38,9 @@ def test_custom_augment_fn(self) -> None:
padding=layer.padding,
bias=layer.bias,
).requires_grad_()
- backbone.conv1 = new_layer
+ model.conv1 = new_layer
augment_fn = SimCLRAugmentation((2, 2))
- BYOL(backbone, augment_fn=augment_fn)
+ BYOL(model, augment_fn=augment_fn)
class TestBYOLTask:
@@ -63,7 +59,7 @@ class TestBYOLTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
+ config = os.path.join("tests", "conf", name + ".yaml")
if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)
@@ -71,31 +67,20 @@ def test_trainer(
if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
-
- # Instantiate model
- model = instantiate(conf.module)
- model.backbone = SegmentationTestModel(**conf.module)
-
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
- trainer.fit(model=model, datamodule=datamodule)
-
- @pytest.fixture
- def model_kwargs(self) -> dict[str, Any]:
- return {
- "backbone": "resnet18",
- "in_channels": 13,
- "loss": "ce",
- "num_classes": 10,
- "weights": None,
- }
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
@pytest.fixture
def weights(self) -> WeightsEnum:
@@ -117,41 +102,36 @@ def mocked_weights(
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
- def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
- model_kwargs["weights"] = checkpoint
+ def test_weight_file(self, checkpoint: str) -> None:
with pytest.warns(UserWarning):
- BYOLTask(**model_kwargs)
+ BYOLTask(model="resnet18", in_channels=13, weights=checkpoint)
- def test_weight_enum(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = mocked_weights
- BYOLTask(**model_kwargs)
+ def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
+ BYOLTask(
+ model=mocked_weights.meta["model"],
+ weights=mocked_weights,
+ in_channels=mocked_weights.meta["in_chans"],
+ )
- def test_weight_str(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = str(mocked_weights)
- BYOLTask(**model_kwargs)
+ def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
+ BYOLTask(
+ model=mocked_weights.meta["model"],
+ weights=str(mocked_weights),
+ in_channels=mocked_weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_enum_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = weights
- BYOLTask(**model_kwargs)
+ def test_weight_enum_download(self, weights: WeightsEnum) -> None:
+ BYOLTask(
+ model=weights.meta["model"],
+ weights=weights,
+ in_channels=weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_str_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = str(weights)
- BYOLTask(**model_kwargs)
+ def test_weight_str_download(self, weights: WeightsEnum) -> None:
+ BYOLTask(
+ model=weights.meta["model"],
+ weights=str(weights),
+ in_channels=weights.meta["in_chans"],
+ )
diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py
index fb8b9c91d69..b5c8243c41a 100644
--- a/tests/trainers/test_classification.py
+++ b/tests/trainers/test_classification.py
@@ -10,9 +10,7 @@
import torch
import torch.nn as nn
import torchvision
-from hydra.utils import instantiate
from lightning.pytorch import Trainer
-from omegaconf import OmegaConf
from pytest import MonkeyPatch
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum
@@ -23,6 +21,7 @@
MisconfigurationException,
)
from torchgeo.datasets import BigEarthNet, EuroSAT
+from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask
@@ -87,42 +86,33 @@ def test_trainer(
if name.startswith("so2sat"):
pytest.importorskip("h5py", minversion="3")
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
+ config = os.path.join("tests", "conf", name + ".yaml")
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
-
- # Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
- model = instantiate(conf.module)
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
- trainer.fit(model=model, datamodule=datamodule)
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
try:
- trainer.test(model=model, datamodule=datamodule)
+ main(["test"] + args)
except MisconfigurationException:
pass
try:
- trainer.predict(model=model, datamodule=datamodule)
+ main(["predict"] + args)
except MisconfigurationException:
pass
- @pytest.fixture
- def model_kwargs(self) -> dict[str, Any]:
- return {
- "model": "resnet18",
- "in_channels": 13,
- "loss": "ce",
- "num_classes": 10,
- "weights": None,
- }
-
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@@ -143,61 +133,59 @@ def mocked_weights(
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
- def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
- model_kwargs["weights"] = checkpoint
+ def test_weight_file(self, checkpoint: str) -> None:
with pytest.warns(UserWarning):
- ClassificationTask(**model_kwargs)
+ ClassificationTask(
+ model="resnet18", weights=checkpoint, in_channels=13, num_classes=10
+ )
- def test_weight_enum(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = mocked_weights
+ def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
with pytest.warns(UserWarning):
- ClassificationTask(**model_kwargs)
-
- def test_weight_str(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = str(mocked_weights)
+ ClassificationTask(
+ model=mocked_weights.meta["model"],
+ weights=mocked_weights,
+ in_channels=mocked_weights.meta["in_chans"],
+ num_classes=10,
+ )
+
+ def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
with pytest.warns(UserWarning):
- ClassificationTask(**model_kwargs)
+ ClassificationTask(
+ model=mocked_weights.meta["model"],
+ weights=str(mocked_weights),
+ in_channels=mocked_weights.meta["in_chans"],
+ num_classes=10,
+ )
@pytest.mark.slow
- def test_weight_enum_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = weights
- ClassificationTask(**model_kwargs)
+ def test_weight_enum_download(self, weights: WeightsEnum) -> None:
+ ClassificationTask(
+ model=weights.meta["model"],
+ weights=weights,
+ in_channels=weights.meta["in_chans"],
+ num_classes=10,
+ )
@pytest.mark.slow
- def test_weight_str_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = str(weights)
- ClassificationTask(**model_kwargs)
+ def test_weight_str_download(self, weights: WeightsEnum) -> None:
+ ClassificationTask(
+ model=weights.meta["model"],
+ weights=str(weights),
+ in_channels=weights.meta["in_chans"],
+ num_classes=10,
+ )
- def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None:
- model_kwargs["loss"] = "invalid_loss"
+ def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
- ClassificationTask(**model_kwargs)
+ ClassificationTask(model="resnet18", loss="invalid_loss")
- def test_no_rgb(
- self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool
- ) -> None:
+ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(EuroSATDataModule, "plot", plot)
datamodule = EuroSATDataModule(
root="tests/data/eurosat", batch_size=1, num_workers=0
)
- model = ClassificationTask(**model_kwargs)
+ model = ClassificationTask(model="resnet18", in_channels=13, num_classes=10)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -206,11 +194,11 @@ def test_no_rgb(
)
trainer.validate(model=model, datamodule=datamodule)
- def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None:
+ def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictClassificationDataModule(
root="tests/data/eurosat", batch_size=1, num_workers=0
)
- model = ClassificationTask(**model_kwargs)
+ model = ClassificationTask(model="resnet18", in_channels=13, num_classes=10)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -222,12 +210,8 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None
@pytest.mark.parametrize(
"model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"]
)
- def test_freeze_backbone(
- self, model_name: str, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["freeze_backbone"] = True
- model_kwargs["model"] = model_name
- model = ClassificationTask(**model_kwargs)
+ def test_freeze_backbone(self, model_name: str) -> None:
+ model = ClassificationTask(model=model_name, freeze_backbone=True)
assert not all([param.requires_grad for param in model.model.parameters()])
assert all(
[param.requires_grad for param in model.model.get_classifier().parameters()]
@@ -241,56 +225,46 @@ class TestMultiLabelClassificationTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
-
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
+ config = os.path.join("tests", "conf", name + ".yaml")
- # Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
- model = instantiate(conf.module)
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
- trainer.fit(model=model, datamodule=datamodule)
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
try:
- trainer.test(model=model, datamodule=datamodule)
+ main(["test"] + args)
except MisconfigurationException:
pass
try:
- trainer.predict(model=model, datamodule=datamodule)
+ main(["predict"] + args)
except MisconfigurationException:
pass
- @pytest.fixture
- def model_kwargs(self) -> dict[str, Any]:
- return {
- "model": "resnet18",
- "in_channels": 14,
- "loss": "bce",
- "num_classes": 19,
- "weights": None,
- }
-
- def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None:
- model_kwargs["loss"] = "invalid_loss"
+ def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
- MultiLabelClassificationTask(**model_kwargs)
+ MultiLabelClassificationTask(model="resnet18", loss="invalid_loss")
- def test_no_rgb(
- self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool
- ) -> None:
+ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(BigEarthNetDataModule, "plot", plot)
datamodule = BigEarthNetDataModule(
root="tests/data/bigearthnet", batch_size=1, num_workers=0
)
- model = MultiLabelClassificationTask(**model_kwargs)
+ model = MultiLabelClassificationTask(
+ model="resnet18", in_channels=14, num_classes=19, loss="bce"
+ )
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -299,11 +273,13 @@ def test_no_rgb(
)
trainer.validate(model=model, datamodule=datamodule)
- def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None:
+ def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictMultiLabelClassificationDataModule(
root="tests/data/bigearthnet", batch_size=1, num_workers=0
)
- model = MultiLabelClassificationTask(**model_kwargs)
+ model = MultiLabelClassificationTask(
+ model="resnet18", in_channels=14, num_classes=19, loss="bce"
+ )
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py
index 77ac3a3d768..a05e1820dbf 100644
--- a/tests/trainers/test_detection.py
+++ b/tests/trainers/test_detection.py
@@ -8,16 +8,18 @@
import torch
import torch.nn as nn
import torchvision.models.detection
-from hydra.utils import instantiate
from lightning.pytorch import Trainer
-from omegaconf import OmegaConf
from pytest import MonkeyPatch
from torch.nn.modules import Module
from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule
from torchgeo.datasets import NASAMarineDebris
+from torchgeo.main import main
from torchgeo.trainers import ObjectDetectionTask
+# MAP metric requires pycocotools to be installed
+pytest.importorskip("pycocotools")
+
class PredictObjectDetectionDataModule(NASAMarineDebrisDataModule):
def setup(self, stage: str) -> None:
@@ -63,12 +65,8 @@ class TestObjectDetectionTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool
) -> None:
- conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml"))
-
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
+ config = os.path.join("tests", "conf", name + ".yaml")
- # Instantiate model
monkeypatch.setattr(
torchvision.models.detection, "FasterRCNN", ObjectDetectionTestModel
)
@@ -78,54 +76,49 @@ def test_trainer(
monkeypatch.setattr(
torchvision.models.detection, "RetinaNet", ObjectDetectionTestModel
)
- conf.module.model = model_name
- model = instantiate(conf.module)
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
- trainer.fit(model=model, datamodule=datamodule)
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
try:
- trainer.test(model=model, datamodule=datamodule)
+ main(["test"] + args)
except MisconfigurationException:
pass
try:
- trainer.predict(model=model, datamodule=datamodule)
+ main(["predict"] + args)
except MisconfigurationException:
pass
- @pytest.fixture
- def model_kwargs(self) -> dict[Any, Any]:
- return {"model": "faster-rcnn", "backbone": "resnet18", "num_classes": 2}
-
- def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None:
- model_kwargs["model"] = "invalid_model"
+ def test_invalid_model(self) -> None:
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
- ObjectDetectionTask(**model_kwargs)
+ ObjectDetectionTask(model="invalid_model")
- def test_invalid_backbone(self, model_kwargs: dict[Any, Any]) -> None:
- model_kwargs["backbone"] = "invalid_backbone"
+ def test_invalid_backbone(self) -> None:
match = "Backbone type 'invalid_backbone' is not valid."
with pytest.raises(ValueError, match=match):
- ObjectDetectionTask(**model_kwargs)
+ ObjectDetectionTask(backbone="invalid_backbone")
- def test_non_pretrained_backbone(self, model_kwargs: dict[Any, Any]) -> None:
- model_kwargs["pretrained"] = False
- ObjectDetectionTask(**model_kwargs)
+ def test_pretrained_backbone(self) -> None:
+ ObjectDetectionTask(backbone="resnet18", weights=True)
- def test_no_rgb(
- self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool
- ) -> None:
+ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot)
datamodule = NASAMarineDebrisDataModule(
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0
)
- model = ObjectDetectionTask(**model_kwargs)
+ model = ObjectDetectionTask(backbone="resnet18", num_classes=2)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -134,11 +127,11 @@ def test_no_rgb(
)
trainer.validate(model=model, datamodule=datamodule)
- def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None:
+ def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictObjectDetectionDataModule(
root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0
)
- model = ObjectDetectionTask(**model_kwargs)
+ model = ObjectDetectionTask(backbone="resnet18", num_classes=2)
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -148,10 +141,8 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None
trainer.predict(model=model, datamodule=datamodule)
@pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"])
- def test_freeze_backbone(
- self, model_name: str, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["freeze_backbone"] = True
- model_kwargs["model"] = model_name
- model = ObjectDetectionTask(**model_kwargs)
+ def test_freeze_backbone(self, model_name: str) -> None:
+ model = ObjectDetectionTask(
+ model=model_name, backbone="resnet18", freeze_backbone=True
+ )
assert not all([param.requires_grad for param in model.model.parameters()])
diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py
index ec5acf25bd9..a4d19fc98f2 100644
--- a/tests/trainers/test_moco.py
+++ b/tests/trainers/test_moco.py
@@ -9,14 +9,12 @@
import timm
import torch
import torchvision
-from hydra.utils import instantiate
-from lightning.pytorch import Trainer
-from omegaconf import OmegaConf
from pytest import MonkeyPatch
from torch.nn import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
+from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import MoCoTask
@@ -48,7 +46,7 @@ class TestMoCoTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
+ config = os.path.join("tests", "conf", name + ".yaml")
if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)
@@ -56,21 +54,22 @@ def test_trainer(
if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
-
- # Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
- model = instantiate(conf.module)
-
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
- trainer.fit(model=model, datamodule=datamodule)
+
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match="MoCo v1 uses a memory bank"):
@@ -105,49 +104,40 @@ def mocked_weights(
return weights
def test_weight_file(self, checkpoint: str) -> None:
- model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
- MoCoTask(**model_kwargs)
+ MoCoTask(model="resnet18", weights=checkpoint)
def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": mocked_weights.meta["model"],
- "weights": mocked_weights,
- "in_channels": mocked_weights.meta["in_chans"],
- }
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
- MoCoTask(**model_kwargs)
+ MoCoTask(
+ model=mocked_weights.meta["model"],
+ weights=mocked_weights,
+ in_channels=mocked_weights.meta["in_chans"],
+ )
def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": mocked_weights.meta["model"],
- "weights": str(mocked_weights),
- "in_channels": mocked_weights.meta["in_chans"],
- }
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
- MoCoTask(**model_kwargs)
+ MoCoTask(
+ model=mocked_weights.meta["model"],
+ weights=str(mocked_weights),
+ in_channels=mocked_weights.meta["in_chans"],
+ )
@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": weights.meta["model"],
- "weights": weights,
- "in_channels": weights.meta["in_chans"],
- }
- match = "num classes .* != num classes in pretrained model"
- with pytest.warns(UserWarning, match=match):
- MoCoTask(**model_kwargs)
+ MoCoTask(
+ model=weights.meta["model"],
+ weights=weights,
+ in_channels=weights.meta["in_chans"],
+ )
@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": weights.meta["model"],
- "weights": str(weights),
- "in_channels": weights.meta["in_chans"],
- }
- match = "num classes .* != num classes in pretrained model"
- with pytest.warns(UserWarning, match=match):
- MoCoTask(**model_kwargs)
+ MoCoTask(
+ model=weights.meta["model"],
+ weights=str(weights),
+ in_channels=weights.meta["in_chans"],
+ )
diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py
index 5bf3443e61b..3f13df0737e 100644
--- a/tests/trainers/test_regression.py
+++ b/tests/trainers/test_regression.py
@@ -11,15 +11,14 @@
import torch
import torch.nn as nn
import torchvision
-from hydra.utils import instantiate
from lightning.pytorch import Trainer
-from omegaconf import OmegaConf
from pytest import MonkeyPatch
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import MisconfigurationException, TropicalCycloneDataModule
from torchgeo.datasets import TropicalCyclone
+from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import PixelwiseRegressionTask, RegressionTask
@@ -56,55 +55,47 @@ def plot(*args: Any, **kwargs: Any) -> None:
raise ValueError
-def create_model(**kwargs: Any) -> Module:
- return PixelwiseRegressionTestModel(**kwargs)
-
-
class TestRegressionTask:
+ @classmethod
+ def create_model(*args: Any, **kwargs: Any) -> Module:
+ return RegressionTestModel(**kwargs)
+
@pytest.mark.parametrize(
"name", ["cowc_counting", "cyclone", "sustainbench_crop_yield", "skippd"]
)
- def test_trainer(self, name: str, fast_dev_run: bool) -> None:
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
-
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
-
- # Instantiate model
- model = instantiate(conf.module)
-
- model.model = RegressionTestModel(
- in_chans=conf.module.in_channels, num_classes=conf.module.num_outputs
- )
-
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
-
- trainer.fit(model=model, datamodule=datamodule)
+ def test_trainer(
+ self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
+ ) -> None:
+ if name == "skippd":
+ pytest.importorskip("h5py", minversion="3")
+
+ config = os.path.join("tests", "conf", name + ".yaml")
+
+ monkeypatch.setattr(timm, "create_model", self.create_model)
+
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
try:
- trainer.test(model=model, datamodule=datamodule)
+ main(["test"] + args)
except MisconfigurationException:
pass
try:
- trainer.predict(model=model, datamodule=datamodule)
+ main(["predict"] + args)
except MisconfigurationException:
pass
- @pytest.fixture
- def model_kwargs(self) -> dict[str, Any]:
- return {
- "model": "resnet18",
- "weights": None,
- "num_outputs": 1,
- "in_channels": 3,
- "loss": "mse",
- }
-
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@@ -125,55 +116,48 @@ def mocked_weights(
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
- def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
- model_kwargs["weights"] = checkpoint
+ def test_weight_file(self, checkpoint: str) -> None:
with pytest.warns(UserWarning):
- RegressionTask(**model_kwargs)
+ RegressionTask(model="resnet18", weights=checkpoint)
- def test_weight_enum(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = mocked_weights
+ def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
with pytest.warns(UserWarning):
- RegressionTask(**model_kwargs)
+ RegressionTask(
+ model=mocked_weights.meta["model"],
+ weights=mocked_weights,
+ in_channels=mocked_weights.meta["in_chans"],
+ )
- def test_weight_str(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = str(mocked_weights)
+ def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
with pytest.warns(UserWarning):
- RegressionTask(**model_kwargs)
+ RegressionTask(
+ model=mocked_weights.meta["model"],
+ weights=str(mocked_weights),
+ in_channels=mocked_weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_enum_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = weights
- RegressionTask(**model_kwargs)
+ def test_weight_enum_download(self, weights: WeightsEnum) -> None:
+ RegressionTask(
+ model=weights.meta["model"],
+ weights=weights,
+ in_channels=weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_str_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["model"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = str(weights)
- RegressionTask(**model_kwargs)
+ def test_weight_str_download(self, weights: WeightsEnum) -> None:
+ RegressionTask(
+ model=weights.meta["model"],
+ weights=str(weights),
+ in_channels=weights.meta["in_chans"],
+ )
- def test_no_rgb(
- self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool
- ) -> None:
+ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot)
datamodule = TropicalCycloneDataModule(
root="tests/data/cyclone", batch_size=1, num_workers=0
)
- model = RegressionTask(**model_kwargs)
+ model = RegressionTask(model="resnet18")
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -182,11 +166,11 @@ def test_no_rgb(
)
trainer.validate(model=model, datamodule=datamodule)
- def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None:
+ def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictRegressionDataModule(
root="tests/data/cyclone", batch_size=1, num_workers=0
)
- model = RegressionTask(**model_kwargs)
+ model = RegressionTask(model="resnet18")
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -195,21 +179,16 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None
)
trainer.predict(model=model, datamodule=datamodule)
- def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None:
- model_kwargs["loss"] = "invalid_loss"
+ def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
- RegressionTask(**model_kwargs)
+ RegressionTask(model="resnet18", loss="invalid_loss")
@pytest.mark.parametrize(
"model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"]
)
- def test_freeze_backbone(
- self, model_name: str, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["freeze_backbone"] = True
- model_kwargs["model"] = model_name
- model = RegressionTask(**model_kwargs)
+ def test_freeze_backbone(self, model_name: str) -> None:
+ model = RegressionTask(model=model_name, freeze_backbone=True)
assert not all([param.requires_grad for param in model.model.parameters()])
assert all(
[param.requires_grad for param in model.model.get_classifier().parameters()]
@@ -217,80 +196,46 @@ def test_freeze_backbone(
class TestPixelwiseRegressionTask:
- @pytest.mark.parametrize(
- "name,batch_size,loss,model_type",
- [
- ("inria", 1, "mse", "unet"),
- ("inria", 2, "mae", "deeplabv3+"),
- ("inria", 1, "mse", "fcn"),
- ],
- )
+ @classmethod
+ def create_model(*args: Any, **kwargs: Any) -> Module:
+ return PixelwiseRegressionTestModel(**kwargs)
+
+ @pytest.mark.parametrize("name", ["inria_unet", "inria_deeplab", "inria_fcn"])
def test_trainer(
- self,
- monkeypatch: MonkeyPatch,
- name: str,
- batch_size: int,
- loss: str,
- model_type: str,
- fast_dev_run: bool,
- model_kwargs: dict[str, Any],
+ self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
-
- # Instantiate datamodule
- conf.datamodule.batch_size = batch_size
- datamodule = instantiate(conf.datamodule)
-
- # Instantiate model
- monkeypatch.setattr(smp, "Unet", create_model)
- monkeypatch.setattr(smp, "DeepLabV3Plus", create_model)
- model_kwargs["model"] = model_type
- model_kwargs["loss"] = loss
-
- if model_type == "fcn":
- model_kwargs["num_filters"] = 2
-
- model = PixelwiseRegressionTask(**model_kwargs)
- model.model = PixelwiseRegressionTestModel(
- in_channels=model_kwargs["in_channels"]
- )
-
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
-
- trainer.fit(model=model, datamodule=datamodule)
+ config = os.path.join("tests", "conf", name + ".yaml")
+
+ monkeypatch.setattr(smp, "Unet", self.create_model)
+ monkeypatch.setattr(smp, "DeepLabV3Plus", self.create_model)
+
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
try:
- trainer.test(model=model, datamodule=datamodule)
+ main(["test"] + args)
except MisconfigurationException:
pass
try:
- trainer.predict(model=model, datamodule=datamodule)
+ main(["predict"] + args)
except MisconfigurationException:
pass
- def test_invalid_model(self, model_kwargs: dict[str, Any]) -> None:
- model_kwargs["model"] = "invalid_model"
+ def test_invalid_model(self) -> None:
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
- PixelwiseRegressionTask(**model_kwargs)
-
- @pytest.fixture
- def model_kwargs(self) -> dict[str, Any]:
- return {
- "model": "unet",
- "backbone": "resnet18",
- "weights": None,
- "num_outputs": 1,
- "in_channels": 3,
- "loss": "mse",
- "learning_rate": 1e-3,
- "learning_rate_schedule_patience": 6,
- }
+ PixelwiseRegressionTask(model="invalid_model")
@pytest.fixture
def weights(self) -> WeightsEnum:
@@ -312,55 +257,51 @@ def mocked_weights(
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
- def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
- model_kwargs["weights"] = checkpoint
- PixelwiseRegressionTask(**model_kwargs)
+ def test_weight_file(self, checkpoint: str) -> None:
+ PixelwiseRegressionTask(model="unet", backbone="resnet18", weights=checkpoint)
- def test_weight_enum(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = mocked_weights
- PixelwiseRegressionTask(**model_kwargs)
+ def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
+ PixelwiseRegressionTask(
+ model="unet",
+ backbone=mocked_weights.meta["model"],
+ weights=mocked_weights,
+ in_channels=mocked_weights.meta["in_chans"],
+ )
- def test_weight_str(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = str(mocked_weights)
- PixelwiseRegressionTask(**model_kwargs)
+ def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
+ PixelwiseRegressionTask(
+ model="unet",
+ backbone=mocked_weights.meta["model"],
+ weights=str(mocked_weights),
+ in_channels=mocked_weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_enum_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = weights
- PixelwiseRegressionTask(**model_kwargs)
+ def test_weight_enum_download(self, weights: WeightsEnum) -> None:
+ PixelwiseRegressionTask(
+ model="unet",
+ backbone=weights.meta["model"],
+ weights=weights,
+ in_channels=weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_str_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = str(weights)
- PixelwiseRegressionTask(**model_kwargs)
+ def test_weight_str_download(self, weights: WeightsEnum) -> None:
+ PixelwiseRegressionTask(
+ model="unet",
+ backbone=weights.meta["model"],
+ weights=str(weights),
+ in_channels=weights.meta["in_chans"],
+ )
+ @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
@pytest.mark.parametrize(
"backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"]
)
- @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
- def test_freeze_backbone(
- self, backbone: str, model_name: str, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["freeze_backbone"] = True
- model_kwargs["model"] = model_name
- model_kwargs["backbone"] = backbone
- model = PixelwiseRegressionTask(**model_kwargs)
+ def test_freeze_backbone(self, model_name: str, backbone: str) -> None:
+ model = PixelwiseRegressionTask(
+ model=model_name, backbone=backbone, freeze_backbone=True
+ )
assert all(
[param.requires_grad is False for param in model.model.encoder.parameters()]
)
@@ -373,12 +314,10 @@ def test_freeze_backbone(
)
@pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
- def test_freeze_decoder(
- self, model_name: str, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["freeze_decoder"] = True
- model_kwargs["model"] = model_name
- model = PixelwiseRegressionTask(**model_kwargs)
+ def test_freeze_decoder(self, model_name: str) -> None:
+ model = PixelwiseRegressionTask(
+ model=model_name, backbone="resnet18", freeze_decoder=True
+ )
assert all(
[param.requires_grad is False for param in model.model.decoder.parameters()]
)
diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py
index c29d37e24eb..bce3e5b0aaf 100644
--- a/tests/trainers/test_segmentation.py
+++ b/tests/trainers/test_segmentation.py
@@ -5,22 +5,20 @@
from pathlib import Path
from typing import Any, cast
-import numpy as np
import pytest
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
import torchvision
-from hydra.utils import instantiate
from lightning.pytorch import Trainer
-from omegaconf import OmegaConf
from pytest import MonkeyPatch
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datamodules import MisconfigurationException, SEN12MSDataModule
from torchgeo.datasets import LandCoverAI
+from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import SemanticSegmentationTask
@@ -85,45 +83,34 @@ def test_trainer(
sha256 = "ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b"
monkeypatch.setattr(LandCoverAI, "sha256", sha256)
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
+ config = os.path.join("tests", "conf", name + ".yaml")
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
-
- # Instantiate model
monkeypatch.setattr(smp, "Unet", create_model)
monkeypatch.setattr(smp, "DeepLabV3Plus", create_model)
- model = instantiate(conf.module)
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
- trainer.fit(model=model, datamodule=datamodule)
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
try:
- trainer.test(model=model, datamodule=datamodule)
+ main(["test"] + args)
except MisconfigurationException:
pass
try:
- trainer.predict(model=model, datamodule=datamodule)
+ main(["predict"] + args)
except MisconfigurationException:
pass
- @pytest.fixture
- def model_kwargs(self) -> dict[Any, Any]:
- return {
- "model": "unet",
- "backbone": "resnet18",
- "weights": None,
- "in_channels": 3,
- "num_classes": 6,
- "loss": "ce",
- "ignore_index": 0,
- }
-
@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
@@ -144,78 +131,62 @@ def mocked_weights(
monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load)
return weights
- def test_weight_file(self, model_kwargs: dict[str, Any], checkpoint: str) -> None:
- model_kwargs["weights"] = checkpoint
- SemanticSegmentationTask(**model_kwargs)
+ def test_weight_file(self, checkpoint: str) -> None:
+ SemanticSegmentationTask(backbone="resnet18", weights=checkpoint, num_classes=6)
- def test_weight_enum(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = mocked_weights
- SemanticSegmentationTask(**model_kwargs)
+ def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
+ SemanticSegmentationTask(
+ backbone=mocked_weights.meta["model"],
+ weights=mocked_weights,
+ in_channels=mocked_weights.meta["in_chans"],
+ )
- def test_weight_str(
- self, model_kwargs: dict[str, Any], mocked_weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = mocked_weights.meta["model"]
- model_kwargs["in_channels"] = mocked_weights.meta["in_chans"]
- model_kwargs["weights"] = str(mocked_weights)
- SemanticSegmentationTask(**model_kwargs)
+ def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
+ SemanticSegmentationTask(
+ backbone=mocked_weights.meta["model"],
+ weights=str(mocked_weights),
+ in_channels=mocked_weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_enum_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = weights
- SemanticSegmentationTask(**model_kwargs)
+ def test_weight_enum_download(self, weights: WeightsEnum) -> None:
+ SemanticSegmentationTask(
+ backbone=weights.meta["model"],
+ weights=weights,
+ in_channels=weights.meta["in_chans"],
+ )
@pytest.mark.slow
- def test_weight_str_download(
- self, model_kwargs: dict[str, Any], weights: WeightsEnum
- ) -> None:
- model_kwargs["backbone"] = weights.meta["model"]
- model_kwargs["in_channels"] = weights.meta["in_chans"]
- model_kwargs["weights"] = str(weights)
- SemanticSegmentationTask(**model_kwargs)
+ def test_weight_str_download(self, weights: WeightsEnum) -> None:
+ SemanticSegmentationTask(
+ backbone=weights.meta["model"],
+ weights=str(weights),
+ in_channels=weights.meta["in_chans"],
+ )
- def test_invalid_model(self, model_kwargs: dict[Any, Any]) -> None:
- model_kwargs["model"] = "invalid_model"
+ def test_invalid_model(self) -> None:
match = "Model type 'invalid_model' is not valid."
with pytest.raises(ValueError, match=match):
- SemanticSegmentationTask(**model_kwargs)
+ SemanticSegmentationTask(model="invalid_model")
- def test_invalid_loss(self, model_kwargs: dict[Any, Any]) -> None:
- model_kwargs["loss"] = "invalid_loss"
+ def test_invalid_loss(self) -> None:
match = "Loss type 'invalid_loss' is not valid."
with pytest.raises(ValueError, match=match):
- SemanticSegmentationTask(**model_kwargs)
-
- def test_invalid_ignoreindex(self, model_kwargs: dict[Any, Any]) -> None:
- model_kwargs["ignore_index"] = "0"
- match = "ignore_index must be an int or None"
- with pytest.raises(ValueError, match=match):
- SemanticSegmentationTask(**model_kwargs)
+ SemanticSegmentationTask(loss="invalid_loss")
- def test_ignoreindex_with_jaccard(self, model_kwargs: dict[Any, Any]) -> None:
- model_kwargs["loss"] = "jaccard"
- model_kwargs["ignore_index"] = 0
+ def test_ignoreindex_with_jaccard(self) -> None:
match = "ignore_index has no effect on training when loss='jaccard'"
with pytest.warns(UserWarning, match=match):
- SemanticSegmentationTask(**model_kwargs)
+ SemanticSegmentationTask(loss="jaccard", ignore_index=0)
- def test_no_rgb(
- self, monkeypatch: MonkeyPatch, model_kwargs: dict[Any, Any], fast_dev_run: bool
- ) -> None:
- model_kwargs["in_channels"] = 15
+ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(SEN12MSDataModule, "plot", plot)
datamodule = SEN12MSDataModule(
root="tests/data/sen12ms", batch_size=1, num_workers=0
)
- model = SemanticSegmentationTask(**model_kwargs)
+ model = SemanticSegmentationTask(
+ backbone="resnet18", in_channels=15, num_classes=6
+ )
trainer = Trainer(
accelerator="cpu",
fast_dev_run=fast_dev_run,
@@ -224,17 +195,14 @@ def test_no_rgb(
)
trainer.validate(model=model, datamodule=datamodule)
+ @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
@pytest.mark.parametrize(
"backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"]
)
- @pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
- def test_freeze_backbone(
- self, backbone: str, model_name: str, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["freeze_backbone"] = True
- model_kwargs["model"] = model_name
- model_kwargs["backbone"] = backbone
- model = SemanticSegmentationTask(**model_kwargs)
+ def test_freeze_backbone(self, model_name: str, backbone: str) -> None:
+ model = SemanticSegmentationTask(
+ model=model_name, backbone=backbone, freeze_backbone=True
+ )
assert all(
[param.requires_grad is False for param in model.model.encoder.parameters()]
)
@@ -247,12 +215,8 @@ def test_freeze_backbone(
)
@pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
- def test_freeze_decoder(
- self, model_name: str, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["freeze_decoder"] = True
- model_kwargs["model"] = model_name
- model = SemanticSegmentationTask(**model_kwargs)
+ def test_freeze_decoder(self, model_name: str) -> None:
+ model = SemanticSegmentationTask(model=model_name, freeze_decoder=True)
assert all(
[param.requires_grad is False for param in model.model.decoder.parameters()]
)
@@ -263,23 +227,3 @@ def test_freeze_decoder(
for param in model.model.segmentation_head.parameters()
]
)
-
- @pytest.mark.parametrize(
- "class_weights", [torch.tensor([1, 2, 3]), np.array([1, 2, 3]), [1, 2, 3]]
- )
- def test_classweights_valid(
- self, class_weights: Any, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["class_weights"] = class_weights
- sst = SemanticSegmentationTask(**model_kwargs)
- assert isinstance(sst.loss.weight, torch.Tensor)
- assert torch.equal(sst.loss.weight, torch.tensor([1.0, 2.0, 3.0]))
- assert sst.loss.weight.dtype == torch.float32
-
- @pytest.mark.parametrize("class_weights", [[], None])
- def test_classweights_empty(
- self, class_weights: Any, model_kwargs: dict[Any, Any]
- ) -> None:
- model_kwargs["class_weights"] = class_weights
- sst = SemanticSegmentationTask(**model_kwargs)
- assert sst.loss.weight is None
diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py
index 5fa8f15ea15..b7629cb3654 100644
--- a/tests/trainers/test_simclr.py
+++ b/tests/trainers/test_simclr.py
@@ -9,14 +9,12 @@
import timm
import torch
import torchvision
-from hydra.utils import instantiate
-from lightning.pytorch import Trainer
-from omegaconf import OmegaConf
from pytest import MonkeyPatch
from torch.nn import Module
from torchvision.models._api import WeightsEnum
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
+from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.trainers import SimCLRTask
@@ -48,7 +46,7 @@ class TestSimCLRTask:
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
- conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml"))
+ config = os.path.join("tests", "conf", name + ".yaml")
if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)
@@ -56,21 +54,22 @@ def test_trainer(
if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)
- # Instantiate datamodule
- datamodule = instantiate(conf.datamodule)
-
- # Instantiate model
monkeypatch.setattr(timm, "create_model", create_model)
- model = instantiate(conf.module)
-
- # Instantiate trainer
- trainer = Trainer(
- accelerator="cpu",
- fast_dev_run=fast_dev_run,
- log_every_n_steps=1,
- max_epochs=1,
- )
- trainer.fit(model=model, datamodule=datamodule)
+
+ args = [
+ "--config",
+ config,
+ "--trainer.accelerator",
+ "cpu",
+ "--trainer.fast_dev_run",
+ str(fast_dev_run),
+ "--trainer.max_epochs",
+ "1",
+ "--trainer.log_every_n_steps",
+ "1",
+ ]
+
+ main(["fit"] + args)
def test_version_warnings(self) -> None:
with pytest.warns(UserWarning, match="SimCLR v1 only uses 2 layers"):
@@ -103,49 +102,40 @@ def mocked_weights(
return weights
def test_weight_file(self, checkpoint: str) -> None:
- model_kwargs: dict[str, Any] = {"model": "resnet18", "weights": checkpoint}
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
- SimCLRTask(**model_kwargs)
+ SimCLRTask(model="resnet18", weights=checkpoint)
def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": mocked_weights.meta["model"],
- "weights": mocked_weights,
- "in_channels": mocked_weights.meta["in_chans"],
- }
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
- SimCLRTask(**model_kwargs)
+ SimCLRTask(
+ model=mocked_weights.meta["model"],
+ weights=mocked_weights,
+ in_channels=mocked_weights.meta["in_chans"],
+ )
def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": mocked_weights.meta["model"],
- "weights": str(mocked_weights),
- "in_channels": mocked_weights.meta["in_chans"],
- }
match = "num classes .* != num classes in pretrained model"
with pytest.warns(UserWarning, match=match):
- SimCLRTask(**model_kwargs)
+ SimCLRTask(
+ model=mocked_weights.meta["model"],
+ weights=str(mocked_weights),
+ in_channels=mocked_weights.meta["in_chans"],
+ )
@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": weights.meta["model"],
- "weights": weights,
- "in_channels": weights.meta["in_chans"],
- }
- match = "num classes .* != num classes in pretrained model"
- with pytest.warns(UserWarning, match=match):
- SimCLRTask(**model_kwargs)
+ SimCLRTask(
+ model=weights.meta["model"],
+ weights=weights,
+ in_channels=weights.meta["in_chans"],
+ )
@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
- model_kwargs: dict[str, Any] = {
- "model": weights.meta["model"],
- "weights": str(weights),
- "in_channels": weights.meta["in_chans"],
- }
- match = "num classes .* != num classes in pretrained model"
- with pytest.warns(UserWarning, match=match):
- SimCLRTask(**model_kwargs)
+ SimCLRTask(
+ model=weights.meta["model"],
+ weights=str(weights),
+ in_channels=weights.meta["in_chans"],
+ )
diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py
index f09454f830d..0c1615db840 100644
--- a/torchgeo/__init__.py
+++ b/torchgeo/__init__.py
@@ -11,4 +11,4 @@
"""
__author__ = "Adam J. Stewart"
-__version__ = "0.5.0.dev0"
+__version__ = "0.6.0.dev0"
diff --git a/torchgeo/__main__.py b/torchgeo/__main__.py
new file mode 100644
index 00000000000..66e073ed1cb
--- /dev/null
+++ b/torchgeo/__main__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Command-line interface to TorchGeo."""
+
+from torchgeo.main import main
+
+main()
diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py
index 8761e850e18..66555c7b978 100644
--- a/torchgeo/datamodules/__init__.py
+++ b/torchgeo/datamodules/__init__.py
@@ -18,6 +18,7 @@
from .l7irish import L7IrishDataModule
from .l8biome import L8BiomeDataModule
from .landcoverai import LandCoverAIDataModule
+from .levircd import LEVIRCDPlusDataModule
from .loveda import LoveDADataModule
from .naip import NAIPChesapeakeDataModule
from .nasa_marine_debris import NASAMarineDebrisDataModule
@@ -56,6 +57,7 @@
"GID15DataModule",
"InriaAerialImageLabelingDataModule",
"LandCoverAIDataModule",
+ "LEVIRCDPlusDataModule",
"LoveDADataModule",
"NASAMarineDebrisDataModule",
"OSCDDataModule",
diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py
index f1d2418dffe..604c0d7e3ec 100644
--- a/torchgeo/datamodules/chesapeake.py
+++ b/torchgeo/datamodules/chesapeake.py
@@ -93,7 +93,7 @@ def __init__(
"""
# This is a rough estimate of how large of a patch we will need to sample in
# EPSG:3857 in order to guarantee a large enough patch in the local CRS.
- self.original_patch_size = patch_size * 2
+ self.original_patch_size = patch_size * 3
kwargs["transforms"] = _Transform(K.CenterCrop(patch_size))
super().__init__(
diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py
index b4267cfc50a..ccf2a90f691 100644
--- a/torchgeo/datamodules/eurosat.py
+++ b/torchgeo/datamodules/eurosat.py
@@ -10,41 +10,37 @@
from ..datasets import EuroSAT, EuroSAT100
from .geo import NonGeoDataModule
-MEAN = torch.tensor(
- [
- 1354.40546513,
- 1118.24399958,
- 1042.92983953,
- 947.62620298,
- 1199.47283961,
- 1999.79090914,
- 2369.22292565,
- 2296.82608323,
- 732.08340178,
- 12.11327804,
- 1819.01027855,
- 1118.92391149,
- 2594.14080798,
- ]
-)
-
-STD = torch.tensor(
- [
- 245.71762908,
- 333.00778264,
- 395.09249139,
- 593.75055589,
- 566.4170017,
- 861.18399006,
- 1086.63139075,
- 1117.98170791,
- 404.91978886,
- 4.77584468,
- 1002.58768311,
- 761.30323499,
- 1231.58581042,
- ]
-)
+MEAN = {
+ "B01": 1354.40546513,
+ "B02": 1118.24399958,
+ "B03": 1042.92983953,
+ "B04": 947.62620298,
+ "B05": 1199.47283961,
+ "B06": 1999.79090914,
+ "B07": 2369.22292565,
+ "B08": 2296.82608323,
+ "B8A": 732.08340178,
+ "B09": 12.11327804,
+ "B10": 1819.01027855,
+ "B11": 1118.92391149,
+ "B12": 2594.14080798,
+}
+
+STD = {
+ "B01": 245.71762908,
+ "B02": 333.00778264,
+ "B03": 395.09249139,
+ "B04": 593.75055589,
+ "B05": 566.4170017,
+ "B06": 861.18399006,
+ "B07": 1086.63139075,
+ "B08": 1117.98170791,
+ "B8A": 404.91978886,
+ "B09": 4.77584468,
+ "B10": 1002.58768311,
+ "B11": 761.30323499,
+ "B12": 1231.58581042,
+}
class EuroSATDataModule(NonGeoDataModule):
@@ -55,9 +51,6 @@ class EuroSATDataModule(NonGeoDataModule):
.. versionadded:: 0.2
"""
- mean = MEAN
- std = STD
-
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
@@ -71,6 +64,10 @@ def __init__(
"""
super().__init__(EuroSAT, batch_size, num_workers, **kwargs)
+ bands = kwargs.get("bands", EuroSAT.all_band_names)
+ self.mean = torch.tensor([MEAN[b] for b in bands])
+ self.std = torch.tensor([STD[b] for b in bands])
+
class EuroSAT100DataModule(NonGeoDataModule):
"""LightningDataModule implementation for the EuroSAT100 dataset.
@@ -80,9 +77,6 @@ class EuroSAT100DataModule(NonGeoDataModule):
.. versionadded:: 0.5
"""
- mean = MEAN
- std = STD
-
def __init__(
self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
) -> None:
@@ -95,3 +89,7 @@ def __init__(
:class:`~torchgeo.datasets.EuroSAT100`.
"""
super().__init__(EuroSAT100, batch_size, num_workers, **kwargs)
+
+ bands = kwargs.get("bands", EuroSAT.all_band_names)
+ self.mean = torch.tensor([MEAN[b] for b in bands])
+ self.std = torch.tensor([STD[b] for b in bands])
diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py
index c0e1b714710..96da7683ba2 100644
--- a/torchgeo/datamodules/geo.py
+++ b/torchgeo/datamodules/geo.py
@@ -6,9 +6,9 @@
from typing import Any, Callable, Optional, Union, cast
import kornia.augmentation as K
-import matplotlib.pyplot as plt
import torch
from lightning.pytorch import LightningDataModule
+from matplotlib.figure import Figure
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, default_collate
@@ -141,7 +141,7 @@ def on_after_batch_transfer(
return batch
- def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
+ def plot(self, *args: Any, **kwargs: Any) -> Optional[Figure]:
"""Run the plot method of the validation dataset if one exists.
Should only be called during 'fit' or 'validate' stages as ``val_dataset``
@@ -154,10 +154,12 @@ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
+ fig: Optional[Figure] = None
dataset = self.dataset or self.val_dataset
if dataset is not None:
if hasattr(dataset, "plot"):
- return dataset.plot(*args, **kwargs)
+ fig = dataset.plot(*args, **kwargs)
+ return fig
class GeoDataModule(BaseDataModule):
diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py
index 698273b9485..ca46c39dc6e 100644
--- a/torchgeo/datamodules/inria.py
+++ b/torchgeo/datamodules/inria.py
@@ -12,7 +12,6 @@
from ..transforms import AugmentationSequential
from ..transforms.transforms import _RandomNCrop
from .geo import NonGeoDataModule
-from .utils import dataset_split
class InriaAerialImageLabelingDataModule(NonGeoDataModule):
@@ -29,8 +28,6 @@ def __init__(
batch_size: int = 64,
patch_size: Union[tuple[int, int], int] = 64,
num_workers: int = 0,
- val_split_pct: float = 0.1,
- test_split_pct: float = 0.1,
**kwargs: Any,
) -> None:
"""Initialize a new InriaAerialImageLabelingDataModule instance.
@@ -40,16 +37,12 @@ def __init__(
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
Should be a multiple of 32 for most segmentation architectures.
num_workers: Number of workers for parallel data loading.
- val_split_pct: Percentage of the dataset to use as a validation set.
- test_split_pct: Percentage of the dataset to use as a test set.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.InriaAerialImageLabeling`.
"""
super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs)
self.patch_size = _to_tuple(patch_size)
- self.val_split_pct = val_split_pct
- self.test_split_pct = test_split_pct
self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
@@ -75,11 +68,10 @@ def setup(self, stage: str) -> None:
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
- if stage in ["fit", "validate", "test"]:
- self.dataset = InriaAerialImageLabeling(split="train", **self.kwargs)
- self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
- self.dataset, self.val_split_pct, self.test_split_pct
- )
+ if stage in ["fit"]:
+ self.train_dataset = InriaAerialImageLabeling(split="train", **self.kwargs)
+ if stage in ["fit", "validate"]:
+ self.val_dataset = InriaAerialImageLabeling(split="val", **self.kwargs)
if stage in ["predict"]:
# Test set masks are not public, use for prediction instead
self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs)
diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py
new file mode 100644
index 00000000000..b021d8c860b
--- /dev/null
+++ b/torchgeo/datamodules/levircd.py
@@ -0,0 +1,70 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""LEVIR-CD+ datamodule."""
+
+from typing import Any, Union
+
+import kornia.augmentation as K
+
+from torchgeo.datamodules.utils import dataset_split
+from torchgeo.samplers.utils import _to_tuple
+
+from ..datasets import LEVIRCDPlus
+from ..transforms import AugmentationSequential
+from ..transforms.transforms import _RandomNCrop
+from .geo import NonGeoDataModule
+
+
+class LEVIRCDPlusDataModule(NonGeoDataModule):
+ """LightningDataModule implementation for the LEVIR-CD+ dataset.
+
+ Uses the train/test splits from the dataset and further splits
+ the train split into train/val splits.
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(
+ self,
+ batch_size: int = 8,
+ patch_size: Union[tuple[int, int], int] = 256,
+ val_split_pct: float = 0.2,
+ num_workers: int = 0,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize a new LEVIRCDPlusDataModule instance.
+
+ Args:
+ batch_size: Size of each mini-batch.
+ patch_size: Size of each patch, either ``size`` or ``(height, width)``.
+ Should be a multiple of 32 for most segmentation architectures.
+ val_split_pct: Percentage of the dataset to use as a validation set.
+ num_workers: Number of workers for parallel data loading.
+ **kwargs: Additional keyword arguments passed to
+ :class:`~torchgeo.datasets.LEVIRCDPlus`.
+ """
+ super().__init__(LEVIRCDPlus, 1, num_workers, **kwargs)
+
+ self.patch_size = _to_tuple(patch_size)
+ self.val_split_pct = val_split_pct
+
+ self.aug = AugmentationSequential(
+ K.Normalize(mean=self.mean, std=self.std),
+ _RandomNCrop(self.patch_size, batch_size),
+ data_keys=["image1", "image2", "mask"],
+ )
+
+ def setup(self, stage: str) -> None:
+ """Set up datasets.
+
+ Args:
+ stage: Either 'fit', 'validate', 'test', or 'predict'.
+ """
+ if stage in ["fit", "validate"]:
+ self.dataset = LEVIRCDPlus(split="train", **self.kwargs)
+ self.train_dataset, self.val_dataset = dataset_split(
+ self.dataset, val_pct=self.val_split_pct
+ )
+ if stage in ["test"]:
+ self.test_dataset = LEVIRCDPlus(split="test", **self.kwargs)
diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py
index 28f64c5826c..1631734d094 100644
--- a/torchgeo/datamodules/naip.py
+++ b/torchgeo/datamodules/naip.py
@@ -6,7 +6,7 @@
from typing import Any, Optional, Union
import kornia.augmentation as K
-import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from ..datasets import NAIP, BoundingBox, Chesapeake13
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
@@ -95,7 +95,7 @@ def setup(self, stage: str) -> None:
self.dataset, self.patch_size, self.patch_size, test_roi
)
- def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
+ def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run NAIP plot method.
Args:
diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py
index 748c4038091..19f34677065 100644
--- a/torchgeo/datamodules/oscd.py
+++ b/torchgeo/datamodules/oscd.py
@@ -7,7 +7,6 @@
import kornia.augmentation as K
import torch
-from einops import repeat
from ..datasets import OSCD
from ..samplers.utils import _to_tuple
@@ -16,6 +15,38 @@
from .geo import NonGeoDataModule
from .utils import dataset_split
+MEAN = {
+ "B01": 1583.0741,
+ "B02": 1374.3202,
+ "B03": 1294.1616,
+ "B04": 1325.6158,
+ "B05": 1478.7408,
+ "B06": 1933.0822,
+ "B07": 2166.0608,
+ "B08": 2076.4868,
+ "B8A": 2306.0652,
+ "B09": 690.9814,
+ "B10": 16.2360,
+ "B11": 2080.3347,
+ "B12": 1524.6930,
+}
+
+STD = {
+ "B01": 52.1937,
+ "B02": 83.4168,
+ "B03": 105.6966,
+ "B04": 151.1401,
+ "B05": 147.4615,
+ "B06": 115.9289,
+ "B07": 123.1974,
+ "B08": 114.6483,
+ "B8A": 141.4530,
+ "B09": 73.2758,
+ "B10": 4.8368,
+ "B11": 213.4821,
+ "B12": 179.4793,
+}
+
class OSCDDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the OSCD dataset.
@@ -26,42 +57,6 @@ class OSCDDataModule(NonGeoDataModule):
.. versionadded:: 0.2
"""
- mean = torch.tensor(
- [
- 1583.0741,
- 1374.3202,
- 1294.1616,
- 1325.6158,
- 1478.7408,
- 1933.0822,
- 2166.0608,
- 2076.4868,
- 2306.0652,
- 690.9814,
- 16.2360,
- 2080.3347,
- 1524.6930,
- ]
- )
-
- std = torch.tensor(
- [
- 52.1937,
- 83.4168,
- 105.6966,
- 151.1401,
- 147.4615,
- 115.9289,
- 123.1974,
- 114.6483,
- 141.4530,
- 73.2758,
- 4.8368,
- 213.4821,
- 179.4793,
- ]
- )
-
def __init__(
self,
batch_size: int = 64,
@@ -86,19 +81,14 @@ def __init__(
self.patch_size = _to_tuple(patch_size)
self.val_split_pct = val_split_pct
- self.bands = kwargs.get("bands", "all")
- if self.bands == "rgb":
- self.mean = self.mean[[3, 2, 1]]
- self.std = self.std[[3, 2, 1]]
-
- # Change detection, 2 images from different times
- self.mean = repeat(self.mean, "c -> (t c)", t=2)
- self.std = repeat(self.std, "c -> (t c)", t=2)
+ self.bands = kwargs.get("bands", OSCD.all_bands)
+ self.mean = torch.tensor([MEAN[b] for b in self.bands])
+ self.std = torch.tensor([STD[b] for b in self.bands])
self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
_RandomNCrop(self.patch_size, batch_size),
- data_keys=["image", "mask"],
+ data_keys=["image1", "image2", "mask"],
)
def setup(self, stage: str) -> None:
diff --git a/torchgeo/datamodules/skippd.py b/torchgeo/datamodules/skippd.py
index 8367f776ff6..b76eb3e1e92 100644
--- a/torchgeo/datamodules/skippd.py
+++ b/torchgeo/datamodules/skippd.py
@@ -15,6 +15,8 @@ class SKIPPDDataModule(NonGeoDataModule):
Implements 80/20 train/val splits on train_val set.
See :func:`setup` for more details.
+
+ .. versionadded:: 0.5
"""
def __init__(
diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py
index e7cc9be05a3..ec3e85097e9 100644
--- a/torchgeo/datamodules/so2sat.py
+++ b/torchgeo/datamodules/so2sat.py
@@ -183,7 +183,9 @@ def __init__(
.. versionadded:: 0.5
The *val_split_pct* parameter, and the 'rgb' argument to *band_set*.
"""
- version = kwargs.get("version", "2")
+ # https://github.com/Lightning-AI/lightning/issues/18616
+ kwargs["version"] = str(kwargs.get("version", "2"))
+ version = kwargs["version"]
kwargs["bands"] = So2Sat.BAND_SETS[band_set]
self.val_split_pct = val_split_pct
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 59900a46a7c..0a2261aab3b 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -8,6 +8,7 @@
from .astergdem import AsterGDEM
from .benin_cashews import BeninSmallHolderCashews
from .bigearthnet import BigEarthNet
+from .biomassters import BioMassters
from .cbf import CanadianBuildingFootprints
from .cdl import CDL
from .chesapeake import (
@@ -72,6 +73,7 @@
)
from .levircd import LEVIRCDPlus
from .loveda import LoveDA
+from .mapinwild import MapInWild
from .millionaid import MillionAID
from .naip import NAIP
from .nasa_marine_debris import NASAMarineDebris
@@ -84,6 +86,8 @@
from .reforestree import ReforesTree
from .resisc45 import RESISC45
from .rioxr import RioXarrayDataset
+from .rwanda_field_boundary import RwandaFieldBoundary
+from .seasonet import SeasoNet
from .seco import SeasonalContrastS2
from .sen12ms import SEN12MS
from .sentinel import Sentinel, Sentinel1, Sentinel2
@@ -113,6 +117,7 @@
from .usavars import USAVars
from .utils import (
BoundingBox,
+ DatasetNotFoundError,
concat_samples,
merge_samples,
stack_samples,
@@ -173,6 +178,7 @@
"ADVANCE",
"BeninSmallHolderCashews",
"BigEarthNet",
+ "BioMassters",
"CloudCoverDetection",
"COWC",
"COWCCounting",
@@ -193,6 +199,7 @@
"LandCoverAI",
"LEVIRCDPlus",
"LoveDA",
+ "MapInWild",
"MillionAID",
"NASAMarineDebris",
"OSCD",
@@ -201,7 +208,9 @@
"Potsdam2D",
"RESISC45",
"ReforesTree",
+ "RwandaFieldBoundary",
"SeasonalContrastS2",
+ "SeasoNet",
"SEN12MS",
"SKIPPD",
"So2Sat",
@@ -247,4 +256,6 @@
"random_grid_cell_assignment",
"roi_split",
"time_series_split",
+ # Errors
+ "DatasetNotFoundError",
)
diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py
index e8d75b13fcf..3618db0fa9b 100644
--- a/torchgeo/datasets/advance.py
+++ b/torchgeo/datasets/advance.py
@@ -10,11 +10,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_and_extract_archive
+from .utils import DatasetNotFoundError, download_and_extract_archive
class ADVANCE(NonGeoDataset):
@@ -100,8 +101,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.transforms = transforms
@@ -111,10 +111,7 @@ def __init__(
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
self.files = self._load_files(self.root)
self.classes = sorted({f["cls"] for f in self.files})
@@ -217,11 +214,7 @@ def _check_integrity(self) -> bool:
return True
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- AssertionError: if the checksum of split.py does not match
- """
+ """Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return
@@ -236,7 +229,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py
index 4b7d4ac6462..6c5959a0660 100644
--- a/torchgeo/datasets/agb_live_woody_density.py
+++ b/torchgeo/datasets/agb_live_woody_density.py
@@ -3,16 +3,17 @@
"""Aboveground Live Woody Biomass Density dataset."""
-import glob
import json
import os
-from typing import Any, Callable, Optional
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
-from .utils import download_url
+from .utils import DatasetNotFoundError, download_url
class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
@@ -43,10 +44,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
is_image = False
- url = (
- "https://opendata.arcgis.com/api/v3/datasets/3e8736c8866b458687"
- "e00d40c9f00bce_0/downloads/data?format=geojson&spatialRefId=4326"
- )
+ url = "https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326" # noqa: E501
base_filename = "Aboveground_Live_Woody_Biomass_Density.geojson"
@@ -58,7 +56,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -68,7 +66,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -79,48 +77,43 @@ def __init__(
cache: if True, cache file handle to speed up repeated sampling
Raises:
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.download = download
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if dataset is missing
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
- pathname = os.path.join(self.root, self.filename_glob)
- if glob.glob(pathname):
+ if self.files:
return
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
def _download(self) -> None:
"""Download the dataset."""
- download_url(self.url, self.root, self.base_filename)
+ assert isinstance(self.paths, str)
+ download_url(self.url, self.paths, self.base_filename)
- with open(os.path.join(self.root, self.base_filename)) as f:
+ with open(os.path.join(self.paths, self.base_filename)) as f:
content = json.load(f)
for item in content["features"]:
download_url(
- item["properties"]["download"],
- self.root,
+ item["properties"]["Mg_px_1_download"],
+ self.paths,
item["properties"]["tile_id"] + ".tif",
)
@@ -129,7 +122,7 @@ def plot(
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py
index 0bc1d886457..99215f774b4 100644
--- a/torchgeo/datasets/astergdem.py
+++ b/torchgeo/datasets/astergdem.py
@@ -3,14 +3,14 @@
"""Aster Global Digital Elevation Model dataset."""
-import glob
-import os
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
+from .utils import DatasetNotFoundError
class AsterGDEM(RasterDataset):
@@ -46,7 +46,7 @@ class AsterGDEM(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -55,8 +55,8 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found, here the collection of
- individual zip files for each tile should be found
+ paths: one or more root directories to search or files to load, here
+ the collection of individual zip files for each tile should be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -66,38 +66,31 @@ def __init__(
cache: if True, cache file handle to speed up repeated sampling
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if dataset is missing
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if dataset is missing
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exists
- pathname = os.path.join(self.root, self.filename_glob)
- if glob.glob(pathname):
+ if self.files:
return
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` "
- "either specify a different `root` directory or make sure you "
- "have manually downloaded dataset tiles as suggested in the documentation."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py
index 371c1985ab3..dccf29cedf0 100644
--- a/torchgeo/datasets/benin_cashews.py
+++ b/torchgeo/datasets/benin_cashews.py
@@ -13,11 +13,17 @@
import rasterio
import rasterio.features
import torch
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_radiant_mlhub_collection,
+ extract_archive,
+)
# TODO: read geospatial information from stac.json files
@@ -197,7 +203,7 @@ def __init__(
verbose: if True, print messages when new tiles are loaded
Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self._validate_bands(bands)
@@ -213,10 +219,7 @@ def __init__(
self._download(api_key)
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
@@ -431,7 +434,7 @@ def plot(
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py
index 2118265c569..9a127248a8b 100644
--- a/torchgeo/datasets/bigearthnet.py
+++ b/torchgeo/datasets/bigearthnet.py
@@ -12,11 +12,17 @@
import numpy as np
import rasterio
import torch
+from matplotlib.figure import Figure
from rasterio.enums import Resampling
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive, sort_sentinel2_bands
+from .utils import (
+ DatasetNotFoundError,
+ download_url,
+ extract_archive,
+ sort_sentinel2_bands,
+)
class BigEarthNet(NonGeoDataset):
@@ -284,6 +290,9 @@ def __init__(
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits_metadata
assert bands in ["s1", "s2", "all"]
@@ -433,11 +442,7 @@ def _load_target(self, index: int) -> Tensor:
return target
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
keys = ["s1", "s2"] if self.bands == "all" else [self.bands]
urls = [self.metadata[k]["url"] for k in keys]
md5s = [self.metadata[k]["md5"] for k in keys]
@@ -477,11 +482,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
for url, filename, md5 in zip(urls, filenames, md5s):
@@ -533,7 +534,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py
new file mode 100644
index 00000000000..970c5594950
--- /dev/null
+++ b/torchgeo/datasets/biomassters.py
@@ -0,0 +1,288 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""BioMassters Dataset."""
+
+import os
+from collections.abc import Sequence
+from typing import Optional
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import rasterio
+import torch
+from matplotlib.figure import Figure
+from torch import Tensor
+
+from .geo import NonGeoDataset
+from .utils import DatasetNotFoundError, percentile_normalization
+
+
+class BioMassters(NonGeoDataset):
+ """BioMassters Dataset for Aboveground Biomass prediction.
+
+ Dataset intended for Aboveground Biomass (AGB) prediction
+ over Finnish forests based on Sentinel 1 and 2 data with
+ corresponding target AGB mask values generated by Light Detection
+ and Ranging (LiDAR).
+
+ Dataset Format:
+
+ * .tif files for Sentinel 1 and 2 data
+ * .tif file for pixel wise AGB target mask
+ * .csv files for metadata regarding features and targets
+
+ Dataset Features:
+
+ * 13,000 target AGB masks of size (256x256px)
+ * 12 months of data per target mask
+ * Sentinel 1 and Sentinel 2 data for each location
+ * Sentinel 1 available for every month
+ * Sentinel 2 available for almost every month
+ (not available for every month due to ESA aquisition halt over the region
+ during particular periods)
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://nascetti-a.github.io/BioMasster/
+
+ .. versionadded:: 0.5
+ """
+
+ valid_splits = ["train", "test"]
+ valid_sensors = ("S1", "S2")
+
+ metadata_filename = "The_BioMassters_-_features_metadata.csv.csv"
+
+ def __init__(
+ self,
+ root: str = "data",
+ split: str = "train",
+ sensors: Sequence[str] = ["S1", "S2"],
+ as_time_series: bool = False,
+ ) -> None:
+ """Initialize a new instance of BioMassters dataset.
+
+ If ``as_time_series=False`` (the default), each time step becomes its own
+ sample with the target being shared across multiple samples.
+
+ Args:
+ root: root directory where dataset can be found
+ split: train or test split
+ sensors: which sensors to consider for the sample, Sentinel 1 and/or
+ Sentinel 2 ('S1', 'S2')
+ as_time_series: whether or not to return all available
+ time-steps or just a single one for a given target location
+
+ Raises:
+ AssertionError: if ``split`` or ``sensors`` is invalid
+ DatasetNotFoundError: If dataset is not found.
+ """
+ self.root = root
+
+ assert (
+ split in self.valid_splits
+ ), f"Please choose one of the valid splits: {self.valid_splits}."
+ self.split = split
+
+ assert set(sensors).issubset(
+ set(self.valid_sensors)
+ ), f"Please choose a subset of valid sensors: {self.valid_sensors}."
+ self.sensors = sensors
+ self.as_time_series = as_time_series
+
+ self._verify()
+
+ # open metadata csv files
+ self.df = pd.read_csv(os.path.join(self.root, self.metadata_filename))
+
+ # filter sensors
+ self.df = self.df[self.df["satellite"].isin(self.sensors)]
+
+ # filter split
+ self.df = self.df[self.df["split"] == self.split]
+
+ # generate numerical month from filename since first month is September
+ # and has numerical index of 0
+ self.df["num_month"] = (
+ self.df["filename"]
+ .str.split("_", expand=True)[2]
+ .str.split(".", expand=True)[0]
+ .astype(int)
+ )
+
+ # set dataframe index depending on the task for easier indexing
+ if self.as_time_series:
+ self.df["num_index"] = self.df.groupby(["chip_id"]).ngroup()
+ else:
+ filter_df = (
+ self.df.groupby(["chip_id", "month"])["satellite"].count().reset_index()
+ )
+ filter_df = filter_df[filter_df["satellite"] == len(self.sensors)].drop(
+ "satellite", axis=1
+ )
+ # guarantee that each sample has corresponding number of images available
+ self.df = self.df.merge(filter_df, on=["chip_id", "month"], how="inner")
+
+ self.df["num_index"] = self.df.groupby(["chip_id", "month"]).ngroup()
+
+ def __getitem__(self, index: int) -> dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and labels at that index
+
+ Raises:
+ IndexError: if index is out of range of the dataset
+ """
+ sample_df = self.df[self.df["num_index"] == index].copy()
+
+ # sort by satellite and month to return correct order
+ sample_df.sort_values(
+ by=["satellite", "num_month"], inplace=True, ascending=True
+ )
+
+ filepaths = sample_df["filename"].tolist()
+ sample: dict[str, Tensor] = {}
+ for sens in self.sensors:
+ sens_filepaths = [fp for fp in filepaths if sens in fp]
+ sample[f"image_{sens}"] = self._load_input(sens_filepaths)
+
+ if self.split == "train":
+ sample["label"] = self._load_target(
+ sample_df["corresponding_agbm"].unique()[0]
+ )
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the length of the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.df["num_index"].unique())
+
+ def _load_input(self, filenames: list[str]) -> Tensor:
+ """Load the input imagery at the index.
+
+ Args:
+ filenames: list of filenames corresponding to input
+
+ Returns:
+ input image
+ """
+ filepaths = [
+ os.path.join(self.root, f"{self.split}_features", f) for f in filenames
+ ]
+ arr_list = [rasterio.open(fp).read() for fp in filepaths]
+ if self.as_time_series:
+ arr = np.stack(arr_list, axis=0)
+ else:
+ arr = np.concatenate(arr_list, axis=0)
+ return torch.tensor(arr.astype(np.int32))
+
+ def _load_target(self, filename: str) -> Tensor:
+ """Load the target mask at the index.
+
+ Args:
+ filename: filename of target to index
+
+ Returns:
+ target mask
+ """
+ with rasterio.open(os.path.join(self.root, "train_agbm", filename), "r") as src:
+ arr: "np.typing.NDArray[np.float_]" = src.read()
+
+ target = torch.from_numpy(arr).float()
+ return target
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset."""
+ # Check if the extracted files already exist
+ exists = []
+
+ filenames = [f"{self.split}_features", self.metadata_filename]
+ for filename in filenames:
+ pathname = os.path.join(self.root, filename)
+ exists.append(os.path.exists(pathname))
+ if all(exists):
+ return
+
+ raise DatasetNotFoundError(self)
+
+ def plot(
+ self,
+ sample: dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample return by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional suptitle to use for figure
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ ncols = len(self.sensors) + 1
+
+ showing_predictions = "prediction" in sample
+ if showing_predictions:
+ ncols += 1
+
+ fig, axs = plt.subplots(1, ncols=ncols, figsize=(5 * ncols, 10))
+ for idx, sens in enumerate(self.sensors):
+ img = sample[f"image_{sens}"].numpy()
+ if self.as_time_series:
+ # plot last time step
+ img = img[-1, ...]
+ if sens == "S2":
+ img = img[[2, 1, 0], ...]
+ img = percentile_normalization(img.transpose(1, 2, 0))
+ else:
+ co_polarization = img[0] # transmit == receive
+ cross_polarization = img[1] # transmit != receive
+ ratio = co_polarization / cross_polarization
+
+ # https://gis.stackexchange.com/a/400780/123758
+ co_polarization = np.clip(co_polarization / 0.3, a_min=0, a_max=1)
+ cross_polarization = np.clip(
+ cross_polarization / 0.05, a_min=0, a_max=1
+ )
+ ratio = np.clip(ratio / 25, a_min=0, a_max=1)
+
+ img = np.stack((co_polarization, cross_polarization, ratio), axis=-1)
+
+ axs[idx].imshow(img)
+ axs[idx].axis("off")
+ if show_titles:
+ axs[idx].set_title(sens)
+
+ if showing_predictions:
+ pred = axs[ncols - 2].imshow(
+ sample["prediction"].permute(1, 2, 0), cmap="YlGn"
+ )
+ plt.colorbar(pred, ax=axs[ncols - 2], fraction=0.046, pad=0.04)
+ axs[ncols - 2].axis("off")
+ if show_titles:
+ axs[ncols - 2].set_title("Prediction")
+
+ # plot target / only available in train set
+ if "label" in sample:
+ target = axs[-1].imshow(sample["label"].permute(1, 2, 0), cmap="YlGn")
+ plt.colorbar(target, ax=axs[-1], fraction=0.046, pad=0.04)
+ axs[-1].axis("Off")
+ if show_titles:
+ axs[-1].set_title("Target")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+
+ return fig
diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py
index 625357c12ce..d3010e44dce 100644
--- a/torchgeo/datasets/cbf.py
+++ b/torchgeo/datasets/cbf.py
@@ -4,13 +4,15 @@
"""Canadian Building Footprints dataset."""
import os
-from typing import Any, Callable, Optional
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import VectorDataset
-from .utils import check_integrity, download_and_extract_archive
+from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive
class CanadianBuildingFootprints(VectorDataset):
@@ -59,7 +61,7 @@ class CanadianBuildingFootprints(VectorDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.00001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -69,7 +71,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -79,23 +81,21 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if ``download=False`` and data is not found, or
- ``checksum=True`` and checksums don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.checksum = checksum
if download:
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
- super().__init__(root, crs, res, transforms)
+ super().__init__(paths, crs, res, transforms)
def _check_integrity(self) -> bool:
"""Check integrity of dataset.
@@ -103,8 +103,9 @@ def _check_integrity(self) -> bool:
Returns:
True if dataset files are found and/or MD5s match, else False
"""
+ assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
- filepath = os.path.join(self.root, prov_terr + ".zip")
+ filepath = os.path.join(self.paths, prov_terr + ".zip")
if not check_integrity(filepath, md5 if self.checksum else None):
return False
return True
@@ -114,11 +115,11 @@ def _download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return
-
+ assert isinstance(self.paths, str)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + ".zip",
- self.root,
+ self.paths,
md5=md5 if self.checksum else None,
)
@@ -127,7 +128,7 @@ def plot(
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py
index 324a15dcdf0..bad87db572a 100644
--- a/torchgeo/datasets/cdl.py
+++ b/torchgeo/datasets/cdl.py
@@ -3,16 +3,17 @@
"""CDL dataset."""
-import glob
import os
-from typing import Any, Callable, Optional
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
import torch
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
-from .utils import BoundingBox, download_url, extract_archive
+from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class CDL(RasterDataset):
@@ -204,7 +205,7 @@ class CDL(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2022],
@@ -217,7 +218,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -233,11 +234,13 @@ def __init__(
Raises:
AssertionError: if ``years`` or ``classes`` are invalid
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionadded:: 0.5
The *years* and *classes* parameters.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
assert set(years) <= self.md5s.keys(), (
"CDL data product only exists for the following years: "
@@ -248,7 +251,7 @@ def __init__(
), f"Only the following classes are valid: {list(self.cmap.keys())}."
assert 0 in classes, "Classes must include the background class: 0"
- self.root = root
+ self.paths = paths
self.years = years
self.classes = classes
self.download = download
@@ -258,7 +261,7 @@ def __init__(
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
@@ -282,28 +285,17 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
return sample
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
- exists = []
- for year in self.years:
- filename_year = self.filename_glob.replace("*", str(year))
- pathname = os.path.join(self.root, "**", filename_year)
- for fname in glob.iglob(pathname, recursive=True):
- if not fname.endswith(".zip"):
- exists.append(True)
-
- if len(exists) == len(self.years):
+ if self.files:
return
# Check if the zip files have already been downloaded
exists = []
+ assert isinstance(self.paths, str)
for year in self.years:
pathname = os.path.join(
- self.root, self.zipfile_glob.replace("*", str(year))
+ self.paths, self.zipfile_glob.replace("*", str(year))
)
if os.path.exists(pathname):
exists.append(True)
@@ -316,11 +308,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -331,23 +319,24 @@ def _download(self) -> None:
for year in self.years:
download_url(
self.url.format(year),
- self.root,
+ self.paths,
md5=self.md5s[year] if self.checksum else None,
)
def _extract(self) -> None:
"""Extract the dataset."""
+ assert isinstance(self.paths, str)
for year in self.years:
zipfile_name = self.zipfile_glob.replace("*", str(year))
- pathname = os.path.join(self.root, zipfile_name)
- extract_archive(pathname, self.root)
+ pathname = os.path.join(self.paths, zipfile_name)
+ extract_archive(pathname, self.paths)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py
index a472c9b20bf..17b0ee8e74a 100644
--- a/torchgeo/datasets/chesapeake.py
+++ b/torchgeo/datasets/chesapeake.py
@@ -6,8 +6,8 @@
import abc
import os
import sys
-from collections.abc import Sequence
-from typing import Any, Callable, Optional, cast
+from collections.abc import Iterable, Sequence
+from typing import Any, Callable, Optional, Union, cast
import fiona
import matplotlib.pyplot as plt
@@ -19,11 +19,12 @@
import shapely.ops
import torch
from matplotlib.colors import ListedColormap
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
from .geo import GeoDataset, RasterDataset
-from .utils import BoundingBox, download_url, extract_archive
+from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class Chesapeake(RasterDataset, abc.ABC):
@@ -88,7 +89,7 @@ def url(self) -> str:
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -99,7 +100,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -111,10 +112,12 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.download = download
self.checksum = checksum
@@ -131,30 +134,23 @@ def __init__(
)
self._cmap = ListedColormap(colors)
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted file already exists
- if os.path.exists(os.path.join(self.root, self.filename)):
+ if self.files:
return
# Check if the zip file has already been downloaded
- if os.path.exists(os.path.join(self.root, self.zipfile)):
+ assert isinstance(self.paths, str)
+ if os.path.exists(os.path.join(self.paths, self.zipfile)):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -162,18 +158,19 @@ def _verify(self) -> None:
def _download(self) -> None:
"""Download the dataset."""
- download_url(self.url, self.root, filename=self.zipfile, md5=self.md5)
+ download_url(self.url, self.paths, filename=self.zipfile, md5=self.md5)
def _extract(self) -> None:
"""Extract the dataset."""
- extract_archive(os.path.join(self.root, self.zipfile))
+ assert isinstance(self.paths, str)
+ extract_archive(os.path.join(self.paths, self.zipfile))
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
@@ -489,7 +486,7 @@ class ChesapeakeCVPR(GeoDataset):
)
# these are used to check the integrity of the dataset
- files = [
+ _files = [
"de_1m_2013_extended-debuffered-test_tiles",
"de_1m_2013_extended-debuffered-train_tiles",
"de_1m_2013_extended-debuffered-val_tiles",
@@ -556,9 +553,8 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
AssertionError: if ``splits`` or ``layers`` are not valid
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
for split in splits:
assert split in self.splits
@@ -688,17 +684,13 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
return sample
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, filename))
# Check if the extracted files already exist
- if all(map(exists, self.files)):
+ if all(map(exists, self._files)):
return
# Check if the zip files have already been downloaded
@@ -713,11 +705,7 @@ def exists(filename: str) -> bool:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -743,7 +731,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py
index a3c344bd944..63a6433f630 100644
--- a/torchgeo/datasets/cloud_cover.py
+++ b/torchgeo/datasets/cloud_cover.py
@@ -12,10 +12,16 @@
import numpy as np
import rasterio
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_radiant_mlhub_collection,
+ extract_archive,
+)
# TODO: read geospatial information from stac.json files
@@ -122,7 +128,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.split = split
@@ -136,10 +142,7 @@ def __init__(
self._download(api_key)
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
self.chip_paths = self._load_collections()
@@ -330,9 +333,6 @@ def _download(self, api_key: Optional[str] = None) -> None:
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")
@@ -355,7 +355,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py
index 1339691cf72..ac42c8d1ead 100644
--- a/torchgeo/datasets/cms_mangrove_canopy.py
+++ b/torchgeo/datasets/cms_mangrove_canopy.py
@@ -3,15 +3,15 @@
"""CMS Global Mangrove Canopy dataset."""
-import glob
import os
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
-from .utils import check_integrity, extract_archive
+from .utils import DatasetNotFoundError, check_integrity, extract_archive
class CMSGlobalMangroveCanopy(RasterDataset):
@@ -167,7 +167,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
measurement: str = "agb",
@@ -179,7 +179,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -192,11 +192,13 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if dataset is missing or checksum fails
AssertionError: if country or measurement arg are not str or invalid
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.checksum = checksum
assert isinstance(country, str), "Country argument must be a str."
@@ -219,36 +221,29 @@ def __init__(
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
- pathname = os.path.join(self.root, "**", self.filename_glob)
- if glob.glob(pathname):
+ if self.files:
return
# Check if the zip file has already been downloaded
- pathname = os.path.join(self.root, self.zipfile)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, self.zipfile)
if os.path.exists(pathname):
if self.checksum and not check_integrity(pathname, self.md5):
raise RuntimeError("Dataset found, but corrupted.")
self._extract()
return
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` "
- "either specify a different `root` directory or make sure you "
- "have manually downloaded the dataset as instructed in the documentation."
- )
+ raise DatasetNotFoundError(self)
def _extract(self) -> None:
"""Extract the dataset."""
- pathname = os.path.join(self.root, self.zipfile)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, self.zipfile)
extract_archive(pathname)
def plot(
@@ -256,7 +251,7 @@ def plot(
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py
index e9d8dc02b7f..0e0518502ad 100644
--- a/torchgeo/datasets/cowc.py
+++ b/torchgeo/datasets/cowc.py
@@ -11,11 +11,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_and_extract_archive
+from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive
class COWC(NonGeoDataset, abc.ABC):
@@ -80,8 +81,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in ["train", "test"]
@@ -94,10 +94,7 @@ def __init__(
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
self.images = []
self.targets = []
@@ -196,7 +193,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py
index 989b7dc2f76..1dc6a2c36a7 100644
--- a/torchgeo/datasets/cv4a_kenya_crop_type.py
+++ b/torchgeo/datasets/cv4a_kenya_crop_type.py
@@ -11,11 +11,17 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_radiant_mlhub_collection,
+ extract_archive,
+)
# TODO: read geospatial information from stac.json files
@@ -140,7 +146,7 @@ def __init__(
verbose: if True, print messages when new tiles are loaded
Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self._validate_bands(bands)
@@ -156,10 +162,7 @@ def __init__(
self._download(api_key)
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
@@ -389,9 +392,6 @@ def _download(self, api_key: Optional[str] = None) -> None:
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")
@@ -411,7 +411,7 @@ def plot(
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py
index 4022f5a84f9..b0cc6f8f1c4 100644
--- a/torchgeo/datasets/cyclone.py
+++ b/torchgeo/datasets/cyclone.py
@@ -11,11 +11,17 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_radiant_mlhub_collection,
+ extract_archive,
+)
class TropicalCyclone(NonGeoDataset):
@@ -85,7 +91,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.md5s
@@ -98,10 +104,7 @@ def __init__(
self._download(api_key)
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
output_dir = "_".join([self.collection_id, split, "source"])
filename = os.path.join(root, output_dir, "collection.json")
@@ -205,9 +208,6 @@ def _download(self, api_key: Optional[str] = None) -> None:
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")
@@ -227,7 +227,7 @@ def plot(
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py
index 694da07f3d1..233c70cd049 100644
--- a/torchgeo/datasets/deepglobelandcover.py
+++ b/torchgeo/datasets/deepglobelandcover.py
@@ -15,6 +15,7 @@
from .geo import NonGeoDataset
from .utils import (
+ DatasetNotFoundError,
check_integrity,
draw_semantic_segmentation_masks,
extract_archive,
@@ -102,6 +103,9 @@ def __init__(
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found.
"""
assert split in self.splits
self.root = root
@@ -195,11 +199,7 @@ def _load_target(self, index: int) -> Tensor:
return tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if checksum fails or the dataset is not downloaded
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
if os.path.exists(os.path.join(self.root, self.data_root)):
return
@@ -213,11 +213,7 @@ def _verify(self) -> None:
extract_archive(filepath)
return
- # Check if the user requested to download the dataset
- raise RuntimeError(
- "Dataset not found in `root`, either specify a different"
- + " `root` directory or manually download the dataset to this directory."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py
index f17a080e94c..5268ea4f0e5 100644
--- a/torchgeo/datasets/dfc2022.py
+++ b/torchgeo/datasets/dfc2022.py
@@ -13,11 +13,17 @@
import rasterio
import torch
from matplotlib import colors
+from matplotlib.figure import Figure
from rasterio.enums import Resampling
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, extract_archive, percentile_normalization
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ extract_archive,
+ percentile_normalization,
+)
class DFC2022(NonGeoDataset):
@@ -152,6 +158,7 @@ def __init__(
Raises:
AssertionError: if ``split`` is invalid
+ DatasetNotFoundError: If dataset is not found.
"""
assert split in self.metadata
self.root = root
@@ -257,11 +264,7 @@ def _load_target(self, path: str) -> Tensor:
return tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if checksum fails or the dataset is not downloaded
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
exists = []
for split_info in self.metadata.values():
@@ -287,18 +290,14 @@ def _verify(self) -> None:
if all(exists):
return
- # Check if the user requested to download the dataset
- raise RuntimeError(
- "Dataset not found in `root` directory, either specify a different"
- + " `root` directory or manually download the dataset to this directory."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py
index 8dfa5a7c957..94e409c4d07 100644
--- a/torchgeo/datasets/eddmaps.py
+++ b/torchgeo/datasets/eddmaps.py
@@ -8,10 +8,11 @@
from typing import Any
import numpy as np
+import pandas as pd
from rasterio.crs import CRS
from .geo import GeoDataset
-from .utils import BoundingBox, disambiguate_timestamp
+from .utils import BoundingBox, DatasetNotFoundError, disambiguate_timestamp
class EDDMapS(GeoDataset):
@@ -34,11 +35,6 @@ class EDDMapS(GeoDataset):
Georgia - Center for Invasive Species and Ecosystem Health. Available online at
https://www.eddmaps.org/; last accessed *DATE*.
- .. note::
- This dataset requires the following additional library to be installed:
-
- * `pandas `_ to load CSV files
-
.. versionadded:: 0.3
"""
@@ -52,8 +48,7 @@ def __init__(self, root: str = "data") -> None:
root: root directory where dataset can be found
Raises:
- FileNotFoundError: if no files are found in ``root``
- ImportError: if pandas is not installed
+ DatasetNotFoundError: If dataset is not found.
"""
super().__init__()
@@ -61,14 +56,7 @@ def __init__(self, root: str = "data") -> None:
filepath = os.path.join(root, "mappings.csv")
if not os.path.exists(filepath):
- raise FileNotFoundError(f"Dataset not found in `root={self.root}`")
-
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
+ raise DatasetNotFoundError(self)
# Read CSV file
data = pd.read_csv(
diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py
index 8487389927f..f6842738abb 100644
--- a/torchgeo/datasets/enviroatlas.py
+++ b/torchgeo/datasets/enviroatlas.py
@@ -18,10 +18,11 @@
import shapely.ops
import torch
from matplotlib.colors import ListedColormap
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import GeoDataset
-from .utils import BoundingBox, download_url, extract_archive
+from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class EnviroAtlas(GeoDataset):
@@ -79,7 +80,7 @@ class EnviroAtlas(GeoDataset):
)
# these are used to check the integrity of the dataset
- files = [
+ _files = [
"austin_tx-2012_1m-test_tiles-debuffered",
"austin_tx-2012_1m-val5_tiles-debuffered",
"durham_nc-2012_1m-test_tiles-debuffered",
@@ -277,9 +278,8 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
AssertionError: if ``splits`` or ``layers`` are not valid
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
for split in splits:
assert split in self.splits
@@ -411,17 +411,13 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
return sample
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))
# Check if the extracted files already exist
- if all(map(exists, self.files)):
+ if all(map(exists, self._files)):
return
# Check if the zip files have already been downloaded
@@ -431,11 +427,7 @@ def exists(filename: str) -> bool:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -454,7 +446,7 @@ def plot(
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Note: only plots the "naip" and "lc" layers.
diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py
index 64c4cc82d33..2d26f24565a 100644
--- a/torchgeo/datasets/esri2020.py
+++ b/torchgeo/datasets/esri2020.py
@@ -5,13 +5,15 @@
import glob
import os
-from typing import Any, Callable, Optional
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class Esri2020(RasterDataset):
@@ -66,7 +68,7 @@ class Esri2020(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -77,7 +79,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -89,41 +91,35 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.download = download
self.checksum = checksum
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted file already exists
- pathname = os.path.join(self.root, "**", self.filename_glob)
- if glob.glob(pathname):
+ if self.files:
return
# Check if the zip files have already been downloaded
- pathname = os.path.join(self.root, self.zipfile)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, self.zipfile)
if glob.glob(pathname):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -131,18 +127,19 @@ def _verify(self) -> None:
def _download(self) -> None:
"""Download the dataset."""
- download_url(self.url, self.root, filename=self.zipfile, md5=self.md5)
+ download_url(self.url, self.paths, filename=self.zipfile, md5=self.md5)
def _extract(self) -> None:
"""Extract the dataset."""
- extract_archive(os.path.join(self.root, self.zipfile))
+ assert isinstance(self.paths, str)
+ extract_archive(os.path.join(self.paths, self.zipfile))
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py
index cdb42a6850d..7dfa50fb2ab 100644
--- a/torchgeo/datasets/etci2021.py
+++ b/torchgeo/datasets/etci2021.py
@@ -10,11 +10,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_and_extract_archive
+from .utils import DatasetNotFoundError, download_and_extract_archive
class ETCI2021(NonGeoDataset):
@@ -97,8 +98,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.metadata.keys()
@@ -111,10 +111,7 @@ def __init__(
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
self.files = self._load_files(self.root, self.split)
@@ -242,11 +239,7 @@ def _check_integrity(self) -> bool:
return True
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- AssertionError: if the checksum of split.py does not match
- """
+ """Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return
@@ -263,7 +256,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py
index 5a2134f84dc..ce82500ee7d 100644
--- a/torchgeo/datasets/eudem.py
+++ b/torchgeo/datasets/eudem.py
@@ -5,13 +5,15 @@
import glob
import os
-from typing import Any, Callable, Optional
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
-from .utils import check_integrity, extract_archive
+from .utils import DatasetNotFoundError, check_integrity, extract_archive
class EUDEM(RasterDataset):
@@ -81,7 +83,7 @@ class EUDEM(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -91,8 +93,8 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found, here the collection of
- individual zip files for each tile should be found
+ paths: one or more root directories to search or files to load, here
+ the collection of individual zip files for each tile should be found
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -103,28 +105,27 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.checksum = checksum
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted file already exists
- pathname = os.path.join(self.root, self.filename_glob)
- if glob.glob(pathname):
+ if self.files:
return
# Check if the zip files have already been downloaded
- pathname = os.path.join(self.root, self.zipfile_glob)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, self.zipfile_glob)
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
filename = os.path.basename(zipfile)
@@ -133,18 +134,14 @@ def _verify(self) -> None:
extract_archive(zipfile)
return
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` "
- "either specify a different `root` directory or make sure you "
- "have manually downloaded the dataset as suggested in the documentation."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py
index 74a5de64536..319f10e1c65 100644
--- a/torchgeo/datasets/eurosat.py
+++ b/torchgeo/datasets/eurosat.py
@@ -10,10 +10,17 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
-from .utils import check_integrity, download_url, extract_archive, rasterio_loader
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_url,
+ extract_archive,
+ rasterio_loader,
+)
class EuroSAT(NonGeoClassificationDataset):
@@ -30,16 +37,16 @@ class EuroSAT(NonGeoClassificationDataset):
Dataset classes:
- * Industrial Buildings
- * Residential Buildings
* Annual Crop
- * Permanent Crop
- * River
- * Sea and Lake
+ * Forest
* Herbaceous Vegetation
* Highway
+ * Industrial Buildings
* Pasture
- * Forest
+ * Permanent Crop
+ * Residential Buildings
+ * River
+ * SeaLake
This dataset uses the train/val/test splits defined in the "In-domain representation
learning for remote sensing" paper:
@@ -72,18 +79,6 @@ class EuroSAT(NonGeoClassificationDataset):
"val": "95de90f2aa998f70a3b2416bfe0687b4",
"test": "7ae5ab94471417b6e315763121e67c5f",
}
- classes = [
- "Industrial Buildings",
- "Residential Buildings",
- "Annual Crop",
- "Permanent Crop",
- "River",
- "Sea and Lake",
- "Herbaceous Vegetation",
- "Highway",
- "Pasture",
- "Forest",
- ]
all_band_names = (
"B01",
@@ -94,7 +89,7 @@ class EuroSAT(NonGeoClassificationDataset):
"B06",
"B07",
"B08",
- "B08A",
+ "B8A",
"B09",
"B10",
"B11",
@@ -127,8 +122,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionadded:: 0.3
The *bands* parameter.
@@ -191,11 +185,7 @@ def _check_integrity(self) -> bool:
return integrity
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
filepath = os.path.join(self.root, self.base_dir)
if os.path.exists(filepath):
@@ -208,11 +198,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()
@@ -261,7 +247,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py
index b9f89cb081b..24cc7b97d2f 100644
--- a/torchgeo/datasets/fair1m.py
+++ b/torchgeo/datasets/fair1m.py
@@ -12,11 +12,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_url, extract_archive
+from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive
def parse_pascal_voc(path: str) -> dict[str, Any]:
@@ -243,8 +244,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found.
.. versionchanged:: 0.5
Added *split* and *download* parameters.
@@ -328,11 +328,7 @@ def _load_target(
return boxes, labels_tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if checksum fails or the dataset is not found
- """
+ """Verify the integrity of the dataset."""
# Check if the directories already exist
exists = []
for directory in self.directories[self.split]:
@@ -361,18 +357,10 @@ def _verify(self) -> None:
self._download()
return
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
- """
+ """Download the dataset and extract it."""
paths = self.paths[self.split]
urls = self.urls[self.split]
md5s = self.md5s[self.split]
@@ -395,7 +383,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py
index 0d01534aa60..d51a4384f81 100644
--- a/torchgeo/datasets/fire_risk.py
+++ b/torchgeo/datasets/fire_risk.py
@@ -7,10 +7,11 @@
from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class FireRisk(NonGeoClassificationDataset):
@@ -83,7 +84,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
self.root = root
@@ -97,11 +98,7 @@ def __init__(
)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
path = os.path.join(self.root, self.directory)
if os.path.exists(path):
@@ -115,11 +112,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()
@@ -144,7 +137,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py
index 1cb6fa567d8..b74f3bcd3bc 100644
--- a/torchgeo/datasets/forestdamage.py
+++ b/torchgeo/datasets/forestdamage.py
@@ -12,11 +12,17 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_and_extract_archive, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_and_extract_archive,
+ extract_archive,
+)
def parse_pascal_voc(path: str) -> dict[str, Any]:
@@ -118,8 +124,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.transforms = transforms
@@ -236,21 +241,13 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory, either specify a different"
- + " `root` directory or manually download "
- + "the dataset to this directory."
- )
+ raise DatasetNotFoundError(self)
# else download the dataset
self._download()
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- AssertionError: if the checksum does not match
- """
+ """Download the dataset and extract it."""
download_and_extract_archive(
self.url,
self.root,
@@ -263,7 +260,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py
index 17e6952ab33..a34cc8ba685 100644
--- a/torchgeo/datasets/gbif.py
+++ b/torchgeo/datasets/gbif.py
@@ -10,10 +10,11 @@
from typing import Any
import numpy as np
+import pandas as pd
from rasterio.crs import CRS
from .geo import GeoDataset
-from .utils import BoundingBox
+from .utils import BoundingBox, DatasetNotFoundError
def _disambiguate_timestamps(
@@ -72,11 +73,6 @@ class GBIF(GeoDataset):
* https://www.gbif.org/citation-guidelines
- .. note::
- This dataset requires the following additional library to be installed:
-
- * `pandas `_ to load CSV files
-
.. versionadded:: 0.3
"""
@@ -90,8 +86,7 @@ def __init__(self, root: str = "data") -> None:
root: root directory where dataset can be found
Raises:
- FileNotFoundError: if no files are found in ``root``
- ImportError: if pandas is not installed
+ DatasetNotFoundError: If dataset is not found.
"""
super().__init__()
@@ -99,14 +94,7 @@ def __init__(self, root: str = "data") -> None:
files = glob.glob(os.path.join(root, "**.csv"))
if not files:
- raise FileNotFoundError(f"Dataset not found in `root={self.root}`")
-
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
+ raise DatasetNotFoundError(self)
# Read tab-delimited CSV file
data = pd.read_table(
diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py
index dc2b5fa1388..34db093e08c 100644
--- a/torchgeo/datasets/geo.py
+++ b/torchgeo/datasets/geo.py
@@ -9,8 +9,9 @@
import os
import re
import sys
-from collections.abc import Sequence
-from typing import Any, Callable, Optional, cast
+import warnings
+from collections.abc import Iterable, Sequence
+from typing import Any, Callable, Optional, Union, cast
import fiona
import fiona.transform
@@ -29,7 +30,14 @@
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader as pil_loader
-from .utils import BoundingBox, concat_samples, disambiguate_timestamp, merge_samples
+from .utils import (
+ BoundingBox,
+ DatasetNotFoundError,
+ concat_samples,
+ disambiguate_timestamp,
+ merge_samples,
+ path_is_vsi,
+)
class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
@@ -72,9 +80,17 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC):
dataset = landsat7 | landsat8
"""
+ paths: Union[str, Iterable[str]]
_crs = CRS.from_epsg(4326)
_res = 0.0
+ #: Glob expression used to search for files.
+ #:
+ #: This expression should be specific enough that it will not pick up files from
+ #: other datasets. It should not include a file extension, as the dataset may be in
+ #: a different file format than what it was originally downloaded as.
+ filename_glob = "*"
+
# NOTE: according to the Python docs:
#
# * https://docs.python.org/3/library/exceptions.html#NotImplementedError
@@ -269,17 +285,42 @@ def res(self, new_res: float) -> None:
print(f"Converting {self.__class__.__name__} res from {self.res} to {new_res}")
self._res = new_res
+ @property
+ def files(self) -> set[str]:
+ """A list of all files in the dataset.
+
+ Returns:
+ All files in the dataset.
+
+ .. versionadded:: 0.5
+ """
+ # Make iterable
+ if isinstance(self.paths, str):
+ paths: Iterable[str] = [self.paths]
+ else:
+ paths = self.paths
+
+ # Using set to remove any duplicates if directories are overlapping
+ files: set[str] = set()
+ for path in paths:
+ if os.path.isdir(path):
+ pathname = os.path.join(path, "**", self.filename_glob)
+ files |= set(glob.iglob(pathname, recursive=True))
+ elif os.path.isfile(path) or path_is_vsi(path):
+ files.add(path)
+ else:
+ warnings.warn(
+ f"Could not find any relevant files for provided path '{path}'. "
+ f"Path was ignored.",
+ UserWarning,
+ )
+
+ return files
+
class RasterDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as raster files."""
- #: Glob expression used to search for files.
- #:
- #: This expression should be specific enough that it will not pick up files from
- #: other datasets. It should not include a file extension, as the dataset may be in
- #: a different file format than what it was originally downloaded as.
- filename_glob = "*"
-
#: Regular expression used to extract date from filename.
#:
#: The expression should use named groups. The expression may contain any number of
@@ -329,7 +370,7 @@ def dtype(self) -> torch.dtype:
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
@@ -339,7 +380,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -350,19 +391,21 @@ def __init__(
cache: if True, cache file handle to speed up repeated sampling
Raises:
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
super().__init__(transforms)
- self.root = root
+ self.paths = paths
self.bands = bands or self.all_bands
self.cache = cache
# Populate the dataset index
i = 0
- pathname = os.path.join(root, "**", self.filename_glob)
filename_regex = re.compile(self.filename_regex, re.VERBOSE)
- for filepath in glob.iglob(pathname, recursive=True):
+ for filepath in self.files:
match = re.match(filename_regex, os.path.basename(filepath))
if match is not None:
try:
@@ -396,10 +439,7 @@ def __init__(
i += 1
if i == 0:
- msg = f"No {self.__class__.__name__} data was found in `root='{self.root}'`"
- if self.bands:
- msg += f" with `bands={self.bands}`"
- raise FileNotFoundError(msg)
+ raise DatasetNotFoundError(self)
if not self.separate_files:
self.band_indexes = None
@@ -540,16 +580,9 @@ def _load_warp_file(self, filepath: str) -> DatasetReader:
class VectorDataset(GeoDataset):
"""Abstract base class for :class:`GeoDataset` stored as vector files."""
- #: Glob expression used to search for files.
- #:
- #: This expression should be specific enough that it will not pick up files from
- #: other datasets. It should not include a file extension, as the dataset may be in
- #: a different file format than what it was originally downloaded as.
- filename_glob = "*"
-
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -558,7 +591,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -568,20 +601,22 @@ def __init__(
rasterized into the mask
Raises:
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found.
.. versionadded:: 0.4
The *label_name* parameter.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
super().__init__(transforms)
- self.root = root
+ self.paths = paths
self.label_name = label_name
# Populate the dataset index
i = 0
- pathname = os.path.join(root, "**", self.filename_glob)
- for filepath in glob.iglob(pathname, recursive=True):
+ for filepath in self.files:
try:
with fiona.open(filepath) as src:
if crs is None:
@@ -602,8 +637,7 @@ def __init__(
i += 1
if i == 0:
- msg = f"No {self.__class__.__name__} data was found in `root='{root}'`"
- raise FileNotFoundError(msg)
+ raise DatasetNotFoundError(self)
self._crs = crs
self._res = res
diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py
index a027bde4770..6fc2b520181 100644
--- a/torchgeo/datasets/gid15.py
+++ b/torchgeo/datasets/gid15.py
@@ -10,11 +10,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_and_extract_archive
+from .utils import DatasetNotFoundError, download_and_extract_archive
class GID15(NonGeoDataset):
@@ -104,8 +105,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
@@ -118,10 +118,7 @@ def __init__(
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
self.files = self._load_files(self.root, self.split)
@@ -225,11 +222,7 @@ def _check_integrity(self) -> bool:
return True
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- AssertionError: if the checksum of split.py does not match
- """
+ """Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return
@@ -241,9 +234,7 @@ def _download(self) -> None:
md5=self.md5 if self.checksum else None,
)
- def plot(
- self, sample: dict[str, Tensor], suptitle: Optional[str] = None
- ) -> plt.Figure:
+ def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py
index 8f7959116e3..c9da83b9bec 100644
--- a/torchgeo/datasets/globbiomass.py
+++ b/torchgeo/datasets/globbiomass.py
@@ -5,14 +5,16 @@
import glob
import os
-from typing import Any, Callable, Optional, cast
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union, cast
import matplotlib.pyplot as plt
import torch
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
-from .utils import BoundingBox, check_integrity, extract_archive
+from .utils import BoundingBox, DatasetNotFoundError, check_integrity, extract_archive
class GlobBiomass(RasterDataset):
@@ -117,7 +119,7 @@ class GlobBiomass(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
measurement: str = "agb",
@@ -128,7 +130,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -140,11 +142,13 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if dataset is missing or checksum fails
AssertionError: if measurement argument is invalid, or not a str
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.checksum = checksum
assert isinstance(measurement, str), "Measurement argument must be a str."
@@ -160,7 +164,7 @@ def __init__(
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
@@ -199,18 +203,14 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
return sample
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted file already exists
- pathname = os.path.join(self.root, self.filename_glob)
- if glob.glob(pathname):
+ if self.files:
return
# Check if the zip files have already been downloaded
- pathname = os.path.join(self.root, self.zipfile_glob)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, self.zipfile_glob)
if glob.glob(pathname):
for zipfile in glob.iglob(pathname):
filename = os.path.basename(zipfile)
@@ -219,18 +219,14 @@ def _verify(self) -> None:
extract_archive(zipfile)
return
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` "
- "either specify a different `root` directory or make sure you "
- "have manually downloaded the dataset as suggested in the documentation."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py
index d8af15b175e..916ad1af71f 100644
--- a/torchgeo/datasets/idtrees.py
+++ b/torchgeo/datasets/idtrees.py
@@ -10,15 +10,17 @@
import fiona
import matplotlib.pyplot as plt
import numpy as np
+import pandas as pd
import rasterio
import torch
+from matplotlib.figure import Figure
from rasterio.enums import Resampling
from torch import Tensor
from torchvision.ops import clip_boxes_to_image, remove_small_boxes
from torchvision.utils import draw_bounding_boxes
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class IDTReeS(NonGeoDataset):
@@ -163,7 +165,8 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- ImportError: if laspy or pandas are are not installed
+ ImportError: if laspy is not installed
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in ["train", "test"]
assert task in ["task1", "task2"]
@@ -178,12 +181,6 @@ def __init__(
self.num_classes = len(self.classes)
self._verify()
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
try:
import laspy # noqa: F401
except ImportError:
@@ -345,8 +342,6 @@ def _load(
Returns:
the image path, geometries, and labels
"""
- import pandas as pd
-
if self.split == "train":
directory = os.path.join(root, self.directories[self.split][0])
labels: pd.DataFrame = self._load_labels(directory)
@@ -373,8 +368,6 @@ def _load_labels(self, directory: str) -> Any:
Returns:
a pandas DataFrame containing the labels for each image
"""
- import pandas as pd
-
path_mapping = os.path.join(directory, "Field", "itc_rsFile.csv")
path_labels = os.path.join(directory, "Field", "train_data.csv")
df_mapping = pd.read_csv(path_mapping)
@@ -451,11 +444,7 @@ def _filter_boxes(
return boxes, labels
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
url = self.metadata[self.split]["url"]
md5 = self.metadata[self.split]["md5"]
filename = self.metadata[self.split]["filename"]
@@ -477,11 +466,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
download_url(
@@ -496,7 +481,7 @@ def plot(
show_titles: bool = True,
suptitle: Optional[str] = None,
hsi_indices: tuple[int, int, int] = (0, 1, 2),
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py
index ac4fa41bb40..6838a5cdf4b 100644
--- a/torchgeo/datasets/inaturalist.py
+++ b/torchgeo/datasets/inaturalist.py
@@ -8,10 +8,11 @@
import sys
from typing import Any
+import pandas as pd
from rasterio.crs import CRS
from .geo import GeoDataset
-from .utils import BoundingBox, disambiguate_timestamp
+from .utils import BoundingBox, DatasetNotFoundError, disambiguate_timestamp
class INaturalist(GeoDataset):
@@ -26,11 +27,6 @@ class INaturalist(GeoDataset):
* https://www.inaturalist.org/pages/help#cite
- .. note::
- This dataset requires the following additional library to be installed:
-
- * `pandas `_ to load CSV files
-
.. versionadded:: 0.3
"""
@@ -44,8 +40,7 @@ def __init__(self, root: str = "data") -> None:
root: root directory where dataset can be found
Raises:
- FileNotFoundError: if no files are found in ``root``
- ImportError: if pandas is not installed
+ DatasetNotFoundError: If dataset is not found.
"""
super().__init__()
@@ -53,14 +48,7 @@ def __init__(self, root: str = "data") -> None:
files = glob.glob(os.path.join(root, "**.csv"))
if not files:
- raise FileNotFoundError(f"Dataset not found in `root={self.root}`")
-
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
+ raise DatasetNotFoundError(self)
# Read CSV file
data = pd.read_csv(
diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py
index f05aa7e56f6..b3ab0a6fd9c 100644
--- a/torchgeo/datasets/inria.py
+++ b/torchgeo/datasets/inria.py
@@ -5,6 +5,7 @@
import glob
import os
+import re
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
@@ -15,7 +16,12 @@
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, extract_archive, percentile_normalization
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ extract_archive,
+ percentile_normalization,
+)
class InriaAerialImageLabeling(NonGeoDataset):
@@ -45,6 +51,9 @@ class InriaAerialImageLabeling(NonGeoDataset):
* https://doi.org/10.1109/IGARSS.2017.8127684
.. versionadded:: 0.3
+
+ .. versionchanged:: 0.5
+ Added support for a *val* split.
"""
directory = "AerialImageDataset"
@@ -62,17 +71,17 @@ def __init__(
Args:
root: root directory where dataset can be found
- split: train/test split
+ split: train/val/test split
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version.
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` is invalid
- RuntimeError: if dataset is missing
+ DatasetNotFoundError: If dataset is not found.
"""
self.root = root
- assert split in {"train", "test"}
+ assert split in {"train", "val", "test"}
self.split = split
self.transforms = transforms
self.checksum = checksum
@@ -90,15 +99,25 @@ def _load_files(self, root: str) -> list[dict[str, str]]:
list of dicts containing paths for each pair of image and label
"""
files = []
- root_dir = os.path.join(root, self.directory, self.split)
+ split = "train" if self.split in ["train", "val"] else "test"
+ root_dir = os.path.join(root, self.directory, split)
+ pattern = re.compile(r"([A-Za-z]+)(\d+)")
+
images = glob.glob(os.path.join(root_dir, "images", "*.tif"))
images = sorted(images)
- if self.split == "train":
+
+ if split == "train":
labels = glob.glob(os.path.join(root_dir, "gt", "*.tif"))
labels = sorted(labels)
for img, lbl in zip(images, labels):
- files.append({"image": img, "label": lbl})
+ if match := pattern.search(img):
+ idx = int(match.group(2))
+ # For validation, use the first 5 images of every location
+ if self.split == "train" and idx > 5:
+ files.append({"image": img, "label": lbl})
+ elif self.split == "val" and idx < 6:
+ files.append({"image": img, "label": lbl})
else:
for img in images:
files.append({"image": img})
@@ -171,11 +190,7 @@ def _verify(self) -> None:
archive_path = os.path.join(self.root, self.filename)
md5_hash = self.md5 if self.checksum else None
if not os.path.isfile(archive_path):
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` "
- "either specify a different `root` directory "
- "or download the dataset to this directory"
- )
+ raise DatasetNotFoundError(self)
if not check_integrity(archive_path, md5_hash):
raise RuntimeError("Dataset corrupted")
print("Extracting...")
diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py
index 3b34088d550..04805e85c3a 100644
--- a/torchgeo/datasets/l7irish.py
+++ b/torchgeo/datasets/l7irish.py
@@ -5,15 +5,16 @@
import glob
import os
-from collections.abc import Sequence
-from typing import Any, Callable, Optional, cast
+from collections.abc import Iterable, Sequence
+from typing import Any, Callable, Optional, Union, cast
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
from .geo import RasterDataset
-from .utils import BoundingBox, download_url, extract_archive
+from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class L7Irish(RasterDataset):
@@ -90,7 +91,7 @@ class L7Irish(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = CRS.from_epsg(3857),
res: Optional[float] = None,
bands: Sequence[str] = all_bands,
@@ -102,7 +103,7 @@ def __init__(
"""Initialize a new L7Irish instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to EPSG:3857)
res: resolution of the dataset in units of CRS
@@ -115,43 +116,34 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
- self.root = root
+ self.paths = paths
self.download = download
self.checksum = checksum
self._verify()
super().__init__(
- root, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
+ paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
- pathname = os.path.join(self.root, "**", self.filename_glob)
- for fname in glob.iglob(pathname, recursive=True):
+ if self.files:
return
# Check if the tar.gz files have already been downloaded
- pathname = os.path.join(self.root, "*.tar.gz")
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, "*.tar.gz")
if glob.glob(pathname):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -161,12 +153,13 @@ def _download(self) -> None:
"""Download the dataset."""
for biome, md5 in self.md5s.items():
download_url(
- self.url.format(biome), self.root, md5=md5 if self.checksum else None
+ self.url.format(biome), self.paths, md5=md5 if self.checksum else None
)
def _extract(self) -> None:
"""Extract the dataset."""
- pathname = os.path.join(self.root, "*.tar.gz")
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, "*.tar.gz")
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)
@@ -223,7 +216,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py
index c7e4cf52315..e42eaf1b2a7 100644
--- a/torchgeo/datasets/l8biome.py
+++ b/torchgeo/datasets/l8biome.py
@@ -5,15 +5,16 @@
import glob
import os
-from collections.abc import Sequence
-from typing import Any, Callable, Optional, cast
+from collections.abc import Iterable, Sequence
+from typing import Any, Callable, Optional, Union, cast
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
from .geo import RasterDataset
-from .utils import BoundingBox, download_url, extract_archive
+from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class L8Biome(RasterDataset):
@@ -89,7 +90,7 @@ class L8Biome(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]],
crs: Optional[CRS] = CRS.from_epsg(3857),
res: Optional[float] = None,
bands: Sequence[str] = all_bands,
@@ -101,7 +102,7 @@ def __init__(
"""Initialize a new L8Biome instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to EPSG:3857)
res: resolution of the dataset in units of CRS
@@ -114,43 +115,34 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
- self.root = root
+ self.paths = paths
self.download = download
self.checksum = checksum
self._verify()
super().__init__(
- root, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
+ paths, crs=crs, res=res, bands=bands, transforms=transforms, cache=cache
)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
- pathname = os.path.join(self.root, "**", self.filename_glob)
- for fname in glob.iglob(pathname, recursive=True):
+ if self.files:
return
# Check if the tar.gz files have already been downloaded
- pathname = os.path.join(self.root, "*.tar.gz")
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, "*.tar.gz")
if glob.glob(pathname):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -160,12 +152,13 @@ def _download(self) -> None:
"""Download the dataset."""
for biome, md5 in self.md5s.items():
download_url(
- self.url.format(biome), self.root, md5=md5 if self.checksum else None
+ self.url.format(biome), self.paths, md5=md5 if self.checksum else None
)
def _extract(self) -> None:
"""Extract the dataset."""
- pathname = os.path.join(self.root, "*.tar.gz")
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, "*.tar.gz")
for tarfile in glob.iglob(pathname):
extract_archive(tarfile)
@@ -219,7 +212,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py
index 9203f71e8b6..8dec50b562e 100644
--- a/torchgeo/datasets/landcoverai.py
+++ b/torchgeo/datasets/landcoverai.py
@@ -13,13 +13,20 @@
import numpy as np
import torch
from matplotlib.colors import ListedColormap
+from matplotlib.figure import Figure
from PIL import Image
from rasterio.crs import CRS
from torch import Tensor
from torch.utils.data import Dataset
from .geo import NonGeoDataset, RasterDataset
-from .utils import BoundingBox, download_url, extract_archive, working_dir
+from .utils import (
+ BoundingBox,
+ DatasetNotFoundError,
+ download_url,
+ extract_archive,
+ working_dir,
+)
class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
@@ -83,8 +90,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.download = download
@@ -98,11 +104,7 @@ def __init__(
self._verify()
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
if self._verify_data():
return
@@ -114,11 +116,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -155,7 +153,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
@@ -221,20 +219,19 @@ def __init__(
"""Initialize a new LandCover.ai NonGeo dataset instance.
Args:
- root: root directory where dataset can be found
- crs: :term:`coordinate reference system (CRS)` to warp to
- (defaults to the CRS of the first file found)
- res: resolution of the dataset in units of CRS
- (defaults to the resolution of the first file found)
- transforms: a function/transform that takes input sample and its target as
- entry and returns a transformed version
- cache: if True, cache file handle to speed up repeated sampling
- download: if True, download dataset and store it in the root directory
- checksum: if True, check the MD5 of the downloaded files (may be slow)
+ root: root directory where dataset can be found
+ crs: :term:`coordinate reference system (CRS)` to warp to
+ (defaults to the CRS of the first file found)
+ res: resolution of the dataset in units of CRS
+ (defaults to the resolution of the first file found)
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ cache: if True, cache file handle to speed up repeated sampling
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
LandCoverAIBase.__init__(self, root, download, checksum)
RasterDataset.__init__(self, root, crs, res, transforms=transforms, cache=cache)
@@ -318,8 +315,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in ["train", "val", "test"]
diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py
index 2816cb9b506..c3b7a48f1d3 100644
--- a/torchgeo/datasets/landsat.py
+++ b/torchgeo/datasets/landsat.py
@@ -4,10 +4,11 @@
"""Landsat datasets."""
import abc
-from collections.abc import Sequence
-from typing import Any, Callable, Optional
+from collections.abc import Iterable, Sequence
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@@ -57,7 +58,7 @@ def default_bands(self) -> list[str]:
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
bands: Optional[Sequence[str]] = None,
@@ -67,7 +68,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -78,19 +79,22 @@ def __init__(
cache: if True, cache file handle to speed up repeated sampling
Raises:
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
bands = bands or self.default_bands
self.filename_glob = self.filename_glob.format(bands[0])
- super().__init__(root, crs, res, bands, transforms, cache)
+ super().__init__(paths, crs, res, bands, transforms, cache)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py
index cf7349339ed..3b5028dcef9 100644
--- a/torchgeo/datasets/levircd.py
+++ b/torchgeo/datasets/levircd.py
@@ -10,11 +10,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_and_extract_archive
+from .utils import DatasetNotFoundError, download_and_extract_archive
class LEVIRCDPlus(NonGeoDataset):
@@ -71,8 +72,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
@@ -85,10 +85,7 @@ def __init__(
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
self.files = self._load_files(self.root, self.directory, self.split)
@@ -105,9 +102,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
image1 = self._load_image(files["image1"])
image2 = self._load_image(files["image2"])
mask = self._load_target(files["mask"])
-
- image = torch.stack(tensors=[image1, image2], dim=0)
- sample = {"image": image, "mask": mask}
+ sample = {"image1": image1, "image2": image2, "mask": mask}
if self.transforms is not None:
sample = self.transforms(sample)
@@ -157,7 +152,7 @@ def _load_image(self, path: str) -> Tensor:
filename = os.path.join(path)
with Image.open(filename) as img:
array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB"))
- tensor = torch.from_numpy(array)
+ tensor = torch.from_numpy(array).float()
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
@@ -192,11 +187,7 @@ def _check_integrity(self) -> bool:
return True
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- AssertionError: if the checksum of split.py does not match
- """
+ """Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return
@@ -213,7 +204,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
@@ -226,20 +217,34 @@ def plot(
.. versionadded:: 0.2
"""
- image1, image2, mask = (sample["image"][0], sample["image"][1], sample["mask"])
ncols = 3
+ def get_rgb(img: Tensor) -> "np.typing.NDArray[np.uint8]":
+ rgb_img = img.permute(1, 2, 0).float().numpy()
+ per02 = np.percentile(rgb_img, 2)
+ per98 = np.percentile(rgb_img, 98)
+ delta = per98 - per02
+ epsilon = 1e-7
+ norm_img: "np.typing.NDArray[np.uint8]" = (
+ np.clip((rgb_img - per02) / (delta + epsilon), 0, 1) * 255
+ ).astype(np.uint8)
+ return norm_img
+
+ image1 = get_rgb(sample["image1"])
+ image2 = get_rgb(sample["image2"])
+ mask = sample["mask"].numpy()
+
if "prediction" in sample:
prediction = sample["prediction"]
ncols += 1
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))
- axs[0].imshow(image1.permute(1, 2, 0))
+ axs[0].imshow(image1)
axs[0].axis("off")
- axs[1].imshow(image2.permute(1, 2, 0))
+ axs[1].imshow(image2)
axs[1].axis("off")
- axs[2].imshow(mask)
+ axs[2].imshow(mask, cmap="gray")
axs[2].axis("off")
if "prediction" in sample:
diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py
index 22e0a23d57f..9c7e2aaff4e 100644
--- a/torchgeo/datasets/loveda.py
+++ b/torchgeo/datasets/loveda.py
@@ -10,11 +10,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_and_extract_archive
+from .utils import DatasetNotFoundError, download_and_extract_archive
class LoveDA(NonGeoDataset):
@@ -108,10 +109,8 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- AssertionError: if ``split`` argument is invalid
- AssertionError: if ``scene`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ AssertionError: if ``split`` or ``scene`` arguments are invalid
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
assert set(scene).intersection(
@@ -138,10 +137,7 @@ def __init__(
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found at root directory or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
self.files = self._load_files(self.scene_paths, self.split)
@@ -248,11 +244,7 @@ def _check_integrity(self) -> bool:
return True
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- AssertionError: if the checksum of split.py does not match
- """
+ """Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return
@@ -264,9 +256,7 @@ def _download(self) -> None:
md5=self.md5 if self.checksum else None,
)
- def plot(
- self, sample: dict[str, Tensor], suptitle: Optional[str] = None
- ) -> plt.Figure:
+ def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py
new file mode 100644
index 00000000000..5eaa426d230
--- /dev/null
+++ b/torchgeo/datasets/mapinwild.py
@@ -0,0 +1,402 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""MapInWild dataset."""
+
+import os
+import shutil
+from collections import defaultdict
+from typing import Callable, Optional
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import rasterio
+import torch
+from matplotlib.figure import Figure
+from torch import Tensor
+
+from .geo import NonGeoDataset
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_url,
+ extract_archive,
+ percentile_normalization,
+)
+
+
+class MapInWild(NonGeoDataset):
+ """MapInWild dataset.
+
+ The `MapInWild `__ dataset is
+ curated for the task of wilderness mapping on a pixel-level. MapInWild is a
+ multi-modal dataset and comprises various geodata acquired and formed from
+ different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season
+ Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging
+ Radiometer Suite NightTime Day/Night band. The dataset consists of 8144
+ images with the shape of 1920 × 1920 pixels. The images are weakly annotated
+ from the World Database of Protected Areas (WDPA).
+
+ Dataset features:
+
+ * 1018 areas globally sampled from the WDPA
+ * 10-Band Sentinel-2
+ * Dual-pol Sentinel-1
+ * ESA WorldCover Land Cover
+ * Visible Infrared Imaging Radiometer Suite NightTime Day/Night Band
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://ieeexplore.ieee.org/document/10089830
+
+ .. versionadded:: 0.5
+ """
+
+ url = "https://huggingface.co/datasets/burakekim/mapinwild/resolve/main/"
+
+ modality_urls = {
+ "esa_wc": {"esa_wc/ESA_WC.zip"},
+ "viirs": {"viirs/VIIRS.zip"},
+ "mask": {"mask/mask.zip"},
+ "s1": {"s1/s1_part1.zip", "s1/s1_part2.zip"},
+ "s2_temporal_subset": {
+ "s2_temporal_subset/s2_temporal_subset_part1.zip",
+ "s2_temporal_subset/s2_temporal_subset_part2.zip",
+ },
+ "s2_autumn": {"s2_autumn/s2_autumn_part1.zip", "s2_autumn/s2_autumn_part2.zip"},
+ "s2_spring": {"s2_spring/s2_spring_part1.zip", "s2_spring/s2_spring_part2.zip"},
+ "s2_summer": {"s2_summer/s2_summer_part1.zip", "s2_summer/s2_summer_part2.zip"},
+ "s2_winter": {"s2_winter/s2_winter_part1.zip", "s2_winter/s2_winter_part2.zip"},
+ "split_IDs": {"split_IDs/split_IDs.csv"},
+ }
+
+ md5s = {
+ "ESA_WC.zip": "72b2ee578fe10f0df85bdb7f19311c92",
+ "VIIRS.zip": "4eff014bae127fe536f8a5f17d89ecb4",
+ "mask.zip": "87c83a23a73998ad60d448d240b66225",
+ "s1_part1.zip": "d8a911f5c76b50eb0760b8f0047e4674",
+ "s1_part2.zip": "a30369d17c62d2af5aa52a4189590e3c",
+ "s2_temporal_subset_part1.zip": "78c2d05514458a036fe133f1e2f11d2a",
+ "s2_temporal_subset_part2.zip": "076cd3bd00eb5b7f5d80c9e0a0de0275",
+ "s2_autumn_part1.zip": "6ee7d1ac44b5107e3663636269aecf68",
+ "s2_autumn_part2.zip": "4fc5e1d5c772421dba553722433ac3b9",
+ "s2_spring_part1.zip": "2a89687d8fafa7fc7f5e641bfa97d472",
+ "s2_spring_part2.zip": "5845dcae0ab3cdc174b7c41edd4283a9",
+ "s2_summer_part1.zip": "73ca8291d3f4fb7533636220a816bb71",
+ "s2_summer_part2.zip": "5b5816bbd32987619bf72cde5cacd032",
+ "s2_winter_part1.zip": "ca958f7cd98e37cb59d6f3877573ee6d",
+ "s2_winter_part2.zip": "e7aacb0806d6d619b6abc408e6b09fdc",
+ "split_IDs.csv": "cb5c6c073702acee23544e1e6fe5856f",
+ }
+
+ mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)}
+
+ wc_cmap = {
+ 10: (0, 160, 0),
+ 20: (150, 100, 0),
+ 30: (255, 180, 0),
+ 40: (255, 255, 100),
+ 50: (195, 20, 0),
+ 60: (255, 245, 215),
+ 70: (255, 255, 255),
+ 80: (0, 70, 200),
+ 90: (0, 220, 130),
+ 95: (0, 150, 120),
+ 100: (255, 235, 175),
+ }
+
+ def __init__(
+ self,
+ root: str = "data",
+ modality: list[str] = ["mask", "esa_wc", "viirs", "s2_summer"],
+ split: str = "train",
+ transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new MapInWild dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ modality: the modality to download. Choose from: "mask", "esa_wc",
+ "viirs", "s1", "s2_temporal_subset", "s2_[season]".
+ split: one of "train", "validation", or "test"
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ AssertionError: if ``split`` argument is invalid
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+ """
+ assert split in ["train", "validation", "test"]
+
+ self.checksum = checksum
+ self.root = root
+ self.transforms = transforms
+ self.modality = modality
+ self.download = download
+
+ modality.append("split_IDs")
+ for mode in modality:
+ for modality_link in self.modality_urls[mode]:
+ modality_url = os.path.join(self.url, modality_link)
+ self._verify(
+ url=modality_url, md5=self.md5s[os.path.split(modality_link)[-1]]
+ )
+
+ # Merge modalities downloaded in two parts
+ if (
+ download
+ and mode not in os.listdir(self.root)
+ and len(self.modality_urls[mode]) == 2
+ ):
+ self._merge_parts(mode)
+
+ # Masks will be loaded seperately in the :meth:`__getitem__`
+ if "mask" in self.modality:
+ self.modality.remove("mask")
+
+ # Split IDs has been downloaded and is not needed in the list
+ if "split_IDs" in self.modality:
+ self.modality.remove("split_IDs")
+
+ if os.path.exists(os.path.join(self.root, "split_IDs.csv")):
+ split_dataframe = pd.read_csv(os.path.join(self.root, "split_IDs.csv"))
+ self.ids = split_dataframe[split].dropna().values.tolist()
+ self.ids = list(map(int, self.ids))
+
+ def __getitem__(self, index: int) -> dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ data and label at that index
+ """
+ list_modalities = []
+ id = self.ids[index]
+
+ mask = self._load_raster(id, "mask")
+ mask[mask != 0] = 1
+
+ for mode in self.modality:
+ mode = mode.upper() if mode in ["esa_wc", "viirs"] else mode
+ data = self._load_raster(id, mode)
+ list_modalities.append(data)
+
+ image = torch.cat(list_modalities, dim=0)
+
+ sample: dict[str, Tensor] = {"image": image, "mask": mask}
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.ids)
+
+ def _load_raster(self, filename: int, source: str) -> Tensor:
+ """Load a single raster image or target.
+
+ Args:
+ filename: name of the file to load
+ source: the directory of the modality
+
+ Returns:
+ the raster image or target
+ """
+ with rasterio.open(os.path.join(self.root, source, f"{filename}.tif")) as f:
+ raw_array = f.read()
+ array: "np.typing.NDArray[np.int_]" = np.stack(raw_array, axis=0)
+ if array.dtype == np.uint16:
+ array = array.astype(np.int32)
+ tensor = torch.from_numpy(array).float()
+ return tensor
+
+ def _verify(self, url: str, md5: Optional[str] = None) -> None:
+ """Verify the integrity of the dataset.
+
+ Args:
+ url: url to the file
+ md5: md5 of the file to be verified
+ """
+ modality_folder_name = url.split("/")[-1]
+ mod_fold_no_ext = modality_folder_name.split(".")[0]
+ modality_path = os.path.join(self.root, mod_fold_no_ext)
+ split_path = os.path.join(self.root, modality_folder_name)
+ if mod_fold_no_ext == "split_IDs":
+ modality_path = split_path
+
+ # Check if the files already exist
+ if os.path.exists(modality_path):
+ return
+
+ # Check if the zip files have already been downloaded, if so, extract
+ filepath = os.path.join(self.root, url.split("/")[-1])
+ if os.path.isfile(filepath) and filepath.endswith(".zip"):
+ if self.checksum and not check_integrity(filepath, md5):
+ raise RuntimeError("Dataset found, but corrupted.")
+ self._extract(url)
+ return
+
+ # Check if the user requested to download the dataset
+ if not self.download:
+ raise DatasetNotFoundError(self)
+
+ # Download the dataset
+ self._download(url, md5)
+ if not url.endswith(".csv"):
+ self._extract(url)
+
+ def _download(self, url: str, md5: Optional[str]) -> None:
+ """Downloads a modality.
+
+ Args:
+ url: download url of a modality
+ md5: md5 of a modality
+ """
+ download_url(
+ url,
+ self.root,
+ filename=os.path.split(url)[1],
+ md5=md5 if self.checksum else None,
+ )
+
+ def _extract(self, path: str) -> None:
+ """Extracts a modality.
+
+ Args:
+ path: path to the modality folder
+ """
+ filepath = os.path.join(self.root, os.path.split(path)[1])
+ extract_archive(filepath)
+
+ def _merge_parts(self, modality: str) -> None:
+ """Merge the modalities that are downloaded and extracted in two parts.
+
+ Args:
+ root: root directory where dataset can be found
+ modality: the filename of the modality
+ """
+ # Create a new folder named after the 'modality' variable
+ modality_folder = os.path.join(self.root, modality)
+ # Will not raise an error if the folder already exists
+ os.makedirs(modality_folder, exist_ok=True)
+
+ # List of source folders
+ source_folders = [
+ os.path.join(self.root, modality + "_part1"),
+ os.path.join(self.root, modality + "_part2"),
+ ]
+
+ # Move files from each source folder to the new 'modality' folder
+ for source_folder in source_folders:
+ for file_name in os.listdir(source_folder):
+ source = os.path.join(source_folder, file_name)
+ destination = os.path.join(modality_folder, file_name)
+ if os.path.isfile(source):
+ shutil.copy(source, destination) # Move files to 'modality' folder
+
+ def _convert_to_color(
+ self, arr_2d: Tensor, cmap: dict[int, tuple[int, int, int]]
+ ) -> "np.typing.NDArray[np.uint8]":
+ """Numeric labels to RGB-color encoding.
+
+ Args:
+ arr_2d: 2D array to be colorized
+ cmap: colormap to use when mapping the labels
+
+ Returns:
+ 3D colored image
+ """
+ arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)
+
+ for c, i in cmap.items():
+ m = arr_2d == c
+ arr_3d[m] = i
+ return arr_3d
+
+ def plot(
+ self,
+ sample: dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample image-mask pair returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ modality_channels = defaultdict(lambda: 10, {"viirs": 1, "esa_wc": 1, "s1": 2})
+
+ start_idx = 0
+ split_images = {}
+
+ for modality in self.modality:
+ end_idx = start_idx + modality_channels[modality] # Start + n of channels
+ split_images[modality] = sample["image"][start_idx:end_idx, :, :] # Slicing
+ start_idx = end_idx # Update the iterator
+
+ # Prepare the mask
+ mask = sample["mask"].squeeze()
+ color_mask = self._convert_to_color(mask, cmap=self.mask_cmap)
+
+ num_subplots = len(split_images) + 1 # +1 for color_mask
+ showing_predictions = "prediction" in sample
+ if showing_predictions:
+ num_subplots += 1
+
+ fig, axs = plt.subplots(1, num_subplots, figsize=(num_subplots * 4, 5))
+
+ # Plot each modality in its respective axis
+ for i, (modality, image) in enumerate(split_images.items()):
+ ax = axs[i]
+ img = np.transpose(image, (1, 2, 0)).squeeze()
+ # Apply transformations based on modality type
+ if modality.startswith("s2"):
+ img = img[:, :, [4, 3, 2]]
+ if modality == "esa_wc":
+ img = self._convert_to_color(torch.as_tensor(img), cmap=self.wc_cmap)
+ if modality == "s1":
+ img = img[:, :, 0]
+
+ if not "esa_wc":
+ img = percentile_normalization(img)
+
+ ax.imshow(img)
+ if show_titles:
+ ax.set_title(modality)
+ ax.axis("off")
+
+ # Plot color_mask in its own axis
+ axs[len(split_images)].imshow(color_mask)
+ if show_titles:
+ axs[len(split_images)].set_title("Annotation")
+ axs[len(split_images)].axis("off")
+
+ # If available, plot predictions in a new axis
+ if showing_predictions:
+ prediction = sample["prediction"].squeeze()
+ color_predictions = self._convert_to_color(prediction, cmap=self.mask_cmap)
+ axs[-1].imshow(color_predictions, vmin=0, vmax=1, interpolation="none")
+ if show_titles:
+ axs[-1].set_title("Prediction")
+ axs[-1].axis("off")
+
+ plt.tight_layout()
+ return fig
diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py
index d3ecb7224ef..ed9a9c156ea 100644
--- a/torchgeo/datasets/millionaid.py
+++ b/torchgeo/datasets/millionaid.py
@@ -9,12 +9,13 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from torchgeo.datasets import NonGeoDataset
-from .utils import check_integrity, extract_archive
+from .utils import DatasetNotFoundError, check_integrity, extract_archive
class MillionAID(NonGeoDataset):
@@ -204,7 +205,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if dataset is not found
+ DatasetNotFoundError: If dataset is not found.
"""
self.root = root
self.transforms = transforms
@@ -325,18 +326,14 @@ def _verify(self) -> None:
extract_archive(filepath)
return
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` directory, either "
- "specify a different `root` directory or manually download "
- "the dataset to this directory."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py
index 027ad8aa409..f70ff47a259 100644
--- a/torchgeo/datasets/naip.py
+++ b/torchgeo/datasets/naip.py
@@ -6,6 +6,7 @@
from typing import Any, Optional
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from .geo import RasterDataset
@@ -52,7 +53,7 @@ def plot(
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py
index 3f124f42dba..a1637f46e7d 100644
--- a/torchgeo/datasets/nasa_marine_debris.py
+++ b/torchgeo/datasets/nasa_marine_debris.py
@@ -10,11 +10,17 @@
import numpy as np
import rasterio
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from torchvision.utils import draw_bounding_boxes
from .geo import NonGeoDataset
-from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_radiant_mlhub_collection,
+ extract_archive,
+)
class NASAMarineDebris(NonGeoDataset):
@@ -76,6 +82,9 @@ def __init__(
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
verbose: if True, print messages when new tiles are loaded
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.transforms = transforms
@@ -174,11 +183,7 @@ def _load_files(self) -> list[dict[str, str]]:
return files
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
exists = [
os.path.exists(os.path.join(self.root, directory))
@@ -204,11 +209,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
for collection_id in self.collection_ids:
@@ -224,7 +225,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py
index d0fda73fc53..6c1243538f3 100644
--- a/torchgeo/datasets/nlcd.py
+++ b/torchgeo/datasets/nlcd.py
@@ -5,14 +5,16 @@
import glob
import os
-from typing import Any, Callable, Optional
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
import torch
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
-from .utils import BoundingBox, download_url, extract_archive
+from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class NLCD(RasterDataset):
@@ -105,7 +107,7 @@ class NLCD(RasterDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: Optional[float] = None,
years: list[int] = [2019],
@@ -118,7 +120,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -134,8 +136,7 @@ def __init__(
Raises:
AssertionError: if ``years`` or ``classes`` are invalid
- FileNotFoundError: if no files are found in ``root``
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert set(years) <= self.md5s.keys(), (
"NLCD data product only exists for the following years: "
@@ -146,7 +147,7 @@ def __init__(
), f"Only the following classes are valid: {list(self.cmap.keys())}."
assert 0 in classes, "Classes must include the background class: 0"
- self.root = root
+ self.paths = paths
self.years = years
self.classes = classes
self.download = download
@@ -156,7 +157,7 @@ def __init__(
self._verify()
- super().__init__(root, crs, res, transforms=transforms, cache=cache)
+ super().__init__(paths, crs, res, transforms=transforms, cache=cache)
# Map chosen classes to ordinal numbers, all others mapped to background class
for v, k in enumerate(self.classes):
@@ -180,29 +181,17 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
return sample
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
- exists = []
- for year in self.years:
- filename_year = self.filename_glob.replace("*", str(year), 1)
- pathname = os.path.join(self.root, "**", filename_year)
- if glob.glob(pathname, recursive=True):
- exists.append(True)
- else:
- exists.append(False)
-
- if all(exists):
+ if self.files:
return
# Check if the zip files have already been downloaded
exists = []
for year in self.years:
zipfile_year = self.zipfile_glob.replace("*", str(year), 1)
- pathname = os.path.join(self.root, "**", zipfile_year)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, "**", zipfile_year)
if glob.glob(pathname, recursive=True):
exists.append(True)
self._extract()
@@ -214,11 +203,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -229,7 +214,7 @@ def _download(self) -> None:
for year in self.years:
download_url(
self.url.format(year),
- self.root,
+ self.paths,
md5=self.md5s[year] if self.checksum else None,
)
@@ -237,15 +222,16 @@ def _extract(self) -> None:
"""Extract the dataset."""
for year in self.years:
zipfile_name = self.zipfile_glob.replace("*", str(year), 1)
- pathname = os.path.join(self.root, "**", zipfile_name)
- extract_archive(glob.glob(pathname, recursive=True)[0], self.root)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, "**", zipfile_name)
+ extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py
index 57e735b59a3..e05ea93220c 100644
--- a/torchgeo/datasets/openbuildings.py
+++ b/torchgeo/datasets/openbuildings.py
@@ -7,20 +7,23 @@
import json
import os
import sys
-from typing import Any, Callable, Optional, cast
+from collections.abc import Iterable
+from typing import Any, Callable, Optional, Union, cast
import fiona
import fiona.transform
import matplotlib.pyplot as plt
+import pandas as pd
import rasterio
import shapely
import shapely.wkt as wkt
import torch
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from rtree.index import Index, Property
from .geo import VectorDataset
-from .utils import BoundingBox, check_integrity
+from .utils import BoundingBox, DatasetNotFoundError, check_integrity
class OpenBuildings(VectorDataset):
@@ -203,7 +206,7 @@ class OpenBuildings(VectorDataset):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 0.0001,
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -212,7 +215,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -221,28 +224,24 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
- self.root = root
+ self.paths = paths
self.res = res
self.checksum = checksum
- self.root = root
self.res = res
self.transforms = transforms
self._verify()
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
-
# Create an R-tree to index the dataset using the polygon centroid as bounds
self.index = Index(interleaved=False, properties=Property(dimension=3))
- with open(os.path.join(root, "tiles.geojson")) as f:
+ assert isinstance(self.paths, str)
+ with open(os.path.join(self.paths, "tiles.geojson")) as f:
data = json.load(f)
features = data["features"]
@@ -250,7 +249,7 @@ def __init__(
feature["properties"]["tile_url"].split("/")[-1] for feature in features
] # get csv filename
- polygon_files = glob.glob(os.path.join(self.root, self.zipfile_glob))
+ polygon_files = glob.glob(os.path.join(self.paths, self.zipfile_glob))
polygon_filenames = [f.split(os.sep)[-1] for f in polygon_files]
matched_features = [
@@ -279,15 +278,13 @@ def __init__(
coords = (minx, maxx, miny, maxy, mint, maxt)
filepath = os.path.join(
- self.root, feature["properties"]["tile_url"].split("/")[-1]
+ self.paths, feature["properties"]["tile_url"].split("/")[-1]
)
self.index.insert(i, coords, filepath)
i += 1
if i == 0:
- raise FileNotFoundError(
- f"No {self.__class__.__name__} data was found in '{self.root}'"
- )
+ raise DatasetNotFoundError(self)
self._crs = crs
self._source_crs = source_crs
@@ -349,8 +346,6 @@ def _filter_geometries(
List with all polygons from all hit filepaths
"""
- import pandas as pd
-
# We need to know the bounding box of the query in the source CRS
(minx, maxx), (miny, maxy) = fiona.transform.transform(
self._crs.to_dict(),
@@ -398,14 +393,10 @@ def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]:
return transformed
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if dataset is missing or checksum fails
- FileNotFoundError: if metadata file is not found in root
- """
+ """Verify the integrity of the dataset."""
# Check if the zip files have already been downloaded and checksum
- pathname = os.path.join(self.root, self.zipfile_glob)
+ assert isinstance(self.paths, str)
+ pathname = os.path.join(self.paths, self.zipfile_glob)
i = 0
for zipfile in glob.iglob(pathname):
filename = os.path.basename(zipfile)
@@ -416,25 +407,14 @@ def _verify(self) -> None:
if i != 0:
return
- # check if the metadata file has been downloaded
- if not os.path.exists(os.path.join(self.root, self.meta_data_filename)):
- raise FileNotFoundError(
- f"Meta data file {self.meta_data_filename} "
- f"not found in in `root={self.root}`."
- )
-
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` "
- "either specify a different `root` directory or make sure you "
- "have manually downloaded the dataset as suggested in the documentation."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py
index 6f5b7003d20..e60c4de6214 100644
--- a/torchgeo/datasets/oscd.py
+++ b/torchgeo/datasets/oscd.py
@@ -17,6 +17,7 @@
from .geo import NonGeoDataset
from .utils import (
+ DatasetNotFoundError,
download_url,
draw_semantic_segmentation_masks,
extract_archive,
@@ -78,11 +79,29 @@ class OSCD(NonGeoDataset):
colormap = ["blue"]
+ all_bands = (
+ "B01",
+ "B02",
+ "B03",
+ "B04",
+ "B05",
+ "B06",
+ "B07",
+ "B08",
+ "B8A",
+ "B09",
+ "B10",
+ "B11",
+ "B12",
+ )
+
+ rgb_bands = ("B04", "B03", "B02")
+
def __init__(
self,
root: str = "data",
split: str = "train",
- bands: str = "all",
+ bands: Sequence[str] = all_bands,
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
download: bool = False,
checksum: bool = False,
@@ -99,15 +118,15 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
- assert bands in ["rgb", "all"]
+ assert set(bands) <= set(self.all_bands)
+ self.bands = bands
+ self.all_band_indices = [self.all_bands.index(b) for b in self.bands]
self.root = root
self.split = split
- self.bands = bands
self.transforms = transforms
self.download = download
self.checksum = checksum
@@ -129,9 +148,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
image1 = self._load_image(files["images1"])
image2 = self._load_image(files["images2"])
mask = self._load_target(str(files["mask"]))
-
- image = torch.cat([image1, image2])
- sample = {"image": image, "mask": mask}
+ sample = {"image1": image1, "image2": image2, "mask": mask}
if self.transforms is not None:
sample = self.transforms(sample)
@@ -170,8 +187,8 @@ def get_image_paths(ind: int) -> list[str]:
)
images1, images2 = get_image_paths(1), get_image_paths(2)
- if self.bands == "rgb":
- images1, images2 = images1[1:4][::-1], images2[1:4][::-1]
+ images1 = [images1[i] for i in self.all_band_indices]
+ images2 = [images2[i] for i in self.all_band_indices]
with open(os.path.join(images_root, region, "dates.txt")) as f:
dates = tuple(
@@ -204,7 +221,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor:
with Image.open(path) as img:
images.append(np.array(img))
array: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0).astype(np.int_)
- tensor = torch.from_numpy(array)
+ tensor = torch.from_numpy(array).float()
return tensor
def _load_target(self, path: str) -> Tensor:
@@ -225,11 +242,7 @@ def _load_target(self, path: str) -> Tensor:
return tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
pathname = os.path.join(self.root, "**", self.filename_glob)
for fname in glob.iglob(pathname, recursive=True):
@@ -244,11 +257,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -287,13 +296,21 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
+
+ Raises:
+ ValueError: If *bands* does not include all RGB bands.
"""
ncols = 2
- rgb_inds = [3, 2, 1] if self.bands == "all" else [0, 1, 2]
+ try:
+ rgb_indices = [self.bands.index(band) for band in self.rgb_bands]
+ except ValueError as e:
+ raise ValueError(
+ "RGB bands must be present to use `plot` with S2 imagery."
+ ) from e
def get_masked(img: Tensor) -> "np.typing.NDArray[np.uint8]":
- rgb_img = img[rgb_inds].float().numpy()
+ rgb_img = img[rgb_indices].float().numpy()
per02 = np.percentile(rgb_img, 2)
per98 = np.percentile(rgb_img, 98)
rgb_img = (np.clip((rgb_img - per02) / (per98 - per02), 0, 1) * 255).astype(
@@ -307,9 +324,8 @@ def get_masked(img: Tensor) -> "np.typing.NDArray[np.uint8]":
)
return array
- idx = sample["image"].shape[0] // 2
- image1 = get_masked(sample["image"][:idx])
- image2 = get_masked(sample["image"][idx:])
+ image1 = get_masked(sample["image1"])
+ image2 = get_masked(sample["image2"])
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
axs[0].imshow(image1)
axs[0].axis("off")
diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py
index f3e771c92e5..f551408b039 100644
--- a/torchgeo/datasets/pastis.py
+++ b/torchgeo/datasets/pastis.py
@@ -12,10 +12,11 @@
import numpy as np
import torch
from matplotlib.colors import ListedColormap
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_url, extract_archive
+from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive
class PASTIS(NonGeoDataset):
@@ -148,6 +149,9 @@ def __init__(
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert set(folds) <= set(range(6))
assert bands in ["s1a", "s1d", "s2"]
@@ -307,11 +311,7 @@ def _load_files(self) -> list[dict[str, str]]:
return files
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the directory already exists
path = os.path.join(self.root, self.directory)
if os.path.exists(path):
@@ -327,11 +327,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()
@@ -351,7 +347,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py
index a13e2431b26..876f14dd59b 100644
--- a/torchgeo/datasets/patternnet.py
+++ b/torchgeo/datasets/patternnet.py
@@ -7,10 +7,11 @@
from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class PatternNet(NonGeoClassificationDataset):
@@ -95,6 +96,9 @@ def __init__(
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.download = download
@@ -103,11 +107,7 @@ def __init__(
super().__init__(root=os.path.join(root, self.directory), transforms=transforms)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
filepath = os.path.join(self.root, self.directory)
if os.path.exists(filepath):
@@ -121,11 +121,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()
@@ -150,7 +146,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py
index 5443c807d22..782fd7ce87d 100644
--- a/torchgeo/datasets/potsdam.py
+++ b/torchgeo/datasets/potsdam.py
@@ -16,6 +16,7 @@
from .geo import NonGeoDataset
from .utils import (
+ DatasetNotFoundError,
check_integrity,
draw_semantic_segmentation_masks,
extract_archive,
@@ -133,6 +134,10 @@ def __init__(
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ AssertionError: If *split* is invalid.
+ DatasetNotFoundError: If dataset is not found.
"""
assert split in self.splits
self.root = root
@@ -209,11 +214,7 @@ def _load_target(self, index: int) -> Tensor:
return tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if checksum fails or the dataset is not downloaded
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
if os.path.exists(os.path.join(self.root, self.image_root)):
return
@@ -233,11 +234,7 @@ def _verify(self) -> None:
if all(exists):
return
- # Check if the user requested to download the dataset
- raise RuntimeError(
- "Dataset not found in `root` directory, either specify a different"
- + " `root` directory or manually download the dataset to this directory."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py
index 1e50165e5b8..9ee8a82d617 100644
--- a/torchgeo/datasets/reforestree.py
+++ b/torchgeo/datasets/reforestree.py
@@ -10,12 +10,19 @@
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
+import pandas as pd
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_and_extract_archive, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_and_extract_archive,
+ extract_archive,
+)
class ReforesTree(NonGeoDataset):
@@ -76,8 +83,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.transforms = transforms
@@ -86,13 +92,6 @@ def __init__(
self._verify()
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
-
self.files = self._load_files(self.root)
self.annot_df = pd.read_csv(os.path.join(root, "mapping", "final_dataset.csv"))
@@ -178,11 +177,7 @@ def _load_target(self, filepath: str) -> tuple[Tensor, ...]:
return boxes, labels, agb
def _verify(self) -> None:
- """Checks the integrity of the dataset structure.
-
- Raises:
- RuntimeError: if dataset is not found in root or is corrupted
- """
+ """Checks the integrity of the dataset structure."""
filepaths = [os.path.join(self.root, dir) for dir in ["tiles", "mapping"]]
if all([os.path.exists(filepath) for filepath in filepaths]):
return
@@ -196,21 +191,13 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# else download the dataset
self._download()
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- AssertionError: if the checksum does not match
- """
+ """Download the dataset and extract it."""
download_and_extract_archive(
self.url,
self.root,
@@ -223,7 +210,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py
index 2c5e5871b79..07fdfd272a8 100644
--- a/torchgeo/datasets/resisc45.py
+++ b/torchgeo/datasets/resisc45.py
@@ -8,10 +8,11 @@
import matplotlib.pyplot as plt
import numpy as np
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class RESISC45(NonGeoClassificationDataset):
@@ -88,7 +89,6 @@ class RESISC45(NonGeoClassificationDataset):
If you use this dataset in your research, please cite the following paper:
* https://doi.org/10.1109/jproc.2017.2675998
-
"""
url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
@@ -107,53 +107,6 @@ class RESISC45(NonGeoClassificationDataset):
"val": "a0770cee4c5ca20b8c32bbd61e114805",
"test": "3dda9e4988b47eb1de9f07993653eb08",
}
- classes = [
- "airplane",
- "airport",
- "baseball_diamond",
- "basketball_court",
- "beach",
- "bridge",
- "chaparral",
- "church",
- "circular_farmland",
- "cloud",
- "commercial_area",
- "dense_residential",
- "desert",
- "forest",
- "freeway",
- "golf_course",
- "ground_track_field",
- "harbor",
- "industrial_area",
- "intersection",
- "island",
- "lake",
- "meadow",
- "medium_residential",
- "mobile_home_park",
- "mountain",
- "overpass",
- "palace",
- "parking_lot",
- "railway",
- "railway_station",
- "rectangular_farmland",
- "river",
- "roundabout",
- "runway",
- "sea_ice",
- "ship",
- "snowberg",
- "sparse_residential",
- "stadium",
- "storage_tank",
- "tennis_court",
- "terrace",
- "thermal_power_station",
- "wetland",
- ]
def __init__(
self,
@@ -172,6 +125,9 @@ def __init__(
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
self.root = root
@@ -192,11 +148,7 @@ def __init__(
)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
filepath = os.path.join(self.root, self.directory)
if os.path.exists(filepath):
@@ -210,11 +162,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()
@@ -246,7 +194,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py
new file mode 100644
index 00000000000..32f6a1625e2
--- /dev/null
+++ b/torchgeo/datasets/rwanda_field_boundary.py
@@ -0,0 +1,320 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Rwanda Field Boundary Competition dataset."""
+
+import os
+from collections.abc import Sequence
+from typing import Callable, Optional
+
+import matplotlib.pyplot as plt
+import numpy as np
+import rasterio
+import rasterio.features
+import torch
+from matplotlib.figure import Figure
+from torch import Tensor
+
+from .geo import NonGeoDataset
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_radiant_mlhub_collection,
+ extract_archive,
+)
+
+
+class RwandaFieldBoundary(NonGeoDataset):
+ r"""Rwanda Field Boundary Competition dataset.
+
+ This dataset contains field boundaries for smallholder farms in eastern Rwanda.
+ The Nasa Harvest program funded a team of annotators from TaQadam to label Planet
+ imagery for the 2021 growing season for the purpose of conducting the Rwanda Field
+ boundary detection Challenge. The dataset includes rasterized labeled field
+ boundaries and time series satellite imagery from Planet's NICFI program.
+ Planet's basemap imagery is provided for six months (March, April, August, October,
+ November and December). Note: only fields that were big enough to be differentiated
+ on the Planetscope imagery were labeled, only fields that were fully contained
+ within the chips were labeled. The paired dataset is provided in 256x256 chips for a
+ total of 70 tiles covering 1532 individual fields.
+
+ The labels are provided as binary semantic segmentation labels:
+
+ 0. No field-boundary
+ 1. Field-boundary
+
+ If you use this dataset in your research, please cite the following:
+
+ * https://doi.org/10.34911/RDNT.G580WW
+
+ .. note::
+
+ This dataset requires the following additional library to be installed:
+
+ * `radiant-mlhub `_ to download the
+ imagery and labels from the Radiant Earth MLHub
+
+ .. versionadded:: 0.5
+ """
+
+ dataset_id = "nasa_rwanda_field_boundary_competition"
+ collection_ids = [
+ "nasa_rwanda_field_boundary_competition_source_train",
+ "nasa_rwanda_field_boundary_competition_labels_train",
+ "nasa_rwanda_field_boundary_competition_source_test",
+ ]
+ number_of_patches_per_split = {"train": 57, "test": 13}
+
+ filenames = {
+ "train_images": "nasa_rwanda_field_boundary_competition_source_train.tar.gz",
+ "test_images": "nasa_rwanda_field_boundary_competition_source_test.tar.gz",
+ "train_labels": "nasa_rwanda_field_boundary_competition_labels_train.tar.gz",
+ }
+ md5s = {
+ "train_images": "1f9ec08038218e67e11f82a86849b333",
+ "test_images": "17bb0e56eedde2e7a43c57aa908dc125",
+ "train_labels": "10e4eb761523c57b6d3bdf9394004f5f",
+ }
+
+ dates = ("2021_03", "2021_04", "2021_08", "2021_10", "2021_11", "2021_12")
+
+ all_bands = ("B01", "B02", "B03", "B04")
+ rgb_bands = ("B03", "B02", "B01")
+
+ classes = ["No field-boundary", "Field-boundary"]
+
+ splits = ["train", "test"]
+
+ def __init__(
+ self,
+ root: str = "data",
+ split: str = "train",
+ bands: Sequence[str] = all_bands,
+ transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
+ download: bool = False,
+ api_key: Optional[str] = None,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new RwandaFieldBoundary instance.
+
+ Args:
+ root: root directory where dataset can be found
+ split: one of "train" or "test"
+ bands: the subset of bands to load
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory
+ api_key: a RadiantEarth MLHub API key to use for downloading the dataset
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+ """
+ self._validate_bands(bands)
+ assert split in self.splits
+ if download and api_key is None:
+ raise RuntimeError("Must provide an API key to download the dataset")
+ self.root = os.path.expanduser(root)
+ self.bands = bands
+ self.transforms = transforms
+ self.split = split
+ self.download = download
+ self.api_key = api_key
+ self.checksum = checksum
+ self._verify()
+
+ self.image_filenames: list[list[list[str]]] = []
+ self.mask_filenames: list[str] = []
+ for i in range(self.number_of_patches_per_split[split]):
+ dates = []
+ for date in self.dates:
+ patch = []
+ for band in self.bands:
+ fn = os.path.join(
+ self.root,
+ f"nasa_rwanda_field_boundary_competition_source_{split}",
+ f"nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}", # noqa: E501
+ f"{band}.tif",
+ )
+ patch.append(fn)
+ dates.append(patch)
+ self.image_filenames.append(dates)
+ self.mask_filenames.append(
+ os.path.join(
+ self.root,
+ f"nasa_rwanda_field_boundary_competition_labels_{split}",
+ f"nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}",
+ "raster_labels.tif",
+ )
+ )
+
+ def __getitem__(self, index: int) -> dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ a dict containing image, mask, transform, crs, and metadata at index.
+ """
+ img_fns = self.image_filenames[index]
+ mask_fn = self.mask_filenames[index]
+
+ imgs = []
+ for date_fns in img_fns:
+ bands = []
+ for band_fn in date_fns:
+ with rasterio.open(band_fn) as f:
+ bands.append(f.read(1).astype(np.int32))
+ imgs.append(bands)
+ img = torch.from_numpy(np.array(imgs))
+
+ sample = {"image": img}
+
+ if self.split == "train":
+ with rasterio.open(mask_fn) as f:
+ mask = f.read(1)
+ mask = torch.from_numpy(mask)
+ sample["mask"] = mask
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the number of chips in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.image_filenames)
+
+ def _validate_bands(self, bands: Sequence[str]) -> None:
+ """Validate list of bands.
+
+ Args:
+ bands: user-provided sequence of bands to load
+
+ Raises:
+ ValueError: if an invalid band name is provided
+ """
+ for band in bands:
+ if band not in self.all_bands:
+ raise ValueError(f"'{band}' is an invalid band name.")
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset."""
+ # Check if the subdirectories already exist and have the correct number of files
+ checks = []
+ for split, num_patches in self.number_of_patches_per_split.items():
+ path = os.path.join(
+ self.root, f"nasa_rwanda_field_boundary_competition_source_{split}"
+ )
+ if os.path.exists(path):
+ num_files = len(os.listdir(path))
+ # 6 dates + 1 collection.json file
+ checks.append(num_files == (num_patches * 6) + 1)
+ else:
+ checks.append(False)
+
+ if all(checks):
+ return
+
+ # Check if tar file already exists (if so then extract)
+ have_all_files = True
+ for group in ["train_images", "train_labels", "test_images"]:
+ filepath = os.path.join(self.root, self.filenames[group])
+ if os.path.exists(filepath):
+ if self.checksum and not check_integrity(filepath, self.md5s[group]):
+ raise RuntimeError("Dataset found, but corrupted.")
+ extract_archive(filepath)
+ else:
+ have_all_files = False
+ if have_all_files:
+ return
+
+ # Check if the user requested to download the dataset
+ if not self.download:
+ raise DatasetNotFoundError(self)
+
+ # Download and extract the dataset
+ self._download()
+
+ def _download(self) -> None:
+ """Download the dataset and extract it."""
+ for collection_id in self.collection_ids:
+ download_radiant_mlhub_collection(collection_id, self.root, self.api_key)
+
+ for group in ["train_images", "train_labels", "test_images"]:
+ filepath = os.path.join(self.root, self.filenames[group])
+ if self.checksum and not check_integrity(filepath, self.md5s[group]):
+ raise RuntimeError("Dataset not found or corrupted.")
+ extract_archive(filepath, self.root)
+
+ def plot(
+ self,
+ sample: dict[str, Tensor],
+ show_titles: bool = True,
+ time_step: int = 0,
+ suptitle: Optional[str] = None,
+ ) -> Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ time_step: time step at which to access image, beginning with 0
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+
+ Raises:
+ ValueError: if the RGB bands are not included in ``self.bands``
+ """
+ rgb_indices = []
+ for band in self.rgb_bands:
+ if band in self.bands:
+ rgb_indices.append(self.bands.index(band))
+ else:
+ raise ValueError("Dataset doesn't contain some of the RGB bands")
+
+ num_time_points = sample["image"].shape[0]
+ assert time_step < num_time_points
+
+ image = np.rollaxis(sample["image"][time_step, rgb_indices].numpy(), 0, 3)
+ image = np.clip(image / 2000, 0, 1)
+
+ if "mask" in sample:
+ mask = sample["mask"].numpy()
+ else:
+ mask = np.zeros_like(image)
+
+ num_panels = 2
+ showing_predictions = "prediction" in sample
+ if showing_predictions:
+ predictions = sample["prediction"].numpy()
+ num_panels += 1
+
+ fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4))
+
+ axs[0].imshow(image)
+ axs[0].axis("off")
+ if show_titles:
+ axs[0].set_title(f"t={time_step}")
+
+ axs[1].imshow(mask, vmin=0, vmax=1, interpolation="none")
+ axs[1].axis("off")
+ if show_titles:
+ axs[1].set_title("Mask")
+
+ if showing_predictions:
+ axs[2].imshow(predictions, vmin=0, vmax=1, interpolation="none")
+ axs[2].axis("off")
+ if show_titles:
+ axs[2].set_title("Predictions")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+ return fig
diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py
new file mode 100644
index 00000000000..6d61fa011f1
--- /dev/null
+++ b/torchgeo/datasets/seasonet.py
@@ -0,0 +1,482 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""SeasoNet dataset."""
+
+import os
+import random
+from collections.abc import Callable, Collection, Iterable
+from typing import Optional
+
+import matplotlib.patches as mpatches
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import rasterio
+import torch
+from matplotlib.colors import ListedColormap
+from matplotlib.figure import Figure
+from rasterio.enums import Resampling
+from torch import Tensor
+
+from .geo import NonGeoDataset
+from .utils import (
+ DatasetNotFoundError,
+ download_url,
+ extract_archive,
+ percentile_normalization,
+)
+
+
+class SeasoNet(NonGeoDataset):
+ """SeasoNet Semantic Segmentation dataset.
+
+ The `SeasoNet `__ dataset consists of
+ 1,759,830 multi-spectral Sentinel-2 image patches, taken from 519,547 unique
+ locations, covering the whole surface area of Germany. Annotations are
+ provided in the form of pixel-level land cover and land usage segmentation
+ masks from the German land cover model LBM-DE2018 with land cover classes
+ based on the CORINE Land Cover database (CLC) 2018. The set is split into
+ two overlapping grids, consisting of roughly 880,000 samples each, which are
+ shifted by half the patch size in both dimensions. The images in each of the
+ both grids themselves do not overlap.
+
+ Dataset format:
+
+ * images are 16-bit GeoTiffs, split into seperate files based on resolution
+ * images include 12 spectral bands with 10, 20 and 60 m per pixel resolutions
+ * masks are single-channel 8-bit GeoTiffs
+
+ Dataset classes:
+
+ 0. Continuous urban fabric
+ 1. Discontinuous urban fabric
+ 2. Industrial or commercial units
+ 3. Road and rail networks and associated land
+ 4. Port areas
+ 5. Airports
+ 6. Mineral extraction sites
+ 7. Dump sites
+ 8. Construction sites
+ 9. Green urban areas
+ 10. Sport and leisure facilities
+ 11. Non-irrigated arable land
+ 12. Vineyards
+ 13. Fruit trees and berry plantations
+ 14. Pastures
+ 15. Broad-leaved forest
+ 16. Coniferous forest
+ 17. Mixed forest
+ 18. Natural grasslands
+ 19. Moors and heathland
+ 20. Transitional woodland/shrub
+ 21. Beaches, dunes, sands
+ 22. Bare rock
+ 23. Sparsely vegetated areas
+ 24. Inland marshes
+ 25. Peat bogs
+ 26. Salt marshes
+ 27. Intertidal flats
+ 28. Water courses
+ 29. Water bodies
+ 30. Coastal lagoons
+ 31. Estuaries
+ 32. Sea and ocean
+
+ If you use this dataset in your research, please cite the following paper:
+
+ * https://doi.org/10.1109/IGARSS46834.2022.9884079
+
+ .. versionadded:: 0.5
+ """
+
+ metadata = [
+ {
+ "name": "spring",
+ "ext": ".zip",
+ "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip", # noqa: E501
+ "md5": "de4cdba7b6196aff624073991b187561",
+ },
+ {
+ "name": "summer",
+ "ext": ".zip",
+ "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip", # noqa: E501
+ "md5": "6a54d4e134d27ae4eb03f180ee100550",
+ },
+ {
+ "name": "fall",
+ "ext": ".zip",
+ "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip", # noqa: E501
+ "md5": "5f94920fe41a63c6bfbab7295f7d6b95",
+ },
+ {
+ "name": "winter",
+ "ext": ".zip",
+ "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip", # noqa: E501
+ "md5": "dc5e3e09e52ab5c72421b1e3186c9a48",
+ },
+ {
+ "name": "snow",
+ "ext": ".zip",
+ "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip", # noqa: E501
+ "md5": "e1b300994143f99ebb03f51d6ab1cbe6",
+ },
+ {
+ "name": "splits",
+ "ext": ".zip",
+ "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip", # noqa: E501
+ "md5": "e4ec4a18bc4efc828f0944a7cf4d5fed",
+ },
+ {
+ "name": "meta.csv",
+ "ext": "",
+ "url": "https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv", # noqa: E501
+ "md5": "43ea07974936a6bf47d989c32e16afe7",
+ },
+ ]
+ classes = [
+ "Continuous urban fabric",
+ "Discontinuous urban fabric",
+ "Industrial or commercial units",
+ "Road and rail networks and associated land",
+ "Port areas",
+ "Airports",
+ "Mineral extraction sites",
+ "Dump sites",
+ "Construction sites",
+ "Green urban areas",
+ "Sport and leisure facilities",
+ "Non-irrigated arable land",
+ "Vineyards",
+ "Fruit trees and berry plantations",
+ "Pastures",
+ "Broad-leaved forest",
+ "Coniferous forest",
+ "Mixed forest",
+ "Natural grasslands",
+ "Moors and heathland",
+ "Transitional woodland/shrub",
+ "Beaches, dunes, sands",
+ "Bare rock",
+ "Sparsely vegetated areas",
+ "Inland marshes",
+ "Peat bogs",
+ "Salt marshes",
+ "Intertidal flats",
+ "Water courses",
+ "Water bodies",
+ "Coastal lagoons",
+ "Estuaries",
+ "Sea and ocean",
+ ]
+ all_seasons = {"Spring", "Summer", "Fall", "Winter", "Snow"}
+ all_bands = ("10m_RGB", "10m_IR", "20m", "60m")
+ band_nums = {"10m_RGB": 3, "10m_IR": 1, "20m": 6, "60m": 2}
+ splits = ["train", "val", "test"]
+ cmap = {
+ 0: (230, 000, 77, 255),
+ 1: (255, 000, 000, 255),
+ 2: (204, 77, 242, 255),
+ 3: (204, 000, 000, 255),
+ 4: (230, 204, 204, 255),
+ 5: (230, 204, 230, 255),
+ 6: (166, 000, 204, 255),
+ 7: (166, 77, 000, 255),
+ 8: (255, 77, 255, 255),
+ 9: (255, 166, 255, 255),
+ 10: (255, 230, 255, 255),
+ 11: (255, 255, 168, 255),
+ 12: (230, 128, 000, 255),
+ 13: (242, 166, 77, 255),
+ 14: (230, 230, 77, 255),
+ 15: (128, 255, 000, 255),
+ 16: (000, 166, 000, 255),
+ 17: (77, 255, 000, 255),
+ 18: (204, 242, 77, 255),
+ 19: (166, 255, 128, 255),
+ 20: (166, 242, 000, 255),
+ 21: (230, 230, 230, 255),
+ 22: (204, 204, 204, 255),
+ 23: (204, 255, 204, 255),
+ 24: (166, 166, 255, 255),
+ 25: (77, 77, 255, 255),
+ 26: (204, 204, 255, 255),
+ 27: (166, 166, 230, 255),
+ 28: (000, 204, 242, 255),
+ 29: (128, 242, 230, 255),
+ 30: (000, 255, 166, 255),
+ 31: (166, 255, 230, 255),
+ 32: (230, 242, 255, 255),
+ }
+ image_size = (120, 120)
+
+ def __init__(
+ self,
+ root: str = "data",
+ split: str = "train",
+ seasons: Collection[str] = all_seasons,
+ bands: Iterable[str] = all_bands,
+ grids: Iterable[int] = [1, 2],
+ concat_seasons: int = 1,
+ transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new SeasoNet dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ split: one of "train", "val" or "test"
+ seasons: list of seasons to load
+ bands: list of bands to load
+ grids: which of the overlapping grids to load
+ concat_seasons: number of seasonal images to return per sample.
+ if 1, each seasonal image is returned as its own sample,
+ otherwise seasonal images are randomly picked from the seasons
+ specified in ``seasons`` and returned as stacked tensors
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ DatasetNotFoundError: If dataset is not found and *download* is False.
+ """
+ assert split in self.splits
+ assert set(seasons) <= self.all_seasons
+ assert set(bands) <= set(self.all_bands)
+ assert set(grids) <= {1, 2}
+ assert concat_seasons in range(1, len(seasons) + 1)
+
+ self.root = root
+ self.bands = bands
+ self.concat_seasons = concat_seasons
+ self.transforms = transforms
+ self.download = download
+ self.checksum = checksum
+
+ self._verify()
+
+ self.channels = 0
+ for b in bands:
+ self.channels += self.band_nums[b]
+
+ csv = pd.read_csv(os.path.join(self.root, "meta.csv"), index_col="Index")
+
+ if split is not None:
+ # Filter entries by split
+ split_csv = pd.read_csv(
+ os.path.join(self.root, f"splits/{split}.csv"), header=None
+ )[0]
+ csv = csv.iloc[split_csv]
+
+ # Filter entries by grids and seasons
+ csv = csv[csv["Grid"].isin(grids)]
+ csv = csv[csv["Season"].isin(seasons)]
+
+ # Replace relative data paths with absolute paths
+ csv["Path"] = csv["Path"].apply(
+ lambda p: [os.path.join(self.root, p, os.path.basename(p))]
+ )
+
+ if self.concat_seasons > 1:
+ # Group entries by location
+ self.files = csv.groupby(["Latitude", "Longitude"])
+ self.files = self.files["Path"].agg("sum")
+
+ # Remove entries with less than concat_seasons available seasons
+ self.files = self.files[
+ self.files.apply(lambda d: len(d) >= self.concat_seasons)
+ ]
+ else:
+ self.files = csv["Path"]
+
+ def __getitem__(self, index: int) -> dict[str, Tensor]:
+ """Return an index within the dataset.
+
+ Args:
+ index: index to return
+
+ Returns:
+ sample at that index containing the image with shape SCxHxW
+ and the mask with shape HxW, where ``S = self.concat_seasons``
+ """
+ image = self._load_image(index)
+ mask = self._load_target(index)
+ sample = {"image": image, "mask": mask}
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+ def __len__(self) -> int:
+ """Return the number of data points in the dataset.
+
+ Returns:
+ length of the dataset
+ """
+ return len(self.files)
+
+ def _load_image(self, index: int) -> Tensor:
+ """Load image(s) for a single location.
+
+ Args:
+ index: index to return
+
+ Returns:
+ the stacked seasonal images
+ """
+ paths = self.files.iloc[index]
+ if self.concat_seasons > 1:
+ paths = random.sample(paths, self.concat_seasons)
+ tensor = torch.empty(self.concat_seasons * self.channels, *self.image_size)
+ for img_idx, path in enumerate(paths):
+ bnd_idx = 0
+ for band in self.bands:
+ with rasterio.open(f"{path}_{band}.tif") as f:
+ array = f.read(
+ out_shape=[f.count] + list(self.image_size),
+ out_dtype="int32",
+ resampling=Resampling.bilinear,
+ )
+ image = torch.from_numpy(array).float()
+ c = img_idx * self.channels + bnd_idx
+ tensor[c : c + image.shape[0]] = image
+ bnd_idx += image.shape[0]
+ return tensor
+
+ def _load_target(self, index: int) -> Tensor:
+ """Load the target mask for a single location.
+
+ Args:
+ index: index to return
+
+ Returns:
+ the target mask
+ """
+ path = self.files.iloc[index][0]
+ with rasterio.open(f"{path}_labels.tif") as f:
+ array = f.read() - 1
+ tensor = torch.from_numpy(array).squeeze().long()
+ return tensor
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset."""
+ # Check if all files already exist
+ if all(
+ os.path.exists(os.path.join(self.root, file_info["name"]))
+ for file_info in self.metadata
+ ):
+ return
+
+ # Check for downloaded files
+ missing = []
+ extractable = []
+ for file_info in self.metadata:
+ file_path = os.path.join(self.root, file_info["name"] + file_info["ext"])
+ if not os.path.exists(file_path):
+ missing.append(file_info)
+ elif file_info["ext"] == ".zip":
+ extractable.append(file_path)
+
+ # Check if the user requested to download the dataset
+ if missing and not self.download:
+ raise DatasetNotFoundError(self)
+
+ # Download missing files
+ for file_info in missing:
+ download_url(
+ file_info["url"],
+ self.root,
+ filename=file_info["name"] + file_info["ext"],
+ md5=file_info["md5"] if self.checksum else None,
+ )
+ if file_info["ext"] == ".zip":
+ extractable.append(os.path.join(self.root, file_info["name"] + ".zip"))
+
+ # Extract downloaded files
+ for file_path in extractable:
+ extract_archive(file_path)
+
+ def plot(
+ self,
+ sample: dict[str, Tensor],
+ show_titles: bool = True,
+ show_legend: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ show_legend: flag indicating whether to show a legend for
+ the segmentation masks
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+
+ Raises:
+ ValueError: If *bands* does not contain all RGB bands.
+ """
+ if "10m_RGB" not in self.bands:
+ raise ValueError("Dataset does not contain RGB bands")
+
+ ncols = self.concat_seasons + 1
+
+ images, mask = sample["image"], sample["mask"]
+ show_predictions = "prediction" in sample
+ if show_predictions:
+ prediction = sample["prediction"]
+ ncols += 1
+
+ plt_cmap = ListedColormap(np.array(list(self.cmap.values())) / 255)
+
+ start = 0
+ for b in self.bands:
+ if b == "10m_RGB":
+ break
+ start += self.band_nums[b]
+ rgb_indices = [start + s * self.channels for s in range(self.concat_seasons)]
+
+ fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 4.5, 5))
+ fig.subplots_adjust(wspace=0.05)
+ for ax, index in enumerate(rgb_indices):
+ image = images[index : index + 3].permute(1, 2, 0).numpy()
+ image = percentile_normalization(image)
+ axs[ax].imshow(image)
+ axs[ax].axis("off")
+ if show_titles:
+ axs[ax].set_title(f"Image {ax+1}")
+
+ axs[ax + 1].imshow(mask, vmin=0, vmax=32, cmap=plt_cmap, interpolation="none")
+ axs[ax + 1].axis("off")
+ if show_titles:
+ axs[ax + 1].set_title("Mask")
+
+ if show_predictions:
+ axs[ax + 2].imshow(
+ prediction, vmin=0, vmax=32, cmap=plt_cmap, interpolation="none"
+ )
+ axs[ax + 2].axis("off")
+ if show_titles:
+ axs[ax + 2].set_title("Prediction")
+
+ if show_legend:
+ lgd = np.unique(mask)
+
+ if show_predictions:
+ lgd = np.union1d(lgd, np.unique(prediction))
+ patches = [
+ mpatches.Patch(color=plt_cmap(i), label=self.classes[i]) for i in lgd
+ ]
+ plt.legend(
+ handles=patches, bbox_to_anchor=(1.05, 1), borderaxespad=0, loc=2
+ )
+
+ if suptitle is not None:
+ plt.suptitle(suptitle, size="xx-large")
+
+ return fig
diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py
index 757f6765448..e6abd0e1d2b 100644
--- a/torchgeo/datasets/seco.py
+++ b/torchgeo/datasets/seco.py
@@ -11,11 +11,17 @@
import numpy as np
import rasterio
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive, percentile_normalization
+from .utils import (
+ DatasetNotFoundError,
+ download_url,
+ extract_archive,
+ percentile_normalization,
+)
class SeasonalContrastS2(NonGeoDataset):
@@ -93,8 +99,7 @@ def __init__(
Raises:
AssertionError: if ``version`` argument is invalid
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert version in self.metadata.keys()
assert seasons in range(5)
@@ -182,11 +187,7 @@ def _load_patch(self, root: str, subdir: str) -> Tensor:
return image
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
directory_path = os.path.join(
self.root, self.metadata[self.version]["directory"]
@@ -202,11 +203,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -232,7 +229,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py
index 4f49ff67a46..1a0f812c368 100644
--- a/torchgeo/datasets/sen12ms.py
+++ b/torchgeo/datasets/sen12ms.py
@@ -11,10 +11,11 @@
import numpy as np
import rasterio
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, percentile_normalization
+from .utils import DatasetNotFoundError, check_integrity, percentile_normalization
class SEN12MS(NonGeoDataset):
@@ -188,7 +189,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if data is not found in ``root``, or checksums don't match
+ DatasetNotFoundError: If dataset is not found.
"""
assert split in ["train", "test"]
@@ -203,12 +204,10 @@ def __init__(
self.transforms = transforms
self.checksum = checksum
- if checksum:
- if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted.")
- else:
- if not self._check_integrity_light():
- raise RuntimeError("Dataset not found or corrupted.")
+ if (
+ checksum and not self._check_integrity()
+ ) or not self._check_integrity_light():
+ raise DatasetNotFoundError(self)
with open(os.path.join(self.root, split + "_list.txt")) as f:
self.ids = [line.rstrip() for line in f.readlines()]
@@ -317,7 +316,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py
index 2b7a32d285c..2c1eebe51a7 100644
--- a/torchgeo/datasets/sentinel.py
+++ b/torchgeo/datasets/sentinel.py
@@ -3,11 +3,12 @@
"""Sentinel datasets."""
-from collections.abc import Sequence
-from typing import Any, Callable, Optional
+from collections.abc import Iterable, Sequence
+from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
import torch
+from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@@ -139,7 +140,7 @@ class Sentinel1(Sentinel):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, list[str]] = "data",
crs: Optional[CRS] = None,
res: float = 10,
bands: Sequence[str] = ["VV", "VH"],
@@ -149,7 +150,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -161,7 +162,10 @@ def __init__(
Raises:
AssertionError: if ``bands`` is invalid
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*.
"""
assert len(bands) > 0, "'bands' cannot be an empty list"
assert len(bands) == len(set(bands)), "'bands' contains duplicate bands"
@@ -183,14 +187,14 @@ def __init__(
self.filename_glob = self.filename_glob.format(bands[0])
- super().__init__(root, crs, res, bands, transforms, cache)
+ super().__init__(paths, crs, res, bands, transforms, cache)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
@@ -292,7 +296,7 @@ class Sentinel2(Sentinel):
def __init__(
self,
- root: str = "data",
+ paths: Union[str, Iterable[str]] = "data",
crs: Optional[CRS] = None,
res: float = 10,
bands: Optional[Sequence[str]] = None,
@@ -302,7 +306,7 @@ def __init__(
"""Initialize a new Dataset instance.
Args:
- root: root directory where dataset can be found
+ paths: one or more root directories to search or files to load
crs: :term:`coordinate reference system (CRS)` to warp to
(defaults to the CRS of the first file found)
res: resolution of the dataset in units of CRS
@@ -313,20 +317,23 @@ def __init__(
cache: if True, cache file handle to speed up repeated sampling
Raises:
- FileNotFoundError: if no files are found in ``root``
+ DatasetNotFoundError: If dataset is not found.
+
+ .. versionchanged:: 0.5
+ *root* was renamed to *paths*
"""
bands = bands or self.all_bands
self.filename_glob = self.filename_glob.format(bands[0])
self.filename_regex = self.filename_regex.format(res)
- super().__init__(root, crs, res, bands, transforms, cache)
+ super().__init__(paths, crs, res, bands, transforms, cache)
def plot(
self,
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py
index 572066a6198..156b3f1568d 100644
--- a/torchgeo/datasets/skippd.py
+++ b/torchgeo/datasets/skippd.py
@@ -9,10 +9,12 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from einops import rearrange
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class SKIPPD(NonGeoDataset):
@@ -32,9 +34,19 @@ class SKIPPD(NonGeoDataset):
* fish-eye RGB images (64x64px)
* power output measurements from 30-kW rooftop PV array
* 1-min interval across 3 years (2017-2019)
+
+ Nowcast task:
+
* 349,372 images under the split key *trainval*
* 14,003 images under the split key *test*
+ Forecast task:
+
+ * 130,412 images under the split key *trainval*
+ * 2,462 images under the split key *test*
+ * consists of a concatenated RGB time-series of 16
+ time-steps
+
If you use this dataset in your research, please cite:
* https://doi.org/10.48550/arXiv.2207.00913
@@ -42,21 +54,26 @@ class SKIPPD(NonGeoDataset):
.. versionadded:: 0.5
"""
- url = "https://stacks.stanford.edu/object/dj417rh1007"
- md5 = "b38d0f322aaeb254445e2edd8bc5d012"
-
- img_file_name = "2017_2019_images_pv_processed.hdf5"
+ url = "https://huggingface.co/datasets/torchgeo/skippd/resolve/main/{}"
+ md5 = {
+ "forecast": "f4f3509ddcc83a55c433be9db2e51077",
+ "nowcast": "0000761d403e45bb5f86c21d3c69aa80",
+ }
- data_dir = "dj417rh1007"
+ data_file_name = "2017_2019_images_pv_processed_{}.hdf5"
+ zipfile_name = "2017_2019_images_pv_processed_{}.zip"
valid_splits = ["trainval", "test"]
+ valid_tasks = ["nowcast", "forecast"]
+
dateformat = "%m/%d/%Y, %H:%M:%S"
def __init__(
self,
root: str = "data",
split: str = "trainval",
+ task: str = "nowcast",
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
download: bool = False,
checksum: bool = False,
@@ -66,21 +83,27 @@ def __init__(
Args:
root: root directory where dataset can be found
split: one of "trainval", or "test"
+ task: one fo "nowcast", or "forecast"
transforms: a function/transform that takes an input sample
and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 after downloading files (may be slow)
Raises:
- AssertionError: if ``countries`` contains invalid countries
+ AssertionError: if ``task`` or ``split`` is invalid
+ DatasetNotFoundError: If dataset is not found and *download* is False.
ImportError: if h5py is not installed
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
assert (
split in self.valid_splits
- ), f"Pleas choose one of these valid data splits {self.valid_splits}."
+ ), f"Please choose one of these valid data splits {self.valid_splits}."
self.split = split
+ assert (
+ task in self.valid_tasks
+ ), f"Please choose one of these valid tasks {self.valid_tasks}."
+ self.task = task
+
self.root = root
self.transforms = transforms
self.download = download
@@ -104,7 +127,7 @@ def __len__(self) -> int:
import h5py
with h5py.File(
- os.path.join(self.root, self.data_dir, self.img_file_name), "r"
+ os.path.join(self.root, self.data_file_name.format(self.task)), "r"
) as f:
num_datapoints: int = f[self.split]["pv_log"].shape[0]
@@ -139,12 +162,18 @@ def _load_image(self, index: int) -> Tensor:
import h5py
with h5py.File(
- os.path.join(self.root, self.data_dir, self.img_file_name), "r"
+ os.path.join(self.root, self.data_file_name.format(self.task)), "r"
) as f:
- arr = f[self.split]["images_log"][index, :, :, :]
+ arr = f[self.split]["images_log"][index]
+
+ # forecast has dimension [16, 64, 64, 3] but reshape to [48, 64, 64]
+ # https://github.com/yuhao-nie/Stanford-solar-forecasting-dataset/blob/main/models/SUNSET_forecast.ipynb
+ if self.task == "forecast":
+ arr = rearrange(arr, "t h w c-> (t c) h w")
+ else:
+ arr = rearrange(arr, "h w c -> c h w")
- # put channel first
- tensor = torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32)
+ tensor = torch.from_numpy(arr).to(torch.float32)
return tensor
def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
@@ -159,14 +188,13 @@ def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
import h5py
with h5py.File(
- os.path.join(self.root, self.data_dir, self.img_file_name), "r"
+ os.path.join(self.root, self.data_file_name.format(self.task)), "r"
) as f:
label = f[self.split]["pv_log"][index]
- path = os.path.join(self.root, self.data_dir, f"times_{self.split}.npy")
+ path = os.path.join(self.root, f"times_{self.split}_{self.task}.npy")
datestring = np.load(path, allow_pickle=True)[index].strftime(self.dateformat)
- # put channel first
features: dict[str, Union[str, Tensor]] = {
"label": torch.tensor(label, dtype=torch.float32),
"date": datestring,
@@ -174,51 +202,39 @@ def _load_features(self, index: int) -> dict[str, Union[str, Tensor]]:
return features
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
- pathname = os.path.join(self.root, self.data_dir)
+ pathname = os.path.join(self.root, self.data_file_name.format(self.task))
if os.path.exists(pathname):
return
# Check if the zip files have already been downloaded
- pathname = os.path.join(self.root, self.data_dir) + ".zip"
+ pathname = os.path.join(self.root, self.zipfile_name.format(self.task))
if os.path.exists(pathname):
self._extract()
return
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
self._extract()
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
- """
+ """Download the dataset and extract it."""
download_url(
- self.url,
+ self.url.format(self.zipfile_name.format(self.task)),
self.root,
- filename=self.data_dir,
- md5=self.md5 if self.checksum else None,
+ filename=self.zipfile_name.format(self.task),
+ md5=self.md5[self.task] if self.checksum else None,
)
self._extract()
def _extract(self) -> None:
"""Extract the dataset."""
- zipfile_path = os.path.join(self.root, self.data_dir) + ".zip"
+ zipfile_path = os.path.join(self.root, self.zipfile_name.format(self.task))
extract_archive(zipfile_path, self.root)
def plot(
@@ -226,9 +242,11 @@ def plot(
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
+ In the ``forecast`` task the latest image is plotted.
+
Args:
sample: a sample return by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
@@ -237,7 +255,13 @@ def plot(
Returns:
a matplotlib Figure with the rendered sample
"""
- image, label = sample["image"], sample["label"].item()
+ if self.task == "nowcast":
+ image, label = sample["image"].permute(1, 2, 0), sample["label"].item()
+ else:
+ image, label = (
+ sample["image"].permute(1, 2, 0).reshape(64, 64, 3, 16)[:, :, :, -1],
+ sample["label"][-1].item(),
+ )
showing_predictions = "prediction" in sample
if showing_predictions:
@@ -245,7 +269,7 @@ def plot(
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
- ax.imshow(image.permute(1, 2, 0) / 255)
+ ax.imshow(image / 255)
ax.axis("off")
if show_titles:
diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py
index 12a4a48dd6f..dff7276c65a 100644
--- a/torchgeo/datasets/so2sat.py
+++ b/torchgeo/datasets/so2sat.py
@@ -10,10 +10,11 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, percentile_normalization
+from .utils import DatasetNotFoundError, check_integrity, percentile_normalization
class So2Sat(NonGeoDataset):
@@ -207,7 +208,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if data is not found in ``root``, or checksums don't match
+ DatasetNotFoundError: If dataset is not found.
.. versionadded:: 0.3
The *bands* parameter.
@@ -256,7 +257,7 @@ def __init__(
self.fn = os.path.join(self.root, self.filenames_by_version[version][split])
if not self._check_integrity():
- raise RuntimeError("Dataset not found or corrupted.")
+ raise DatasetNotFoundError(self)
with h5py.File(self.fn, "r") as f:
self.size: int = f["label"].shape[0]
@@ -335,7 +336,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py
index e1f00347dfe..c6780e1971c 100644
--- a/torchgeo/datasets/spacenet.py
+++ b/torchgeo/datasets/spacenet.py
@@ -26,6 +26,7 @@
from .geo import NonGeoDataset
from .utils import (
+ DatasetNotFoundError,
check_integrity,
download_radiant_mlhub_collection,
download_radiant_mlhub_dataset,
@@ -98,7 +99,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.image = image # For testing
@@ -116,11 +117,7 @@ def __init__(
if to_be_downloaded:
if not download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use "
- "`download=True` to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
else:
self._download(to_be_downloaded, api_key)
@@ -283,9 +280,6 @@ def _download(self, collections: list[str], api_key: Optional[str] = None) -> No
Args:
collections: Collections to be downloaded
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
"""
for collection in collections:
download_radiant_mlhub_collection(collection, self.root, api_key)
@@ -421,7 +415,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
collections = ["sn1_AOI_1_RIO"]
assert image in {"rgb", "8band"}
@@ -541,7 +535,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
super().__init__(
@@ -664,7 +658,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert image in {"MS", "PAN", "PS-MS", "PS-RGB"}
self.speed_mask = speed_mask
@@ -909,7 +903,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
collections = ["sn4_AOI_6_Atlanta"]
assert image in {"MS", "PAN", "PS-RGBNIR"}
@@ -1081,7 +1075,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
super().__init__(
root,
@@ -1205,7 +1199,7 @@ def __init__(
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.image = image # For testing
@@ -1223,9 +1217,6 @@ def __download(self, api_key: Optional[str] = None) -> None:
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
"""
if os.path.exists(
os.path.join(
@@ -1307,7 +1298,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` but dataset is missing
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.split = split
@@ -1326,11 +1317,7 @@ def __init__(
if to_be_downloaded:
if not download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use "
- "`download=True` to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
else:
self._download(to_be_downloaded, api_key)
diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py
index 99b88976fa9..fa486c4aa79 100644
--- a/torchgeo/datasets/ssl4eo.py
+++ b/torchgeo/datasets/ssl4eo.py
@@ -12,10 +12,11 @@
import numpy as np
import rasterio
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_url, extract_archive
+from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive
class SSL4EO(NonGeoDataset):
@@ -179,7 +180,7 @@ def __init__(
Raises:
AssertionError: if any arguments are invalid
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.metadata
assert seasons in range(1, 5)
@@ -233,11 +234,7 @@ def __len__(self) -> int:
return len(self.scenes)
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
path = os.path.join(self.subdir, "00000*", "*", "all_bands.tif")
if glob.glob(path):
@@ -255,11 +252,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -293,7 +286,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
@@ -429,7 +422,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
- RuntimeError: if dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found.
"""
assert split in self.metadata
assert seasons in range(1, 5)
@@ -482,11 +475,7 @@ def __len__(self) -> int:
return 251079
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
directory_path = os.path.join(self.root, self.split)
if os.path.exists(directory_path):
@@ -500,7 +489,7 @@ def _verify(self) -> None:
if integrity:
self._extract()
else:
- raise RuntimeError(f"Dataset not found in `root={self.root}`")
+ raise DatasetNotFoundError(self)
def _extract(self) -> None:
"""Extract the dataset."""
@@ -512,7 +501,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py
index fb6fc48889d..a2cd92867b7 100644
--- a/torchgeo/datasets/ssl4eo_benchmark.py
+++ b/torchgeo/datasets/ssl4eo_benchmark.py
@@ -11,12 +11,13 @@
import numpy as np
import rasterio
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .cdl import CDL
from .geo import NonGeoDataset
from .nlcd import NLCD
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class SSL4EOLBenchmark(NonGeoDataset):
@@ -130,7 +131,7 @@ def __init__(
Raises:
AssertionError: if any arguments are invalid
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert (
sensor in self.valid_sensors
@@ -189,11 +190,7 @@ def __init__(
self.ordinal_cmap[v] = torch.tensor(self.cmap[k])
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
img_pathname = os.path.join(self.root, self.img_dir_name, "**", "all_bands.tif")
exists = []
@@ -222,11 +219,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -332,7 +325,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py
index acb44a5ce29..5dca4e6d969 100644
--- a/torchgeo/datasets/sustainbench_crop_yield.py
+++ b/torchgeo/datasets/sustainbench_crop_yield.py
@@ -9,10 +9,11 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class SustainBenchCropYield(NonGeoDataset):
@@ -77,7 +78,7 @@ def __init__(
Raises:
AssertionError: if ``countries`` contains invalid countries or if ``split``
is invalid
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert set(countries).issubset(
self.valid_countries
@@ -185,11 +186,7 @@ def retrieve_collection(self) -> list[tuple[str, int]]:
return collection
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
pathname = os.path.join(self.root, self.dir)
if os.path.exists(pathname):
@@ -203,22 +200,14 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
self._extract()
def _download(self) -> None:
- """Download the dataset and extract it.
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
- """
+ """Download the dataset and extract it."""
download_url(
self.url,
self.root,
@@ -238,7 +227,7 @@ def plot(
band_idx: int = 0,
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py
index ce3966646c3..4a7867dc4a5 100644
--- a/torchgeo/datasets/ucmerced.py
+++ b/torchgeo/datasets/ucmerced.py
@@ -8,10 +8,11 @@
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as F
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
-from .utils import check_integrity, download_url, extract_archive
+from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive
class UCMerced(NonGeoClassificationDataset):
@@ -67,29 +68,6 @@ class UCMerced(NonGeoClassificationDataset):
md5 = "5b7ec56793786b6dc8a908e8854ac0e4"
base_dir = os.path.join("UCMerced_LandUse", "Images")
- classes = [
- "agricultural",
- "airplane",
- "baseballdiamond",
- "beach",
- "buildings",
- "chaparral",
- "denseresidential",
- "forest",
- "freeway",
- "golfcourse",
- "harbor",
- "intersection",
- "mediumresidential",
- "mobilehomepark",
- "overpass",
- "parkinglot",
- "river",
- "runway",
- "sparseresidential",
- "storagetanks",
- "tenniscourt",
- ]
splits = ["train", "val", "test"]
split_urls = {
@@ -122,8 +100,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
self.root = root
@@ -169,11 +146,7 @@ def _check_integrity(self) -> bool:
return integrity
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
filepath = os.path.join(self.root, self.base_dir)
if os.path.exists(filepath):
@@ -186,11 +159,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()
@@ -222,7 +191,7 @@ def plot(
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
@@ -236,6 +205,11 @@ def plot(
.. versionadded:: 0.2
"""
image = np.rollaxis(sample["image"].numpy(), 0, 3)
+
+ # Normalize the image if the max value is greater than 1
+ if image.max() > 1:
+ image = image.astype(np.float32) / 255.0 # Scale to [0, 1]
+
label = cast(int, sample["label"].item())
label_class = self.classes[label]
diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py
index 0361fd85b36..f33d268d4ce 100644
--- a/torchgeo/datasets/usavars.py
+++ b/torchgeo/datasets/usavars.py
@@ -10,13 +10,14 @@
import matplotlib.pyplot as plt
import numpy as np
+import pandas as pd
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_url, extract_archive
+from .utils import DatasetNotFoundError, download_url, extract_archive
class USAVars(NonGeoDataset):
@@ -105,9 +106,7 @@ def __init__(
Raises:
AssertionError: if invalid labels are provided
- ImportError: if pandas is not installed
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
@@ -124,13 +123,6 @@ def __init__(
self._verify()
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
-
self.files = self._load_files()
self.label_dfs = {
@@ -193,11 +185,7 @@ def _load_image(self, path: str) -> Tensor:
return tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
pathname = os.path.join(self.root, "uar")
csv_pathname = os.path.join(self.root, "*.csv")
@@ -215,11 +203,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
self._download()
self._extract()
diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py
index 55d94abed6f..6446e7e64e1 100644
--- a/torchgeo/datasets/utils.py
+++ b/torchgeo/datasets/utils.py
@@ -23,11 +23,13 @@
import rasterio
import torch
from torch import Tensor
+from torch.utils.data import Dataset
from torchvision.datasets.utils import check_integrity, download_url
from torchvision.utils import draw_segmentation_masks
__all__ = (
"check_integrity",
+ "DatasetNotFoundError",
"download_url",
"download_and_extract_archive",
"extract_archive",
@@ -46,6 +48,49 @@
)
+class DatasetNotFoundError(FileNotFoundError):
+ """Raised when a dataset is requested but doesn't exist.
+
+ .. versionadded:: 0.6
+ """
+
+ def __init__(self, dataset: Dataset[object]) -> None:
+ """Intstantiate a new DatasetNotFoundError instance.
+
+ Args:
+ dataset: The dataset that was requested.
+ """
+ msg = "Dataset not found"
+
+ if hasattr(dataset, "root"):
+ var = "root"
+ val = dataset.root
+ elif hasattr(dataset, "paths"):
+ var = "paths"
+ val = dataset.paths
+ else:
+ super().__init__(f"{msg}.")
+ return
+
+ msg += f" in `{var}={val!r}` and "
+
+ if hasattr(dataset, "download") and not dataset.download:
+ msg += "`download=False`"
+ else:
+ msg += "cannot be automatically downloaded"
+
+ msg += f", either specify a different `{var}` or "
+
+ if hasattr(dataset, "download") and not dataset.download:
+ msg += "use `download=True` to automatically"
+ else:
+ msg += "manually"
+
+ msg += " download the dataset."
+
+ super().__init__(msg)
+
+
class _rarfile:
class RarFile:
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -737,3 +782,27 @@ def percentile_normalization(
(img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1
)
return img_normalized
+
+
+def path_is_vsi(path: str) -> bool:
+ """Checks if the given path is pointing to a Virtual File System.
+
+ .. note::
+ Does not check if the path exists, or if it is a dir or file.
+
+ VSI can for instance be Cloud Storage Blobs or zip-archives.
+ They will start with a prefix indicating this.
+ For examples of these, see references for the two accepted syntaxes.
+
+ * https://gdal.org/user/virtual_file_systems.html
+ * https://rasterio.readthedocs.io/en/latest/topics/datasets.html
+
+ Args:
+ path: string representing a directory or file
+
+ Returns:
+ True if path is on a virtual file system, else False
+
+ .. versionadded:: 0.6
+ """
+ return "://" in path or path.startswith("/vsi")
diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py
index 78370f31585..59ecfcda690 100644
--- a/torchgeo/datasets/vaihingen.py
+++ b/torchgeo/datasets/vaihingen.py
@@ -15,6 +15,7 @@
from .geo import NonGeoDataset
from .utils import (
+ DatasetNotFoundError,
check_integrity,
draw_semantic_segmentation_masks,
extract_archive,
@@ -132,6 +133,10 @@ def __init__(
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ AssertionError: If *split* is invalid.
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
self.root = root
@@ -210,11 +215,7 @@ def _load_target(self, index: int) -> Tensor:
return tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if checksum fails or the dataset is not downloaded
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
if os.path.exists(os.path.join(self.root, self.image_root)):
return
@@ -234,11 +235,7 @@ def _verify(self) -> None:
if all(exists):
return
- # Check if the user requested to download the dataset
- raise RuntimeError(
- "Dataset not found in `root` directory, either specify a different"
- + " `root` directory or manually download the dataset to this directory."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py
index 4dbda994433..db0807ee930 100644
--- a/torchgeo/datasets/vhr10.py
+++ b/torchgeo/datasets/vhr10.py
@@ -10,11 +10,17 @@
import numpy as np
import torch
from matplotlib import patches
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, download_and_extract_archive, download_url
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ download_and_extract_archive,
+ download_url,
+)
def convert_coco_poly_to_mask(
@@ -199,8 +205,7 @@ def __init__(
Raises:
AssertionError: if ``split`` argument is invalid
ImportError: if ``split="positive"`` and pycocotools is not installed
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in ["positive", "negative"]
@@ -213,10 +218,7 @@ def __init__(
self._download()
if not self._check_integrity():
- raise RuntimeError(
- "Dataset not found or corrupted. "
- + "You can use download=True to download it"
- )
+ raise DatasetNotFoundError(self)
if split == "positive":
# Must be installed to parse annotations file
@@ -371,7 +373,7 @@ def plot(
show_feats: Optional[str] = "both",
box_alpha: float = 0.7,
mask_alpha: float = 0.7,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
@@ -394,13 +396,13 @@ def plot(
assert show_feats in {"boxes", "masks", "both"}
if self.split == "negative":
- plt.imshow(sample["image"].permute(1, 2, 0))
- axs = plt.gca()
- axs.axis("off")
+ fig, axs = plt.subplots(squeeze=False)
+ axs[0, 0].imshow(sample["image"].permute(1, 2, 0))
+ axs[0, 0].axis("off")
if suptitle is not None:
plt.suptitle(suptitle)
- return plt.gcf()
+ return fig
if show_feats != "boxes":
try:
@@ -437,11 +439,9 @@ def plot(
ncols += 1
# Display image
- fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
- if not isinstance(axs, np.ndarray):
- axs = [axs]
- axs[0].imshow(image)
- axs[0].axis("off")
+ fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10))
+ axs[0, 0].imshow(image)
+ axs[0, 0].axis("off")
cm = plt.get_cmap("gist_rainbow")
for i in range(n_gt):
@@ -451,7 +451,7 @@ def plot(
# Add bounding boxes
x1, y1, x2, y2 = boxes[i]
if show_feats in {"boxes", "both"}:
- p = patches.Rectangle(
+ r = patches.Rectangle(
(x1, y1),
x2 - x1,
y2 - y1,
@@ -461,32 +461,32 @@ def plot(
edgecolor=color,
facecolor="none",
)
- axs[0].add_patch(p)
+ axs[0, 0].add_patch(r)
# Add labels
label = self.categories[class_num]
caption = label
- axs[0].text(
+ axs[0, 0].text(
x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none"
)
# Add masks
if show_feats in {"masks", "both"} and "masks" in sample:
mask = masks[i]
- contours = find_contours(mask, 0.5)
+ contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
)
- axs[0].add_patch(p)
+ axs[0, 0].add_patch(p)
if show_titles:
- axs[0].set_title("Ground Truth")
+ axs[0, 0].set_title("Ground Truth")
if show_predictions:
- axs[1].imshow(image)
- axs[1].axis("off")
+ axs[0, 1].imshow(image)
+ axs[0, 1].axis("off")
for i in range(n_pred):
score = prediction_scores[i]
if score < 0.5:
@@ -498,7 +498,7 @@ def plot(
if show_pred_boxes:
# Add bounding boxes
x1, y1, x2, y2 = prediction_boxes[i]
- p = patches.Rectangle(
+ r = patches.Rectangle(
(x1, y1),
x2 - x1,
y2 - y1,
@@ -508,12 +508,12 @@ def plot(
edgecolor=color,
facecolor="none",
)
- axs[1].add_patch(p)
+ axs[0, 1].add_patch(r)
# Add labels
label = self.categories[class_num]
caption = f"{label} {score:.3f}"
- axs[1].text(
+ axs[0, 1].text(
x1,
y1 - 8,
caption,
@@ -525,16 +525,16 @@ def plot(
# Add masks
if show_pred_masks:
mask = prediction_masks[i]
- contours = find_contours(mask, 0.5)
+ contours = find_contours(mask, 0.5) # type: ignore[no-untyped-call]
for verts in contours:
verts = np.fliplr(verts)
p = patches.Polygon(
verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
)
- axs[1].add_patch(p)
+ axs[0, 1].add_patch(p)
if show_titles:
- axs[1].set_title("Prediction")
+ axs[0, 1].set_title("Prediction")
plt.tight_layout()
diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py
index 05e84bf17b4..d782cd2b50a 100644
--- a/torchgeo/datasets/western_usa_live_fuel_moisture.py
+++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py
@@ -8,11 +8,16 @@
import os
from typing import Any, Callable, Optional
+import pandas as pd
import torch
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_radiant_mlhub_collection, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ download_radiant_mlhub_collection,
+ extract_archive,
+)
class WesternUSALiveFuelMoisture(NonGeoDataset):
@@ -217,8 +222,7 @@ def __init__(
Raises:
AssertionError: if ``input_features`` contains invalid variable names
- ImportError: if pandas is not installed
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
super().__init__()
@@ -230,13 +234,6 @@ def __init__(
self._verify()
- try:
- import pandas as pd # noqa: F401
- except ImportError:
- raise ImportError(
- "pandas is not installed and is required to use this dataset"
- )
-
assert all(
input in self.all_variable_names for input in input_features
), "Invalid input variable name."
@@ -287,14 +284,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:
return sample
- def _load_data(self) -> "pd.DataFrame": # type: ignore[name-defined] # noqa: F821
+ def _load_data(self) -> pd.DataFrame:
"""Load data from individual files into pandas dataframe.
Returns:
the features and label
"""
- import pandas as pd
-
data_rows = []
for path in self.collection:
with open(path) as f:
@@ -309,11 +304,7 @@ def _load_data(self) -> "pd.DataFrame": # type: ignore[name-defined] # noqa: F8
return df
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the extracted files already exist
pathname = os.path.join(self.root, self.collection_id)
if os.path.exists(pathname):
@@ -327,11 +318,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -347,9 +334,6 @@ def _download(self, api_key: Optional[str] = None) -> None:
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
-
- Raises:
- RuntimeError: if download doesn't work correctly or checksums don't match
"""
download_radiant_mlhub_collection(self.collection_id, self.root, api_key)
filename = os.path.join(self.root, self.collection_id) + ".tar.gz"
diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py
index 3b423247a52..55eaa6735c8 100644
--- a/torchgeo/datasets/xview.py
+++ b/torchgeo/datasets/xview.py
@@ -10,11 +10,17 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive
+from .utils import (
+ DatasetNotFoundError,
+ check_integrity,
+ draw_semantic_segmentation_masks,
+ extract_archive,
+)
class XView2(NonGeoDataset):
@@ -77,6 +83,10 @@ def __init__(
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ AssertionError: If *split* is invalid.
+ DatasetNotFoundError: If dataset is not found.
"""
assert split in self.metadata
self.root = root
@@ -180,11 +190,7 @@ def _load_target(self, path: str) -> Tensor:
return tensor
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if checksum fails or the dataset is not downloaded
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
exists = []
for split_info in self.metadata.values():
@@ -213,11 +219,7 @@ def _verify(self) -> None:
if all(exists):
return
- # Check if the user requested to download the dataset
- raise RuntimeError(
- "Dataset not found in `root` directory, either specify a different"
- + " `root` directory or manually download the dataset to this directory."
- )
+ raise DatasetNotFoundError(self)
def plot(
self,
@@ -225,7 +227,7 @@ def plot(
show_titles: bool = True,
suptitle: Optional[str] = None,
alpha: float = 0.5,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py
index 008f4e34d9d..2047b121adc 100644
--- a/torchgeo/datasets/zuericrop.py
+++ b/torchgeo/datasets/zuericrop.py
@@ -9,10 +9,11 @@
import matplotlib.pyplot as plt
import torch
+from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
-from .utils import download_url, percentile_normalization
+from .utils import DatasetNotFoundError, download_url, percentile_normalization
class ZueriCrop(NonGeoDataset):
@@ -80,8 +81,7 @@ def __init__(
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
- RuntimeError: if ``download=False`` and data is not found, or checksums
- don't match
+ DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self._validate_bands(bands)
self.band_indices = torch.tensor(
@@ -208,11 +208,7 @@ def _load_target(self, index: int) -> tuple[Tensor, Tensor, Tensor]:
return masks, boxes, labels
def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
+ """Verify the integrity of the dataset."""
# Check if the files already exist
exists = []
for filename in self.filenames:
@@ -224,11 +220,7 @@ def _verify(self) -> None:
# Check if the user requested to download the dataset
if not self.download:
- raise RuntimeError(
- "Dataset not found in `root` directory and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
+ raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@@ -250,6 +242,7 @@ def _validate_bands(self, bands: Sequence[str]) -> None:
Args:
bands: user-provided sequence of bands to load
+
Raises:
AssertionError: if ``bands`` is not a sequence
ValueError: if an invalid band name is provided
@@ -267,7 +260,7 @@ def plot(
time_step: int = 0,
show_titles: bool = True,
suptitle: Optional[str] = None,
- ) -> plt.Figure:
+ ) -> Figure:
"""Plot a sample from the dataset.
Args:
diff --git a/torchgeo/main.py b/torchgeo/main.py
new file mode 100644
index 00000000000..0b002cdc201
--- /dev/null
+++ b/torchgeo/main.py
@@ -0,0 +1,37 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Command-line interface to TorchGeo."""
+
+import os
+
+from lightning.pytorch.cli import ArgsType, LightningCLI
+
+# Allows classes to be referenced using only the class name
+import torchgeo.datamodules # noqa: F401
+import torchgeo.trainers # noqa: F401
+from torchgeo.datamodules import BaseDataModule
+from torchgeo.trainers import BaseTask
+
+
+def main(args: ArgsType = None) -> None:
+ """Command-line interface to TorchGeo."""
+ # Taken from https://github.com/pangeo-data/cog-best-practices
+ rasterio_best_practices = {
+ "GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
+ "AWS_NO_SIGN_REQUEST": "YES",
+ "GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000",
+ "GDAL_SWATH_SIZE": "200000000",
+ "VSI_CURL_CACHE_SIZE": "200000000",
+ }
+ os.environ.update(rasterio_best_practices)
+
+ LightningCLI(
+ model_class=BaseTask,
+ datamodule_class=BaseDataModule,
+ seed_everything_default=0,
+ subclass_mode_model=True,
+ subclass_mode_data=True,
+ save_config_kwargs={"overwrite": True},
+ args=args,
+ )
diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py
index 3cfd274996b..59f42223cf1 100644
--- a/torchgeo/models/rcf.py
+++ b/torchgeo/models/rcf.py
@@ -5,21 +5,33 @@
from typing import Optional
+import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules import Module
+from ..datasets import NonGeoDataset
+
class RCF(Module):
"""This model extracts random convolutional features (RCFs) from its input.
- RCFs are used in Multi-task Observation using Satellite Imagery & Kitchen Sinks
- (MOSAIKS) method proposed in https://www.nature.com/articles/s41467-021-24638-z.
+ RCFs are used in the Multi-task Observation using Satellite Imagery & Kitchen Sinks
+ (MOSAIKS) method proposed in "A generalizable and accessible approach to machine
+ learning with global satellite imagery".
+
+ This class can operate in two modes, "gaussian" and "empirical". In "gaussian" mode,
+ the filters will be sampled from a Gaussian distribution, while in "empirical" mode,
+ the filters will be sampled from a dataset.
+
+ If you use this model in your research, please cite the following paper:
+
+ * https://www.nature.com/articles/s41467-021-24638-z
.. note::
- This Module is *not* trainable. It is only used as a feature extractor.
+ This Module is *not* trainable. It is only used as a feature extractor.
"""
weights: Tensor
@@ -32,6 +44,8 @@ def __init__(
kernel_size: int = 3,
bias: float = -1.0,
seed: Optional[int] = None,
+ mode: str = "gaussian",
+ dataset: Optional[NonGeoDataset] = None,
) -> None:
"""Initializes the RCF model.
@@ -41,21 +55,28 @@ def __init__(
.. versionadded:: 0.2
The *seed* parameter.
+ .. versionadded:: 0.5
+ The *mode* and *dataset* parameters.
+
Args:
in_channels: number of input channels
features: number of features to compute, must be divisible by 2
kernel_size: size of the kernel used to compute the RCFs
bias: bias of the convolutional layer
seed: random seed used to initialize the convolutional layer
+ mode: "empirical" or "gaussian"
+ dataset: a NonGeoDataset to sample from when mode is "empirical"
"""
super().__init__()
-
+ assert mode in ["empirical", "gaussian"]
+ if mode == "empirical" and dataset is None:
+ raise ValueError("dataset must be provided when mode is 'empirical'")
assert features % 2 == 0
+ num_patches = features // 2
- if seed is None:
- generator = None
- else:
- generator = torch.Generator().manual_seed(seed)
+ generator = torch.Generator()
+ if seed:
+ generator = generator.manual_seed(seed)
# We register the weight and bias tensors as "buffers". This does two things:
# makes them behave correctly when we call .to(...) on the module, and makes
@@ -64,7 +85,7 @@ def __init__(
self.register_buffer(
"weights",
torch.randn(
- features // 2,
+ num_patches,
in_channels,
kernel_size,
kernel_size,
@@ -73,9 +94,85 @@ def __init__(
),
)
self.register_buffer(
- "biases", torch.zeros(features // 2, requires_grad=False) + bias
+ "biases", torch.zeros(num_patches, requires_grad=False) + bias
+ )
+
+ if mode == "empirical":
+ assert dataset is not None
+ num_channels, height, width = dataset[0]["image"].shape
+ assert num_channels == in_channels
+ patches = np.zeros(
+ (num_patches, num_channels, kernel_size, kernel_size), dtype=np.float32
+ )
+ idxs = torch.randint(
+ 0, len(dataset), (num_patches,), generator=generator
+ ).numpy()
+ ys = torch.randint(
+ 0, height - kernel_size, (num_patches,), generator=generator
+ ).numpy()
+ xs = torch.randint(
+ 0, width - kernel_size, (num_patches,), generator=generator
+ ).numpy()
+
+ for i in range(num_patches):
+ img = dataset[idxs[i]]["image"]
+ patches[i] = img[
+ :, ys[i] : ys[i] + kernel_size, xs[i] : xs[i] + kernel_size
+ ]
+
+ patches = self._normalize(patches)
+ self.weights = torch.tensor(patches)
+
+ def _normalize(
+ self,
+ patches: "np.typing.NDArray[np.float32]",
+ min_divisor: float = 1e-8,
+ zca_bias: float = 0.001,
+ ) -> "np.typing.NDArray[np.float32]":
+ """Does ZCA whitening on a set of input patches.
+
+ Copied from https://github.com/Global-Policy-Lab/mosaiks-paper/blob/7efb09ed455505562d6bb04c2aaa242ef59f0a82/code/mosaiks/featurization.py#L120
+
+ Args:
+ patches: a numpy array of size (N, C, H, W)
+ min_divisor: a small number to guard against division by zero
+ zca_bias: bias term for ZCA whitening
+
+ Returns
+ a numpy array of size (N, C, H, W) containing the normalized patches
+
+ .. versionadded:: 0.5
+ """ # noqa: E501
+ n_patches = patches.shape[0]
+ orig_shape = patches.shape
+ patches = patches.reshape(patches.shape[0], -1)
+
+ # Zero mean every feature
+ patches = patches - np.mean(patches, axis=1, keepdims=True)
+
+ # Normalize
+ patch_norms = np.linalg.norm(patches, axis=1)
+
+ # Get rid of really small norms
+ patch_norms[np.where(patch_norms < min_divisor)] = 1
+
+ # Make features unit norm
+ patches = patches / patch_norms[:, np.newaxis]
+
+ patchesCovMat = 1.0 / n_patches * patches.T.dot(patches)
+
+ (E, V) = np.linalg.eig(patchesCovMat)
+
+ E += zca_bias
+ sqrt_zca_eigs = np.sqrt(E)
+ inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1))
+ global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T)
+ patches_normalized: "np.typing.NDArray[np.float32]" = (
+ (patches).dot(global_ZCA).dot(global_ZCA.T)
)
+ return patches_normalized.reshape(orig_shape).astype("float32")
+
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the RCF model.
diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py
index 082899d572e..fb254c166e4 100644
--- a/torchgeo/models/resnet.py
+++ b/torchgeo/models/resnet.py
@@ -52,6 +52,13 @@
data_keys=["image"],
)
+# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
+_ssl4eo_l_transforms = AugmentationSequential(
+ K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
+ K.CenterCrop((224, 224)),
+ data_keys=["image"],
+)
+
# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
# Can be removed once torchvision>=0.15 is required
@@ -67,6 +74,136 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
.. versionadded:: 0.4
"""
+ LANDSAT_TM_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_tm_toa_moco-1c691b4f.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_TM_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_tm_toa_simclr-d2d38ace.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_ETM_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_etm_toa_moco-bb88689c.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 9,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_ETM_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_etm_toa_simclr-4d813f79.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 9,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_ETM_SR_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_etm_sr_moco-4f078acd.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 6,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_ETM_SR_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_etm_sr_simclr-8e8543b4.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 6,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 11,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 11,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_OLI_SR_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_oli_sr_moco-660e82ed.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_OLI_SR_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet18_landsat_oli_sr_simclr-7bced5be.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
SENTINEL2_ALL_MOCO = Weights(
url="https://huggingface.co/torchgeo/resnet18_sentinel2_all_moco/resolve/main/resnet18_sentinel2_all_moco-59bfdff9.pth", # noqa: E501
transforms=_zhu_xlab_transforms,
@@ -94,7 +231,7 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SECO = Weights(
- url="https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/main/resnet18_sentinel2_rgb_seco-9976a9cb.pth", # noqa: E501
+ url="https://huggingface.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/main/resnet18_sentinel2_rgb_seco-cefca942.pth", # noqa: E501
transforms=_seco_transforms,
meta={
"dataset": "SeCo Dataset",
@@ -129,6 +266,136 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
},
)
+ LANDSAT_TM_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_tm_toa_moco-ba1ce753.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_TM_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_tm_toa_simclr-a1c93432.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_ETM_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_etm_toa_moco-e9a84d5a.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 9,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_ETM_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_etm_toa_simclr-70b5575f.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 9,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_ETM_SR_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_etm_sr_moco-1266cde3.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 6,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_ETM_SR_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_etm_sr_simclr-e5d185d7.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 6,
+ "model": "resnet18",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 11,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 11,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_OLI_SR_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_oli_sr_moco-ff580dad.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_OLI_SR_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/resnet50_landsat_oli_sr_simclr-94f78913.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "resnet50",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
SENTINEL1_ALL_MOCO = Weights(
url="https://huggingface.co/torchgeo/resnet50_sentinel1_all_moco/resolve/main/resnet50_sentinel1_all_moco-906e4356.pth", # noqa: E501
transforms=_zhu_xlab_transforms,
@@ -182,7 +449,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc]
)
SENTINEL2_RGB_SECO = Weights(
- url="https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/main/resnet50_sentinel2_rgb_seco-584035db.pth", # noqa: E501
+ url="https://huggingface.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/main/resnet50_sentinel2_rgb_seco-018bf397.pth", # noqa: E501
transforms=_seco_transforms,
meta={
"dataset": "SeCo Dataset",
@@ -220,7 +487,11 @@ def resnet18(
model: ResNet = timm.create_model("resnet18", *args, **kwargs)
if weights:
- model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
+ missing_keys, unexpected_keys = model.load_state_dict(
+ weights.get_state_dict(progress=True), strict=False
+ )
+ assert set(missing_keys) <= {"fc.weight", "fc.bias"}
+ assert not unexpected_keys
return model
@@ -251,6 +522,10 @@ def resnet50(
model: ResNet = timm.create_model("resnet50", *args, **kwargs)
if weights:
- model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
+ missing_keys, unexpected_keys = model.load_state_dict(
+ weights.get_state_dict(progress=True), strict=False
+ )
+ assert set(missing_keys) <= {"fc.weight", "fc.bias"}
+ assert not unexpected_keys
return model
diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py
index 7080257852c..99bd7c23acc 100644
--- a/torchgeo/models/vit.py
+++ b/torchgeo/models/vit.py
@@ -25,6 +25,13 @@
data_keys=["image"],
)
+# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501
+_ssl4eo_l_transforms = AugmentationSequential(
+ K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
+ K.CenterCrop((224, 224)),
+ data_keys=["image"],
+)
+
# https://github.com/pytorch/vision/pull/6883
# https://github.com/pytorch/vision/pull/7107
# Can be removed once torchvision>=0.15 is required
@@ -40,6 +47,136 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
.. versionadded:: 0.4
"""
+ LANDSAT_TM_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_tm_toa_moco-a1c967d8.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_TM_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_tm_toa_simclr-7c2d9799.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_ETM_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_etm_toa_moco-26d19bcf.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 9,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_ETM_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_etm_toa_simclr-34fb12cb.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 9,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_ETM_SR_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_etm_sr_moco-eaa4674e.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 6,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_ETM_SR_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_etm_sr_simclr-a14c466a.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 6,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 11,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 11,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
+ LANDSAT_OLI_SR_MOCO = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_oli_sr_moco-c9b8898d.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "moco",
+ },
+ )
+
+ LANDSAT_OLI_SR_SIMCLR = Weights(
+ url="https://huggingface.co/torchgeo/ssl4eo_landsat/resolve/main/vits16_landsat_oli_sr_simclr-4e8f6102.pth", # noqa: E501
+ transforms=_ssl4eo_l_transforms,
+ meta={
+ "dataset": "SSL4EO-L",
+ "in_chans": 7,
+ "model": "vit_small_patch16_224",
+ "publication": "https://arxiv.org/abs/2306.09424",
+ "repo": "https://github.com/microsoft/torchgeo",
+ "ssl_method": "simclr",
+ },
+ )
+
SENTINEL2_ALL_DINO = Weights(
url="https://huggingface.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/main/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth", # noqa: E501
transforms=_zhu_xlab_transforms,
@@ -94,6 +231,10 @@ def vit_small_patch16_224(
)
if weights:
- model.load_state_dict(weights.get_state_dict(progress=True), strict=False)
+ missing_keys, unexpected_keys = model.load_state_dict(
+ weights.get_state_dict(progress=True), strict=False
+ )
+ assert set(missing_keys) <= {"head.weight", "head.bias"}
+ assert not unexpected_keys
return model
diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py
index b39bc483b40..ec8d916a012 100644
--- a/torchgeo/trainers/__init__.py
+++ b/torchgeo/trainers/__init__.py
@@ -3,6 +3,7 @@
"""TorchGeo trainers."""
+from .base import BaseTask
from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
@@ -12,13 +13,17 @@
from .simclr import SimCLRTask
__all__ = (
- "BYOLTask",
+ # Supervised
"ClassificationTask",
- "MoCoTask",
"MultiLabelClassificationTask",
"ObjectDetectionTask",
"PixelwiseRegressionTask",
"RegressionTask",
"SemanticSegmentationTask",
+ # Self-supervised
+ "BYOLTask",
+ "MoCoTask",
"SimCLRTask",
+ # Base classes
+ "BaseTask",
)
diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py
new file mode 100644
index 00000000000..3a44c047a31
--- /dev/null
+++ b/torchgeo/trainers/base.py
@@ -0,0 +1,78 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+"""Base classes for all :mod:`torchgeo` trainers."""
+
+from abc import ABC, abstractmethod
+from collections.abc import Sequence
+from typing import Any, Optional, Union
+
+import lightning
+from lightning.pytorch import LightningModule
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+
+
+class BaseTask(LightningModule, ABC):
+ """Abstract base class for all TorchGeo trainers.
+
+ .. versionadded:: 0.5
+ """
+
+ #: Model to train.
+ model: Any
+
+ #: Performance metric to monitor in learning rate scheduler and callbacks.
+ monitor = "val_loss"
+
+ #: Whether the goal is to minimize or maximize the performance metric to monitor.
+ mode = "min"
+
+ def __init__(self, ignore: Optional[Union[Sequence[str], str]] = None) -> None:
+ """Initialize a new BaseTask instance.
+
+ Args:
+ ignore: Arguments to skip when saving hyperparameters.
+ """
+ super().__init__()
+ self.save_hyperparameters(ignore=ignore)
+ self.configure_losses()
+ self.configure_metrics()
+ self.configure_models()
+
+ def configure_losses(self) -> None:
+ """Initialize the loss criterion."""
+
+ def configure_metrics(self) -> None:
+ """Initialize the performance metrics."""
+
+ @abstractmethod
+ def configure_models(self) -> None:
+ """Initialize the model."""
+
+ def configure_optimizers(
+ self,
+ ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
+ """Initialize the optimizer and learning rate scheduler.
+
+ Returns:
+ Optimizer and learning rate scheduler.
+ """
+ optimizer = AdamW(self.parameters(), lr=self.hparams["lr"])
+ scheduler = ReduceLROnPlateau(optimizer, patience=self.hparams["patience"])
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor},
+ }
+
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
+ """Forward pass of the model.
+
+ Args:
+ args: Arguments to pass to model.
+ kwargs: Keyword arguments to pass to model.
+
+ Returns:
+ Output of the model.
+ """
+ return self.model(*args, **kwargs)
diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py
index 00315f4028c..68bdb6c9c43 100644
--- a/torchgeo/trainers/byol.py
+++ b/torchgeo/trainers/byol.py
@@ -1,23 +1,22 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-"""BYOL tasks."""
+"""BYOL trainer for self-supervised learning (SSL)."""
import os
-from typing import Any, Optional, cast
+from typing import Any, Optional, Union
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia import augmentation as K
-from lightning.pytorch import LightningModule
-from torch import Tensor, optim
-from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch import Tensor
from torchvision.models._api import WeightsEnum
from ..models import get_weight
from . import utils
+from .base import BaseTask
def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
@@ -75,7 +74,8 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
an augmented batch of imagery
"""
- return cast(Tensor, self.augmentation(x))
+ z: Tensor = self.augmentation(x)
+ return z
class MLP(nn.Module):
@@ -108,7 +108,8 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
embedded version of the input
"""
- return cast(Tensor, self.mlp(x))
+ z: Tensor = self.mlp(x)
+ return z
class BackboneWrapper(nn.Module):
@@ -122,7 +123,7 @@ class BackboneWrapper(nn.Module):
* The forward call returns the output of the projection head
.. versionchanged 0.4: Name changed from *EncoderWrapper* to
- *BackboneWrapper*.
+ *BackboneWrapper*.
"""
def __init__(
@@ -270,7 +271,8 @@ def forward(self, x: Tensor) -> Tensor:
Returns:
output from the model
"""
- return cast(Tensor, self.predictor(self.backbone(x)))
+ z: Tensor = self.predictor(self.backbone(x))
+ return z
def update_target(self) -> None:
"""Method to update the "target" model weights."""
@@ -278,29 +280,59 @@ def update_target(self) -> None:
pt.data = self.beta * pt.data + (1 - self.beta) * p.data
-class BYOLTask(LightningModule):
- """Class for pre-training any PyTorch model using BYOL.
+class BYOLTask(BaseTask):
+ """BYOL: Bootstrap Your Own Latent.
- Supports any available `Timm model
- `_
- as an architecture choice. To see a list of available pretrained
- models, you can do:
+ Reference implementation:
- .. code-block:: python
+ * https://github.com/deepmind/deepmind-research/tree/master/byol
- import timm
- print(timm.list_models())
+ If you use this trainer in your research, please cite the following paper:
+
+ * https://arxiv.org/abs/2006.07733
"""
- def config_task(self) -> None:
- """Configures the task based on kwargs parameters passed to the constructor."""
- # Create model
- in_channels = self.hyperparams["in_channels"]
- weights = self.hyperparams["weights"]
+ monitor = "train_loss"
+
+ def __init__(
+ self,
+ model: str = "resnet50",
+ weights: Optional[Union[WeightsEnum, str, bool]] = None,
+ in_channels: int = 3,
+ lr: float = 1e-3,
+ patience: int = 10,
+ ) -> None:
+ """Initialize a new BYOLTask instance.
+
+ Args:
+ model: Name of the `timm
+ `__ model to use.
+ weights: Initial model weights. Either a weight enum, the string
+ representation of a weight enum, True for ImageNet weights, False
+ or None for random weights, or the path to a saved model state dict.
+ in_channels: Number of input channels to model.
+ lr: Learning rate for optimizer.
+ patience: Patience for learning rate scheduler.
+
+ .. versionchanged:: 0.4
+ *backbone_name* was renamed to *backbone*. Changed backbone support from
+ torchvision.models to timm.
+
+ .. versionchanged:: 0.5
+ *backbone*, *learning_rate*, and *learning_rate_schedule_patience* were
+ renamed to *model*, *lr*, and *patience*.
+ """
+ self.weights = weights
+ super().__init__(ignore="weights")
+
+ def configure_models(self) -> None:
+ """Initialize the model."""
+ weights = self.weights
+ in_channels: int = self.hparams["in_channels"]
+
+ # Create backbone
backbone = timm.create_model(
- self.hyperparams["backbone"],
- in_chans=in_channels,
- pretrained=weights is True,
+ self.hparams["model"], in_chans=in_channels, pretrained=weights is True
)
# Load weights
@@ -315,79 +347,25 @@ def config_task(self) -> None:
self.model = BYOL(backbone, in_channels=in_channels, image_size=(224, 224))
- def __init__(self, **kwargs: Any) -> None:
- """Initialize a LightningModule for pre-training a model with BYOL.
-
- Keyword Args:
- in_channels: Number of input channels to model
- backbone: Name of the timm model to use
- weights: Either a weight enum, the string representation of a weight enum,
- True for ImageNet weights, False or None for random weights,
- or the path to a saved model state dict.
- learning_rate: Learning rate for optimizer
- learning_rate_schedule_patience: Patience for learning rate scheduler
-
- Raises:
- ValueError: if kwargs arguments are invalid
-
- .. versionchanged:: 0.4
- The *backbone_name* parameter was renamed to *backbone*. Change backbone
- support from torchvision.models to timm.
- """
- super().__init__()
-
- # Creates `self.hparams` from kwargs
- self.save_hyperparameters()
- self.hyperparams = cast(dict[str, Any], self.hparams)
-
- self.config_task()
-
- def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Forward pass of the model.
+ def training_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the training loss and additional metrics.
Args:
- x: tensor of data to run through the model
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- output from the model
- """
- return self.model(*args, **kwargs)
-
- def configure_optimizers(self) -> dict[str, Any]:
- """Initialize the optimizer and learning rate scheduler.
+ The loss tensor.
- Returns:
- learning rate dictionary.
- """
- optimizer_class = getattr(optim, self.hyperparams.get("optimizer", "Adam"))
- lr = self.hyperparams.get("learning_rate", 1e-4)
- weight_decay = self.hyperparams.get("weight_decay", 1e-6)
- optimizer = optimizer_class(self.parameters(), lr=lr, weight_decay=weight_decay)
-
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": ReduceLROnPlateau(
- optimizer,
- patience=self.hyperparams["learning_rate_schedule_patience"],
- ),
- "monitor": "train_loss",
- },
- }
-
- def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the training loss.
-
- Args:
- batch: the output of your DataLoader
-
- Returns:
- training loss
+ Raises:
+ AssertionError: If channel dimensions are incorrect.
"""
- batch = args[0]
x = batch["image"]
- in_channels = self.hyperparams["in_channels"]
+ in_channels = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
if x.size(1) == in_channels:
@@ -409,16 +387,18 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))
- self.log("train_loss", loss, on_step=True, on_epoch=False)
+ self.log("train_loss", loss)
self.model.update_target()
return loss
- def validation_step(self, *args: Any, **kwargs: Any) -> None:
+ def validation_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> None:
"""No-op, does nothing."""
- def test_step(self, *args: Any, **kwargs: Any) -> None:
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""No-op, does nothing."""
- def predict_step(self, *args: Any, **kwargs: Any) -> None:
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""No-op, does nothing."""
diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py
index 29f96f2f9fd..76bea82118f 100644
--- a/torchgeo/trainers/classification.py
+++ b/torchgeo/trainers/classification.py
@@ -1,19 +1,17 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-"""Classification tasks."""
+"""Trainers for image classification."""
import os
-from typing import Any, cast
+from typing import Any, Optional, Union
import matplotlib.pyplot as plt
import timm
import torch
import torch.nn as nn
-from lightning.pytorch import LightningModule
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
-from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import (
MulticlassAccuracy,
@@ -27,30 +25,106 @@
from ..datasets import unbind_samples
from ..models import get_weight
from . import utils
+from .base import BaseTask
+
+
+class ClassificationTask(BaseTask):
+ """Image classification."""
+
+ def __init__(
+ self,
+ model: str = "resnet50",
+ weights: Optional[Union[WeightsEnum, str, bool]] = None,
+ in_channels: int = 3,
+ num_classes: int = 1000,
+ loss: str = "ce",
+ class_weights: Optional[Tensor] = None,
+ lr: float = 1e-3,
+ patience: int = 10,
+ freeze_backbone: bool = False,
+ ) -> None:
+ """Initialize a new ClassificationTask instance.
+ Args:
+ model: Name of the `timm
+ `__ model to use.
+ weights: Initial model weights. Either a weight enum, the string
+ representation of a weight enum, True for ImageNet weights, False
+ or None for random weights, or the path to a saved model state dict.
+ in_channels: Number of input channels to model.
+ num_classes: Number of prediction classes.
+ loss: One of 'ce', 'bce', 'jaccard', or 'focal'.
+ class_weights: Optional rescaling weight given to each
+ class and used with 'ce' loss.
+ lr: Learning rate for optimizer.
+ patience: Patience for learning rate scheduler.
+ freeze_backbone: Freeze the backbone network to linear probe
+ the classifier head.
+
+ .. versionchanged:: 0.4
+ *classification_model* was renamed to *model*.
+
+ .. versionadded:: 0.5
+ The *class_weights* and *freeze_backbone* parameters.
+
+ .. versionchanged:: 0.5
+ *learning_rate* and *learning_rate_schedule_patience* were renamed to
+ *lr* and *patience*.
+ """
+ self.weights = weights
+ super().__init__(ignore="weights")
-class ClassificationTask(LightningModule):
- """LightningModule for image classification.
+ def configure_losses(self) -> None:
+ """Initialize the loss criterion.
- Supports any available `Timm model
- `_
- as an architecture choice. To see a list of available
- models, you can do:
+ Raises:
+ ValueError: If *loss* is invalid.
+ """
+ loss: str = self.hparams["loss"]
+ if loss == "ce":
+ self.criterion: nn.Module = nn.CrossEntropyLoss(
+ weight=self.hparams["class_weights"]
+ )
+ elif loss == "bce":
+ self.criterion = nn.BCEWithLogitsLoss()
+ elif loss == "jaccard":
+ self.criterion = JaccardLoss(mode="multiclass")
+ elif loss == "focal":
+ self.criterion = FocalLoss(mode="multiclass", normalized=True)
+ else:
+ raise ValueError(f"Loss type '{loss}' is not valid.")
- .. code-block:: python
+ def configure_metrics(self) -> None:
+ """Initialize the performance metrics."""
+ metrics = MetricCollection(
+ {
+ "OverallAccuracy": MulticlassAccuracy(
+ num_classes=self.hparams["num_classes"], average="micro"
+ ),
+ "AverageAccuracy": MulticlassAccuracy(
+ num_classes=self.hparams["num_classes"], average="macro"
+ ),
+ "JaccardIndex": MulticlassJaccardIndex(
+ num_classes=self.hparams["num_classes"]
+ ),
+ "F1Score": MulticlassFBetaScore(
+ num_classes=self.hparams["num_classes"], beta=1.0, average="micro"
+ ),
+ }
+ )
+ self.train_metrics = metrics.clone(prefix="train_")
+ self.val_metrics = metrics.clone(prefix="val_")
+ self.test_metrics = metrics.clone(prefix="test_")
- import timm
- print(timm.list_models())
- """
+ def configure_models(self) -> None:
+ """Initialize the model."""
+ weights = self.weights
- def config_model(self) -> None:
- """Configures the model based on kwargs parameters passed to the constructor."""
# Create model
- weights = self.hyperparams["weights"]
self.model = timm.create_model(
- self.hyperparams["model"],
- num_classes=self.hyperparams["num_classes"],
- in_chans=self.hyperparams["in_channels"],
+ self.hparams["model"],
+ num_classes=self.hparams["num_classes"],
+ in_chans=self.hparams["in_channels"],
pretrained=weights is True,
)
@@ -65,139 +139,59 @@ def config_model(self) -> None:
self.model = utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
- if self.hyperparams.get("freeze_backbone", False):
+ if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
- def config_task(self) -> None:
- """Configures the task based on kwargs parameters passed to the constructor."""
- self.config_model()
-
- if self.hyperparams["loss"] == "ce":
- self.loss: nn.Module = nn.CrossEntropyLoss()
- elif self.hyperparams["loss"] == "jaccard":
- self.loss = JaccardLoss(mode="multiclass")
- elif self.hyperparams["loss"] == "focal":
- self.loss = FocalLoss(mode="multiclass", normalized=True)
- else:
- raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.")
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the LightningModule with a model and loss function.
-
- Keyword Args:
- model: Name of the classification model use
- loss: Name of the loss function, accepts 'ce', 'jaccard', or 'focal'
- weights: Either a weight enum, the string representation of a weight enum,
- True for ImageNet weights, False or None for random weights,
- or the path to a saved model state dict.
- num_classes: Number of prediction classes
- in_channels: Number of input channels to model
- learning_rate: Learning rate for optimizer
- learning_rate_schedule_patience: Patience for learning rate scheduler
- freeze_backbone: Freeze the backbone network to linear probe
- the classifier head
-
- .. versionchanged:: 0.4
- The *classification_model* parameter was renamed to *model*.
-
- .. versionadded:: 0.5
- The *freeze_backbone* parameter.
- """
- super().__init__()
-
- # Creates `self.hparams` from kwargs
- self.save_hyperparameters()
- self.hyperparams = cast(dict[str, Any], self.hparams)
-
- self.config_task()
-
- self.train_metrics = MetricCollection(
- {
- "OverallAccuracy": MulticlassAccuracy(
- num_classes=self.hyperparams["num_classes"], average="micro"
- ),
- "AverageAccuracy": MulticlassAccuracy(
- num_classes=self.hyperparams["num_classes"], average="macro"
- ),
- "JaccardIndex": MulticlassJaccardIndex(
- num_classes=self.hyperparams["num_classes"]
- ),
- "F1Score": MulticlassFBetaScore(
- num_classes=self.hyperparams["num_classes"],
- beta=1.0,
- average="micro",
- ),
- },
- prefix="train_",
- )
- self.val_metrics = self.train_metrics.clone(prefix="val_")
- self.test_metrics = self.train_metrics.clone(prefix="test_")
-
- def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Forward pass of the model.
-
- Args:
- x: input image
-
- Returns:
- prediction
- """
- return self.model(*args, **kwargs)
-
- def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the training loss.
+ def training_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the training loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- training loss
+ The loss tensor.
"""
- batch = args[0]
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
-
- loss = self.loss(y_hat, y)
-
- # by default, the train step logs every `log_every_n_steps` steps where
- # `log_every_n_steps` is a parameter to the `Trainer` object
- self.log("train_loss", loss, on_step=True, on_epoch=False)
+ loss: Tensor = self.criterion(y_hat, y)
+ self.log("train_loss", loss)
self.train_metrics(y_hat_hard, y)
+ self.log_dict(self.train_metrics)
- return cast(Tensor, loss)
+ return loss
- def on_train_epoch_end(self) -> None:
- """Logs epoch-level training metrics."""
- self.log_dict(self.train_metrics.compute())
- self.train_metrics.reset()
-
- def validation_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute validation loss and log example predictions.
+ def validation_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> None:
+ """Compute the validation loss and additional metrics.
Args:
- batch: the output of your DataLoader
- batch_idx: the index of this batch
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
- batch_idx = args[1]
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
-
- loss = self.loss(y_hat, y)
-
- self.log("val_loss", loss, on_step=False, on_epoch=True)
+ loss = self.criterion(y_hat, y)
+ self.log("val_loss", loss)
self.val_metrics(y_hat_hard, y)
+ self.log_dict(self.val_metrics)
if (
batch_idx < 10
and hasattr(self.trainer, "datamodule")
+ and hasattr(self.trainer.datamodule, "plot")
and self.logger
and hasattr(self.logger, "experiment")
and hasattr(self.logger.experiment, "add_figure")
@@ -209,176 +203,119 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
- summary_writer = self.logger.experiment
- summary_writer.add_figure(
- f"image/{batch_idx}", fig, global_step=self.global_step
- )
- plt.close()
+ if fig:
+ summary_writer = self.logger.experiment
+ summary_writer.add_figure(
+ f"image/{batch_idx}", fig, global_step=self.global_step
+ )
+ plt.close()
except ValueError:
pass
- def on_validation_epoch_end(self) -> None:
- """Logs epoch level validation metrics."""
- self.log_dict(self.val_metrics.compute())
- self.val_metrics.reset()
-
- def test_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute test loss.
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+ """Compute the test loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
-
- loss = self.loss(y_hat, y)
-
- # by default, the test and validation steps only log per *epoch*
- self.log("test_loss", loss, on_step=False, on_epoch=True)
+ loss = self.criterion(y_hat, y)
+ self.log("test_loss", loss)
self.test_metrics(y_hat_hard, y)
+ self.log_dict(self.test_metrics)
- def on_test_epoch_end(self) -> None:
- """Logs epoch level test metrics."""
- self.log_dict(self.test_metrics.compute())
- self.test_metrics.reset()
-
- def predict_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the predictions.
+ def predict_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the predicted class probabilities.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- predicted softmax probabilities
+ Output predicted probabilities.
"""
- batch = args[0]
x = batch["image"]
y_hat: Tensor = self(x).softmax(dim=-1)
return y_hat
- def configure_optimizers(self) -> dict[str, Any]:
- """Initialize the optimizer and learning rate scheduler.
-
- Returns:
- learning rate dictionary
- """
- optimizer = torch.optim.AdamW(
- self.model.parameters(), lr=self.hyperparams["learning_rate"]
- )
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": ReduceLROnPlateau(
- optimizer,
- patience=self.hyperparams["learning_rate_schedule_patience"],
- ),
- "monitor": "val_loss",
- },
- }
-
class MultiLabelClassificationTask(ClassificationTask):
- """LightningModule for multi-label image classification."""
-
- def config_task(self) -> None:
- """Configures the task based on kwargs parameters passed to the constructor."""
- self.config_model()
-
- if self.hyperparams["loss"] == "bce":
- self.loss = nn.BCEWithLogitsLoss()
- else:
- raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.")
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the LightningModule with a model and loss function.
-
- Keyword Args:
- model: Name of the classification model use
- loss: Name of the loss function, currently only supports 'bce'
- weights: Either "random" or 'imagenet'
- num_classes: Number of prediction classes
- in_channels: Number of input channels to model
- learning_rate: Learning rate for optimizer
- learning_rate_schedule_patience: Patience for learning rate scheduler
- freeze_backbone: Freeze the backbone network to linear probe
- the classifier head
+ """Multi-label image classification."""
- .. versionchanged:: 0.4
- The *classification_model* parameter was renamed to *model*.
-
- .. versionadded:: 0.5
- The *freeze_backbone* parameter.
- """
- super().__init__(**kwargs)
-
- self.train_metrics = MetricCollection(
+ def configure_metrics(self) -> None:
+ """Initialize the performance metrics."""
+ metrics = MetricCollection(
{
"OverallAccuracy": MultilabelAccuracy(
- num_labels=self.hyperparams["num_classes"], average="micro"
+ num_labels=self.hparams["num_classes"], average="micro"
),
"AverageAccuracy": MultilabelAccuracy(
- num_labels=self.hyperparams["num_classes"], average="macro"
+ num_labels=self.hparams["num_classes"], average="macro"
),
"F1Score": MultilabelFBetaScore(
- num_labels=self.hyperparams["num_classes"],
- beta=1.0,
- average="micro",
+ num_labels=self.hparams["num_classes"], beta=1.0, average="micro"
),
- },
- prefix="train_",
+ }
)
- self.val_metrics = self.train_metrics.clone(prefix="val_")
- self.test_metrics = self.train_metrics.clone(prefix="test_")
+ self.train_metrics = metrics.clone(prefix="train_")
+ self.val_metrics = metrics.clone(prefix="val_")
+ self.test_metrics = metrics.clone(prefix="test_")
- def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the training loss.
+ def training_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the training loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- training loss
+ The loss tensor.
"""
- batch = args[0]
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
-
- loss = self.loss(y_hat, y.to(torch.float))
-
- # by default, the train step logs every `log_every_n_steps` steps where
- # `log_every_n_steps` is a parameter to the `Trainer` object
- self.log("train_loss", loss, on_step=True, on_epoch=False)
+ loss: Tensor = self.criterion(y_hat, y.to(torch.float))
+ self.log("train_loss", loss)
self.train_metrics(y_hat_hard, y)
+ self.log_dict(self.train_metrics)
- return cast(Tensor, loss)
+ return loss
- def validation_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute validation loss and log example predictions.
+ def validation_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> None:
+ """Compute the validation loss and additional metrics.
Args:
- batch: the output of your DataLoader
- batch_idx: the index of this batch
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
- batch_idx = args[1]
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
-
- loss = self.loss(y_hat, y.to(torch.float))
-
- self.log("val_loss", loss, on_step=False, on_epoch=True)
+ loss = self.criterion(y_hat, y.to(torch.float))
+ self.log("val_loss", loss)
self.val_metrics(y_hat_hard, y)
+ self.log_dict(self.val_metrics)
if (
batch_idx < 10
and hasattr(self.trainer, "datamodule")
+ and hasattr(self.trainer.datamodule, "plot")
and self.logger
and hasattr(self.logger, "experiment")
and hasattr(self.logger.experiment, "add_figure")
@@ -390,40 +327,44 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
- summary_writer = self.logger.experiment
- summary_writer.add_figure(
- f"image/{batch_idx}", fig, global_step=self.global_step
- )
+ if fig:
+ summary_writer = self.logger.experiment
+ summary_writer.add_figure(
+ f"image/{batch_idx}", fig, global_step=self.global_step
+ )
except ValueError:
pass
- def test_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute test loss.
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+ """Compute the test loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
x = batch["image"]
y = batch["label"]
y_hat = self(x)
y_hat_hard = torch.sigmoid(y_hat)
-
- loss = self.loss(y_hat, y.to(torch.float))
-
- # by default, the test and validation steps only log per *epoch*
- self.log("test_loss", loss, on_step=False, on_epoch=True)
+ loss = self.criterion(y_hat, y.to(torch.float))
+ self.log("test_loss", loss)
self.test_metrics(y_hat_hard, y)
+ self.log_dict(self.test_metrics)
- def predict_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the predictions.
+ def predict_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the predicted class probabilities.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
+
Returns:
- predicted sigmoid probabilities
+ Output predicted probabilities.
"""
- batch = args[0]
x = batch["image"]
y_hat = torch.sigmoid(self(x))
return y_hat
diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py
index ed58bf945a2..127c7bfc798 100644
--- a/torchgeo/trainers/detection.py
+++ b/torchgeo/trainers/detection.py
@@ -1,17 +1,15 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-"""Detection tasks."""
+"""Trainers for object detection."""
from functools import partial
-from typing import Any, cast
+from typing import Any, Optional
import matplotlib.pyplot as plt
import torch
import torchvision.models.detection
-from lightning.pytorch import LightningModule
from torch import Tensor
-from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models import resnet as R
@@ -21,6 +19,7 @@
from torchvision.ops import MultiScaleRoIAlign, feature_pyramid_network, misc
from ..datasets.utils import unbind_samples
+from .base import BaseTask
BACKBONE_LAT_DIM_MAP = {
"resnet18": 512,
@@ -47,48 +46,88 @@
}
-class ObjectDetectionTask(LightningModule):
- """LightningModule for object detection of images.
+class ObjectDetectionTask(BaseTask):
+ """Object detection.
- Currently, supports Faster R-CNN, FCOS, and RetinaNet models from
- `torchvision
- `_ with
- one of the following *backbone* arguments:
+ .. versionadded:: 0.4
+ """
- .. code-block:: python
+ monitor = "val_map"
+ mode = "max"
+
+ def __init__(
+ self,
+ model: str = "faster-rcnn",
+ backbone: str = "resnet50",
+ weights: Optional[bool] = None,
+ in_channels: int = 3,
+ num_classes: int = 1000,
+ trainable_layers: int = 3,
+ lr: float = 1e-3,
+ patience: int = 10,
+ freeze_backbone: bool = False,
+ ) -> None:
+ """Initialize a new ObjectDetectionTask instance.
- ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
- 'resnext50_32x4d','resnext101_32x8d', 'wide_resnet50_2',
- 'wide_resnet101_2']
+ Args:
+ model: Name of the `torchvision
+ `__
+ model to use. One of 'faster-rcnn', 'fcos', or 'retinanet'.
+ backbone: Name of the `torchvision
+ `__
+ backbone to use. One of 'resnet18', 'resnet34', 'resnet50',
+ 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', or 'wide_resnet101_2'.
+ weights: Initial model weights. True for ImageNet weights, False or None
+ for random weights.
+ in_channels: Number of input channels to model.
+ num_classes: Number of prediction classes.
+ trainable_layers: Number of trainable layers.
+ lr: Learning rate for optimizer.
+ patience: Patience for learning rate scheduler.
+ freeze_backbone: Freeze the backbone network to fine-tune the detection
+ head.
- .. versionadded:: 0.4
- """
+ .. versionchanged:: 0.4
+ *detection_model* was renamed to *model*.
+
+ .. versionadded:: 0.5
+ The *freeze_backbone* parameter.
- def config_task(self) -> None:
- """Configures the task based on kwargs parameters passed to the constructor."""
- backbone_pretrained = self.hyperparams.get("pretrained", True)
+ .. versionchanged:: 0.5
+ *pretrained*, *learning_rate*, and *learning_rate_schedule_patience* were
+ renamed to *weights*, *lr*, and *patience*.
+ """
+ super().__init__()
- if self.hyperparams["backbone"] in BACKBONE_LAT_DIM_MAP:
+ def configure_models(self) -> None:
+ """Initialize the model.
+
+ Raises:
+ ValueError: If *model* or *backbone* are invalid.
+ """
+ backbone: str = self.hparams["backbone"]
+ model: str = self.hparams["model"]
+ weights: Optional[bool] = self.hparams["weights"]
+ num_classes: int = self.hparams["num_classes"]
+ freeze_backbone: bool = self.hparams["freeze_backbone"]
+
+ if backbone in BACKBONE_LAT_DIM_MAP:
kwargs = {
- "backbone_name": self.hyperparams["backbone"],
- "trainable_layers": self.hyperparams.get("trainable_layers", 3),
+ "backbone_name": backbone,
+ "trainable_layers": self.hparams["trainable_layers"],
}
- if backbone_pretrained:
- kwargs["weights"] = BACKBONE_WEIGHT_MAP[self.hyperparams["backbone"]]
+ if weights:
+ kwargs["weights"] = BACKBONE_WEIGHT_MAP[backbone]
else:
kwargs["weights"] = None
- latent_dim = BACKBONE_LAT_DIM_MAP[self.hyperparams["backbone"]]
+ latent_dim = BACKBONE_LAT_DIM_MAP[backbone]
else:
- raise ValueError(
- f"Backbone type '{self.hyperparams['backbone']}' is not valid."
- )
+ raise ValueError(f"Backbone type '{backbone}' is not valid.")
- num_classes = self.hyperparams["num_classes"]
-
- if self.hyperparams["model"] == "faster-rcnn":
- backbone = resnet_fpn_backbone(**kwargs)
+ if model == "faster-rcnn":
+ model_backbone = resnet_fpn_backbone(**kwargs)
anchor_generator = AnchorGenerator(
sizes=((32), (64), (128), (256), (512)), aspect_ratios=((0.5, 1.0, 2.0))
)
@@ -97,40 +136,40 @@ def config_task(self) -> None:
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
)
- if self.hyperparams.get("freeze_backbone", False):
- for param in backbone.parameters():
+ if freeze_backbone:
+ for param in model_backbone.parameters():
param.requires_grad = False
self.model = torchvision.models.detection.FasterRCNN(
- backbone,
+ model_backbone,
num_classes,
rpn_anchor_generator=anchor_generator,
box_roi_pool=roi_pooler,
)
- elif self.hyperparams["model"] == "fcos":
+ elif model == "fcos":
kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7(256, 256)
kwargs["norm_layer"] = (
- misc.FrozenBatchNorm2d if kwargs["weights"] else torch.nn.BatchNorm2d
+ misc.FrozenBatchNorm2d if weights else torch.nn.BatchNorm2d
)
- backbone = resnet_fpn_backbone(**kwargs)
+ model_backbone = resnet_fpn_backbone(**kwargs)
anchor_generator = AnchorGenerator(
sizes=((8,), (16,), (32,), (64,), (128,), (256,)),
aspect_ratios=((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)),
)
- if self.hyperparams.get("freeze_backbone", False):
- for param in backbone.parameters():
+ if freeze_backbone:
+ for param in model_backbone.parameters():
param.requires_grad = False
self.model = torchvision.models.detection.FCOS(
- backbone, num_classes, anchor_generator=anchor_generator
+ model_backbone, num_classes, anchor_generator=anchor_generator
)
- elif self.hyperparams["model"] == "retinanet":
+ elif model == "retinanet":
kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7(
latent_dim, 256
)
- backbone = resnet_fpn_backbone(**kwargs)
+ model_backbone = resnet_fpn_backbone(**kwargs)
anchor_sizes = (
(16, 20, 25),
@@ -144,75 +183,44 @@ def config_task(self) -> None:
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
head = RetinaNetHead(
- backbone.out_channels,
+ model_backbone.out_channels,
anchor_generator.num_anchors_per_location()[0],
num_classes,
norm_layer=partial(torch.nn.GroupNorm, 32),
)
- if self.hyperparams.get("freeze_backbone", False):
- for param in backbone.parameters():
+ if freeze_backbone:
+ for param in model_backbone.parameters():
param.requires_grad = False
self.model = torchvision.models.detection.RetinaNet(
- backbone, num_classes, anchor_generator=anchor_generator, head=head
+ model_backbone,
+ num_classes,
+ anchor_generator=anchor_generator,
+ head=head,
)
else:
- raise ValueError(f"Model type '{self.hyperparams['model']}' is not valid.")
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the LightningModule with a model and loss function.
-
- Keyword Args:
- model: Name of the detection model type to use
- backbone: Name of the model backbone to use
- in_channels: Number of channels in input image
- num_classes: Number of semantic classes to predict
- learning_rate: Learning rate for optimizer
- learning_rate_schedule_patience: Patience for learning rate scheduler
- freeze_backbone: Freeze the backbone network to fine-tune the detection head
-
- Raises:
- ValueError: if kwargs arguments are invalid
-
- .. versionchanged:: 0.4
- The *detection_model* parameter was renamed to *model*.
-
- .. versionadded:: 0.5
- The *freeze_backbone* parameter.
- """
- super().__init__()
- # Creates `self.hparams` from kwargs
- self.save_hyperparameters()
- self.hyperparams = cast(dict[str, Any], self.hparams)
-
- self.config_task()
+ raise ValueError(f"Model type '{model}' is not valid.")
+ def configure_metrics(self) -> None:
+ """Initialize the performance metrics."""
metrics = MetricCollection([MeanAveragePrecision()])
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
- def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Forward pass of the model.
-
- Args:
- x: tensor of data to run through the model
-
- Returns:
- output from the model
- """
- return self.model(*args, **kwargs)
-
- def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the training loss.
+ def training_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the training loss.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- training loss
+ The loss tensor.
"""
- batch = args[0]
x = batch["image"]
batch_size = x.shape[0]
y = [
@@ -220,21 +228,20 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
for i in range(batch_size)
]
loss_dict = self(x, y)
- train_loss = sum(loss_dict.values())
-
+ train_loss: Tensor = sum(loss_dict.values())
self.log_dict(loss_dict)
+ return train_loss
- return cast(Tensor, train_loss)
-
- def validation_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute validation loss and log example predictions.
+ def validation_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> None:
+ """Compute the validation metrics.
Args:
- batch: the output of your DataLoader
- batch_idx: the index of this batch
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
- batch_idx = args[1]
x = batch["image"]
batch_size = x.shape[0]
y = [
@@ -242,12 +249,17 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
for i in range(batch_size)
]
y_hat = self(x)
+ metrics = self.val_metrics(y_hat, y)
- self.val_metrics.update(y_hat, y)
+ # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
+ metrics.pop("val_classes", None)
+
+ self.log_dict(metrics)
if (
batch_idx < 10
and hasattr(self.trainer, "datamodule")
+ and hasattr(self.trainer.datamodule, "plot")
and self.logger
and hasattr(self.logger, "experiment")
and hasattr(self.logger.experiment, "add_figure")
@@ -264,31 +276,23 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
sample["image"] *= 255
sample["image"] = sample["image"].to(torch.uint8)
fig = datamodule.plot(sample)
- summary_writer = self.logger.experiment
- summary_writer.add_figure(
- f"image/{batch_idx}", fig, global_step=self.global_step
- )
- plt.close()
+ if fig:
+ summary_writer = self.logger.experiment
+ summary_writer.add_figure(
+ f"image/{batch_idx}", fig, global_step=self.global_step
+ )
+ plt.close()
except ValueError:
pass
- def on_validation_epoch_end(self) -> None:
- """Logs epoch level validation metrics."""
- metrics = self.val_metrics.compute()
-
- # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
- metrics.pop("val_classes", None)
-
- self.log_dict(metrics)
- self.val_metrics.reset()
-
- def test_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute test MAP.
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+ """Compute the test metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
x = batch["image"]
batch_size = x.shape[0]
y = [
@@ -296,50 +300,26 @@ def test_step(self, *args: Any, **kwargs: Any) -> None:
for i in range(batch_size)
]
y_hat = self(x)
-
- self.test_metrics.update(y_hat, y)
-
- def on_test_epoch_end(self) -> None:
- """Logs epoch level test metrics."""
- metrics = self.test_metrics.compute()
+ metrics = self.test_metrics(y_hat, y)
# https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
metrics.pop("test_classes", None)
self.log_dict(metrics)
- self.test_metrics.reset()
- def predict_step(self, *args: Any, **kwargs: Any) -> list[dict[str, Tensor]]:
- """Compute and return the predictions.
+ def predict_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> list[dict[str, Tensor]]:
+ """Compute the predicted bounding boxes.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- list of predicted boxes, labels and scores
+ Output predicted probabilities.
"""
- batch = args[0]
x = batch["image"]
y_hat: list[dict[str, Tensor]] = self(x)
return y_hat
-
- def configure_optimizers(self) -> dict[str, Any]:
- """Initialize the optimizer and learning rate scheduler.
-
- Returns:
- learning rate dictionary
- """
- optimizer = torch.optim.Adam(
- self.model.parameters(), lr=self.hyperparams["learning_rate"]
- )
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": ReduceLROnPlateau(
- optimizer,
- mode="max",
- patience=self.hyperparams["learning_rate_schedule_patience"],
- ),
- "monitor": "val_map",
- },
- }
diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py
index f5425390bdd..d2621a8da74 100644
--- a/torchgeo/trainers/moco.py
+++ b/torchgeo/trainers/moco.py
@@ -6,9 +6,10 @@
import os
import warnings
from collections.abc import Sequence
-from typing import Any, Optional, Union, cast
+from typing import Any, Optional, Union
import kornia.augmentation as K
+import lightning
import timm
import torch
import torch.nn as nn
@@ -17,7 +18,6 @@
from lightly.models.modules import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.utils.scheduler import cosine_schedule
-from lightning import LightningModule
from torch import Tensor
from torch.optim import SGD, AdamW, Optimizer
from torch.optim.lr_scheduler import (
@@ -32,6 +32,7 @@
from ..models import get_weight
from . import utils
+from .base import BaseTask
try:
from torch.optim.lr_scheduler import LRScheduler
@@ -118,7 +119,7 @@ def moco_augmentations(
return aug1, aug2
-class MoCoTask(LightningModule):
+class MoCoTask(BaseTask):
"""MoCo: Momentum Contrast.
Reference implementations:
@@ -135,6 +136,8 @@ class MoCoTask(LightningModule):
.. versionadded:: 0.5
"""
+ monitor = "train_loss"
+
def __init__(
self,
model: str = "resnet50",
@@ -160,7 +163,8 @@ def __init__(
"""Initialize a new MoCoTask instance.
Args:
- model: Name of the timm model to use.
+ model: Name of the `timm
+ `__ model to use.
weights: Initial model weights. Either a weight enum, the string
representation of a weight enum, True for ImageNet weights, False
or None for random weights, or the path to a saved model state dict.
@@ -198,8 +202,6 @@ def __init__(
Warns:
UserWarning: If hyperparameters do not match MoCo version requested.
"""
- super().__init__()
-
# Validate hyperparameters
assert version in range(1, 4)
if version == 1:
@@ -216,13 +218,32 @@ def __init__(
if memory_bank_size > 0:
warnings.warn("MoCo v3 does not use a memory bank")
- self.save_hyperparameters(ignore=["augmentation1", "augmentation2"])
+ self.weights = weights
+ super().__init__(ignore=["weights", "augmentation1", "augmentation2"])
grayscale_weights = grayscale_weights or torch.ones(in_channels)
aug1, aug2 = moco_augmentations(version, size, grayscale_weights)
self.augmentation1 = augmentation1 or aug1
self.augmentation2 = augmentation2 or aug2
+ def configure_losses(self) -> None:
+ """Initialize the loss criterion."""
+ self.criterion = NTXentLoss(
+ self.hparams["temperature"],
+ self.hparams["memory_bank_size"],
+ self.hparams["gather_distributed"],
+ )
+
+ def configure_models(self) -> None:
+ """Initialize the model."""
+ model: str = self.hparams["model"]
+ weights = self.weights
+ in_channels: int = self.hparams["in_channels"]
+ version: int = self.hparams["version"]
+ layers: int = self.hparams["layers"]
+ hidden_dim: int = self.hparams["hidden_dim"]
+ output_dim: int = self.hparams["output_dim"]
+
# Create backbone
self.backbone = timm.create_model(
model, in_chans=in_channels, num_classes=0, pretrained=weights is True
@@ -258,12 +279,54 @@ def __init__(
output_dim, hidden_dim, output_dim, num_layers=2, batch_norm=batch_norm
)
- # Define loss function
- self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed)
-
# Initialize moving average of output
self.avg_output_std = 0.0
+ def configure_optimizers(
+ self,
+ ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
+ """Initialize the optimizer and learning rate scheduler.
+
+ Returns:
+ Optimizer and learning rate scheduler.
+ """
+ if self.hparams["version"] == 3:
+ optimizer: Optimizer = AdamW(
+ params=self.parameters(),
+ lr=self.hparams["lr"],
+ weight_decay=self.hparams["weight_decay"],
+ )
+ warmup_epochs = 40
+ max_epochs = 200
+ if self.trainer and self.trainer.max_epochs:
+ max_epochs = self.trainer.max_epochs
+ scheduler: LRScheduler = SequentialLR(
+ optimizer,
+ schedulers=[
+ LinearLR(
+ optimizer,
+ start_factor=1 / warmup_epochs,
+ total_iters=warmup_epochs,
+ ),
+ CosineAnnealingLR(optimizer, T_max=max_epochs),
+ ],
+ milestones=[warmup_epochs],
+ )
+ else:
+ optimizer = SGD(
+ params=self.parameters(),
+ lr=self.hparams["lr"],
+ momentum=self.hparams["momentum"],
+ weight_decay=self.hparams["weight_decay"],
+ )
+ scheduler = MultiStepLR(
+ optimizer=optimizer, milestones=self.hparams["schedule"]
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor},
+ }
+
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""Forward pass of the model.
@@ -271,15 +334,15 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
x: Mini-batch of images.
Returns:
- Output from the model and backbone
+ Output of the model and backbone
"""
- h = self.backbone(x)
+ h: Tensor = self.backbone(x)
q = h
if self.hparams["version"] > 1:
q = self.projection_head(q)
if self.hparams["version"] == 3:
q = self.prediction_head(q)
- return cast(Tensor, q), cast(Tensor, h)
+ return q, h
def forward_momentum(self, x: Tensor) -> Tensor:
"""Forward pass of the momentum model.
@@ -290,10 +353,10 @@ def forward_momentum(self, x: Tensor) -> Tensor:
Returns:
Output from the momentum model.
"""
- k = self.backbone_momentum(x)
+ k: Tensor = self.backbone_momentum(x)
if self.hparams["version"] > 1:
k = self.projection_head_momentum(k)
- return cast(Tensor, k)
+ return k
def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
@@ -330,7 +393,7 @@ def training_step(
with torch.no_grad():
update_momentum(self.backbone, self.backbone_momentum, m)
k = self.forward_momentum(x2)
- loss = self.criterion(q, k)
+ loss: Tensor = self.criterion(q, k)
elif self.hparams["version"] == 2:
q, h1 = self.forward(x1)
with torch.no_grad():
@@ -360,7 +423,7 @@ def training_step(
self.log("train_ssl_std", self.avg_output_std)
self.log("train_loss", loss)
- return cast(Tensor, loss)
+ return loss
def validation_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
@@ -372,43 +435,3 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""No-op, does nothing."""
-
- def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
- """Initialize the optimizer and learning rate scheduler.
-
- Returns:
- Optimizer and learning rate scheduler.
- """
- if self.hparams["version"] == 3:
- optimizer: Optimizer = AdamW(
- params=self.parameters(),
- lr=self.hparams["lr"],
- weight_decay=self.hparams["weight_decay"],
- )
- warmup_epochs = 40
- max_epochs = 200
- if self.trainer and self.trainer.max_epochs:
- max_epochs = self.trainer.max_epochs
- lr_scheduler: LRScheduler = SequentialLR(
- optimizer,
- schedulers=[
- LinearLR(
- optimizer,
- start_factor=1 / warmup_epochs,
- total_iters=warmup_epochs,
- ),
- CosineAnnealingLR(optimizer, T_max=max_epochs),
- ],
- milestones=[warmup_epochs],
- )
- else:
- optimizer = SGD(
- params=self.parameters(),
- lr=self.hparams["lr"],
- momentum=self.hparams["momentum"],
- weight_decay=self.hparams["weight_decay"],
- )
- lr_scheduler = MultiStepLR(
- optimizer=optimizer, milestones=self.hparams["schedule"]
- )
- return [optimizer], [lr_scheduler]
diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py
index ce62a64947a..b540ceecddc 100644
--- a/torchgeo/trainers/regression.py
+++ b/torchgeo/trainers/regression.py
@@ -1,51 +1,121 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-"""Regression tasks."""
+"""Trainers for regression."""
import os
-from typing import Any, cast
+from typing import Any, Optional, Union
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
-from lightning.pytorch import LightningModule
from torch import Tensor
-from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchvision.models._api import WeightsEnum
from ..datasets import unbind_samples
from ..models import FCN, get_weight
from . import utils
+from .base import BaseTask
+
+
+class RegressionTask(BaseTask):
+ """Regression."""
+
+ target_key = "label"
+
+ def __init__(
+ self,
+ model: str = "resnet50",
+ backbone: str = "resnet50",
+ weights: Optional[Union[WeightsEnum, str, bool]] = None,
+ in_channels: int = 3,
+ num_outputs: int = 1,
+ num_filters: int = 3,
+ loss: str = "mse",
+ lr: float = 1e-3,
+ patience: int = 10,
+ freeze_backbone: bool = False,
+ freeze_decoder: bool = False,
+ ) -> None:
+ """Initialize a new RegressionTask instance.
+ Args:
+ model: Name of the
+ `timm `__ or
+ `smp `__ model to use.
+ backbone: Name of the
+ `timm `__ or
+ `smp `__ backbone
+ to use. Only applicable to PixelwiseRegressionTask.
+ weights: Initial model weights. Either a weight enum, the string
+ representation of a weight enum, True for ImageNet weights, False
+ or None for random weights, or the path to a saved model state dict.
+ in_channels: Number of input channels to model.
+ num_outputs: Number of prediction outputs.
+ num_filters: Number of filters. Only applicable when model='fcn'.
+ loss: One of 'mse' or 'mae'.
+ lr: Learning rate for optimizer.
+ patience: Patience for learning rate scheduler.
+ freeze_backbone: Freeze the backbone network to linear probe
+ the regression head. Does not support FCN models.
+ freeze_decoder: Freeze the decoder network to linear probe
+ the regression head. Does not support FCN models.
+ Only applicable to PixelwiseRegressionTask.
-class RegressionTask(LightningModule):
- """LightningModule for training models on regression datasets.
+ .. versionchanged:: 0.4
+ Change regression model support from torchvision.models to timm
- Supports any available `Timm model
- `_
- as an architecture choice. To see a list of available
- models, you can do:
+ .. versionadded:: 0.5
+ The *freeze_backbone* and *freeze_decoder* parameters.
- .. code-block:: python
+ .. versionchanged:: 0.5
+ *learning_rate* and *learning_rate_schedule_patience* were renamed to
+ *lr* and *patience*.
+ """
+ self.weights = weights
+ super().__init__(ignore="weights")
- import timm
- print(timm.list_models())
- """
+ def configure_losses(self) -> None:
+ """Initialize the loss criterion.
+
+ Raises:
+ ValueError: If *loss* is invalid.
+ """
+ loss: str = self.hparams["loss"]
+ if loss == "mse":
+ self.criterion: nn.Module = nn.MSELoss()
+ elif loss == "mae":
+ self.criterion = nn.L1Loss()
+ else:
+ raise ValueError(
+ f"Loss type '{loss}' is not valid. "
+ "Currently, supports 'mse' or 'mae' loss."
+ )
- target_key: str = "label"
+ def configure_metrics(self) -> None:
+ """Initialize the performance metrics."""
+ metrics = MetricCollection(
+ {
+ "RMSE": MeanSquaredError(squared=False),
+ "MSE": MeanSquaredError(squared=True),
+ "MAE": MeanAbsoluteError(),
+ }
+ )
+ self.train_metrics = metrics.clone(prefix="train_")
+ self.val_metrics = metrics.clone(prefix="val_")
+ self.test_metrics = metrics.clone(prefix="test_")
- def config_model(self) -> None:
- """Configures the model based on kwargs parameters."""
+ def configure_models(self) -> None:
+ """Initialize the model."""
# Create model
- weights = self.hyperparams["weights"]
+ weights = self.weights
self.model = timm.create_model(
- self.hyperparams["model"],
- num_classes=self.hyperparams["num_outputs"],
- in_chans=self.hyperparams["in_channels"],
+ self.hparams["model"],
+ num_classes=self.hparams["num_outputs"],
+ in_chans=self.hparams["in_channels"],
pretrained=weights is True,
)
@@ -60,131 +130,63 @@ def config_model(self) -> None:
self.model = utils.load_state_dict(self.model, state_dict)
# Freeze backbone and unfreeze classifier head
- if self.hyperparams.get("freeze_backbone", False):
+ if self.hparams["freeze_backbone"]:
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True
- def config_task(self) -> None:
- """Configures the task based on kwargs parameters."""
- self.config_model()
-
- self.loss: nn.Module
- if self.hyperparams["loss"] == "mse":
- self.loss = nn.MSELoss()
- elif self.hyperparams["loss"] == "mae":
- self.loss = nn.L1Loss()
- else:
- raise ValueError(
- f"Loss type '{self.hyperparams['loss']}' is not valid. "
- f"Currently, supports 'mse' or 'mae' loss."
- )
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize a new LightningModule for training simple regression models.
-
- Keyword Args:
- model: Name of the timm model to use
- weights: Either a weight enum, the string representation of a weight enum,
- True for ImageNet weights, False or None for random weights,
- or the path to a saved model state dict.
- num_outputs: Number of prediction outputs
- in_channels: Number of input channels to model
- learning_rate: Learning rate for optimizer
- learning_rate_schedule_patience: Patience for learning rate scheduler
- freeze_backbone: Freeze the backbone network to linear probe
- the regression head. Does not support FCN models.
- freeze_decoder: Freeze the decoder network to linear probe
- the regression head. Does not support FCN models.
- Only applicable to PixelwiseRegressionTask.
-
- .. versionchanged:: 0.4
- Change regression model support from torchvision.models to timm
-
- .. versionadded:: 0.5
- The *freeze_backbone* and *freeze_decoder* parameters.
- """
- super().__init__()
-
- # Creates `self.hparams` from kwargs
- self.save_hyperparameters()
- self.hyperparams = cast(dict[str, Any], self.hparams)
- self.config_task()
-
- self.train_metrics = MetricCollection(
- {
- "RMSE": MeanSquaredError(squared=False),
- "MSE": MeanSquaredError(squared=True),
- "MAE": MeanAbsoluteError(),
- },
- prefix="train_",
- )
- self.val_metrics = self.train_metrics.clone(prefix="val_")
- self.test_metrics = self.train_metrics.clone(prefix="test_")
-
- def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Forward pass of the model.
-
- Args:
- x: tensor of data to run through the model
-
- Returns:
- output from the model
- """
- return self.model(*args, **kwargs)
-
- def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the training loss.
+ def training_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the training loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- training loss
+ The loss tensor.
"""
- batch = args[0]
x = batch["image"]
- y = batch[self.target_key]
+ # TODO: remove .to(...) once we have a real pixelwise regression dataset
+ y = batch[self.target_key].to(torch.float)
y_hat = self(x)
-
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)
-
- loss: Tensor = self.loss(y_hat, y.to(torch.float))
- self.log("train_loss", loss) # logging to TensorBoard
- self.train_metrics(y_hat, y.to(torch.float))
+ loss: Tensor = self.criterion(y_hat, y)
+ self.log("train_loss", loss)
+ self.train_metrics(y_hat, y)
+ self.log_dict(self.train_metrics)
return loss
- def on_train_epoch_end(self) -> None:
- """Logs epoch-level training metrics."""
- self.log_dict(self.train_metrics.compute())
- self.train_metrics.reset()
-
- def validation_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute validation loss and log example predictions.
+ def validation_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> None:
+ """Compute the validation loss and additional metrics.
Args:
- batch: the output of your DataLoader
- batch_idx: the index of this batch
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
- batch_idx = args[1]
x = batch["image"]
- y = batch[self.target_key]
+ # TODO: remove .to(...) once we have a real pixelwise regression dataset
+ y = batch[self.target_key].to(torch.float)
y_hat = self(x)
-
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)
-
- loss = self.loss(y_hat, y.to(torch.float))
+ loss = self.criterion(y_hat, y)
self.log("val_loss", loss)
- self.val_metrics(y_hat, y.to(torch.float))
+ self.val_metrics(y_hat, y)
+ self.log_dict(self.val_metrics)
if (
batch_idx < 10
and hasattr(self.trainer, "datamodule")
+ and hasattr(self.trainer.datamodule, "plot")
and self.logger
and hasattr(self.logger, "experiment")
and hasattr(self.logger.experiment, "add_figure")
@@ -199,120 +201,91 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
- summary_writer = self.logger.experiment
- summary_writer.add_figure(
- f"image/{batch_idx}", fig, global_step=self.global_step
- )
- plt.close()
+ if fig:
+ summary_writer = self.logger.experiment
+ summary_writer.add_figure(
+ f"image/{batch_idx}", fig, global_step=self.global_step
+ )
+ plt.close()
except ValueError:
pass
- def on_validation_epoch_end(self) -> None:
- """Logs epoch level validation metrics."""
- self.log_dict(self.val_metrics.compute())
- self.val_metrics.reset()
-
- def test_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute test loss.
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+ """Compute the test loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
x = batch["image"]
- y = batch[self.target_key]
+ # TODO: remove .to(...) once we have a real pixelwise regression dataset
+ y = batch[self.target_key].to(torch.float)
y_hat = self(x)
-
if y_hat.ndim != y.ndim:
y = y.unsqueeze(dim=1)
-
- loss = self.loss(y_hat, y.to(torch.float))
+ loss = self.criterion(y_hat, y)
self.log("test_loss", loss)
- self.test_metrics(y_hat, y.to(torch.float))
+ self.test_metrics(y_hat, y)
+ self.log_dict(self.test_metrics)
- def on_test_epoch_end(self) -> None:
- """Logs epoch level test metrics."""
- self.log_dict(self.test_metrics.compute())
- self.test_metrics.reset()
-
- def predict_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the predictions.
+ def predict_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the predicted regression values.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
+
Returns:
- predicted values
+ Output predicted probabilities.
"""
- batch = args[0]
x = batch["image"]
y_hat: Tensor = self(x)
return y_hat
- def configure_optimizers(self) -> dict[str, Any]:
- """Initialize the optimizer and learning rate scheduler.
-
- Returns:
- learning rate dictionary
- """
- optimizer = torch.optim.AdamW(
- self.model.parameters(), lr=self.hyperparams["learning_rate"]
- )
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": ReduceLROnPlateau(
- optimizer,
- patience=self.hyperparams["learning_rate_schedule_patience"],
- ),
- "monitor": "val_loss",
- },
- }
-
class PixelwiseRegressionTask(RegressionTask):
"""LightningModule for pixelwise regression of images.
- Supports `Segmentation Models Pytorch
- `_
- as an architecture choice in combination with any of these
- `TIMM backbones `_.
-
.. versionadded:: 0.5
"""
- target_key: str = "mask"
+ target_key = "mask"
- def config_model(self) -> None:
- """Configures the model based on kwargs parameters."""
- weights = self.hyperparams["weights"]
+ def configure_models(self) -> None:
+ """Initialize the model."""
+ weights = self.weights
- if self.hyperparams["model"] == "unet":
+ if self.hparams["model"] == "unet":
self.model = smp.Unet(
- encoder_name=self.hyperparams["backbone"],
+ encoder_name=self.hparams["backbone"],
encoder_weights="imagenet" if weights is True else None,
- in_channels=self.hyperparams["in_channels"],
+ in_channels=self.hparams["in_channels"],
classes=1,
)
- elif self.hyperparams["model"] == "deeplabv3+":
+ elif self.hparams["model"] == "deeplabv3+":
self.model = smp.DeepLabV3Plus(
- encoder_name=self.hyperparams["backbone"],
+ encoder_name=self.hparams["backbone"],
encoder_weights="imagenet" if weights is True else None,
- in_channels=self.hyperparams["in_channels"],
+ in_channels=self.hparams["in_channels"],
classes=1,
)
- elif self.hyperparams["model"] == "fcn":
+ elif self.hparams["model"] == "fcn":
self.model = FCN(
- in_channels=self.hyperparams["in_channels"],
+ in_channels=self.hparams["in_channels"],
classes=1,
- num_filters=self.hyperparams["num_filters"],
+ num_filters=self.hparams["num_filters"],
)
else:
raise ValueError(
- f"Model type '{self.hyperparams['model']}' is not valid. "
- f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
+ f"Model type '{self.hparams['model']}' is not valid. "
+ "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)
- if self.hyperparams["model"] != "fcn":
+ if self.hparams["model"] != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
@@ -323,15 +296,17 @@ def config_model(self) -> None:
self.model.encoder.load_state_dict(state_dict)
# Freeze backbone
- if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
- "model"
- ] in ["unet", "deeplabv3+"]:
+ if self.hparams.get("freeze_backbone", False) and self.hparams["model"] in [
+ "unet",
+ "deeplabv3+",
+ ]:
for param in self.model.encoder.parameters():
param.requires_grad = False
# Freeze decoder
- if self.hyperparams.get("freeze_decoder", False) and self.hyperparams[
- "model"
- ] in ["unet", "deeplabv3+"]:
+ if self.hparams.get("freeze_decoder", False) and self.hparams["model"] in [
+ "unet",
+ "deeplabv3+",
+ ]:
for param in self.model.decoder.parameters():
param.requires_grad = False
diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py
index e0497de1b9c..9ee51d8262c 100644
--- a/torchgeo/trainers/segmentation.py
+++ b/torchgeo/trainers/segmentation.py
@@ -1,19 +1,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
-"""Segmentation tasks."""
+"""Trainers for semantic segmentation."""
import os
import warnings
-from typing import Any, cast
+from typing import Any, Optional, Union
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
-import torch
import torch.nn as nn
-from lightning.pytorch import LightningModule
from torch import Tensor
-from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchvision.models._api import WeightsEnum
@@ -21,74 +18,170 @@
from ..datasets.utils import unbind_samples
from ..models import FCN, get_weight
from . import utils
+from .base import BaseTask
+
+
+class SemanticSegmentationTask(BaseTask):
+ """Semantic Segmentation."""
+
+ def __init__(
+ self,
+ model: str = "unet",
+ backbone: str = "resnet50",
+ weights: Optional[Union[WeightsEnum, str, bool]] = None,
+ in_channels: int = 3,
+ num_classes: int = 1000,
+ num_filters: int = 3,
+ loss: str = "ce",
+ class_weights: Optional[Tensor] = None,
+ ignore_index: Optional[int] = None,
+ lr: float = 1e-3,
+ patience: int = 10,
+ freeze_backbone: bool = False,
+ freeze_decoder: bool = False,
+ ) -> None:
+ """Inititalize a new SemanticSegmentationTask instance.
+ Args:
+ model: Name of the
+ `smp `__ model to use.
+ backbone: Name of the `timm
+ `__ or `smp
+ `__ backbone to use.
+ weights: Initial model weights. Either a weight enum, the string
+ representation of a weight enum, True for ImageNet weights, False or
+ None for random weights, or the path to a saved model state dict. FCN
+ model does not support pretrained weights. Pretrained ViT weight enums
+ are not supported yet.
+ in_channels: Number of input channels to model.
+ num_classes: Number of prediction classes.
+ num_filters: Number of filters. Only applicable when model='fcn'.
+ loss: Name of the loss function, currently supports
+ 'ce', 'jaccard' or 'focal' loss.
+ class_weights: Optional rescaling weight given to each
+ class and used with 'ce' loss.
+ ignore_index: Optional integer class index to ignore in the loss and
+ metrics.
+ lr: Learning rate for optimizer.
+ patience: Patience for learning rate scheduler.
+ freeze_backbone: Freeze the backbone network to fine-tune the
+ decoder and segmentation head.
+ freeze_decoder: Freeze the decoder network to linear probe
+ the segmentation head.
-class SemanticSegmentationTask(LightningModule):
- """LightningModule for semantic segmentation of images.
+ Warns:
+ UserWarning: When loss='jaccard' and ignore_index is specified.
- Supports `Segmentation Models Pytorch
- `_
- as an architecture choice in combination with any of these
- `TIMM backbones `_.
- """
+ .. versionchanged:: 0.3
+ *ignore_zeros* was renamed to *ignore_index*.
- def config_task(self) -> None:
- """Configures the task based on kwargs parameters passed to the constructor."""
- weights = self.hyperparams["weights"]
+ .. versionchanged:: 0.4
+ *segmentation_model*, *encoder_name*, and *encoder_weights*
+ were renamed to *model*, *backbone*, and *weights*.
- if self.hyperparams["model"] == "unet":
- self.model = smp.Unet(
- encoder_name=self.hyperparams["backbone"],
- encoder_weights="imagenet" if weights is True else None,
- in_channels=self.hyperparams["in_channels"],
- classes=self.hyperparams["num_classes"],
+ .. versionadded: 0.5
+ The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters.
+
+ .. versionchanged:: 0.5
+ The *weights* parameter now supports WeightEnums and checkpoint paths.
+ *learning_rate* and *learning_rate_schedule_patience* were renamed to
+ *lr* and *patience*.
+ """
+ if ignore_index is not None and loss == "jaccard":
+ warnings.warn(
+ "ignore_index has no effect on training when loss='jaccard'",
+ UserWarning,
)
- elif self.hyperparams["model"] == "deeplabv3+":
- self.model = smp.DeepLabV3Plus(
- encoder_name=self.hyperparams["backbone"],
- encoder_weights="imagenet" if weights is True else None,
- in_channels=self.hyperparams["in_channels"],
- classes=self.hyperparams["num_classes"],
+
+ self.weights = weights
+ super().__init__(ignore="weights")
+
+ def configure_losses(self) -> None:
+ """Initialize the loss criterion.
+
+ Raises:
+ ValueError: If *loss* is invalid.
+ """
+ loss: str = self.hparams["loss"]
+ ignore_index = self.hparams["ignore_index"]
+ if loss == "ce":
+ ignore_value = -1000 if ignore_index is None else ignore_index
+ self.criterion = nn.CrossEntropyLoss(
+ ignore_index=ignore_value, weight=self.hparams["class_weights"]
)
- elif self.hyperparams["model"] == "fcn":
- self.model = FCN(
- in_channels=self.hyperparams["in_channels"],
- classes=self.hyperparams["num_classes"],
- num_filters=self.hyperparams["num_filters"],
+ elif loss == "jaccard":
+ self.criterion = smp.losses.JaccardLoss(
+ mode="multiclass", classes=self.hparams["num_classes"]
+ )
+ elif loss == "focal":
+ self.criterion = smp.losses.FocalLoss(
+ "multiclass", ignore_index=ignore_index, normalized=True
)
else:
raise ValueError(
- f"Model type '{self.hyperparams['model']}' is not valid. "
- f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
+ f"Loss type '{loss}' is not valid. "
+ "Currently, supports 'ce', 'jaccard' or 'focal' loss."
)
- if self.hyperparams["loss"] == "ce":
- ignore_value = -1000 if self.ignore_index is None else self.ignore_index
+ def configure_metrics(self) -> None:
+ """Initialize the performance metrics."""
+ num_classes: int = self.hparams["num_classes"]
+ ignore_index: Optional[int] = self.hparams["ignore_index"]
+ metrics = MetricCollection(
+ [
+ MulticlassAccuracy(
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ multidim_average="global",
+ average="micro",
+ ),
+ MulticlassJaccardIndex(
+ num_classes=num_classes, ignore_index=ignore_index, average="micro"
+ ),
+ ]
+ )
+ self.train_metrics = metrics.clone(prefix="train_")
+ self.val_metrics = metrics.clone(prefix="val_")
+ self.test_metrics = metrics.clone(prefix="test_")
- class_weights = None
- if isinstance(self.class_weights, torch.Tensor):
- class_weights = self.class_weights.to(dtype=torch.float32)
- elif hasattr(self.class_weights, "__array__") or self.class_weights:
- class_weights = torch.tensor(self.class_weights, dtype=torch.float32)
+ def configure_models(self) -> None:
+ """Initialize the model.
- self.loss = nn.CrossEntropyLoss(
- ignore_index=ignore_value, weight=class_weights
+ Raises:
+ ValueError: If *model* is invalid.
+ """
+ model: str = self.hparams["model"]
+ backbone: str = self.hparams["backbone"]
+ weights = self.weights
+ in_channels: int = self.hparams["in_channels"]
+ num_classes: int = self.hparams["num_classes"]
+ num_filters: int = self.hparams["num_filters"]
+
+ if model == "unet":
+ self.model = smp.Unet(
+ encoder_name=backbone,
+ encoder_weights="imagenet" if weights is True else None,
+ in_channels=in_channels,
+ classes=num_classes,
)
- elif self.hyperparams["loss"] == "jaccard":
- self.loss = smp.losses.JaccardLoss(
- mode="multiclass", classes=self.hyperparams["num_classes"]
+ elif model == "deeplabv3+":
+ self.model = smp.DeepLabV3Plus(
+ encoder_name=backbone,
+ encoder_weights="imagenet" if weights is True else None,
+ in_channels=in_channels,
+ classes=num_classes,
)
- elif self.hyperparams["loss"] == "focal":
- self.loss = smp.losses.FocalLoss(
- "multiclass", ignore_index=self.ignore_index, normalized=True
+ elif model == "fcn":
+ self.model = FCN(
+ in_channels=in_channels, classes=num_classes, num_filters=num_filters
)
else:
raise ValueError(
- f"Loss type '{self.hyperparams['loss']}' is not valid. "
- f"Currently, supports 'ce', 'jaccard' or 'focal' loss."
+ f"Model type '{model}' is not valid. "
+ "Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)
- if self.hyperparams["model"] != "fcn":
+ if model != "fcn":
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
@@ -99,161 +192,61 @@ def config_task(self) -> None:
self.model.encoder.load_state_dict(state_dict)
# Freeze backbone
- if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
- "model"
- ] in ["unet", "deeplabv3+"]:
+ if self.hparams["freeze_backbone"] and model in ["unet", "deeplabv3+"]:
for param in self.model.encoder.parameters():
param.requires_grad = False
# Freeze decoder
- if self.hyperparams.get("freeze_decoder", False) and self.hyperparams[
- "model"
- ] in ["unet", "deeplabv3+"]:
+ if self.hparams["freeze_decoder"] and model in ["unet", "deeplabv3+"]:
for param in self.model.decoder.parameters():
param.requires_grad = False
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the LightningModule with a model and loss function.
-
- Keyword Args:
- model: Name of the segmentation model type to use
- backbone: Name of the timm backbone to use
- weights: Either a weight enum, the string representation of a weight enum,
- True for ImageNet weights, False or None for random weights,
- or the path to a saved model state dict. FCN model does not support
- pretrained weights. Pretrained ViT weight enums are not supported yet.
- in_channels: Number of channels in input image
- num_classes: Number of semantic classes to predict
- loss: Name of the loss function, currently supports
- 'ce', 'jaccard' or 'focal' loss
- class_weights: Optional rescaling weight given to each
- class and used with 'ce' loss
- ignore_index: Optional integer class index to ignore in the loss and metrics
- learning_rate: Learning rate for optimizer
- learning_rate_schedule_patience: Patience for learning rate scheduler
- freeze_backbone: Freeze the backbone network to fine-tune the
- decoder and segmentation head
- freeze_decoder: Freeze the decoder network to linear probe
- the segmentation head
-
- Raises:
- ValueError: if kwargs arguments are invalid
-
- .. versionchanged:: 0.3
- The *ignore_zeros* parameter was renamed to *ignore_index*.
-
- .. versionchanged:: 0.4
- The *segmentation_model* parameter was renamed to *model*,
- *encoder_name* renamed to *backbone*, and
- *encoder_weights* to *weights*.
-
- .. versionadded: 0.5
- The *class_weights*, *freeze_backbone*,
- and *freeze_decoder* parameters.
-
- .. versionchanged:: 0.5
- The *weights* parameter now supports WeightEnums and checkpoint paths.
-
- """
- super().__init__()
-
- # Creates `self.hparams` from kwargs
- self.save_hyperparameters()
- self.hyperparams = cast(dict[str, Any], self.hparams)
-
- if not isinstance(kwargs["ignore_index"], (int, type(None))):
- raise ValueError("ignore_index must be an int or None")
- if (kwargs["ignore_index"] is not None) and (kwargs["loss"] == "jaccard"):
- warnings.warn(
- "ignore_index has no effect on training when loss='jaccard'",
- UserWarning,
- )
- self.ignore_index = kwargs["ignore_index"]
- self.class_weights = kwargs.get("class_weights", None)
-
- self.config_task()
-
- self.train_metrics = MetricCollection(
- [
- MulticlassAccuracy(
- num_classes=self.hyperparams["num_classes"],
- ignore_index=self.ignore_index,
- multidim_average="global",
- average="micro",
- ),
- MulticlassJaccardIndex(
- num_classes=self.hyperparams["num_classes"],
- ignore_index=self.ignore_index,
- average="micro",
- ),
- ],
- prefix="train_",
- )
- self.val_metrics = self.train_metrics.clone(prefix="val_")
- self.test_metrics = self.train_metrics.clone(prefix="test_")
-
- def forward(self, *args: Any, **kwargs: Any) -> Any:
- """Forward pass of the model.
-
- Args:
- x: tensor of data to run through the model
-
- Returns:
- output from the model
- """
- return self.model(*args, **kwargs)
-
- def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the training loss.
+ def training_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the training loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- training loss
+ The loss tensor.
"""
- batch = args[0]
x = batch["image"]
y = batch["mask"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
-
- loss = self.loss(y_hat, y)
-
- # by default, the train step logs every `log_every_n_steps` steps where
- # `log_every_n_steps` is a parameter to the `Trainer` object
- self.log("train_loss", loss, on_step=True, on_epoch=False)
+ loss: Tensor = self.criterion(y_hat, y)
+ self.log("train_loss", loss)
self.train_metrics(y_hat_hard, y)
+ self.log_dict(self.train_metrics)
+ return loss
- return cast(Tensor, loss)
-
- def on_train_epoch_end(self) -> None:
- """Logs epoch level training metrics."""
- self.log_dict(self.train_metrics.compute())
- self.train_metrics.reset()
-
- def validation_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute validation loss and log example predictions.
+ def validation_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> None:
+ """Compute the validation loss and additional metrics.
Args:
- batch: the output of your DataLoader
- batch_idx: the index of this batch
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
- batch_idx = args[1]
x = batch["image"]
y = batch["mask"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
-
- loss = self.loss(y_hat, y)
-
- self.log("val_loss", loss, on_step=False, on_epoch=True)
+ loss = self.criterion(y_hat, y)
+ self.log("val_loss", loss)
self.val_metrics(y_hat_hard, y)
+ self.log_dict(self.val_metrics)
if (
batch_idx < 10
and hasattr(self.trainer, "datamodule")
+ and hasattr(self.trainer.datamodule, "plot")
and self.logger
and hasattr(self.logger, "experiment")
and hasattr(self.logger.experiment, "add_figure")
@@ -265,77 +258,45 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig = datamodule.plot(sample)
- summary_writer = self.logger.experiment
- summary_writer.add_figure(
- f"image/{batch_idx}", fig, global_step=self.global_step
- )
- plt.close()
+ if fig:
+ summary_writer = self.logger.experiment
+ summary_writer.add_figure(
+ f"image/{batch_idx}", fig, global_step=self.global_step
+ )
+ plt.close()
except ValueError:
pass
- def on_validation_epoch_end(self) -> None:
- """Logs epoch level validation metrics."""
- self.log_dict(self.val_metrics.compute())
- self.val_metrics.reset()
-
- def test_step(self, *args: Any, **kwargs: Any) -> None:
- """Compute test loss.
+ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
+ """Compute the test loss and additional metrics.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
"""
- batch = args[0]
x = batch["image"]
y = batch["mask"]
y_hat = self(x)
y_hat_hard = y_hat.argmax(dim=1)
-
- loss = self.loss(y_hat, y)
-
- # by default, the test and validation steps only log per *epoch*
- self.log("test_loss", loss, on_step=False, on_epoch=True)
+ loss = self.criterion(y_hat, y)
+ self.log("test_loss", loss)
self.test_metrics(y_hat_hard, y)
+ self.log_dict(self.test_metrics)
- def on_test_epoch_end(self) -> None:
- """Logs epoch level test metrics."""
- self.log_dict(self.test_metrics.compute())
- self.test_metrics.reset()
-
- def predict_step(self, *args: Any, **kwargs: Any) -> Tensor:
- """Compute and return the predictions.
-
- By default, this will loop over images in a dataloader and aggregate
- predictions into a list. This may not be desirable if you have many images
- or large images which could cause out of memory errors. In this case
- it's recommended to override this with a custom predict_step.
+ def predict_step(
+ self, batch: Any, batch_idx: int, dataloader_idx: int = 0
+ ) -> Tensor:
+ """Compute the predicted class probabilities.
Args:
- batch: the output of your DataLoader
+ batch: The output of your DataLoader.
+ batch_idx: Integer displaying index of this batch.
+ dataloader_idx: Index of the current dataloader.
Returns:
- predicted softmax probabilities
+ Output predicted probabilities.
"""
- batch = args[0]
x = batch["image"]
y_hat: Tensor = self(x).softmax(dim=1)
return y_hat
-
- def configure_optimizers(self) -> dict[str, Any]:
- """Initialize the optimizer and learning rate scheduler.
-
- Returns:
- learning rate dictionary
- """
- optimizer = torch.optim.Adam(
- self.model.parameters(), lr=self.hyperparams["learning_rate"]
- )
- return {
- "optimizer": optimizer,
- "lr_scheduler": {
- "scheduler": ReduceLROnPlateau(
- optimizer,
- patience=self.hyperparams["learning_rate_schedule_patience"],
- ),
- "monitor": "val_loss",
- },
- }
diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py
index 1dd546f7c26..a889be1c96f 100644
--- a/torchgeo/trainers/simclr.py
+++ b/torchgeo/trainers/simclr.py
@@ -5,18 +5,18 @@
import os
import warnings
-from typing import Any, Optional, Union, cast
+from typing import Any, Optional, Union
import kornia.augmentation as K
+import lightning
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
-from lightning import LightningModule
from torch import Tensor
-from torch.optim import Adam, Optimizer
+from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torchvision.models._api import WeightsEnum
@@ -24,11 +24,7 @@
from ..models import get_weight
from . import utils
-
-try:
- from torch.optim.lr_scheduler import LRScheduler
-except ImportError:
- from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from .base import BaseTask
def simclr_augmentations(size: int, weights: Tensor) -> nn.Module:
@@ -57,7 +53,7 @@ def simclr_augmentations(size: int, weights: Tensor) -> nn.Module:
)
-class SimCLRTask(LightningModule):
+class SimCLRTask(BaseTask):
"""SimCLR: a simple framework for contrastive learning of visual representations.
Reference implementation:
@@ -72,6 +68,8 @@ class SimCLRTask(LightningModule):
.. versionadded:: 0.5
"""
+ monitor = "train_loss"
+
def __init__(
self,
model: str = "resnet50",
@@ -93,7 +91,8 @@ def __init__(
"""Initialize a new SimCLRTask instance.
Args:
- model: Name of the timm model to use.
+ model: Name of the `timm
+ `__ model to use.
weights: Initial model weights. Either a weight enum, the string
representation of a weight enum, True for ImageNet weights, False
or None for random weights, or the path to a saved model state dict.
@@ -122,8 +121,6 @@ def __init__(
Warns:
UserWarning: If hyperparameters do not match SimCLR version requested.
"""
- super().__init__()
-
# Validate hyperparameters
assert version in range(1, 3)
if version == 1:
@@ -137,16 +134,34 @@ def __init__(
if memory_bank_size == 0:
warnings.warn("SimCLR v2 uses a memory bank")
- self.save_hyperparameters(ignore=["augmentations"])
+ self.weights = weights
+ super().__init__(ignore=["weights", "augmentations"])
grayscale_weights = grayscale_weights or torch.ones(in_channels)
self.augmentations = augmentations or simclr_augmentations(
size, grayscale_weights
)
+ def configure_losses(self) -> None:
+ """Initialize the loss criterion."""
+ self.criterion = NTXentLoss(
+ self.hparams["temperature"],
+ self.hparams["memory_bank_size"],
+ self.hparams["gather_distributed"],
+ )
+
+ def configure_models(self) -> None:
+ """Initialize the model."""
+ weights = self.weights
+ hidden_dim: int = self.hparams["hidden_dim"]
+ output_dim: int = self.hparams["output_dim"]
+
# Create backbone
self.backbone = timm.create_model(
- model, in_chans=in_channels, num_classes=0, pretrained=weights is True
+ self.hparams["model"],
+ in_chans=self.hparams["in_channels"],
+ num_classes=0,
+ pretrained=weights is True,
)
# Load weights
@@ -167,12 +182,9 @@ def __init__(
output_dim = input_dim
self.projection_head = SimCLRProjectionHead(
- input_dim, hidden_dim, output_dim, layers
+ input_dim, hidden_dim, output_dim, self.hparams["layers"]
)
- # Define loss function
- self.criterion = NTXentLoss(temperature, memory_bank_size, gather_distributed)
-
# Initialize moving average of output
self.avg_output_std = 0.0
@@ -187,11 +199,11 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
x: Mini-batch of images.
Returns:
- Output from the model and backbone.
+ Output of the model and backbone.
"""
- h = self.backbone(x) # shape of batch_size x num_features
+ h: Tensor = self.backbone(x) # shape of batch_size x num_features
z = self.projection_head(h)
- return cast(Tensor, z), cast(Tensor, h)
+ return z, h
def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
@@ -205,10 +217,13 @@ def training_step(
Returns:
The loss tensor.
+
+ Raises:
+ AssertionError: If channel dimensions are incorrect.
"""
x = batch["image"]
- in_channels = self.hparams["in_channels"]
+ in_channels: int = self.hparams["in_channels"]
assert x.size(1) == in_channels or x.size(1) == 2 * in_channels
if x.size(1) == in_channels:
@@ -225,7 +240,7 @@ def training_step(
z1, h1 = self(x1)
z2, h2 = self(x2)
- loss = self.criterion(z1, z2)
+ loss: Tensor = self.criterion(z1, z2)
# Calculate the mean normalized standard deviation over features dimensions.
# If this is << 1 / sqrt(h1.shape[1]), then the model is not learning anything.
@@ -238,7 +253,7 @@ def training_step(
self.log("train_ssl_std", self.avg_output_std)
self.log("train_loss", loss)
- return cast(Tensor, loss)
+ return loss
def validation_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
@@ -253,7 +268,9 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""No-op, does nothing."""
- def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
+ def configure_optimizers(
+ self,
+ ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
"""Initialize the optimizer and learning rate scheduler.
Returns:
@@ -272,7 +289,7 @@ def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
warmup_epochs = 10
else:
warmup_epochs = int(max_epochs * 0.05)
- lr_scheduler = SequentialLR(
+ scheduler = SequentialLR(
optimizer,
schedulers=[
LinearLR(optimizer, total_iters=warmup_epochs),
@@ -280,4 +297,7 @@ def configure_optimizers(self) -> tuple[list[Optimizer], list[LRScheduler]]:
],
milestones=[warmup_epochs],
)
- return [optimizer], [lr_scheduler]
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor},
+ }
diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py
index ee332053d0d..d5c7cd97d39 100644
--- a/torchgeo/transforms/transforms.py
+++ b/torchgeo/transforms/transforms.py
@@ -42,7 +42,7 @@ def __init__(
keys: list[str] = []
for key in data_keys:
- if key == "image":
+ if key.startswith("image"):
keys.append("input")
elif key == "boxes":
keys.append("bbox")
diff --git a/train.py b/train.py
deleted file mode 100755
index 2284402d581..00000000000
--- a/train.py
+++ /dev/null
@@ -1,153 +0,0 @@
-#!/usr/bin/env python3
-
-# Copyright (c) Microsoft Corporation. All rights reserved.
-# Licensed under the MIT License.
-
-"""torchgeo model training script."""
-
-import os
-from typing import cast
-
-import lightning.pytorch as pl
-from hydra.utils import instantiate
-from lightning.pytorch import LightningDataModule, LightningModule, Trainer
-from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
-from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
-from omegaconf import DictConfig, OmegaConf
-
-from torchgeo.datamodules import MisconfigurationException
-from torchgeo.trainers import BYOLTask, MoCoTask, ObjectDetectionTask, SimCLRTask
-
-
-def set_up_omegaconf() -> DictConfig:
- """Loads program arguments from either YAML config files or command line arguments.
-
- This method loads defaults/a schema from "conf/defaults.yaml" as well as potential
- arguments from the command line. If one of the command line arguments is
- "config_file", then we additionally read arguments from that YAML file. One of the
- config file based arguments or command line arguments must specify task.name. The
- task.name value is used to grab a task specific defaults from its respective
- trainer. The final configuration is given as merge(task_defaults, defaults,
- config file, command line). The merge() works from the first argument to the last,
- replacing existing values with newer values. Additionally, if any values are
- merged into task_defaults without matching types, then there will be a runtime
- error.
-
- Returns:
- an OmegaConf DictConfig containing all the validated program arguments
-
- Raises:
- FileNotFoundError: when ``config_file`` does not exist
- """
- conf = OmegaConf.load("conf/defaults.yaml")
- command_line_conf = OmegaConf.from_cli()
-
- if "config_file" in command_line_conf:
- config_fn = command_line_conf.config_file
- if not os.path.isfile(config_fn):
- raise FileNotFoundError(f"config_file={config_fn} is not a valid file")
-
- user_conf = OmegaConf.load(config_fn)
- conf = OmegaConf.merge(conf, user_conf)
-
- conf = OmegaConf.merge( # Merge in any arguments passed via the command line
- conf, command_line_conf
- )
- conf = cast(DictConfig, conf) # convince mypy that everything is alright
- return conf
-
-
-def main(conf: DictConfig) -> None:
- """Main training loop."""
- experiment_name = (
- f"{conf.datamodule._target_.lower()}_{conf.module._target_.lower()}"
- )
- if os.path.isfile(conf.program.output_dir):
- raise NotADirectoryError("`program.output_dir` must be a directory")
- os.makedirs(conf.program.output_dir, exist_ok=True)
-
- experiment_dir = os.path.join(conf.program.output_dir, experiment_name)
- os.makedirs(experiment_dir, exist_ok=True)
-
- if len(os.listdir(experiment_dir)) > 0:
- if conf.program.overwrite:
- print(
- f"WARNING! The experiment directory, {experiment_dir}, already exists, "
- + "we might overwrite data in it!"
- )
- else:
- raise FileExistsError(
- f"The experiment directory, {experiment_dir}, already exists and isn't "
- + "empty. We don't want to overwrite any existing results, exiting..."
- )
-
- with open(os.path.join(experiment_dir, "config.yaml"), "w") as f:
- OmegaConf.save(config=conf, f=f)
-
- # Define module and datamodule
- datamodule: LightningDataModule = instantiate(conf.datamodule)
- task: LightningModule = instantiate(conf.module)
-
- # Define callbacks
- tb_logger = TensorBoardLogger(conf.program.log_dir, name=experiment_name)
- csv_logger = CSVLogger(conf.program.log_dir, name=experiment_name)
-
- if isinstance(task, ObjectDetectionTask):
- monitor_metric = "val_map"
- mode = "max"
- elif isinstance(task, (BYOLTask, MoCoTask, SimCLRTask)):
- monitor_metric = "train_loss"
- mode = "min"
- else:
- monitor_metric = "val_loss"
- mode = "min"
-
- checkpoint_callback = ModelCheckpoint(
- monitor=monitor_metric,
- filename=f"checkpoint-{{epoch:02d}}-{{{monitor_metric}:.2f}}",
- dirpath=experiment_dir,
- save_top_k=1,
- save_last=True,
- mode=mode,
- )
- early_stopping_callback = EarlyStopping(
- monitor=monitor_metric, min_delta=0.00, patience=18, mode=mode
- )
-
- # Define trainer
- trainer: Trainer = instantiate(
- conf.trainer,
- callbacks=[checkpoint_callback, early_stopping_callback],
- logger=[tb_logger, csv_logger],
- default_root_dir=experiment_dir,
- )
-
- # Train
- trainer.fit(model=task, datamodule=datamodule)
-
- # Test
- try:
- trainer.test(ckpt_path="best", datamodule=datamodule)
- except MisconfigurationException:
- pass
-
-
-if __name__ == "__main__":
- # Taken from https://github.com/pangeo-data/cog-best-practices
- _rasterio_best_practices = {
- "GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR",
- "AWS_NO_SIGN_REQUEST": "YES",
- "GDAL_MAX_RAW_BLOCK_CACHE_SIZE": "200000000",
- "GDAL_SWATH_SIZE": "200000000",
- "VSI_CURL_CACHE_SIZE": "200000000",
- }
- os.environ.update(_rasterio_best_practices)
-
- conf = set_up_omegaconf()
-
- # Set random seed for reproducibility
- # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.utilities.seed.html#pytorch_lightning.utilities.seed.seed_everything
- pl.seed_everything(conf.program.seed)
-
- # Main training procedure
- main(conf)