diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index f465d76f976..2127a1d22b7 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -1,4 +1,6 @@ -# Double quote -> single quote +# Prettier: double quote -> single quote +6a5aaf4b93507072d40dcd78114893362c4eaf6e +# Ruff: double quote -> single quote b09122f3e4a9cb422f6747bf33eca02993f67549 # Prettier bd9c75798eede1a4b7d7ecd6203179d3cb5e54dd diff --git a/.codecov.yml b/.github/codecov.yml similarity index 100% rename from .codecov.yml rename to .github/codecov.yml diff --git a/.github/dependabot.yml b/.github/dependabot.yml index eb0571076dc..947fd950632 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,35 +1,36 @@ # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file version: 2 updates: - - package-ecosystem: "devcontainers" - directory: "/.devcontainer" + - package-ecosystem: 'devcontainers' + directory: '/.devcontainer' schedule: - interval: "weekly" - - package-ecosystem: "github-actions" - directory: "/" + interval: 'weekly' + - package-ecosystem: 'github-actions' + directory: '/' schedule: - interval: "weekly" - - package-ecosystem: "pip" - directory: "/requirements" + interval: 'weekly' + - package-ecosystem: 'pip' + directory: '/requirements' schedule: - interval: "daily" + interval: 'daily' groups: # torchvision pins torch, must update in unison torch: patterns: - - "torch" - - "torchvision" + - 'torch' + - 'torchvision' ignore: # setuptools releases new versions almost daily - - dependency-name: "setuptools" - update-types: ["version-update:semver-patch"] + - dependency-name: 'setuptools' + update-types: ['version-update:semver-patch'] # sphinx 6 is incompatible with pytorch-sphinx-theme # https://github.com/pytorch/pytorch_sphinx_theme/issues/175 - - dependency-name: "sphinx" - versions: ">=6" + - dependency-name: 'sphinx' + versions: '>=6' # segmentation-models-pytorch pins timm, must update in unison - - dependency-name: "timm" - - package-ecosystem: "npm" - directory: "/requirements" + - dependency-name: 'timm' + - package-ecosystem: 'npm' + directory: '/' schedule: - interval: "weekly" + interval: 'weekly' + versioning-strategy: 'lockfile-only' diff --git a/.github/labeler.yml b/.github/labeler.yml index 9e366aa7787..8b5b2212e85 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,48 +1,48 @@ # TorchGeo modules datamodules: - changed-files: - - any-glob-to-any-file: "torchgeo/datamodules/**" + - any-glob-to-any-file: 'torchgeo/datamodules/**' datasets: - changed-files: - - any-glob-to-any-file: "torchgeo/datasets/**" + - any-glob-to-any-file: 'torchgeo/datasets/**' losses: - changed-files: - - any-glob-to-any-file: "torchgeo/losses/**" + - any-glob-to-any-file: 'torchgeo/losses/**' models: - changed-files: - - any-glob-to-any-file: "torchgeo/models/**" + - any-glob-to-any-file: 'torchgeo/models/**' samplers: - changed-files: - - any-glob-to-any-file: "torchgeo/samplers/**" + - any-glob-to-any-file: 'torchgeo/samplers/**' trainers: - changed-files: - - any-glob-to-any-file: "torchgeo/trainers/**" + - any-glob-to-any-file: 'torchgeo/trainers/**' transforms: - changed-files: - - any-glob-to-any-file: "torchgeo/transforms/**" + - any-glob-to-any-file: 'torchgeo/transforms/**' # Other dependencies: - changed-files: - any-glob-to-any-file: - - "pyproject.toml" - - "requirements/**" - - ".github/dependabot.yml" + - 'pyproject.toml' + - 'requirements/**' + - '.github/dependabot.yml' documentation: - changed-files: - any-glob-to-any-file: - - "docs/**" - - "*.md" - - ".github/*.md" - - ".readthedocs.yaml" + - 'docs/**' + - '*.md' + - '.github/*.md' + - '.readthedocs.yaml' scripts: - changed-files: - any-glob-to-any-file: - - "torchgeo/__main__.py" - - "torchgeo/main.py" - - "experiments/**" + - 'torchgeo/__main__.py' + - 'torchgeo/main.py' + - 'experiments/**' testing: - changed-files: - any-glob-to-any-file: - - "tests/**" - - ".github/workflows/**" + - 'tests/**' + - '.github/workflows/**' diff --git a/.github/release.yml b/.github/release.yml index 76caf0a585e..8ca5b59679c 100644 --- a/.github/release.yml +++ b/.github/release.yml @@ -1,7 +1,7 @@ changelog: exclude: authors: - - dependabot[bot] + - app/dependabot categories: - title: Backwards-incompatible changes labels: diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml new file mode 100644 index 00000000000..3af8fd64935 --- /dev/null +++ b/.github/workflows/deploy.yaml @@ -0,0 +1,45 @@ +name: deploy +on: + release: + types: + - published +jobs: + build: + name: build + runs-on: ubuntu-latest + steps: + - name: Clone repo + uses: actions/checkout@v4.2.2 + - name: Set up python + uses: actions/setup-python@v5.3.0 + with: + python-version: '3.12' + - name: Install pip dependencies + run: pip install build + - name: List pip dependencies + run: pip list + - name: Build project + run: python3 -m build + - name: Upload artifacts + uses: actions/upload-artifact@v4.5.0 + with: + name: pypi-dist + path: dist/ + pypi: + name: pypi + needs: + - build + environment: + name: pypi + url: https://pypi.org/p/torchgeo + permissions: + id-token: write + runs-on: ubuntu-latest + steps: + - name: Download artifacts + uses: actions/download-artifact@v4.1.8 + with: + name: pypi-dist + path: dist/ + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@v1.12.3 diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index cfffdb4b850..c5c0dbb85f4 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -1,4 +1,4 @@ -name: "labeler" +name: 'labeler' on: - pull_request_target jobs: @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Add label uses: actions/labeler@v5.0.0 with: diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 68cdbc9c693..301ce26efe8 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -6,23 +6,27 @@ on: pull_request: branches: - release** +defaults: + run: + shell: bash jobs: integration: name: integration runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: - python-version: "3.12" + python-version: '3.12' - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-integration + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }} - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | @@ -40,21 +44,22 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: - python-version: "3.12" + python-version: '3.12' - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-tutorials + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('pyproject.toml') }}-tutorials - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | - pip install .[docs,tests] planetary_computer pystac + pip install .[docs,tests] planetary_computer pystac tensorboard pip cache purge - name: List pip dependencies run: pip list diff --git a/.github/workflows/style.yaml b/.github/workflows/style.yaml index 5133ce2bd49..73b26cac561 100644 --- a/.github/workflows/style.yaml +++ b/.github/workflows/style.yaml @@ -8,23 +8,27 @@ on: branches: - main - release** +defaults: + run: + shell: bash jobs: mypy: name: mypy runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: - python-version: "3.12" + python-version: '3.12' - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/datasets.txt') }}-${{ hashFiles('requirements/style.txt') }}-${{ hashFiles('requirements/tests.txt') }} + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/required.txt', 'requirements/datasets.txt', 'requirements/style.txt', 'requirements/tests.txt') }} - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | @@ -39,17 +43,18 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: - python-version: "3.12" + python-version: '3.12' - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/style.txt') }} + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/style.txt') }} - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | @@ -66,21 +71,21 @@ jobs: runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up nodejs - uses: actions/setup-node@v4.0.2 + uses: actions/setup-node@v4.1.0 with: - node-version: "20" - cache: "npm" - cache-dependency-path: "requirements/package-lock.json" + node-version: '20' + cache: 'npm' + cache-dependency-path: 'package-lock.json' - name: Installing prettier run: | - npm install requirements/ + npm install npm cache clean --force - name: List npm dependencies run: npm ls --all - name: Run prettier formatting - run: npx prettier . --check + run: npx prettier --check . concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.head.label || github.head_ref || github.ref }} cancel-in-progress: true diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 7a7a0980c73..a44a6cf7669 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -8,43 +8,34 @@ on: branches: - main - release** +defaults: + run: + shell: bash jobs: latest: name: latest runs-on: ${{ matrix.os }} - env: - MPLBACKEND: Agg strategy: matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.10", "3.11", "3.12"] + python-version: ['3.10', '3.11', '3.12'] steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: python-version: ${{ matrix.python-version }} - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/datasets.txt') }}-${{ hashFiles('requirements/tests.txt') }} + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/required.txt', 'requirements/datasets.txt', 'requirements/tests.txt') }} if: ${{ runner.os != 'macOS' }} - name: Setup headless display for pyvista - uses: pyvista/setup-headless-display-action@v2 - - name: Install apt dependencies (Linux) - run: | - sudo apt-get update - sudo apt-get install unrar - if: ${{ runner.os == 'Linux' }} - - name: Install brew dependencies (macOS) - run: brew install rar - if: ${{ runner.os == 'macOS' }} - - name: Install choco dependencies (Windows) - run: choco install 7zip - if: ${{ runner.os == 'Windows' }} + uses: pyvista/setup-headless-display-action@v3 - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | @@ -54,69 +45,63 @@ jobs: run: pip list - name: Run pytest checks run: | - pytest --cov=torchgeo --cov-report=xml --durations=10 + pytest --cov --cov-report=xml python3 -m torchgeo --help - name: Report coverage - uses: codecov/codecov-action@v4.4.1 + uses: codecov/codecov-action@v5.1.2 with: token: ${{ secrets.CODECOV_TOKEN }} minimum: name: minimum runs-on: ubuntu-latest - env: - MPLBACKEND: Agg steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: - python-version: "3.10" + python-version: '3.10' - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/min-reqs.old') }}-${{ hashFiles('requirements/mins-cons.old') }} + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/min-reqs.old') }} - name: Setup headless display for pyvista - uses: pyvista/setup-headless-display-action@v2 - - name: Install apt dependencies (Linux) - run: | - sudo apt-get update - sudo apt-get install unrar + uses: pyvista/setup-headless-display-action@v3 - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | - pip install -r requirements/min-reqs.old -c requirements/min-cons.old + pip install -r requirements/min-reqs.old pip cache purge - name: List pip dependencies run: pip list - name: Run pytest checks run: | - pytest --cov=torchgeo --cov-report=xml --durations=10 + pytest --cov --cov-report=xml python3 -m torchgeo --help - name: Report coverage - uses: codecov/codecov-action@v4.4.1 + uses: codecov/codecov-action@v5.1.2 with: token: ${{ secrets.CODECOV_TOKEN }} datasets: name: datasets runs-on: ubuntu-latest - env: - MPLBACKEND: Agg steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: - python-version: "3.12" + python-version: '3.12' - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/tests.txt') }} + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/required.txt', 'requirements/tests.txt') }} - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | @@ -126,10 +111,10 @@ jobs: run: pip list - name: Run pytest checks run: | - pytest --cov=torchgeo --cov-report=xml --durations=10 + pytest --cov --cov-report=xml python3 -m torchgeo --help - name: Report coverage - uses: codecov/codecov-action@v4.4.1 + uses: codecov/codecov-action@v5.1.2 with: token: ${{ secrets.CODECOV_TOKEN }} concurrency: diff --git a/.github/workflows/tutorials.yaml b/.github/workflows/tutorials.yaml index eeff9acbd03..de0da385b08 100644 --- a/.github/workflows/tutorials.yaml +++ b/.github/workflows/tutorials.yaml @@ -10,27 +10,31 @@ on: - main paths: - docs/tutorials/** +defaults: + run: + shell: bash jobs: notebooks: name: notebooks runs-on: ubuntu-latest steps: - name: Clone repo - uses: actions/checkout@v4.1.6 + uses: actions/checkout@v4.2.2 - name: Set up python - uses: actions/setup-python@v5.1.0 + id: setup-python + uses: actions/setup-python@v5.3.0 with: - python-version: "3.12" + python-version: '3.12' - name: Cache dependencies - uses: actions/cache@v4.0.2 + uses: actions/cache@v4.2.0 id: cache with: path: ${{ env.pythonLocation }} - key: ${{ env.pythonLocation }}-${{ hashFiles('requirements/required.txt') }}-${{ hashFiles('requirements/docs.txt') }}-${{ hashFiles('requirements/tests.txt') }}-tutorials + key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/required.txt', 'requirements/docs.txt', 'requirements/tests.txt') }}-tutorials - name: Install pip dependencies if: steps.cache.outputs.cache-hit != 'true' run: | - pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac + pip install -r requirements/required.txt -r requirements/docs.txt -r requirements/tests.txt planetary_computer pystac tensorboard . pip cache purge - name: List pip dependencies run: pip list diff --git a/.gitignore b/.gitignore index 3a017d0ee6e..180c27c47b2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ # Node stuff: node_modules/ -/*.json # Spack .spack-env/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 164803fb8bd..e09d665c3ff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.5 + rev: v0.8.0 hooks: - id: ruff types_or: @@ -13,7 +13,7 @@ repos: - python - jupyter - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 + rev: v1.13.0 hooks: - id: mypy args: @@ -24,14 +24,15 @@ repos: - einops>=0.6.0 - kornia>=0.6.9 - lightning>=2.0.9 - - matplotlib>=3.8.1 + - matplotlib>=3.9.2 - numpy>=1.22 - - pillow>=10.3.0 + - pillow>=10.4.0 - pytest>=6.1.2 - pyvista>=0.34.2 - scikit-image>=0.22.0 - torch>=2.3 - torchmetrics>=0.10 + - torchvision>=0.18 exclude: (build|data|dist|logo|logs|output)/ - repo: https://github.com/pre-commit/mirrors-prettier rev: v3.1.0 diff --git a/.prettierignore b/.prettierignore index 790f9904d8b..6377adc4849 100644 --- a/.prettierignore +++ b/.prettierignore @@ -1,2 +1,7 @@ -# Ignore artifacts: +# Ignore artifacts tests/data/*/** + +# Automatically igored by git, but not by prettier +.mypy_cache +.pytest_cache +.ruff_cache diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 2c11def62d9..36f098eaf89 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-lts-latest tools: - python: "3.12" + python: '3.12' # Configuration of the Python environment to be used python: diff --git a/CITATION.cff b/CITATION.cff index 64524eaa24b..a8cc2bccab3 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -1,50 +1,50 @@ # https://github.com/citation-file-format/citation-file-format/blob/main/schema-guide.md # Can be validated using `cffconvert --validate` authors: - - family-names: "Stewart" - given-names: "Adam J." - - family-names: "Robinson" - given-names: "Caleb" - - family-names: "Corley" - given-names: "Isaac A." - - family-names: "Ortiz" - given-names: "Anthony" - - family-names: "Lavista Ferres" - given-names: "Juan M." - - family-names: "Banerjee" - given-names: "Arindam" -cff-version: "1.2.0" -message: "If you use this software, please cite it using the metadata from this file." + - family-names: 'Stewart' + given-names: 'Adam J.' + - family-names: 'Robinson' + given-names: 'Caleb' + - family-names: 'Corley' + given-names: 'Isaac A.' + - family-names: 'Ortiz' + given-names: 'Anthony' + - family-names: 'Lavista Ferres' + given-names: 'Juan M.' + - family-names: 'Banerjee' + given-names: 'Arindam' +cff-version: '1.2.0' +message: 'If you use this software, please cite it using the metadata from this file.' preferred-citation: authors: - - family-names: "Stewart" - given-names: "Adam J." - - family-names: "Robinson" - given-names: "Caleb" - - family-names: "Corley" - given-names: "Isaac A." - - family-names: "Ortiz" - given-names: "Anthony" - - family-names: "Lavista Ferres" - given-names: "Juan M." - - family-names: "Banerjee" - given-names: "Arindam" - collection-title: "Proceedings of the 30th International Conference on Advances in Geographic Information Systems" - collection-type: "proceedings" + - family-names: 'Stewart' + given-names: 'Adam J.' + - family-names: 'Robinson' + given-names: 'Caleb' + - family-names: 'Corley' + given-names: 'Isaac A.' + - family-names: 'Ortiz' + given-names: 'Anthony' + - family-names: 'Lavista Ferres' + given-names: 'Juan M.' + - family-names: 'Banerjee' + given-names: 'Arindam' + collection-title: 'Proceedings of the 30th International Conference on Advances in Geographic Information Systems' + collection-type: 'proceedings' conference: - city: "Seattle" + city: 'Seattle' name: "SIGSPATIAL '22" - region: "Washington" - doi: "10.1145/3557915.3560953" + region: 'Washington' + doi: '10.1145/3557915.3560953' end: 12 - isbn: "9781450395298" + isbn: '9781450395298' month: 11 number: 19 publisher: - name: "Association for Computing Machinery" + name: 'Association for Computing Machinery' start: 1 - title: "TorchGeo: Deep Learning With Geospatial Data" - type: "conference-paper" - url: "https://dl.acm.org/doi/10.1145/3557915.3560953" + title: 'TorchGeo: Deep Learning With Geospatial Data' + type: 'conference-paper' + url: 'https://dl.acm.org/doi/10.1145/3557915.3560953' year: 2022 -title: "TorchGeo: Deep Learning With Geospatial Data" +title: 'TorchGeo: Deep Learning With Geospatial Data' diff --git a/README.md b/README.md index 38883a212d7..b5fd835430b 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,8 @@ Testing: The recommended way to install TorchGeo is with [pip](https://pip.pypa.io/): -```console -$ pip install torchgeo +```sh +pip install torchgeo ``` For [conda](https://docs.conda.io/) and [spack](https://spack.io/) installation instructions, see the [documentation](https://torchgeo.readthedocs.io/en/stable/user/installation.html). @@ -72,15 +72,15 @@ Many remote sensing applications involve working with [_geospatial datasets_](ht In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of [Landsat](https://www.usgs.gov/landsat-missions) and [Cropland Data Layer (CDL)](https://data.nal.usda.gov/dataset/cropscape-cropland-data-layer) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets. ```python -landsat7 = Landsat7(root="...", bands=["B1", ..., "B7"]) -landsat8 = Landsat8(root="...", bands=["B2", ..., "B8"]) +landsat7 = Landsat7(paths="...", bands=["B1", ..., "B7"]) +landsat8 = Landsat8(paths="...", bands=["B2", ..., "B8"]) landsat = landsat7 | landsat8 ``` Next, we take the intersection between this dataset and the CDL dataset. We want to take the intersection instead of the union to ensure that we only sample from regions that have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different coordinate reference systems (CRS) or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used. ```python -cdl = CDL(root="...", download=True, checksum=True) +cdl = CDL(paths="...", download=True, checksum=True) dataset = landsat & cdl ``` @@ -192,7 +192,7 @@ trainer.fit(model=task, datamodule=datamodule) TorchGeo also supports command-line interface training using [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html). It can be invoked in two ways: -```console +```sh # If torchgeo has been installed torchgeo # If torchgeo has been installed, or if it has been cloned to the current directory @@ -201,7 +201,7 @@ python3 -m torchgeo It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages: -```console +```sh # See valid stages torchgeo --help # See valid trainer options @@ -220,7 +220,7 @@ trainer: model: class_path: ClassificationTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 13 num_classes: 10 data: @@ -233,13 +233,13 @@ data: we can see the script in action: -```console +```sh # Train and validate a model torchgeo fit --config config.yaml # Validate-only torchgeo validate --config config.yaml # Calculate and report test accuracy -torchgeo test --config config.yaml ckpt_path=... +torchgeo test --config config.yaml --ckpt_path=... ``` It can also be imported and used in a Python script if you need to extend it to add new features: diff --git a/docs/_static/badge-height.css b/docs/_static/badge-height.css new file mode 100644 index 00000000000..b349d8587bd --- /dev/null +++ b/docs/_static/badge-height.css @@ -0,0 +1,5 @@ +/* https://github.com/pytorch/pytorch_sphinx_theme/issues/140 */ +.tutorial-badge { + height: 50px !important; + width: auto !important; +} diff --git a/docs/_static/button-width.css b/docs/_static/button-width.css deleted file mode 100644 index edf1c62c977..00000000000 --- a/docs/_static/button-width.css +++ /dev/null @@ -1,4 +0,0 @@ -.colabbadge { - height: 50px !important; - width: auto !important; -} diff --git a/docs/api/agnostic_pretrained_weights.csv b/docs/api/agnostic_pretrained_weights.csv deleted file mode 100644 index b4e4e935b57..00000000000 --- a/docs/api/agnostic_pretrained_weights.csv +++ /dev/null @@ -1,3 +0,0 @@ -Weight,Source,Citation,License,m-bigearthnet,m-forestnet,m-brick-kiln,m-pv4ger,m-so2sat,m-eurosat,m-pv4ger-seg,m-nz-cattle,m-NeonTree,m-cashew-plant,m-SA-crop,m-chesapeake -DOFABase16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,63.8,45.3,94.7,96.9,52.1,92.2,94.7,81.6,58.6,48.3,31.3,65.4 -DOFALarge16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,64.4,47.4,95.1,97.3,59.3,93.8,95.0,81.7,59.1,53.8,32.1,66.3 diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 98160b41bad..fdcef5450d1 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -57,6 +57,16 @@ BigEarthNet .. autoclass:: BigEarthNetDataModule +CaBuAr +^^^^^^ + +.. autoclass:: CaBuArDataModule + +CaFFe +^^^^^ + +.. autoclass:: CaFFeDataModule + ChaBuD ^^^^^^ @@ -72,6 +82,11 @@ Deep Globe Land Cover Challenge .. autoclass:: DeepGlobeLandCoverDataModule +Digital Typhoon +^^^^^^^^^^^^^^^ + +.. autoclass:: DigitalTyphoonDataModule + ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -89,16 +104,31 @@ FAIR1M .. autoclass:: FAIR1MDataModule +Fields Of The World +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: FieldsOfTheWorldDataModule + FireRisk ^^^^^^^^ .. autoclass:: FireRiskDataModule +GeoNRW +^^^^^^ + +.. autoclass:: GeoNRWDataModule + GID-15 ^^^^^^ .. autoclass:: GID15DataModule +HySpecNet-11k +^^^^^^^^^^^^^ + +.. autoclass:: HySpecNet11kDataModule + Inria Aerial Image Labeling ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -108,6 +138,7 @@ LandCover.ai ^^^^^^^^^^^^ .. autoclass:: LandCoverAIDataModule +.. autoclass:: LandCoverAI100DataModule LEVIR-CD ^^^^^^^^ @@ -167,7 +198,9 @@ So2Sat SpaceNet ^^^^^^^^ +.. autoclass:: SpaceNetBaseDataModule .. autoclass:: SpaceNet1DataModule +.. autoclass:: SpaceNet6DataModule SSL4EO ^^^^^^ @@ -185,6 +218,11 @@ SustainBench Crop Yield .. autoclass:: SustainBenchCropYieldDataModule +TreeSatAI +^^^^^^^^^ + +.. autoclass:: TreeSatAIDataModule + Tropical Cyclone ^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8b5149ade43..868074ba635 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -16,7 +16,7 @@ Geospatial Datasets :widths: 30 15 20 36 20 15 :header-rows: 1 :align: center - :file: geo_datasets.csv + :file: datasets/geo_datasets.csv Aboveground Woody Biomass ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -47,8 +47,6 @@ Chesapeake Land Cover ^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: Chesapeake -.. autoclass:: Chesapeake7 -.. autoclass:: Chesapeake13 .. autoclass:: ChesapeakeDC .. autoclass:: ChesapeakeDE .. autoclass:: ChesapeakeMD @@ -193,11 +191,11 @@ Non-geospatial Datasets :class:`NonGeoDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset `. -.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection +.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection, IC = image captioning :widths: 15 7 15 20 12 11 12 15 13 :header-rows: 1 :align: center - :file: non_geo_datasets.csv + :file: datasets/non_geo_datasets.csv ADVANCE ^^^^^^^ @@ -219,6 +217,16 @@ BioMassters .. autoclass:: BioMassters +CaBuAr +^^^^^^ + +.. autoclass:: CaBuAr + +CaFFe +^^^^^ + +.. autoclass:: CaFFe + ChaBuD ^^^^^^ @@ -256,6 +264,12 @@ DFC2022 .. autoclass:: DFC2022 + +Digital Typhoon +^^^^^^^^^^^^^^^ + +.. autoclass:: DigitalTyphoon + ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -273,6 +287,11 @@ FAIR1M .. autoclass:: FAIR1M +Fields Of The World +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: FieldsOfTheWorld + FireRisk ^^^^^^^^ @@ -283,11 +302,21 @@ Forest Damage .. autoclass:: ForestDamage +GeoNRW +^^^^^^^ + +.. autoclass:: GeoNRW + GID-15 ^^^^^^ .. autoclass:: GID15 +HySpecNet-11k +^^^^^^^^^^^^^ + +.. autoclass:: HySpecNet11k + IDTReeS ^^^^^^^ @@ -302,6 +331,7 @@ LandCover.ai ^^^^^^^^^^^^ .. autoclass:: LandCoverAI +.. autoclass:: LandCoverAI100 LEVIR-CD ^^^^^^^^ @@ -329,6 +359,11 @@ Million-AID .. autoclass:: MillionAID +MMEarth +^^^^^^^^ + +.. autoclass:: MMEarth + NASA Marine Debris ^^^^^^^^^^^^^^^^^^ @@ -374,6 +409,11 @@ Rwanda Field Boundary .. autoclass:: RwandaFieldBoundary +SatlasPretrain +^^^^^^^^^^^^^^ + +.. autoclass:: SatlasPretrain + Seasonal Contrast ^^^^^^^^^^^^^^^^^ @@ -394,6 +434,11 @@ SKIPP'D .. autoclass:: SKIPPD +SkyScript +^^^^^^^^^ + +.. autoclass:: SkyScript + So2Sat ^^^^^^ @@ -410,6 +455,7 @@ SpaceNet .. autoclass:: SpaceNet5 .. autoclass:: SpaceNet6 .. autoclass:: SpaceNet7 +.. autoclass:: SpaceNet8 SSL4EO ^^^^^^ @@ -428,6 +474,11 @@ SustainBench Crop Yield .. autoclass:: SustainBenchCropYield +TreeSatAI +^^^^^^^^^ + +.. autoclass:: TreeSatAI + Tropical Cyclone ^^^^^^^^^^^^^^^^ diff --git a/docs/api/geo_datasets.csv b/docs/api/datasets/geo_datasets.csv similarity index 93% rename from docs/api/geo_datasets.csv rename to docs/api/datasets/geo_datasets.csv index 7cdcaf5cfca..4bb5788609e 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/datasets/geo_datasets.csv @@ -4,7 +4,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `Airphen`_,Imagery,Airphen,-,"1,280x960",0.047--0.09 `Aster Global DEM`_,DEM,Aster,"public domain","3,601x3,601",30 `Canadian Building Footprints`_,Geometries,Bing Imagery,"ODbL-1.0",-,- -`Chesapeake Land Cover`_,"Imagery, Masks",NAIP,"CC-BY-4.0",-,1 +`Chesapeake Land Cover`_,"Imagery, Masks",NAIP,"CC0-1.0",-,1 `Global Mangrove Distribution`_,Masks,"Remote Sensing, In Situ Measurements","public domain",-,3 `Cropland Data Layer`_,Masks,Landsat,"public domain",-,30 `EDDMapS`_,Points,Citizen Scientists,-,-,- @@ -20,7 +20,7 @@ Dataset,Type,Source,License,Size (px),Resolution (m) `L8 Biome`_,"Imagery, Masks",Landsat,"CC0-1.0","8,900x8,900","15, 30" `LandCover.ai Geo`_,"Imagery, Masks",Aerial,"CC-BY-NC-SA-4.0","4,200--9,500",0.25--0.5 `Landsat`_,Imagery,Landsat,"public domain","8,900x8,900",30 -`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",1 +`NAIP`_,Imagery,Aerial,"public domain","6,100x7,600",0.3--2 `NCCM`_,Masks,Sentinel-2,"CC-BY-4.0",-,10 `NLCD`_,Masks,Landsat,"public domain",-,30 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus","CC-BY-4.0 OR ODbL-1.0",-,- diff --git a/docs/api/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv similarity index 78% rename from docs/api/non_geo_datasets.csv rename to docs/api/datasets/non_geo_datasets.csv index 2dac9021daa..7d7a17a4b94 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -3,6 +3,8 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Benin Cashew Plantations`_,S,Airbus Pléiades,"CC-BY-4.0",70,6,"1,122x1,186",10,MSI `BigEarthNet`_,C,Sentinel-1/2,"CDLA-Permissive-1.0","590,326",19--43,120x120,10,"SAR, MSI" `BioMassters`_,R,Sentinel-1/2 and Lidar,"CC-BY-4.0",,,256x256, 10, "SAR, MSI" +`CaBuAr`_,CD,Sentinel-2,"OpenRAIL",424,2,512x512,20,MSI +`CaFFe`_,S,"Sentinel-1, TerraSAR-X, TanDEM-X, ENVISAT, ERS-1/2, ALOS PALSAR, and RADARSAT-1","CC-BY-4.0","19092","2 or 4","512x512",6-20,"SAR" `ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI `Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI `COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB @@ -10,12 +12,16 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI `DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB `DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB +`Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared `ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI `FAIR1M`_,OD,Gaofen/Google Earth,"CC-BY-NC-SA-3.0","15,000",37,"1,024x1,024",0.3--0.8,RGB +`Fields Of The World`_,"S,I",Sentinel-2,"Various","70795","2,3",256x256,10,MSI `FireRisk`_,C,NAIP Aerial,"CC-BY-NC-4.0","91,872",7,"320x320",1,RGB `Forest Damage`_,OD,Drone imagery,"CDLA-Permissive-1.0","1,543",4,"1,500x1,500",,RGB +`GeoNRW`_,S,Aerial,"CC-BY-4.0","7,783",11,"1,000x1,000",1,"RGB, DEM" `GID-15`_,S,Gaofen-2,-,150,15,"6,800x7,200",3,RGB +`HySpecNet-11k`_,-,EnMAP,CC0-1.0,11k,-,128,30,HSI `IDTReeS`_,"OD,C",Aerial,"CC-BY-4.0",591,33,200x200,0.1--1,RGB `Inria Aerial Image Labeling`_,S,Aerial,-,360,2,"5,000x5,000",0.3,RGB `LandCover.ai`_,S,Aerial,"CC-BY-NC-SA-4.0","10,674",5,512x512,0.25--0.5,RGB @@ -24,19 +30,22 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `LoveDA`_,S,Google Earth,"CC-BY-NC-SA-4.0","5,987",7,"1,024x1,024",0.3,RGB `MapInWild`_,S,"Sentinel-1/2, ESA WorldCover, NOAA VIIRS DNB","CC-BY-4.0",1018,1,1920x1920,10--463.83,"SAR, MSI, 2020_Map, avg_rad" `Million-AID`_,C,Google Earth,-,1M,51--73,,0.5--153,RGB +`MMEarth`_,"C, S","Aster, Sentinel, ERA5","CC-BY-4.0","100K--1M",,"128x128 or 64x64",10,MSI `NASA Marine Debris`_,OD,PlanetScope,"Apache-2.0",707,1,256x256,3,RGB `OSCD`_,CD,Sentinel-2,"CC-BY-4.0",24,2,"40--1,180",60,MSI `PASTIS`_,I,Sentinel-1/2,"CC-BY-4.0","2,433",19,128x128xT,10,MSI -`PatternNet`_,C,Google Earth,-,"30,400",38,256x256,0.06--5,RGB +`PatternNet`_,C,Google Earth,"CC-BY-4.0","30,400",38,256x256,0.06--5,RGB `Potsdam`_,S,Aerial,-,38,6,"6,000x6,000",0.05,MSI `QuakeSet`_,"C, R",Sentinel-1,"OpenRAIL","3,327",2,512x512,10,SAR `ReforesTree`_,"OD, R",Aerial,"CC-BY-4.0",100,6,"4,000x4,000",0.02,RGB -`RESISC45`_,C,Google Earth,"CC-BY-NC-4.0","31,500",45,256x256,0.2--30,RGB +`RESISC45`_,C,Google Earth,-,"31,500",45,256x256,0.2--30,RGB `Rwanda Field Boundary`_,S,Planetscope,"NICFI AND CC-BY-4.0",70,2,256x256,4.7,RGB + NIR +`SatlasPretrain`_,"C, R, S, I, OD","NAIP, Landsat, Sentinel",ESA AND CC0-1.0 AND ODbL-1.0 AND CC-BY-4.0,302M,137,512,0.6--30,"SAR, MSI" `Seasonal Contrast`_,T,Sentinel-2,"CC-BY-4.0",100K--1M,-,264x264,10,MSI `SeasoNet`_,S,Sentinel-2,"CC-BY-4.0","1,759,830",33,120x120,10,MSI `SEN12MS`_,S,"Sentinel-1/2, MODIS","CC-BY-4.0","180,662",33,256x256,10,"SAR, MSI" `SKIPP'D`_,R,"Fish-eye","CC-BY-4.0","363,375",-,64x64,-,RGB +`SkyScript`_,IC,"NAIP, orthophotos, Planet SkySat, Sentinel-2, Landsat 8--9",MIT,5.2M,-,100--1000,0.1--30,RGB `So2Sat`_,C,Sentinel-1/2,"CC-BY-4.0","400,673",17,32x32,10,"SAR, MSI" `SpaceNet`_,I,WorldView-2/3 Planet Lab Dove,"CC-BY-SA-4.0","1,889--28,728",2,102--900,0.5--4,MSI `SSL4EO`_-L,T,Landsat,"CC0-1.0",1M,-,264x264,30,MSI @@ -44,6 +53,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `SSL4EO-L Benchmark`_,S,Lansat & CDL,"CC0-1.0",25K,134,264x264,30,MSI `SSL4EO-L Benchmark`_,S,Lansat & NLCD,"CC0-1.0",25K,17,264x264,30,MSI `SustainBench Crop Yield`_,R,MODIS,"CC-BY-SA-4.0",11k,-,32x32,-,MSI +`TreeSatAI`_,"C, R, S","Aerial, Sentinel-1/2",CC-BY-4.0,50K,"12, 15, 20","6, 20, 304","0.2, 10","CIR, MSI, SAR" `Tropical Cyclone`_,R,GOES 8--16,"CC-BY-4.0","108,110",-,256x256,4K--8K,MSI `UC Merced`_,C,USGS National Map,"public domain","2,100",21,256x256,0.3,RGB `USAVars`_,R,NAIP Aerial,"CC-BY-4.0",100K,-,-,4,"RGB, NIR" diff --git a/docs/api/misc_pretrained_weights.csv b/docs/api/misc_pretrained_weights.csv deleted file mode 100644 index 4088b91b99f..00000000000 --- a/docs/api/misc_pretrained_weights.csv +++ /dev/null @@ -1,2 +0,0 @@ -Weight,Channels,Source,Citation,License -ResNet50_Weights.FMOW_RGB_GASSL, 3,`link `__,`link `__,- diff --git a/docs/api/models.rst b/docs/api/models.rst index bfc51cb41ed..5d8aeb84097 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -10,6 +10,15 @@ Change Star .. autoclass:: ChangeStarFarSeg .. autoclass:: ChangeMixin +CROMA +^^^^^ + +.. autoclass:: CROMA +.. autofunction:: croma_base +.. autofunction:: croma_large +.. autoclass:: CROMABase_Weights +.. autoclass:: CROMALarge_Weights + DOFA ^^^^ @@ -47,13 +56,23 @@ ResNet .. autofunction:: resnet18 .. autofunction:: resnet50 +.. autofunction:: resnet152 .. autoclass:: ResNet18_Weights .. autoclass:: ResNet50_Weights +.. autoclass:: ResNet152_Weights + +Scale-MAE +^^^^^^^^^ + +.. autofunction:: ScaleMAE +.. autoclass:: ScaleMAELarge16_Weights Swin Transformer ^^^^^^^^^^^^^^^^^^ +.. autofunction:: swin_v2_t .. autofunction:: swin_v2_b +.. autoclass:: Swin_V2_T_Weights .. autoclass:: Swin_V2_B_Weights Vision Transformer @@ -74,36 +93,38 @@ Utility Functions Pretrained Weights ^^^^^^^^^^^^^^^^^^ +TorchGeo provides a number of pre-trained models and backbones, allowing you to perform transfer learning on small datasets without training a new model from scratch or relying on ImageNet weights. Depending on the satellite/sensor where your data comes from, choose from the following pre-trained weights based on which one has the best performance metrics. + Sensor-Agnostic --------------- -These weights can be used with imagery from any satellite/sensor. +These weights can be used with imagery from any satellite/sensor. In addition to the usual performance metrics, there are also additional columns for dynamic spatial (resolution), temporal (time span), and/or spectral (wavelength) support, either via their training data (implicit) or via their model architecture (explicit). .. csv-table:: - :widths: 45 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 + :widths: 45 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 :header-rows: 1 :align: center - :file: agnostic_pretrained_weights.csv + :file: weights/agnostic.csv -NAIP ----- +Landsat +------- .. csv-table:: - :widths: 45 10 10 10 10 + :widths: 65 10 10 10 10 10 10 10 10 10 :header-rows: 1 :align: center - :file: naip_pretrained_weights.csv + :file: weights/landsat.csv -Landsat -------- +NAIP +---- .. csv-table:: - :widths: 65 10 10 10 10 10 10 10 10 10 + :widths: 45 10 10 10 10 :header-rows: 1 :align: center - :file: landsat_pretrained_weights.csv + :file: weights/naip.csv Sentinel-1 @@ -113,7 +134,7 @@ Sentinel-1 :widths: 45 10 10 10 10 :header-rows: 1 :align: center - :file: sentinel1_pretrained_weights.csv + :file: weights/sentinel1.csv Sentinel-2 @@ -123,13 +144,4 @@ Sentinel-2 :widths: 45 10 10 10 10 15 10 10 10 :header-rows: 1 :align: center - :file: sentinel2_pretrained_weights.csv - -Other Data Sources ------------------- - -.. csv-table:: - :widths: 45 10 10 10 1 - :header-rows: 1 - :align: center - :file: misc_pretrained_weights.csv + :file: weights/sentinel2.csv diff --git a/docs/api/sentinel2_pretrained_weights.csv b/docs/api/sentinel2_pretrained_weights.csv deleted file mode 100644 index 48869d42a30..00000000000 --- a/docs/api/sentinel2_pretrained_weights.csv +++ /dev/null @@ -1,12 +0,0 @@ -Weight,Channels,Source,Citation,License,BigEarthNet,EuroSAT,So2Sat,OSCD -ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",,,, -ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,"CC-BY-4.0",,,, -ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,"Apache-2.0",87.27,93.14,,46.94 -ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,"CC-BY-4.0",90.7,99.1,63.6, -ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",91.8,99.1,60.9, -ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,"CC-BY-4.0",,, -ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,"Apache-2.0",87.81,,, -ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,"CC-BY-4.0",90.5,99.0,62.2, -ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",89.9,98.6,61.6, -Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link `__,`link `__,"ODC-BY",,,, -Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link `__,`link `__,"ODC-BY",,,, diff --git a/docs/api/weights/agnostic.csv b/docs/api/weights/agnostic.csv new file mode 100644 index 00000000000..93f1d1c5c26 --- /dev/null +++ b/docs/api/weights/agnostic.csv @@ -0,0 +1,6 @@ +Weight,Source,Citation,License,Spatial,Temporal,Spectral,m-bigearthnet,m-forestnet,m-brick-kiln,m-pv4ger,m-so2sat,m-eurosat,m-pv4ger-seg,m-nz-cattle,m-NeonTree,m-cashew-plant,m-SA-crop,m-chesapeake +CROMA,`link `__,`link `__,CC-BY-4.0,implicit,-,implicit,,,,,,,,,,,, +DOFABase16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,implicit,-,explicit,65.7,50.9,95.8,96.9,55.1,93.9,94.5,81.4,58.8,51.5,33.0,65.3 +DOFALarge16_Weights.DOFA_MAE,`link `__,`link `__,CC-BY-4.0,implicit,-,explicit,67.5,54.6,96.9,97.3,60.1,97.1,95.0,81.8,59.4,56.9,32.1,66.3 +ResNet50_Weights.FMOW_RGB_GASSL,`link `__,`link `__,-,implicit,-,-,,,,,,,,,,,, +ScaleMAE_ViTLarge16_Weights.FMOW_RGB_SCALEMAE,`link `__,`link `__,CC-BY-NC-4.0,explicit,-,-,,,,,,,,,,, diff --git a/docs/api/landsat_pretrained_weights.csv b/docs/api/weights/landsat.csv similarity index 96% rename from docs/api/landsat_pretrained_weights.csv rename to docs/api/weights/landsat.csv index faf3c286dc5..bfc90b651e0 100644 --- a/docs/api/landsat_pretrained_weights.csv +++ b/docs/api/weights/landsat.csv @@ -29,4 +29,5 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,"CC0-1.0",63.65,46.68,60.01,43.17 ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link `__,`link `__,"CC0-1.0",66.81,50.16,64.17,47.24 ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link `__,`link `__,"CC0-1.0",65.04,48.20,62.61,45.46 -Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,`link `__,`link `__,"ODC-BY",,,, +Swin_V2_B_Weights.LANDSAT_SI_SATLAS,8--9,11,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.LANDSAT_MI_SATLAS,8--9,11,`link `__,`link `__,ODC-BY,,,, diff --git a/docs/api/naip_pretrained_weights.csv b/docs/api/weights/naip.csv similarity index 54% rename from docs/api/naip_pretrained_weights.csv rename to docs/api/weights/naip.csv index e8e8ef14b8b..7dfe84d21dc 100644 --- a/docs/api/naip_pretrained_weights.csv +++ b/docs/api/weights/naip.csv @@ -1,2 +1,3 @@ Weight,Channels,Source,Citation,License -Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link `__,`link `__,"ODC-BY" +Swin_V2_B_Weights.NAIP_RGB_MI_SATLAS,3,`link `__,`link `__,ODC-BY +Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link `__,`link `__,ODC-BY diff --git a/docs/api/sentinel1_pretrained_weights.csv b/docs/api/weights/sentinel1.csv similarity index 52% rename from docs/api/sentinel1_pretrained_weights.csv rename to docs/api/weights/sentinel1.csv index 05d623ccb10..82ed045f149 100644 --- a/docs/api/sentinel1_pretrained_weights.csv +++ b/docs/api/weights/sentinel1.csv @@ -1,3 +1,5 @@ Weight,Channels,Source,Citation,License +ResNet50_Weights.SENTINEL1_ALL_DECUR, 2,`link `__,`link `__,"Apache-2.0" ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link `__,`link `__,"CC-BY-4.0" -Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link `__,`link `__,"ODC-BY" +Swin_V2_B_Weights.SENTINEL1_MI_SATLAS,2,`link `__,`link `__,ODC-BY +Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link `__,`link `__,ODC-BY diff --git a/docs/api/weights/sentinel2.csv b/docs/api/weights/sentinel2.csv new file mode 100644 index 00000000000..e583cfc891b --- /dev/null +++ b/docs/api/weights/sentinel2.csv @@ -0,0 +1,27 @@ +Weight,Channels,Source,Citation,License,BigEarthNet,EuroSAT,So2Sat,OSCD +ResNet18_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",,,, +ResNet18_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,"CC-BY-4.0",,,, +ResNet18_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,"Apache-2.0",87.27,93.14,,46.94 +ResNet50_Weights.SENTINEL2_ALL_DECUR,13,`link `__,`link `__,"Apache-2.0",,,, +ResNet50_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,"CC-BY-4.0",90.7,99.1,63.6, +ResNet50_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",91.8,99.1,60.9, +ResNet50_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet50_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +ResNet50_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet50_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link `__,`link `__,"CC-BY-4.0",,, +ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link `__,`link `__,"Apache-2.0",87.81,,, +ResNet152_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet152_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +ResNet152_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +ResNet152_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link `__,`link `__,"CC-BY-4.0",90.5,99.0,62.2, +ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link `__,`link `__,"CC-BY-4.0",89.9,98.6,61.6, +Swin_V2_T_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_T_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +Swin_V2_T_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_T_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_MI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_MI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_SI_MS_SATLAS,9,`link `__,`link `__,ODC-BY,,,, +Swin_V2_B_Weights.SENTINEL2_SI_RGB_SATLAS,3,`link `__,`link `__,ODC-BY,,,, diff --git a/docs/conf.py b/docs/conf.py index 36af7d3d275..df1de398185 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath('..')) -import torchgeo # noqa: E402 +import torchgeo # -- Project information ----------------------------------------------------- @@ -50,8 +50,7 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build'] -# Sphinx 3.0+ required for: -# autodoc_typehints_description_target = "documented" +# Sphinx 4.0+ required for autodoc_typehints_description_traget needs_sphinx = '4.0' nitpicky = True @@ -60,6 +59,8 @@ ('py:class', 'fiona.model.Feature'), ('py:class', 'kornia.augmentation._2d.intensity.base.IntensityAugmentationBase2D'), ('py:class', 'kornia.augmentation.base._AugmentationBase'), + ('py:class', 'lightning.pytorch.utilities.types.LRSchedulerConfig'), + ('py:class', 'lightning.pytorch.utilities.types.OptimizerConfig'), ('py:class', 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig'), ('py:class', 'segmentation_models_pytorch.base.model.SegmentationModel'), ('py:class', 'timm.models.resnet.ResNet'), @@ -93,7 +94,7 @@ html_favicon = os.path.join('..', 'logo', 'favicon.ico') html_static_path = ['_static'] -html_css_files = ['button-width.css', 'notebook-prompt.css', 'table-scroll.css'] +html_css_files = ['badge-height.css', 'notebook-prompt.css', 'table-scroll.css'] # -- Extension configuration ------------------------------------------------- @@ -127,40 +128,8 @@ # nbsphinx nbsphinx_execute = 'never' -# TODO: branch/tag should change depending on which version of docs you look at -# TODO: width option of image directive is broken, see: -# https://github.com/pytorch/pytorch_sphinx_theme/issues/140 -nbsphinx_prolog = """ -{% set host = "https://colab.research.google.com" %} -{% set repo = "microsoft/torchgeo" %} -{% set urlpath = "docs/" ~ env.docname ~ ".ipynb" %} -{% if "dev" in env.config.release %} - {% set branch = "main" %} -{% else %} - {% set branch = "releases/v" ~ env.config.version %} -{% endif %} - -.. image:: {{ host }}/assets/colab-badge.svg - :class: colabbadge - :alt: Open in Colab - :target: {{ host }}/github/{{ repo }}/blob/{{ branch }}/{{ urlpath }} - -{% set host = "https://pccompute.westeurope.cloudapp.azure.com" %} -{% set host = host ~ "/compute/hub/user-redirect/git-pull" %} -{% set repo = "https%3A%2F%2Fgithub.com%2Fmicrosoft%2Ftorchgeo" %} -{% set urlpath = "tree%2Ftorchgeo%2Fdocs%2F" %} -{% set urlpath = urlpath ~ env.docname | replace("/", "%2F") ~ ".ipynb" %} -{% if "dev" in env.config.release %} - {% set branch = "main" %} -{% else %} - {% set branch = "releases%2Fv" ~ env.config.version %} -{% endif %} - -.. image:: https://img.shields.io/badge/-Open%20on%20Planetary%20Computer-blue - :class: colabbadge - :alt: Open on Planetary Computer - :target: {{ host }}?repo={{ repo }}&urlpath={{ urlpath }}&branch={{ branch }} -""" +with open(os.path.join('tutorials', 'prolog.rst.jinja')) as f: + nbsphinx_prolog = f.read() # Disables requirejs in nbsphinx to enable compatibility with the pytorch_sphinx_theme # See more information here https://github.com/spatialaudio/nbsphinx/issues/599 diff --git a/docs/index.rst b/docs/index.rst index ced959493a8..02ba03a8cc5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,15 @@ torchgeo user/glossary user/alternatives +.. toctree:: + :maxdepth: 2 + :caption: Tutorials + + tutorials/getting_started + tutorials/basic_usage + tutorials/case_studies + tutorials/customization + .. toctree:: :maxdepth: 2 :caption: Package Reference @@ -24,17 +33,6 @@ torchgeo api/trainers api/transforms -.. toctree:: - :maxdepth: 1 - :caption: Tutorials - - tutorials/getting_started - tutorials/custom_raster_dataset - tutorials/transforms - tutorials/indices - tutorials/trainers - tutorials/pretrained_weights - .. toctree:: :maxdepth: 1 :caption: PyTorch Libraries diff --git a/docs/tutorials/basic_usage.rst b/docs/tutorials/basic_usage.rst new file mode 100644 index 00000000000..51eb77a85ca --- /dev/null +++ b/docs/tutorials/basic_usage.rst @@ -0,0 +1,20 @@ +Basic Usage +=========== + +The following tutorials introduce the basic concepts and components of TorchGeo: + +* `Transforms `_: Preprocessing and data augmentation transforms for geospatial data +* `Spectral Indices `_: Visualizing and appending spectral indices +* `Pretrained Weights `_: Models and pretrained weights +* `Lightning Trainers `_: PyTorch Lightning data modules and trainers +* `Command-Line Interface `_: TorchGeo's command-line interface + +.. toctree:: + :hidden: + :maxdepth: 1 + + transforms + indices + pretrained_weights + trainers + cli diff --git a/docs/tutorials/case_studies.rst b/docs/tutorials/case_studies.rst new file mode 100644 index 00000000000..aee09a94265 --- /dev/null +++ b/docs/tutorials/case_studies.rst @@ -0,0 +1,14 @@ +Case Studies +============ + +The following case studies present end-to-end workflows for common use cases of geospatial machine learning: + +* `Earth Surface Water `_: A workflow for mapping surface water, including lakes and rivers + +Do you have a use case that is missing from this list? Please open a pull request to add tutorials for your own use cases. + +.. toctree:: + :hidden: + :maxdepth: 1 + + earth_surface_water diff --git a/docs/tutorials/cli.ipynb b/docs/tutorials/cli.ipynb new file mode 100644 index 00000000000..1424b13291a --- /dev/null +++ b/docs/tutorials/cli.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "16421d50-8d7a-4972-b06f-160fd890cc86", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "id": "e563313d", + "metadata": {}, + "source": [ + "# Command-Line Interface\n", + "\n", + "_Written by: Adam J. Stewart_\n", + "\n", + "TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface." + ] + }, + { + "cell_type": "markdown", + "id": "8c1f4156", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f0d31a8", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "cell_type": "markdown", + "id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b", + "metadata": {}, + "source": [ + "## Subcommands\n", + "\n", + "The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo --help" + ] + }, + { + "cell_type": "markdown", + "id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6", + "metadata": {}, + "source": [ + "## Trainer\n", + "\n", + "Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --help" + ] + }, + { + "cell_type": "markdown", + "id": "b437860c-b406-4150-b30b-8aa895eebfcd", + "metadata": {}, + "source": [ + "## Model\n", + "\n", + "We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --model.help ClassificationTask" + ] + }, + { + "cell_type": "markdown", + "id": "3daacd8d-64f4-4357-bdf3-759295a14224", + "metadata": {}, + "source": [ + "## Data\n", + "\n", + "We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "136eb59f-6662-44af-82e9-c55bdb3f17ac", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --data.help EuroSAT100DataModule" + ] + }, + { + "cell_type": "markdown", + "id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb", + "metadata": {}, + "source": [ + "## Config\n", + "\n", + "Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", + "config = f\"\"\"\n", + "trainer:\n", + " max_epochs: 1\n", + " default_root_dir: '{root}'\n", + "model:\n", + " class_path: ClassificationTask\n", + " init_args:\n", + " model: 'resnet18'\n", + " in_channels: 13\n", + " num_classes: 10\n", + "data:\n", + " class_path: EuroSAT100DataModule\n", + " init_args:\n", + " batch_size: 8\n", + " dict_kwargs:\n", + " root: '{root}'\n", + " download: true\n", + "\"\"\"\n", + "os.makedirs(root, exist_ok=True)\n", + "with open(os.path.join(root, 'config.yaml'), 'w') as f:\n", + " f.write(config)" + ] + }, + { + "cell_type": "markdown", + "id": "a661b8d7-2dc9-4a30-8842-bd52d130e080", + "metadata": {}, + "source": [ + "This YAML file has three sections:\n", + "\n", + "* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n", + "* model: Arguments to pass to the task\n", + "* data: Arguments to pass to the data module\n", + "\n", + "The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments." + ] + }, + { + "cell_type": "markdown", + "id": "e132f933-4edf-42bb-b585-e0d8ceb65eab", + "metadata": {}, + "source": [ + "## Training\n", + "\n", + "We can now train our model like so." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f84b0739-c9e7-4057-8864-98ab69a11f64", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo fit --config {root}/config.yaml" + ] + }, + { + "cell_type": "markdown", + "id": "cb1557f1-6cc0-46da-909c-836911acb248", + "metadata": {}, + "source": [ + "## Validation\n", + "\n", + "Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61", + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "\n", + "checkpoint = glob.glob(\n", + " os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n", + ")[0]\n", + "\n", + "!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}" + ] + }, + { + "cell_type": "markdown", + "id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61", + "metadata": {}, + "source": [ + "## Testing\n", + "\n", + "After finishing our hyperparameter tuning, we can calculate and report the final test performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1faa997-9f81-4847-94fc-5a8bb7687369", + "metadata": {}, + "outputs": [], + "source": [ + "!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}" + ] + }, + { + "cell_type": "markdown", + "id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042", + "metadata": {}, + "source": [ + "## Additional Reading\n", + "\n", + "Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n", + "\n", + "* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "execution": { + "timeout": 1200 + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/contribute_datamodule.ipynb b/docs/tutorials/contribute_datamodule.ipynb new file mode 100644 index 00000000000..3182213c510 --- /dev/null +++ b/docs/tutorials/contribute_datamodule.ipynb @@ -0,0 +1,250 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Contribute a New DataModule" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_Written by: Nils Lehmann_\n", + "\n", + "TorchGeo provides Lightning `DataModules` and trainers to faciliate easy and scalabel model training based on simple configuration files. Essentially, a `DataModule` implements the logic for splitting a dataset into train, validation and test splits for reproducability, wrapping them in PyTorch `DataLoaders` and apply augmentations to batches of data. This tutorial will outline a guide to adding a new datamodule to TorchGeo. It is often easy to do so alongside a new dataset and will make the dataset directly useable for a Lightning training and evaluation pipeline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding the datamodule" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Adding a datamodule to TorchGeo consists of roughly four parts:\n", + "\n", + "1. a `dataset_name.py` file under `torchgeo/datamodules` that implements the split logic and defines augmentation\n", + "2. a `dataset_name.yaml` file under `tests/configs` that defines arguments to directly test the datamodule with the appropriate task\n", + "3. add the above yaml file to the list of files to be tested in the corresponding `test_{task}.py` file under `tests/trainers`\n", + "4. an entry to the documentation page file `datamodules.rst` under `docs/api/`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The datamodule `dataset_name.py` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The vast majority of new DataModules can inherit from one of the base classes that take care of the majority of the work. The goal of the dataset specific DataModule is to specify how the dataset should be split into train/val/test and any augmentations that should be applied to batches of data.\n", + "\n", + "\n", + "```python\n", + "\n", + "\"\"\"NewDatasetDataModule datamodule.\"\"\"\n", + "\n", + "import os\n", + "from typing import Any\n", + "\n", + "import kornia.augmentation as K\n", + "import torch\n", + "from torch.utils.data import Subset\n", + "\n", + "from .geo import NonGeoDataModule\n", + "from .utils import group_shuffle_split\n", + "\n", + "\n", + "# We follow the convention of appending the dataset_name with \"DataModule\"\n", + "class NewDatasetDataModule(NonGeoDataModule):\n", + " \"\"\"LightningDataModule implementation for the NewDataset dataset.\n", + "\n", + " Make a comment here about how the dataset is split into train/val/test.\n", + "\n", + " You can also add any other comments or references that are helpful to \n", + " understand implementation decisions\n", + "\n", + " .. versionadded:: for example 0.7\n", + " \"\"\"\n", + " # you can define channelwise normalization statistics that will be applied\n", + " # to data batches, which is usually crucial for training stability and decent performance\n", + " mean = torch.Tensor([0.5, 0.4, 0.3])\n", + " std = torch.Tensor([1.5, 1.4, 1.3])\n", + "\n", + " def __init__(\n", + " self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any\n", + " ) -> None:\n", + " \"\"\"Initialize a new NewDatasetModule instance.\n", + "\n", + " Args:\n", + " batch_size: Size of each mini-batch.\n", + " num_workers: Number of workers for parallel data loading.\n", + " size: resize images of input size 1000x1000 to size x size\n", + " **kwargs: Additional keyword arguments passed to\n", + " :class:`~torchgeo.datasets.NewDataset`.\n", + " \"\"\"\n", + " # in the init method of the base class the dataset will be instantiated with **kwargs\n", + " super().__init__(NewDatasetName, batch_size, num_workers, **kwargs)\n", + "\n", + " # you can specify a series of Kornia augmentations that will be\n", + " # applied to a batch of training data in `on_after_batch_transfer` in the NonGeoDataModule base class\n", + " self.train_aug = K.AugmentationSequential(\n", + " K.Resize(size),\n", + " K.Normalize(self.mean, self.std),\n", + " K.RandomHorizontalFlip(p=0.5),\n", + " K.RandomVerticalFlip(p=0.5),\n", + " data_keys=None,\n", + " keepdim=True,\n", + " )\n", + "\n", + " # you can also define specific augmentations for other experiment phases, if not specified\n", + " # self.aug Augmentations will be applied\n", + " self.aug = K.AugmentationSequential(\n", + " K.Normalize(self.mean, self.std),\n", + " K.Resize(size), data_keys=None, keepdim=True\n", + " )\n", + "\n", + " self.size = size\n", + "\n", + " # setup defines how the dataset should be split\n", + " # this could either be predefined from the dataset authors or\n", + " # done in a prescribed way if some or no splits are specified\n", + " def setup(self, stage: str) -> None:\n", + " \"\"\"Set up datasets.\n", + "\n", + " Args:\n", + " stage: Either 'fit', 'validate', 'test', or 'predict'.\n", + " \"\"\"\n", + " if stage in ['fit', 'validate']:\n", + " dataset = NewDatasetName(split='train', **self.kwargs)\n", + " # perhaps the dataset contains some geographical metadata based on which you would create reproducible random\n", + " # splits\n", + " grouping_paths = [os.path.dirname(path) for path in dataset.file_list]\n", + " train_indices, val_indices = group_shuffle_split(\n", + " grouping_paths, test_size=0.2, random_state=0\n", + " )\n", + " self.train_dataset = Subset(dataset, train_indices)\n", + " self.val_dataset = Subset(dataset, val_indices)\n", + " if stage in ['test']:\n", + " self.test_dataset = NewDatasetName(split='test', **self.kwargs)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Linters\n", + "\n", + "See the [linter docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#linters) for an overview of linters that TorchGeo employs and how to apply them during commits for example. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Unit tests" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being called by some unit test. For new datasets, we commonly write a separate test file, however, for datamodules we would like to test them directly with one of the task trainers. To do this, you simply need to define a `config.yaml` file and add it to the list of files to be tested by a task. For example, if you added a new datamodule for image segmentation you would write a config file that should look something like this:\n", + "\n", + "```yaml\n", + "model:\n", + " class_path: SemanticSegmentationTask\n", + " init_args:\n", + " loss: 'ce'\n", + " model: 'unet'\n", + " backbone: 'resnet18'\n", + " in_channels: 3 # number of input channels for the dataset\n", + " num_classes: 7 # number of segmentation models\n", + " num_filters: 1 # a smaller model version for faster unit tests\n", + " ignore_index: null # one can ignore certain classes during the loss computation\n", + "data:\n", + " class_path: NewDatasetNameDataModule # arguments to the DataModule above you wrote\n", + " init_args:\n", + " batch_size: 1 # \n", + " dict_kwargs:\n", + " root: 'tests/data/deepglobelandcover' # necessary arguments for the underlying dataset class that the datamodule builds on\n", + "```\n", + "\n", + "The yaml file should \"simulate\" how you would use this datamodule for an actual experiment. Add this file with `dataset_name.yaml` to the `tests/conf` directory." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Final Checklist" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask in the Slack channel or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- The datamodule implementation\n", + " - define training/val/test split\n", + " - if there are dataset specific augmentations, implement and reference them\n", + " - add microsoft copyright notice to top of the file\n", + "- The config test file\n", + " - select the appropriate task, if the dataset supports multiple ones, you can create one for each task\n", + " - correct arguments such as the number of targets (classes)\n", + " - add the config file to the list of files to be tested in the corresponding `test_{task}.py` file under `tests/trainers`\n", + "- Unit Tests\n", + " - 100% test coverage\n", + "- Documentation\n", + " - an entry to the documentation page file `datamodules.rst` under `docs/api/`" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/contribute_non_geo_dataset.ipynb b/docs/tutorials/contribute_non_geo_dataset.ipynb new file mode 100644 index 00000000000..2ea23365a41 --- /dev/null +++ b/docs/tutorials/contribute_non_geo_dataset.ipynb @@ -0,0 +1,491 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Contribute a New Non-Geospatial Dataset\n", + "\n", + "_Written by: Nils Lehmann_\n", + "\n", + "Open-source datasets have significantly accelerated machine learning research. Geospatial machine learning datasets can be particularly complex to work with compared to more standard RGB-based vision datasets. To spare the community from having to repeatly implement data loading logic over and over, TorchGeo provides dozens of built-in datasets such that they can be downloaded and ready for use in a PyTorch framework with a single line of code. This tutorial will show how you can add a new non-geospatial dataset to this growing collection. \n", + "\n", + "As a reminder, TorchGeo differentiates between two types of datasets: geospatial and non-geospatial datasets. Non-geospatial datasets are integer indexed, like the datasets one might be familar with from torchvision, while geospatial datasets are indexed via spatiotemporal bounding boxes. Non-geospatial datasets can still return geospatial and other metadata and should be specific to the remote sensing domain. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo and its dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Where to start\n", + "\n", + "There are many types of remote sensing datasets. [Satellite-Image-Deep-Learning](https://github.com/satellite-image-deep-learning/datasets) maintains a list of many of these datasets, as well as links to other similar curated lists.\n", + "\n", + "Two aspects that will make it a lot easier to add the dataset are whether or not the dataset can be easily downloaded and whether or the dataset comes with a Github repository and publication that outlines how the authors intend the dataset to be used. These are not necessariy criteria, and sometimes it might be even more worthwhile to add a dataset without an existing code base, precisely because the marginal contribution to the community might be greater since a use of the dataset does not necessitate writing the loading implementation from scratch." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding the dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once you have identified a dataset that you would like to add to TorchGeo, you could identify in what application category it might roughly fall in. For example, a segmentation dataset based on a collection of *.png* files, versus a classification dataset based on pre-defined image chips in *.tif* files. In the later case, if you find that the dataset contains *.tif* files that have very large pixel sizes, such that loading a single file might be costly, consider adding the dataset as a geospatial dataset for easier indexing. Once, you have identified the \"task\" such as segmentation vs classification and the dataset format, see whether a dataset of the same or similar category exists in TorchGeo already. All datasets inherit from a `NonGeoDataset` or `GeoDataset` base class that provides an outline for the implementation logic as well as additional utility functions that should be reused. This reduces code duplication and makes it easier to unit test datasets.\n", + "\n", + "Adding a dataset to TorchGeo consists of roughly four steps:\n", + "\n", + "1. a `dataset_name.py` file itself that implements the logic of the dataset\n", + "2. a `data.py` file that creates dummy data in the same structure and format as the original dataset for unit tests\n", + "3. a `test_dataset_name.py` file that implements unit tests for the dataset\n", + "4. an entry to the documentation page files: `non_geo_datasets.csv` and `datasets.rst`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `dataset_name.py` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This file implements the logic to load a sample from the dataset as well as downloading the dataset automatically if possible. The new dataset inherits from a base class and the documentation string (docstring) of the class should contain:\n", + "\n", + "* a short summary of the dataset\n", + "* outline the features, such as the task the dataset is designed to solve\n", + "* outline the format the dataset comes in, e.g., file types, pixel dimensions, etc.\n", + "* a proper reference to the dataset such as a link to the paper so users can adequately cite the dataset when using it\n", + "* if required, a note about additional dependencies that are not part of TorchGeo's required dependencies\n", + "\n", + "The dataset implementation itself should contain:\n", + "\n", + "* a method to create an index structure the dataset can iterate over to load samples. This index structure also defines the length (`__len__`) of the dataset, i.e. how many individual samples can be loaded from the dataset\n", + "* a `__getitem__` method that takes an integer index argument, loads a sample of the dataset, and returns its components in a dictionary\n", + "* a `_verify` method that checks whether the dataset can be found on the filesystem, has already been downloaded and only needs to be extracted, or downloads and extracts the dataset from the web\n", + "* a `plot` method that can visually display a single sample of the dataset\n", + "\n", + "The code below attempts to roughly outline the parts required for a new `NonGeoDataset`. Specifics are of course very dependent on the type of dataset you want to add, but this template and other existing datasets should give you a decent starting point." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections.abc import Callable\n", + "\n", + "from matplotlib.pyplot import Figure\n", + "from torch import Tensor\n", + "\n", + "from torchgeo.datasets import NonGeoDataset\n", + "from torchgeo.datasets.utils import Path\n", + "\n", + "\n", + "class MyNewDataset(NonGeoDataset):\n", + " \"\"\"MyNewDataset.\n", + "\n", + " Short summary of the dataset and link to its homepage.\n", + "\n", + " Dataset features:\n", + "\n", + " * number of classes\n", + " * sensors\n", + " * area covered\n", + " * etc.\n", + "\n", + " Dataset format:\n", + "\n", + " * what file format and shape the input data comes in\n", + " * what file format and shape the target data comes in\n", + " * possible metadata files\n", + "\n", + " If you use this dataset in your research, please cite the following paper:\n", + "\n", + " * URL of publication or citation information\n", + "\n", + " .. versionadded: next TorchGeo minor release version, e.g., 1.0\n", + " \"\"\"\n", + "\n", + " # In this part of the code you can define class attributes such as a list of\n", + " # class names, color maps, url and checksums for data download, and other\n", + " # attributes that one might require repeatedly in the subsequent class methods.\n", + "\n", + " def __init__(\n", + " self,\n", + " root: Path = 'data',\n", + " split: str = 'train',\n", + " transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,\n", + " download: bool = False,\n", + " ) -> None:\n", + " \"\"\"Initialize the dataset.\n", + "\n", + " The init parameters can include additional arguments, such as an option to\n", + " select specific image bands, data modalities, or other arguments that give\n", + " greater control over data loading. They should all have reasonable defaults.\n", + "\n", + " Args:\n", + " root: root directory where dataset can be found\n", + " split: one of \"train\", \"val\", or \"test\"\n", + " transforms: a function/transform that takes input sample and its target as\n", + " entry and returns a transformed version\n", + " download: if True, download dataset and store it in the root directory\n", + " \"\"\"\n", + "\n", + " def __len__(self) -> int:\n", + " \"\"\"The length of the dataset.\n", + "\n", + " This is the total number of samples per epoch, and is used to define the\n", + " maximum allow index that can be passed to `__getitem__`.\n", + " \"\"\"\n", + "\n", + " def __getitem__(self, index: int) -> dict[str, Tensor]:\n", + " \"\"\"A single sample from the dataset.\n", + "\n", + " Load a single input image and target label or mask, and return it in a\n", + " dictionary.\n", + " \"\"\"\n", + "\n", + " def plot(self) -> Figure:\n", + " \"\"\"Plot a sample of the dataset for visualization purposes.\n", + "\n", + " This might involve selecting the RGB bands, using a colormap to display a mask,\n", + " adding a legend with class labels, etc.\n", + " \"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `data.py` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `data.py` file is placed under `tests/data/dataset_name/` directory and creates a smaller dummy dataset that replicates the features and formats of the actual full datasets for unit tests. This is needed to keep the tests fast (we don't have time or storage space to download the real dataset) and to comply with the dataset license. \n", + "\n", + "The script should:\n", + "\n", + "* replicate the directory structure and file names\n", + "* replicate the file format, data type, and range of values\n", + "* use the same compression scheme to simulate downloading the dataset\n", + "\n", + "This is usually highly dependent on the dataset format and structure the new dataset comes in. You should always look for a similar dataset first and use that as a reference. However, below is an outline of the usual building blocks of a `data.py` script, for example an image segmentation dataset with 10 classes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "import tempfile\n", + "\n", + "import numpy as np\n", + "from PIL import Image\n", + "\n", + "# Define the root directory and subdirectories\n", + "# Normally this would be the current directory (tests/data/my_new_dataset)\n", + "root_dir = os.path.join(tempfile.gettempdir(), 'my_new_dataset')\n", + "sub_dirs = ['image', 'target']\n", + "splits = ['train', 'val', 'test']\n", + "\n", + "image_file_names = ['sample_1.png', 'sample_2.png', 'sample_3.png']\n", + "\n", + "IMG_SIZE = 32\n", + "\n", + "\n", + "# Function to create dummy input images\n", + "def create_input_image(path: str, shape: tuple[int], pixel_values: list[int]) -> None:\n", + " data = np.random.choice(pixel_values, size=shape, replace=True).astype(np.uint8)\n", + " img = Image.fromarray(data)\n", + " img.save(path)\n", + "\n", + "\n", + "# Function to create dummy targets\n", + "def create_target_images(split: str, filename: str) -> None:\n", + " target_pixel_values = range(10)\n", + " path = os.path.join(root_dir, 'target', split, filename)\n", + " create_input_image(path, (IMG_SIZE, IMG_SIZE), target_pixel_values)\n", + "\n", + "\n", + "# Create a new clean version when re-running the script\n", + "if os.path.exists(root_dir):\n", + " shutil.rmtree(root_dir)\n", + "\n", + "# Create the directory structure\n", + "for sub_dir in sub_dirs:\n", + " for split in splits:\n", + " os.makedirs(os.path.join(root_dir, sub_dir, split), exist_ok=True)\n", + "\n", + "# Create dummy data for all splits and filenames\n", + "for split in splits:\n", + " for filename in image_file_names:\n", + " create_input_image(\n", + " os.path.join(root_dir, 'image', split, filename),\n", + " (IMG_SIZE, IMG_SIZE),\n", + " range(2**16),\n", + " )\n", + " create_target_images(split, filename.replace('_', '_target_'))\n", + "\n", + "# Zip directory\n", + "shutil.make_archive(root_dir, 'zip', '.', root_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The `test_dataset_name.py` file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `test_dataset_name.py` file is placed under the `tests/datasets/` directory. This file implements the unit tests for the dataset, such that every line of code in `dataset_name.py` is tested. The logic of the individual test cases will likely be very similar to existing test files so you can look at those to to see how you can test the individual parts of the dataset logic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import shutil\n", + "from pathlib import Path\n", + "\n", + "import pytest\n", + "import torch\n", + "import torch.nn as nn\n", + "from _pytest.fixtures import SubRequest\n", + "from matplotlib import pyplot as plt\n", + "from pytest import MonkeyPatch\n", + "\n", + "from torchgeo.datasets import DatasetNotFoundError\n", + "\n", + "\n", + "def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:\n", + " shutil.copy(url, root)\n", + "\n", + "\n", + "class TestMyNewDataset:\n", + " # pytest fixtures can be used to define variables to test different argument\n", + " # configurations to test, for example the different splits of the dataset\n", + " # or subselection of modalities/bands\n", + " @pytest.fixture(params=['train', 'val', 'test'])\n", + " def dataset(\n", + " self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest\n", + " ) -> MyNewDataset:\n", + " # monkeypatch can overwrite the class attributes defined above the __init__\n", + " # method and use the specific unit tests settings to mock behavior\n", + "\n", + " split: str = request.param\n", + " transforms = nn.Identity()\n", + " return MyNewDataset(tmp_path, split=split, transforms=transforms, download=True)\n", + "\n", + " def test_getitem(self, dataset: MyNewDataset) -> None:\n", + " # Retrieve a sample and check some of the desired properties\n", + " x = dataset[0]\n", + " assert isinstance(x, dict)\n", + " assert isinstance(x['image'], torch.Tensor)\n", + " assert isinstance(x['label'], torch.Tensor)\n", + "\n", + " # For all additional class arguments, check behavior for invalid parameters\n", + " def test_invalid_split(self) -> None:\n", + " with pytest.raises(AssertionError):\n", + " MyNewDataset(foo='bar')\n", + "\n", + " # Test the length of the dataset, this should coincide with the dummy data\n", + " def test_len(self, dataset: MyNewDataset) -> None:\n", + " assert len(dataset) == 2\n", + "\n", + " # Test the logic when the dataset is already downloaded\n", + " def test_already_downloaded(self, dataset: MyNewDataset, tmp_path: Path) -> None:\n", + " MyNewDataset(root=tmp_path, download=True)\n", + "\n", + " # Test the logic when the dataset is already downloaded but not extracted\n", + " def test_already_downloaded_not_extracted(\n", + " self, dataset: MyNewDataset, tmp_path: Path\n", + " ) -> None:\n", + " shutil.rmtree(dataset.root)\n", + " download_url(dataset.url, root=tmp_path)\n", + " MyNewDataset(root=tmp_path, download=False)\n", + "\n", + " # Test the logic when the dataset is not downloaded\n", + " def test_not_downloaded(self, tmp_path: Path) -> None:\n", + " with pytest.raises(DatasetNotFoundError, match='Dataset not found'):\n", + " MyNewDataset(tmp_path)\n", + "\n", + " # Test the plotting method through something like the following\n", + " def test_plot(self, dataset: MyNewDataset) -> None:\n", + " x = dataset[0].copy()\n", + " x['prediction'] = x['label'].clone()\n", + " dataset.plot(x, suptitle='Test')\n", + " plt.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Documentation Entries" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The entry point for new and experienced users of domain libraries is often the dedicated documentation page that accompanies a Github repository. TorchGeo uses the popular [Sphinx](https://www.sphinx-doc.org/en/master/) framework to build its documentation. To display the documentation strings you have written in `dataset_name.py` on the actual documentation page, you need to create an entry in `docs/api/datasets.rst` in alphabetical order:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```rst\n", + "Dataset Name\n", + "^^^^^^^^^^^^\n", + "\n", + ".. autoclass:: MyNewDataset\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additionally, add a row in the `non_geo_datasets.csv` file under `docs/api/datasets` to include the dataset in the overview table." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Linters\n", + "\n", + "See the [linter docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#linters) for an overview of linters that TorchGeo employs and how to apply them during commits." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test Coverage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being run by some unit test. The [testing docs](https://torchgeo.readthedocs.io/en/stable/user/contributing.html#tests) provide instructions on how you can test the coverage locally for the `dataset_new.py` file that you are adding." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Final Checklist" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask on Slack or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side menu." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Dataset implementation in `dataset_name.py`\n", + " - Class docstring containing:\n", + " - Summary intro\n", + " - Dataset features\n", + " - Dataset format\n", + " - Link to publication\n", + " - `versionadded` tag\n", + " - if applicable a note on additional dependencies\n", + " - all class methods have docstrings\n", + " - all class methods have argument and return type hints, mypy (the tool that checks type hints) can be confusing at the beginning so don't hesitate to ask for help\n", + " - if dataset is on GitHub or Huggingface, url link should contain the commit hash\n", + " - checksum added\n", + " - plot method that can display a single sample from the dataset (you can add the resulting figure in your PR description)\n", + " - add the dataset to `torchgeo/datastes/__init__.py`\n", + " - Add the copyright at the top of the file\n", + "- Dummy data script `data.py`\n", + " - replicate directory structure\n", + " - replicate naming of directory and files\n", + " - for image based datasets, use a small size, like 32x32\n", + "- Unit tests `test_dataset_name.py`\n", + " - 100% test coverage \n", + "- Documentation with `non_geo_datasets.csv` and `datasets.rst`\n", + " - entry in `datasets.rst`\n", + " - entry in `non_geo_datasets.csv`\n", + " - documentation displays properly, this can be checked locally or via the GitHub CI tests under `docs/readthedocs.org:torchgeo`" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 1da51ec620d..f171c31363b 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -1,14 +1,15 @@ { "cells": [ { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { "id": "iiqWbXISOEAQ" }, + "outputs": [], "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." ] }, { @@ -19,6 +20,8 @@ "source": [ "# Custom Raster Datasets\n", "\n", + "_Written by: Ritwik Gupta_\n", + "\n", "In this tutorial, we'll describe how to write a custom dataset in TorchGeo. There are many types of datasets that you may encounter, from image data, to segmentation masks, to point labels. We'll focus on the most common type of dataset: a raster file containing an image or mask. Let's get started!" ] }, @@ -332,6 +335,10 @@ "\n", "If your data only contains model inputs (such as images), use `is_image = True`. If your data only contains ground truth model outputs (such as segmentation masks), use `is_image = False` instead.\n", "\n", + "Consequently, the sample returned by the dataset/data loader will use the \"image\" key if *is_image* is True, otherwise it will use the \"mask\" key.\n", + "\n", + "For datasets with both model inputs and outputs, the recommended approach is to use 2 `RasterDataset` instances and combine them using an `IntersectionDataset`. See L7 Irish, L8 Biome, and I/O Bench for examples of this in `torchgeo/datasets`.\n", + "\n", "### `dtype`\n", "\n", "Defaults to float32 for `is_image == True` and long for `is_image == False`. This is what you want for 99% of datasets, but can be overridden for tasks like pixel-wise regression (where the target mask should be float32).\n", @@ -369,8 +376,8 @@ " date_format = '%Y%m%dT%H%M%S'\n", " is_image = True\n", " separate_files = True\n", - " all_bands = ['B02', 'B03', 'B04', 'B08']\n", - " rgb_bands = ['B04', 'B03', 'B02']" + " all_bands = ('B02', 'B03', 'B04', 'B08')\n", + " rgb_bands = ('B04', 'B03', 'B02')" ] }, { @@ -432,8 +439,8 @@ " date_format = '%Y%m%dT%H%M%S'\n", " is_image = True\n", " separate_files = True\n", - " all_bands = ['B02', 'B03', 'B04', 'B08']\n", - " rgb_bands = ['B04', 'B03', 'B02']\n", + " all_bands = ('B02', 'B03', 'B04', 'B08')\n", + " rgb_bands = ('B04', 'B03', 'B02')\n", "\n", " def plot(self, sample):\n", " # Find the correct band index order\n", @@ -474,10 +481,9 @@ }, "outputs": [], "source": [ - "torch.manual_seed(1)\n", - "\n", "dataset = Sentinel2(root)\n", - "sampler = RandomGeoSampler(dataset, size=4096, length=3)\n", + "g = torch.Generator().manual_seed(1)\n", + "sampler = RandomGeoSampler(dataset, size=4096, length=3, generator=g)\n", "dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)\n", "\n", "for batch in dataloader:\n", @@ -516,12 +522,12 @@ "outputs": [], "source": [ "class Downloadable(RasterDataset):\n", - " def __init__(self, root, crs, res, transforms, cache, download=False):\n", - " super().__init__(root, crs, res, transforms, cache)\n", - "\n", + " def __init__(self, paths, crs, res, bands, transforms, cache, download=False):\n", " if download:\n", " # download the dataset\n", - " ..." + " ...\n", + "\n", + " super().__init__(paths, crs, res, bands, transforms, cache)" ] }, { @@ -559,9 +565,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.13.0" } }, "nbformat": 4, - "nbformat_minor": 1 + "nbformat_minor": 4 } diff --git a/docs/tutorials/customization.rst b/docs/tutorials/customization.rst new file mode 100644 index 00000000000..07b0870607b --- /dev/null +++ b/docs/tutorials/customization.rst @@ -0,0 +1,20 @@ +Customization +============= + +Is TorchGeo missing a dataset or model you need? Would you like to modify the default augmentations for a data module or extend a builtin trainer? + +The following tutorials will teach you how to customize TorchGeo to meet your needs: + +* `Custom Non-Geospatial Datasets `_: How to create and contribute a new NonGeoDataset +* `Custom Raster Datasets `_: How to create a new RasterDataset +* `Custom Data Module `_: How to create and contribute a new DataModule + +TorchGeo is a community-driven open source library. If there is a feature missing that you would like to add, please open a pull request to add it. See the ref:`contributing` guidelines to get started. + +.. toctree:: + :hidden: + :maxdepth: 1 + + contribute_non_geo_dataset + custom_raster_dataset + contribute_datamodule diff --git a/docs/tutorials/earth_surface_water.ipynb b/docs/tutorials/earth_surface_water.ipynb new file mode 100644 index 00000000000..f65aedbf726 --- /dev/null +++ b/docs/tutorials/earth_surface_water.ipynb @@ -0,0 +1,888 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MAR190Aszv8r" + }, + "source": [ + "# Earth Water Surface\n", + "\n", + "_Written by: Mauricio Cordeiro_" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0lZeGJNTz1y5" + }, + "source": [ + "## Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VXG-oXGUz39X" + }, + "source": [ + "The objective of this tutorial is to go through the Earth Water Surface dataset and cover the following topics:
\n", + "\n", + "* Creating RasterDatasets, DataLoaders and Samplers for images and masks;\n", + "* Intersection Dataset;\n", + "* Normalizing the data;\n", + "* Creating spectral indices;\n", + "* Creating the segmentation model (DeepLabV3);\n", + "* Loss function and metrics; and\n", + "* Training loop.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JTt_Ysyl4El5" + }, + "source": [ + "## Environment" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-U9kWwoL4GqT" + }, + "source": [ + "For the environment, we will install the torchgeo and scikit-learn packages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "A14-syGAFahE", + "outputId": "d5b0ac1d-5cc2-4532-b2e6-7134638e9389" + }, + "outputs": [], + "source": [ + "%pip install torchgeo scikit-learn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import tempfile\n", + "from collections.abc import Callable, Iterable\n", + "from pathlib import Path\n", + "\n", + "import kornia.augmentation as K\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import rasterio as rio\n", + "import torch\n", + "from sklearn.metrics import jaccard_score\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils\n", + "from torchgeo.samplers import RandomGeoSampler, Units\n", + "from torchgeo.transforms import indices" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "73l1wc-fFWuU" + }, + "source": [ + "## Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "yY7mCdoq00Yo" + }, + "source": [ + "The dataset we will use is the Earth Surface Water dataset [1] (licensed under Creative Commons Attribution 4.0 International Public License), which has patches from different parts of the world (Figure below) and its corresponding water masks. The dataset uses optical imagery from Sentinel-2 satellite with 10m of spatial resolution.\n", + "\n", + "![Image1](https://raw.githubusercontent.com/xinluo2018/WatNet/main/figures/dataset.png)\n", + "\n", + "[1] Xin Luo. (2021). Earth Surface Water Dataset [Data set]. Zenodo. https://doi.org/10.5281/zenodo.5205674\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download and extract dataset to a temp folder\n", + "tmp_path = Path(tempfile.gettempdir()) / 'surface_water/'\n", + "utils.download_and_extract_archive(\n", + " 'https://hf.co/datasets/cordmaur/earth_surface_water/resolve/main/earth_surface_water.zip',\n", + " tmp_path,\n", + ")\n", + "\n", + "# Set the root to the extracted folder\n", + "root = tmp_path / 'dset-s2'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "abnT63f1GOh8" + }, + "source": [ + "## Creating the Datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7ke5sQ1_4nEq" + }, + "source": [ + "Now that we have the original dataset already uncompressed in Colab’s environment, we can prepare it to be loaded into a neural network. For that, we will create an instance of the RasterDataset class, provided by TorchGeo, and point to the specific directory, using the following commands. The `scale` function will apply the `1e-4` scale necessary to get the Sentinel-2 values in reflectance. Once the datasets are created, we can combine images with masks (labels) using the `&` operator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iXiPVrjXGSes" + }, + "outputs": [], + "source": [ + "def scale(item: dict):\n", + " item['image'] = item['image'] / 10000\n", + " return item" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yWApUYVYZ1Lr" + }, + "outputs": [], + "source": [ + "train_imgs = RasterDataset(\n", + " paths=(root / 'tra_scene').as_posix(), crs='epsg:3395', res=10, transforms=scale\n", + ")\n", + "train_msks = RasterDataset(\n", + " paths=(root / 'tra_truth').as_posix(), crs='epsg:3395', res=10\n", + ")\n", + "\n", + "valid_imgs = RasterDataset(\n", + " paths=(root / 'val_scene').as_posix(), crs='epsg:3395', res=10, transforms=scale\n", + ")\n", + "valid_msks = RasterDataset(\n", + " paths=(root / 'val_truth').as_posix(), crs='epsg:3395', res=10\n", + ")\n", + "\n", + "# IMPORTANT\n", + "train_msks.is_image = False\n", + "valid_msks.is_image = False\n", + "\n", + "train_dset = train_imgs & train_msks\n", + "valid_dset = valid_imgs & valid_msks\n", + "\n", + "# Create the samplers\n", + "\n", + "train_sampler = RandomGeoSampler(train_imgs, size=512, length=130, units=Units.PIXELS)\n", + "valid_sampler = RandomGeoSampler(valid_imgs, size=512, length=64, units=Units.PIXELS)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MaTZ03eJ5FGa" + }, + "source": [ + "Note that we are specifying the CRS (Coordinate Reference System) to EPSG:3395. TorchGeo requires that all the images are loaded in the same CRS. However, the patches in the dataset are in different UTM projections and the default behavior of TorchGeo is to use the first CRS found as its default. In this case, we have to inform a CRS that is able to cope with these different regions around the globe. To minimize the deformations due to the huge differences in latitude (I can create a history specific for this purpose) within the patches, I have selected World Mercator as the main CRS for the project. Figure 3 shows the world projected in World Mercator CRS.\n", + "\n", + "\n", + "![Image2](https://miro.medium.com/max/4800/1*sUdRKEfIAbm79jpB3bCShQ.webp)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TqNOU7WaOJ2t" + }, + "source": [ + "### Understanding the sampler" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8RFrF3bTOSJn" + }, + "source": [ + "To create training patches that can be fed into a neural network from our dataset, we need to select samples of fixed sizes. TorchGeo has many samplers, but here we will use the `RandomGeoSampler` class. Basically, the sampler selects random bounding boxes of fixed size that belongs to the original image. Then, these bounding boxes are used in the `RasterDataset` to query the portion of the image we want. Here is an exmple using the previously created samplers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "r8VGVIWNPI_W", + "outputId": "5b779cd3-25bc-4ec4-e29d-99391e906a4d" + }, + "outputs": [], + "source": [ + "bbox = next(iter(train_sampler))\n", + "bbox" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KFCssvlqPI87", + "outputId": "4bb06a55-375e-4c14-d790-2429d4f14d0a" + }, + "outputs": [], + "source": [ + "sample = train_dset[bbox]\n", + "sample.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "G-gF-8k1PI6N", + "outputId": "60002cb9-8d8c-4e1e-ba18-51edf4191ea3" + }, + "outputs": [], + "source": [ + "sample['image'].shape, sample['mask'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dGiqU2TcPcTq" + }, + "source": [ + "Notice we have now patches of same size (..., 512 x 512)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TBNfy4X5Pn-G" + }, + "source": [ + "## Creating Dataloaders" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a97dvCYVP5D5" + }, + "source": [ + "Creating a `DataLoader` in TorchGeo is very straightforward, just like it is with Pytorch (we are actually using the same class). Note below that we are also using the same samplers already defined. Additionally we inform the dataset that the dataloader will use to pull data from, the batch_size (number of samples in each batch) and a collate function that specifies how to “concatenate” the multiple samples into one single batch.\n", + "\n", + "Finally, we can iterate through the dataloader to grab batches from it. To test it, we will get the first batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IhWa0SYfav2V", + "outputId": "6e80207e-30cf-46b8-d507-e15c0619aa9d" + }, + "outputs": [], + "source": [ + "# Adjust the batch size according to your GPU memory\n", + "train_dataloader = DataLoader(\n", + " train_dset, sampler=train_sampler, batch_size=4, collate_fn=stack_samples\n", + ")\n", + "valid_dataloader = DataLoader(\n", + " valid_dset, sampler=valid_sampler, batch_size=4, collate_fn=stack_samples\n", + ")\n", + "\n", + "train_batch = next(iter(train_dataloader))\n", + "valid_batch = next(iter(valid_dataloader))\n", + "train_batch.keys(), valid_batch.keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o7VRAzpkQIvr" + }, + "source": [ + "## Batch Visualization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "78IKqlWUQMTM" + }, + "source": [ + "Now that we can draw batches from our datasets, let’s create a function to display the batches.\n", + "\n", + "The function `plot_batch` will will check automatically the number of items in the batch and if there are masks associated to arrange the output grid accordingly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zynH3etxQQmG" + }, + "outputs": [], + "source": [ + "def plot_imgs(\n", + " images: Iterable, axs: Iterable, chnls: list[int] = [2, 1, 0], bright: float = 3.0\n", + "):\n", + " for img, ax in zip(images, axs):\n", + " arr = torch.clamp(bright * img, min=0, max=1).numpy()\n", + " rgb = arr.transpose(1, 2, 0)[:, :, chnls]\n", + " ax.imshow(rgb)\n", + " ax.axis('off')\n", + "\n", + "\n", + "def plot_msks(masks: Iterable, axs: Iterable):\n", + " for mask, ax in zip(masks, axs):\n", + " ax.imshow(mask.squeeze().numpy(), cmap='Blues')\n", + " ax.axis('off')\n", + "\n", + "\n", + "def plot_batch(\n", + " batch: dict,\n", + " bright: float = 3.0,\n", + " cols: int = 4,\n", + " width: int = 5,\n", + " chnls: list[int] = [2, 1, 0],\n", + "):\n", + " # Get the samples and the number of items in the batch\n", + " samples = unbind_samples(batch.copy())\n", + "\n", + " # if batch contains images and masks, the number of images will be doubled\n", + " n = 2 * len(samples) if ('image' in batch) and ('mask' in batch) else len(samples)\n", + "\n", + " # calculate the number of rows in the grid\n", + " rows = n // cols + (1 if n % cols != 0 else 0)\n", + "\n", + " # create a grid\n", + " _, axs = plt.subplots(rows, cols, figsize=(cols * width, rows * width))\n", + "\n", + " if ('image' in batch) and ('mask' in batch):\n", + " # plot the images on the even axis\n", + " plot_imgs(\n", + " images=map(lambda x: x['image'], samples),\n", + " axs=axs.reshape(-1)[::2],\n", + " chnls=chnls,\n", + " bright=bright,\n", + " )\n", + "\n", + " # plot the masks on the odd axis\n", + " plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1)[1::2])\n", + "\n", + " else:\n", + " if 'image' in batch:\n", + " plot_imgs(\n", + " images=map(lambda x: x['image'], samples),\n", + " axs=axs.reshape(-1),\n", + " chnls=chnls,\n", + " bright=bright,\n", + " )\n", + "\n", + " elif 'mask' in batch:\n", + " plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 672 + }, + "id": "0S4DZa8aQd8Z", + "outputId": "755d6f7a-d4fd-4cab-ec1d-c293255a7e8c" + }, + "outputs": [], + "source": [ + "plot_batch(train_batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SkGQoaWlQVFC" + }, + "source": [ + "## Data Standardization and Spectral Indices" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mvce_wKMQ3pT" + }, + "source": [ + "Normally, machine learning methods (deep learning included) benefit from feature scaling. That means standard deviation around 1 and zero mean, by applying the following formula:
\n", + "$X'=\\frac{X-Mean}{\\text{Standard deviation}}$\n", + "\n", + "To do that, we need to first find the mean and standard deviation for each one of the 6s channels in the dataset.\n", + "\n", + "Let’s define a function calculate these statistics and write its results in the variables mean and std. We will use our previously installed rasterio package to open the images and perform a simple average over the statistics for each batch/channel. For the standard deviation, this method is an approximation. For a more precise calculation, please refer to: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.htm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IZ5yXZPzNIdh" + }, + "outputs": [], + "source": [ + "def calc_statistics(dset: RasterDataset):\n", + " \"\"\"\n", + " Calculate the statistics (mean and std) for the entire dataset\n", + " Warning: This is an approximation. The correct value should take into account the\n", + " mean for the whole dataset for computing individual stds.\n", + " For correctness I suggest checking: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html\n", + " \"\"\"\n", + "\n", + " # To avoid loading the entire dataset in memory, we will loop through each img\n", + " # The filenames will be retrieved from the dataset's rtree index\n", + " files = [\n", + " item.object for item in dset.index.intersection(dset.index.bounds, objects=True)\n", + " ]\n", + "\n", + " # Reseting statistics\n", + " accum_mean = 0\n", + " accum_std = 0\n", + "\n", + " for file in files:\n", + " img = rio.open(file).read() / 10000 # type: ignore\n", + " accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)\n", + " accum_std += img.reshape((img.shape[0], -1)).std(axis=1)\n", + "\n", + " # at the end, we shall have 2 vectors with lenght n=chnls\n", + " # we will average them considering the number of images\n", + " return accum_mean / len(files), accum_std / len(files)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4VtubMAPxXYq" + }, + "outputs": [], + "source": [ + "# Calculate the statistics (Mean and std) for the dataset\n", + "mean, std = calc_statistics(train_imgs)\n", + "\n", + "# Please, note that we will create spectral indices using the raw (non-normalized) data. Then, when normalizing, the sensors will have more channels (the indices) that should not be normalized.\n", + "# To solve this, we will add the indices to the 0's to the mean vector and 1's to the std vectors\n", + "mean = np.concat([mean, [0, 0, 0]])\n", + "std = np.concat([std, [1, 1, 1]])\n", + "\n", + "norm = K.Normalize(mean=mean, std=std)\n", + "\n", + "tfms = torch.nn.Sequential(\n", + " indices.AppendNDWI(index_green=1, index_nir=3),\n", + " indices.AppendNDWI(index_green=1, index_nir=5),\n", + " indices.AppendNDVI(index_nir=3, index_red=2),\n", + " norm,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hVsDAOaVRhiN", + "outputId": "b128640d-b3bc-4cf6-82e1-187b5df2a164" + }, + "outputs": [], + "source": [ + "transformed_img = tfms(train_batch['image'])\n", + "print(transformed_img.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HA-jTEKeRuA4" + }, + "source": [ + "Note that our transformed batch has now 9 channels, instead of 6." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hWfUOS1RR14g" + }, + "source": [ + "> Important: the normalize method we created will apply the normalization just to the original bands and it will ignore the previously appended indices. That’s important to avoid errors due to distinct shapes between the batch and the mean and std vectors." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JfNmPcwWSNFv" + }, + "source": [ + "## Segmentation Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bMmXdpZWSSNR" + }, + "source": [ + "For the semantic segmentation model, we are going to use a predefined architecture that is available in Pytorch. Looking at list (https://pytorch.org/vision/stable/models.html#semantic-segmentation) it is possible to note 3 models available for semantic segmentation, but one (LRASPP) is intended for mobile applications. In our tutorial, we will use the DeepLabV3 model.\n", + "\n", + "Here, we will create a DeepLabV3 model for 2 classes. In this case, I will skip the pretrained weights, as the weights represent another domain (not water segmentation from multispectral imagery)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OzhOtV3IyubJ", + "outputId": "b7e57773-1f4c-4140-9039-b0e84e59aa58" + }, + "outputs": [], + "source": [ + "from torchvision.models.segmentation import deeplabv3_resnet50\n", + "\n", + "model = deeplabv3_resnet50(weights=None, num_classes=2)\n", + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AvsK_LXuSmnL" + }, + "source": [ + "The first thing we have to pay attention in the model architecture is the number of channels expected in the first convolution (Conv2d), that is defined as 3. That’s because the model is prepared to work with RGB images. After the first convolution, the 3 channels will produce 64 channels in lower resolution, and so on. As we have now 9 channels, we will change this first processing layer to adapt correctly to our model. We can do this by replacing the first convolutional layer for a new one, by following the commands. Finally, we check a mock batch can pass through the model and provide the output with 2 channels (water / no_water) as desired." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xYPdGBPOSfHJ", + "outputId": "7e9267b7-a4e7-428d-a0cd-9292dd2b1923" + }, + "outputs": [], + "source": [ + "backbone = model.get_submodule('backbone')\n", + "\n", + "conv = torch.nn.modules.conv.Conv2d(\n", + " in_channels=9,\n", + " out_channels=64,\n", + " kernel_size=(7, 7),\n", + " stride=(2, 2),\n", + " padding=(3, 3),\n", + " bias=False,\n", + ")\n", + "backbone.register_module('conv1', conv)\n", + "\n", + "pred = model(torch.randn(3, 9, 512, 512))\n", + "pred['out'].shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PvvGWtMrTdCE" + }, + "source": [ + "## Training Loop" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YuUPhlzxTftk" + }, + "source": [ + "The training function should receive the number of epochs, the model, the dataloaders, the loss function (to be optimized) the accuracy function (to assess the results), the optimizer (that will adjust the parameters of the model in the correct direction) and the transformations to be applied to each batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if GPU is available\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "device" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sFlsT8glSq3x" + }, + "outputs": [], + "source": [ + "def train_loop(\n", + " epochs: int,\n", + " train_dl: DataLoader,\n", + " val_dl: DataLoader | None,\n", + " model: torch.nn.Module,\n", + " loss_fn: Callable,\n", + " optimizer: torch.optim.Optimizer,\n", + " acc_fns: list | None = None,\n", + " batch_tfms: Callable | None = None,\n", + "):\n", + " # size = len(dataloader.dataset)\n", + " cuda_model = model.to(device)\n", + "\n", + " for epoch in range(epochs):\n", + " accum_loss = 0\n", + " for batch in train_dl:\n", + " if batch_tfms is not None:\n", + " X = batch_tfms(batch['image']).to(device)\n", + " else:\n", + " X = batch['image'].to(device)\n", + "\n", + " y = batch['mask'].type(torch.long).to(device)\n", + " pred = cuda_model(X)['out']\n", + " loss = loss_fn(pred, y)\n", + "\n", + " # BackProp\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # update the accum loss\n", + " accum_loss += float(loss) / len(train_dl)\n", + "\n", + " # Testing against the validation dataset\n", + " if acc_fns is not None and val_dl is not None:\n", + " # reset the accuracies metrics\n", + " acc = [0.0] * len(acc_fns)\n", + "\n", + " with torch.no_grad():\n", + " for batch in val_dl:\n", + " if batch_tfms is not None:\n", + " X = batch_tfms(batch['image']).to(device)\n", + " else:\n", + " X = batch['image'].type(torch.float32).to(device)\n", + "\n", + " y = batch['mask'].type(torch.long).to(device)\n", + "\n", + " pred = cuda_model(X)['out']\n", + "\n", + " for i, acc_fn in enumerate(acc_fns):\n", + " acc[i] = float(acc[i] + acc_fn(pred, y) / len(val_dl))\n", + "\n", + " # at the end of the epoch, print the errors, etc.\n", + " print(\n", + " f'Epoch {epoch}: Train Loss={accum_loss:.5f} - Accs={[round(a, 3) for a in acc]}'\n", + " )\n", + " else:\n", + " print(f'Epoch {epoch}: Train Loss={accum_loss:.5f}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kUIPlgndUB_9" + }, + "source": [ + "## Loss and Accuracy Functions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xs5LOs8LUS7j" + }, + "source": [ + "For the loss function, normally the Cross Entropy Loss should work, but it requires the mask to have shape (N, d1, d2). In this case, we will need to squeeze our second dimension manually." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "O8K3e5NwTrYg" + }, + "outputs": [], + "source": [ + "def oa(pred, y):\n", + " flat_y = y.squeeze()\n", + " flat_pred = pred.argmax(dim=1)\n", + " acc = torch.count_nonzero(flat_y == flat_pred) / torch.numel(flat_y)\n", + " return acc\n", + "\n", + "\n", + "def iou(pred, y):\n", + " flat_y = y.cpu().numpy().squeeze()\n", + " flat_pred = pred.argmax(dim=1).detach().cpu().numpy()\n", + " return jaccard_score(flat_y.reshape(-1), flat_pred.reshape(-1), zero_division=1.0)\n", + "\n", + "\n", + "def loss(p, t):\n", + " return torch.nn.functional.cross_entropy(p, t.squeeze())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UW6YfmnyUa1A" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xajn1unSUh3h" + }, + "source": [ + "> To train the model it is important to have CUDA GPUs available. In Colab, it can be done by changing the runtime type and re-running the notebook. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "moU9ol78UUqP", + "outputId": "a1b5af92-b7fe-43c0-c4cd-f92906a46fb3" + }, + "outputs": [], + "source": [ + "# adjust number of epochs depending on the device\n", + "if torch.cuda.is_available():\n", + " num_epochs = 2\n", + "else:\n", + " # if GPU is not available, just make 1 pass and limit the size of the datasets\n", + " num_epochs = 1\n", + "\n", + " # by limiting the length of the sampler we limit the iterations in each epoch\n", + " train_dataloader.sampler.length = 8\n", + " valid_dataloader.sampler.length = 8\n", + "\n", + "# train the model\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)\n", + "train_loop(\n", + " num_epochs,\n", + " train_dataloader,\n", + " valid_dataloader,\n", + " model,\n", + " loss,\n", + " optimizer,\n", + " acc_fns=[oa, iou],\n", + " batch_tfms=tfms,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Additional Reading" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tutorial is also available as a 3 parts Medium story:
https://medium.com/towards-data-science/artificial-intelligence-for-geospatial-analysis-with-pytorchs-torchgeo-part-1-52d17e409f09" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "authorship_tag": "ABX9TyPk0gtwHzQoTqfC6uudCTRe", + "include_colab_link": true, + "provenance": [], + "toc_visible": true + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/tutorials/geospatial.ipynb b/docs/tutorials/geospatial.ipynb new file mode 100644 index 00000000000..880ad1ab998 --- /dev/null +++ b/docs/tutorials/geospatial.ipynb @@ -0,0 +1,355 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "45973fd5-6259-4e03-9501-02ee96f3f870", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "id": "9478ed9a", + "metadata": { + "id": "NdrXRgjU7Zih" + }, + "source": [ + "# Introduction to Geospatial Data\n", + "\n", + "_Written by: Adam J. Stewart_\n", + "\n", + "In this tutorial, we introduce the challenges of working with geospatial data, especially remote sensing imagery. This is not meant to discourage practicioners, but to elucidate why existing computer vision domain libraries like torchvision are insufficient for working with multispectral satellite imagery." + ] + }, + { + "cell_type": "markdown", + "id": "4cc902a5-0a06-4b02-af47-31b124da8193", + "metadata": {}, + "source": [ + "## Common Modalities\n", + "\n", + "Geospatial data come in a wide variety of common modalities. Below, we dive into each modality and discuss what makes it unique." + ] + }, + { + "cell_type": "markdown", + "id": "7d02bf4d-e979-4d41-bf70-e1b5a73bac2f", + "metadata": {}, + "source": [ + "### Tabular data\n", + "\n", + "Many geospatial datasets, especially those collected by in-situ sensors, are distributed in tabular format. For example, imagine weather or air quality stations that distribute example data like:\n", + "\n", + "| Latitude | Longitude | Temperature | Pressure | PM$_{2.5}$ | O$_3$ | CO |\n", + "| -------: | --------: | ----------: | -------: | ---------: | ----: | -----: |\n", + "| 40.7128 | 74.0060 | 1 | 1025 | 20.0 | 4 | 473.9 |\n", + "| 37.7749 | 122.4194 | 11 | 1021 | 21.4 | 6 | 1259.5 |\n", + "| ... | ... | ... | ... | ... | ... | ... |\n", + "| 41.8781 | 87.6298 | -1 | 1024 | 14.5 | 30 | - |\n", + "| 25.7617 | 80.1918 | 17 | 1026 | 5.0 | - | - |\n", + "\n", + "This kind of data is relatively easy to load and integrate into a machine learning pipeline. The following models work well for tabular data:\n", + "\n", + "* Multi-Layer Perceptrons (MLPs): for unstructured data\n", + "* Recurrent Neural Networks (RNNs): for time-series data\n", + "* Graph Neural Networks (GNNs): for ungridded geospatial data\n", + "\n", + "Note that it is not uncommon for there to be missing values (as is the case for air pollutants in some cities) due to missing or faulty sensors. Data imputation may be required to fill in these missing values. Also make sure all values are converted to a common set of units." + ] + }, + { + "cell_type": "markdown", + "id": "b0076503-57d4-4803-b7ba-dc6b96dd5cf8", + "metadata": {}, + "source": [ + "### Multispectral\n", + "\n", + "Although traditional computer vision datasets are typically restricted to red-green-blue (RGB) images, remote sensing satellites typically capture 3–15 different spectral bands with wavelengths far outside of the visible spectrum. Mathematically speaking, each image will be formatted as:\n", + "\n", + "$$ x \\in \\mathbb{R}^{C \\times H \\times W},$$\n", + "\n", + "where:\n", + "\n", + "* $C$ is the number of spectral bands (color channels),\n", + "* $H$ is the height of each image (in pixels), and\n", + "* $W$ is the width of each image (in pixels).\n", + "\n", + "Below, we see a false-color composite created using spectral channels outside of the visible spectrum (such as near-infrared):\n", + "\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "03de4b08-3941-4b76-9eb1-11a57a7c9684", + "metadata": {}, + "source": [ + "### Hyperspectral\n", + "\n", + "While multispectral images are often limited to 3–15 disjoint spectral bands, hyperspectral sensors capture hundreds of spectral bands to approximate the continuous color spectrum. These images often present a particular challenge to convolutional neural networks (CNNs) due to the sheer data volume, and require either small image patches (decreased $H$ and $W$) or dimensionality reduction (decreased $C$) in order to avoid out-of-memory errors on the GPU.\n", + "\n", + "Below, we see a hyperspectral data cube, with each color channel visualized along the $z$-axis:\n", + "\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "20c83407-0496-45af-b0e0-2c615b0d9a03", + "metadata": {}, + "source": [ + "### Radar\n", + "\n", + "Passive sensors (ones that do not emit light) are limited by daylight hours and cloud-free conditions. Active sensors such as radar emit polarized microwave pulses and measure the time it takes for the signal to reflect or scatter off of objects. This allows radar satellites to operate at night and in adverse weather conditions. The images captured by these sensors are stored as complex numbers, with a real (amplitude) and imaginary (phase) component, making it difficult to integrate them into machine learning pipelines.\n", + "\n", + "Radar is commonly used in meteorology (Doppler radar) and geophysics (ground penetrating radar). By attaching a radar antenna to a moving satellite, a larger effective aperature is created, increasing the spatial resolution of the captured image. This technique is known as synthetic aperature radar (SAR), and has many common applications in geodesy, flood mapping, and glaciology. Finally, by comparing the phases of multiple SAR snapshots of a single location at different times, we can analyze minute changes in surface elevation, in a technique known as Interferometric Synthetic Aperature Radar (InSAR). Below, we see an interferogram of earthquake deformation:\n", + "\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "53b27a38-64cc-43e7-91da-1244fb9dd416", + "metadata": {}, + "source": [ + "### Lidar\n", + "\n", + "Similar to radar, lidar is another active remote sensing method that replaces microwave pulses with lasers. By measuring the time it takes light to reflect off of an object and return to the sensor, we can generate a 3D point cloud mapping object structures. Mathematically, our dataset would then become:\n", + "\n", + "$$D = \\left\\{\\left(x^{(i)}, y^{(i)}, z^{(i)}\\right)\\right\\}_{i=1}^N$$\n", + "\n", + "This technology is frequently used in several different application domains:\n", + "\n", + "* Meteorology: clouds, aerosols\n", + "* Geodesy: surveying, archaeology\n", + "* Forestry: tree height, biomass density\n", + "\n", + "Below, we see a 3D point cloud captured for a city:\n", + "\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "da27ea73-2e3d-43d8-ba0b-cdac92ab2f81", + "metadata": {}, + "source": [ + "## Resolution\n", + "\n", + "Remote sensing data comes in a number of spatial, temporal, and spectral resolutions.\n", + "\n", + "
\n", + "Warning: In computer vision, resolution usually refers to the dimensions of an image (in pixels). In remote sensing, resolution instead refers to the dimensions of each pixel (in meters). Throughout this tutorial, we will use the latter definition unless otherwise specified.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "bcd51946-62eb-44db-bf3b-bd2df5308ab6", + "metadata": {}, + "source": [ + "### Spatial resolution\n", + "\n", + "Choosing the right data for your application is often controlled by the resolution of the imagery. Spatial resolution, also called ground sample distance (GSD), is the size of each pixel as measured on the Earth's surface. While the exact definitions change as satellites become better, approximate ranges of resolution include:\n", + "\n", + "| Category | Resolution | Examples |\n", + "| -------: | ---------: | :------: |\n", + "| Low resolution | > 30 m | MODIS (250 m–1 km), GOES-16 (500 m–2 km) |\n", + "| Medium resolution | 5–30 m | Sentinel-2 (10–60 m), Landsat-9 (15–100 m) |\n", + "| High resolution | 1–5 m | Planet Dove (3–5 m), RapidEye (5 m) |\n", + "| Very high resolution | < 1 m | Maxar WorldView-3 (0.3 m), QuickBird (0.6 m) |\n", + "\n", + "It is not uncommon for a single sensor to capture high resolution panchromatic bands, medium resolution visible bands, and low resolution thermal bands. It is also possible for pixels to be non-square, as is the case for OCO-2. All bands must be resampled to the same resolution for use in machine learning pipelines." + ] + }, + { + "cell_type": "markdown", + "id": "bce349d9-8b5e-48e4-800a-c2fbb1c343cb", + "metadata": {}, + "source": [ + "### Temporal resolution\n", + "\n", + "For time-series applications, it is also important to think about the repeat period of the satellite you want to use. Depending on the orbit of the satellite, imagery can be anywhere from biweekly (for polar, sun-synchronous orbits) to continuous (for geostationary orbits). The latter is common for global Earth observation missions, while the latter is common for weather and communications satellites. Below, we see an illustration of a geostationary orbit:\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "Due to partial overlap in orbit paths and intermittent cloud cover, satellite image time series (SITS) are often of irregular length and irregular spacing. This can be especially challenging for naïve time-series models to handle." + ] + }, + { + "cell_type": "markdown", + "id": "6c278653-d14b-44c8-971b-9851b7515b0f", + "metadata": {}, + "source": [ + "### Spectral resolution\n", + "\n", + "It is also important to consider the spectral resolution of a sensor, including both the number of spectral bands and the bandwidth that is captured. Different downstream applications require different spectral bands, and there is often a tradeoff between additional spectral bands and higher spatial resolution. The following figure compares the wavelengths captured by sensors onboard different satellites:\n", + "\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "c79a050a-5389-4fb4-8501-3f1be067a166", + "metadata": {}, + "source": [ + "## Preprocessing\n", + "\n", + "Geospatial data also has unique preprocessing requirements that necessitate experience working with a variety of tools like GDAL, the geospatial data abstraction library. GDAL support ~160 raster drivers and ~80 vector drivers, allowing users to reproject, resample, and rasterize data from a variety of specialty file formats." + ] + }, + { + "cell_type": "markdown", + "id": "5cadbfeb-40da-455f-8b9a-bb5a983aaa8b", + "metadata": {}, + "source": [ + "### Reprojection\n", + "\n", + "The Earth is three dimensional, but images are two dimensional. This requires a *projection* to map the 3D surface onto a 2D image, and a *coordinate reference system* (CRS) to map each point back to a specific latitude/longitude. Below, we see examples of a few common projections:\n", + "\n", + "
\n", + "
\n", + " \n", + "
Mercator
\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + " \n", + "
Albers Equal Area
\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + " \n", + "
Interrupted Goode Homolosine
\n", + "
\n", + "
\n", + "\n", + "There are literally thousands of different projections out there, and every dataset (or even different images within a single dataset) can have different projections. Even if you correctly georeference images during indexing, if you forget to project them to a common CRS, you can end up with rotated images with nodata values around them, and the images will not be pixel-aligned.\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "We can use a command like:\n", + "\n", + "```\n", + "$ gdalwarp -s_srs EPSG:5070 -t_srs EPSG:4326 src.tif dst.tif\n", + "```\n", + "\n", + "to reproject a file from one CRS to another." + ] + }, + { + "cell_type": "markdown", + "id": "b0362b1e-1c47-4884-a414-699db82acb6e", + "metadata": {}, + "source": [ + "### Resampling\n", + "\n", + "As previously mentioned, each dataset may have its own unique spatial resolution, and even separate bands (channels) in a single image may have different resolutions. All data (including input images and target masks for semantic segmentation) must be resampled to the same resolution. This can be done using GDAL like so:\n", + "\n", + "```\n", + "$ gdalwarp -tr 30 30 src.tif dst.tif\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "ddbd742a-11b4-4fb9-a27a-58f5f14e2982", + "metadata": {}, + "source": [ + "Just because two files have the same resolution does not mean that they have *target-aligned pixels* (TAP). Our goal is that every input pixel is perfectly aligned with every expected output pixel, but differences in geolocation can result in masks that are offset by half a pixel from the input image. We can ensure TAP by adding the `-tap` flag:\n", + "\n", + "```\n", + "$ gdalwarp -tr 30 30 -tap src.tif dst.tif\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "c0c499d9-e266-4619-863f-e416cc823c58", + "metadata": {}, + "source": [ + "### Rasterization\n", + "\n", + "Not all geospatial data is raster data. Many files come in vector format, including points, lines, and polygons.\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "Of course, semantic segmentation requires these polygon masks to be converted to raster masks. This process is called rasterization, and can be performed like so:\n", + "\n", + "```\n", + "$ gdal_rasterize -tr 30 30 -a BUILDING_HEIGHT -l buildings buildings.shp buildings.tif\n", + "```\n", + "\n", + "Above, we set the resolution to 30 m/pixel and use the `BUILDING_HEIGHT` attribute of the `buildings` layer as the burn-in value.\n" + ] + }, + { + "cell_type": "markdown", + "id": "a3acc64e-8dc0-46b4-a677-ecb9723d4f56", + "metadata": {}, + "source": [ + "## Additional Reading\n", + "\n", + "Luckily, TorchGeo can handle most preprocessing for us. If you would like to learn more about working with geospatial data, including how to manually do the above tasks, the following additional reading may be useful:\n", + "\n", + "* [GDAL documentation](https://gdal.org/en/stable/index.html)\n", + "* [rasterio documentation](https://rasterio.readthedocs.io/en/stable/index.html)\n", + "* [Guide to GeoTIFF compression and optimization with GDAL](https://kokoalberti.com/articles/geotiff-compression-optimization-guide/)\n", + "* [A survival guide to Landsat preprocessing](https://doi.org/10.1002/ecy.1730)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "getting_started.ipynb", + "provenance": [] + }, + "execution": { + "timeout": 1200 + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/getting_started.ipynb b/docs/tutorials/getting_started.ipynb deleted file mode 100644 index 1b1982711b3..00000000000 --- a/docs/tutorials/getting_started.ipynb +++ /dev/null @@ -1,322 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "35303546", - "metadata": {}, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "id": "9478ed9a", - "metadata": { - "id": "NdrXRgjU7Zih" - }, - "source": [ - "# Getting Started\n", - "\n", - "In this tutorial, we demonstrate some of the basic features of TorchGeo and show how easy it is to use if you're already familiar with other PyTorch domain libraries like torchvision.\n", - "\n", - "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." - ] - }, - { - "cell_type": "markdown", - "id": "34f10e9f", - "metadata": { - "id": "lCqHTGRYBZcz" - }, - "source": [ - "## Setup\n", - "\n", - "First, we install TorchGeo." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "019092f0", - "metadata": {}, - "outputs": [], - "source": [ - "%pip install torchgeo" - ] - }, - { - "cell_type": "markdown", - "id": "4db9f791", - "metadata": { - "id": "dV0NLHfGBMWl" - }, - "source": [ - "## Imports\n", - "\n", - "Next, we import TorchGeo and any other libraries we need." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3d92b0f1", - "metadata": { - "id": "entire-albania" - }, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "\n", - "from torch.utils.data import DataLoader\n", - "\n", - "from torchgeo.datasets import NAIP, ChesapeakeDE, stack_samples\n", - "from torchgeo.datasets.utils import download_url\n", - "from torchgeo.samplers import RandomGeoSampler" - ] - }, - { - "cell_type": "markdown", - "id": "7f26e4b8", - "metadata": { - "id": "5rLknZxrBEMz" - }, - "source": [ - "## Datasets\n", - "\n", - "For this tutorial, we'll be using imagery from the [National Agriculture Imagery Program (NAIP)](https://catalog.data.gov/dataset/national-agriculture-imagery-program-naip) and labels from the [Chesapeake Bay High-Resolution Land Cover Project](https://www.chesapeakeconservancy.org/conservation-innovation-center/high-resolution-data/land-cover-data-project/). First, we manually download a few NAIP tiles and create a PyTorch Dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4a39af46", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 232, - "referenced_widgets": [ - "d00a2177bf4b4b8191bfc8796f0e749f", - "17d6b81aec50455989276b595457cc7f", - "06ccd130058b432dbfa025c102eaeb27", - "6bc5b9872b574cb5aa6ebd1d44e7a71f", - "f7746f028f874a85b6101185fc9a8efc", - "f7ef78d6f87a4a2685788e395525fa7c", - "5b2450e316e64b4ba432c78b63275124", - "d3bbd6112f144c77bc68e5f2a7a355ff", - "a0300b1252cd4da5a798b55c15f8f5fd", - "793c2851b6464b398f7b4d2f2f509722", - "8dd61c8479d74c95a55de147e04446b3", - "b57d15e6c32b4fff8994ae67320972f6", - "9a34f8907a264232adf6b0d0543461dd", - "e680eda3c84c440083e2959f04431bea", - "a073e33fd9ae4125822fc17971233770", - "87faaa32454a42939d3bd405e726228c", - "b3d4c9c99bec4e69a199e45920d52ce4", - "a215f3310ea543d1a8991f57ec824872", - "569f60397fd6440d825e8afb83b4e1ae", - "b7f604d2ba4e4328a451725973fa755f", - "737fa148dfae49a18cc0eabbe05f2d0f", - "0b6613adbcc74165a9d9f74988af366e", - "b25f274c737d4212b3ffeedb2372ba22", - "ef0fc75ff5044171be942a6b3ba0c2da", - "612d84013a6e4890a48eb229f6431233", - "9a689285370646ab800155432ea042a5", - "014ed48a23234e8b81dd7ac4dbf95817", - "93c536a27b024728a00486b1f68b4dde", - "8a8538a91a74439b81e3f7c6516763e3", - "caf540562b484594bab8d6210dd7c2c1", - "99cd2e65fb104380953745f2e0a93fac", - "c5b818707bb64c5a865236a46399cea2", - "54f5db9555c44efa9370cbb7ab58e142", - "1d83b20dbb9c4c6a9d5c100fe4770ba4", - "c51b2400ca9442a9a9e0712d5778cd9a", - "bd2e44a8eb1a4c19a32da5a1edd647d1", - "0f9feea4b8344a7f8054c9417150825e", - "31acb7a1ca8940078e1aacd72e547f47", - "0d0ca8d64d3e4c2f88d87342808dd677", - "54402c5f8df34b83b95c94104b26e2c6", - "910b98584fa74bb5ad308fe770f5b40e", - "b2dce834ee044d69858389178b493a2b", - "237f2e31bcfe476baafae8d922877e07", - "43ac7d95481b4ea3866feef6ace2f043" - ] - }, - "id": "e3138ac3", - "outputId": "11589c46-eee6-455d-839b-390f2934d834" - }, - "outputs": [], - "source": [ - "naip_root = os.path.join(tempfile.gettempdir(), 'naip')\n", - "naip_url = (\n", - " 'https://naipeuwest.blob.core.windows.net/naip/v002/de/2018/de_060cm_2018/38075/'\n", - ")\n", - "tiles = [\n", - " 'm_3807511_ne_18_060_20181104.tif',\n", - " 'm_3807511_se_18_060_20181104.tif',\n", - " 'm_3807512_nw_18_060_20180815.tif',\n", - " 'm_3807512_sw_18_060_20180815.tif',\n", - "]\n", - "for tile in tiles:\n", - " download_url(naip_url + tile, naip_root)\n", - "\n", - "naip = NAIP(naip_root)" - ] - }, - { - "cell_type": "markdown", - "id": "e25bad40", - "metadata": { - "id": "HQVji2B22Qfu" - }, - "source": [ - "Next, we tell TorchGeo to automatically download the corresponding Chesapeake labels." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "689bb2b0", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "2Ah34KAw2biY", - "outputId": "03b7bdf0-78c1-4a13-ac56-59de740d7f59" - }, - "outputs": [], - "source": [ - "chesapeake_root = os.path.join(tempfile.gettempdir(), 'chesapeake')\n", - "os.makedirs(chesapeake_root, exist_ok=True)\n", - "chesapeake = ChesapeakeDE(chesapeake_root, crs=naip.crs, res=naip.res, download=True)" - ] - }, - { - "cell_type": "markdown", - "id": "56f2d78b", - "metadata": { - "id": "OWUhlfpD22IX" - }, - "source": [ - "Finally, we create an IntersectionDataset so that we can automatically sample from both GeoDatasets simultaneously." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "daefbc4d", - "metadata": { - "id": "WXxy8F8l2-aC" - }, - "outputs": [], - "source": [ - "dataset = naip & chesapeake" - ] - }, - { - "cell_type": "markdown", - "id": "ded44652", - "metadata": { - "id": "yF_R54Yf3EUd" - }, - "source": [ - "## Sampler\n", - "\n", - "Unlike typical PyTorch Datasets, TorchGeo GeoDatasets are indexed using lat/long/time bounding boxes. This requires us to use a custom GeoSampler instead of the default sampler/batch_sampler that comes with PyTorch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b8a0d99c", - "metadata": { - "id": "RLczuU293itT" - }, - "outputs": [], - "source": [ - "sampler = RandomGeoSampler(dataset, size=1000, length=10)" - ] - }, - { - "cell_type": "markdown", - "id": "5b8c1c52", - "metadata": { - "id": "OWa-mmYd8S6K" - }, - "source": [ - "## DataLoader\n", - "\n", - "Now that we have a Dataset and Sampler, we can combine these into a single DataLoader." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "96faa142", - "metadata": { - "id": "jfx-9ZmU8ZTc" - }, - "outputs": [], - "source": [ - "dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples)" - ] - }, - { - "cell_type": "markdown", - "id": "64ae63f7", - "metadata": { - "id": "HZIfqqW58oZe" - }, - "source": [ - "## Training\n", - "\n", - "Other than that, the rest of the training pipeline is the same as it is for torchvision." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8a2b44f8", - "metadata": { - "id": "7sGmNvBy8uIg" - }, - "outputs": [], - "source": [ - "for sample in dataloader:\n", - " image = sample['image']\n", - " target = sample['mask']" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "getting_started.ipynb", - "provenance": [] - }, - "execution": { - "timeout": 1200 - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/tutorials/getting_started.rst b/docs/tutorials/getting_started.rst new file mode 100644 index 00000000000..8049a95a20f --- /dev/null +++ b/docs/tutorials/getting_started.rst @@ -0,0 +1,18 @@ +Getting Started +=============== + +New to deep learning or remote sensing? First time using PyTorch or TorchGeo? You've come to the right place. + +The following tutorials will teach you enough to get started: + +* `Introduction to PyTorch `_: A brief overview of deep learning with PyTorch +* `Introduction to Geospatial Data `_: A brief overview of the challenges of working with geospatial data +* `Introduction to TorchGeo `_: A brief overview of the design of TorchGeo + +.. toctree:: + :hidden: + :maxdepth: 1 + + pytorch + geospatial + torchgeo diff --git a/docs/tutorials/indices.ipynb b/docs/tutorials/indices.ipynb index 30576609ac8..5695cd51b0d 100644 --- a/docs/tutorials/indices.ipynb +++ b/docs/tutorials/indices.ipynb @@ -1,14 +1,15 @@ { "cells": [ { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { "id": "DYndcZst_kdr" }, + "outputs": [], "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." ] }, { @@ -17,7 +18,9 @@ "id": "ZKIkyiLScf9P" }, "source": [ - "# Indices" + "# Spectral Indices\n", + "\n", + "_Written by: Isaac A. Corley_" ] }, { @@ -374,9 +377,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.13.0" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/docs/tutorials/pretrained_weights.ipynb b/docs/tutorials/pretrained_weights.ipynb index 28b354efee6..3a7023f0ee9 100644 --- a/docs/tutorials/pretrained_weights.ipynb +++ b/docs/tutorials/pretrained_weights.ipynb @@ -1,14 +1,15 @@ { "cells": [ { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { "id": "p63J-QmUrMN-" }, + "outputs": [], "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." ] }, { @@ -19,6 +20,8 @@ "source": [ "# Pretrained Weights\n", "\n", + "_Written by: Nils Lehmann_\n", + "\n", "In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/). We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images.\n", "\n", "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." @@ -147,9 +150,11 @@ "source": [ "## Weights\n", "\n", - "Available pretrained weights are listed on the model documentation [page](https://torchgeo.readthedocs.io/en/stable/api/models.html). While some weights only accept RGB channel input, some weights have been pretrained on Sentinel 2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel 2 data.\n", + "Pretrained weights for `torchgeo.models` are available and sorted by satellite or sensor type: sensor-agnostic, Landsat, NAIP, Sentinel-1, and Sentinel-2. Refer to the [model documentation](https://torchgeo.readthedocs.io/en/stable/api/models.html#pretrained-weights) for a complete list of weights. Choose from the provided pre-trained weights based on your specific use case.\n", + "\n", + "While some weights only accept RGB channel input, some weights have been pretrained on Sentinel-2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel-2 data.\n", "\n", - "To access these weights you can do the following:" + "To use these weights, you can load them as follows:" ] }, { @@ -169,7 +174,16 @@ "id": "EIpnXuXgrMOM" }, "source": [ - "This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer object as well as the training logic." + "This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`torchgeo.trainers` provides specialized task classes that simplify training workflows for common geospatial tasks. Depending on your objective, you can select the appropriate trainer class, such as `ClassificationTask` for classification, `SemanticSegmentationTask` for semantic segmentation, or other task-specific trainers. Check the [trainers documentation](https://torchgeo.readthedocs.io/en/stable/api/trainers.html) for more information.\n", + "\n", + "Given that EuroSAT is a classification dataset, we can use a `ClassificationTask` object that holds the model and optimizer as well as the training logic." ] }, { @@ -215,7 +229,7 @@ "id": "dWidC6vDrMON" }, "source": [ - "If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/rwightman/pytorch-image-models) model with pretrained weights from TorchGeo as follows:" + "If you do not want to utilize the `ClassificationTask` functionality for your experiments, you can also just create a [timm](https://github.com/huggingface/pytorch-image-models) model with pretrained weights from TorchGeo as follows:" ] }, { @@ -495,7 +509,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.13.0" }, "vscode": { "interpreter": { @@ -504,5 +518,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/docs/tutorials/prolog.rst.jinja b/docs/tutorials/prolog.rst.jinja new file mode 100644 index 00000000000..2381495107b --- /dev/null +++ b/docs/tutorials/prolog.rst.jinja @@ -0,0 +1,31 @@ +{# Macros #} +{% macro image(badge, class, alt, target) %} +.. image:: {{ badge }} + :class: {{ class }} + :alt: {{ alt }} + :target: {{ target }} +{% endmacro %} + +{# Global variables #} +{% if "dev" in env.config.release %} + {% set branch = "main" %} +{% else %} + {% set branch = "releases/v" ~ env.config.version %} +{% endif %} +{% set class = "tutorial-badge" %} +{% set path = "/microsoft/torchgeo/blob/" ~ branch ~ "/docs/" ~ env.docname ~ ".ipynb" %} + +{# Lightning Studio #} +{% set badge = "https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" %} +{% set alt = "Open in Studio" %} +{% set repo_url = "https://github.com" ~ path %} +{% set target = "https://lightning.ai/new?repo_url=" ~ repo_url | urlencode %} + +{{ image(badge, class, alt, target) }} + +{# Google Colab #} +{% set badge = "https://colab.research.google.com/assets/colab-badge.svg" %} +{% set alt = "Open in Colab" %} +{% set target = "https://colab.research.google.com/github" ~ path %} + +{{ image(badge, class, alt, target) }} diff --git a/docs/tutorials/pytorch.ipynb b/docs/tutorials/pytorch.ipynb new file mode 100644 index 00000000000..b66d3cc95b8 --- /dev/null +++ b/docs/tutorials/pytorch.ipynb @@ -0,0 +1,507 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "45973fd5-6259-4e03-9501-02ee96f3f870", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "id": "9478ed9a", + "metadata": { + "id": "NdrXRgjU7Zih" + }, + "source": [ + "# Introduction to PyTorch\n", + "\n", + "_Written by: Adam J. Stewart_\n", + "\n", + "In this tutorial, we introduce the basics of deep learning with PyTorch. Understanding deep learning terminology and the training and evaluation pipeline in PyTorch is essential to using TorchGeo." + ] + }, + { + "cell_type": "markdown", + "id": "34f10e9f", + "metadata": { + "id": "lCqHTGRYBZcz" + }, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo and all of its dependencies, including PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "019092f0", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "cell_type": "markdown", + "id": "4db9f791", + "metadata": { + "id": "dV0NLHfGBMWl" + }, + "source": [ + "## Imports\n", + "\n", + "Next, we import PyTorch, TorchGeo, and any other libraries we need. We also manually set the random seed to ensure the reproducibility of our experiments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d92b0f1", + "metadata": { + "id": "entire-albania" + }, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "import kornia.augmentation as K\n", + "import torch\n", + "from torch import nn, optim\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchgeo.datasets import EuroSAT100\n", + "from torchgeo.models import ResNet18_Weights, resnet18\n", + "\n", + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "markdown", + "id": "9d13c2db-e5d4-4d83-846b-a2c32774bb44", + "metadata": {}, + "source": [ + "## Definitions\n", + "\n", + "If this is your first introduction to deep learning (DL), a natural question might be \"what _is_ deep learning?\". You may also be curious how it relates to other similar buzz words, including artificial intelligence (AI) and machine learning (ML). We can define these terms as follows:\n", + "\n", + "* AI: when machines exhibit human intelligence\n", + "* ML: when machines learn from example\n", + "* DL: when machines learn using neural networks\n", + "\n", + "In this definition, DL is a subset of ML, and ML is a subset of AI. Some common examples of models and applications of these include:\n", + "\n", + "* AI: Minimax, A*, Deep Blue, video game AI\n", + "* ML: OLS, SVM, $k$-means, spam filtering\n", + "* DL: MLP, CNN, ChatGPT, self-driving cars\n", + "\n", + "In this tutorial, we will specifically focus on deep learning, but many of the same concepts are shared with machine learning." + ] + }, + { + "cell_type": "markdown", + "id": "7f26e4b8", + "metadata": { + "id": "5rLknZxrBEMz" + }, + "source": [ + "## Datasets\n", + "\n", + "In order to learn by example, we first need examples. In machine learning, we construct datasets of the form:\n", + "\n", + "$$D = \\left\\{\\left(x^{(i)}, y^{(i)}\\right)\\right\\}_{i=1}^N$$\n", + "\n", + "Written in English, dataset $D$ is composed of $N$ pairs of inputs $x$ and expected outputs $y$. $x$ and $y$ can be tabular data, images, text, or any other object that can be represented mathematically.\n", + "\n", + "![EuroSAT](https://github.com/phelber/EuroSAT/blob/master/eurosat-overview.png?raw=true)\n", + "\n", + "In this tutorial (and many later tutorials), we will use EuroSAT100, a toy dataset composed of 100 images from the [EuroSAT](https://github.com/phelber/EuroSAT) dataset. EuroSAT is a popular image classification dataset with multispectral images from the Sentinel-2 satellites. Each image is classified into one of ten categories or \"classes\":\n", + "\n", + "0. Annual Crop\n", + "1. Forest\n", + "2. Herbaceous Vegetation\n", + "3. Highway\n", + "4. Industrial Buildings\n", + "5. Pasture\n", + "6. Permanent Crop\n", + "7. Residential Buildings\n", + "8. River\n", + "9. Sea & Lake\n", + "\n", + "We can load this dataset and visualize the RGB bands of some example $(x, y)$ pairs like so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa0c5a0c-ac4c-44c5-9fb7-fe4be07a0f01", + "metadata": {}, + "outputs": [], + "source": [ + "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", + "dataset = EuroSAT100(root, download=True)\n", + "\n", + "for i in torch.randint(len(dataset), (10,)):\n", + " sample = dataset[i]\n", + " dataset.plot(sample)" + ] + }, + { + "cell_type": "markdown", + "id": "f89e20ae-d3b6-4f05-a83f-f7034dd9862f", + "metadata": {}, + "source": [ + "In machine learning, we not only want to train a model, but also evaluate its performance on unseen data. Oftentimes, our dataset is split into three separate subsets:\n", + "\n", + "* train: for training the model *parameters*\n", + "* val: for validating the model *hyperparameters*\n", + "* test: for testing the model *performance*\n", + "\n", + "Parameters are the actual model weights, while hyperparameters are things like model width or learning rate that are chosen by the user. We can initialize datasets for all three splits like so:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4785cddb-9821-4a2a-aa86-c08ffb6f2ebc", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = EuroSAT100(root, split='train')\n", + "val_dataset = EuroSAT100(root, split='val')\n", + "test_dataset = EuroSAT100(root, split='test')" + ] + }, + { + "cell_type": "markdown", + "id": "3e92d5be-8400-4c8a-83b0-314a672f22d1", + "metadata": {}, + "source": [ + "## Data Loaders\n", + "\n", + "While our dataset objects know how to load a single $(x, y)$ pair, machine learning often operates on what are called *mini-batches* of data. We can pass our above datasets to a PyTorch DataLoader object to construct these mini-batches:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8909c035-cbe9-49b6-8380-360914093f9a", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 10\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)\n", + "val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False)\n", + "test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False)" + ] + }, + { + "cell_type": "markdown", + "id": "e7162d06-8814-4680-8192-aff279e70049", + "metadata": {}, + "source": [ + "## Transforms\n", + "\n", + "There are two categories of transforms a user may want to apply to their data:\n", + "\n", + "* Preprocessing: required to make data \"ML-ready\"\n", + "* Data augmentation: designed to artificially inflate the size of the dataset\n", + "\n", + "Preprocessing transforms such as normalization and one-hot encodings are applied to both training and evaluation data. Data augmentation transforms such as random flip and rotation are typically only performed during training. Below, we initialize transforms for both using the [Kornia](https://kornia.readthedocs.io/en/latest/augmentation.html) library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efc0152c-e7e4-4f06-9418-9d2c5dd803c3", + "metadata": {}, + "outputs": [], + "source": [ + "preprocess = K.Normalize(0, 10000)\n", + "augment = K.ImageSequential(K.RandomHorizontalFlip(), K.RandomVerticalFlip())" + ] + }, + { + "cell_type": "markdown", + "id": "e80cda68-19ea-4cd5-a3bc-cf8fcf8147af", + "metadata": {}, + "source": [ + "## Model\n", + "\n", + "Our goal is to learn some function $f$ that can map between input $x$ and expected output $y$. Mathematically, this can be expressed as:\n", + "\n", + "$$x \\overset{f}{\\mapsto} y, \\quad y = f(x)$$\n", + "\n", + "Since our $x$ in this case is an image, we choose to use ResNet-18, a popular *convolutional neural network* (CNN). We also initialize our model with weights that have been pre-trained on Sentinel-2 imagery so we don't have to start from scratch. This process is known as *transfer learning*." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7cda7a8-2cd6-46a0-a1c2-bafc751a23f2", + "metadata": {}, + "outputs": [], + "source": [ + "model = resnet18(ResNet18_Weights.SENTINEL2_ALL_MOCO)" + ] + }, + { + "cell_type": "markdown", + "id": "a12b9e9b-26cc-43f4-a517-31b805862df5", + "metadata": {}, + "source": [ + "## Loss Function\n", + "\n", + "If $y$ is our expected output (also called \"ground truth\") and $\\hat{y}$ is our predicted output, our goal is to minimize the difference between $y$ and $\\hat{y}$. This difference is referred to as *error* or *loss*, and the loss function tells us how big of a mistake we made. For regression tasks, a simple mean squared error is sufficient:\n", + "\n", + "$$L(y, \\hat{y}) = \\left(y - \\hat{y}\\right)^2$$\n", + "\n", + "For classification tasks, such as EuroSAT, we instead use a negative log-likelihood:\n", + "\n", + "$$L_c(y, \\hat{y}) = - \\sum_{c=1}^C \\mathbb{1}_{y=\\hat{y}}\\log{p_c}$$\n", + "\n", + "where $\\mathbb{1}$ is the indicator function and $p_c$ is the probability with which the model predicts class $c$. By normalizing this over the log probability of all classes, we get the cross-entropy loss." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a0a699e-9bb3-4a06-91d2-401dd048ba66", + "metadata": {}, + "outputs": [], + "source": [ + "loss_fn = nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "markdown", + "id": "7743edf6-5fec-494d-8842-6cf8b45a2289", + "metadata": {}, + "source": [ + "## Optimizer\n", + "\n", + "In order to minimize our loss, we compute the gradient of the loss function with respect to model parameters $\\theta$. We then take a small step $\\alpha$ (also called the *learning rate*) in the direction of the negative gradient to update our model parameters in a process called *backpropagation*:\n", + "\n", + "$$\\theta \\leftarrow \\theta - \\alpha \\nabla_\\theta L(y, \\hat{y})$$\n", + "\n", + "When done one image or one mini-batch at a time, this is known as *stochastic gradient descent* (SGD)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15b0c17b-db53-41b2-96aa-3b732684b4cd", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = optim.SGD(model.parameters(), lr=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "7efbe79d-a9a0-4a23-b2f3-21b4ea0af7bd", + "metadata": {}, + "source": [ + "## Device\n", + "\n", + "If you peak into the internals of deep learning models, you'll notice that most of it is actually linear algebra. This linear algebra is extremely easy to parallelize, and therefore can run very quickly on a GPU. We now transfer our model and all data to the GPU (if one is available):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a006a71f-0802-49b3-bd45-ddd524ae36a4", + "metadata": {}, + "outputs": [], + "source": [ + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "model = model.to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "7af95903-79c9-4a61-a7a4-2d41c884fba0", + "metadata": {}, + "source": [ + "## Training\n", + "\n", + "We finally have all the basic components we need to train our ResNet-18 model on the EuroSAT100 dataset. During training, we set the model to train mode, then iterate over all mini-batches in the dataset. During the forward pass, we ask the model $f$ to predict $\\hat{y}$ given $x$. We then calculate the loss accrued by these predictions. During the backward pass, we backpropagate our gradients to update all model weights." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d235772d-475e-42e7-bc7c-f50729ee0e22", + "metadata": {}, + "outputs": [], + "source": [ + "def train(dataloader):\n", + " model.train()\n", + " total_loss = 0\n", + " for batch in dataloader:\n", + " x = batch['image'].to(device)\n", + " y = batch['label'].to(device)\n", + "\n", + " # Forward pass\n", + " y_hat = model(x)\n", + " loss = loss_fn(y_hat, y)\n", + " total_loss += loss.item()\n", + "\n", + " # Backward pass\n", + " loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " print(f'Loss: {total_loss:.2f}')" + ] + }, + { + "cell_type": "markdown", + "id": "1fd82312-cd17-4886-bcb6-8e42633e5009", + "metadata": {}, + "source": [ + "## Evaluation\n", + "\n", + "Once the model is trained, we need to evaluate its performance on unseen data. To do this, we set the model to evaluation mode, then iterate over all mini-batches in the dataset. Note that we also disable the computation of gradients, since we do not need to backpropagate them. Finally, we compute the number of correctly classified images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bddce3b-ed2f-4a5c-b3c6-2f1a3a51c2d9", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(dataloader):\n", + " model.eval()\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for batch in dataloader:\n", + " x = batch['image'].to(device)\n", + " y = batch['label'].to(device)\n", + "\n", + " # Forward pass\n", + " y_hat = model(x)\n", + " correct += (y_hat.argmax(1) == y).type(torch.float).sum().item()\n", + "\n", + " correct /= len(dataloader.dataset)\n", + " print(f'Accuracy: {correct:.0%}')" + ] + }, + { + "cell_type": "markdown", + "id": "f62a54e5-897d-476c-8d84-381993dbabbd", + "metadata": {}, + "source": [ + "## Putting It All Together\n", + "\n", + "In machine learning, we typically iterate over our datasets multiple times. Each full pass through the dataset is called an *epoch*. The following hyperparameter controls the number of epoch for which we train our model, and can be modified to train the model for longer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb5dc7e8-6cb3-4457-83ad-7fa5aef8ea0c", + "metadata": { + "nbmake": { + "mock": { + "epochs": 1 + } + } + }, + "outputs": [], + "source": [ + "epochs = 100" + ] + }, + { + "cell_type": "markdown", + "id": "f53526e6-54a3-43f7-a377-dca298730387", + "metadata": {}, + "source": [ + "During each epoch, we train the model on our training dataset, then evaluate its performance on the validation dataset. The goal is for training loss to decrease and validation accuracy to increase, although you should expect noise in the training process. Generally, you want to train the model until the validation accuracy starts to plateau or even decrease." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97601568-ba75-443d-81cf-494956b2924c", + "metadata": {}, + "outputs": [], + "source": [ + "for epoch in range(epochs):\n", + " print(f'Epoch: {epoch}')\n", + " train(train_dataloader)\n", + " evaluate(val_dataloader)" + ] + }, + { + "cell_type": "markdown", + "id": "e130fc89-0823-4814-85f8-d4416d6df395", + "metadata": {}, + "source": [ + "Finally, we evaluate our performance on the test dataset. Note that we are only training our model on a toy dataset consisting of 100 images. If we instead trained on the full dataset (replace `EuroSAT100` with `EuroSAT` in the above code), we would likely get much higher performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cd0bd25-e19a-4b26-94a1-fe9a544e8afd", + "metadata": {}, + "outputs": [], + "source": [ + "evaluate(test_dataloader)" + ] + }, + { + "cell_type": "markdown", + "id": "a3acc64e-8dc0-46b4-a677-ecb9723d4f56", + "metadata": {}, + "source": [ + "## Additional Reading\n", + "\n", + "If you are new to machine learning and overwhelmed by all of the above terminology, or would like to gain a better understanding of some of the math that goes into machine learning, I would highly recommend a formal machine learning or deep learning course. The following official PyTorch tutorials are also worth exploring:\n", + "\n", + "* [PyTorch: Learn the Basics](https://pytorch.org/tutorials/beginner/basics/intro.html)\n", + "* [Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)\n", + "* [Transfer Learning for Computer Vision](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "getting_started.ipynb", + "provenance": [] + }, + "execution": { + "timeout": 1200 + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/torchgeo.ipynb b/docs/tutorials/torchgeo.ipynb new file mode 100644 index 00000000000..cc6e17c6170 --- /dev/null +++ b/docs/tutorials/torchgeo.ipynb @@ -0,0 +1,445 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "45973fd5-6259-4e03-9501-02ee96f3f870", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "id": "9478ed9a", + "metadata": { + "id": "NdrXRgjU7Zih" + }, + "source": [ + "# Introduction to TorchGeo\n", + "\n", + "_Written by: Adam J. Stewart_\n", + "\n", + "Now that we've seen the basics of PyTorch and the challenges of working with geospatial data, let's see how TorchGeo addresses these challenges." + ] + }, + { + "cell_type": "markdown", + "id": "34f10e9f", + "metadata": { + "id": "lCqHTGRYBZcz" + }, + "source": [ + "## Setup\n", + "\n", + "First, we install TorchGeo and all of its dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "019092f0", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torchgeo" + ] + }, + { + "cell_type": "markdown", + "id": "4db9f791", + "metadata": { + "id": "dV0NLHfGBMWl" + }, + "source": [ + "## Imports\n", + "\n", + "Next, we import TorchGeo and any other libraries we need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d92b0f1", + "metadata": { + "id": "entire-albania" + }, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "from datetime import datetime\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from torchgeo.datasets import CDL, BoundingBox, Landsat7, Landsat8, stack_samples\n", + "from torchgeo.datasets.utils import download_and_extract_archive\n", + "from torchgeo.samplers import GridGeoSampler, RandomGeoSampler" + ] + }, + { + "cell_type": "markdown", + "id": "b813beba-62ad-430c-96e5-1d81bef1e244", + "metadata": {}, + "source": [ + "## Motivation\n", + "\n", + "Let's start with a common task in geospatial machine learning to motivate us: land cover mapping. Imagine you have a collection of imagery and a land cover layer or *mask* you would like to learn to predict. In machine learning, this pixelwise classification process is referred to as *semantic segmentation*.\n", + "\n", + "More concretely, imagine you would like to combine a set of Landsat 7 and 8 scenes with the Cropland Data Layer (CDL). This presents a number of challenges for a typical machine learning pipeline:\n", + "\n", + "* We may have hundreds of partially overlapping Landsat images that need to be mosaiced together\n", + "* We have a single CDL mask covering the entire continental US\n", + "* Neither the Landsat input or CDL output will have the same geospatial bounds\n", + "* Landsat is multispectral, and may have a different resolution for each spectral band\n", + "* Landsat 7 and 8 have a different number of spectral bands\n", + "* Landsat and CDL may have a differerent CRS\n", + "* Every single Landsat file may be in a different CRS (e.g., multiple UTM zones)\n", + "* We may have multiple years of input and output data, and need to ensure matching time spans\n", + "\n", + "We can't have a dataset of length 1, and it isn't obvious what to do when the number, bounds, and size of input images differ from the output masks. Furthermore, each image is far too large to pass to a neural network. \n", + "\n", + "Traditionally, people either performed classification on a single pixel at a time or curated their own benchmark dataset. This works fine for training, but isn't really useful for inference. What we would really like to be able to do is sample small pixel-aligned pairs of input images and output masks from the region of overlap between both datasets. This exact situation is illustrated in the following figure:\n", + "\n", + "![Landsat CDL intersection](https://github.com/microsoft/torchgeo/blob/main/images/geodataset.png?raw=true)\n", + "\n", + "Now, let's see what features TorchGeo has to support this kind of use case." + ] + }, + { + "cell_type": "markdown", + "id": "41119706-0722-4fd0-85a7-787bb12bbab8", + "metadata": {}, + "source": [ + "## Datasets\n", + "\n", + "Geospatial data comes in a wide variety of formats. TorchGeo has two separate classes of datasets to deal with this dataset diversity:\n", + "\n", + "* `NonGeoDataset`: for curated benchmark datasets, where geospatial metadata is either missing or unnecessary\n", + "* `GeoDataset`: for uncurated raster and vector data layers, where geospatial metadata is critical for merging datasets\n", + "\n", + "We have already seen the former in the Introduction to PyTorch tutorial, as `EuroSAT100` is a subclass of `NonGeoDataset`. In this tutorial, we will focus on the latter and its advantages for working with uncurated data." + ] + }, + { + "cell_type": "markdown", + "id": "914b39c6-3373-4ae8-b9ea-d377e73e9fbe", + "metadata": {}, + "source": [ + "### Landsat\n", + "\n", + "First, let's start with our Landsat imagery. We will download a couple of Landsat 7 and 8 scenes, then pass them to builtin TorchGeo datasets for each." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48d2a61d-16bb-4809-9da0-3bd369bff070", + "metadata": {}, + "outputs": [], + "source": [ + "landsat_root = os.path.join(tempfile.gettempdir(), 'landsat')\n", + "\n", + "url = 'https://hf.co/datasets/torchgeo/tutorials/resolve/ff30b729e3cbf906148d69a4441cc68023898924/'\n", + "landsat7_url = url + 'LE07_L2SP_022032_20230725_20230820_02_T1.tar.gz'\n", + "landsat8_url = url + 'LC08_L2SP_023032_20230831_20230911_02_T1.tar.gz'\n", + "\n", + "download_and_extract_archive(landsat7_url, landsat_root)\n", + "download_and_extract_archive(landsat8_url, landsat_root)\n", + "\n", + "landsat7_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']\n", + "landsat8_bands = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']\n", + "\n", + "landsat7 = Landsat7(paths=landsat_root, bands=landsat7_bands)\n", + "landsat8 = Landsat8(paths=landsat_root, bands=landsat8_bands)\n", + "\n", + "print(landsat7)\n", + "print(landsat8)\n", + "\n", + "print(landsat7.crs)\n", + "print(landsat8.crs)" + ] + }, + { + "cell_type": "markdown", + "id": "ce12838a-1010-46cb-bcca-6379f9e327ac", + "metadata": {}, + "source": [ + "The following details are worth noting:\n", + "\n", + "* We ignore the \"coastal blue\" band of Landsat 8 because it does not exist in Landsat 7\n", + "* Even though all files are stored in the same directory, the datasets know which files to include\n", + "* `paths` can be a directory to recursively search, a list of local files, or even a list of remote cloud assets" + ] + }, + { + "cell_type": "markdown", + "id": "a51c5df2-5543-41ae-a9cf-254e29b6bdfd", + "metadata": {}, + "source": [ + "### CDL\n", + "\n", + "Next, let's do the same for the CDL dataset. We are using a smaller cropped version of this dataset to make the download faster." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "909d233b-b212-48f1-b910-3065f8fcf083", + "metadata": {}, + "outputs": [], + "source": [ + "cdl_root = os.path.join(tempfile.gettempdir(), 'cdl')\n", + "\n", + "cdl_url = url + '2023_30m_cdls.zip'\n", + "\n", + "download_and_extract_archive(cdl_url, cdl_root)\n", + "\n", + "cdl = CDL(paths=cdl_root)\n", + "\n", + "print(cdl)\n", + "print(cdl.crs)" + ] + }, + { + "cell_type": "markdown", + "id": "571a6512-494f-401a-bf2f-599f28b2fad5", + "metadata": {}, + "source": [ + "Again, the following details are worth noting:\n", + "\n", + "* We could actually ask the `CDL` dataset to download our data for us by adding `download=True`\n", + "* All datasets have different spatial extents\n", + "* All datasets have different CRSs" + ] + }, + { + "cell_type": "markdown", + "id": "4a15b938-3277-46bc-86e4-a5d7f57e838a", + "metadata": {}, + "source": [ + "### Composing datasets\n", + "\n", + "We would like to be able to intelligently combine all three datasets in order to train a land cover mapping model. This requires us to create a virtual mosaic of all Landsat scenes, regardless of overlap. This can be done by taking the *union* of both datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b5adace-d7c9-4c27-9e53-ae532b081046", + "metadata": {}, + "outputs": [], + "source": [ + "landsat = landsat7 | landsat8\n", + "print(landsat)\n", + "print(landsat.crs)" + ] + }, + { + "cell_type": "markdown", + "id": "ddac6f18-36de-4241-a150-0ee50d0f40dd", + "metadata": {}, + "source": [ + "Similarly, we only want to sample from locations with both input imagery and output masks, not locations with only one or the other. We can achieve this by taking the *intersection* of both datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dd9f067-0e00-47ac-8bc1-6e7cd9e41e4d", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = landsat & cdl\n", + "print(dataset)\n", + "print(dataset.crs)" + ] + }, + { + "cell_type": "markdown", + "id": "48d2afbe-aab8-415e-a0df-fdb0d5209a49", + "metadata": {}, + "source": [ + "Note that all datasets now have the same CRS. When you run this code, you should notice it happen very quickly. TorchGeo hasn't actually created a mosaic yet or reprojected anything, it will do this on the fly for us." + ] + }, + { + "cell_type": "markdown", + "id": "4df7ee26-2c11-4e70-b113-e633fbbc2cd9", + "metadata": {}, + "source": [ + "### Spatiotemporal indexing\n", + "\n", + "How did we do this? TorchGeo uses a data structure called an *R-tree* to store the spatiotemporal bounding box of every file in the dataset. \n", + "\n", + "![R-tree](https://raw.githubusercontent.com/davidmoten/davidmoten.github.io/master/resources/rtree-3d/plot2.png)\n", + "\n", + "TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3992c571-0a6f-4d28-a2dc-e5915c00901e", + "metadata": {}, + "outputs": [], + "source": [ + "size = 256\n", + "\n", + "xmin = 925000\n", + "xmax = xmin + size * 30\n", + "ymin = 4470000\n", + "ymax = ymin + size * 30\n", + "tmin = datetime(2023, 1, 1).timestamp()\n", + "tmax = datetime(2023, 12, 31).timestamp()\n", + "\n", + "bbox = BoundingBox(xmin, xmax, ymin, ymax, tmin, tmax)\n", + "sample = dataset[bbox]\n", + "\n", + "landsat8.plot(sample)\n", + "cdl.plot(sample)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "bc591543-6d74-47b3-8c24-feada66d0a38", + "metadata": {}, + "source": [ + "TorchGeo uses *windowed-reading* to only read the blocks of memory needed to load a small patch from a large raster tile. It also automatically reprojects all data to the same CRS and resolution (from the first dataset). This can be controlled by explicitly passing `crs` or `res` to the dataset." + ] + }, + { + "cell_type": "markdown", + "id": "e2e4221e-dfb7-4966-96a6-e52400ae266c", + "metadata": {}, + "source": [ + "## Samplers\n", + "\n", + "The above `BoundingBox` makes it easy to index into complex datasets consisting of hundreds of files. However, it is a bit cumbersome to manually construct these queries every time, especially if we want thousands or even millions of bounding boxes. Luckily, TorchGeo provides a `GeoSampler` class to construct these for us." + ] + }, + { + "cell_type": "markdown", + "id": "47a7423d-32a9-40ae-be62-d54805835b19", + "metadata": {}, + "source": [ + "### Random sampling\n", + "\n", + "Usually, at training time, we want the largest possible dataset we can muster. For curated benchmark datasets like `EuroSAT100`, we achieved this by applying data augmentation to artificially inflate the size and diversity of our dataset. For `GeoDataset` objects, we can achieve this using random sampling. It doesn't matter if two or more of our images have partial overlap, as long as they bring unique pixels that help our model learn. \n", + "\n", + "TorchGeo provides a `RandomGeoSampler` to achieve this. We just tell the sampler how large we want each image patch to be (in pixel coordinates or CRS units) and, optionally, the number of image patches per epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36a60164-aa88-4773-a38f-d40960f4bfb2", + "metadata": {}, + "outputs": [], + "source": [ + "train_sampler = RandomGeoSampler(dataset, size=size, length=1000)\n", + "next(iter(train_sampler))" + ] + }, + { + "cell_type": "markdown", + "id": "b6d35b26-edae-46dc-b232-878421faa84d", + "metadata": {}, + "source": [ + "### Gridded sampling\n", + "\n", + "At evaluation time, this actually becomes a problem. We want to make sure we aren't making multiple predictions for the same location. We also want to make sure we don't miss any locations. To achieve this, TorchGeo also provides a `GridGeoSampler`. We can tell the sampler the size of each image patch and the stride of our sliding window." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33340c1a-756f-4ffe-ae3d-c2307fc98d07", + "metadata": {}, + "outputs": [], + "source": [ + "test_sampler = GridGeoSampler(dataset, size=size, stride=size)\n", + "next(iter(test_sampler))" + ] + }, + { + "cell_type": "markdown", + "id": "b9806919-6520-4da6-9eb3-3e1e6a10498e", + "metadata": {}, + "source": [ + "## Data Loaders\n", + "\n", + "All of these abstractions (`GeoDataset` and `GeoSampler`) are fully compatible with all of the rest of PyTorch. We can simply pass them to a data loader like below. Note that we also need the `stack_samples` collation function to convert a list of samples to a mini-batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd44d29d-b7c0-4617-bb94-d41a14e8f54a", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataloader = DataLoader(\n", + " dataset, batch_size=128, sampler=train_sampler, collate_fn=stack_samples\n", + ")\n", + "test_dataloader = DataLoader(\n", + " dataset, batch_size=128, sampler=test_sampler, collate_fn=stack_samples\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e46e8453-df25-4265-a85b-75dce7dea047", + "metadata": {}, + "source": [ + "Now that we have working data loaders, we can copy-n-paste our training code from the Introduction to PyTorch tutorial. We only need to change our model to one designed for semantic segmentation, such as a U-Net. Every other line of code would be identical to how you would do this in your normal PyTorch workflow." + ] + }, + { + "cell_type": "markdown", + "id": "a3acc64e-8dc0-46b4-a677-ecb9723d4f56", + "metadata": {}, + "source": [ + "## Additional Reading\n", + "\n", + "TorchGeo has plenty of other tutorials and documentation. If you would like to get more insight into the design of TorchGeo, the following external resources are also helpful:\n", + "\n", + "* [TorchGeo: Deep Learning With Geospatial Data](https://arxiv.org/abs/2111.08872)\n", + "* [Geospatial deep learning with TorchGeo](https://pytorch.org/blog/geospatial-deep-learning-with-torchgeo/)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "getting_started.ipynb", + "provenance": [] + }, + "execution": { + "timeout": 1200 + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/trainers.ipynb b/docs/tutorials/trainers.ipynb index 5de3937a026..0394f3faebc 100644 --- a/docs/tutorials/trainers.ipynb +++ b/docs/tutorials/trainers.ipynb @@ -1,15 +1,16 @@ { "cells": [ { - "cell_type": "markdown", - "id": "b13c2251", + "cell_type": "code", + "execution_count": null, + "id": "16421d50-8d7a-4972-b06f-160fd890cc86", "metadata": { "id": "b13c2251" }, + "outputs": [], "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." ] }, { @@ -21,6 +22,8 @@ "source": [ "# Lightning Trainers\n", "\n", + "_Written by: Caleb Robinson_\n", + "\n", "In this tutorial, we demonstrate TorchGeo trainers to train and test a model. We will use the [EuroSAT](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) dataset throughout this tutorial. Specifically, a subset containing only 100 images. We will train models to predict land cover classes.\n", "\n", "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." @@ -328,7 +331,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.13.0" } }, "nbformat": 4, diff --git a/docs/tutorials/transforms.ipynb b/docs/tutorials/transforms.ipynb index a7de9f32c69..689b2eebd33 100644 --- a/docs/tutorials/transforms.ipynb +++ b/docs/tutorials/transforms.ipynb @@ -1,14 +1,15 @@ { "cells": [ { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { "id": "DYndcZst_kdr" }, + "outputs": [], "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT License." ] }, { @@ -20,6 +21,13 @@ "# Transforms" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "_Written by: Isaac A. Corley_" + ] + }, { "cell_type": "markdown", "metadata": { @@ -93,7 +101,7 @@ "from torch.utils.data import DataLoader\n", "\n", "from torchgeo.datasets import EuroSAT100\n", - "from torchgeo.transforms import AugmentationSequential, indices" + "from torchgeo.transforms import indices" ] }, { @@ -208,11 +216,11 @@ " 'B06': 'Vegetation Red Edge 2',\n", " 'B07': 'Vegetation Red Edge 3',\n", " 'B08': 'NIR 1',\n", - " 'B8A': 'NIR 2',\n", " 'B09': 'Water Vapour',\n", " 'B10': 'SWIR 1',\n", " 'B11': 'SWIR 2',\n", " 'B12': 'SWIR 3',\n", + " 'B8A': 'NIR 2',\n", "}" ] }, @@ -408,7 +416,7 @@ "id": "p28C8cTGE3dP" }, "source": [ - "Transforms are able to operate across batches of samples and singular samples. This allows them to be used inside the dataset itself or externally, chained together with other transform operations using `nn.Sequential`. " + "`torchgeo.transforms` work seamlessly with both singular samples and batches of data. They can be applied within datasets or externally and combined with other transforms using `nn.Sequential`. Built for multispectral imagery, they are fully compatible with `torchvision.transforms` and `kornia.augmentation`." ] }, { @@ -429,13 +437,24 @@ "print(x.dtype, x.min(), x.max())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Appending Indices\n", + "\n", + "`torchgeo.transforms` support appending indices to a specified channel dimension.\n", + "\n", + "For detailed usage of all available transforms, refer to the [transforms documentation](https://torchgeo.readthedocs.io/en/stable/api/transforms.html)." + ] + }, { "cell_type": "markdown", "metadata": { "id": "KRjb-u0EEmDf" }, "source": [ - "Indices can also be computed on batches of images and appended as an additional band to the specified channel dimension. Notice how the number of channels increases from 13 -> 14." + "The following example shows how indices can be computed on batches of images and appended as an additional band to the specified channel dimension. Notice how the number of channels increases from 13 -> 14." ] }, { @@ -500,7 +519,9 @@ "id": "w4ZbjxPyHoiB" }, "source": [ - "It's even possible to chain indices along with augmentations from Kornia for a single callable during training." + "It's even possible to chain indices along with augmentations from Kornia for a single callable during training.\n", + "\n", + "When using Kornia with a dictionary input, you must explicitly set `data_keys=None` during the creation of the augmentation pipeline." ] }, { @@ -515,7 +536,7 @@ }, "outputs": [], "source": [ - "transforms = AugmentationSequential(\n", + "transforms = K.AugmentationSequential(\n", " MinMaxNormalize(mins, maxs),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -523,7 +544,7 @@ " indices.AppendNDWI(index_green=2, index_nir=7),\n", " K.RandomHorizontalFlip(p=0.5),\n", " K.RandomVerticalFlip(p=0.5),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ")\n", "\n", "batch = next(dataloader)\n", @@ -569,7 +590,7 @@ "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", - "transforms = AugmentationSequential(\n", + "transforms = K.AugmentationSequential(\n", " MinMaxNormalize(mins, maxs),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -580,10 +601,10 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ")\n", "\n", - "transforms_gpu = AugmentationSequential(\n", + "transforms_gpu = K.AugmentationSequential(\n", " MinMaxNormalize(mins.to(device), maxs.to(device)),\n", " indices.AppendNDBI(index_swir=11, index_nir=7),\n", " indices.AppendNDSI(index_green=3, index_swir=11),\n", @@ -594,7 +615,7 @@ " K.RandomAffine(degrees=(0, 90), p=0.25),\n", " K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0), p=0.25),\n", " K.RandomResizedCrop(size=(512, 512), scale=(0.8, 1.0), p=0.25),\n", - " data_keys=['image'],\n", + " data_keys=None,\n", ").to(device)\n", "\n", "\n", @@ -664,7 +685,7 @@ }, "outputs": [], "source": [ - "transforms = AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=['image'])\n", + "transforms = K.AugmentationSequential(MinMaxNormalize(mins, maxs), data_keys=None)\n", "dataset = EuroSAT100(root, transforms=transforms)" ] }, @@ -689,6 +710,18 @@ "print(f\"Class Label: {dataset.classes[sample['label']]}\")\n", "image.resize((256, 256), resample=Image.BILINEAR)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Additional Reading\n", + "\n", + "To learn more about preprocessing and data augmentation transforms, the following external resources may be helpful:\n", + "\n", + "* [Kornia augmentations](https://kornia.readthedocs.io/en/latest/augmentation.html)\n", + "* [torchvision transforms](https://pytorch.org/vision/main/transforms.html)" + ] } ], "metadata": { @@ -717,9 +750,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.13.0" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 } diff --git a/docs/user/alternatives.rst b/docs/user/alternatives.rst index a87bc2b8b2c..dd241f55f1c 100644 --- a/docs/user/alternatives.rst +++ b/docs/user/alternatives.rst @@ -13,7 +13,7 @@ When deciding which library is most useful to you, it is worth considering the f Software is a living, breathing organism and is constantly undergoing change. If any of the above information is incorrect or out of date, or if you want to add a new project to this list, please open a PR! - *Last updated: 25 December 2023* + *Last updated: 30 November 2024* Features -------- @@ -22,7 +22,7 @@ Features .. csv-table:: :align: center - :file: features.csv + :file: metrics/features.csv :header-rows: 1 :widths: auto @@ -55,7 +55,7 @@ These are metrics that can be scraped from GitHub. .. csv-table:: :align: right - :file: github.csv + :file: metrics/github.csv :header-rows: 1 :widths: auto @@ -75,7 +75,7 @@ These are metrics that can be scraped from GitHub. **Commits**: The number of commits on the main development branch. This is another metric for how active development has been. However, this can vary a lot depending on whether PRs are merged with or without squashing first. -**Core SLOCs**: The number of source lines of code in the core library, excluding empty lines and comments. This tells you how large the library is, and how long it would take someone to write something like it themselves. We use `scc `_ to compute SLOCs. +**Core SLOCs**: The number of source lines of code in the core library, excluding empty lines and comments. This tells you how large the library is, and how long it would take someone to write something like it themselves. We use `scc `_ to compute SLOCs and exclude markup languages from the count. **Test SLOCs**: The number of source lines of code in the testing suite, excluding empty lines and comments. This tells you how well tested the project is. A good goal to strive for is a similar amount of code for testing as there is in the core library itself. @@ -86,17 +86,17 @@ These are metrics that can be scraped from GitHub. Downloads --------- -These are download metrics for the project. Note that these numbers can be artificially inflated by installs during continuous integration. They give you a better idea of the number of projects that depend on a library than the number of users of that library. +These are download metrics for the project. Note that these numbers can be artificially inflated by mirrors and installs during continuous integration. They give you a better idea of the number of projects that depend on a library than the number of users of that library. .. csv-table:: :align: right - :file: downloads.csv + :file: metrics/downloads.csv :header-rows: 1 :widths: auto -**PyPI Downloads**: The number of downloads from the Python Packaging Index. PyPI download metrics are computed by `PePy `_. +**PyPI Downloads**: The number of downloads from the Python Packaging Index. PyPI download metrics are computed by `PyPI Stats `_ and `PePy `_. -**CRAN Downloads**: The number of downloads from the Comprehensive R Archive Network. CRAN download metrics are computed by `Meta CRAN `_. +**CRAN Downloads**: The number of downloads from the Comprehensive R Archive Network. CRAN download metrics are computed by `Meta CRAN `_ and `DataScienceMeta `_. **Conda Downloads**: The number of downloads from Conda Forge. Conda download metrics are computed by `Conda Forge `_. @@ -113,5 +113,5 @@ These are download metrics for the project. Note that these numbers can be artif .. _DeepForest: https://github.com/weecology/DeepForest .. _SITS: https://github.com/e-sensing/sits .. _segment-geospatial: https://github.com/opengeos/segment-geospatial -.. _GeoTorchAI: https://github.com/wherobots/GeoTorchAI -.. _Moonshine: https://github.com/moonshinelabs-ai/moonshine +.. _TerraTorch: https://github.com/IBM/terratorch +.. _scikit-eo: https://github.com/yotarazona/scikit-eo diff --git a/docs/user/contributing.rst b/docs/user/contributing.rst index 5e44212d884..033fcecdea2 100644 --- a/docs/user/contributing.rst +++ b/docs/user/contributing.rst @@ -1,3 +1,5 @@ +.. _contributing: + Contributing ============ @@ -52,11 +54,11 @@ Tests TorchGeo uses `GitHub Actions `_ for Continuous Integration. We run a suite of unit tests on every commit to ensure that pull requests don't break anything. If you submit a pull request that adds or modifies any Python code, we require unit tests for that code before the pull request can be merged. -For example, if you add a new dataset in ``torchgeo/datasets/foo.py``, you'll need to create corresponding unit tests in ``tests/datasets/test_foo.py``. The easiest way to do this is to find unit tests for similar datasets and modify them for your dataset. These tests can then be run with `pytest `_: +For example, if you add a new dataset in ``torchgeo/datasets/foo.py``, you'll need to create corresponding unit tests in ``tests/datasets/test_foo.py``. The easiest way to do this is to find unit tests for similar datasets and modify them for your dataset. These tests can then be run with `pytest `_: .. code-block:: console - $ pytest --cov=torchgeo/datasets --cov-report=term-missing tests/datasets/test_foo.py + $ pytest --cov=torchgeo.datasets.foo tests/datasets/test_foo.py ========================= test session starts ========================= platform darwin -- Python 3.10.11, pytest-6.2.4, py-1.9.0, pluggy-0.13.0 rootdir: ~/torchgeo, configfile: pyproject.toml @@ -65,21 +67,19 @@ For example, if you add a new dataset in ``torchgeo/datasets/foo.py``, you'll ne tests/datasets/test_foo.py ....... [100%] - ---------- coverage: platform darwin, python 3.10.11-final-0 ----------- + --------- coverage: platform darwin, python 3.10.11-final-0 ----------- Name Stmts Miss Cover Missing ----------------------------------------------------------------------- - torchgeo/datasets/__init__.py 26 0 100% torchgeo/datasets/foo.py 177 62 65% 376-403, 429-496, 504-509 - ... ----------------------------------------------------------------------- - TOTAL 1709 920 46% + TOTAL 177 62 65% ========================== 7 passed in 6.20s ========================== From this output, you can see that all tests pass, but many lines of code in ``torchgeo/datasets/foo.py`` are not being tested, including 376--403, 429--496, etc. In order for this pull request to be merged, additional tests will need to be added until there is 100% test coverage. -These tests require `pytest `_ and `pytest-cov `_ to be installed. +These tests require `pytest `_ and `pytest-cov `_ to be installed. .. note:: If you add a new dataset, the tests will require some form of data to run. This data should be stored in ``tests/data/``. Please don't include real data, as this may violate the license the data is distributed under, and can involve very large file sizes. Instead, create fake data examples using the instructions found `here `__. @@ -88,18 +88,18 @@ These tests require `pytest `_ and `pytest-cov `_ compliant and maintain a high-quality codebase, we use a couple of linting tools: +In order to remain `PEP-8 `_ compliant and maintain a high-quality codebase, we use a few linting tools: * `ruff `_ for code formatting -* `mypy `_ for static type analysis +* `mypy `_ for static type analysis * `prettier `_ for code formatting These tools should be used from the root of the project to ensure that our configuration files are found. Ruff is relatively easy to use, and will automatically fix most issues it encounters: .. code-block:: console - $ ruff check $ ruff format + $ ruff check Mypy won't fix your code for you, but will warn you about potential issues with your code: @@ -111,21 +111,11 @@ Mypy won't fix your code for you, but will warn you about potential issues with If you've never used mypy before or aren't familiar with `Python type hints `_, this check can be particularly daunting. Don't hesitate to ask for help with resolving any of these warnings on your pull request. -Prettier is a code formatter that helps to ensure consistent code style across a project. It supports various languages. Follow these steps to install Prettier: - -1. Install Node.js: Prettier is a Node.js module, so you need to have Node.js installed on your system. You can download and install Node.js from the `Node.js official website `_. -2. Install Prettier: Use the following command to install the Prettier module in your project: - -.. code-block:: console - - $ npm install prettier --no-save - - -3. Run Prettier: Use the following command to run Prettier formating: +Prettier is a code formatter that helps to ensure consistent code style across a project. It supports various languages. .. code-block:: console - $ npx prettier . --write + $ prettier --write . You can also use `git pre-commit hooks `_ to automatically run these checks before each commit. pre-commit is a tool that automatically runs linters locally, so that you don't have to remember to run them manually and then have your code flagged by CI. You can set up pre-commit with: @@ -142,7 +132,7 @@ Now, every time you run ``git commit``, pre-commit will run and let you know if Documentation ------------- -All of our documentation is hosted on `Read the Docs `_. If you make non-trivial changes to the documentation, it helps to build the documentation yourself locally. To do this, make sure the dependencies are installed: +All of our documentation is hosted on `Read the Docs `_. If you make non-trivial changes to the documentation, it helps to build the documentation yourself locally. To do this, make sure the dependencies are installed: .. code-block:: console @@ -164,7 +154,7 @@ The resulting HTML files can be found in ``_build/html``. Open ``index.html`` in Tutorials --------- -TorchGeo has a number of tutorials included in the documentation that can be run in `Google Colab `_. These Jupyter notebooks are tested before each release to make sure that they still run properly. To test these locally, install `pytest `_ and `nbmake `_ and run: +TorchGeo has a number of tutorials included in the documentation that can be run in `Lightning Studios `_ and `Google Colab `_. These Jupyter notebooks are tested before each release to make sure that they still run properly. To test these locally, install `pytest `_ and `nbmake `_ and run: .. code-block:: console @@ -177,13 +167,14 @@ Datasets A major component of TorchGeo is the large collection of :mod:`torchgeo.datasets` that have been implemented. Adding new datasets to this list is a great way to contribute to the library. A brief checklist to follow when implementing a new dataset: * Implement the dataset extending either :class:`~torchgeo.datasets.GeoDataset` or :class:`~torchgeo.datasets.NonGeoDataset` -* Add the dataset definition to ``torchgeo/datasets/__init__.py`` -* Add a ``data.py`` script to ``tests/data//`` that generates test data with the same directory structure/file naming conventions as the new dataset -* Add appropriate tests with 100% test coverage to ``tests/datasets/`` +* Add the dataset definition to ``torchgeo/datasets/foo.py``, where *foo* is the name of the dataset +* Add an import alias to this dataset in ``torchgeo/datasets/__init__.py`` +* Add a ``tests/data/foo/data.py`` script that generates fake test data with the same directory structure/file naming conventions as the real dataset +* Add appropriate tests with 100% test coverage to ``tests/datasets/test_foo.py`` * Add the dataset to ``docs/api/datasets.rst`` -* Add the dataset metadata to either ``docs/api/geo_datasets.csv`` or ``docs/api/non_geo_datasets.csv`` +* Add the dataset metadata to either ``docs/api/datasets/geo_datasets.csv`` or ``docs/api/datasets/non_geo_datasets.csv`` -A good way to get started is by looking at some of the existing implementations that are most closely related to the dataset that you are implementing (e.g. if you are implementing a semantic segmentation dataset, looking at the LandCover.ai dataset implementation would be a good starting point). +A good way to get started is by looking at some of the existing implementations that are most closely related to the dataset that you are implementing (e.g., if you are implementing a semantic segmentation dataset, looking at the LandCover.ai dataset implementation would be a good starting point). I/O Benchmarking ---------------- diff --git a/docs/user/downloads.csv b/docs/user/downloads.csv deleted file mode 100644 index e82a1048255..00000000000 --- a/docs/user/downloads.csv +++ /dev/null @@ -1,10 +0,0 @@ -Library,PyPI/CRAN Last Week,PyPI/CRAN Last Month,PyPI/CRAN All Time,Conda All Time,Total All Time -`TorchGeo`_,"4,227","15,709","172,020","14,994","187,014" -`eo-learn`_,542,"2,224","119,657","29,806","149,463" -`Raster Vision`_,286,"1,446","53,029","2,211","55,240" -`PaddleRS`_,18,56,"1,196",0,"1,196" -`DeepForest`_,"2,429","11,092","686,997","48,368","735,365" -`SITS`_,128,728,"8,438","54,587","63,025" -`segment-geospatial`_,987,"6,088","64,067","8,226","72,293" -`GeoTorchAI`_,40,150,"2,210",0,"2,210" -`Moonshine`_,61,275,"6,243",0,"6,243" diff --git a/docs/user/github.csv b/docs/user/github.csv deleted file mode 100644 index 32a55c1c977..00000000000 --- a/docs/user/github.csv +++ /dev/null @@ -1,10 +0,0 @@ -Library,Contributors,Forks,Watchers,Stars,Issues,PRs,Releases,Commits,Core SLOCs,Test SLOCs,Test Coverage,License -`TorchGeo`_,56,243,44,"2,010",366,"1,315",10,"1,720","28,942","15,488",100%,MIT -`eo-learn`_,40,288,45,"1,058",159,622,40,"2,439","8,135","5,915",92%,MIT -`Raster Vision`_,30,375,74,"1,950",687,"1,220",19,"3,416","20,965","8,339",86%,Apache-2.0 -`PaddleRS`_,22,83,12,311,68,112,1,643,"21,859","2,156",48%,Apache-2.0 -`DeepForest`_,13,154,15,409,368,206,43,650,"2,375","1,149",86%,MIT -`SITS`_,13,74,29,415,564,495,41,"5,720","22,770","6,162",95%,GPL-2.0 -`segment-geospatial`_,11,244,51,"2,453",102,82,22,156,"5,355",92,22%,MIT -`GeoTorchAI`_,4,31,13,431,22,20,1,207,"6,153",550,38%,AGPL-3.0 -`Moonshine`_,1,1,4,121,2,5,1,48,245,56,69%,MIT diff --git a/docs/user/installation.rst b/docs/user/installation.rst index 3dd5545b774..9a96af95534 100644 --- a/docs/user/installation.rst +++ b/docs/user/installation.rst @@ -1,7 +1,7 @@ Installation ============ -TorchGeo is simple and easy to install. We support installation using the `pip `_, `conda `_, and `spack `_ package managers. +TorchGeo is simple and easy to install. We support installation using the `pip `_, `conda `_, and `spack `_ package managers. pip --- @@ -34,7 +34,7 @@ By default, only required dependencies are installed. TorchGeo has a number of o $ pip install torchgeo[style,tests] $ pip install torchgeo[all] -See the ``pyproject.toml`` for a complete list of options. See the `pip documentation `_ for more details. +See the ``pyproject.toml`` for a complete list of options. See the `pip documentation `_ for more details. conda ----- @@ -82,4 +82,4 @@ Optional dependencies can be installed by enabling build variants: $ spack install py-torchgeo+datasets $ spack install py-torchgeo+style+tests -Run ``spack info py-torchgeo`` for a complete list of variants. See the `spack documentation `_ for more details. +Run ``spack info py-torchgeo`` for a complete list of variants. See the `spack documentation `_ for more details. diff --git a/docs/user/metrics/downloads.csv b/docs/user/metrics/downloads.csv new file mode 100644 index 00000000000..c7049b95d89 --- /dev/null +++ b/docs/user/metrics/downloads.csv @@ -0,0 +1,10 @@ +Library,PyPI/CRAN Last Week,PyPI/CRAN Last Month,PyPI/CRAN All Time,Conda All Time,Total All Time +`TorchGeo`_,"8,435","30,948","311,897","25,174","337,071" +`eo-learn`_,309,"2,370","156,309","40,325","196,634" +`Raster Vision`_,"9,198","31,588","115,670","3,968","119,638" +`PaddleRS`_,16,53,"2,029",0,"2,029" +`segment-geospatial`_,"1,956","10,689","157,443","26,576","184,019" +`DeepForest`_,767,"13,925","827,339","71,367","898,706" +`TerraTorch`_,318,"1,322","7,037",0,"7,037" +`SITS`_,120,539,"14,618","78,976","91,743" +`scikit-eo`_,115,717,"14,700",0,"14,700" diff --git a/docs/user/features.csv b/docs/user/metrics/features.csv similarity index 70% rename from docs/user/features.csv rename to docs/user/metrics/features.csv index f9dc1b00dbe..67e26cd710e 100644 --- a/docs/user/features.csv +++ b/docs/user/metrics/features.csv @@ -1,10 +1,10 @@ Library,ML Backend,I/O Backend,Spatial Backend,Transform Backend,Datasets,Weights,CLI,Reprojection,STAC,Time-Series -`TorchGeo`_,PyTorch,"GDAL, h5py, laspy, OpenCV, pandas, pillow, scipy",R-tree,Kornia,71,43,✅,✅,❌,🚧 +`TorchGeo`_,PyTorch,"GDAL, h5py, laspy, OpenCV, pandas, pillow, scipy",R-tree,Kornia,92,69,✅,✅,❌,🚧 `eo-learn`_,scikit-learn,"GDAL, OpenCV, pandas",geopandas,numpy,0,0,❌,✅,❌,🚧 `Raster Vision`_,"PyTorch, TensorFlow*","GDAL, OpenCV, pandas, pillow, scipy, xarray",STAC,Albumentations,0,6,✅,✅,✅,✅ `PaddleRS`_,PaddlePaddle,"GDAL, OpenCV",shapely,numpy,7,14,🚧,✅,❌,🚧 -`DeepForest`_,PyTorch,"GDAL, OpenCV, pandas, pillow, scipy",R-tree,Albumentations,0,2,❌,❌,❌,❌ -`SITS`_,R Torch,GDAL,-,tidyverse,22,0,❌,✅,✅,✅ `segment-geospatial`_,PyTorch,"GDAL, OpenCV, pandas",geopandas,numpy,0,0,❌,✅,❌,❌ -`GeoTorchAI`_,PyTorch,"GDAL, pandas, xarray",Sedona,numpy,14,0,❌,❌,❌,🚧 -`Moonshine`_,PyTorch,-,-,numpy,0,3,❌,❌,❌,❌ +`DeepForest`_,PyTorch,"GDAL, OpenCV, pandas, pillow, scipy",R-tree,Albumentations,0,4,❌,❌,❌,❌ +`TerraTorch`_,PyTorch,"GDAL, h5py, pandas, xarray",R-tree,Albumentations,22,1,✅,✅,❌,🚧 +`SITS`_,R Torch,GDAL,-,tidyverse,22,0,❌,✅,✅,✅ +`scikit-eo`_,"scikit-learn, TensorFlow","pandas, scipy, numpy, rasterio","geopandas",numpy,0,0,❌,❌,❌,🚧 diff --git a/docs/user/metrics/github.csv b/docs/user/metrics/github.csv new file mode 100644 index 00000000000..96c678e6c3e --- /dev/null +++ b/docs/user/metrics/github.csv @@ -0,0 +1,10 @@ +Library,Contributors,Forks,Watchers,Stars,Issues,PRs,Releases,Commits,Core SLOCs,Test SLOCs,Test Coverage,License +`TorchGeo`_,76,352,51,"2,790",446,"1,860",13,"2,193","29,305","17,294",100%,MIT +`eo-learn`_,40,299,46,"1,131",160,640,45,"2,472","7,497","5,872",92%,MIT +`Raster Vision`_,32,388,71,"2,090",701,"1,430",23,"3,614","21,734","8,792",90%,Apache-2.0 +`PaddleRS`_,23,91,13,400,93,116,3,644,"20,679","3,239",48%,Apache-2.0 +`segment-geospatial`_,20,316,61,"3,078",150,136,38,229,"6,845",92,22%,MIT +`DeepForest`_,17,176,17,524,439,351,47,938,"3,320","1,886",86%,MIT +`TerraTorch`_,16,24,12,171,78,185,8,606,"14,933","2,077",44%,Apache-2.0 +`SITS`_,14,78,28,483,654,590,44,"6,244","24,284","8,697",94%,GPL-2.0 +`scikit-eo`_,7,20,9,192,24,13,17,510,"1,617","170",37%,Apache-2.0 diff --git a/experiments/ssl4eo/download_ssl4eo.py b/experiments/ssl4eo/download_ssl4eo.py index d93413bdd25..283e6e39c82 100755 --- a/experiments/ssl4eo/download_ssl4eo.py +++ b/experiments/ssl4eo/download_ssl4eo.py @@ -12,7 +12,7 @@ ### match and download pre-sampled locations python download_ssl4eo.py \ --save-path ./data \ - --collection COPERNICUS/S2 \ + --collection COPERNICUS/S2_HARMONIZED \ --meta-cloud-name CLOUDY_PIXEL_PERCENTAGE \ --cloud-pct 20 \ --dates 2021-12-21 2021-09-22 2021-06-21 2021-03-20 \ @@ -125,7 +125,7 @@ def filter_collection( if filtered.size().getInfo() == 0: raise ee.EEException( - f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' # noqa: E501 + f'ImageCollection.filter: No suitable images found in ({coords[1]:.4f}, {coords[0]:.4f}) between {period[0]} and {period[1]}.' ) return filtered @@ -319,7 +319,10 @@ def update(self, delta: int = 1) -> int: ) # collection properties parser.add_argument( - '--collection', type=str, default='COPERNICUS/S2', help='GEE collection name' + '--collection', + type=str, + default='COPERNICUS/S2_HARMONIZED', + help='GEE collection name', ) parser.add_argument('--qa-band', type=str, default='QA60', help='qa band name') parser.add_argument( @@ -517,7 +520,7 @@ def worker(idx: int) -> None: print(f'Downloaded {count} images in {time.time() - start_time:.3f}s.') else: if args.debug: - print('no suitable image for location %d.' % (idx)) + print(f'no suitable image for location {idx}.') # add to existing checked locations with open(ext_path, 'a') as f: diff --git a/experiments/ssl4eo/landsat/conf/l7irish.yaml b/experiments/ssl4eo/landsat/conf/l7irish.yaml index 91b1cbea15d..84c282fb951 100644 --- a/experiments/ssl4eo/landsat/conf/l7irish.yaml +++ b/experiments/ssl4eo/landsat/conf/l7irish.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 9 num_classes: 5 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -20,4 +20,4 @@ data: patch_size: 224 num_workers: 16 dict_kwargs: - paths: "data/l7irish" + paths: 'data/l7irish' diff --git a/experiments/ssl4eo/landsat/conf/l8biome.yaml b/experiments/ssl4eo/landsat/conf/l8biome.yaml index 728073a56fa..41de287cf54 100644 --- a/experiments/ssl4eo/landsat/conf/l8biome.yaml +++ b/experiments/ssl4eo/landsat/conf/l8biome.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 11 num_classes: 5 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -20,4 +20,4 @@ data: patch_size: 224 num_workers: 16 dict_kwargs: - paths: "data/l8biome" + paths: 'data/l8biome' diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_cdl.yaml index 93062f9942f..efcf09b01c2 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_cdl.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_cdl.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 6 num_classes: 18 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "etm_sr" - product: "cdl" + root: 'data/ssl4eo_benchmark' + sensor: 'etm_sr' + product: 'cdl' classes: - 0 - 1 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml index 718b3281ed0..bff225c5ba0 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_sr_nlcd.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 6 num_classes: 14 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "etm_sr" - product: "nlcd" + root: 'data/ssl4eo_benchmark' + sensor: 'etm_sr' + product: 'nlcd' classes: - 0 - 11 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_cdl.yaml index f10a0508ffe..15e1412d8b7 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_cdl.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_cdl.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 9 num_classes: 18 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "etm_toa" - product: "cdl" + root: 'data/ssl4eo_benchmark' + sensor: 'etm_toa' + product: 'cdl' classes: - 0 - 1 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml index 52a6107c096..29ecf722209 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_etm_toa_nlcd.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 9 num_classes: 14 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "etm_toa" - product: "nlcd" + root: 'data/ssl4eo_benchmark' + sensor: 'etm_toa' + product: 'nlcd' classes: - 0 - 11 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_cdl.yaml index 669e1221944..e289a69c46a 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_cdl.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_cdl.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 7 num_classes: 18 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "oli_sr" - product: "cdl" + root: 'data/ssl4eo_benchmark' + sensor: 'oli_sr' + product: 'cdl' classes: - 0 - 1 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml index 81f3283f5b9..ab1ed68dc52 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_sr_nlcd.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 7 num_classes: 14 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "oli_sr" - product: "nlcd" + root: 'data/ssl4eo_benchmark' + sensor: 'oli_sr' + product: 'nlcd' classes: - 0 - 11 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml index 876e25184c7..7066db83df8 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_cdl.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 11 num_classes: 18 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "oli_tirs_toa" - product: "cdl" + root: 'data/ssl4eo_benchmark' + sensor: 'oli_tirs_toa' + product: 'cdl' classes: - 0 - 1 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml index 77f3f6eae45..b403e63c434 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_oli_tirs_toa_nlcd.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 11 num_classes: 14 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "oli_tirs_toa" - product: "nlcd" + root: 'data/ssl4eo_benchmark' + sensor: 'oli_tirs_toa' + product: 'nlcd' classes: - 0 - 11 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_cdl.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_cdl.yaml index 7adf1e46f97..b359d5e390b 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_cdl.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_cdl.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 7 num_classes: 18 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "tm_toa" - product: "cdl" + root: 'data/ssl4eo_benchmark' + sensor: 'tm_toa' + product: 'cdl' classes: - 0 - 1 diff --git a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml index b52fee6c6ca..41a908de96d 100644 --- a/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml +++ b/experiments/ssl4eo/landsat/conf/ssl4eo_benchmark_tm_toa_nlcd.yaml @@ -4,12 +4,12 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' weights: null in_channels: 7 num_classes: 14 - loss: "ce" + loss: 'ce' ignore_index: 0 lr: 1e-3 patience: 6 @@ -19,9 +19,9 @@ data: batch_size: 64 num_workers: 16 dict_kwargs: - root: "data/ssl4eo_benchmark" - sensor: "tm_toa" - product: "nlcd" + root: 'data/ssl4eo_benchmark' + sensor: 'tm_toa' + product: 'nlcd' classes: - 0 - 11 diff --git a/experiments/ssl4eo/landsat/plot_landsat_bands.py b/experiments/ssl4eo/landsat/plot_landsat_bands.py index 2c72e7f887f..9f15a29a6d0 100755 --- a/experiments/ssl4eo/landsat/plot_landsat_bands.py +++ b/experiments/ssl4eo/landsat/plot_landsat_bands.py @@ -46,7 +46,7 @@ df = df.iloc[::-1] fig, ax = plt.subplots(figsize=(5.5, args.fig_height)) -ax1, ax2 = fig.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [3, 1]}) # type: ignore[misc] +ax1, ax2 = fig.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [3, 1]}) sensor_names: list[str] = [] sensor_ylocs: list[float] = [] @@ -125,8 +125,8 @@ # Labels ax.set_xlabel(r'Wavelength (\textmu m)') -ax.set_xticks([0], labels=[0]) -ax.set_yticks([0], labels=[0]) +ax.set_xticks([0], labels=['0']) +ax.set_yticks([0], labels=['0']) ax.tick_params(colors='w') ax.spines[['bottom', 'left', 'top', 'right']].set_visible(False) diff --git a/experiments/ssl4eo/landsat/plot_landsat_timeline.py b/experiments/ssl4eo/landsat/plot_landsat_timeline.py index ef855f01eee..83b868c4035 100755 --- a/experiments/ssl4eo/landsat/plot_landsat_timeline.py +++ b/experiments/ssl4eo/landsat/plot_landsat_timeline.py @@ -88,10 +88,10 @@ } xranges = [(start, end - start) for start, end in working[satellite]] - ax.broken_barh(xranges, hatch=None, **kwargs) + ax.broken_barh(xranges, hatch=None, **kwargs) # type: ignore[arg-type] xranges = [(start, end - start) for start, end in failing[satellite]] - ax.broken_barh(xranges, hatch='////', **kwargs) + ax.broken_barh(xranges, hatch='////', **kwargs) # type: ignore[arg-type] # Label xmin = global_xmax @@ -127,16 +127,16 @@ 'verticalalignment': 'center_baseline', } - ax.text(x, horizontalalignment=horizontalalignment, **kwargs) + ax.text(x, horizontalalignment=horizontalalignment, **kwargs) # type: ignore[arg-type] yticks.append(ymin + args.bar_height / 2) ymin += args.bar_height + args.bar_sep ax.xaxis_date() -ax.set_xlim(global_xmin, global_xmax) +ax.set_xlim(global_xmin, global_xmax) # type: ignore[arg-type] ax.set_ylabel('Landsat Mission') ax.set_yticks(yticks) -ax.set_yticklabels(range(9, 0, -1)) +ax.set_yticklabels(map(str, range(9, 0, -1))) ax.tick_params(axis='both', which='both', top=False, right=False) ax.spines[['top', 'right']].set_visible(False) diff --git a/experiments/ssl4eo/plot_example_predictions.py b/experiments/ssl4eo/plot_example_predictions.py index 596c8ea304d..96fe9103b9c 100755 --- a/experiments/ssl4eo/plot_example_predictions.py +++ b/experiments/ssl4eo/plot_example_predictions.py @@ -63,13 +63,9 @@ data = sample[key] if key == 'image': data = data[[2, 1, 0]].permute(1, 2, 0).numpy().astype('uint8') - Image.fromarray(data, 'RGB').save( # type: ignore[no-untyped-call] - f'{path}/{key}.png' - ) + Image.fromarray(data, 'RGB').save(f'{path}/{key}.png') else: data = data * 255 / 4 data = data.numpy().astype('uint8').squeeze() - Image.fromarray(data, 'L').save( # type: ignore[no-untyped-call] - f'{path}/{key}.png' - ) + Image.fromarray(data, 'L').save(f'{path}/{key}.png') i += 1 diff --git a/experiments/ssl4eo/sample_ssl4eo.py b/experiments/ssl4eo/sample_ssl4eo.py index 68d1056df55..69f82f10283 100755 --- a/experiments/ssl4eo/sample_ssl4eo.py +++ b/experiments/ssl4eo/sample_ssl4eo.py @@ -47,7 +47,7 @@ def get_world_cities( download_root: str = 'world_cities', size: int = 10000 ) -> pd.DataFrame: - url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' # noqa: E501 + url = 'https://simplemaps.com/static/data/world-cities/basic/simplemaps_worldcities_basicv1.71.zip' filename = 'worldcities.csv' download_and_extract_archive(url, download_root) cols = ['city', 'lat', 'lng', 'population'] diff --git a/experiments/torchgeo/conf/chesapeake_cvpr.yaml b/experiments/torchgeo/conf/chesapeake_cvpr.yaml index da2e012ed05..43c07480293 100644 --- a/experiments/torchgeo/conf/chesapeake_cvpr.yaml +++ b/experiments/torchgeo/conf/chesapeake_cvpr.yaml @@ -4,9 +4,9 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' weights: null lr: 1e-3 patience: 6 @@ -18,15 +18,15 @@ data: class_path: ChesapeakeCVPRDataModule init_args: train_splits: - - "de-train" + - 'de-train' val_splits: - - "de-val" + - 'de-val' test_splits: - - "de-test" + - 'de-test' batch_size: 200 patch_size: 256 num_workers: 4 class_set: ${model.init_args.num_classes} use_prior_labels: False dict_kwargs: - root: "data/chesapeake/cvpr" + root: 'data/chesapeake/cvpr' diff --git a/experiments/torchgeo/conf/cowc_counting.yaml b/experiments/torchgeo/conf/cowc_counting.yaml index 481ba40cd97..eaf328062be 100644 --- a/experiments/torchgeo/conf/cowc_counting.yaml +++ b/experiments/torchgeo/conf/cowc_counting.yaml @@ -16,4 +16,4 @@ data: batch_size: 64 num_workers: 4 dict_kwargs: - root: "data/cowc_counting" + root: 'data/cowc_counting' diff --git a/experiments/torchgeo/conf/etci2021.yaml b/experiments/torchgeo/conf/etci2021.yaml index c3f0ae487ca..8d42dc1e633 100644 --- a/experiments/torchgeo/conf/etci2021.yaml +++ b/experiments/torchgeo/conf/etci2021.yaml @@ -4,9 +4,9 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' weights: true lr: 1e-3 patience: 6 @@ -19,4 +19,4 @@ data: batch_size: 32 num_workers: 4 dict_kwargs: - root: "data/etci2021" + root: 'data/etci2021' diff --git a/experiments/torchgeo/conf/eurosat.yaml b/experiments/torchgeo/conf/eurosat.yaml index 6e788273aa6..f9c10e2ff70 100644 --- a/experiments/torchgeo/conf/eurosat.yaml +++ b/experiments/torchgeo/conf/eurosat.yaml @@ -4,8 +4,8 @@ trainer: model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' lr: 1e-3 patience: 6 weights: null @@ -17,4 +17,4 @@ data: batch_size: 128 num_workers: 4 dict_kwargs: - root: "data/eurosat" + root: 'data/eurosat' diff --git a/experiments/torchgeo/conf/landcoverai.yaml b/experiments/torchgeo/conf/landcoverai.yaml index e9ef4df66cf..d026f92c66f 100644 --- a/experiments/torchgeo/conf/landcoverai.yaml +++ b/experiments/torchgeo/conf/landcoverai.yaml @@ -4,9 +4,9 @@ trainer: model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' weights: true lr: 1e-3 patience: 6 @@ -20,4 +20,4 @@ data: batch_size: 32 num_workers: 4 dict_kwargs: - root: "data/landcoverai" + root: 'data/landcoverai' diff --git a/experiments/torchgeo/conf/resisc45.yaml b/experiments/torchgeo/conf/resisc45.yaml index 8a9d34c4ede..fde75ff8844 100644 --- a/experiments/torchgeo/conf/resisc45.yaml +++ b/experiments/torchgeo/conf/resisc45.yaml @@ -4,8 +4,8 @@ trainer: model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' lr: 1e-3 patience: 6 weights: null @@ -17,4 +17,4 @@ data: batch_size: 128 num_workers: 4 dict_kwargs: - root: "data/resisc45" + root: 'data/resisc45' diff --git a/experiments/torchgeo/conf/so2sat.yaml b/experiments/torchgeo/conf/so2sat.yaml index 1b9e7144263..ad2da906431 100644 --- a/experiments/torchgeo/conf/so2sat.yaml +++ b/experiments/torchgeo/conf/so2sat.yaml @@ -4,8 +4,8 @@ trainer: model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' lr: 1e-3 patience: 6 weights: null @@ -16,6 +16,6 @@ data: init_args: batch_size: 128 num_workers: 4 - band_set: "all" + band_set: 'all' dict_kwargs: - root: "data/so2sat" + root: 'data/so2sat' diff --git a/experiments/torchgeo/conf/ucmerced.yaml b/experiments/torchgeo/conf/ucmerced.yaml index 2a4d8786422..abafd7eaff7 100644 --- a/experiments/torchgeo/conf/ucmerced.yaml +++ b/experiments/torchgeo/conf/ucmerced.yaml @@ -4,8 +4,8 @@ trainer: model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' weights: null lr: 1e-3 patience: 6 @@ -17,4 +17,4 @@ data: batch_size: 128 num_workers: 4 dict_kwargs: - root: "data/ucmerced" + root: 'data/ucmerced' diff --git a/hubconf.py b/hubconf.py index f174d76256e..9362e056127 100644 --- a/hubconf.py +++ b/hubconf.py @@ -12,7 +12,10 @@ dofa_large_patch16_224, resnet18, resnet50, + resnet152, + scalemae_large_patch16, swin_v2_b, + swin_v2_t, vit_small_patch16_224, ) @@ -21,8 +24,11 @@ 'dofa_large_patch16_224', 'resnet18', 'resnet50', + 'resnet152', + 'scalemae_large_patch16', 'swin_v2_b', + 'swin_v2_t', 'vit_small_patch16_224', ) -dependencies = ['timm'] +dependencies = ['timm', 'torchvision'] diff --git a/requirements/package-lock.json b/package-lock.json similarity index 59% rename from requirements/package-lock.json rename to package-lock.json index d2d97509fbb..b57536fde23 100644 --- a/requirements/package-lock.json +++ b/package-lock.json @@ -1,11 +1,17 @@ { + "name": "torchgeo", "lockfileVersion": 3, "requires": true, "packages": { + "": { + "dependencies": { + "prettier": ">=3" + } + }, "node_modules/prettier": { - "version": "3.2.5", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.2.5.tgz", - "integrity": "sha512-3/GWa9aOC0YeD7LUfvOG2NiDyhOWRvt1k+rcKhOuYnMY24iiCphgneUfJDyFXd6rZCAnuLBv6UeAULtrhT/F4A==", + "version": "3.4.2", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.4.2.tgz", + "integrity": "sha512-e9MewbtFo+Fevyuxn/4rrcDAaq0IYxPGLvObpQjiZBMAzB9IGmzlnG9RZy3FFas+eBMu2vA0CszMeduow5dIuQ==", "bin": { "prettier": "bin/prettier.cjs" }, diff --git a/requirements/package.json b/package.json similarity index 51% rename from requirements/package.json rename to package.json index ebfe31dde86..d6e569c647e 100644 --- a/requirements/package.json +++ b/package.json @@ -2,6 +2,9 @@ "name": "torchgeo", "private": "true", "dependencies": { - "prettier": ">=3.0.0" + "prettier": ">=3" + }, + "prettier": { + "singleQuote": true } } diff --git a/pyproject.toml b/pyproject.toml index 9707f472d8e..a4e125f5a31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,14 +40,16 @@ dependencies = [ "einops>=0.3", # fiona 1.8.21+ required for Python 3.10 wheels "fiona>=1.8.21", - # kornia 0.7.2+ required for dict support in AugmentationSequential - "kornia>=0.7.2", - # lightly 1.4.4+ required for MoCo v3 support + # kornia 0.7.4+ required for AugmentationSequential support for unknown keys + "kornia>=0.7.4", + # lightly 1.4.5+ required for LARS optimizer # lightly 1.4.26 is incompatible with the version of timm required by smp # https://github.com/microsoft/torchgeo/issues/1824 - "lightly>=1.4.4,!=1.4.26", + "lightly>=1.4.5,!=1.4.26", # lightning 2+ required for LightningCLI args + sys.argv support - "lightning[pytorch-extra]>=2", + # lightning 2.3 contains known bugs related to YAML parsing + # https://github.com/Lightning-AI/pytorch-lightning/issues/19977 + "lightning[pytorch-extra]>=2,!=2.3.*,!=2.5.0", # matplotlib 3.5+ required for Python 3.10 wheels "matplotlib>=3.5", # numpy 1.21.2+ required by Python 3.10 wheels @@ -59,7 +61,9 @@ dependencies = [ # pyproj 3.3+ required for Python 3.10 wheels "pyproj>=3.3", # rasterio 1.3+ required for Python 3.10 wheels - "rasterio>=1.3", + # rasterio 1.4.0-1.4.2 lack support for merging WarpedVRT objects + # https://github.com/rasterio/rasterio/issues/3196 + "rasterio>=1.3,!=1.4.0,!=1.4.1,!=1.4.2", # rtree 1+ required for Python 3.10 wheels "rtree>=1", # segmentation-models-pytorch 0.2+ required for smp.losses module @@ -85,20 +89,16 @@ datasets = [ "laspy>=2", # opencv-python 4.5.4+ required for Python 3.10 wheels "opencv-python>=4.5.4", + # pandas 2+ required for parquet extra + "pandas[parquet]>=2", # pycocotools 2.0.7+ required for wheels "pycocotools>=2.0.7", # pyvista 0.34.2+ required to avoid ImportError in CI "pyvista>=0.34.2", - # radiant-mlhub 0.3+ required for newer tqdm support required by lightning - "radiant-mlhub>=0.3", - # rarfile 4+ required for wheels - "rarfile>=4", # scikit-image 0.19+ required for Python 3.10 wheels "scikit-image>=0.19", # scipy 1.7.2+ required for Python 3.10 wheels "scipy>=1.7.2", - # zipfile-deflate64 0.2+ required for Python 3.10 wheels - "zipfile-deflate64>=0.2", ] docs = [ # ipywidgets 7+ required by nbsphinx @@ -107,7 +107,7 @@ docs = [ "nbsphinx>=0.8.5", # release versions missing files, must install from master "pytorch-sphinx-theme", - # sphinx 4+ required for autodoc_typehints_description_target = documented + # sphinx 4+ required for autodoc_typehints_description_target # sphinx 6+ is incompatible with pytorch-sphinx-theme # https://github.com/pytorch/pytorch_sphinx_theme/issues/175 "sphinx>=4,<6", @@ -115,8 +115,8 @@ docs = [ style = [ # mypy 0.900+ required for pyproject.toml support "mypy>=0.900", - # ruff 0.2+ required for [ruff.lint] - "ruff>=0.2", + # ruff 0.8+ required for removal of ANN101, ANN102 + "ruff>=0.8", ] tests = [ # nbmake 1.3.3+ required for variable mocking @@ -137,35 +137,54 @@ torchgeo = "torchgeo.main:main" Homepage = "https://github.com/microsoft/torchgeo" Documentation = "https://torchgeo.readthedocs.io" +# https://coverage.readthedocs.io/en/latest/config.html [tool.coverage.report] # Ignore warnings for overloads # https://github.com/nedbat/coveragepy/issues/970#issuecomment-612602180 -exclude_lines = [ - "pragma: no cover", +exclude_also = [ "@overload", ] +show_missing = true +[tool.coverage.run] +source_pkgs = ["torchgeo"] + +# https://mypy.readthedocs.io/en/stable/config_file.html [tool.mypy] -python_version = "3.10" +# Import discovery ignore_missing_imports = true -show_error_codes = true -exclude = "(build|data|dist|docs/src|images|logo|logs|output)/" +exclude = "(build|data|dist|docs/.*|images|logo|.*logs|output|requirements)/" -# Strict -warn_unused_configs = true +# Disallow dynamic typing (TODO: work in progress) +disallow_any_unimported = false +disallow_any_expr = false +disallow_any_decorated = false +disallow_any_explicit = false disallow_any_generics = true disallow_subclassing_any = true + +# Untyped definitions and calls disallow_untyped_calls = true disallow_untyped_defs = true disallow_incomplete_defs = true -check_untyped_defs = true disallow_untyped_decorators = true -no_implicit_optional = true + +# Configuring warnings warn_redundant_casts = true warn_unused_ignores = true +warn_no_return = true warn_return_any = true -no_implicit_reexport = true +warn_unreachable = true + +# Miscellaneous strictness flags strict_equality = true +strict = true + +# Configuring error messages +pretty = true + +# Miscellaneous +warn_unused_configs = true [tool.pytest.ini_options] # Skip slow tests by default @@ -178,7 +197,7 @@ filterwarnings = [ # https://github.com/pytorch/vision/pull/5898 "ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:torchvision.transforms.functional_pil", "ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:torchvision.transforms._functional_pil", - # https://github.com/rwightman/pytorch-image-models/pull/1256 + # https://github.com/huggingface/pytorch-image-models/pull/1256 "ignore:.* is deprecated and will be removed in Pillow 10:DeprecationWarning:timm.data", # https://github.com/pytorch/pytorch/issues/72906 # https://github.com/pytorch/pytorch/pull/69823 @@ -215,12 +234,24 @@ filterwarnings = [ "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", "ignore:pkg_resources is deprecated as an API.:DeprecationWarning:lightning_utilities.core.imports", "ignore:Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated:DeprecationWarning:jsonargparse", + "ignore:`ModuleAvailableCache` is a special case of `RequirementCache`.:DeprecationWarning:lightning.fabric.plugins.environments.xla", # https://github.com/pytorch/pytorch/issues/110549 "ignore:allow_ops_in_compiled_graph failed to import torch:ImportWarning:einops", # https://github.com/rr-/docstring_parser/pull/82 "ignore:ast.* is deprecated and will be removed in Python 3.14:DeprecationWarning:docstring_parser.attrdoc", # https://github.com/python/cpython/pull/102953 - "ignore:Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata:DeprecationWarning:tarfile", + "ignore:Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata:DeprecationWarning:torchgeo.datasets.utils", + "ignore:Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata:DeprecationWarning:torchgeo.datasets.digital_typhoon", + "ignore:Python 3.14 will, by default, filter extracted tar archives and reject files or modify their metadata:DeprecationWarning:torchvision.datasets.utils", + # https://github.com/kornia/kornia/pull/2967 + "ignore:`torch.cuda.amp.custom_fwd\\(args...\\)` is deprecated.:FutureWarning:kornia.feature.lightglue", + # https://github.com/kornia/kornia/pull/2981 + "ignore:torch.is_autocast_cpu_enabled\\(\\) is deprecated.:DeprecationWarning:kornia.utils.helpers", + # https://github.com/pytorch/pytorch/pull/129239 + "ignore:You are using `torch.load` with `weights_only=False`:FutureWarning", + # https://github.com/pytorch/pytorch/issues/136264 + "ignore:__array__ implementation doesn't accept a copy keyword:DeprecationWarning", + "ignore:__array_wrap__ must accept context and return_scalar arguments:DeprecationWarning", # Expected warnings # Lightning warns us about using num_workers=0, but it's faster on macOS @@ -266,10 +297,11 @@ quote-style = "single" skip-magic-trailing-comma = true [tool.ruff.lint] -extend-select = ["D", "I", "UP"] +extend-select = ["ANN", "D", "I", "RUF", "UP"] +ignore = ["ANN401"] [tool.ruff.lint.per-file-ignores] -"docs/**" = ["D"] +"docs/**" = ["ANN", "D"] "experiments/**" = ["D"] "tests/**" = ["D"] diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 1c95448b161..11a82e9276c 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -1,11 +1,9 @@ # datasets -h5py==3.11.0 -laspy==2.5.3 -opencv-python==4.9.0.80 -pycocotools==2.0.7 -pyvista==0.43.8 -radiant-mlhub==0.4.1 -rarfile==4.2 -scikit-image==0.23.2 -scipy==1.13.1 -zipfile-deflate64==0.2.0 +h5py==3.12.1 +laspy==2.5.4 +opencv-python==4.10.0.84 +pandas[parquet]==2.2.3 +pycocotools==2.0.8 +pyvista==0.44.2 +scikit-image==0.25.0 +scipy==1.14.1 diff --git a/requirements/docs.txt b/requirements/docs.txt index 181d7c91e9e..2903d4b4c97 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ # docs -ipywidgets==8.1.3 -nbsphinx==0.9.4 +ipywidgets==8.1.5 +nbsphinx==0.9.6 sphinx==5.3.0 diff --git a/requirements/min-cons.old b/requirements/min-cons.old deleted file mode 100644 index 42347517237..00000000000 --- a/requirements/min-cons.old +++ /dev/null @@ -1,8 +0,0 @@ -# https://github.com/Lightning-AI/utilities/pull/147 -lightning-utilities<0.10 -# https://github.com/pypa/pip/issues/11760 -nbconvert<6 -# https://github.com/fatiando/pooch/issues/252 -pooch<1.5 -# https://github.com/jquast/wcwidth/issues/36 -wcwidth!=0.2.1 diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index d3e72bb5b83..a4a57e10968 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -4,8 +4,8 @@ setuptools==61.0.0 # install einops==0.3.0 fiona==1.8.21 -kornia==0.7.2 -lightly==1.4.4 +kornia==0.7.4 +lightly==1.4.5 lightning[pytorch-extra]==2.0.0 matplotlib==3.5.0 numpy==1.21.2 @@ -26,12 +26,11 @@ h5py==3.6.0 laspy==2.0.0 opencv-python==4.5.4.58 pycocotools==2.0.7 +pyarrow==15.0.0 # Remove when we upgrade min version of pandas to `pandas[parquet]>=2` pyvista==0.34.2 -radiant-mlhub==0.3.0 -rarfile==4.0 scikit-image==0.19.0 scipy==1.7.2 -zipfile-deflate64==0.2.0 +vtk==9.3.1 # PyVista is not yet compatible with VTK 9.4+ # tests pytest==7.3.0 diff --git a/requirements/required.txt b/requirements/required.txt index bc8fb922367..bbe4e855a74 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -1,22 +1,22 @@ # setup -setuptools==70.0.0 +setuptools==75.6.0 # install einops==0.8.0 -fiona==1.9.6 -kornia==0.7.2 -lightly==1.5.4 -lightning[pytorch-extra]==2.2.5 -matplotlib==3.9.0 -numpy==1.26.4 -pandas==2.2.2 -pillow==10.3.0 -pyproj==3.6.1 -rasterio==1.3.10 -rtree==1.2.0 -segmentation-models-pytorch==0.3.3 -shapely==2.0.4 -timm==0.9.2 -torch==2.3.0 -torchmetrics==1.4.0.post0 -torchvision==0.18.0 +fiona==1.10.1 +kornia==0.7.4 +lightly==1.5.15 +lightning[pytorch-extra]==2.5.0.post0 +matplotlib==3.10.0 +numpy==2.2.1 +pandas==2.2.3 +pillow==11.1.0 +pyproj==3.7.0 +rasterio==1.4.3 +rtree==1.3.0 +segmentation-models-pytorch==0.3.4 +shapely==2.0.6 +timm==0.9.7 +torch==2.5.1 +torchmetrics==1.6.1 +torchvision==0.20.1 diff --git a/requirements/style.txt b/requirements/style.txt index 82bf7c8d526..cb4ef80189d 100644 --- a/requirements/style.txt +++ b/requirements/style.txt @@ -1,3 +1,3 @@ # style -mypy==1.10.0 -ruff==0.4.6 +mypy==1.14.1 +ruff==0.8.5 diff --git a/requirements/tests.txt b/requirements/tests.txt index c4770eef925..f79028e4de2 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,4 @@ # tests -nbmake==1.5.3 -pytest==8.2.1 -pytest-cov==5.0.0 +nbmake==1.5.5 +pytest==8.3.4 +pytest-cov==6.0.0 diff --git a/tests/conf/agrifieldnet.yaml b/tests/conf/agrifieldnet.yaml index 42f2550910a..70c0498f252 100644 --- a/tests/conf/agrifieldnet.yaml +++ b/tests/conf/agrifieldnet.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 12 num_classes: 14 num_filters: 1 @@ -14,4 +14,4 @@ data: batch_size: 2 patch_size: 16 dict_kwargs: - paths: "tests/data/agrifieldnet" + paths: 'tests/data/agrifieldnet' diff --git a/tests/conf/bigearthnet_all.yaml b/tests/conf/bigearthnet_all.yaml index 2eba9c471d2..d24a2af4442 100644 --- a/tests/conf/bigearthnet_all.yaml +++ b/tests/conf/bigearthnet_all.yaml @@ -1,8 +1,8 @@ model: class_path: MultiLabelClassificationTask init_args: - loss: "bce" - model: "resnet18" + loss: 'bce' + model: 'resnet18' in_channels: 14 num_classes: 19 data: @@ -10,7 +10,6 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/bigearthnet" - bands: "all" + root: 'tests/data/bigearthnet' + bands: 'all' num_classes: 19 - download: true diff --git a/tests/conf/bigearthnet_s1.yaml b/tests/conf/bigearthnet_s1.yaml index b93d54ce1eb..78f01ebb81a 100644 --- a/tests/conf/bigearthnet_s1.yaml +++ b/tests/conf/bigearthnet_s1.yaml @@ -1,8 +1,8 @@ model: class_path: MultiLabelClassificationTask init_args: - loss: "bce" - model: "resnet18" + loss: 'bce' + model: 'resnet18' in_channels: 2 num_classes: 19 data: @@ -10,7 +10,6 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/bigearthnet" - bands: "s1" + root: 'tests/data/bigearthnet' + bands: 's1' num_classes: 19 - download: true diff --git a/tests/conf/bigearthnet_s2.yaml b/tests/conf/bigearthnet_s2.yaml index d00085a4879..e1afa68c126 100644 --- a/tests/conf/bigearthnet_s2.yaml +++ b/tests/conf/bigearthnet_s2.yaml @@ -1,8 +1,8 @@ model: class_path: MultiLabelClassificationTask init_args: - loss: "bce" - model: "resnet18" + loss: 'bce' + model: 'resnet18' in_channels: 12 num_classes: 19 data: @@ -10,7 +10,6 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/bigearthnet" - bands: "s2" + root: 'tests/data/bigearthnet' + bands: 's2' num_classes: 19 - download: true diff --git a/tests/conf/cabuar.yaml b/tests/conf/cabuar.yaml new file mode 100644 index 00000000000..0a9ba151269 --- /dev/null +++ b/tests/conf/cabuar.yaml @@ -0,0 +1,16 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: 'ce' + model: 'unet' + backbone: 'resnet18' + in_channels: 24 + num_classes: 2 + num_filters: 1 + ignore_index: null +data: + class_path: CaBuArDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/cabuar' diff --git a/tests/conf/chabud.yaml b/tests/conf/chabud.yaml index 1a708673326..3cff070bf59 100644 --- a/tests/conf/chabud.yaml +++ b/tests/conf/chabud.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 24 num_classes: 2 num_filters: 1 @@ -13,5 +13,4 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/chabud" - download: true + root: 'tests/data/chabud' diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml index 2b499a7fa4c..72494f73a0a 100644 --- a/tests/conf/chesapeake_cvpr_5.yaml +++ b/tests/conf/chesapeake_cvpr_5.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "jaccard" - model: "unet" - backbone: "resnet50" + loss: 'jaccard' + model: 'unet' + backbone: 'resnet50' in_channels: 4 num_classes: 5 num_filters: 1 @@ -12,15 +12,14 @@ data: class_path: ChesapeakeCVPRDataModule init_args: train_splits: - - "de-test" + - 'de-test' val_splits: - - "de-test" + - 'de-test' test_splits: - - "de-test" + - 'de-test' batch_size: 2 patch_size: 64 class_set: 5 use_prior_labels: False dict_kwargs: - root: "tests/data/chesapeake/cvpr" - download: true + root: 'tests/data/chesapeake/cvpr' diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index a5e2e5bb506..bc8fd4504ac 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 4 num_classes: 7 num_filters: 1 @@ -12,15 +12,14 @@ data: class_path: ChesapeakeCVPRDataModule init_args: train_splits: - - "de-test" + - 'de-test' val_splits: - - "de-test" + - 'de-test' test_splits: - - "de-test" + - 'de-test' batch_size: 2 patch_size: 64 class_set: 7 use_prior_labels: False dict_kwargs: - root: "tests/data/chesapeake/cvpr" - download: true + root: 'tests/data/chesapeake/cvpr' diff --git a/tests/conf/chesapeake_cvpr_prior_byol.yaml b/tests/conf/chesapeake_cvpr_prior_byol.yaml index 5018f0e4c33..6198c612324 100644 --- a/tests/conf/chesapeake_cvpr_prior_byol.yaml +++ b/tests/conf/chesapeake_cvpr_prior_byol.yaml @@ -2,20 +2,19 @@ model: class_path: BYOLTask init_args: in_channels: 4 - model: "resnet18" + model: 'resnet18' data: class_path: ChesapeakeCVPRDataModule init_args: train_splits: - - "de-test" + - 'de-test' val_splits: - - "de-test" + - 'de-test' test_splits: - - "de-test" + - 'de-test' batch_size: 2 patch_size: 64 class_set: 5 use_prior_labels: True dict_kwargs: - root: "tests/data/chesapeake/cvpr" - download: true + root: 'tests/data/chesapeake/cvpr' diff --git a/tests/conf/chesapeake_cvpr_prior_moco.yaml b/tests/conf/chesapeake_cvpr_prior_moco.yaml index 918288d90d6..29d1c5c3862 100644 --- a/tests/conf/chesapeake_cvpr_prior_moco.yaml +++ b/tests/conf/chesapeake_cvpr_prior_moco.yaml @@ -1,21 +1,20 @@ model: class_path: MoCoTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 4 data: class_path: ChesapeakeCVPRDataModule init_args: train_splits: - - "de-test" + - 'de-test' val_splits: - - "de-test" + - 'de-test' test_splits: - - "de-test" + - 'de-test' batch_size: 2 patch_size: 64 class_set: 5 use_prior_labels: True dict_kwargs: - root: "tests/data/chesapeake/cvpr" - download: false + root: 'tests/data/chesapeake/cvpr' diff --git a/tests/conf/chesapeake_cvpr_prior_simclr.yaml b/tests/conf/chesapeake_cvpr_prior_simclr.yaml index 9f21527cfae..8868c7c9dac 100644 --- a/tests/conf/chesapeake_cvpr_prior_simclr.yaml +++ b/tests/conf/chesapeake_cvpr_prior_simclr.yaml @@ -1,7 +1,7 @@ model: class_path: SimCLRTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 4 version: 1 layers: 2 @@ -10,15 +10,14 @@ data: class_path: ChesapeakeCVPRDataModule init_args: train_splits: - - "de-test" + - 'de-test' val_splits: - - "de-test" + - 'de-test' test_splits: - - "de-test" + - 'de-test' batch_size: 2 patch_size: 64 class_set: 5 use_prior_labels: True dict_kwargs: - root: "tests/data/chesapeake/cvpr" - download: false + root: 'tests/data/chesapeake/cvpr' diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index b247b20cdd9..21f4430331c 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -4,11 +4,10 @@ model: model: resnet18 num_outputs: 1 in_channels: 3 - loss: "mse" + loss: 'mse' data: class_path: COWCCountingDataModule init_args: batch_size: 1 dict_kwargs: - root: "tests/data/cowc_counting" - download: true + root: 'tests/data/cowc_counting' diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index a0c435e9549..2b81705bca6 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -1,14 +1,13 @@ model: class_path: RegressionTask init_args: - model: "resnet18" + model: 'resnet18' num_outputs: 1 in_channels: 3 - loss: "mse" + loss: 'mse' data: class_path: TropicalCycloneDataModule init_args: batch_size: 1 dict_kwargs: - root: "tests/data/cyclone" - download: true + root: 'tests/data/cyclone' diff --git a/tests/conf/deepglobelandcover.yaml b/tests/conf/deepglobelandcover.yaml index 08a29843fdc..e2e8c642477 100644 --- a/tests/conf/deepglobelandcover.yaml +++ b/tests/conf/deepglobelandcover.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 7 num_filters: 1 @@ -15,4 +15,4 @@ data: patch_size: 2 val_split_pct: 0.5 dict_kwargs: - root: "tests/data/deepglobelandcover" + root: 'tests/data/deepglobelandcover' diff --git a/tests/conf/digital_typhoon_id.yaml b/tests/conf/digital_typhoon_id.yaml new file mode 100644 index 00000000000..ea9e206ba75 --- /dev/null +++ b/tests/conf/digital_typhoon_id.yaml @@ -0,0 +1,18 @@ +model: + class_path: RegressionTask + init_args: + model: 'resnet18' + num_outputs: 1 + in_channels: 3 + loss: 'mse' +data: + class_path: DigitalTyphoonDataModule + init_args: + batch_size: 1 + split_by: 'typhoon_id' + dict_kwargs: + root: 'tests/data/digital_typhoon' + download: true + min_feature_value: + wind: 10 + sequence_length: 3 diff --git a/tests/conf/digital_typhoon_time.yaml b/tests/conf/digital_typhoon_time.yaml new file mode 100644 index 00000000000..b281be7e3c7 --- /dev/null +++ b/tests/conf/digital_typhoon_time.yaml @@ -0,0 +1,18 @@ +model: + class_path: RegressionTask + init_args: + model: 'resnet18' + num_outputs: 1 + in_channels: 3 + loss: 'mse' +data: + class_path: DigitalTyphoonDataModule + init_args: + batch_size: 1 + split_by: 'time' + dict_kwargs: + root: 'tests/data/digital_typhoon' + download: true + min_feature_value: + wind: 10 + sequence_length: 3 diff --git a/tests/conf/etci2021.yaml b/tests/conf/etci2021.yaml index bdd08948433..1606a870c52 100644 --- a/tests/conf/etci2021.yaml +++ b/tests/conf/etci2021.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 6 num_classes: 2 ignore_index: 0 @@ -12,5 +12,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/etci2021" - download: true + root: 'tests/data/etci2021' diff --git a/tests/conf/eurosat.yaml b/tests/conf/eurosat.yaml index 365b46aa776..a715bb0f7fe 100644 --- a/tests/conf/eurosat.yaml +++ b/tests/conf/eurosat.yaml @@ -1,8 +1,8 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' in_channels: 13 num_classes: 2 data: @@ -10,5 +10,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/eurosat" - download: true + root: 'tests/data/eurosat' diff --git a/tests/conf/eurosat100.yaml b/tests/conf/eurosat100.yaml index 0981e380548..870658c61a3 100644 --- a/tests/conf/eurosat100.yaml +++ b/tests/conf/eurosat100.yaml @@ -1,8 +1,8 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' lr: 1e-3 patience: 6 weights: null @@ -13,5 +13,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/eurosat" - download: true + root: 'tests/data/eurosat' diff --git a/tests/conf/eurosatspatial.yaml b/tests/conf/eurosatspatial.yaml index 0bcfa7126d4..a9f4c8d7f4b 100644 --- a/tests/conf/eurosatspatial.yaml +++ b/tests/conf/eurosatspatial.yaml @@ -1,8 +1,8 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' lr: 1e-3 patience: 6 weights: null @@ -13,5 +13,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/eurosat" - download: true + root: 'tests/data/eurosat' diff --git a/tests/conf/fire_risk.yaml b/tests/conf/fire_risk.yaml index b4ff3467c04..eeaddf11529 100644 --- a/tests/conf/fire_risk.yaml +++ b/tests/conf/fire_risk.yaml @@ -1,8 +1,8 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' in_channels: 3 num_classes: 5 data: @@ -10,5 +10,4 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/fire_risk" - download: false + root: 'tests/data/fire_risk' diff --git a/tests/conf/ftw.yaml b/tests/conf/ftw.yaml new file mode 100644 index 00000000000..7335d5ab1bf --- /dev/null +++ b/tests/conf/ftw.yaml @@ -0,0 +1,16 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: 'ce' + model: 'unet' + backbone: 'resnet18' + in_channels: 8 + num_classes: 2 + num_filters: 1 + ignore_index: null +data: + class_path: FieldsOfTheWorldDataModule + init_args: + batch_size: 1 + dict_kwargs: + root: 'tests/data/ftw' diff --git a/tests/conf/geonrw.yaml b/tests/conf/geonrw.yaml new file mode 100644 index 00000000000..9b06b7833bf --- /dev/null +++ b/tests/conf/geonrw.yaml @@ -0,0 +1,16 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: 'ce' + model: 'unet' + backbone: 'resnet18' + in_channels: 3 + num_classes: 11 + num_filters: 1 + ignore_index: null +data: + class_path: GeoNRWDataModule + init_args: + batch_size: 1 + dict_kwargs: + root: 'tests/data/geonrw' diff --git a/tests/conf/gid15.yaml b/tests/conf/gid15.yaml index 057a56696b2..f6ae0b3b231 100644 --- a/tests/conf/gid15.yaml +++ b/tests/conf/gid15.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 16 num_filters: 1 @@ -15,5 +15,4 @@ data: patch_size: 2 val_split_pct: 0.5 dict_kwargs: - root: "tests/data/gid15" - download: true + root: 'tests/data/gid15' diff --git a/tests/conf/hyspecnet_byol.yaml b/tests/conf/hyspecnet_byol.yaml new file mode 100644 index 00000000000..5c0fa31d609 --- /dev/null +++ b/tests/conf/hyspecnet_byol.yaml @@ -0,0 +1,11 @@ +model: + class_path: BYOLTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/hyspecnet_moco.yaml b/tests/conf/hyspecnet_moco.yaml new file mode 100644 index 00000000000..732b83912c1 --- /dev/null +++ b/tests/conf/hyspecnet_moco.yaml @@ -0,0 +1,11 @@ +model: + class_path: MoCoTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/hyspecnet_simclr.yaml b/tests/conf/hyspecnet_simclr.yaml new file mode 100644 index 00000000000..d16e8209326 --- /dev/null +++ b/tests/conf/hyspecnet_simclr.yaml @@ -0,0 +1,11 @@ +model: + class_path: SimCLRTask + init_args: + model: 'resnet18' + in_channels: 202 +data: + class_path: HySpecNet11kDataModule + init_args: + batch_size: 2 + dict_kwargs: + root: 'tests/data/hyspecnet' diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml index 4fbd3ded072..01e97d3b503 100644 --- a/tests/conf/inria.yaml +++ b/tests/conf/inria.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 2 ignore_index: null @@ -13,4 +13,4 @@ data: batch_size: 1 patch_size: 2 dict_kwargs: - root: "tests/data/inria" + root: 'tests/data/inria' diff --git a/tests/conf/inria_deeplab.yaml b/tests/conf/inria_deeplab.yaml index e16ba15abe3..07b4f05c81b 100644 --- a/tests/conf/inria_deeplab.yaml +++ b/tests/conf/inria_deeplab.yaml @@ -1,14 +1,14 @@ model: class_path: PixelwiseRegressionTask init_args: - model: "deeplabv3+" - backbone: "resnet18" + model: 'deeplabv3+' + backbone: 'resnet18' in_channels: 3 - loss: "mae" + loss: 'mae' data: class_path: InriaAerialImageLabelingDataModule init_args: batch_size: 1 patch_size: 2 dict_kwargs: - root: "tests/data/inria" + root: 'tests/data/inria' diff --git a/tests/conf/inria_fcn.yaml b/tests/conf/inria_fcn.yaml index 692db059dbf..e9b78c3e0c9 100644 --- a/tests/conf/inria_fcn.yaml +++ b/tests/conf/inria_fcn.yaml @@ -1,14 +1,14 @@ model: class_path: PixelwiseRegressionTask init_args: - model: "fcn" - backbone: "resnet18" + model: 'fcn' + backbone: 'resnet18' in_channels: 3 - loss: "mae" + loss: 'mae' data: class_path: InriaAerialImageLabelingDataModule init_args: batch_size: 1 patch_size: 2 dict_kwargs: - root: "tests/data/inria" + root: 'tests/data/inria' diff --git a/tests/conf/inria_unet.yaml b/tests/conf/inria_unet.yaml index ded50ffe79c..5ebf54bead9 100644 --- a/tests/conf/inria_unet.yaml +++ b/tests/conf/inria_unet.yaml @@ -1,14 +1,14 @@ model: class_path: PixelwiseRegressionTask init_args: - model: "unet" - backbone: "resnet18" + model: 'unet' + backbone: 'resnet18' in_channels: 3 - loss: "mae" + loss: 'mae' data: class_path: InriaAerialImageLabelingDataModule init_args: batch_size: 1 patch_size: 2 dict_kwargs: - root: "tests/data/inria" + root: 'tests/data/inria' diff --git a/tests/conf/io_preprocessed.yaml b/tests/conf/io_preprocessed.yaml index f98f5e22ec0..33e6a62da86 100644 --- a/tests/conf/io_preprocessed.yaml +++ b/tests/conf/io_preprocessed.yaml @@ -3,11 +3,10 @@ model: data: class_path: IOBenchDataModule dict_kwargs: - root: "data/io" - split: "preprocessed" - download: true + root: 'data/io' + split: 'preprocessed' checksum: true trainer: max_epochs: 1 num_sanity_val_steps: 0 - profiler: "simple" + profiler: 'simple' diff --git a/tests/conf/io_raw.yaml b/tests/conf/io_raw.yaml index d8bd8cf1b6a..de0e6f211ab 100644 --- a/tests/conf/io_raw.yaml +++ b/tests/conf/io_raw.yaml @@ -3,11 +3,10 @@ model: data: class_path: IOBenchDataModule dict_kwargs: - root: "data/io" - split: "raw" - download: true + root: 'data/io' + split: 'raw' checksum: true trainer: max_epochs: 1 num_sanity_val_steps: 0 - profiler: "simple" + profiler: 'simple' diff --git a/tests/conf/iobench.yaml b/tests/conf/iobench.yaml index 888212f67a5..cffb007775a 100644 --- a/tests/conf/iobench.yaml +++ b/tests/conf/iobench.yaml @@ -6,4 +6,4 @@ data: batch_size: 2 patch_size: 16 dict_kwargs: - root: "tests/data/iobench" + root: 'tests/data/iobench' diff --git a/tests/conf/l7irish.yaml b/tests/conf/l7irish.yaml index fc67fb8e1cc..18219d68ae8 100644 --- a/tests/conf/l7irish.yaml +++ b/tests/conf/l7irish.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 9 num_classes: 5 num_filters: 1 @@ -15,5 +15,4 @@ data: patch_size: 32 length: 5 dict_kwargs: - paths: "tests/data/l7irish" - download: true + paths: 'tests/data/l7irish' diff --git a/tests/conf/l8biome.yaml b/tests/conf/l8biome.yaml index f33b4b36464..85f5b09948a 100644 --- a/tests/conf/l8biome.yaml +++ b/tests/conf/l8biome.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 11 num_classes: 5 num_filters: 1 @@ -15,5 +15,4 @@ data: patch_size: 32 length: 5 dict_kwargs: - paths: "tests/data/l8biome" - download: true + paths: 'tests/data/l8biome' diff --git a/tests/conf/landcoverai.yaml b/tests/conf/landcoverai.yaml index afdf3631ebb..c27d936844d 100644 --- a/tests/conf/landcoverai.yaml +++ b/tests/conf/landcoverai.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 5 num_filters: 1 @@ -13,5 +13,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/landcoverai" - download: true + root: 'tests/data/landcoverai' diff --git a/tests/conf/landcoverai100.yaml b/tests/conf/landcoverai100.yaml new file mode 100644 index 00000000000..f6461851fa3 --- /dev/null +++ b/tests/conf/landcoverai100.yaml @@ -0,0 +1,16 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: 'ce' + model: 'unet' + backbone: 'resnet18' + in_channels: 3 + num_classes: 5 + num_filters: 1 + ignore_index: null +data: + class_path: LandCoverAI100DataModule + init_args: + batch_size: 1 + dict_kwargs: + root: 'tests/data/landcoverai' diff --git a/tests/conf/loveda.yaml b/tests/conf/loveda.yaml index 44745a6d929..3107db88c81 100644 --- a/tests/conf/loveda.yaml +++ b/tests/conf/loveda.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 8 num_filters: 1 @@ -13,5 +13,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/loveda" - download: true + root: 'tests/data/loveda' diff --git a/tests/conf/naipchesapeake.yaml b/tests/conf/naipchesapeake.yaml index 4b13865f1bd..09eb2fca53c 100644 --- a/tests/conf/naipchesapeake.yaml +++ b/tests/conf/naipchesapeake.yaml @@ -1,11 +1,11 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "deeplabv3+" - backbone: "resnet34" + loss: 'ce' + model: 'deeplabv3+' + backbone: 'resnet18' in_channels: 4 - num_classes: 14 + num_classes: 128 num_filters: 1 ignore_index: null data: @@ -14,6 +14,5 @@ data: batch_size: 2 patch_size: 32 dict_kwargs: - naip_paths: "tests/data/naip" - chesapeake_paths: "tests/data/chesapeake/BAYWIDE" - chesapeake_download: true + naip_paths: 'tests/data/naip' + chesapeake_paths: 'tests/data/chesapeake/lulc' diff --git a/tests/conf/nasa_marine_debris.yaml b/tests/conf/nasa_marine_debris.yaml index a0f30127414..92546a592ec 100644 --- a/tests/conf/nasa_marine_debris.yaml +++ b/tests/conf/nasa_marine_debris.yaml @@ -1,13 +1,12 @@ model: class_path: ObjectDetectionTask init_args: - model: "faster-rcnn" - backbone: "resnet18" + model: 'faster-rcnn' + backbone: 'resnet18' num_classes: 2 data: class_path: NASAMarineDebrisDataModule init_args: batch_size: 1 dict_kwargs: - root: "tests/data/nasa_marine_debris" - download: true + root: 'tests/data/nasa_marine_debris' diff --git a/tests/conf/potsdam2d.yaml b/tests/conf/potsdam2d.yaml index 362ec81815d..31ea7441950 100644 --- a/tests/conf/potsdam2d.yaml +++ b/tests/conf/potsdam2d.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 4 num_classes: 6 num_filters: 1 @@ -15,4 +15,4 @@ data: patch_size: 2 val_split_pct: 0.5 dict_kwargs: - root: "tests/data/potsdam" + root: 'tests/data/potsdam' diff --git a/tests/conf/quakeset.yaml b/tests/conf/quakeset.yaml index 9d54e1b6d4f..0f1df7baa4b 100644 --- a/tests/conf/quakeset.yaml +++ b/tests/conf/quakeset.yaml @@ -1,8 +1,8 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' in_channels: 4 num_classes: 2 data: @@ -10,5 +10,4 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/quakeset" - download: false + root: 'tests/data/quakeset' diff --git a/tests/conf/resisc45.yaml b/tests/conf/resisc45.yaml index 86deb432f65..567987aea54 100644 --- a/tests/conf/resisc45.yaml +++ b/tests/conf/resisc45.yaml @@ -1,8 +1,8 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' in_channels: 3 num_classes: 3 data: @@ -10,5 +10,4 @@ data: init_args: batch_size: 1 dict_kwargs: - root: "tests/data/resisc45" - download: true + root: 'tests/data/resisc45' diff --git a/tests/conf/seco_byol_1.yaml b/tests/conf/seco_byol_1.yaml index 9744f31a282..64440125cea 100644 --- a/tests/conf/seco_byol_1.yaml +++ b/tests/conf/seco_byol_1.yaml @@ -2,16 +2,11 @@ model: class_path: BYOLTask init_args: in_channels: 3 - model: "resnet18" + model: 'resnet18' data: class_path: SeasonalContrastS2DataModule init_args: batch_size: 2 dict_kwargs: - root: "tests/data/seco" + root: 'tests/data/seco' seasons: 1 - # https://github.com/Lightning-AI/lightning/issues/18616 - bands: - - "B4" - - "B3" - - "B2" diff --git a/tests/conf/seco_byol_2.yaml b/tests/conf/seco_byol_2.yaml index cb87dc1dbd0..521e03f1771 100644 --- a/tests/conf/seco_byol_2.yaml +++ b/tests/conf/seco_byol_2.yaml @@ -2,16 +2,11 @@ model: class_path: BYOLTask init_args: in_channels: 3 - model: "resnet18" + model: 'resnet18' data: class_path: SeasonalContrastS2DataModule init_args: batch_size: 2 dict_kwargs: - root: "tests/data/seco" + root: 'tests/data/seco' seasons: 2 - # https://github.com/Lightning-AI/lightning/issues/18616 - bands: - - "B4" - - "B3" - - "B2" diff --git a/tests/conf/seco_moco_1.yaml b/tests/conf/seco_moco_1.yaml index 97972245748..55d06e8173d 100644 --- a/tests/conf/seco_moco_1.yaml +++ b/tests/conf/seco_moco_1.yaml @@ -1,7 +1,7 @@ model: class_path: MoCoTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 3 version: 1 weight_decay: 1e-4 @@ -13,10 +13,5 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/seco" + root: 'tests/data/seco' seasons: 1 - # https://github.com/Lightning-AI/lightning/issues/18616 - bands: - - "B4" - - "B3" - - "B2" diff --git a/tests/conf/seco_moco_2.yaml b/tests/conf/seco_moco_2.yaml index 2e98abb6f66..919a10e3d4f 100644 --- a/tests/conf/seco_moco_2.yaml +++ b/tests/conf/seco_moco_2.yaml @@ -1,7 +1,7 @@ model: class_path: MoCoTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 3 version: 2 layers: 2 @@ -16,10 +16,5 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/seco" + root: 'tests/data/seco' seasons: 2 - # https://github.com/Lightning-AI/lightning/issues/18616 - bands: - - "B4" - - "B3" - - "B2" diff --git a/tests/conf/seco_simclr_1.yaml b/tests/conf/seco_simclr_1.yaml index 5f8ec279c6e..45df100ff4f 100644 --- a/tests/conf/seco_simclr_1.yaml +++ b/tests/conf/seco_simclr_1.yaml @@ -1,7 +1,7 @@ model: class_path: SimCLRTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 3 version: 1 layers: 2 @@ -14,10 +14,5 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/seco" + root: 'tests/data/seco' seasons: 1 - # https://github.com/Lightning-AI/lightning/issues/18616 - bands: - - "B4" - - "B3" - - "B2" diff --git a/tests/conf/seco_simclr_2.yaml b/tests/conf/seco_simclr_2.yaml index 9af2632de7c..8a4531e30f4 100644 --- a/tests/conf/seco_simclr_2.yaml +++ b/tests/conf/seco_simclr_2.yaml @@ -1,7 +1,7 @@ model: class_path: SimCLRTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 3 version: 2 layers: 4 @@ -14,10 +14,5 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/seco" + root: 'tests/data/seco' seasons: 2 - # https://github.com/Lightning-AI/lightning/issues/18616 - bands: - - "B4" - - "B3" - - "B2" diff --git a/tests/conf/sen12ms_all.yaml b/tests/conf/sen12ms_all.yaml index 3f83fa55085..d3aadd7d2f1 100644 --- a/tests/conf/sen12ms_all.yaml +++ b/tests/conf/sen12ms_all.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 15 num_classes: 11 ignore_index: null @@ -11,6 +11,6 @@ data: class_path: SEN12MSDataModule init_args: batch_size: 1 - band_set: "all" + band_set: 'all' dict_kwargs: - root: "tests/data/sen12ms" + root: 'tests/data/sen12ms' diff --git a/tests/conf/sen12ms_s1.yaml b/tests/conf/sen12ms_s1.yaml index 7e536d9e35a..ce790ca65de 100644 --- a/tests/conf/sen12ms_s1.yaml +++ b/tests/conf/sen12ms_s1.yaml @@ -1,10 +1,10 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "focal" - model: "fcn" + loss: 'focal' + model: 'fcn' num_filters: 1 - backbone: "resnet18" + backbone: 'resnet18' in_channels: 2 num_classes: 11 ignore_index: null @@ -12,6 +12,6 @@ data: class_path: SEN12MSDataModule init_args: batch_size: 1 - band_set: "s1" + band_set: 's1' dict_kwargs: - root: "tests/data/sen12ms" + root: 'tests/data/sen12ms' diff --git a/tests/conf/sen12ms_s2_all.yaml b/tests/conf/sen12ms_s2_all.yaml index b98d59d0c7f..9bcd1de1cf7 100644 --- a/tests/conf/sen12ms_s2_all.yaml +++ b/tests/conf/sen12ms_s2_all.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 13 num_classes: 11 ignore_index: null @@ -11,6 +11,6 @@ data: class_path: SEN12MSDataModule init_args: batch_size: 1 - band_set: "s2-all" + band_set: 's2-all' dict_kwargs: - root: "tests/data/sen12ms" + root: 'tests/data/sen12ms' diff --git a/tests/conf/sen12ms_s2_reduced.yaml b/tests/conf/sen12ms_s2_reduced.yaml index 770efaa6549..9e6ca651e64 100644 --- a/tests/conf/sen12ms_s2_reduced.yaml +++ b/tests/conf/sen12ms_s2_reduced.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 6 num_classes: 11 ignore_index: null @@ -11,6 +11,6 @@ data: class_path: SEN12MSDataModule init_args: batch_size: 1 - band_set: "s2-reduced" + band_set: 's2-reduced' dict_kwargs: - root: "tests/data/sen12ms" + root: 'tests/data/sen12ms' diff --git a/tests/conf/sentinel2_cdl.yaml b/tests/conf/sentinel2_cdl.yaml index 9cb192bd819..5a2928eaf99 100644 --- a/tests/conf/sentinel2_cdl.yaml +++ b/tests/conf/sentinel2_cdl.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 13 num_classes: 134 num_filters: 1 @@ -14,5 +14,5 @@ data: batch_size: 2 patch_size: 16 dict_kwargs: - cdl_paths: "tests/data/cdl" - sentinel2_paths: "tests/data/sentinel2" + cdl_paths: 'tests/data/cdl' + sentinel2_paths: 'tests/data/sentinel2' diff --git a/tests/conf/sentinel2_eurocrops.yaml b/tests/conf/sentinel2_eurocrops.yaml index b3633d3590d..760e8fd718a 100644 --- a/tests/conf/sentinel2_eurocrops.yaml +++ b/tests/conf/sentinel2_eurocrops.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 13 num_classes: 3 num_filters: 1 @@ -13,5 +13,5 @@ data: batch_size: 2 patch_size: 16 dict_kwargs: - sentinel2_paths: "tests/data/sentinel2" - eurocrops_paths: "tests/data/eurocrops" + sentinel2_paths: 'tests/data/sentinel2' + eurocrops_paths: 'tests/data/eurocrops' diff --git a/tests/conf/sentinel2_nccm.yaml b/tests/conf/sentinel2_nccm.yaml index 0244455863d..97af8efa06b 100644 --- a/tests/conf/sentinel2_nccm.yaml +++ b/tests/conf/sentinel2_nccm.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 13 num_classes: 5 num_filters: 1 @@ -14,5 +14,5 @@ data: batch_size: 2 patch_size: 16 dict_kwargs: - nccm_paths: "tests/data/nccm" - sentinel2_paths: "tests/data/sentinel2" + nccm_paths: 'tests/data/nccm' + sentinel2_paths: 'tests/data/sentinel2' diff --git a/tests/conf/sentinel2_south_america_soybean.yaml b/tests/conf/sentinel2_south_america_soybean.yaml index 7fe95704950..1c788d1f197 100644 --- a/tests/conf/sentinel2_south_america_soybean.yaml +++ b/tests/conf/sentinel2_south_america_soybean.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "deeplabv3+" - backbone: "resnet18" + loss: 'ce' + model: 'deeplabv3+' + backbone: 'resnet18' in_channels: 13 num_classes: 2 num_filters: 1 @@ -13,5 +13,5 @@ data: batch_size: 2 patch_size: 16 dict_kwargs: - south_america_soybean_paths: "tests/data/south_america_soybean" - sentinel2_paths: "tests/data/sentinel2" + south_america_soybean_paths: 'tests/data/south_america_soybean' + sentinel2_paths: 'tests/data/sentinel2' diff --git a/tests/conf/skippd.yaml b/tests/conf/skippd.yaml index 15a531437c5..3d40f203b70 100644 --- a/tests/conf/skippd.yaml +++ b/tests/conf/skippd.yaml @@ -1,15 +1,14 @@ model: class_path: RegressionTask init_args: - model: "resnet18" + model: 'resnet18' num_outputs: 1 in_channels: 3 - loss: "mse" + loss: 'mse' data: class_path: SKIPPDDataModule init_args: batch_size: 1 val_split_pct: 0.4 dict_kwargs: - root: "tests/data/skippd" - download: true + root: 'tests/data/skippd' diff --git a/tests/conf/so2sat_all.yaml b/tests/conf/so2sat_all.yaml index c728c9d7179..736c98d2d18 100644 --- a/tests/conf/so2sat_all.yaml +++ b/tests/conf/so2sat_all.yaml @@ -1,15 +1,15 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' in_channels: 18 num_classes: 17 data: class_path: So2SatDataModule init_args: batch_size: 1 - band_set: "all" + band_set: 'all' dict_kwargs: - root: "tests/data/so2sat" - version: "2" + root: 'tests/data/so2sat' + version: '2' diff --git a/tests/conf/so2sat_rgb.yaml b/tests/conf/so2sat_rgb.yaml index 66e1e223561..840cae6534a 100644 --- a/tests/conf/so2sat_rgb.yaml +++ b/tests/conf/so2sat_rgb.yaml @@ -1,16 +1,16 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' in_channels: 3 num_classes: 17 data: class_path: So2SatDataModule init_args: batch_size: 1 - band_set: "rgb" + band_set: 'rgb' val_split_pct: 0.5 dict_kwargs: - root: "tests/data/so2sat" - version: "3_random" + root: 'tests/data/so2sat' + version: '3_random' diff --git a/tests/conf/so2sat_s1.yaml b/tests/conf/so2sat_s1.yaml index df7a9cb1ea9..cd15c91c248 100644 --- a/tests/conf/so2sat_s1.yaml +++ b/tests/conf/so2sat_s1.yaml @@ -1,15 +1,15 @@ model: class_path: ClassificationTask init_args: - loss: "focal" - model: "resnet18" + loss: 'focal' + model: 'resnet18' in_channels: 8 num_classes: 17 data: class_path: So2SatDataModule init_args: batch_size: 1 - band_set: "s1" + band_set: 's1' dict_kwargs: - root: "tests/data/so2sat" - version: "2" + root: 'tests/data/so2sat' + version: '2' diff --git a/tests/conf/so2sat_s2.yaml b/tests/conf/so2sat_s2.yaml index fb41099e60e..828b9f04fc5 100644 --- a/tests/conf/so2sat_s2.yaml +++ b/tests/conf/so2sat_s2.yaml @@ -1,14 +1,14 @@ model: class_path: ClassificationTask init_args: - loss: "jaccard" - model: "resnet18" + loss: 'jaccard' + model: 'resnet18' in_channels: 10 num_classes: 17 data: class_path: So2SatDataModule init_args: batch_size: 1 - band_set: "s2" + band_set: 's2' dict_kwargs: - root: "tests/data/so2sat" + root: 'tests/data/so2sat' diff --git a/tests/conf/southafricacroptype.yaml b/tests/conf/southafricacroptype.yaml index cfd4f8dfecf..d409e535e9d 100644 --- a/tests/conf/southafricacroptype.yaml +++ b/tests/conf/southafricacroptype.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 12 num_classes: 10 num_filters: 1 @@ -14,4 +14,4 @@ data: batch_size: 2 patch_size: 16 dict_kwargs: - paths: "tests/data/south_africa_crop_type" + paths: 'tests/data/south_africa_crop_type' diff --git a/tests/conf/spacenet1.yaml b/tests/conf/spacenet1.yaml index 0da6cd24c4c..6f47d88d35b 100644 --- a/tests/conf/spacenet1.yaml +++ b/tests/conf/spacenet1.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 3 num_filters: 1 @@ -12,8 +12,7 @@ data: class_path: SpaceNet1DataModule init_args: batch_size: 1 - val_split_pct: 0.33 - test_split_pct: 0.33 + val_split_pct: 0.34 + test_split_pct: 0.34 dict_kwargs: - root: "tests/data/spacenet" - download: true + root: 'tests/data/spacenet/spacenet1' diff --git a/tests/conf/spacenet6.yaml b/tests/conf/spacenet6.yaml new file mode 100644 index 00000000000..c017102d3df --- /dev/null +++ b/tests/conf/spacenet6.yaml @@ -0,0 +1,19 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: 'ce' + model: 'unet' + backbone: 'resnet18' + in_channels: 4 + num_classes: 3 + num_filters: 1 + ignore_index: null +data: + class_path: SpaceNet6DataModule + init_args: + batch_size: 1 + val_split_pct: 0.34 + test_split_pct: 0.34 + dict_kwargs: + root: 'tests/data/spacenet/spacenet6' + image: 'SAR-Intensity' diff --git a/tests/conf/ssl4eo_l_benchmark_cdl.yaml b/tests/conf/ssl4eo_l_benchmark_cdl.yaml index a4a4a7b9203..842c3577970 100644 --- a/tests/conf/ssl4eo_l_benchmark_cdl.yaml +++ b/tests/conf/ssl4eo_l_benchmark_cdl.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 7 num_classes: 134 num_filters: 1 @@ -13,6 +13,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo_benchmark_landsat" - sensor: "tm_toa" - product: "cdl" + root: 'tests/data/ssl4eo_benchmark_landsat' + sensor: 'tm_toa' + product: 'cdl' diff --git a/tests/conf/ssl4eo_l_benchmark_nlcd.yaml b/tests/conf/ssl4eo_l_benchmark_nlcd.yaml index 89475a091b0..040fab72a53 100644 --- a/tests/conf/ssl4eo_l_benchmark_nlcd.yaml +++ b/tests/conf/ssl4eo_l_benchmark_nlcd.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 6 num_classes: 17 num_filters: 1 @@ -13,6 +13,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo_benchmark_landsat" - sensor: "etm_sr" - product: "nlcd" + root: 'tests/data/ssl4eo_benchmark_landsat' + sensor: 'etm_sr' + product: 'nlcd' diff --git a/tests/conf/ssl4eo_l_byol_1.yaml b/tests/conf/ssl4eo_l_byol_1.yaml index ed78b7fae37..71151f52d28 100644 --- a/tests/conf/ssl4eo_l_byol_1.yaml +++ b/tests/conf/ssl4eo_l_byol_1.yaml @@ -2,12 +2,12 @@ model: class_path: BYOLTask init_args: in_channels: 7 - model: "resnet18" + model: 'resnet18' data: class_path: SSL4EOLDataModule init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/l" - split: "tm_toa" + root: 'tests/data/ssl4eo/l' + split: 'tm_toa' seasons: 1 diff --git a/tests/conf/ssl4eo_l_byol_2.yaml b/tests/conf/ssl4eo_l_byol_2.yaml index 6e1c6ab060d..3a2f3df24d9 100644 --- a/tests/conf/ssl4eo_l_byol_2.yaml +++ b/tests/conf/ssl4eo_l_byol_2.yaml @@ -2,12 +2,12 @@ model: class_path: BYOLTask init_args: in_channels: 6 - model: "resnet18" + model: 'resnet18' data: class_path: SSL4EOLDataModule init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/l" - split: "etm_sr" + root: 'tests/data/ssl4eo/l' + split: 'etm_sr' seasons: 2 diff --git a/tests/conf/ssl4eo_l_moco_1.yaml b/tests/conf/ssl4eo_l_moco_1.yaml index 023f1ff9b66..d15bf8b15b9 100644 --- a/tests/conf/ssl4eo_l_moco_1.yaml +++ b/tests/conf/ssl4eo_l_moco_1.yaml @@ -1,7 +1,7 @@ model: class_path: MoCoTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 9 version: 1 weight_decay: 1e-4 @@ -19,6 +19,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/l" - split: "etm_toa" + root: 'tests/data/ssl4eo/l' + split: 'etm_toa' seasons: 1 diff --git a/tests/conf/ssl4eo_l_moco_2.yaml b/tests/conf/ssl4eo_l_moco_2.yaml index 3edf6a52487..3eac4a5cb0d 100644 --- a/tests/conf/ssl4eo_l_moco_2.yaml +++ b/tests/conf/ssl4eo_l_moco_2.yaml @@ -1,7 +1,7 @@ model: class_path: MoCoTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 11 version: 2 layers: 2 @@ -16,6 +16,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/l" - split: "oli_tirs_toa" + root: 'tests/data/ssl4eo/l' + split: 'oli_tirs_toa' seasons: 2 diff --git a/tests/conf/ssl4eo_l_simclr_1.yaml b/tests/conf/ssl4eo_l_simclr_1.yaml index b705579173f..c6249900dd4 100644 --- a/tests/conf/ssl4eo_l_simclr_1.yaml +++ b/tests/conf/ssl4eo_l_simclr_1.yaml @@ -1,7 +1,7 @@ model: class_path: SimCLRTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 7 version: 1 layers: 2 @@ -14,6 +14,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/l" - split: "oli_sr" + root: 'tests/data/ssl4eo/l' + split: 'oli_sr' seasons: 1 diff --git a/tests/conf/ssl4eo_l_simclr_2.yaml b/tests/conf/ssl4eo_l_simclr_2.yaml index 7310bba9e95..63ebfc2c111 100644 --- a/tests/conf/ssl4eo_l_simclr_2.yaml +++ b/tests/conf/ssl4eo_l_simclr_2.yaml @@ -1,7 +1,7 @@ model: class_path: SimCLRTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 7 version: 2 layers: 3 @@ -14,6 +14,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/l" - split: "tm_toa" + root: 'tests/data/ssl4eo/l' + split: 'tm_toa' seasons: 2 diff --git a/tests/conf/ssl4eo_s12_byol_1.yaml b/tests/conf/ssl4eo_s12_byol_1.yaml index ccdf4b5736d..aeba268fb89 100644 --- a/tests/conf/ssl4eo_s12_byol_1.yaml +++ b/tests/conf/ssl4eo_s12_byol_1.yaml @@ -2,12 +2,12 @@ model: class_path: BYOLTask init_args: in_channels: 2 - model: "resnet18" + model: 'resnet18' data: class_path: SSL4EOS12DataModule init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/s12" - split: "s1" + root: 'tests/data/ssl4eo/s12' + split: 's1' seasons: 1 diff --git a/tests/conf/ssl4eo_s12_byol_2.yaml b/tests/conf/ssl4eo_s12_byol_2.yaml index 6368e8fdefe..a9f602a1294 100644 --- a/tests/conf/ssl4eo_s12_byol_2.yaml +++ b/tests/conf/ssl4eo_s12_byol_2.yaml @@ -2,12 +2,12 @@ model: class_path: BYOLTask init_args: in_channels: 13 - model: "resnet18" + model: 'resnet18' data: class_path: SSL4EOS12DataModule init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/s12" - split: "s2c" + root: 'tests/data/ssl4eo/s12' + split: 's2c' seasons: 2 diff --git a/tests/conf/ssl4eo_s12_moco_1.yaml b/tests/conf/ssl4eo_s12_moco_1.yaml index 513d5ae0842..29678057c33 100644 --- a/tests/conf/ssl4eo_s12_moco_1.yaml +++ b/tests/conf/ssl4eo_s12_moco_1.yaml @@ -1,7 +1,7 @@ model: class_path: MoCoTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 12 version: 1 weight_decay: 1e-4 @@ -13,6 +13,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/s12" - split: "s2a" + root: 'tests/data/ssl4eo/s12' + split: 's2a' seasons: 1 diff --git a/tests/conf/ssl4eo_s12_moco_2.yaml b/tests/conf/ssl4eo_s12_moco_2.yaml index 71d8ee43dc7..e20ea55363e 100644 --- a/tests/conf/ssl4eo_s12_moco_2.yaml +++ b/tests/conf/ssl4eo_s12_moco_2.yaml @@ -1,7 +1,7 @@ model: class_path: MoCoTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 2 version: 2 layers: 2 @@ -16,6 +16,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/s12" - split: "s1" + root: 'tests/data/ssl4eo/s12' + split: 's1' seasons: 2 diff --git a/tests/conf/ssl4eo_s12_simclr_1.yaml b/tests/conf/ssl4eo_s12_simclr_1.yaml index 94444be5cc9..b6316034b2b 100644 --- a/tests/conf/ssl4eo_s12_simclr_1.yaml +++ b/tests/conf/ssl4eo_s12_simclr_1.yaml @@ -1,7 +1,7 @@ model: class_path: SimCLRTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 13 version: 1 layers: 2 @@ -14,6 +14,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/s12" - split: "s2c" + root: 'tests/data/ssl4eo/s12' + split: 's2c' seasons: 1 diff --git a/tests/conf/ssl4eo_s12_simclr_2.yaml b/tests/conf/ssl4eo_s12_simclr_2.yaml index 7d88a3713ba..50bcbddf8a2 100644 --- a/tests/conf/ssl4eo_s12_simclr_2.yaml +++ b/tests/conf/ssl4eo_s12_simclr_2.yaml @@ -1,7 +1,7 @@ model: class_path: SimCLRTask init_args: - model: "resnet18" + model: 'resnet18' in_channels: 12 version: 2 layers: 3 @@ -14,6 +14,6 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ssl4eo/s12" - split: "s2a" + root: 'tests/data/ssl4eo/s12' + split: 's2a' seasons: 2 diff --git a/tests/conf/sustainbench_crop_yield.yaml b/tests/conf/sustainbench_crop_yield.yaml index ba6e65af105..83a44daa305 100644 --- a/tests/conf/sustainbench_crop_yield.yaml +++ b/tests/conf/sustainbench_crop_yield.yaml @@ -1,14 +1,13 @@ model: class_path: RegressionTask init_args: - model: "resnet18" + model: 'resnet18' num_outputs: 1 in_channels: 9 - loss: "mse" + loss: 'mse' data: class_path: SustainBenchCropYieldDataModule init_args: batch_size: 1 dict_kwargs: - root: "tests/data/sustainbench_crop_yield" - download: true + root: 'tests/data/sustainbench_crop_yield' diff --git a/tests/conf/treesatai.yaml b/tests/conf/treesatai.yaml new file mode 100644 index 00000000000..e605b688b82 --- /dev/null +++ b/tests/conf/treesatai.yaml @@ -0,0 +1,13 @@ +model: + class_path: MultiLabelClassificationTask + init_args: + model: 'resnet18' + in_channels: 19 + num_classes: 15 + loss: 'bce' +data: + class_path: TreeSatAIDataModule + init_args: + batch_size: 1 + dict_kwargs: + root: 'tests/data/treesatai' diff --git a/tests/conf/ucmerced.yaml b/tests/conf/ucmerced.yaml index d9c8752f1ec..051112f540b 100644 --- a/tests/conf/ucmerced.yaml +++ b/tests/conf/ucmerced.yaml @@ -1,8 +1,8 @@ model: class_path: ClassificationTask init_args: - loss: "ce" - model: "resnet18" + loss: 'ce' + model: 'resnet18' in_channels: 3 num_classes: 2 data: @@ -10,5 +10,4 @@ data: init_args: batch_size: 2 dict_kwargs: - root: "tests/data/ucmerced" - download: true + root: 'tests/data/ucmerced' diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml index 00404756ace..d0a4e6b948a 100644 --- a/tests/conf/vaihingen2d.yaml +++ b/tests/conf/vaihingen2d.yaml @@ -1,9 +1,9 @@ model: class_path: SemanticSegmentationTask init_args: - loss: "ce" - model: "unet" - backbone: "resnet18" + loss: 'ce' + model: 'unet' + backbone: 'resnet18' in_channels: 3 num_classes: 7 num_filters: 1 @@ -15,4 +15,4 @@ data: patch_size: 2 val_split_pct: 0.5 dict_kwargs: - root: "tests/data/vaihingen" + root: 'tests/data/vaihingen' diff --git a/tests/conf/vhr10.yaml b/tests/conf/vhr10.yaml index 0ea0909b971..9faab1187a6 100644 --- a/tests/conf/vhr10.yaml +++ b/tests/conf/vhr10.yaml @@ -1,8 +1,8 @@ model: class_path: ObjectDetectionTask init_args: - model: "faster-rcnn" - backbone: "resnet50" + model: 'faster-rcnn' + backbone: 'resnet50' num_classes: 11 lr: 2.5e-5 patience: 10 @@ -13,5 +13,4 @@ data: num_workers: 0 patch_size: 4 dict_kwargs: - root: "tests/data/vhr10" - download: true + root: 'tests/data/vhr10' diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000000..d55a972ced1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path +from typing import Any + +import matplotlib +import pytest +import torch +import torchvision +from pytest import MonkeyPatch + + +def load(*args: Any, progress: bool = False, **kwargs: Any) -> Any: + return torch.load(*args, **kwargs) + + +@pytest.fixture +def load_state_dict_from_url(monkeypatch: MonkeyPatch) -> None: + monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) + + +@pytest.fixture(autouse=True, scope='session') +def matplotlib_backend() -> None: + matplotlib.use('agg') + + +@pytest.fixture(autouse=True) +def torch_hub(tmp_path: Path) -> None: + torch.hub.set_dir(tmp_path) # type: ignore[no-untyped-call] diff --git a/tests/data/README.md b/tests/data/README.md index 1d95c728d6d..312bce2a8e8 100644 --- a/tests/data/README.md +++ b/tests/data/README.md @@ -20,7 +20,7 @@ with rio.open(os.path.join(ROOT, FILENAME), "r") as src: dtype = src.profile["dtype"] Z = np.random.randint(np.iinfo(dtype).max, size=(SIZE, SIZE), dtype=dtype) with rio.open(FILENAME, "w", **src.profile) as dst: - for i in dst.profile.indexes: + for i in dst.indexes: dst.write(Z, i) ``` diff --git a/tests/data/cabuar/512x512.hdf5 b/tests/data/cabuar/512x512.hdf5 new file mode 100644 index 00000000000..5d8f16529bb Binary files /dev/null and b/tests/data/cabuar/512x512.hdf5 differ diff --git a/tests/data/cabuar/chabud_test.h5 b/tests/data/cabuar/chabud_test.h5 new file mode 100644 index 00000000000..5408b9d27fc Binary files /dev/null and b/tests/data/cabuar/chabud_test.h5 differ diff --git a/tests/data/cabuar/data.py b/tests/data/cabuar/data.py new file mode 100644 index 00000000000..9be447d816d --- /dev/null +++ b/tests/data/cabuar/data.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import random + +import h5py +import numpy as np + +# Sentinel-2 is 12-bit with range 0-4095 +SENTINEL2_MAX = 4096 + +NUM_CHANNELS = 12 +NUM_CLASSES = 2 +SIZE = 32 + +np.random.seed(0) +random.seed(0) + +filenames = ['512x512.hdf5', 'chabud_test.h5'] +fold_mapping = {'train': [1, 2, 3, 4], 'val': [0], 'test': ['chabud']} + +uris = [ + 'feb08801-64b1-4d11-a3fc-0efaad1f4274_0', + 'e4d4dbcb-dd92-40cf-a7fe-fda8dd35f367_1', + '9fc8c1f4-1858-47c3-953e-1dc8b179a', + '3a1358a2-6155-445a-a269-13bebd9741a8_0', + '2f8e659c-f457-4527-a57f-bffc3bbe0baa_0', + '299ee670-19b1-4a76-bef3-34fd55580711_1', + '05cfef86-3e27-42be-a0cb-a61fe2f89e40_0', + '0328d12a-4ad8-4504-8ac5-70089db10b4e_1', + '04800581-b540-4f9b-9df8-7ee433e83f46_0', + '108ae2a9-d7d6-42f7-b89a-90bb75c23ccb_0', + '29413474-04b8-4bb1-8b89-fd640023d4a6_0', + '43f2e60a-73b4-4f33-b99e-319d892fcab4_0', +] +folds = random.choices(fold_mapping['train'], k=4) + [0] * 4 + ['chabud'] * 4 +files = ['512x512.hdf5'] * 8 + ['chabud_test.h5'] * 4 + +# Remove old data +for filename in filenames: + if os.path.exists(filename): + os.remove(filename) + +# Create dataset file +data = np.random.randint( + SENTINEL2_MAX, size=(SIZE, SIZE, NUM_CHANNELS), dtype=np.uint16 +) +gt = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE, 1), dtype=np.uint16) + +for filename, uri, fold in zip(files, uris, folds): + with h5py.File(filename, 'a') as f: + sample = f.create_group(uri) + sample.attrs.create( + name='fold', data=np.int64(fold) if fold != 'chabud' else fold + ) + sample.create_dataset + sample.create_dataset('pre_fire', data=data) + sample.create_dataset('post_fire', data=data) + sample.create_dataset('mask', data=gt) + +# Compute checksums +for filename in filenames: + with open(filename, 'rb') as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f'{filename} md5: {md5}') diff --git a/tests/data/caffe/caffe.zip b/tests/data/caffe/caffe.zip new file mode 100644 index 00000000000..dd4fb2e0d8f Binary files /dev/null and b/tests/data/caffe/caffe.zip differ diff --git a/tests/data/caffe/caffe/fronts/test/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png b/tests/data/caffe/caffe/fronts/test/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png new file mode 100644 index 00000000000..1954382b5d9 Binary files /dev/null and b/tests/data/caffe/caffe/fronts/test/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/fronts/test/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png b/tests/data/caffe/caffe/fronts/test/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png new file mode 100644 index 00000000000..48900577899 Binary files /dev/null and b/tests/data/caffe/caffe/fronts/test/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/fronts/test/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png b/tests/data/caffe/caffe/fronts/test/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png new file mode 100644 index 00000000000..db00b2c14f6 Binary files /dev/null and b/tests/data/caffe/caffe/fronts/test/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/fronts/train/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png b/tests/data/caffe/caffe/fronts/train/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png new file mode 100644 index 00000000000..4831fa31432 Binary files /dev/null and b/tests/data/caffe/caffe/fronts/train/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/fronts/train/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png b/tests/data/caffe/caffe/fronts/train/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png new file mode 100644 index 00000000000..fd51b45f7cf Binary files /dev/null and b/tests/data/caffe/caffe/fronts/train/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/fronts/train/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png b/tests/data/caffe/caffe/fronts/train/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png new file mode 100644 index 00000000000..2d0de1b1a58 Binary files /dev/null and b/tests/data/caffe/caffe/fronts/train/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/fronts/val/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png b/tests/data/caffe/caffe/fronts/val/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png new file mode 100644 index 00000000000..e319dd8c0ea Binary files /dev/null and b/tests/data/caffe/caffe/fronts/val/Crane_2002-11-09_ERS_20_2_061_front__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/fronts/val/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png b/tests/data/caffe/caffe/fronts/val/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png new file mode 100644 index 00000000000..4dbce280254 Binary files /dev/null and b/tests/data/caffe/caffe/fronts/val/Crane_2007-09-22_ENVISAT_20_1_467_front__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/fronts/val/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png b/tests/data/caffe/caffe/fronts/val/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png new file mode 100644 index 00000000000..2539d230fa0 Binary files /dev/null and b/tests/data/caffe/caffe/fronts/val/JAC_2015-12-23_TSX_6_1_005_front__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/sar_images/test/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png b/tests/data/caffe/caffe/sar_images/test/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png new file mode 100644 index 00000000000..86bcc910bc3 Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/test/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/sar_images/test/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png b/tests/data/caffe/caffe/sar_images/test/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png new file mode 100644 index 00000000000..fd776c41faf Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/test/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/sar_images/test/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png b/tests/data/caffe/caffe/sar_images/test/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png new file mode 100644 index 00000000000..1a2eeb76f2d Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/test/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/sar_images/train/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png b/tests/data/caffe/caffe/sar_images/train/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png new file mode 100644 index 00000000000..2ff2c5afa6f Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/train/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/sar_images/train/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png b/tests/data/caffe/caffe/sar_images/train/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png new file mode 100644 index 00000000000..f5626b7e3bd Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/train/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/sar_images/train/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png b/tests/data/caffe/caffe/sar_images/train/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png new file mode 100644 index 00000000000..3d29c6958fc Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/train/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/sar_images/val/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png b/tests/data/caffe/caffe/sar_images/val/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png new file mode 100644 index 00000000000..d77aad8c720 Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/val/Crane_2002-11-09_ERS_20_2_061__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/sar_images/val/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png b/tests/data/caffe/caffe/sar_images/val/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png new file mode 100644 index 00000000000..f01815a9b8a Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/val/Crane_2007-09-22_ENVISAT_20_1_467__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/sar_images/val/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png b/tests/data/caffe/caffe/sar_images/val/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png new file mode 100644 index 00000000000..5c8fe97a308 Binary files /dev/null and b/tests/data/caffe/caffe/sar_images/val/JAC_2015-12-23_TSX_6_1_005__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/zones/test/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png b/tests/data/caffe/caffe/zones/test/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png new file mode 100644 index 00000000000..6c53942e414 Binary files /dev/null and b/tests/data/caffe/caffe/zones/test/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/zones/test/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png b/tests/data/caffe/caffe/zones/test/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png new file mode 100644 index 00000000000..e993e05ed88 Binary files /dev/null and b/tests/data/caffe/caffe/zones/test/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/zones/test/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png b/tests/data/caffe/caffe/zones/test/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png new file mode 100644 index 00000000000..43c8139d220 Binary files /dev/null and b/tests/data/caffe/caffe/zones/test/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/zones/train/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png b/tests/data/caffe/caffe/zones/train/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png new file mode 100644 index 00000000000..51a239bd2da Binary files /dev/null and b/tests/data/caffe/caffe/zones/train/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/zones/train/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png b/tests/data/caffe/caffe/zones/train/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png new file mode 100644 index 00000000000..48d63bf59de Binary files /dev/null and b/tests/data/caffe/caffe/zones/train/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/zones/train/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png b/tests/data/caffe/caffe/zones/train/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png new file mode 100644 index 00000000000..5a6a1e1a63b Binary files /dev/null and b/tests/data/caffe/caffe/zones/train/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/caffe/zones/val/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png b/tests/data/caffe/caffe/zones/val/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png new file mode 100644 index 00000000000..c0f840d16c0 Binary files /dev/null and b/tests/data/caffe/caffe/zones/val/Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png differ diff --git a/tests/data/caffe/caffe/zones/val/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png b/tests/data/caffe/caffe/zones/val/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png new file mode 100644 index 00000000000..0ef56eb78ab Binary files /dev/null and b/tests/data/caffe/caffe/zones/val/Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png differ diff --git a/tests/data/caffe/caffe/zones/val/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png b/tests/data/caffe/caffe/zones/val/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png new file mode 100644 index 00000000000..f93781f2f4e Binary files /dev/null and b/tests/data/caffe/caffe/zones/val/JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png differ diff --git a/tests/data/caffe/data.py b/tests/data/caffe/data.py new file mode 100644 index 00000000000..51caaf37b3b --- /dev/null +++ b/tests/data/caffe/data.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +from PIL import Image + +# Define the root directory and subdirectories +root_dir = 'caffe' +sub_dirs = ['zones', 'sar_images', 'fronts'] +splits = ['train', 'val', 'test'] + +zone_file_names = [ + 'Crane_2002-11-09_ERS_20_2_061_zones__93_102_0_0_0.png', + 'Crane_2007-09-22_ENVISAT_20_1_467_zones__93_102_8_1024_0.png', + 'JAC_2015-12-23_TSX_6_1_005_zones__57_49_195_384_1024.png', +] + +IMG_SIZE = 32 + + +# Function to create dummy images +def create_dummy_image(path: str, shape: tuple[int], pixel_values: list[int]) -> None: + data = np.random.choice(pixel_values, size=shape, replace=True).astype(np.uint8) + img = Image.fromarray(data) + img.save(path) + + +def create_zone_images(split: str, filename: str) -> None: + zone_pixel_values = [0, 64, 127, 254] + path = os.path.join(root_dir, 'zones', split, filename) + create_dummy_image(path, (IMG_SIZE, IMG_SIZE), zone_pixel_values) + + +def create_sar_images(split: str, filename: str) -> None: + sar_pixel_values = range(256) + path = os.path.join(root_dir, 'sar_images', split, filename) + create_dummy_image(path, (IMG_SIZE, IMG_SIZE), sar_pixel_values) + + +def create_front_images(split: str, filename: str) -> None: + front_pixel_values = [0, 255] + path = os.path.join(root_dir, 'fronts', split, filename) + create_dummy_image(path, (IMG_SIZE, IMG_SIZE), front_pixel_values) + + +if os.path.exists(root_dir): + shutil.rmtree(root_dir) + +# Create the directory structure +for sub_dir in sub_dirs: + for split in splits: + os.makedirs(os.path.join(root_dir, sub_dir, split), exist_ok=True) + +# Create dummy data for all splits and filenames +for split in splits: + for filename in zone_file_names: + create_zone_images(split, filename) + create_sar_images(split, filename.replace('_zones_', '_')) + create_front_images(split, filename.replace('_zones_', '_front_')) + +# zip and compute md5 +shutil.make_archive(root_dir, 'zip', '.', root_dir) + + +def md5(fname: str) -> str: + hash_md5 = hashlib.md5() + with open(fname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +md5sum = md5('caffe.zip') +print(f'MD5 checksum: {md5sum}') diff --git a/tests/data/cbf/data.py b/tests/data/cbf/data.py index 6dc5c457f47..07eaf5f0a6e 100755 --- a/tests/data/cbf/data.py +++ b/tests/data/cbf/data.py @@ -9,7 +9,7 @@ import shutil -def create_geojson(): +def create_geojson() -> dict[object, object]: geojson = { 'type': 'FeatureCollection', 'crs': { diff --git a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif deleted file mode 100644 index 8deb2ecb596..00000000000 Binary files a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.tif and /dev/null differ diff --git a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip b/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip deleted file mode 100644 index af6c5db2430..00000000000 Binary files a/tests/data/chesapeake/BAYWIDE/Baywide_13Class_20132014.zip and /dev/null differ diff --git a/tests/data/chesapeake/BAYWIDE/data.py b/tests/data/chesapeake/lulc/data.py similarity index 54% rename from tests/data/chesapeake/BAYWIDE/data.py rename to tests/data/chesapeake/lulc/data.py index dbf807ce0fe..b828baaa180 100755 --- a/tests/data/chesapeake/BAYWIDE/data.py +++ b/tests/data/chesapeake/lulc/data.py @@ -4,8 +4,7 @@ # Licensed under the MIT License. import hashlib -import os -import subprocess +import shutil import numpy as np import rasterio @@ -13,7 +12,6 @@ from rasterio.transform import Affine SIZE = 128 # image width/height -NUM_CLASSES = 14 np.random.seed(0) @@ -41,24 +39,50 @@ AXIS["Easting",EAST], AXIS["Northing",NORTH]] """ -cmap = { - 0: (0, 0, 0, 255), - 1: (0, 197, 255, 255), - 2: (0, 168, 132, 255), - 3: (38, 115, 0, 255), - 4: (76, 230, 0, 255), - 5: (163, 255, 115, 255), - 6: (255, 170, 0, 255), - 7: (255, 0, 0, 255), - 8: (156, 156, 156, 255), - 9: (0, 0, 0, 255), - 10: (115, 115, 0, 255), - 11: (230, 230, 0, 255), - 12: (255, 255, 115, 255), - 13: (197, 0, 255, 255), -} +values = [ + 11, + 12, + 13, + 14, + 15, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 41, + 42, + 51, + 52, + 53, + 54, + 55, + 56, + 62, + 63, + 64, + 65, + 72, + 73, + 74, + 75, + 83, + 84, + 85, + 91, + 92, + 93, + 94, + 95, + 127, +] + meta = { 'driver': 'GTiff', 'dtype': 'uint8', @@ -70,26 +94,18 @@ 'transform': Affine(1.0, 0.0, 1303555.0000000005, 0.0, -1.0, 2535064.999999998), } -# Remove old data -if os.path.exists(f'{filename}.tif'): - os.remove(f'{filename}.tif') +for state in ['dc', 'de', 'md', 'ny', 'pa', 'va', 'wv']: + filename = f'{state}_lulc_2018_2022-Edition' -# Create raster file -with rasterio.open(f'{filename}.tif', 'w', **meta) as f: - data = np.random.randint(NUM_CLASSES, size=(SIZE, SIZE), dtype=np.uint8) - f.write(data, 1) - f.write_colormap(1, cmap) + # Create raster file + with rasterio.open(f'{filename}.tif', 'w', **meta) as f: + data = np.random.choice(values, size=(SIZE, SIZE)) + f.write(data, 1) -# Create zip file -# 7z required to create a zip file using the proprietary DEFLATE64 compression algorithm -# https://github.com/brianhelba/zipfile-deflate64/issues/19#issuecomment-1006077294 -subprocess.run( - ['7z', 'a', f'{filename}.zip', '-mm=DEFLATE64', f'{filename}.tif'], - capture_output=True, - check=True, -) + # Compress file + shutil.make_archive(filename, 'zip', '.', filename + '.tif') -# Compute checksums -with open(f'{filename}.zip', 'rb') as f: - md5 = hashlib.md5(f.read()).hexdigest() - print(repr(md5)) + # Compute checksums + with open(f'{filename}.zip', 'rb') as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(state, repr(md5)) diff --git a/tests/data/chesapeake/lulc/dc_lulc_2018_2022-Edition.tif b/tests/data/chesapeake/lulc/dc_lulc_2018_2022-Edition.tif new file mode 100644 index 00000000000..ca32088164f Binary files /dev/null and b/tests/data/chesapeake/lulc/dc_lulc_2018_2022-Edition.tif differ diff --git a/tests/data/chesapeake/lulc/dc_lulc_2018_2022-Edition.zip b/tests/data/chesapeake/lulc/dc_lulc_2018_2022-Edition.zip new file mode 100644 index 00000000000..9cadbeeadc3 Binary files /dev/null and b/tests/data/chesapeake/lulc/dc_lulc_2018_2022-Edition.zip differ diff --git a/tests/data/chesapeake/lulc/de_lulc_2018_2022-Edition.tif b/tests/data/chesapeake/lulc/de_lulc_2018_2022-Edition.tif new file mode 100644 index 00000000000..21bf4f29452 Binary files /dev/null and b/tests/data/chesapeake/lulc/de_lulc_2018_2022-Edition.tif differ diff --git a/tests/data/chesapeake/lulc/de_lulc_2018_2022-Edition.zip b/tests/data/chesapeake/lulc/de_lulc_2018_2022-Edition.zip new file mode 100644 index 00000000000..eb2a3aa3d5a Binary files /dev/null and b/tests/data/chesapeake/lulc/de_lulc_2018_2022-Edition.zip differ diff --git a/tests/data/chesapeake/lulc/md_lulc_2018_2022-Edition.tif b/tests/data/chesapeake/lulc/md_lulc_2018_2022-Edition.tif new file mode 100644 index 00000000000..2135a007504 Binary files /dev/null and b/tests/data/chesapeake/lulc/md_lulc_2018_2022-Edition.tif differ diff --git a/tests/data/chesapeake/lulc/md_lulc_2018_2022-Edition.zip b/tests/data/chesapeake/lulc/md_lulc_2018_2022-Edition.zip new file mode 100644 index 00000000000..a27942225c4 Binary files /dev/null and b/tests/data/chesapeake/lulc/md_lulc_2018_2022-Edition.zip differ diff --git a/tests/data/chesapeake/lulc/ny_lulc_2018_2022-Edition.tif b/tests/data/chesapeake/lulc/ny_lulc_2018_2022-Edition.tif new file mode 100644 index 00000000000..0288f062aec Binary files /dev/null and b/tests/data/chesapeake/lulc/ny_lulc_2018_2022-Edition.tif differ diff --git a/tests/data/chesapeake/lulc/ny_lulc_2018_2022-Edition.zip b/tests/data/chesapeake/lulc/ny_lulc_2018_2022-Edition.zip new file mode 100644 index 00000000000..fee8fcc8665 Binary files /dev/null and b/tests/data/chesapeake/lulc/ny_lulc_2018_2022-Edition.zip differ diff --git a/tests/data/chesapeake/lulc/pa_lulc_2018_2022-Edition.tif b/tests/data/chesapeake/lulc/pa_lulc_2018_2022-Edition.tif new file mode 100644 index 00000000000..ef61f53183f Binary files /dev/null and b/tests/data/chesapeake/lulc/pa_lulc_2018_2022-Edition.tif differ diff --git a/tests/data/chesapeake/lulc/pa_lulc_2018_2022-Edition.zip b/tests/data/chesapeake/lulc/pa_lulc_2018_2022-Edition.zip new file mode 100644 index 00000000000..4afe96a5328 Binary files /dev/null and b/tests/data/chesapeake/lulc/pa_lulc_2018_2022-Edition.zip differ diff --git a/tests/data/chesapeake/lulc/va_lulc_2018_2022-Edition.tif b/tests/data/chesapeake/lulc/va_lulc_2018_2022-Edition.tif new file mode 100644 index 00000000000..00f09fae24f Binary files /dev/null and b/tests/data/chesapeake/lulc/va_lulc_2018_2022-Edition.tif differ diff --git a/tests/data/chesapeake/lulc/va_lulc_2018_2022-Edition.zip b/tests/data/chesapeake/lulc/va_lulc_2018_2022-Edition.zip new file mode 100644 index 00000000000..795cae8e293 Binary files /dev/null and b/tests/data/chesapeake/lulc/va_lulc_2018_2022-Edition.zip differ diff --git a/tests/data/chesapeake/lulc/wv_lulc_2018_2022-Edition.tif b/tests/data/chesapeake/lulc/wv_lulc_2018_2022-Edition.tif new file mode 100644 index 00000000000..217d64e1eb6 Binary files /dev/null and b/tests/data/chesapeake/lulc/wv_lulc_2018_2022-Edition.tif differ diff --git a/tests/data/chesapeake/lulc/wv_lulc_2018_2022-Edition.zip b/tests/data/chesapeake/lulc/wv_lulc_2018_2022-Edition.zip new file mode 100644 index 00000000000..bb902000363 Binary files /dev/null and b/tests/data/chesapeake/lulc/wv_lulc_2018_2022-Edition.zip differ diff --git a/tests/data/cropharvest/data.py b/tests/data/cropharvest/data.py index 5bf85d21f84..58ad7824657 100755 --- a/tests/data/cropharvest/data.py +++ b/tests/data/cropharvest/data.py @@ -24,7 +24,7 @@ ] -def create_geojson(): +def create_geojson() -> dict[object, object]: geojson = { 'type': 'FeatureCollection', 'crs': {}, diff --git a/tests/data/cv4a_kenya_crop_type/FieldIds.csv b/tests/data/cv4a_kenya_crop_type/FieldIds.csv new file mode 100644 index 00000000000..04ff33b2500 --- /dev/null +++ b/tests/data/cv4a_kenya_crop_type/FieldIds.csv @@ -0,0 +1,5 @@ +train,test +1,2 +3,4 +5 +6 diff --git a/tests/data/cv4a_kenya_crop_type/data.py b/tests/data/cv4a_kenya_crop_type/data.py new file mode 100755 index 00000000000..e55ffa45b64 --- /dev/null +++ b/tests/data/cv4a_kenya_crop_type/data.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +from PIL import Image + +DTYPE = np.float32 +SIZE = 2 + +np.random.seed(0) + +all_bands = ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + 'CLD', +) + +for tile in range(1): + directory = os.path.join('data', str(tile)) + os.makedirs(directory, exist_ok=True) + + arr = np.random.randint(np.iinfo(np.int32).max, size=(SIZE, SIZE), dtype=np.int32) + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_field_id.tif')) + + arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8) + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_label.tif')) + + for date in ['20190606']: + directory = os.path.join(directory, date) + os.makedirs(directory, exist_ok=True) + + for band in all_bands: + arr = np.random.rand(SIZE, SIZE).astype(DTYPE) * np.finfo(DTYPE).max + img = Image.fromarray(arr) + img.save(os.path.join(directory, f'{tile}_{band}_{date}.tif')) diff --git a/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif b/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif new file mode 100644 index 00000000000..f72a6772091 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/0_field_id.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/0_label.tif b/tests/data/cv4a_kenya_crop_type/data/0/0_label.tif new file mode 100644 index 00000000000..c0555ad107b Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/0_label.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif new file mode 100644 index 00000000000..1311d977f55 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B01_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif new file mode 100644 index 00000000000..ad41e11ea7b Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B02_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif new file mode 100644 index 00000000000..294e70e13f8 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B03_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif new file mode 100644 index 00000000000..704c8dfc23d Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B04_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif new file mode 100644 index 00000000000..a0aa5478a3a Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B05_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif new file mode 100644 index 00000000000..834e92f43b5 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B06_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif new file mode 100644 index 00000000000..58f58df0767 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B07_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif new file mode 100644 index 00000000000..f534bde3167 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B08_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif new file mode 100644 index 00000000000..b931b7189b0 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B09_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif new file mode 100644 index 00000000000..ea661cbc40e Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B11_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif new file mode 100644 index 00000000000..017b1714532 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B12_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif new file mode 100644 index 00000000000..1e3f7ce38b8 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_B8A_20190606.tif differ diff --git a/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif new file mode 100644 index 00000000000..1ec85420866 Binary files /dev/null and b/tests/data/cv4a_kenya_crop_type/data/0/20190606/0_CLD_20190606.tif differ diff --git a/tests/data/cyclone/data.py b/tests/data/cyclone/data.py new file mode 100755 index 00000000000..2ea0f7a425a --- /dev/null +++ b/tests/data/cyclone/data.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import pandas as pd +from PIL import Image + +DTYPE = np.uint8 +SIZE = 2 + +np.random.seed(0) + +for split in ['train', 'test']: + os.makedirs(split, exist_ok=True) + + filename = split + if split == 'train': + filename = 'training' + + features = pd.read_csv(f'{filename}_set_features.csv') + for image_id, _, _, ocean in features.values: + size = (SIZE, SIZE) + if ocean % 2 == 0: + size = (SIZE * 2, SIZE * 2, 3) + + arr = np.random.randint(np.iinfo(DTYPE).max, size=size, dtype=DTYPE) + img = Image.fromarray(arr) + img.save(os.path.join(split, f'{image_id}.jpg')) diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz deleted file mode 100644 index cbfa3779d9a..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json deleted file mode 100644 index a5692a66e5e..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_test_labels_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz deleted file mode 100644 index 7a8162fafdf..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json deleted file mode 100644 index 97c44e9907a..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_test_source_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json deleted file mode 100644 index 83438ddffa4..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "a", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json deleted file mode 100644 index 13f4a63afaa..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "b", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json deleted file mode 100644 index d8671e26416..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "c", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json deleted file mode 100644 index a6eebd660e0..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "d", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json deleted file mode 100644 index 90267dc6f1f..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "e", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz deleted file mode 100644 index 83f9138674e..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json deleted file mode 100644 index 834d293998a..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_train_labels_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json deleted file mode 100644 index e59bae96dc9..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz deleted file mode 100644 index b3f019e97c7..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json deleted file mode 100644 index a03e0c77a19..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_train_source_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json deleted file mode 100644 index 83438ddffa4..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "a", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json deleted file mode 100644 index 13f4a63afaa..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "b", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json deleted file mode 100644 index d8671e26416..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "c", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json deleted file mode 100644 index a6eebd660e0..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "d", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json deleted file mode 100644 index 90267dc6f1f..00000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "e", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg deleted file mode 100644 index 79c38f2a929..00000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/test/aaa_000.jpg b/tests/data/cyclone/test/aaa_000.jpg new file mode 100644 index 00000000000..f4d039da97c Binary files /dev/null and b/tests/data/cyclone/test/aaa_000.jpg differ diff --git a/tests/data/cyclone/test/bbb_111.jpg b/tests/data/cyclone/test/bbb_111.jpg new file mode 100644 index 00000000000..0d8e7a84a23 Binary files /dev/null and b/tests/data/cyclone/test/bbb_111.jpg differ diff --git a/tests/data/cyclone/test/ccc_222.jpg b/tests/data/cyclone/test/ccc_222.jpg new file mode 100644 index 00000000000..ebd3ba67c09 Binary files /dev/null and b/tests/data/cyclone/test/ccc_222.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg b/tests/data/cyclone/test/ddd_333.jpg similarity index 74% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg rename to tests/data/cyclone/test/ddd_333.jpg index 79c38f2a929..575d5a5c69f 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg and b/tests/data/cyclone/test/ddd_333.jpg differ diff --git a/tests/data/cyclone/test/eee_444.jpg b/tests/data/cyclone/test/eee_444.jpg new file mode 100644 index 00000000000..0cd10728e84 Binary files /dev/null and b/tests/data/cyclone/test/eee_444.jpg differ diff --git a/tests/data/cyclone/test_set_features.csv b/tests/data/cyclone/test_set_features.csv new file mode 100644 index 00000000000..dce291b0b5c --- /dev/null +++ b/tests/data/cyclone/test_set_features.csv @@ -0,0 +1,6 @@ +Image ID,Storm ID,Relative Time,Ocean +aaa_000,aaa,0,0 +bbb_111,bbb,1,1 +ccc_222,ccc,2,2 +ddd_333,ddd,3,3 +eee_444,eee,4,4 diff --git a/tests/data/cyclone/test_set_labels.csv b/tests/data/cyclone/test_set_labels.csv new file mode 100644 index 00000000000..8aa2d7c7f67 --- /dev/null +++ b/tests/data/cyclone/test_set_labels.csv @@ -0,0 +1,6 @@ +Image ID,Wind Speed +aaa_000,0 +bbb_111,1 +ccc_222,2 +ddd_333,3 +eee_444,4 diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg b/tests/data/cyclone/train/fff_555.jpg similarity index 73% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg rename to tests/data/cyclone/train/fff_555.jpg index 79c38f2a929..15225859b03 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg and b/tests/data/cyclone/train/fff_555.jpg differ diff --git a/tests/data/cyclone/train/ggg_666.jpg b/tests/data/cyclone/train/ggg_666.jpg new file mode 100644 index 00000000000..3065b52a80b Binary files /dev/null and b/tests/data/cyclone/train/ggg_666.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg b/tests/data/cyclone/train/hhh_777.jpg similarity index 75% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg rename to tests/data/cyclone/train/hhh_777.jpg index 79c38f2a929..877ac76c481 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg and b/tests/data/cyclone/train/hhh_777.jpg differ diff --git a/tests/data/cyclone/train/iii_888.jpg b/tests/data/cyclone/train/iii_888.jpg new file mode 100644 index 00000000000..731128b8a0c Binary files /dev/null and b/tests/data/cyclone/train/iii_888.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg b/tests/data/cyclone/train/jjj_999.jpg similarity index 75% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg rename to tests/data/cyclone/train/jjj_999.jpg index 79c38f2a929..8fda5ace924 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg and b/tests/data/cyclone/train/jjj_999.jpg differ diff --git a/tests/data/cyclone/training_set_features.csv b/tests/data/cyclone/training_set_features.csv new file mode 100644 index 00000000000..56df786e8ef --- /dev/null +++ b/tests/data/cyclone/training_set_features.csv @@ -0,0 +1,6 @@ +Image ID,Storm ID,Relative Time,Ocean +fff_555,fff,5,5 +ggg_666,ggg,6,6 +hhh_777,hhh,7,7 +iii_888,iii,8,8 +jjj_999,jjj,9,9 diff --git a/tests/data/cyclone/training_set_labels.csv b/tests/data/cyclone/training_set_labels.csv new file mode 100644 index 00000000000..5a8bbabce8c --- /dev/null +++ b/tests/data/cyclone/training_set_labels.csv @@ -0,0 +1,6 @@ +Image ID,Wind Speed +fff_555,5 +ggg_666,6 +hhh_777,7 +iii_888,8 +jjj_999,9 diff --git a/tests/data/dfc2022/data.py b/tests/data/dfc2022/data.py index 39d41f5d945..67323262452 100755 --- a/tests/data/dfc2022/data.py +++ b/tests/data/dfc2022/data.py @@ -19,36 +19,36 @@ train_set = [ { - 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 - 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', # noqa: E501 + 'image': 'labeled_train/Nantes_Saint-Nazaire/BDORTHO/44-2013-0295-6713-LA93-0M50-E080.tif', + 'dem': 'labeled_train/Nantes_Saint-Nazaire/RGEALTI/44-2013-0295-6713-LA93-0M50-E080_RGEALTI.tif', + 'target': 'labeled_train/Nantes_Saint-Nazaire/UrbanAtlas/44-2013-0295-6713-LA93-0M50-E080_UA2012.tif', }, { - 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 - 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', # noqa: E501 + 'image': 'labeled_train/Nice/BDORTHO/06-2014-1007-6318-LA93-0M50-E080.tif', + 'dem': 'labeled_train/Nice/RGEALTI/06-2014-1007-6318-LA93-0M50-E080_RGEALTI.tif', + 'target': 'labeled_train/Nice/UrbanAtlas/06-2014-1007-6318-LA93-0M50-E080_UA2012.tif', }, ] unlabeled_set = [ { - 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'unlabeled_train/Calais_Dunkerque/BDORTHO/59-2012-0650-7077-LA93-0M50-E080.tif', + 'dem': 'unlabeled_train/Calais_Dunkerque/RGEALTI/59-2012-0650-7077-LA93-0M50-E080_RGEALTI.tif', }, { - 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'unlabeled_train/LeMans/BDORTHO/72-2013-0469-6789-LA93-0M50-E080.tif', + 'dem': 'unlabeled_train/LeMans/RGEALTI/72-2013-0469-6789-LA93-0M50-E080_RGEALTI.tif', }, ] val_set = [ { - 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'val/Marseille_Martigues/BDORTHO/13-2014-0900-6268-LA93-0M50-E080.tif', + 'dem': 'val/Marseille_Martigues/RGEALTI/13-2014-0900-6268-LA93-0M50-E080_RGEALTI.tif', }, { - 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', # noqa: E501 - 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', # noqa: E501 + 'image': 'val/Clermont-Ferrand/BDORTHO/63-2013-0711-6530-LA93-0M50-E080.tif', + 'dem': 'val/Clermont-Ferrand/RGEALTI/63-2013-0711-6530-LA93-0M50-E080_RGEALTI.tif', }, ] diff --git a/tests/data/digital_typhoon/WP.tar.gz b/tests/data/digital_typhoon/WP.tar.gz new file mode 100644 index 00000000000..3d707e3a5b5 Binary files /dev/null and b/tests/data/digital_typhoon/WP.tar.gz differ diff --git a/tests/data/digital_typhoon/WP.tar.gzaa b/tests/data/digital_typhoon/WP.tar.gzaa new file mode 100644 index 00000000000..3d707e3a5b5 Binary files /dev/null and b/tests/data/digital_typhoon/WP.tar.gzaa differ diff --git a/tests/data/digital_typhoon/WP.tar.gzab b/tests/data/digital_typhoon/WP.tar.gzab new file mode 100644 index 00000000000..3d707e3a5b5 Binary files /dev/null and b/tests/data/digital_typhoon/WP.tar.gzab differ diff --git a/tests/data/digital_typhoon/WP/aux_data.csv b/tests/data/digital_typhoon/WP/aux_data.csv new file mode 100644 index 00000000000..81864dd076c --- /dev/null +++ b/tests/data/digital_typhoon/WP/aux_data.csv @@ -0,0 +1,26 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404 +0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465 +0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531 +0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707 +0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033 +1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646 +1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944 +1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444 +1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257 +1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154 +2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457 +2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529 +2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371 +2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041 +2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034 +3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104 +3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658 +3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375 +3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168 +3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622 +4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048 +4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288 +4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964 +4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343 +4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497 diff --git a/tests/data/digital_typhoon/WP/image/0/0.h5 b/tests/data/digital_typhoon/WP/image/0/0.h5 new file mode 100644 index 00000000000..235ea1897f3 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/1.h5 b/tests/data/digital_typhoon/WP/image/0/1.h5 new file mode 100644 index 00000000000..98ece1b9351 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/2.h5 b/tests/data/digital_typhoon/WP/image/0/2.h5 new file mode 100644 index 00000000000..40cd7317d40 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/3.h5 b/tests/data/digital_typhoon/WP/image/0/3.h5 new file mode 100644 index 00000000000..6f2be498621 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/0/4.h5 b/tests/data/digital_typhoon/WP/image/0/4.h5 new file mode 100644 index 00000000000..731298cdd32 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/0/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/0.h5 b/tests/data/digital_typhoon/WP/image/1/0.h5 new file mode 100644 index 00000000000..d6009711570 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/1.h5 b/tests/data/digital_typhoon/WP/image/1/1.h5 new file mode 100644 index 00000000000..3f636ec3afc Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/2.h5 b/tests/data/digital_typhoon/WP/image/1/2.h5 new file mode 100644 index 00000000000..71acdc32c82 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/3.h5 b/tests/data/digital_typhoon/WP/image/1/3.h5 new file mode 100644 index 00000000000..65b76ff2f32 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/1/4.h5 b/tests/data/digital_typhoon/WP/image/1/4.h5 new file mode 100644 index 00000000000..df52fb412fd Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/1/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/0.h5 b/tests/data/digital_typhoon/WP/image/2/0.h5 new file mode 100644 index 00000000000..d391fab0d71 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/1.h5 b/tests/data/digital_typhoon/WP/image/2/1.h5 new file mode 100644 index 00000000000..7b80f60255b Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/2.h5 b/tests/data/digital_typhoon/WP/image/2/2.h5 new file mode 100644 index 00000000000..c108210a0e5 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/3.h5 b/tests/data/digital_typhoon/WP/image/2/3.h5 new file mode 100644 index 00000000000..2f1f14b9a51 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/2/4.h5 b/tests/data/digital_typhoon/WP/image/2/4.h5 new file mode 100644 index 00000000000..4e0fcb578fd Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/2/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/0.h5 b/tests/data/digital_typhoon/WP/image/3/0.h5 new file mode 100644 index 00000000000..d04cc4f79c0 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/1.h5 b/tests/data/digital_typhoon/WP/image/3/1.h5 new file mode 100644 index 00000000000..65ac943c680 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/2.h5 b/tests/data/digital_typhoon/WP/image/3/2.h5 new file mode 100644 index 00000000000..1ab8b197980 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/3.h5 b/tests/data/digital_typhoon/WP/image/3/3.h5 new file mode 100644 index 00000000000..9fcab04f7d1 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/3/4.h5 b/tests/data/digital_typhoon/WP/image/3/4.h5 new file mode 100644 index 00000000000..ccd6e248d55 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/3/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/0.h5 b/tests/data/digital_typhoon/WP/image/4/0.h5 new file mode 100644 index 00000000000..64fc6bd1b53 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/0.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/1.h5 b/tests/data/digital_typhoon/WP/image/4/1.h5 new file mode 100644 index 00000000000..1d66c74238f Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/1.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/2.h5 b/tests/data/digital_typhoon/WP/image/4/2.h5 new file mode 100644 index 00000000000..7353050bcd1 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/2.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/3.h5 b/tests/data/digital_typhoon/WP/image/4/3.h5 new file mode 100644 index 00000000000..f7185764d80 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/3.h5 differ diff --git a/tests/data/digital_typhoon/WP/image/4/4.h5 b/tests/data/digital_typhoon/WP/image/4/4.h5 new file mode 100644 index 00000000000..3fde7973bb3 Binary files /dev/null and b/tests/data/digital_typhoon/WP/image/4/4.h5 differ diff --git a/tests/data/digital_typhoon/WP/metadata/0.csv b/tests/data/digital_typhoon/WP/metadata/0.csv new file mode 100644 index 00000000000..df1c40f5fba --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/0.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +0,0.h5,1979,12,25,6,3,-55.81114066899345,76.6995939240727,973.8743108424701,44.98399850309952,66,71,75,137,25,95,1,1,1.h5,mask_40,89.87979469874404 +0,1.h5,1979,12,25,7,3,-33.621634184914114,-25.860702927919903,903.8203398162416,6.3832427352565,230,28,61,111,72,4,1,0,2.h5,mask_40,55.86768840838465 +0,2.h5,1979,12,25,8,3,72.02964248591297,-47.48138416430828,982.76724331446,0.027966770724696666,342,76,23,337,49,19,1,0,3.h5,mask_49,55.18449786430531 +0,3.h5,1979,12,25,9,2,55.920575184851316,13.989913225833078,906.0181106433341,51.01642134825744,330,90,52,258,44,65,1,1,4.h5,mask_86,15.969129252036707 +0,4.h5,1979,12,25,10,2,-43.28994147714503,-161.94483446959413,903.9366550400755,16.7093617045847,242,62,99,132,63,0,1,1,5.h5,mask_66,70.21971067939033 diff --git a/tests/data/digital_typhoon/WP/metadata/1.csv b/tests/data/digital_typhoon/WP/metadata/1.csv new file mode 100644 index 00000000000..3dd6e71bcc1 --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/1.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +1,0.h5,1988,1,22,10,2,-33.37129190053344,-115.29637290040873,948.0758912152131,51.11399505734963,118,15,67,232,63,86,1,1,1.h5,mask_15,30.245077213336646 +1,1.h5,1988,1,22,11,2,74.93228846926493,70.74999801636073,910.1992664115785,60.8348103266534,266,41,67,48,44,16,1,0,2.h5,mask_90,42.30390416164944 +1,2.h5,1988,1,22,12,2,-27.931601464223597,-141.3019006863473,961.5531323907394,18.35497901874176,19,61,24,295,50,26,1,1,3.h5,mask_67,60.35785307941444 +1,3.h5,1988,1,22,13,3,-27.166703710913154,-27.976214499674484,904.1165949703977,9.081723951290567,144,43,66,22,32,48,0,1,4.h5,mask_3,80.04417033291257 +1,4.h5,1988,1,22,14,2,47.51657289770864,-138.58539565379158,950.9654977977864,86.18819130981862,175,75,89,42,19,70,0,1,5.h5,mask_96,0.44001778199053154 diff --git a/tests/data/digital_typhoon/WP/metadata/2.csv b/tests/data/digital_typhoon/WP/metadata/2.csv new file mode 100644 index 00000000000..9f43c8edca2 --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/2.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +2,0.h5,1998,8,23,22,2,71.11037770397022,-170.05883586527145,902.757696015989,64.83605229043086,308,32,54,249,94,13,1,0,1.h5,mask_87,97.96789767456457 +2,1.h5,1998,8,23,23,2,-45.9880469141837,-153.85203885662787,956.1578736191437,95.77226625568278,230,17,58,214,72,21,1,0,2.h5,mask_66,48.1513473689529 +2,2.h5,1998,8,24,0,4,-88.778300647409,-78.43060469893915,958.764771469677,17.97662971655637,127,41,19,138,89,36,1,1,3.h5,mask_57,76.31799924098371 +2,3.h5,1998,8,24,1,2,-49.56689955810804,-120.3389762632577,986.4933451650326,49.259894810485605,333,90,28,51,45,99,1,0,4.h5,mask_92,65.60333971250041 +2,4.h5,1998,8,24,2,3,-52.55231579306487,80.06217230886841,997.4333837891787,48.25976623703225,63,7,13,71,55,58,1,1,5.h5,mask_73,50.634737551399034 diff --git a/tests/data/digital_typhoon/WP/metadata/3.csv b/tests/data/digital_typhoon/WP/metadata/3.csv new file mode 100644 index 00000000000..6144192b6ea --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/3.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +3,0.h5,1997,4,24,16,4,-61.81374526076493,60.62026564332362,900.1093638487514,94.66595722320622,189,70,67,249,12,58,0,1,1.h5,mask_93,99.77561346276104 +3,1.h5,1997,4,24,17,3,35.596382297289026,-117.20301531275722,925.1366339770796,34.46028512732848,55,55,74,11,0,49,1,1,2.h5,mask_11,5.726401727423658 +3,2.h5,1997,4,24,18,1,68.16880747309938,30.42194122117013,955.7265683876137,96.55057639044118,217,22,60,6,18,9,1,1,3.h5,mask_63,58.982331802755375 +3,3.h5,1997,4,24,19,3,-5.491619122910365,141.83240318855258,922.5486496962513,89.2199247408618,49,26,14,245,95,84,1,0,4.h5,mask_38,76.01607012923168 +3,4.h5,1997,4,24,20,4,4.052162855787202,21.732867986138842,990.5791999912764,98.40094253121877,158,86,11,28,11,81,0,0,5.h5,mask_12,75.84036894650622 diff --git a/tests/data/digital_typhoon/WP/metadata/4.csv b/tests/data/digital_typhoon/WP/metadata/4.csv new file mode 100644 index 00000000000..c2267e37fe9 --- /dev/null +++ b/tests/data/digital_typhoon/WP/metadata/4.csv @@ -0,0 +1,6 @@ +id,image_path,year,month,day,hour,grade,lat,lng,pressure,wind,dir50,long50,short50,dir30,long30,short30,landfall,intp,file_1,mask_1,mask_1_pct +4,0.h5,1984,6,16,14,3,53.238650326925125,-54.63854263302531,934.2198641027621,18.697921579520305,212,16,42,91,90,56,1,1,1.h5,mask_72,78.93081269669048 +4,1.h5,1984,6,16,15,2,-56.222689844694024,-6.8726887962189664,912.6113238303491,61.286246561868666,60,81,2,198,64,76,1,0,2.h5,mask_64,24.039173626000288 +4,2.h5,1984,6,16,16,2,-4.285643464886363,95.66534210331434,962.0580147775602,86.01251389789185,281,81,5,228,18,94,0,0,3.h5,mask_66,89.89080488339964 +4,3.h5,1984,6,16,17,2,89.15893201203946,124.94143678744513,997.342814284227,84.00590505469005,242,28,61,132,80,29,0,0,4.h5,mask_77,4.839048143310343 +4,4.h5,1984,6,16,18,1,-46.31233638346047,21.77073986978661,932.8378121656477,26.18973887839292,294,76,57,252,99,27,1,0,5.h5,mask_65,89.74882055138497 diff --git a/tests/data/digital_typhoon/data.py b/tests/data/digital_typhoon/data.py new file mode 100644 index 00000000000..bd5346e5d54 --- /dev/null +++ b/tests/data/digital_typhoon/data.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil + +import h5py +import numpy as np +import pandas as pd +from torchvision.datasets.utils import calculate_md5 + +# Define the root directory +root = 'WP' +IMAGE_SIZE = 32 +NUM_TYHOON_IDS = 5 +NUM_IMAGES_PER_ID = 5 +CHUNK_SIZE = 2**12 + +# If the root directory exists, remove it +if os.path.exists(root): + shutil.rmtree(root) + +# Create the 'image' and 'metadata' directories +os.makedirs(os.path.join(root, 'image')) +os.makedirs(os.path.join(root, 'metadata')) + +# For each typhoon_id +all_dfs = [] +for typhoon_id in range(NUM_TYHOON_IDS): + # Create a directory under 'root/image/typhoon_id/' + os.makedirs(os.path.join(root, 'image', str(typhoon_id)), exist_ok=True) + + # Create dummy .h5 files + image_paths_per_typhoon = [] + for image_id in range(NUM_IMAGES_PER_ID): + image_file_name = f'{image_id}.h5' + with h5py.File( + os.path.join(root, 'image', str(typhoon_id), image_file_name), 'w' + ) as hf: + hf.create_dataset('Infrared', data=np.random.rand(IMAGE_SIZE, IMAGE_SIZE)) + image_paths_per_typhoon.append(image_file_name) + + start_time = pd.Timestamp( + year=np.random.randint(1978, 2022), + month=np.random.randint(1, 13), + day=np.random.randint(1, 29), + hour=np.random.randint(0, 24), + ) + times = pd.date_range(start=start_time, periods=NUM_IMAGES_PER_ID, freq='H') + df = pd.DataFrame( + { + 'id': np.repeat(typhoon_id, NUM_IMAGES_PER_ID), + 'image_path': image_paths_per_typhoon, + 'year': times.year, + 'month': times.month, + 'day': times.day, + 'hour': times.hour, + 'grade': np.random.randint(1, 5, NUM_IMAGES_PER_ID), + 'lat': np.random.uniform(-90, 90, NUM_IMAGES_PER_ID), + 'lng': np.random.uniform(-180, 180, NUM_IMAGES_PER_ID), + 'pressure': np.random.uniform(900, 1000, NUM_IMAGES_PER_ID), + 'wind': np.random.uniform(0, 100, NUM_IMAGES_PER_ID), + 'dir50': np.random.randint(0, 360, NUM_IMAGES_PER_ID), + 'long50': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'short50': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'dir30': np.random.randint(0, 360, NUM_IMAGES_PER_ID), + 'long30': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'short30': np.random.randint(0, 100, NUM_IMAGES_PER_ID), + 'landfall': np.random.randint(0, 2, NUM_IMAGES_PER_ID), + 'intp': np.random.randint(0, 2, NUM_IMAGES_PER_ID), + 'file_1': [f'{idx}.h5' for idx in range(1, NUM_IMAGES_PER_ID + 1)], + 'mask_1': [ + 'mask_' + str(i) for i in np.random.randint(1, 100, NUM_IMAGES_PER_ID) + ], + 'mask_1_pct': np.random.uniform(0, 100, NUM_IMAGES_PER_ID), + } + ) + + # Save the DataFrame to corresponding typhoon id as metadata + df.to_csv(os.path.join(root, 'metadata', f'{typhoon_id}.csv'), index=False) + + all_dfs.append(df) + +# Save the aux_data.csv +aux_data = pd.concat(all_dfs) +aux_data.to_csv(os.path.join(root, 'aux_data.csv'), index=False) + + +# Create tarball +shutil.make_archive(root, 'gztar', '.', root) + +# simulate multiple tar files +path = f'{root}.tar.gz' +paths = [] +with open(path, 'rb') as f: + # Write the entire tarball to gzaa + split = f'{path}aa' + with open(split, 'wb') as g: + g.write(f.read()) + paths.append(split) + +# Create gzab as a copy of gzaa +shutil.copy2(f'{path}aa', f'{path}ab') +paths.append(f'{path}ab') + + +# Calculate the md5sum of the tar file +for path in paths: + print(f'{path}: {calculate_md5(path)}') diff --git a/tests/data/eurocrops/data.py b/tests/data/eurocrops/data.py index 4128407a3e1..064ecb34052 100755 --- a/tests/data/eurocrops/data.py +++ b/tests/data/eurocrops/data.py @@ -18,7 +18,7 @@ SIZE = 1280 -def create_data_file(dataname): +def create_data_file(dataname: str) -> None: schema = {'geometry': 'Polygon', 'properties': {'EC_hcat_c': 'str'}} with fiona.open( dataname, 'w', crs=CRS.from_epsg(32616), driver='ESRI Shapefile', schema=schema @@ -33,7 +33,7 @@ def create_data_file(dataname): shpfile.write({'geometry': mapping(polygon), 'properties': properties}) -def create_csv(fname): +def create_csv(fname: str) -> None: with open(fname, 'w') as f: writer = csv.DictWriter(f, fieldnames=['HCAT2_code']) writer.writeheader() diff --git a/tests/data/eurosat/EuroSAT100.zip b/tests/data/eurosat/EuroSAT100.zip new file mode 100644 index 00000000000..ed2eb18d324 Binary files /dev/null and b/tests/data/eurosat/EuroSAT100.zip differ diff --git a/tests/data/eurosat/eurosat-100-test.txt b/tests/data/eurosat/eurosat-100-test.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-100-test.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-100-train.txt b/tests/data/eurosat/eurosat-100-train.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-100-train.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-100-val.txt b/tests/data/eurosat/eurosat-100-val.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-100-val.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-spatial-test.txt b/tests/data/eurosat/eurosat-spatial-test.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-spatial-test.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-spatial-train.txt b/tests/data/eurosat/eurosat-spatial-train.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-spatial-train.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/eurosat/eurosat-spatial-val.txt b/tests/data/eurosat/eurosat-spatial-val.txt new file mode 100644 index 00000000000..debeff4c852 --- /dev/null +++ b/tests/data/eurosat/eurosat-spatial-val.txt @@ -0,0 +1,2 @@ +AnnualCrop_1.tif +Forest_1.tif diff --git a/tests/data/ftw/austria.zip b/tests/data/ftw/austria.zip new file mode 100644 index 00000000000..e8b01db1b11 Binary files /dev/null and b/tests/data/ftw/austria.zip differ diff --git a/tests/data/ftw/data.py b/tests/data/ftw/data.py new file mode 100755 index 00000000000..8ffff19d6ad --- /dev/null +++ b/tests/data/ftw/data.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil +import zipfile + +import numpy as np +import pandas as pd +import rasterio +from affine import Affine + +np.random.seed(0) + +country = 'austria' +SIZE = 32 +num_samples = {'train': 2, 'val': 2, 'test': 2} +BASE_PROFILE = { + 'driver': 'GTiff', + 'dtype': 'uint16', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'count': 4, + 'crs': 'EPSG:4326', + 'transform': Affine(5.4e-05, 0.0, 0, 0.0, 5.4e-05, 0), + 'blockxsize': SIZE, + 'blockysize': SIZE, + 'tiled': True, + 'interleave': 'pixel', +} + + +def create_image(fn: str) -> None: + os.makedirs(os.path.dirname(fn), exist_ok=True) + + profile = BASE_PROFILE.copy() + + data = np.random.randint(0, 20000, size=(4, SIZE, SIZE), dtype=np.uint16) + with rasterio.open(fn, 'w', **profile) as dst: + dst.write(data) + + +def create_mask(fn: str, min_val: int, max_val: int) -> None: + os.makedirs(os.path.dirname(fn), exist_ok=True) + + profile = BASE_PROFILE.copy() + profile['dtype'] = 'uint8' + profile['nodata'] = 0 + profile['count'] = 1 + + data = np.random.randint(min_val, max_val, size=(1, SIZE, SIZE), dtype=np.uint8) + with rasterio.open(fn, 'w', **profile) as dst: + dst.write(data) + + +if __name__ == '__main__': + i = 0 + cols = {'aoi_id': [], 'split': []} + for split, n in num_samples.items(): + for j in range(n): + aoi = f'g_{i}' + cols['aoi_id'].append(aoi) + cols['split'].append(split) + + create_image(os.path.join(country, 's2_images', 'window_a', f'{aoi}.tif')) + create_image(os.path.join(country, 's2_images', 'window_b', f'{aoi}.tif')) + + create_mask( + os.path.join(country, 'label_masks', 'semantic_2class', f'{aoi}.tif'), + 0, + 1, + ) + create_mask( + os.path.join(country, 'label_masks', 'semantic_3class', f'{aoi}.tif'), + 0, + 2, + ) + create_mask( + os.path.join(country, 'label_masks', 'instance', f'{aoi}.tif'), 0, 100 + ) + + i += 1 + + # Create an extra train file to test for missing other files + aoi = f'g_{i}' + cols['aoi_id'].append(aoi) + cols['split'].append(split) + create_image(os.path.join(country, 's2_images', 'window_a', f'{aoi}.tif')) + + # Write parquet index + df = pd.DataFrame(cols) + df.to_parquet(os.path.join(country, f'chips_{country}.parquet')) + + # archive to zip + with zipfile.ZipFile(f'{country}.zip', 'w') as zipf: + for root, _, files in os.walk(country): + for file in files: + output_fn = os.path.join(root, file) + zipf.write(output_fn, os.path.relpath(output_fn, country)) + + shutil.rmtree(country) + + # Compute checksums + with open(f'{country}.zip', 'rb') as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f'{md5}') diff --git a/tests/data/geonrw/aachen/0_0_dem.tif b/tests/data/geonrw/aachen/0_0_dem.tif new file mode 100644 index 00000000000..2a58132c7ea Binary files /dev/null and b/tests/data/geonrw/aachen/0_0_dem.tif differ diff --git a/tests/data/geonrw/aachen/0_0_rgb.jp2 b/tests/data/geonrw/aachen/0_0_rgb.jp2 new file mode 100644 index 00000000000..57993cab5e7 Binary files /dev/null and b/tests/data/geonrw/aachen/0_0_rgb.jp2 differ diff --git a/tests/data/geonrw/aachen/0_0_seg.tif b/tests/data/geonrw/aachen/0_0_seg.tif new file mode 100644 index 00000000000..09c9e760a5d Binary files /dev/null and b/tests/data/geonrw/aachen/0_0_seg.tif differ diff --git a/tests/data/geonrw/aachen/1_1_dem.tif b/tests/data/geonrw/aachen/1_1_dem.tif new file mode 100644 index 00000000000..7c877668374 Binary files /dev/null and b/tests/data/geonrw/aachen/1_1_dem.tif differ diff --git a/tests/data/geonrw/aachen/1_1_rgb.jp2 b/tests/data/geonrw/aachen/1_1_rgb.jp2 new file mode 100644 index 00000000000..83376c763a0 Binary files /dev/null and b/tests/data/geonrw/aachen/1_1_rgb.jp2 differ diff --git a/tests/data/geonrw/aachen/1_1_seg.tif b/tests/data/geonrw/aachen/1_1_seg.tif new file mode 100644 index 00000000000..f215324aae5 Binary files /dev/null and b/tests/data/geonrw/aachen/1_1_seg.tif differ diff --git a/tests/data/geonrw/bergisch/0_0_dem.tif b/tests/data/geonrw/bergisch/0_0_dem.tif new file mode 100644 index 00000000000..bb394b53943 Binary files /dev/null and b/tests/data/geonrw/bergisch/0_0_dem.tif differ diff --git a/tests/data/geonrw/bergisch/0_0_rgb.jp2 b/tests/data/geonrw/bergisch/0_0_rgb.jp2 new file mode 100644 index 00000000000..eba6da32503 Binary files /dev/null and b/tests/data/geonrw/bergisch/0_0_rgb.jp2 differ diff --git a/tests/data/geonrw/bergisch/0_0_seg.tif b/tests/data/geonrw/bergisch/0_0_seg.tif new file mode 100644 index 00000000000..7a8ccea16a3 Binary files /dev/null and b/tests/data/geonrw/bergisch/0_0_seg.tif differ diff --git a/tests/data/geonrw/bergisch/1_1_dem.tif b/tests/data/geonrw/bergisch/1_1_dem.tif new file mode 100644 index 00000000000..8bb4a9c3985 Binary files /dev/null and b/tests/data/geonrw/bergisch/1_1_dem.tif differ diff --git a/tests/data/geonrw/bergisch/1_1_rgb.jp2 b/tests/data/geonrw/bergisch/1_1_rgb.jp2 new file mode 100644 index 00000000000..e9682e25d93 Binary files /dev/null and b/tests/data/geonrw/bergisch/1_1_rgb.jp2 differ diff --git a/tests/data/geonrw/bergisch/1_1_seg.tif b/tests/data/geonrw/bergisch/1_1_seg.tif new file mode 100644 index 00000000000..a06d56ca172 Binary files /dev/null and b/tests/data/geonrw/bergisch/1_1_seg.tif differ diff --git a/tests/data/geonrw/bielefeld/0_0_dem.tif b/tests/data/geonrw/bielefeld/0_0_dem.tif new file mode 100644 index 00000000000..3eb1be253ed Binary files /dev/null and b/tests/data/geonrw/bielefeld/0_0_dem.tif differ diff --git a/tests/data/geonrw/bielefeld/0_0_rgb.jp2 b/tests/data/geonrw/bielefeld/0_0_rgb.jp2 new file mode 100644 index 00000000000..3ff83e0b8c2 Binary files /dev/null and b/tests/data/geonrw/bielefeld/0_0_rgb.jp2 differ diff --git a/tests/data/geonrw/bielefeld/0_0_seg.tif b/tests/data/geonrw/bielefeld/0_0_seg.tif new file mode 100644 index 00000000000..77d81e79f7f Binary files /dev/null and b/tests/data/geonrw/bielefeld/0_0_seg.tif differ diff --git a/tests/data/geonrw/bielefeld/1_1_dem.tif b/tests/data/geonrw/bielefeld/1_1_dem.tif new file mode 100644 index 00000000000..7d4369753b2 Binary files /dev/null and b/tests/data/geonrw/bielefeld/1_1_dem.tif differ diff --git a/tests/data/geonrw/bielefeld/1_1_rgb.jp2 b/tests/data/geonrw/bielefeld/1_1_rgb.jp2 new file mode 100644 index 00000000000..9f045f99166 Binary files /dev/null and b/tests/data/geonrw/bielefeld/1_1_rgb.jp2 differ diff --git a/tests/data/geonrw/bielefeld/1_1_seg.tif b/tests/data/geonrw/bielefeld/1_1_seg.tif new file mode 100644 index 00000000000..3e8b5759632 Binary files /dev/null and b/tests/data/geonrw/bielefeld/1_1_seg.tif differ diff --git a/tests/data/geonrw/data.py b/tests/data/geonrw/data.py new file mode 100644 index 00000000000..8c9a06dcb3d --- /dev/null +++ b/tests/data/geonrw/data.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil +import tarfile + +import numpy as np +from PIL import Image + +# Constants +IMAGE_SIZE = (100, 100) +TRAIN_CITIES = ['aachen', 'bergisch', 'bielefeld'] +TEST_CITIES = ['duesseldorf'] +CLASSES = [ + 'background', + 'forest', + 'water', + 'agricultural', + 'residential,commercial,industrial', + 'grassland,swamp,shrubbery', + 'railway,trainstation', + 'highway,squares', + 'airport,shipyard', + 'roads', + 'buildings', +] +NUM_SAMPLES_PER_CITY = 2 + + +def create_directories(cities: list[str]) -> None: + for city in cities: + if os.path.exists(city): + shutil.rmtree(city) + os.makedirs(city, exist_ok=True) + + +def generate_dummy_data(cities: list[str]) -> None: + for city in cities: + for i in range(NUM_SAMPLES_PER_CITY): + utm_coords = f'{i}_{i}' + rgb_image = np.random.randint(0, 256, (*IMAGE_SIZE, 3), dtype=np.uint8) + dem_image = np.random.randint(0, 256, IMAGE_SIZE, dtype=np.uint8) + seg_image = np.random.randint(0, len(CLASSES), IMAGE_SIZE, dtype=np.uint8) + + Image.fromarray(rgb_image).save(os.path.join(city, f'{utm_coords}_rgb.jp2')) + Image.fromarray(dem_image).save(os.path.join(city, f'{utm_coords}_dem.tif')) + Image.fromarray(seg_image).save(os.path.join(city, f'{utm_coords}_seg.tif')) + + +def create_tarball(output_filename: str, source_dirs: list[str]) -> None: + with tarfile.open(output_filename, 'w:gz') as tar: + for source_dir in source_dirs: + tar.add(source_dir, arcname=os.path.basename(source_dir)) + + +def calculate_md5(filename: str) -> str: + hash_md5 = hashlib.md5() + with open(filename, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +# Main function +def main() -> None: + train_cities = TRAIN_CITIES + test_cities = TEST_CITIES + + create_directories(train_cities) + create_directories(test_cities) + + generate_dummy_data(train_cities) + generate_dummy_data(test_cities) + + tarball_name = 'nrw_dataset.tar.gz' + create_tarball(tarball_name, train_cities + test_cities) + + md5sum = calculate_md5(tarball_name) + print(f'MD5 checksum: {md5sum}') + + +if __name__ == '__main__': + main() diff --git a/tests/data/geonrw/duesseldorf/0_0_dem.tif b/tests/data/geonrw/duesseldorf/0_0_dem.tif new file mode 100644 index 00000000000..16b41e9946f Binary files /dev/null and b/tests/data/geonrw/duesseldorf/0_0_dem.tif differ diff --git a/tests/data/geonrw/duesseldorf/0_0_rgb.jp2 b/tests/data/geonrw/duesseldorf/0_0_rgb.jp2 new file mode 100644 index 00000000000..79d0b9a6569 Binary files /dev/null and b/tests/data/geonrw/duesseldorf/0_0_rgb.jp2 differ diff --git a/tests/data/geonrw/duesseldorf/0_0_seg.tif b/tests/data/geonrw/duesseldorf/0_0_seg.tif new file mode 100644 index 00000000000..79c5052fcfd Binary files /dev/null and b/tests/data/geonrw/duesseldorf/0_0_seg.tif differ diff --git a/tests/data/geonrw/duesseldorf/1_1_dem.tif b/tests/data/geonrw/duesseldorf/1_1_dem.tif new file mode 100644 index 00000000000..1dce40cc3c0 Binary files /dev/null and b/tests/data/geonrw/duesseldorf/1_1_dem.tif differ diff --git a/tests/data/geonrw/duesseldorf/1_1_rgb.jp2 b/tests/data/geonrw/duesseldorf/1_1_rgb.jp2 new file mode 100644 index 00000000000..e49ba9a0ced Binary files /dev/null and b/tests/data/geonrw/duesseldorf/1_1_rgb.jp2 differ diff --git a/tests/data/geonrw/duesseldorf/1_1_seg.tif b/tests/data/geonrw/duesseldorf/1_1_seg.tif new file mode 100644 index 00000000000..830e86fd017 Binary files /dev/null and b/tests/data/geonrw/duesseldorf/1_1_seg.tif differ diff --git a/tests/data/geonrw/nrw_dataset.tar.gz b/tests/data/geonrw/nrw_dataset.tar.gz new file mode 100644 index 00000000000..1014f5a447c Binary files /dev/null and b/tests/data/geonrw/nrw_dataset.tar.gz differ diff --git a/tests/data/hyspecnet/data.py b/tests/data/hyspecnet/data.py new file mode 100755 index 00000000000..3b4b701106e --- /dev/null +++ b/tests/data/hyspecnet/data.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil + +import numpy as np +import rasterio +from rasterio import Affine +from rasterio.crs import CRS + +SIZE = 32 +DTYPE = 'int16' + +np.random.seed(0) + +# Tile name purposefully shortened to avoid Windows git filename length limit. +tiles = ['ENMAP01_20221103T162438Z'] +patches = ['Y01460273_X05670694', 'Y01460273_X06950822'] + +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'nodata': -32768.0, + 'width': SIZE, + 'height': SIZE, + 'count': 224, + 'crs': CRS.from_epsg(32618), + 'transform': Affine(30.0, 0.0, 691845.0, 0.0, -30.0, 4561935.0), + 'blockysize': 3, + 'tiled': False, + 'compress': 'deflate', + 'interleave': 'band', +} + +root = 'hyspecnet-11k' +path = os.path.join(root, 'splits', 'easy') +os.makedirs(path, exist_ok=True) +for tile in tiles: + for patch in patches: + # Split CSV + path = os.path.join(tile, f'{tile}-{patch}', f'{tile}-{patch}-DATA.npy') + for split in ['train', 'val', 'test']: + with open(os.path.join(root, 'splits', 'easy', f'{split}.csv'), 'a+') as f: + f.write(f'{path}\n') + + # Spectral image + path = os.path.join(root, 'patches', path) + os.makedirs(os.path.dirname(path), exist_ok=True) + path = path.replace('DATA.npy', 'SPECTRAL_IMAGE.TIF') + Z = np.random.randint( + np.iinfo(DTYPE).min, np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE + ) + with rasterio.open(path, 'w', **profile) as src: + for i in range(1, profile['count'] + 1): + src.write(Z, i) + +shutil.make_archive(f'{root}-01', 'gztar', '.', os.path.join(root, 'patches')) +shutil.make_archive(f'{root}-splits', 'gztar', '.', os.path.join(root, 'splits')) diff --git a/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz b/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz new file mode 100644 index 00000000000..b5a5ec766a5 Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k-01.tar.gz differ diff --git a/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz b/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz new file mode 100644 index 00000000000..152f71c040f Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k-splits.tar.gz differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF new file mode 100644 index 00000000000..498bf304fa1 Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-SPECTRAL_IMAGE.TIF differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF new file mode 100644 index 00000000000..5142ff4fbcf Binary files /dev/null and b/tests/data/hyspecnet/hyspecnet-11k/patches/ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-SPECTRAL_IMAGE.TIF differ diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/test.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/train.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv new file mode 100644 index 00000000000..14393bce82a --- /dev/null +++ b/tests/data/hyspecnet/hyspecnet-11k/splits/easy/val.csv @@ -0,0 +1,2 @@ +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X05670694/ENMAP01_20221103T162438Z-Y01460273_X05670694-DATA.npy +ENMAP01_20221103T162438Z/ENMAP01_20221103T162438Z-Y01460273_X06950822/ENMAP01_20221103T162438Z-Y01460273_X06950822-DATA.npy diff --git a/tests/data/levircd/levircd/A/test_0.png b/tests/data/levircd/levircd/A/test_0.png new file mode 100644 index 00000000000..26d40b5926b Binary files /dev/null and b/tests/data/levircd/levircd/A/test_0.png differ diff --git a/tests/data/levircd/levircd/A/test_1.png b/tests/data/levircd/levircd/A/test_1.png new file mode 100644 index 00000000000..832b7d2982b Binary files /dev/null and b/tests/data/levircd/levircd/A/test_1.png differ diff --git a/tests/data/levircd/levircd/A/train_0.png b/tests/data/levircd/levircd/A/train_0.png new file mode 100644 index 00000000000..23b42403ac7 Binary files /dev/null and b/tests/data/levircd/levircd/A/train_0.png differ diff --git a/tests/data/levircd/levircd/A/train_1.png b/tests/data/levircd/levircd/A/train_1.png new file mode 100644 index 00000000000..7bbe1a0d54c Binary files /dev/null and b/tests/data/levircd/levircd/A/train_1.png differ diff --git a/tests/data/levircd/levircd/A/val_0.png b/tests/data/levircd/levircd/A/val_0.png new file mode 100644 index 00000000000..9447aae0cd3 Binary files /dev/null and b/tests/data/levircd/levircd/A/val_0.png differ diff --git a/tests/data/levircd/levircd/A/val_1.png b/tests/data/levircd/levircd/A/val_1.png new file mode 100644 index 00000000000..5e71e532c83 Binary files /dev/null and b/tests/data/levircd/levircd/A/val_1.png differ diff --git a/tests/data/levircd/levircd/B/test_0.png b/tests/data/levircd/levircd/B/test_0.png new file mode 100644 index 00000000000..5f76dd3ebfc Binary files /dev/null and b/tests/data/levircd/levircd/B/test_0.png differ diff --git a/tests/data/levircd/levircd/B/test_1.png b/tests/data/levircd/levircd/B/test_1.png new file mode 100644 index 00000000000..efba091309f Binary files /dev/null and b/tests/data/levircd/levircd/B/test_1.png differ diff --git a/tests/data/levircd/levircd/B/train_0.png b/tests/data/levircd/levircd/B/train_0.png new file mode 100644 index 00000000000..b9a80349ecd Binary files /dev/null and b/tests/data/levircd/levircd/B/train_0.png differ diff --git a/tests/data/levircd/levircd/B/train_1.png b/tests/data/levircd/levircd/B/train_1.png new file mode 100644 index 00000000000..ab2d146e891 Binary files /dev/null and b/tests/data/levircd/levircd/B/train_1.png differ diff --git a/tests/data/levircd/levircd/B/val_0.png b/tests/data/levircd/levircd/B/val_0.png new file mode 100644 index 00000000000..3cbec8e6276 Binary files /dev/null and b/tests/data/levircd/levircd/B/val_0.png differ diff --git a/tests/data/levircd/levircd/B/val_1.png b/tests/data/levircd/levircd/B/val_1.png new file mode 100644 index 00000000000..87701be980d Binary files /dev/null and b/tests/data/levircd/levircd/B/val_1.png differ diff --git a/tests/data/levircd/levircd/data.py b/tests/data/levircd/levircd/data.py index 41b56c903f6..173a0762d15 100644 --- a/tests/data/levircd/levircd/data.py +++ b/tests/data/levircd/levircd/data.py @@ -5,7 +5,6 @@ import hashlib import os -import shutil import zipfile import numpy as np @@ -32,8 +31,11 @@ def create_mask(path: str) -> None: directories = ['A', 'B', 'label'] for split, filename in zip(splits, filenames): + if os.path.exists(filename): + os.remove(filename) + for directory in directories: - os.mkdir(directory) + os.makedirs(directory, exist_ok=True) for i in range(2): path = os.path.join('A', f'{split}_{i}.png') @@ -51,9 +53,6 @@ def create_mask(path: str) -> None: for file in os.listdir(directory): f.write(os.path.join(directory, file)) - for directory in directories: - shutil.rmtree(directory) - # compute checksum with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() diff --git a/tests/data/levircd/levircd/label/test_0.png b/tests/data/levircd/levircd/label/test_0.png new file mode 100644 index 00000000000..fe0b9615c0d Binary files /dev/null and b/tests/data/levircd/levircd/label/test_0.png differ diff --git a/tests/data/levircd/levircd/label/test_1.png b/tests/data/levircd/levircd/label/test_1.png new file mode 100644 index 00000000000..22e587217c9 Binary files /dev/null and b/tests/data/levircd/levircd/label/test_1.png differ diff --git a/tests/data/levircd/levircd/label/train_0.png b/tests/data/levircd/levircd/label/train_0.png new file mode 100644 index 00000000000..94812c1971a Binary files /dev/null and b/tests/data/levircd/levircd/label/train_0.png differ diff --git a/tests/data/levircd/levircd/label/train_1.png b/tests/data/levircd/levircd/label/train_1.png new file mode 100644 index 00000000000..94812c1971a Binary files /dev/null and b/tests/data/levircd/levircd/label/train_1.png differ diff --git a/tests/data/levircd/levircd/label/val_0.png b/tests/data/levircd/levircd/label/val_0.png new file mode 100644 index 00000000000..94812c1971a Binary files /dev/null and b/tests/data/levircd/levircd/label/val_0.png differ diff --git a/tests/data/levircd/levircd/label/val_1.png b/tests/data/levircd/levircd/label/val_1.png new file mode 100644 index 00000000000..aea7f5ff8ad Binary files /dev/null and b/tests/data/levircd/levircd/label/val_1.png differ diff --git a/tests/data/levircd/levircd/test.zip b/tests/data/levircd/levircd/test.zip index 7629cd144a0..01d572f0322 100644 Binary files a/tests/data/levircd/levircd/test.zip and b/tests/data/levircd/levircd/test.zip differ diff --git a/tests/data/levircd/levircd/train.zip b/tests/data/levircd/levircd/train.zip index dee1f4057e6..fd99e32b1c0 100644 Binary files a/tests/data/levircd/levircd/train.zip and b/tests/data/levircd/levircd/train.zip differ diff --git a/tests/data/levircd/levircd/val.zip b/tests/data/levircd/levircd/val.zip index 00d77468f42..4481f35cb41 100644 Binary files a/tests/data/levircd/levircd/val.zip and b/tests/data/levircd/levircd/val.zip differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+.zip b/tests/data/levircd/levircdplus/LEVIR-CD+.zip index 0aba7587268..2e9641e1f2b 100644 Binary files a/tests/data/levircd/levircdplus/LEVIR-CD+.zip and b/tests/data/levircd/levircdplus/LEVIR-CD+.zip differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/test/A/00.png b/tests/data/levircd/levircdplus/LEVIR-CD+/test/A/00.png new file mode 100644 index 00000000000..9447aae0cd3 Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/test/A/00.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/test/A/01.png b/tests/data/levircd/levircdplus/LEVIR-CD+/test/A/01.png new file mode 100644 index 00000000000..5e71e532c83 Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/test/A/01.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/test/B/00.png b/tests/data/levircd/levircdplus/LEVIR-CD+/test/B/00.png new file mode 100644 index 00000000000..3cbec8e6276 Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/test/B/00.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/test/B/01.png b/tests/data/levircd/levircdplus/LEVIR-CD+/test/B/01.png new file mode 100644 index 00000000000..87701be980d Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/test/B/01.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/test/label/00.png b/tests/data/levircd/levircdplus/LEVIR-CD+/test/label/00.png new file mode 100644 index 00000000000..94812c1971a Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/test/label/00.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/test/label/01.png b/tests/data/levircd/levircdplus/LEVIR-CD+/test/label/01.png new file mode 100644 index 00000000000..aea7f5ff8ad Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/test/label/01.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/train/A/00.png b/tests/data/levircd/levircdplus/LEVIR-CD+/train/A/00.png new file mode 100644 index 00000000000..23b42403ac7 Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/train/A/00.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/train/A/01.png b/tests/data/levircd/levircdplus/LEVIR-CD+/train/A/01.png new file mode 100644 index 00000000000..7bbe1a0d54c Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/train/A/01.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/train/B/00.png b/tests/data/levircd/levircdplus/LEVIR-CD+/train/B/00.png new file mode 100644 index 00000000000..b9a80349ecd Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/train/B/00.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/train/B/01.png b/tests/data/levircd/levircdplus/LEVIR-CD+/train/B/01.png new file mode 100644 index 00000000000..ab2d146e891 Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/train/B/01.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/train/label/00.png b/tests/data/levircd/levircdplus/LEVIR-CD+/train/label/00.png new file mode 100644 index 00000000000..94812c1971a Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/train/label/00.png differ diff --git a/tests/data/levircd/levircdplus/LEVIR-CD+/train/label/01.png b/tests/data/levircd/levircdplus/LEVIR-CD+/train/label/01.png new file mode 100644 index 00000000000..94812c1971a Binary files /dev/null and b/tests/data/levircd/levircdplus/LEVIR-CD+/train/label/01.png differ diff --git a/tests/data/levircd/levircdplus/data.py b/tests/data/levircd/levircdplus/data.py index 5ea6296e91b..ad86c5f1af9 100644 --- a/tests/data/levircd/levircdplus/data.py +++ b/tests/data/levircd/levircdplus/data.py @@ -57,5 +57,3 @@ def create_mask(path: str) -> None: with open(f'{root}.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() print(f'{root}.zip: {md5}') - - shutil.rmtree(root) diff --git a/tests/data/mmearth/data.py b/tests/data/mmearth/data.py new file mode 100644 index 00000000000..45961fa1cfd --- /dev/null +++ b/tests/data/mmearth/data.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json +import os +import shutil +from copy import deepcopy +from datetime import datetime, timedelta + +import h5py +import numpy as np + +meta_dummy_dict = { + 'S2_DATE': '2018-07-16', + 'S2_type': 'l1c', + 'CRS': 'EPSG:32721', + 'lat': -14.499441524746077, + 'lon': -56.98355999998649, +} + +num_tiles = 10 + +meta_id_strings = [str(i) for i in range(num_tiles)] + +modalities = { + 'aster': {'bands': 2, 'dtype': np.int16}, + 'biome': {'bands': 14, 'dtype': np.uint8}, + 'canopy_height_eth': {'bands': 2, 'dtype': np.int8}, + 'dynamic_world': {'bands': 1, 'dtype': np.uint8}, + 'eco_region': {'bands': 846, 'dtype': np.uint16}, + 'era5': {'bands': 12, 'dtype': np.float32}, + 'esa_worldcover': {'bands': 1, 'dtype': np.uint8}, + 'sentinel1': {'bands': 8, 'dtype': np.float32}, + 'sentinel2': {'bands': 13, 'dtype': np.uint16}, + 'sentinel2_cloudmask': {'bands': 1, 'dtype': np.uint16}, + 'sentinel2_cloudprod': {'bands': 1, 'dtype': np.uint16}, + 'sentinel2_scl': {'bands': 1, 'dtype': np.uint16}, +} + +all_modality_bands = { + 'sentinel2': [ + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8A', + 'B8', + 'B9', + 'B10', + 'B11', + 'B12', + ], + 'sentinel2_cloudmask': ['QA60'], + 'sentinel2_cloudprod': ['MSK_CLDPRB'], + 'sentinel2_scl': ['SCL'], + 'sentinel1_asc': ['VV', 'VH', 'HH', 'HV'], + 'sentinel1_desc': ['VV', 'VH', 'HH', 'HV'], + 'aster': ['b1', 'slope'], # elevation and slope + 'era5': [ + 'prev_temperature_2m', # previous month avg temp + 'prev_temperature_2m_min', # previous month min temp + 'prev_temperature_2m_max', # previous month max temp + 'prev_total_precipitation_sum', # previous month total precip + 'curr_temperature_2m', # current month avg temp + 'curr_temperature_2m_min', # current month min temp + 'curr_temperature_2m_max', # current month max temp + 'curr_total_precipitation_sum', # current month total precip + '0_temperature_2m_mean', # year avg temp + '1_temperature_2m_min_min', # year min temp + '2_temperature_2m_max_max', # year max temp + '3_total_precipitation_sum_sum', # year total precip + ], + 'dynamic_world': ['label'], + 'canopy_height_eth': ['height', 'std'], + 'lat': ['sin', 'cos'], + 'lon': ['sin', 'cos'], + 'biome': ['biome'], + 'eco_region': ['eco_region'], + 'month': ['sin_month', 'cos_month'], + 'esa_worldcover': ['Map'], +} + + +def create_hd5f(dataset_name: str, px_dim: tuple[int]) -> list[dict[str, str]]: + # Create the HDF5 file + with h5py.File(f'{dataset_name}.h5', 'w') as h5file: + # Create datasets for each modality + for modality, modal_info in modalities.items(): + bands = modal_info['bands'] + if modality in ['era5', 'eco_region', 'biome']: + h5file.create_dataset( + modality, (num_tiles, bands), dtype=modal_info['dtype'] + ) + else: + h5file.create_dataset( + modality, (num_tiles, bands, *px_dim), dtype=modal_info['dtype'] + ) + + # Create datasets for metadata + h5file.create_dataset('lat', (num_tiles, 2), dtype=np.float32) + h5file.create_dataset('lon', (num_tiles, 2), dtype=np.float32) + h5file.create_dataset('month', (num_tiles, 2), dtype=np.int32) + h5file.create_dataset( + 'metadata', + (num_tiles,), + dtype=np.dtype([('meta_id', 'S10'), ('S2_type', 'S3')]), + ) + + # Populate the datasets with sample data + tile_info = {} + for i in range(num_tiles): + for modality in modalities: + if modality == 'dynamic_world': + old_values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + data = np.random.choice(old_values, size=(bands, *px_dim)) + elif modality == 'esa_worldcover': + old_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 100, 255] + data = np.random.choice(old_values, size=(bands, *px_dim)) + elif modality == 'era5': + # only vector not image data + data = np.random.random(size=(bands,)) + elif modality in ['biome', 'eco_region']: + data = np.random.randint(0, 2, size=(bands,)) + elif modality == 'sentinel2': + data = np.random.randint(0, 65535, size=(bands, *px_dim)) + elif modality in ['aster', 'canopy_height_eth', 'sentinel1']: + data = np.random.random(size=(bands, *px_dim)) + elif modality in [ + 'sentinel2_cloudmask', + 'sentinel2_cloudprod', + 'sentinel2_scl', + ]: + data = np.random.randint(0, 2, size=(bands, *px_dim)) + + data = data.astype(modal_info['dtype']) + h5file[modality][i] = data + + # add other data for lat, lon, month + h5file['lat'][i] = np.random.random(size=(2,)) + h5file['lon'][i] = np.random.random(size=(2,)) + h5file['month'][i] = np.random.random(size=(2,)) + + # Assign S2_type and store in metadata + S2_type = np.random.choice(['l1c', 'l2a']).encode('utf-8') + meta_id = str(i).encode('utf-8') + h5file['metadata'][i] = (meta_id, S2_type) + + # Collect tile info for JSON file + tile_meta = meta_dummy_dict.copy() + + tile_meta['S2_type'] = S2_type.decode('utf-8') + # in all_Modality_bands era5 contains the data instead `prev` and `curr` prefixes + date_str = tile_meta['S2_DATE'] + date_obj = datetime.strptime(date_str, '%Y-%m-%d') + curr_month_str = date_obj.strftime('%Y%m') + prev_month_obj = date_obj.replace(day=1) - timedelta(days=1) + prev_month_str = prev_month_obj.strftime('%Y%m') + curr_sample_bands = deepcopy(all_modality_bands) + curr_sample_bands['era5'] = [ + b.replace('curr', curr_month_str).replace('prev', prev_month_str) + for b in curr_sample_bands['era5'] + ] + tile_meta['BANDS'] = curr_sample_bands + tile_info[str(i)] = tile_meta + + return tile_info + + +extra_band_stats = { + 'sentinel2_l1c': {'bands': 13, 'dtype': np.uint16}, + 'sentinel2_l2a': {'bands': 13, 'dtype': np.uint16}, + 'lat': {'bands': 2, 'dtype': np.float32}, + 'lon': {'bands': 2, 'dtype': np.float32}, + 'month': {'bands': 2, 'dtype': np.float32}, +} + +band_modalities = { + k: v + for k, v in {**modalities, **extra_band_stats}.items() + if k not in {'biome', 'eco_region', 'dynamic_world', 'esa_worldcover'} +} + +# Create JSON files for band stats and splits +# sentinel 2 has l1c and l2a but there is only a common sentinel 2 data entry +band_stats = { + modality: { + 'mean': np.random.random(size=(mod_info['bands'])).tolist(), + 'std': np.random.random(size=(mod_info['bands'])).tolist(), + 'min': np.random.random(size=(mod_info['bands'])).tolist(), + 'max': np.random.random(size=(mod_info['bands'])).tolist(), + } + for modality, mod_info in band_modalities.items() +} + +train_split = num_tiles +val_split = 0 +test_split = 0 + +splits = { + 'train': list(range(train_split)), + 'val': list(range(train_split, train_split + val_split)), + 'test': list(range(train_split + val_split, num_tiles)), +} + +if __name__ == '__main__': + filenames = { + 'MMEarth': {'dirname': 'data_1M_v001', 'px_dim': (128, 128)}, + 'MMEarth64': {'dirname': 'data_1M_v001_64', 'px_dim': (64, 64)}, + 'MMEarth100k': {'dirname': 'data_100k_v001', 'px_dim': (128, 128)}, + } + for key, vals in filenames.items(): + dirname = vals['dirname'] + # remove existing files + if os.path.exists(dirname): + shutil.rmtree(dirname) + + # create directory + os.makedirs(dirname) + tile_info = create_hd5f(os.path.join(dirname, dirname), vals['px_dim']) + + print(f'{key} data file and JSON files created successfully.') + + with open(os.path.join(dirname, f'{dirname}_splits.json'), 'w') as f: + json.dump(splits, f, indent=4) + + with open(os.path.join(dirname, f'{dirname}_band_stats.json'), 'w') as f: + json.dump(band_stats, f, indent=4) + + with open(os.path.join(dirname, f'{dirname}_tile_info.json'), 'w') as f: + json.dump(tile_info, f, indent=4) diff --git a/tests/data/mmearth/data_100k_v001/data_100k_v001.h5 b/tests/data/mmearth/data_100k_v001/data_100k_v001.h5 new file mode 100644 index 00000000000..c485a0faa62 Binary files /dev/null and b/tests/data/mmearth/data_100k_v001/data_100k_v001.h5 differ diff --git a/tests/data/mmearth/data_100k_v001/data_100k_v001_band_stats.json b/tests/data/mmearth/data_100k_v001/data_100k_v001_band_stats.json new file mode 100644 index 00000000000..501667e4839 --- /dev/null +++ b/tests/data/mmearth/data_100k_v001/data_100k_v001_band_stats.json @@ -0,0 +1,420 @@ +{ + "aster": { + "mean": [ + 0.34133172608321716, + 0.3059512737624116 + ], + "std": [ + 0.3465348008910826, + 0.14108695274821736 + ], + "min": [ + 0.8418094294546998, + 0.4742174200974866 + ], + "max": [ + 0.56738806029585, + 0.0518313995381231 + ] + }, + "canopy_height_eth": { + "mean": [ + 0.854532719112457, + 0.48863801930320394 + ], + "std": [ + 0.5895142273813204, + 0.1380733622865845 + ], + "min": [ + 0.7537277848083938, + 0.20478855446904576 + ], + "max": [ + 0.5045161659636557, + 0.5376684828821884 + ] + }, + "era5": { + "mean": [ + 0.4417867806655783, + 0.18400642123926858, + 0.11974228279177279, + 0.9522889638018397, + 0.9273662674296557, + 0.8755178421266646, + 0.606034251540829, + 0.30760754028836534, + 0.6040509112467255, + 0.6765954694705612, + 0.6691595591399268, + 0.5760865666368172 + ], + "std": [ + 0.5142377087804115, + 0.2701723743576415, + 0.8413069700552763, + 0.23868021272203077, + 0.5615458693574323, + 0.7949644871571033, + 0.26212481323891657, + 0.7322482538861085, + 0.1995248437867745, + 0.42723767485667563, + 0.739198522837161, + 0.8092830064036739 + ], + "min": [ + 0.14533112908329815, + 0.23840001563382995, + 0.09261877533368601, + 0.10812791898965746, + 0.3602589294337053, + 0.41608271321516976, + 0.40824824209496946, + 0.4362332517942743, + 0.6458086696919946, + 0.2873520751891693, + 0.1946008373600201, + 0.3371402501790228 + ], + "max": [ + 0.9619147643696027, + 0.6002844111029695, + 0.34438509909726867, + 0.5211044855925113, + 0.249727288970654, + 0.07768059753391432, + 0.8934236930498343, + 0.8550867273916366, + 0.34905292318622505, + 0.07599362043189295, + 0.3695837636892234, + 0.8599690826993232 + ] + }, + "sentinel1": { + "mean": [ + 0.4602361303699314, + 0.9803602949980195, + 0.6286630558858189, + 0.8546244471280615, + 0.3908955820387353, + 0.15722620842791302, + 0.5954830179122328, + 0.8116450473795687 + ], + "std": [ + 0.03964016383304825, + 0.2701027934269321, + 0.3164522549613331, + 0.09860183113067111, + 0.1335076195305025, + 0.6380811967697871, + 0.5940489208142838, + 0.90153692977137 + ], + "min": [ + 0.44493594515658574, + 0.18478926184346423, + 0.2860240951390637, + 0.9376102612207217, + 0.9249907883844413, + 0.7000425768046851, + 0.3974535731475711, + 0.2996108322023431 + ], + "max": [ + 0.6430863691662376, + 0.9639089581632254, + 0.11634161184104996, + 0.753747780295231, + 0.4158525831196007, + 0.5988102320036879, + 0.10986853662090668, + 0.0600516168930747 + ] + }, + "sentinel2": { + "mean": [ + 0.572429320063415, + 0.15567923224572222, + 0.18809706032097528, + 0.8513440458791045, + 0.4678999223480048, + 0.050053414311246325, + 0.03783582407238084, + 0.2677522946476404, + 0.05453320208593193, + 0.5979956410404416, + 0.49602815159537084, + 0.988465511898549, + 0.6396682346061375 + ], + "std": [ + 0.788144262779709, + 0.8657320673010912, + 0.5279649775889855, + 0.3519159907818131, + 0.42634341564905587, + 0.7545521069496844, + 0.1962002041789851, + 0.7059625691340591, + 0.5931227904116899, + 0.9725044299059084, + 0.5405521502367713, + 0.2843034778768231, + 0.31920824614985277 + ], + "min": [ + 0.2720562009507226, + 0.5899353156966084, + 0.3934572906331085, + 0.44543431690993573, + 0.7278364898053944, + 0.02060665070965617, + 0.38574185899879954, + 0.6467951673496654, + 0.09562009477216771, + 0.7774338666717099, + 0.8432355577315033, + 0.4368636724686574, + 0.43488985400118574 + ], + "max": [ + 0.5900761314218557, + 0.36518105262763567, + 0.025620224680206638, + 0.5735969386962791, + 0.7634711203974548, + 0.1736244550922521, + 0.6024088499995152, + 0.9342662339896931, + 0.03710445086723202, + 0.1890352011946118, + 0.28380920040594426, + 0.08168516136465487, + 0.13526257707976375 + ] + }, + "sentinel2_cloudmask": { + "mean": [ + 0.6570709089318469 + ], + "std": [ + 0.5657620804780292 + ], + "min": [ + 0.9670225671155827 + ], + "max": [ + 0.5486983844030023 + ] + }, + "sentinel2_cloudprod": { + "mean": [ + 0.6891626967636988 + ], + "std": [ + 0.4094519969523073 + ], + "min": [ + 0.18725260491655094 + ], + "max": [ + 0.07180021957746674 + ] + }, + "sentinel2_scl": { + "mean": [ + 0.6780711668782042 + ], + "std": [ + 0.4943563461327216 + ], + "min": [ + 0.72302837101946 + ], + "max": [ + 0.28749332478382883 + ] + }, + "sentinel2_l1c": { + "mean": [ + 0.21099016187905117, + 0.5890058125196053, + 0.3870387069065061, + 0.40632422729999684, + 0.09220072185564243, + 0.05179158725809463, + 0.3472011267218935, + 0.27714371744503874, + 0.8667033333340239, + 0.42299347757834715, + 0.21100068056443366, + 0.9402893951577577, + 0.3890143754610127 + ], + "std": [ + 0.9129275727157, + 0.27695516423511546, + 0.6574105342764129, + 0.3857889836668025, + 0.4733288194932791, + 0.7763859293169395, + 0.969951792165023, + 0.7683755050895299, + 0.7736738677488465, + 0.6231553439174615, + 0.8681139667570541, + 0.693870549161861, + 0.07153957606497696 + ], + "min": [ + 0.7774302874038522, + 0.5237210940430268, + 0.48160697988637924, + 0.40412832766833284, + 0.49783101469118285, + 0.1676681532899118, + 0.8610056792509986, + 0.2652839446267331, + 0.9325651272132277, + 0.563023094265321, + 0.2869457262128843, + 0.6022487049661519, + 0.13539449396850844 + ], + "max": [ + 0.3394905584222998, + 0.6912694198479455, + 0.9365463758014783, + 0.026939601415270298, + 0.5290840296268874, + 0.38007307086114506, + 0.8005140940419264, + 0.7775367379319111, + 0.5736020267695333, + 0.9672861900139044, + 0.5859121986439549, + 0.8918748335743096, + 0.8098629367248834 + ] + }, + "sentinel2_l2a": { + "mean": [ + 0.4319213025299248, + 0.47014764209420445, + 0.10854844936417318, + 0.3565311102195149, + 0.035159148875477664, + 0.9947423748438694, + 0.6998282309520572, + 0.7089475988524567, + 0.6559450071993304, + 0.5583110883126653, + 0.9159743145429701, + 0.8343679900271499, + 0.7655093634482485 + ], + "std": [ + 0.12423175444317092, + 0.9912849566181509, + 0.3951297176601042, + 0.8104237474502085, + 0.7201051485011062, + 0.13586708888652077, + 0.7374687030638306, + 0.18741797127758675, + 0.16046499702755812, + 0.2749311810960794, + 0.13799794859023207, + 0.852581184239024, + 0.2925724204650476 + ], + "min": [ + 0.6710722460441257, + 0.6301086524595431, + 0.7368624384973665, + 0.13933868140865313, + 0.3705067827935764, + 0.7957615986693085, + 0.16723862032125847, + 0.20743892979117518, + 0.6662554693908289, + 0.9305180256466181, + 0.6165542799694995, + 0.9436576994737303, + 0.5915822101257956 + ], + "max": [ + 0.5530057895839687, + 0.12760001304721147, + 0.4562998709662902, + 0.04654611423548116, + 0.2540205560580904, + 0.15138539441364263, + 0.26367052218377185, + 0.6596795765749286, + 0.27285099411653047, + 0.47125521126252945, + 0.5939409956768125, + 0.2847412892997587, + 0.11228964358173976 + ] + }, + "lat": { + "mean": [ + 0.35663113563250803, + 0.9664439016211125 + ], + "std": [ + 0.5843606387551367, + 0.20984876015034148 + ], + "min": [ + 0.09962346810982947, + 0.8432152033355034 + ], + "max": [ + 0.16308304708635868, + 0.22022458984219218 + ] + }, + "lon": { + "mean": [ + 0.08778981307315648, + 0.5574407869891105 + ], + "std": [ + 0.6788585171009821, + 0.9327195921283604 + ], + "min": [ + 0.037812728275171015, + 0.7791613393176342 + ], + "max": [ + 0.45824364356139435, + 0.282148611369736 + ] + }, + "month": { + "mean": [ + 0.6768511662230008, + 0.020069115332411624 + ], + "std": [ + 0.40045110232925263, + 0.8656439167267811 + ], + "min": [ + 0.5073524073801928, + 0.0917181048136515 + ], + "max": [ + 0.6822690079185049, + 0.01508976602253198 + ] + } +} \ No newline at end of file diff --git a/tests/data/mmearth/data_100k_v001/data_100k_v001_splits.json b/tests/data/mmearth/data_100k_v001/data_100k_v001_splits.json new file mode 100644 index 00000000000..dffb6c43b69 --- /dev/null +++ b/tests/data/mmearth/data_100k_v001/data_100k_v001_splits.json @@ -0,0 +1,16 @@ +{ + "train": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9 + ], + "val": [], + "test": [] +} \ No newline at end of file diff --git a/tests/data/mmearth/data_100k_v001/data_100k_v001_tile_info.json b/tests/data/mmearth/data_100k_v001/data_100k_v001_tile_info.json new file mode 100644 index 00000000000..3700c94d789 --- /dev/null +++ b/tests/data/mmearth/data_100k_v001/data_100k_v001_tile_info.json @@ -0,0 +1,912 @@ +{ + "0": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "1": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "2": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "3": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "4": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "5": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "6": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "7": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "8": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "9": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + } +} \ No newline at end of file diff --git a/tests/data/mmearth/data_1M_v001/data_1M_v001.h5 b/tests/data/mmearth/data_1M_v001/data_1M_v001.h5 new file mode 100644 index 00000000000..1e37c8a005d Binary files /dev/null and b/tests/data/mmearth/data_1M_v001/data_1M_v001.h5 differ diff --git a/tests/data/mmearth/data_1M_v001/data_1M_v001_band_stats.json b/tests/data/mmearth/data_1M_v001/data_1M_v001_band_stats.json new file mode 100644 index 00000000000..501667e4839 --- /dev/null +++ b/tests/data/mmearth/data_1M_v001/data_1M_v001_band_stats.json @@ -0,0 +1,420 @@ +{ + "aster": { + "mean": [ + 0.34133172608321716, + 0.3059512737624116 + ], + "std": [ + 0.3465348008910826, + 0.14108695274821736 + ], + "min": [ + 0.8418094294546998, + 0.4742174200974866 + ], + "max": [ + 0.56738806029585, + 0.0518313995381231 + ] + }, + "canopy_height_eth": { + "mean": [ + 0.854532719112457, + 0.48863801930320394 + ], + "std": [ + 0.5895142273813204, + 0.1380733622865845 + ], + "min": [ + 0.7537277848083938, + 0.20478855446904576 + ], + "max": [ + 0.5045161659636557, + 0.5376684828821884 + ] + }, + "era5": { + "mean": [ + 0.4417867806655783, + 0.18400642123926858, + 0.11974228279177279, + 0.9522889638018397, + 0.9273662674296557, + 0.8755178421266646, + 0.606034251540829, + 0.30760754028836534, + 0.6040509112467255, + 0.6765954694705612, + 0.6691595591399268, + 0.5760865666368172 + ], + "std": [ + 0.5142377087804115, + 0.2701723743576415, + 0.8413069700552763, + 0.23868021272203077, + 0.5615458693574323, + 0.7949644871571033, + 0.26212481323891657, + 0.7322482538861085, + 0.1995248437867745, + 0.42723767485667563, + 0.739198522837161, + 0.8092830064036739 + ], + "min": [ + 0.14533112908329815, + 0.23840001563382995, + 0.09261877533368601, + 0.10812791898965746, + 0.3602589294337053, + 0.41608271321516976, + 0.40824824209496946, + 0.4362332517942743, + 0.6458086696919946, + 0.2873520751891693, + 0.1946008373600201, + 0.3371402501790228 + ], + "max": [ + 0.9619147643696027, + 0.6002844111029695, + 0.34438509909726867, + 0.5211044855925113, + 0.249727288970654, + 0.07768059753391432, + 0.8934236930498343, + 0.8550867273916366, + 0.34905292318622505, + 0.07599362043189295, + 0.3695837636892234, + 0.8599690826993232 + ] + }, + "sentinel1": { + "mean": [ + 0.4602361303699314, + 0.9803602949980195, + 0.6286630558858189, + 0.8546244471280615, + 0.3908955820387353, + 0.15722620842791302, + 0.5954830179122328, + 0.8116450473795687 + ], + "std": [ + 0.03964016383304825, + 0.2701027934269321, + 0.3164522549613331, + 0.09860183113067111, + 0.1335076195305025, + 0.6380811967697871, + 0.5940489208142838, + 0.90153692977137 + ], + "min": [ + 0.44493594515658574, + 0.18478926184346423, + 0.2860240951390637, + 0.9376102612207217, + 0.9249907883844413, + 0.7000425768046851, + 0.3974535731475711, + 0.2996108322023431 + ], + "max": [ + 0.6430863691662376, + 0.9639089581632254, + 0.11634161184104996, + 0.753747780295231, + 0.4158525831196007, + 0.5988102320036879, + 0.10986853662090668, + 0.0600516168930747 + ] + }, + "sentinel2": { + "mean": [ + 0.572429320063415, + 0.15567923224572222, + 0.18809706032097528, + 0.8513440458791045, + 0.4678999223480048, + 0.050053414311246325, + 0.03783582407238084, + 0.2677522946476404, + 0.05453320208593193, + 0.5979956410404416, + 0.49602815159537084, + 0.988465511898549, + 0.6396682346061375 + ], + "std": [ + 0.788144262779709, + 0.8657320673010912, + 0.5279649775889855, + 0.3519159907818131, + 0.42634341564905587, + 0.7545521069496844, + 0.1962002041789851, + 0.7059625691340591, + 0.5931227904116899, + 0.9725044299059084, + 0.5405521502367713, + 0.2843034778768231, + 0.31920824614985277 + ], + "min": [ + 0.2720562009507226, + 0.5899353156966084, + 0.3934572906331085, + 0.44543431690993573, + 0.7278364898053944, + 0.02060665070965617, + 0.38574185899879954, + 0.6467951673496654, + 0.09562009477216771, + 0.7774338666717099, + 0.8432355577315033, + 0.4368636724686574, + 0.43488985400118574 + ], + "max": [ + 0.5900761314218557, + 0.36518105262763567, + 0.025620224680206638, + 0.5735969386962791, + 0.7634711203974548, + 0.1736244550922521, + 0.6024088499995152, + 0.9342662339896931, + 0.03710445086723202, + 0.1890352011946118, + 0.28380920040594426, + 0.08168516136465487, + 0.13526257707976375 + ] + }, + "sentinel2_cloudmask": { + "mean": [ + 0.6570709089318469 + ], + "std": [ + 0.5657620804780292 + ], + "min": [ + 0.9670225671155827 + ], + "max": [ + 0.5486983844030023 + ] + }, + "sentinel2_cloudprod": { + "mean": [ + 0.6891626967636988 + ], + "std": [ + 0.4094519969523073 + ], + "min": [ + 0.18725260491655094 + ], + "max": [ + 0.07180021957746674 + ] + }, + "sentinel2_scl": { + "mean": [ + 0.6780711668782042 + ], + "std": [ + 0.4943563461327216 + ], + "min": [ + 0.72302837101946 + ], + "max": [ + 0.28749332478382883 + ] + }, + "sentinel2_l1c": { + "mean": [ + 0.21099016187905117, + 0.5890058125196053, + 0.3870387069065061, + 0.40632422729999684, + 0.09220072185564243, + 0.05179158725809463, + 0.3472011267218935, + 0.27714371744503874, + 0.8667033333340239, + 0.42299347757834715, + 0.21100068056443366, + 0.9402893951577577, + 0.3890143754610127 + ], + "std": [ + 0.9129275727157, + 0.27695516423511546, + 0.6574105342764129, + 0.3857889836668025, + 0.4733288194932791, + 0.7763859293169395, + 0.969951792165023, + 0.7683755050895299, + 0.7736738677488465, + 0.6231553439174615, + 0.8681139667570541, + 0.693870549161861, + 0.07153957606497696 + ], + "min": [ + 0.7774302874038522, + 0.5237210940430268, + 0.48160697988637924, + 0.40412832766833284, + 0.49783101469118285, + 0.1676681532899118, + 0.8610056792509986, + 0.2652839446267331, + 0.9325651272132277, + 0.563023094265321, + 0.2869457262128843, + 0.6022487049661519, + 0.13539449396850844 + ], + "max": [ + 0.3394905584222998, + 0.6912694198479455, + 0.9365463758014783, + 0.026939601415270298, + 0.5290840296268874, + 0.38007307086114506, + 0.8005140940419264, + 0.7775367379319111, + 0.5736020267695333, + 0.9672861900139044, + 0.5859121986439549, + 0.8918748335743096, + 0.8098629367248834 + ] + }, + "sentinel2_l2a": { + "mean": [ + 0.4319213025299248, + 0.47014764209420445, + 0.10854844936417318, + 0.3565311102195149, + 0.035159148875477664, + 0.9947423748438694, + 0.6998282309520572, + 0.7089475988524567, + 0.6559450071993304, + 0.5583110883126653, + 0.9159743145429701, + 0.8343679900271499, + 0.7655093634482485 + ], + "std": [ + 0.12423175444317092, + 0.9912849566181509, + 0.3951297176601042, + 0.8104237474502085, + 0.7201051485011062, + 0.13586708888652077, + 0.7374687030638306, + 0.18741797127758675, + 0.16046499702755812, + 0.2749311810960794, + 0.13799794859023207, + 0.852581184239024, + 0.2925724204650476 + ], + "min": [ + 0.6710722460441257, + 0.6301086524595431, + 0.7368624384973665, + 0.13933868140865313, + 0.3705067827935764, + 0.7957615986693085, + 0.16723862032125847, + 0.20743892979117518, + 0.6662554693908289, + 0.9305180256466181, + 0.6165542799694995, + 0.9436576994737303, + 0.5915822101257956 + ], + "max": [ + 0.5530057895839687, + 0.12760001304721147, + 0.4562998709662902, + 0.04654611423548116, + 0.2540205560580904, + 0.15138539441364263, + 0.26367052218377185, + 0.6596795765749286, + 0.27285099411653047, + 0.47125521126252945, + 0.5939409956768125, + 0.2847412892997587, + 0.11228964358173976 + ] + }, + "lat": { + "mean": [ + 0.35663113563250803, + 0.9664439016211125 + ], + "std": [ + 0.5843606387551367, + 0.20984876015034148 + ], + "min": [ + 0.09962346810982947, + 0.8432152033355034 + ], + "max": [ + 0.16308304708635868, + 0.22022458984219218 + ] + }, + "lon": { + "mean": [ + 0.08778981307315648, + 0.5574407869891105 + ], + "std": [ + 0.6788585171009821, + 0.9327195921283604 + ], + "min": [ + 0.037812728275171015, + 0.7791613393176342 + ], + "max": [ + 0.45824364356139435, + 0.282148611369736 + ] + }, + "month": { + "mean": [ + 0.6768511662230008, + 0.020069115332411624 + ], + "std": [ + 0.40045110232925263, + 0.8656439167267811 + ], + "min": [ + 0.5073524073801928, + 0.0917181048136515 + ], + "max": [ + 0.6822690079185049, + 0.01508976602253198 + ] + } +} \ No newline at end of file diff --git a/tests/data/mmearth/data_1M_v001/data_1M_v001_splits.json b/tests/data/mmearth/data_1M_v001/data_1M_v001_splits.json new file mode 100644 index 00000000000..dffb6c43b69 --- /dev/null +++ b/tests/data/mmearth/data_1M_v001/data_1M_v001_splits.json @@ -0,0 +1,16 @@ +{ + "train": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9 + ], + "val": [], + "test": [] +} \ No newline at end of file diff --git a/tests/data/mmearth/data_1M_v001/data_1M_v001_tile_info.json b/tests/data/mmearth/data_1M_v001/data_1M_v001_tile_info.json new file mode 100644 index 00000000000..f89fae376e3 --- /dev/null +++ b/tests/data/mmearth/data_1M_v001/data_1M_v001_tile_info.json @@ -0,0 +1,912 @@ +{ + "0": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "1": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "2": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "3": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "4": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "5": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "6": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "7": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "8": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "9": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + } +} \ No newline at end of file diff --git a/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64.h5 b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64.h5 new file mode 100644 index 00000000000..1e4908001b2 Binary files /dev/null and b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64.h5 differ diff --git a/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_band_stats.json b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_band_stats.json new file mode 100644 index 00000000000..501667e4839 --- /dev/null +++ b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_band_stats.json @@ -0,0 +1,420 @@ +{ + "aster": { + "mean": [ + 0.34133172608321716, + 0.3059512737624116 + ], + "std": [ + 0.3465348008910826, + 0.14108695274821736 + ], + "min": [ + 0.8418094294546998, + 0.4742174200974866 + ], + "max": [ + 0.56738806029585, + 0.0518313995381231 + ] + }, + "canopy_height_eth": { + "mean": [ + 0.854532719112457, + 0.48863801930320394 + ], + "std": [ + 0.5895142273813204, + 0.1380733622865845 + ], + "min": [ + 0.7537277848083938, + 0.20478855446904576 + ], + "max": [ + 0.5045161659636557, + 0.5376684828821884 + ] + }, + "era5": { + "mean": [ + 0.4417867806655783, + 0.18400642123926858, + 0.11974228279177279, + 0.9522889638018397, + 0.9273662674296557, + 0.8755178421266646, + 0.606034251540829, + 0.30760754028836534, + 0.6040509112467255, + 0.6765954694705612, + 0.6691595591399268, + 0.5760865666368172 + ], + "std": [ + 0.5142377087804115, + 0.2701723743576415, + 0.8413069700552763, + 0.23868021272203077, + 0.5615458693574323, + 0.7949644871571033, + 0.26212481323891657, + 0.7322482538861085, + 0.1995248437867745, + 0.42723767485667563, + 0.739198522837161, + 0.8092830064036739 + ], + "min": [ + 0.14533112908329815, + 0.23840001563382995, + 0.09261877533368601, + 0.10812791898965746, + 0.3602589294337053, + 0.41608271321516976, + 0.40824824209496946, + 0.4362332517942743, + 0.6458086696919946, + 0.2873520751891693, + 0.1946008373600201, + 0.3371402501790228 + ], + "max": [ + 0.9619147643696027, + 0.6002844111029695, + 0.34438509909726867, + 0.5211044855925113, + 0.249727288970654, + 0.07768059753391432, + 0.8934236930498343, + 0.8550867273916366, + 0.34905292318622505, + 0.07599362043189295, + 0.3695837636892234, + 0.8599690826993232 + ] + }, + "sentinel1": { + "mean": [ + 0.4602361303699314, + 0.9803602949980195, + 0.6286630558858189, + 0.8546244471280615, + 0.3908955820387353, + 0.15722620842791302, + 0.5954830179122328, + 0.8116450473795687 + ], + "std": [ + 0.03964016383304825, + 0.2701027934269321, + 0.3164522549613331, + 0.09860183113067111, + 0.1335076195305025, + 0.6380811967697871, + 0.5940489208142838, + 0.90153692977137 + ], + "min": [ + 0.44493594515658574, + 0.18478926184346423, + 0.2860240951390637, + 0.9376102612207217, + 0.9249907883844413, + 0.7000425768046851, + 0.3974535731475711, + 0.2996108322023431 + ], + "max": [ + 0.6430863691662376, + 0.9639089581632254, + 0.11634161184104996, + 0.753747780295231, + 0.4158525831196007, + 0.5988102320036879, + 0.10986853662090668, + 0.0600516168930747 + ] + }, + "sentinel2": { + "mean": [ + 0.572429320063415, + 0.15567923224572222, + 0.18809706032097528, + 0.8513440458791045, + 0.4678999223480048, + 0.050053414311246325, + 0.03783582407238084, + 0.2677522946476404, + 0.05453320208593193, + 0.5979956410404416, + 0.49602815159537084, + 0.988465511898549, + 0.6396682346061375 + ], + "std": [ + 0.788144262779709, + 0.8657320673010912, + 0.5279649775889855, + 0.3519159907818131, + 0.42634341564905587, + 0.7545521069496844, + 0.1962002041789851, + 0.7059625691340591, + 0.5931227904116899, + 0.9725044299059084, + 0.5405521502367713, + 0.2843034778768231, + 0.31920824614985277 + ], + "min": [ + 0.2720562009507226, + 0.5899353156966084, + 0.3934572906331085, + 0.44543431690993573, + 0.7278364898053944, + 0.02060665070965617, + 0.38574185899879954, + 0.6467951673496654, + 0.09562009477216771, + 0.7774338666717099, + 0.8432355577315033, + 0.4368636724686574, + 0.43488985400118574 + ], + "max": [ + 0.5900761314218557, + 0.36518105262763567, + 0.025620224680206638, + 0.5735969386962791, + 0.7634711203974548, + 0.1736244550922521, + 0.6024088499995152, + 0.9342662339896931, + 0.03710445086723202, + 0.1890352011946118, + 0.28380920040594426, + 0.08168516136465487, + 0.13526257707976375 + ] + }, + "sentinel2_cloudmask": { + "mean": [ + 0.6570709089318469 + ], + "std": [ + 0.5657620804780292 + ], + "min": [ + 0.9670225671155827 + ], + "max": [ + 0.5486983844030023 + ] + }, + "sentinel2_cloudprod": { + "mean": [ + 0.6891626967636988 + ], + "std": [ + 0.4094519969523073 + ], + "min": [ + 0.18725260491655094 + ], + "max": [ + 0.07180021957746674 + ] + }, + "sentinel2_scl": { + "mean": [ + 0.6780711668782042 + ], + "std": [ + 0.4943563461327216 + ], + "min": [ + 0.72302837101946 + ], + "max": [ + 0.28749332478382883 + ] + }, + "sentinel2_l1c": { + "mean": [ + 0.21099016187905117, + 0.5890058125196053, + 0.3870387069065061, + 0.40632422729999684, + 0.09220072185564243, + 0.05179158725809463, + 0.3472011267218935, + 0.27714371744503874, + 0.8667033333340239, + 0.42299347757834715, + 0.21100068056443366, + 0.9402893951577577, + 0.3890143754610127 + ], + "std": [ + 0.9129275727157, + 0.27695516423511546, + 0.6574105342764129, + 0.3857889836668025, + 0.4733288194932791, + 0.7763859293169395, + 0.969951792165023, + 0.7683755050895299, + 0.7736738677488465, + 0.6231553439174615, + 0.8681139667570541, + 0.693870549161861, + 0.07153957606497696 + ], + "min": [ + 0.7774302874038522, + 0.5237210940430268, + 0.48160697988637924, + 0.40412832766833284, + 0.49783101469118285, + 0.1676681532899118, + 0.8610056792509986, + 0.2652839446267331, + 0.9325651272132277, + 0.563023094265321, + 0.2869457262128843, + 0.6022487049661519, + 0.13539449396850844 + ], + "max": [ + 0.3394905584222998, + 0.6912694198479455, + 0.9365463758014783, + 0.026939601415270298, + 0.5290840296268874, + 0.38007307086114506, + 0.8005140940419264, + 0.7775367379319111, + 0.5736020267695333, + 0.9672861900139044, + 0.5859121986439549, + 0.8918748335743096, + 0.8098629367248834 + ] + }, + "sentinel2_l2a": { + "mean": [ + 0.4319213025299248, + 0.47014764209420445, + 0.10854844936417318, + 0.3565311102195149, + 0.035159148875477664, + 0.9947423748438694, + 0.6998282309520572, + 0.7089475988524567, + 0.6559450071993304, + 0.5583110883126653, + 0.9159743145429701, + 0.8343679900271499, + 0.7655093634482485 + ], + "std": [ + 0.12423175444317092, + 0.9912849566181509, + 0.3951297176601042, + 0.8104237474502085, + 0.7201051485011062, + 0.13586708888652077, + 0.7374687030638306, + 0.18741797127758675, + 0.16046499702755812, + 0.2749311810960794, + 0.13799794859023207, + 0.852581184239024, + 0.2925724204650476 + ], + "min": [ + 0.6710722460441257, + 0.6301086524595431, + 0.7368624384973665, + 0.13933868140865313, + 0.3705067827935764, + 0.7957615986693085, + 0.16723862032125847, + 0.20743892979117518, + 0.6662554693908289, + 0.9305180256466181, + 0.6165542799694995, + 0.9436576994737303, + 0.5915822101257956 + ], + "max": [ + 0.5530057895839687, + 0.12760001304721147, + 0.4562998709662902, + 0.04654611423548116, + 0.2540205560580904, + 0.15138539441364263, + 0.26367052218377185, + 0.6596795765749286, + 0.27285099411653047, + 0.47125521126252945, + 0.5939409956768125, + 0.2847412892997587, + 0.11228964358173976 + ] + }, + "lat": { + "mean": [ + 0.35663113563250803, + 0.9664439016211125 + ], + "std": [ + 0.5843606387551367, + 0.20984876015034148 + ], + "min": [ + 0.09962346810982947, + 0.8432152033355034 + ], + "max": [ + 0.16308304708635868, + 0.22022458984219218 + ] + }, + "lon": { + "mean": [ + 0.08778981307315648, + 0.5574407869891105 + ], + "std": [ + 0.6788585171009821, + 0.9327195921283604 + ], + "min": [ + 0.037812728275171015, + 0.7791613393176342 + ], + "max": [ + 0.45824364356139435, + 0.282148611369736 + ] + }, + "month": { + "mean": [ + 0.6768511662230008, + 0.020069115332411624 + ], + "std": [ + 0.40045110232925263, + 0.8656439167267811 + ], + "min": [ + 0.5073524073801928, + 0.0917181048136515 + ], + "max": [ + 0.6822690079185049, + 0.01508976602253198 + ] + } +} \ No newline at end of file diff --git a/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_splits.json b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_splits.json new file mode 100644 index 00000000000..dffb6c43b69 --- /dev/null +++ b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_splits.json @@ -0,0 +1,16 @@ +{ + "train": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9 + ], + "val": [], + "test": [] +} \ No newline at end of file diff --git a/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_tile_info.json b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_tile_info.json new file mode 100644 index 00000000000..ad0abfb43c7 --- /dev/null +++ b/tests/data/mmearth/data_1M_v001_64/data_1M_v001_64_tile_info.json @@ -0,0 +1,912 @@ +{ + "0": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "1": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "2": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "3": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "4": { + "S2_DATE": "2018-07-16", + "S2_type": "l1c", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "5": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "6": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "7": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "8": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + }, + "9": { + "S2_DATE": "2018-07-16", + "S2_type": "l2a", + "CRS": "EPSG:32721", + "lat": -14.499441524746077, + "lon": -56.98355999998649, + "BANDS": { + "sentinel2": [ + "B1", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8A", + "B8", + "B9", + "B10", + "B11", + "B12" + ], + "sentinel2_cloudmask": [ + "QA60" + ], + "sentinel2_cloudprod": [ + "MSK_CLDPRB" + ], + "sentinel2_scl": [ + "SCL" + ], + "sentinel1_asc": [ + "VV", + "VH", + "HH", + "HV" + ], + "sentinel1_desc": [ + "VV", + "VH", + "HH", + "HV" + ], + "aster": [ + "b1", + "slope" + ], + "era5": [ + "201806_temperature_2m", + "201806_temperature_2m_min", + "201806_temperature_2m_max", + "201806_total_precipitation_sum", + "201807_temperature_2m", + "201807_temperature_2m_min", + "201807_temperature_2m_max", + "201807_total_precipitation_sum", + "0_temperature_2m_mean", + "1_temperature_2m_min_min", + "2_temperature_2m_max_max", + "3_total_precipitation_sum_sum" + ], + "dynamic_world": [ + "label" + ], + "canopy_height_eth": [ + "height", + "std" + ], + "lat": [ + "sin", + "cos" + ], + "lon": [ + "sin", + "cos" + ], + "biome": [ + "biome" + ], + "eco_region": [ + "eco_region" + ], + "month": [ + "sin_month", + "cos_month" + ], + "esa_worldcover": [ + "Map" + ] + } + } +} \ No newline at end of file diff --git a/tests/data/nasa_marine_debris/data.py b/tests/data/nasa_marine_debris/data.py new file mode 100755 index 00000000000..a782dea3d27 --- /dev/null +++ b/tests/data/nasa_marine_debris/data.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import rasterio as rio +from rasterio import Affine +from rasterio.crs import CRS + +SIZE = 32 +DTYPE = np.uint8 + +np.random.seed(0) + +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 3, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( + 2.1457672119140625e-05, + 0.0, + -87.626953125, + 0.0, + -2.0629065249348766e-05, + 15.977172621632805, + ), +} + +os.makedirs('source', exist_ok=True) +os.makedirs('labels', exist_ok=True) + +files = [ + '20160928_153233_0e16_16816-29821-16', + '20160928_153233_0e16_16816-29824-16', + '20160928_153233_0e16_16816-29825-16', + '20160928_153233_0e16_16816-29828-16', + '20160928_153233_0e16_16816-29829-16', +] +for file in files: + with rio.open(os.path.join('source', f'{file}.tif'), 'w', **profile) as f: + for i in range(1, 4): + Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + f.write(Z, i) + + count = np.random.randint(5) + x = np.random.randint(SIZE, size=count) + y = np.random.randint(SIZE, size=count) + dx = np.random.randint(5, size=count) + dy = np.random.randint(5, size=count) + label = np.ones(count) + Z = np.stack([x, y, x + dx, y + dy, label], axis=-1) + np.save(os.path.join('labels', f'{file}.npy'), Z) diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29821-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29821-16.npy new file mode 100644 index 00000000000..104f61ed701 Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29821-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29824-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29824-16.npy new file mode 100644 index 00000000000..4de2e99e4ff Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29824-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29825-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29825-16.npy new file mode 100644 index 00000000000..e90c5b706e6 Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29825-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29828-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29828-16.npy new file mode 100644 index 00000000000..99aa9231262 Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29828-16.npy differ diff --git a/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29829-16.npy b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29829-16.npy new file mode 100644 index 00000000000..577edba2c81 Binary files /dev/null and b/tests/data/nasa_marine_debris/labels/20160928_153233_0e16_16816-29829-16.npy differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels.tar.gz b/tests/data/nasa_marine_debris/nasa_marine_debris_labels.tar.gz deleted file mode 100644 index 3a4c8edbeeb..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels.tar.gz and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29821-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29821-16/pixel_bounds.npy deleted file mode 100755 index eeaaf46f294..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29821-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29824-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29824-16/pixel_bounds.npy deleted file mode 100755 index eeaaf46f294..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29824-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29825-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29825-16/pixel_bounds.npy deleted file mode 100755 index f559349de8c..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29825-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29828-16/pixel_bounds.npy b/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29828-16/pixel_bounds.npy deleted file mode 100755 index eeaaf46f294..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_labels/nasa_marine_debris_labels_20160928_153233_0e16_16816-29828-16/pixel_bounds.npy and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source.tar.gz b/tests/data/nasa_marine_debris/nasa_marine_debris_source.tar.gz deleted file mode 100644 index b1a3b53d413..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source.tar.gz and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29821-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29821-16/image_geotiff.tif deleted file mode 100644 index 471c657e5f6..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29821-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29824-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29824-16/image_geotiff.tif deleted file mode 100644 index c472caabbee..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29824-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29825-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29825-16/image_geotiff.tif deleted file mode 100644 index d6fd058b202..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29825-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29828-16/image_geotiff.tif b/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29828-16/image_geotiff.tif deleted file mode 100644 index 65ff7677ab6..00000000000 Binary files a/tests/data/nasa_marine_debris/nasa_marine_debris_source/nasa_marine_debris_source_20160928_153233_0e16_16816-29828-16/image_geotiff.tif and /dev/null differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29821-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29821-16.tif new file mode 100644 index 00000000000..66b8ae8e123 Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29821-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29824-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29824-16.tif new file mode 100644 index 00000000000..2b1609165ae Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29824-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29825-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29825-16.tif new file mode 100644 index 00000000000..1366468ec1c Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29825-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29828-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29828-16.tif new file mode 100644 index 00000000000..f8d75a064ad Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29828-16.tif differ diff --git a/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29829-16.tif b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29829-16.tif new file mode 100644 index 00000000000..85c99a80c64 Binary files /dev/null and b/tests/data/nasa_marine_debris/source/20160928_153233_0e16_16816-29829-16.tif differ diff --git a/tests/data/nccm/data.py b/tests/data/nccm/data.py index 9dda733f4b3..f763439ef1b 100644 --- a/tests/data/nccm/data.py +++ b/tests/data/nccm/data.py @@ -17,7 +17,7 @@ files = ['CDL2017_clip.tif', 'CDL2018_clip1.tif', 'CDL2019_clip.tif'] -def create_file(path: str, dtype: str): +def create_file(path: str, dtype: str) -> None: """Create the testing file.""" profile = { 'driver': 'GTiff', diff --git a/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604/nlcd_2011_land_cover_l48_20210604.img b/tests/data/nlcd/Annual_NLCD_LndCov_2011_CU_C1V0.tif similarity index 100% rename from tests/data/nlcd/nlcd_2011_land_cover_l48_20210604/nlcd_2011_land_cover_l48_20210604.img rename to tests/data/nlcd/Annual_NLCD_LndCov_2011_CU_C1V0.tif diff --git a/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img b/tests/data/nlcd/Annual_NLCD_LndCov_2019_CU_C1V0.tif similarity index 100% rename from tests/data/nlcd/nlcd_2019_land_cover_l48_20210604/nlcd_2019_land_cover_l48_20210604.img rename to tests/data/nlcd/Annual_NLCD_LndCov_2019_CU_C1V0.tif diff --git a/tests/data/nlcd/data.py b/tests/data/nlcd/data.py index 072b6637500..e37d1189810 100755 --- a/tests/data/nlcd/data.py +++ b/tests/data/nlcd/data.py @@ -5,7 +5,6 @@ import hashlib import os -import shutil import numpy as np import rasterio @@ -16,8 +15,6 @@ np.random.seed(0) -dir = 'nlcd_{}_land_cover_l48_20210604' - years = [2011, 2019] wkt = """ @@ -43,7 +40,7 @@ """ -def create_file(path: str, dtype: str): +def create_file(path: str, dtype: str) -> None: """Create the testing file.""" profile = { 'driver': 'GTiff', @@ -67,21 +64,12 @@ def create_file(path: str, dtype: str): if __name__ == '__main__': for year in years: - year_dir = dir.format(year) - # Remove old data - if os.path.isdir(year_dir): - shutil.rmtree(year_dir) - - os.makedirs(os.path.join(os.getcwd(), year_dir)) - - zip_filename = year_dir + '.zip' - filename = year_dir + '.img' - create_file(os.path.join(year_dir, filename), dtype='int8') - - # Compress data - shutil.make_archive(year_dir, 'zip', '.', year_dir) + filename = os.path.join( + 'tests', 'data', 'nlcd', 'Annual_NLCD_LndCov_{}_CU_C1V0.tif' + ).format(year) + create_file(filename, dtype='int8') # Compute checksums - with open(zip_filename, 'rb') as f: + with open(filename, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f'{zip_filename}: {md5}') + print(f'{filename}: {md5}') diff --git a/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip b/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip deleted file mode 100644 index 6e5ad22538f..00000000000 Binary files a/tests/data/nlcd/nlcd_2011_land_cover_l48_20210604.zip and /dev/null differ diff --git a/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip b/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip deleted file mode 100644 index 0021cc86961..00000000000 Binary files a/tests/data/nlcd/nlcd_2019_land_cover_l48_20210604.zip and /dev/null differ diff --git a/tests/data/openbuildings/data.py b/tests/data/openbuildings/data.py index 8babe48758d..2113a431408 100755 --- a/tests/data/openbuildings/data.py +++ b/tests/data/openbuildings/data.py @@ -15,7 +15,7 @@ SIZE = 0.05 -def create_meta_data_file(zipfilename): +def create_meta_data_file(zipfilename: str) -> dict[object, object]: meta_data = { 'type': 'FeatureCollection', 'features': [ @@ -38,7 +38,7 @@ def create_meta_data_file(zipfilename): return meta_data -def create_csv_data_row(lat, long): +def create_csv_data_row(lat: float, long: float) -> dict[object, object]: width, height = SIZE / 10, SIZE / 10 minx = long - 0.5 * width maxx = long + 0.5 * width @@ -59,7 +59,7 @@ def create_csv_data_row(lat, long): return data_row -def create_buildings_data(): +def create_buildings_data() -> list[dict[object, object]]: fourth = SIZE / 4 # pandas df dict_data = [ diff --git a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz b/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz deleted file mode 100644 index 1c642bf9c73..00000000000 Binary files a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_labels.tar.gz and /dev/null differ diff --git a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz b/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz deleted file mode 100644 index f5e0e289137..00000000000 Binary files a/tests/data/ref_african_crops_kenya_02/ref_african_crops_kenya_02_source.tar.gz and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/data.py b/tests/data/ref_cloud_cover_detection_challenge_v1/data.py index e8a771e0fa5..1523af6a14e 100755 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/data.py +++ b/tests/data/ref_cloud_cover_detection_challenge_v1/data.py @@ -1,275 +1,42 @@ +#!/usr/bin/env python3 + # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import os -from datetime import datetime as dt -from pathlib import Path import numpy as np import rasterio -from pystac import ( - Asset, - CatalogType, - Collection, - Extent, - Item, - Link, - MediaType, - SpatialExtent, - TemporalExtent, -) -from pystac.extensions.eo import Band, EOExtension -from pystac.extensions.label import ( - LabelClasses, - LabelCount, - LabelExtension, - LabelOverview, - LabelType, -) +from rasterio import Affine from rasterio.crs import CRS -from rasterio.transform import Affine - -np.random.seed(0) -SIZE = 512 -BANDS = ['B02', 'B03', 'B04', 'B08'] +SIZE = 2 +DTYPE = np.uint16 -SOURCE_COLLECTION_ID = 'ref_cloud_cover_detection_challenge_v1_test_source' -SOURCE_ITEM_ID = 'ref_cloud_cover_detection_challenge_v1_test_source_aaaa' -LABEL_COLLECTION_ID = 'ref_cloud_cover_detection_challenge_v1_test_labels' -LABEL_ITEM_ID = 'ref_cloud_cover_detection_challenge_v1_test_labels_aaaa' +np.random.seed(0) -# geometry used by both source and label items -TEST_GEOMETRY = { - 'type': 'Polygon', - 'coordinates': [ - [ - [137.86580132892396, -29.52744848758255], - [137.86450090473795, -29.481297003404038], - [137.91724642199793, -29.48015007212528], - [137.9185707094313, -29.526299409555623], - [137.86580132892396, -29.52744848758255], - ] - ], +splits = {'train': 'public', 'test': 'private'} +chip_ids = ['aaaa'] +all_bands = ['B02', 'B03', 'B04', 'B08'] +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(32753), + 'transform': Affine(10.0, 0.0, 777760.0, 0.0, -10.0, 6735270.0), } - -# bbox used by both source and label items -TEST_BBOX = [ - 137.86450090473795, - -29.52744848758255, - 137.9185707094313, - -29.48015007212528, -] - -# sentinel-2 bands for EO extension -S2_BANDS = [ - Band.create(name='B02', common_name='blue', description='Blue'), - Band.create(name='B03', common_name='green', description='Green'), - Band.create(name='B04', common_name='red', description='Red'), - Band.create(name='B08', common_name='nir', description='NIR'), -] - -# class map for overviews -CLASS_COUNT_MAP = {'0': 'no cloud', '1': 'cloud'} - -# define the spatial and temporal extent of collections -TEST_EXTENT = Extent( - spatial=SpatialExtent( - bboxes=[ - [ - -80.05464265420176, - -53.31380701212582, - 151.75593282192196, - 35.199126843018696, - ] - ] - ), - temporal=TemporalExtent( - intervals=[ - [ - dt.strptime('2018-02-18', '%Y-%m-%d'), - dt.strptime('2020-09-13', '%Y-%m-%d'), - ] - ] - ), -) - - -def create_raster(path: str, dtype: str, num_channels: int, collection: str) -> None: - if not os.path.exists(os.path.split(path)[0]): - Path(os.path.split(path)[0]).mkdir(parents=True) - - profile = {} - profile['driver'] = 'GTiff' - profile['dtype'] = dtype - profile['count'] = num_channels - profile['crs'] = CRS.from_epsg(32753) - profile['transform'] = Affine(1.0, 0.0, 777760.0, 0.0, -10.0, 6735270.0) - profile['height'] = SIZE - profile['width'] = SIZE - profile['compress'] = 'lzw' - profile['predictor'] = 2 - - if collection == 'source': - if 'float' in profile['dtype']: - Z = np.random.randn(SIZE, SIZE).astype(profile['dtype']) - else: - Z = np.random.randint( - np.iinfo(profile['dtype']).max, - size=(SIZE, SIZE), - dtype=profile['dtype'], - ) - elif collection == 'labels': - Z = np.random.randint(0, 2, (SIZE, SIZE)).astype(profile['dtype']) - - with rasterio.open(path, 'w', **profile) as src: - for i in range(1, profile['count'] + 1): - src.write(Z, i) - - -def create_source_item() -> Item: - # instantiate source Item - test_source_item = Item( - id=SOURCE_ITEM_ID, - geometry=TEST_GEOMETRY, - bbox=TEST_BBOX, - datetime=dt.strptime('2020-06-03', '%Y-%m-%d'), - properties={}, - ) - - # add Asset with EO Extension for each S2 band - for band in BANDS: - img_path = os.path.join( - os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f'{band}.tif' - ) - image_asset = Asset(href=img_path, media_type=MediaType.GEOTIFF) - eo_asset_ext = EOExtension.ext(image_asset) - - for s2_band in S2_BANDS: - if s2_band.name == band: - eo_asset_ext.apply(bands=[s2_band]) - test_source_item.add_asset(key=band, asset=image_asset) - - eo_image_ext = EOExtension.ext(test_source_item, add_if_missing=True) - eo_image_ext.apply(bands=S2_BANDS) - - return test_source_item - - -def get_class_label_list(overview: LabelOverview) -> LabelClasses: - label_list = [d['name'] for d in overview.properties['counts']] - label_classes = LabelClasses.create(classes=label_list, name='labels') - return label_classes - - -def get_item_class_overview(label_type: LabelType, asset_path: str) -> LabelOverview: - """Takes a path to an asset based on type and returns the class label - overview object - - Args: - label_type: LabelType - the type of label, either RASTER or VECTOR - asset_path: str - path to the asset to read in either a raster image or - geojson vector - - Returns: - overview: LabelOverview - the STAC LabelOverview object containing label classes - - """ - - count_list = [] - - img_arr = rasterio.open(asset_path).read() - value_count = np.unique(img_arr.flatten(), return_counts=True) - - for ix, classy in enumerate(value_count[0]): - if classy > 0: - label_count = LabelCount.create( - name=CLASS_COUNT_MAP[str(int(classy))], count=int(value_count[1][ix]) - ) - count_list.append(label_count) - - overview = LabelOverview(properties={}) - overview.apply(property_key='labels', counts=count_list) - - return overview - - -def create_label_item() -> Item: - # instantiate label Item - test_label_item = Item( - id=LABEL_ITEM_ID, - geometry=TEST_GEOMETRY, - bbox=TEST_BBOX, - datetime=dt.strptime('2020-06-03', '%Y-%m-%d'), - properties={}, - ) - - label_overview = get_item_class_overview(LabelType.RASTER, label_path) - label_list = get_class_label_list(label_overview) - - label_ext = LabelExtension.ext(test_label_item, add_if_missing=True) - label_ext.apply( - label_description='Sentinel-2 Cloud Cover Segmentation Test Labels', - label_type=LabelType.RASTER, - label_classes=[label_list], - label_overviews=[label_overview], - ) - - label_asset = Asset(href=label_path, media_type=MediaType.GEOTIFF) - test_label_item.add_asset(key='labels', asset=label_asset) - - return test_label_item - - -if __name__ == '__main__': - # create a geotiff for each s2 band - for b in BANDS: - tif_path = os.path.join( - os.getcwd(), SOURCE_COLLECTION_ID, SOURCE_ITEM_ID, f'{b}.tif' - ) - create_raster(tif_path, 'uint8', 1, 'source') - - # create a geotiff for label - label_path = os.path.join( - os.getcwd(), LABEL_COLLECTION_ID, LABEL_ITEM_ID, 'labels.tif' - ) - create_raster(label_path, 'uint8', 1, 'labels') - - # instantiate the source Collection - test_source_collection = Collection( - id=SOURCE_COLLECTION_ID, - description='Test Source Collection for Torchgo Cloud Cover Detection Dataset', - extent=TEST_EXTENT, - catalog_type=CatalogType.RELATIVE_PUBLISHED, - license='CC-BY-4.0', - ) - - source_item = create_source_item() - test_source_collection.add_item(source_item) - - test_source_collection.normalize_hrefs( - os.path.join(os.getcwd(), SOURCE_COLLECTION_ID) - ) - test_source_collection.make_all_asset_hrefs_relative() - test_source_collection.save(catalog_type=CatalogType.SELF_CONTAINED) - - # instantiate the label Collection - test_label_collection = Collection( - id=LABEL_COLLECTION_ID, - description='Test Label Collection for Torchgo Cloud Cover Detection Dataset', - extent=TEST_EXTENT, - catalog_type=CatalogType.RELATIVE_PUBLISHED, - license='CC-BY-4.0', - ) - - label_item = create_label_item() - label_item.add_link( - Link(rel='source', target=source_item, media_type=MediaType.GEOTIFF) - ) - test_label_collection.add_item(label_item) - - test_label_collection.normalize_hrefs( - os.path.join(os.getcwd(), LABEL_COLLECTION_ID) - ) - test_label_collection.make_all_asset_hrefs_relative() - test_label_collection.save(catalog_type=CatalogType.SELF_CONTAINED) +Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + +for split, directory in splits.items(): + for chip_id in chip_ids: + path = os.path.join(directory, f'{split}_features', chip_id) + os.makedirs(path, exist_ok=True) + for band in all_bands: + with rasterio.open(os.path.join(path, f'{band}.tif'), 'w', **profile) as f: + f.write(Z, 1) + path = os.path.join(directory, f'{split}_labels') + os.makedirs(path, exist_ok=True) + with rasterio.open(os.path.join(path, f'{chip_id}.tif'), 'w', **profile) as f: + f.write(Z, 1) diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B02.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B02.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B02.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B03.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B03.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B03.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B04.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B04.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B04.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B08.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B08.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_features/aaaa/B08.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_labels/aaaa.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_labels/aaaa.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_labels/aaaa.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_metadata.csv b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_metadata.csv new file mode 100644 index 00000000000..17c43ad2cef --- /dev/null +++ b/tests/data/ref_cloud_cover_detection_challenge_v1/private/test_metadata.csv @@ -0,0 +1,2 @@ +chip_id,location,datetime +aaaa,Australia - Central East,2024-06-013T00:00:00Z diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B02.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B02.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B02.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B03.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B03.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B03.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B04.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B04.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B04.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B08.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B08.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_features/aaaa/B08.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_labels/aaaa.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_labels/aaaa.tif new file mode 100644 index 00000000000..79ce7c0a3bf Binary files /dev/null and b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_labels/aaaa.tif differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_metadata.csv b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_metadata.csv new file mode 100644 index 00000000000..17c43ad2cef --- /dev/null +++ b/tests/data/ref_cloud_cover_detection_challenge_v1/public/train_metadata.csv @@ -0,0 +1,2 @@ +chip_id,location,datetime +aaaa,Australia - Central East,2024-06-013T00:00:00Z diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz deleted file mode 100644 index 8aa7da7a185..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/collection.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/collection.json deleted file mode 100644 index 3f50cb6b5db..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/collection.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "type": "Collection", - "id": "ref_cloud_cover_detection_challenge_v1_test_labels", - "stac_version": "1.0.0", - "description": "Test Label Collection for Torchgo Cloud Cover Detection Dataset", - "links": [ - { - "rel": "root", - "href": "./collection.json", - "type": "application/json" - }, - { - "rel": "item", - "href": "./ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json", - "type": "application/json" - } - ], - "stac_extensions": [], - "extent": { - "spatial": { - "bbox": [ - [ - -80.05464265420176, - -53.31380701212582, - 151.75593282192196, - 35.199126843018696 - ] - ] - }, - "temporal": { - "interval": [ - [ - "2018-02-18T00:00:00Z", - "2020-09-13T00:00:00Z" - ] - ] - } - }, - "license": "CC-BY-4.0" -} \ No newline at end of file diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/labels.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/labels.tif deleted file mode 100644 index 0181ba4f573..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/labels.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json deleted file mode 100644 index 7633d8b46d1..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_labels/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa/ref_cloud_cover_detection_challenge_v1_test_labels_aaaa.json +++ /dev/null @@ -1,95 +0,0 @@ -{ - "type": "Feature", - "stac_version": "1.0.0", - "id": "ref_cloud_cover_detection_challenge_v1_test_labels_aaaa", - "properties": { - "label:description": "Sentinel-2 Cloud Cover Segmentation Test Labels", - "label:type": "raster", - "label:properties": null, - "label:classes": [ - { - "classes": [ - "cloud" - ], - "name": "labels" - } - ], - "label:overviews": [ - { - "property_key": "labels", - "counts": [ - { - "name": "cloud", - "count": 130696 - } - ] - } - ], - "datetime": "2020-06-03T00:00:00Z" - }, - "geometry": { - "type": "Polygon", - "coordinates": [ - [ - [ - 137.86580132892396, - -29.52744848758255 - ], - [ - 137.86450090473795, - -29.481297003404038 - ], - [ - 137.91724642199793, - -29.48015007212528 - ], - [ - 137.9185707094313, - -29.526299409555623 - ], - [ - 137.86580132892396, - -29.52744848758255 - ] - ] - ] - }, - "links": [ - { - "rel": "source", - "href": "../../ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json", - "type": "image/tiff; application=geotiff" - }, - { - "rel": "root", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "collection", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "parent", - "href": "../collection.json", - "type": "application/json" - } - ], - "assets": { - "labels": { - "href": "./labels.tif", - "type": "image/tiff; application=geotiff" - } - }, - "bbox": [ - 137.86450090473795, - -29.52744848758255, - 137.9185707094313, - -29.48015007212528 - ], - "stac_extensions": [ - "https://stac-extensions.github.io/label/v1.0.1/schema.json" - ], - "collection": "ref_cloud_cover_detection_challenge_v1_test_labels" -} \ No newline at end of file diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source.tar.gz b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source.tar.gz deleted file mode 100644 index b65bd0849b5..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source.tar.gz and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/collection.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/collection.json deleted file mode 100644 index 3f19bed1e2b..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/collection.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "type": "Collection", - "id": "ref_cloud_cover_detection_challenge_v1_test_source", - "stac_version": "1.0.0", - "description": "Test Source Collection for Torchgo Cloud Cover Detection Dataset", - "links": [ - { - "rel": "root", - "href": "./collection.json", - "type": "application/json" - }, - { - "rel": "item", - "href": "./ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json", - "type": "application/json" - } - ], - "stac_extensions": [], - "extent": { - "spatial": { - "bbox": [ - [ - -80.05464265420176, - -53.31380701212582, - 151.75593282192196, - 35.199126843018696 - ] - ] - }, - "temporal": { - "interval": [ - [ - "2018-02-18T00:00:00Z", - "2020-09-13T00:00:00Z" - ] - ] - } - }, - "license": "CC-BY-4.0" -} \ No newline at end of file diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B02.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B02.tif deleted file mode 100644 index 5f23bc2b562..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B02.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B03.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B03.tif deleted file mode 100644 index f143ae2c3fb..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B03.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B04.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B04.tif deleted file mode 100644 index b1d91415d52..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B04.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B08.tif b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B08.tif deleted file mode 100644 index 111b1d7af26..00000000000 Binary files a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/B08.tif and /dev/null differ diff --git a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json b/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json deleted file mode 100644 index aebe445f34a..00000000000 --- a/tests/data/ref_cloud_cover_detection_challenge_v1/ref_cloud_cover_detection_challenge_v1_test_source/ref_cloud_cover_detection_challenge_v1_test_source_aaaa/ref_cloud_cover_detection_challenge_v1_test_source_aaaa.json +++ /dev/null @@ -1,130 +0,0 @@ -{ - "type": "Feature", - "stac_version": "1.0.0", - "id": "ref_cloud_cover_detection_challenge_v1_test_source_aaaa", - "properties": { - "eo:bands": [ - { - "name": "B02", - "common_name": "blue", - "description": "Blue" - }, - { - "name": "B03", - "common_name": "green", - "description": "Green" - }, - { - "name": "B04", - "common_name": "red", - "description": "Red" - }, - { - "name": "B08", - "common_name": "nir", - "description": "NIR" - } - ], - "datetime": "2020-06-03T00:00:00Z" - }, - "geometry": { - "type": "Polygon", - "coordinates": [ - [ - [ - 137.86580132892396, - -29.52744848758255 - ], - [ - 137.86450090473795, - -29.481297003404038 - ], - [ - 137.91724642199793, - -29.48015007212528 - ], - [ - 137.9185707094313, - -29.526299409555623 - ], - [ - 137.86580132892396, - -29.52744848758255 - ] - ] - ] - }, - "links": [ - { - "rel": "root", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "collection", - "href": "../collection.json", - "type": "application/json" - }, - { - "rel": "parent", - "href": "../collection.json", - "type": "application/json" - } - ], - "assets": { - "B02": { - "href": "./B02.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B02", - "common_name": "blue", - "description": "Blue" - } - ] - }, - "B03": { - "href": "./B03.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B03", - "common_name": "green", - "description": "Green" - } - ] - }, - "B04": { - "href": "./B04.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B04", - "common_name": "red", - "description": "Red" - } - ] - }, - "B08": { - "href": "./B08.tif", - "type": "image/tiff; application=geotiff", - "eo:bands": [ - { - "name": "B08", - "common_name": "nir", - "description": "NIR" - } - ] - } - }, - "bbox": [ - 137.86450090473795, - -29.52744848758255, - 137.9185707094313, - -29.48015007212528 - ], - "stac_extensions": [ - "https://stac-extensions.github.io/eo/v1.0.0/schema.json" - ], - "collection": "ref_cloud_cover_detection_challenge_v1_test_source" -} \ No newline at end of file diff --git a/tests/data/resisc45/NWPU-RESISC45.rar b/tests/data/resisc45/NWPU-RESISC45.rar deleted file mode 100644 index 246ebc01075..00000000000 Binary files a/tests/data/resisc45/NWPU-RESISC45.rar and /dev/null differ diff --git a/tests/data/resisc45/NWPU-RESISC45.zip b/tests/data/resisc45/NWPU-RESISC45.zip new file mode 100644 index 00000000000..e000bc2e690 Binary files /dev/null and b/tests/data/resisc45/NWPU-RESISC45.zip differ diff --git a/tests/data/rwanda_field_boundary/data.py b/tests/data/rwanda_field_boundary/data.py old mode 100644 new mode 100755 index a3522e8c962..bf9954e8935 --- a/tests/data/rwanda_field_boundary/data.py +++ b/tests/data/rwanda_field_boundary/data.py @@ -3,99 +3,46 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import hashlib import os -import shutil import numpy as np import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') all_bands = ('B01', 'B02', 'B03', 'B04') SIZE = 32 -NUM_SAMPLES = 5 +DTYPE = np.uint16 +NUM_SAMPLES = 1 np.random.seed(0) - -def create_mask(fn: str) -> None: - profile = { - 'driver': 'GTiff', - 'dtype': 'uint8', - 'nodata': 0.0, - 'width': SIZE, - 'height': SIZE, - 'count': 1, - 'crs': 'epsg:3857', - 'compress': 'lzw', - 'predictor': 2, - 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - 'blockysize': 32, - 'tiled': False, - 'interleave': 'band', - } - with rasterio.open(fn, 'w', **profile) as f: - f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint8), 1) - - -def create_img(fn: str) -> None: - profile = { - 'driver': 'GTiff', - 'dtype': 'uint16', - 'nodata': 0.0, - 'width': SIZE, - 'height': SIZE, - 'count': 1, - 'crs': 'epsg:3857', - 'compress': 'lzw', - 'predictor': 2, - 'blockysize': 16, - 'transform': rasterio.Affine(10.0, 0.0, 0.0, 0.0, -10.0, 0.0), - 'tiled': False, - 'interleave': 'band', - } - with rasterio.open(fn, 'w', **profile) as f: - f.write(np.random.randint(0, 2, size=(SIZE, SIZE), dtype=np.uint16), 1) - - -if __name__ == '__main__': - # Train and test images - for split in ('train', 'test'): - for i in range(NUM_SAMPLES): - for date in dates: - directory = os.path.join( - f'nasa_rwanda_field_boundary_competition_source_{split}', - f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 - ) - os.makedirs(directory, exist_ok=True) - for band in all_bands: - create_img(os.path.join(directory, f'{band}.tif')) - - # Create collections.json, this isn't used by the dataset but is checked to - # exist - with open( - f'nasa_rwanda_field_boundary_competition_source_{split}/collections.json', - 'w', - ) as f: - f.write('Not used') - - # Train labels - for i in range(NUM_SAMPLES): - directory = os.path.join( - 'nasa_rwanda_field_boundary_competition_labels_train', - f'nasa_rwanda_field_boundary_competition_labels_train_{i:02d}', - ) - os.makedirs(directory, exist_ok=True) - create_mask(os.path.join(directory, 'raster_labels.tif')) - - # Create directories and compute checksums - for filename in [ - 'nasa_rwanda_field_boundary_competition_source_train', - 'nasa_rwanda_field_boundary_competition_source_test', - 'nasa_rwanda_field_boundary_competition_labels_train', - ]: - shutil.make_archive(filename, 'gztar', '.', filename) - # Compute checksums - with open(f'{filename}.tar.gz', 'rb') as f: - md5 = hashlib.md5(f.read()).hexdigest() - print(f'{filename}: {md5}') +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(3857), + 'transform': Affine( + 4.77731426716, 0.0, 3374518.037700199, 0.0, -4.77731426716, -168438.54642526805 + ), +} +Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + +for sample in range(NUM_SAMPLES): + for split in ['train', 'test']: + for date in dates: + path = os.path.join('source', split, date) + os.makedirs(path, exist_ok=True) + for band in all_bands: + file = os.path.join(path, f'{sample:02}_{band}.tif') + with rasterio.open(file, 'w', **profile) as src: + src.write(Z, 1) + + path = os.path.join('labels', 'train') + os.makedirs(path, exist_ok=True) + file = os.path.join(path, f'{sample:02}.tif') + with rasterio.open(file, 'w', **profile) as src: + src.write(Z, 1) diff --git a/tests/data/rwanda_field_boundary/labels/train/00.tif b/tests/data/rwanda_field_boundary/labels/train/00.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/labels/train/00.tif differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz deleted file mode 100644 index ffa98bb53d6..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_labels_train.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz deleted file mode 100644 index a834f66bf38..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_test.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz b/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz deleted file mode 100644 index 8239f70c200..00000000000 Binary files a/tests/data/rwanda_field_boundary/nasa_rwanda_field_boundary_competition_source_train.tar.gz and /dev/null differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_03/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_04/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_08/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_10/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_11/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/test/2021_12/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_03/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_04/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_08/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_10/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_11/00_B04.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B01.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B02.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B03.tif differ diff --git a/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif new file mode 100644 index 00000000000..bd39a26d5e3 Binary files /dev/null and b/tests/data/rwanda_field_boundary/source/train/2021_12/00_B04.tif differ diff --git a/tests/data/satlas/data.py b/tests/data/satlas/data.py new file mode 100755 index 00000000000..1661f6e425f --- /dev/null +++ b/tests/data/satlas/data.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import json +import os +import shutil + +from PIL import Image + +SIZE = 32 +landsat_size = { + 'b1': SIZE // 2, + 'b2': SIZE // 2, + 'b3': SIZE // 2, + 'b4': SIZE // 2, + 'b5': SIZE // 2, + 'b6': SIZE // 2, + 'b7': SIZE // 2, + 'b8': SIZE, + 'b9': SIZE // 2, + 'b10': SIZE // 2, + 'b11': SIZE // 4, + 'b12': SIZE // 4, +} + +index = [[7149, 3246], [1234, 5678]] +good_images = [ + [7149, 3246, '2022-03'], + [1234, 5678, '2022-03'], + [7149, 3246, 'm_3808245_se_17_1_20110801'], + [1234, 5678, 'm_3808245_se_17_1_20110801'], + [7149, 3246, '2022-01'], + [1234, 5678, '2022-01'], + [7149, 3246, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'], + [1234, 5678, 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235'], +] +times = { + '2022-03': '2022-03-01T00:00:00+00:00', + 'm_3808245_se_17_1_20110801': '2011-08-01T12:00:00+00:00', + '2022-01': '2022-01-01T00:00:00+00:00', + 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235': '2022-03-09T06:02:35+00:00', +} + +FILENAME_HIERARCHY = dict[str, 'FILENAME_HIERARCHY'] | list[str] +filenames: FILENAME_HIERARCHY = { + 'landsat': {'2022-03': list(f'b{i}' for i in range(1, 12))}, + 'naip': {'m_3808245_se_17_1_20110801': ['tci', 'ir']}, + 'sentinel1': {'2022-01': ['vh', 'vv']}, + 'sentinel2': { + 'S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235': [ + 'tci', + 'b05', + 'b06', + 'b07', + 'b08', + 'b11', + 'b12', + ] + }, +} + + +def create_files(path: str) -> None: + os.makedirs(path, exist_ok=True) + for col, row in index: + band = os.path.basename(path) + mode = 'RGB' if band == 'tci' else 'L' + size = SIZE + if 'landsat' in path: + size = landsat_size[band] + img = Image.new(mode, (size, size)) + img.save(os.path.join(path, f'{col}_{row}.png')) + + +def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None: + if isinstance(hierarchy, dict): + # Recursive case + for key, value in hierarchy.items(): + path = os.path.join(directory, key) + create_directory(path, value) + else: + # Base case + for value in hierarchy: + path = os.path.join(directory, value) + create_files(path) + + +if __name__ == '__main__': + create_directory('.', filenames) + + col, row = index[0] + path = os.path.join('static', f'{col}_{row}') + os.makedirs(path, exist_ok=True) + img = Image.new('L', (SIZE, SIZE)) + img.save(os.path.join(path, 'land_cover.png')) + + os.makedirs('metadata', exist_ok=True) + with open(os.path.join('metadata', 'train_lowres.json'), 'w') as f: + json.dump(index, f) + + with open(os.path.join('metadata', 'good_images_lowres_all.json'), 'w') as f: + json.dump(good_images, f) + + with open(os.path.join('metadata', 'image_times.json'), 'w') as f: + json.dump(times, f) + + for path in os.listdir('.'): + if os.path.isdir(path): + shutil.make_archive(path, 'tar', '.', path) diff --git a/tests/data/satlas/landsat.tar b/tests/data/satlas/landsat.tar new file mode 100644 index 00000000000..f21ba5980d5 Binary files /dev/null and b/tests/data/satlas/landsat.tar differ diff --git a/tests/data/satlas/landsat/2022-03/b1/1234_5678.png b/tests/data/satlas/landsat/2022-03/b1/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b1/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b1/7149_3246.png b/tests/data/satlas/landsat/2022-03/b1/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b1/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b10/1234_5678.png b/tests/data/satlas/landsat/2022-03/b10/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b10/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b10/7149_3246.png b/tests/data/satlas/landsat/2022-03/b10/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b10/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b11/1234_5678.png b/tests/data/satlas/landsat/2022-03/b11/1234_5678.png new file mode 100644 index 00000000000..a7ff273b877 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b11/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b11/7149_3246.png b/tests/data/satlas/landsat/2022-03/b11/7149_3246.png new file mode 100644 index 00000000000..a7ff273b877 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b11/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b2/1234_5678.png b/tests/data/satlas/landsat/2022-03/b2/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b2/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b2/7149_3246.png b/tests/data/satlas/landsat/2022-03/b2/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b2/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b3/1234_5678.png b/tests/data/satlas/landsat/2022-03/b3/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b3/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b3/7149_3246.png b/tests/data/satlas/landsat/2022-03/b3/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b3/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b4/1234_5678.png b/tests/data/satlas/landsat/2022-03/b4/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b4/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b4/7149_3246.png b/tests/data/satlas/landsat/2022-03/b4/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b4/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b5/1234_5678.png b/tests/data/satlas/landsat/2022-03/b5/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b5/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b5/7149_3246.png b/tests/data/satlas/landsat/2022-03/b5/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b5/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b6/1234_5678.png b/tests/data/satlas/landsat/2022-03/b6/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b6/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b6/7149_3246.png b/tests/data/satlas/landsat/2022-03/b6/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b6/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b7/1234_5678.png b/tests/data/satlas/landsat/2022-03/b7/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b7/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b7/7149_3246.png b/tests/data/satlas/landsat/2022-03/b7/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b7/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b8/1234_5678.png b/tests/data/satlas/landsat/2022-03/b8/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b8/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b8/7149_3246.png b/tests/data/satlas/landsat/2022-03/b8/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b8/7149_3246.png differ diff --git a/tests/data/satlas/landsat/2022-03/b9/1234_5678.png b/tests/data/satlas/landsat/2022-03/b9/1234_5678.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b9/1234_5678.png differ diff --git a/tests/data/satlas/landsat/2022-03/b9/7149_3246.png b/tests/data/satlas/landsat/2022-03/b9/7149_3246.png new file mode 100644 index 00000000000..1c5ee4e26cf Binary files /dev/null and b/tests/data/satlas/landsat/2022-03/b9/7149_3246.png differ diff --git a/tests/data/satlas/metadata.tar b/tests/data/satlas/metadata.tar new file mode 100644 index 00000000000..da55bab052a Binary files /dev/null and b/tests/data/satlas/metadata.tar differ diff --git a/tests/data/satlas/metadata/good_images_lowres_all.json b/tests/data/satlas/metadata/good_images_lowres_all.json new file mode 100644 index 00000000000..32f86878307 --- /dev/null +++ b/tests/data/satlas/metadata/good_images_lowres_all.json @@ -0,0 +1 @@ +[[7149, 3246, "2022-03"], [1234, 5678, "2022-03"], [7149, 3246, "m_3808245_se_17_1_20110801"], [1234, 5678, "m_3808245_se_17_1_20110801"], [7149, 3246, "2022-01"], [1234, 5678, "2022-01"], [7149, 3246, "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235"], [1234, 5678, "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235"]] \ No newline at end of file diff --git a/tests/data/satlas/metadata/image_times.json b/tests/data/satlas/metadata/image_times.json new file mode 100644 index 00000000000..9028902e0a3 --- /dev/null +++ b/tests/data/satlas/metadata/image_times.json @@ -0,0 +1 @@ +{"2022-03": "2022-03-01T00:00:00+00:00", "m_3808245_se_17_1_20110801": "2011-08-01T12:00:00+00:00", "2022-01": "2022-01-01T00:00:00+00:00", "S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235": "2022-03-09T06:02:35+00:00"} \ No newline at end of file diff --git a/tests/data/satlas/metadata/train_lowres.json b/tests/data/satlas/metadata/train_lowres.json new file mode 100644 index 00000000000..af40dbffd64 --- /dev/null +++ b/tests/data/satlas/metadata/train_lowres.json @@ -0,0 +1 @@ +[[7149, 3246], [1234, 5678]] \ No newline at end of file diff --git a/tests/data/satlas/naip.tar b/tests/data/satlas/naip.tar new file mode 100644 index 00000000000..c77db0c064c Binary files /dev/null and b/tests/data/satlas/naip.tar differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/1234_5678.png differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/ir/7149_3246.png differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png new file mode 100644 index 00000000000..1655bc2ca09 Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/1234_5678.png differ diff --git a/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png new file mode 100644 index 00000000000..1655bc2ca09 Binary files /dev/null and b/tests/data/satlas/naip/m_3808245_se_17_1_20110801/tci/7149_3246.png differ diff --git a/tests/data/satlas/sentinel1.tar b/tests/data/satlas/sentinel1.tar new file mode 100644 index 00000000000..75585130401 Binary files /dev/null and b/tests/data/satlas/sentinel1.tar differ diff --git a/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png b/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vh/1234_5678.png differ diff --git a/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png b/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vh/7149_3246.png differ diff --git a/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png b/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vv/1234_5678.png differ diff --git a/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png b/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel1/2022-01/vv/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2.tar b/tests/data/satlas/sentinel2.tar new file mode 100644 index 00000000000..aa3122a90c8 Binary files /dev/null and b/tests/data/satlas/sentinel2.tar differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b05/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b06/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b07/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b08/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b11/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/b12/7149_3246.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png new file mode 100644 index 00000000000..1655bc2ca09 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/1234_5678.png differ diff --git a/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png new file mode 100644 index 00000000000..1655bc2ca09 Binary files /dev/null and b/tests/data/satlas/sentinel2/S2A_MSIL1C_20220309T032601_N0400_R018_T48RYR_20220309T060235/tci/7149_3246.png differ diff --git a/tests/data/satlas/static.tar b/tests/data/satlas/static.tar new file mode 100644 index 00000000000..decccd5ca7a Binary files /dev/null and b/tests/data/satlas/static.tar differ diff --git a/tests/data/satlas/static/7149_3246/land_cover.png b/tests/data/satlas/static/7149_3246/land_cover.png new file mode 100644 index 00000000000..c1620c85534 Binary files /dev/null and b/tests/data/satlas/static/7149_3246/land_cover.png differ diff --git a/tests/data/seasonet/data.py b/tests/data/seasonet/data.py index 68fa8ffe397..e3197ddde12 100644 --- a/tests/data/seasonet/data.py +++ b/tests/data/seasonet/data.py @@ -112,7 +112,7 @@ # Compute checksums with open(archive, 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f'{season}: {repr(md5)}') + print(f'{season}: {md5!r}') # Write meta.csv with open('meta.csv', 'w') as f: @@ -121,7 +121,7 @@ # Compute checksums with open('meta.csv', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f'meta.csv: {repr(md5)}') + print(f'meta.csv: {md5!r}') os.makedirs('splits', exist_ok=True) @@ -138,4 +138,4 @@ # Compute checksums with open('splits.zip', 'rb') as f: md5 = hashlib.md5(f.read()).hexdigest() - print(f'splits: {repr(md5)}') + print(f'splits: {md5!r}') diff --git a/tests/data/skyscript/SkyScript_test_30K_filtered_by_CLIP_openai.csv b/tests/data/skyscript/SkyScript_test_30K_filtered_by_CLIP_openai.csv new file mode 100644 index 00000000000..ed35a3ee1ee --- /dev/null +++ b/tests/data/skyscript/SkyScript_test_30K_filtered_by_CLIP_openai.csv @@ -0,0 +1,3 @@ +filepath,title,title_multi_objects,similarity_CLIP_openai +images6/w779523169_CH_18.jpg,"a satellite image of a beautiful house I will never be able to afford","a satellite image of a beautiful house, surrounded by a yard",0.1 +images7/w602363451_US_21.jpg,"a satellite image of the last mall in the world","a satellite image of a mall; surrounded by a parking lot",0.2 diff --git a/tests/data/skyscript/SkyScript_train_top30pct_filtered_by_CLIP_openai.csv b/tests/data/skyscript/SkyScript_train_top30pct_filtered_by_CLIP_openai.csv new file mode 100644 index 00000000000..eeff57b9670 --- /dev/null +++ b/tests/data/skyscript/SkyScript_train_top30pct_filtered_by_CLIP_openai.csv @@ -0,0 +1,3 @@ +filepath,title,title_multi_objects,similarity_CLIP_openai +images2/w779523169_CH_18.jpg,"a satellite image of a beautiful house I will never be able to afford","a satellite image of a beautiful house, surrounded by a yard",0.1 +images3/w602363451_US_21.jpg,"a satellite image of the last mall in the world","a satellite image of a mall; surrounded by a parking lot",0.2 diff --git a/tests/data/skyscript/SkyScript_val_5K_filtered_by_CLIP_openai.csv b/tests/data/skyscript/SkyScript_val_5K_filtered_by_CLIP_openai.csv new file mode 100644 index 00000000000..89f8af99c59 --- /dev/null +++ b/tests/data/skyscript/SkyScript_val_5K_filtered_by_CLIP_openai.csv @@ -0,0 +1,3 @@ +filepath,title,title_multi_objects,similarity_CLIP_openai +images4/w779523169_CH_18.jpg,"a satellite image of a beautiful house I will never be able to afford","a satellite image of a beautiful house, surrounded by a yard",0.1 +images5/w602363451_US_21.jpg,"a satellite image of the last mall in the world","a satellite image of a mall; surrounded by a parking lot",0.2 diff --git a/tests/data/skyscript/data.py b/tests/data/skyscript/data.py new file mode 100755 index 00000000000..86a67e0ef8f --- /dev/null +++ b/tests/data/skyscript/data.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import random +import shutil + +import pandas as pd +from PIL import Image + +SIZE = 32 + +random.seed(0) + +for csv in glob.iglob('*.csv'): + captions = pd.read_csv(csv) + for jpg in captions['filepath']: + os.makedirs(os.path.dirname(jpg), exist_ok=True) + width = random.randrange(SIZE) + height = random.randrange(SIZE) + img = Image.new('RGB', (width, height)) + img.save(jpg) + +for directory in [f'images{i}' for i in range(2, 8)]: + shutil.make_archive(directory, 'zip', '.', directory) diff --git a/tests/data/skyscript/images2.zip b/tests/data/skyscript/images2.zip new file mode 100644 index 00000000000..ba6b19e99a4 Binary files /dev/null and b/tests/data/skyscript/images2.zip differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg b/tests/data/skyscript/images2/w779523169_CH_18.jpg similarity index 86% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg rename to tests/data/skyscript/images2/w779523169_CH_18.jpg index 77c95fe8774..e90f8457cbd 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg and b/tests/data/skyscript/images2/w779523169_CH_18.jpg differ diff --git a/tests/data/skyscript/images3.zip b/tests/data/skyscript/images3.zip new file mode 100644 index 00000000000..68623ed4bf6 Binary files /dev/null and b/tests/data/skyscript/images3.zip differ diff --git a/tests/data/skyscript/images3/w602363451_US_21.jpg b/tests/data/skyscript/images3/w602363451_US_21.jpg new file mode 100644 index 00000000000..45d07eefaa8 Binary files /dev/null and b/tests/data/skyscript/images3/w602363451_US_21.jpg differ diff --git a/tests/data/skyscript/images4.zip b/tests/data/skyscript/images4.zip new file mode 100644 index 00000000000..ef431ea2618 Binary files /dev/null and b/tests/data/skyscript/images4.zip differ diff --git a/tests/data/skyscript/images4/w779523169_CH_18.jpg b/tests/data/skyscript/images4/w779523169_CH_18.jpg new file mode 100644 index 00000000000..6ff1359cfdd Binary files /dev/null and b/tests/data/skyscript/images4/w779523169_CH_18.jpg differ diff --git a/tests/data/skyscript/images5.zip b/tests/data/skyscript/images5.zip new file mode 100644 index 00000000000..ac727708c5b Binary files /dev/null and b/tests/data/skyscript/images5.zip differ diff --git a/tests/data/skyscript/images5/w602363451_US_21.jpg b/tests/data/skyscript/images5/w602363451_US_21.jpg new file mode 100644 index 00000000000..e68769e09d6 Binary files /dev/null and b/tests/data/skyscript/images5/w602363451_US_21.jpg differ diff --git a/tests/data/skyscript/images6.zip b/tests/data/skyscript/images6.zip new file mode 100644 index 00000000000..9d89f997f41 Binary files /dev/null and b/tests/data/skyscript/images6.zip differ diff --git a/tests/data/skyscript/images6/w779523169_CH_18.jpg b/tests/data/skyscript/images6/w779523169_CH_18.jpg new file mode 100644 index 00000000000..59eeade739a Binary files /dev/null and b/tests/data/skyscript/images6/w779523169_CH_18.jpg differ diff --git a/tests/data/skyscript/images7.zip b/tests/data/skyscript/images7.zip new file mode 100644 index 00000000000..190c9bd8312 Binary files /dev/null and b/tests/data/skyscript/images7.zip differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg b/tests/data/skyscript/images7/w602363451_US_21.jpg similarity index 89% rename from tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg rename to tests/data/skyscript/images7/w602363451_US_21.jpg index 77c95fe8774..cedfb0cc7bb 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg and b/tests/data/skyscript/images7/w602363451_US_21.jpg differ diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index 40dfe87ea37..11e7d0db3d8 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -17,7 +17,7 @@ files = ['SouthAmerica_Soybean_2002.tif', 'SouthAmerica_Soybean_2021.tif'] -def create_file(path: str, dtype: str): +def create_file(path: str, dtype: str) -> None: """Create the testing file.""" profile = { 'driver': 'GTiff', diff --git a/tests/data/spacenet/data.py b/tests/data/spacenet/data.py deleted file mode 100755 index 3e0e6faf497..00000000000 --- a/tests/data/spacenet/data.py +++ /dev/null @@ -1,281 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import os -import shutil -from collections import OrderedDict -from typing import cast - -import fiona -import numpy as np -import rasterio -from rasterio.crs import CRS -from rasterio.transform import Affine -from torchvision.datasets.utils import calculate_md5 - -from torchgeo.datasets import ( - SpaceNet, - SpaceNet1, - SpaceNet2, - SpaceNet3, - SpaceNet4, - SpaceNet5, - SpaceNet6, - SpaceNet7, -) - -transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0) -crs = CRS.from_epsg(4326) - -img_count = { - 'MS.tif': 8, - 'PAN.tif': 1, - 'PS-MS.tif': 8, - 'PS-RGB.tif': 3, - 'PS-RGBNIR.tif': 4, - 'RGB.tif': 3, - 'RGBNIR.tif': 4, - 'SAR-Intensity.tif': 1, - 'mosaic.tif': 3, - '8Band.tif': 8, -} - - -sn4_catalog = [ - '10300100023BC100', - '10300100036D5200', - '1030010003BDDC00', - '1030010003CD4300', -] -sn4_angles = [8, 30, 52, 53] - -sn4_imgdirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-nadir{}_catid_{}' -sn4_lbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-labels' -sn4_emptyimgdirname = ( - 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-nadir53_' - + 'catid_1030010003CD4300' -) -sn4_emptylbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-labels' - - -datasets = [SpaceNet1, SpaceNet2, SpaceNet3, SpaceNet4, SpaceNet5, SpaceNet6, SpaceNet7] - - -def create_test_image(img_dir: str, imgs: list[str]) -> list[list[float]]: - """Create test image - - Args: - img_dir (str): Name of image directory - imgs (List[str]): List of images to be created - - Returns: - List[List[float]]: Boundary coordinates - """ - for img in imgs: - imgpath = os.path.join(img_dir, img) - Z = np.arange(4, dtype='uint16').reshape(2, 2) - count = img_count[img] - with rasterio.open( - imgpath, - 'w', - driver='GTiff', - height=Z.shape[0], - width=Z.shape[1], - count=count, - dtype=Z.dtype, - crs=crs, - transform=transform, - ) as dst: - for i in range(1, dst.count + 1): - dst.write(Z, i) - - tim = rasterio.open(imgpath) - slice_index = [[1, 1], [1, 2], [2, 2], [2, 1], [1, 1]] - return [list(tim.transform * p) for p in slice_index] - - -def create_test_label( - lbldir: str, - lblname: str, - coords: list[list[float]], - det_type: str, - empty: bool = False, - diff_crs: bool = False, -) -> None: - """Create test label - - Args: - lbldir (str): Name of label directory - lblname (str): Name of label file - coords (List[Tuple[float, float]]): Boundary coordinates - det_type (str): Type of dataset. Must be either buildings or roads. - empty (bool, optional): Creates empty label file if True. Defaults to False. - diff_crs (bool, optional): Assigns EPSG:3857 as CRS instead of - default EPSG:4326. Defaults to False. - """ - if empty: - # Creates a new file - with open(os.path.join(lbldir, lblname), 'w'): - pass - return - - if det_type == 'buildings': - meta_properties = OrderedDict() - geom = 'Polygon' - rec = { - 'type': 'Feature', - 'id': '0', - 'properties': OrderedDict(), - 'geometry': {'type': 'Polygon', 'coordinates': [coords]}, - } - else: - meta_properties = OrderedDict( - [ - ('heading', 'str'), - ('lane_number', 'str'), - ('one_way_ty', 'str'), - ('paved', 'str'), - ('road_id', 'int'), - ('road_type', 'str'), - ('origarea', 'int'), - ('origlen', 'float'), - ('partialDec', 'int'), - ('truncated', 'int'), - ('bridge_type', 'str'), - ('inferred_speed_mph', 'float'), - ('inferred_speed_mps', 'float'), - ] - ) - geom = 'LineString' - - dummy_vals = {'str': 'a', 'float': 45.0, 'int': 0} - ROAD_DICT = [(k, dummy_vals[v]) for k, v in meta_properties.items()] - rec = { - 'type': 'Feature', - 'id': '0', - 'properties': OrderedDict(ROAD_DICT), - 'geometry': {'type': 'LineString', 'coordinates': [coords[0], coords[2]]}, - } - - meta = { - 'driver': 'GeoJSON', - 'schema': {'properties': meta_properties, 'geometry': geom}, - 'crs': {'init': 'epsg:4326'}, - } - if diff_crs: - meta['crs'] = {'init': 'epsg:3857'} - out_file = os.path.join(lbldir, lblname) - with fiona.open(out_file, 'w', **meta) as dst: - dst.write(rec) - - -def main() -> None: - ROOT_DIR = os.path.dirname(os.path.realpath(__file__)) - - for dataset in datasets: - collections = list(dataset.collection_md5_dict.keys()) - for collection in collections: - dataset = cast(SpaceNet, dataset) - if dataset.dataset_id == 'spacenet4': - num_samples = 4 - elif collection == 'sn5_AOI_7_Moscow' or collection not in [ - 'sn5_AOI_8_Mumbai', - 'sn7_test_source', - ]: - num_samples = 3 - elif collection == 'sn5_AOI_8_Mumbai': - num_samples = 3 - else: - num_samples = 1 - - for sample in range(num_samples): - out_dir = os.path.join(ROOT_DIR, collection) - if collection == 'sn6_AOI_11_Rotterdam': - out_dir = os.path.join(ROOT_DIR, 'spacenet6', collection) - - # Create img dir - if dataset.dataset_id == 'spacenet4': - assert num_samples == 4 - if sample != 3: - imgdirname = sn4_imgdirname.format( - sn4_angles[sample], sn4_catalog[sample] - ) - lbldirname = sn4_lbldirname - else: - imgdirname = sn4_emptyimgdirname.format( - sn4_angles[sample], sn4_catalog[sample] - ) - lbldirname = sn4_emptylbldirname - else: - imgdirname = f'{collection}_img{sample + 1}' - lbldirname = f'{collection}_img{sample + 1}-labels' - - imgdir = os.path.join(out_dir, imgdirname) - os.makedirs(imgdir, exist_ok=True) - bounds = create_test_image(imgdir, list(dataset.imagery.values())) - - # Create lbl dir - lbldir = os.path.join(out_dir, lbldirname) - os.makedirs(lbldir, exist_ok=True) - det_type = 'roads' if dataset in [SpaceNet3, SpaceNet5] else 'buildings' - if dataset.dataset_id == 'spacenet4' and sample == 3: - # Creates an empty file - create_test_label( - lbldir, dataset.label_glob, bounds, det_type, empty=True - ) - else: - create_test_label(lbldir, dataset.label_glob, bounds, det_type) - - if collection == 'sn5_AOI_8_Mumbai': - if sample == 1: - create_test_label( - lbldir, dataset.label_glob, bounds, det_type, empty=True - ) - if sample == 2: - create_test_label( - lbldir, dataset.label_glob, bounds, det_type, diff_crs=True - ) - - if collection == 'sn1_AOI_1_RIO' and sample == 1: - create_test_label( - lbldir, dataset.label_glob, bounds, det_type, diff_crs=True - ) - - if collection not in [ - 'sn2_AOI_2_Vegas', - 'sn3_AOI_5_Khartoum', - 'sn4_AOI_6_Atlanta', - 'sn5_AOI_8_Mumbai', - 'sn6_AOI_11_Rotterdam', - 'sn7_train_source', - ]: - # Create collection.json - with open( - os.path.join(ROOT_DIR, collection, 'collection.json'), 'w' - ): - pass - if collection == 'sn6_AOI_11_Rotterdam': - # Create collection.json - with open( - os.path.join( - ROOT_DIR, 'spacenet6', collection, 'collection.json' - ), - 'w', - ): - pass - - # Create archive - if collection == 'sn6_AOI_11_Rotterdam': - break - archive_path = os.path.join(ROOT_DIR, collection) - shutil.make_archive( - archive_path, 'gztar', root_dir=ROOT_DIR, base_dir=collection - ) - shutil.rmtree(out_dir) - print(f'{collection}: {calculate_md5(f"{archive_path}.tar.gz")}') - - -if __name__ == '__main__': - main() diff --git a/tests/data/spacenet/sn1_AOI_1_RIO.tar.gz b/tests/data/spacenet/sn1_AOI_1_RIO.tar.gz deleted file mode 100644 index 5b731c0383c..00000000000 Binary files a/tests/data/spacenet/sn1_AOI_1_RIO.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1-labels/labels.geojson b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1-labels/labels.geojson deleted file mode 100644 index 0a418938820..00000000000 --- a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1-labels/labels.geojson +++ /dev/null @@ -1,7 +0,0 @@ -{ -"type": "FeatureCollection", -"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } }, -"features": [ -{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } } -] -} diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/8Band.tif b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/8Band.tif deleted file mode 100644 index 9383f0cdde9..00000000000 Binary files a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/8Band.tif and /dev/null differ diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/RGB.tif b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/RGB.tif deleted file mode 100644 index 022510c2df5..00000000000 Binary files a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/RGB.tif and /dev/null differ diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2-labels/labels.geojson b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2-labels/labels.geojson deleted file mode 100644 index a1174201e20..00000000000 --- a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2-labels/labels.geojson +++ /dev/null @@ -1,7 +0,0 @@ -{ -"type": "FeatureCollection", -"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:EPSG::3857" } }, -"features": [ -{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } } -] -} diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/8Band.tif b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/8Band.tif deleted file mode 100644 index 9383f0cdde9..00000000000 Binary files a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/8Band.tif and /dev/null differ diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/RGB.tif b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/RGB.tif deleted file mode 100644 index 022510c2df5..00000000000 Binary files a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/RGB.tif and /dev/null differ diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3-labels/labels.geojson b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3-labels/labels.geojson deleted file mode 100644 index 0a418938820..00000000000 --- a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3-labels/labels.geojson +++ /dev/null @@ -1,7 +0,0 @@ -{ -"type": "FeatureCollection", -"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } }, -"features": [ -{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } } -] -} diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/8Band.tif b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/8Band.tif deleted file mode 100644 index 9383f0cdde9..00000000000 Binary files a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/8Band.tif and /dev/null differ diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/RGB.tif b/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/RGB.tif deleted file mode 100644 index 022510c2df5..00000000000 Binary files a/tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/RGB.tif and /dev/null differ diff --git a/tests/data/spacenet/sn2_AOI_2_Vegas.tar.gz b/tests/data/spacenet/sn2_AOI_2_Vegas.tar.gz deleted file mode 100644 index c7cbcd5b4fc..00000000000 Binary files a/tests/data/spacenet/sn2_AOI_2_Vegas.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn2_AOI_3_Paris.tar.gz b/tests/data/spacenet/sn2_AOI_3_Paris.tar.gz deleted file mode 100644 index 6c26bde44d7..00000000000 Binary files a/tests/data/spacenet/sn2_AOI_3_Paris.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn2_AOI_4_Shanghai.tar.gz b/tests/data/spacenet/sn2_AOI_4_Shanghai.tar.gz deleted file mode 100644 index b7d8ba655e1..00000000000 Binary files a/tests/data/spacenet/sn2_AOI_4_Shanghai.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn2_AOI_5_Khartoum.tar.gz b/tests/data/spacenet/sn2_AOI_5_Khartoum.tar.gz deleted file mode 100644 index b0a4b29ca34..00000000000 Binary files a/tests/data/spacenet/sn2_AOI_5_Khartoum.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn3_AOI_2_Vegas.tar.gz b/tests/data/spacenet/sn3_AOI_2_Vegas.tar.gz deleted file mode 100644 index 0e17e78befc..00000000000 Binary files a/tests/data/spacenet/sn3_AOI_2_Vegas.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz b/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz deleted file mode 100644 index 3a960eca9af..00000000000 Binary files a/tests/data/spacenet/sn3_AOI_3_Paris.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn3_AOI_4_Shanghai.tar.gz b/tests/data/spacenet/sn3_AOI_4_Shanghai.tar.gz deleted file mode 100644 index f3b479e43b1..00000000000 Binary files a/tests/data/spacenet/sn3_AOI_4_Shanghai.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz b/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz deleted file mode 100644 index f3a1b809291..00000000000 Binary files a/tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz b/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz deleted file mode 100644 index a1e0c8a6910..00000000000 Binary files a/tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn5_AOI_7_Moscow.tar.gz b/tests/data/spacenet/sn5_AOI_7_Moscow.tar.gz deleted file mode 100644 index c666f1b837c..00000000000 Binary files a/tests/data/spacenet/sn5_AOI_7_Moscow.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn5_AOI_8_Mumbai.tar.gz b/tests/data/spacenet/sn5_AOI_8_Mumbai.tar.gz deleted file mode 100644 index 4f6b9cd7ad4..00000000000 Binary files a/tests/data/spacenet/sn5_AOI_8_Mumbai.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn7_test_source.tar.gz b/tests/data/spacenet/sn7_test_source.tar.gz deleted file mode 100644 index e411894fb3a..00000000000 Binary files a/tests/data/spacenet/sn7_test_source.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn7_train_labels.tar.gz b/tests/data/spacenet/sn7_train_labels.tar.gz deleted file mode 100644 index b3f583771c0..00000000000 Binary files a/tests/data/spacenet/sn7_train_labels.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/sn7_train_source.tar.gz b/tests/data/spacenet/sn7_train_source.tar.gz deleted file mode 100644 index e847fd070a4..00000000000 Binary files a/tests/data/spacenet/sn7_train_source.tar.gz and /dev/null differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img1.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img1.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img1.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img2.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img2.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img2.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img3.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img3.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img3.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img4.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img4.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/3band/3band_AOI_1_RIO_img4.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img1.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img1.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img1.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img2.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img2.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img2.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img3.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img3.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img3.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img4.tif b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img4.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/8band/8band_AOI_1_RIO_img4.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/SN1_buildings_test_AOI_1_Rio_3band.tar.gz b/tests/data/spacenet/spacenet1/SN1_buildings/test/SN1_buildings_test_AOI_1_Rio_3band.tar.gz new file mode 100644 index 00000000000..34d22b20d09 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/SN1_buildings_test_AOI_1_Rio_3band.tar.gz differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/test/SN1_buildings_test_AOI_1_Rio_8band.tar.gz b/tests/data/spacenet/spacenet1/SN1_buildings/test/SN1_buildings_test_AOI_1_Rio_8band.tar.gz new file mode 100644 index 00000000000..a7da7248e5d Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/test/SN1_buildings_test_AOI_1_Rio_8band.tar.gz differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img1.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img1.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img1.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img2.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img2.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img2.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img3.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img3.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img3.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img4.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img4.tif new file mode 100644 index 00000000000..beab6f5fd84 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/3band/3band_AOI_1_RIO_img4.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img1.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img1.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img1.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img2.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img2.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img2.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img3.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img3.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img3.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img4.tif b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img4.tif new file mode 100644 index 00000000000..cfb2a40c34f Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/8band/8band_AOI_1_RIO_img4.tif differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_3band.tar.gz b/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_3band.tar.gz new file mode 100644 index 00000000000..40ef8731526 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_3band.tar.gz differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_8band.tar.gz b/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_8band.tar.gz new file mode 100644 index 00000000000..6eaebc3a9ad Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_8band.tar.gz differ diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz b/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz new file mode 100644 index 00000000000..ed4f60826e0 Binary files /dev/null and b/tests/data/spacenet/spacenet1/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz differ diff --git a/tests/data/spacenet/sn1_AOI_1_RIO/collection.json b/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img1.geojson similarity index 100% rename from tests/data/spacenet/sn1_AOI_1_RIO/collection.json rename to tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img1.geojson diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img2.geojson b/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img2.geojson new file mode 100644 index 00000000000..dabcc4ce308 --- /dev/null +++ b/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img2.geojson @@ -0,0 +1 @@ +{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[-43.7720361, -22.922229499999958, 0.0], [-43.772064, -22.9222724, 0.0], [-43.77210239999994, -22.922247399999947, 0.0], [-43.772074499999974, -22.9222046, 0.0], [-43.7720361, -22.922229499999958, 0.0]]]}}]} \ No newline at end of file diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/collection.json b/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img3.geojson similarity index 100% rename from tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/collection.json rename to tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img3.geojson diff --git a/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img4.geojson b/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img4.geojson new file mode 100644 index 00000000000..dabcc4ce308 --- /dev/null +++ b/tests/data/spacenet/spacenet1/SN1_buildings/train/geojson/Geo_AOI_1_RIO_img4.geojson @@ -0,0 +1 @@ +{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[-43.7720361, -22.922229499999958, 0.0], [-43.772064, -22.9222724, 0.0], [-43.77210239999994, -22.922247399999947, 0.0], [-43.772074499999974, -22.9222046, 0.0], [-43.7720361, -22.922229499999958, 0.0]]]}}]} \ No newline at end of file diff --git a/tests/data/spacenet/spacenet1/data.py b/tests/data/spacenet/spacenet1/data.py new file mode 100755 index 00000000000..f8187583cc8 --- /dev/null +++ b/tests/data/spacenet/spacenet1/data.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import json +import os +import shutil +from typing import Any + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + +SIZE = 2 + +NUM_SAMPLES = 4 + +dataset_id = 'SN1_buildings' + +profile = { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'width': SIZE, + 'height': SIZE, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( + 4.489235388119662e-06, + 0.0, + -43.7732462563, + 0.0, + -4.486127586210932e-06, + -22.9214851954, + ), +} + +np.random.seed(0) +Z = np.random.randint(np.iinfo('uint8').max, size=(SIZE, SIZE), dtype='uint8') + + +def create_directories(base_path: str, band_counts: list[int]) -> None: + for count in band_counts: + os.makedirs(os.path.join(base_path, f'{count}band'), exist_ok=True) + + +def generate_geotiff_files( + base_path: str, band_counts: list[int], profile: dict[str, Any], Z: np.ndarray +) -> None: + for count in band_counts: + for i in range(1, NUM_SAMPLES + 1): + path = os.path.join( + base_path, f'{count}band', f'{count}band_AOI_1_RIO_img{i}.tif' + ) + profile['count'] = count + with rasterio.open(path, 'w', **profile) as src: + for j in range(1, count + 1): + src.write(Z, j) + + +def generate_geojson_files(base_path: str, geojson: dict[str, Any]) -> None: + os.makedirs(os.path.join(base_path, 'geojson'), exist_ok=True) + for i in range(1, NUM_SAMPLES + 1): + path = os.path.join(base_path, 'geojson', f'Geo_AOI_1_RIO_img{i}.geojson') + with open(path, 'w') as src: + if i % 2 == 0: + json.dump(geojson, src) + + +def compute_md5(file_path: str) -> str: + hash_md5 = hashlib.md5() + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +# Generate dummy GeoJSON files for building footprints +geojson = { + 'type': 'FeatureCollection', + 'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}}, + 'features': [ + { + 'type': 'Feature', + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ + [ + [-43.7720361, -22.922229499999958, 0.0], + [-43.772064, -22.9222724, 0.0], + [-43.772102399999937, -22.922247399999947, 0.0], + [-43.772074499999974, -22.9222046, 0.0], + [-43.7720361, -22.922229499999958, 0.0], + ] + ], + }, + } + ], +} + +# Remove existing data if it exists +if os.path.exists(dataset_id): + shutil.rmtree(dataset_id) + +train_base_path = os.path.join(dataset_id, 'train') +test_base_path = os.path.join(dataset_id, 'test') + +# Create directories and generate dummy GeoTIFF files for train dataset +create_directories(train_base_path, [3, 8]) +generate_geotiff_files(train_base_path, [3, 8], profile, Z) +generate_geojson_files(train_base_path, geojson) + +# Create directories and generate dummy GeoTIFF files for test dataset (only 3band and 8band) +create_directories(test_base_path, [3, 8]) +generate_geotiff_files(test_base_path, [3, 8], profile, Z) + +# Create tarballs for train and test datasets +tarball_specs = { + 'train': { + '3band': 'SN1_buildings_train_AOI_1_Rio_3band', + '8band': 'SN1_buildings_train_AOI_1_Rio_8band', + 'geojson': 'SN1_buildings_train_AOI_1_Rio_geojson_buildings', + }, + 'test': { + '3band': 'SN1_buildings_test_AOI_1_Rio_3band', + '8band': 'SN1_buildings_test_AOI_1_Rio_8band', + }, +} + +for split, specs in tarball_specs.items(): + for subdir, tarball_name in specs.items(): + tarball_path = os.path.join(dataset_id, split, tarball_name) + shutil.make_archive( + tarball_path, + 'gztar', + root_dir=os.path.join(dataset_id, split), + base_dir=subdir, + ) + +# Compute and print MD5 checksums for the generated tarballs +print('MD5 Checksums for Train Dataset:') +train_tarballs = [ + 'SN1_buildings_train_AOI_1_Rio_3band.tar.gz', + 'SN1_buildings_train_AOI_1_Rio_8band.tar.gz', + 'SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz', +] +for tarball in train_tarballs: + tarball_path = os.path.join(dataset_id, 'train', tarball) + if os.path.exists(tarball_path): + print(f'{tarball}: {compute_md5(tarball_path)}') + +print('\nMD5 Checksums for Test Dataset:') +test_tarballs = [ + 'SN1_buildings_test_AOI_1_Rio_3band.tar.gz', + 'SN1_buildings_test_AOI_1_Rio_8band.tar.gz', +] +for tarball in test_tarballs: + tarball_path = os.path.join(dataset_id, 'test', tarball) + if os.path.exists(tarball_path): + print(f'{tarball}: {compute_md5(tarball_path)}') diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/test/SN6_buildings_AOI_11_Rotterdam_test.tar.gz b/tests/data/spacenet/spacenet6/SN6_buildings/test/SN6_buildings_AOI_11_Rotterdam_test.tar.gz new file mode 100644 index 00000000000..d787b796d94 Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/test/SN6_buildings_AOI_11_Rotterdam_test.tar.gz differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_1.tif b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_1.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_1.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_2.tif b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_2.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_2.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_3.tif b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_3.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_3.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_4.tif b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_4.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/test/test_public/AOI_11_Rotterdam/SAR-Intensity/SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_4.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/SN6_buildings_AOI_11_Rotterdam_train.tar.gz b/tests/data/spacenet/spacenet6/SN6_buildings/train/SN6_buildings_AOI_11_Rotterdam_train.tar.gz new file mode 100644 index 00000000000..a863cfbe026 Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/SN6_buildings_AOI_11_Rotterdam_train.tar.gz differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_1.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_1.tif new file mode 100644 index 00000000000..650664c080f Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_1.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_2.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_2.tif new file mode 100644 index 00000000000..650664c080f Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_2.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_3.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_3.tif new file mode 100644 index 00000000000..650664c080f Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_3.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_4.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_4.tif new file mode 100644 index 00000000000..650664c080f Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PAN/SN6_Train_AOI_11_Rotterdam_PAN_20190804111224_20190804111453_tile_4.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_1.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_1.tif new file mode 100644 index 00000000000..96950caa289 Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_1.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_2.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_2.tif new file mode 100644 index 00000000000..96950caa289 Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_2.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_3.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_3.tif new file mode 100644 index 00000000000..96950caa289 Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_3.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_4.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_4.tif new file mode 100644 index 00000000000..96950caa289 Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGB/SN6_Train_AOI_11_Rotterdam_PS-RGB_20190804111224_20190804111453_tile_4.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_1.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_1.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_1.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_2.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_2.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_2.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_3.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_3.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_3.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_4.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_4.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/PS-RGBNIR/SN6_Train_AOI_11_Rotterdam_PS-RGBNIR_20190804111224_20190804111453_tile_4.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_1.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_1.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_1.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_2.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_2.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_2.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_3.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_3.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_3.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_4.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_4.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/RGBNIR/SN6_Train_AOI_11_Rotterdam_RGBNIR_20190804111224_20190804111453_tile_4.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_1.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_1.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_1.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_2.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_2.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_2.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_3.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_3.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_3.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_4.tif b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_4.tif new file mode 100644 index 00000000000..cce10b4555a Binary files /dev/null and b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/SAR-Intensity/SN6_Train_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_4.tif differ diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_1.geojson b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_1.geojson new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_2.geojson b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_2.geojson new file mode 100644 index 00000000000..7ae5755f9f7 --- /dev/null +++ b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_2.geojson @@ -0,0 +1 @@ +{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[4.47917, 51.9225, 0.0], [4.4792, 51.92255, 0.0], [4.47925, 51.92252, 0.0], [4.47922, 51.92247, 0.0], [4.47917, 51.9225, 0.0]]]}}]} \ No newline at end of file diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_3.geojson b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_3.geojson new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_4.geojson b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_4.geojson new file mode 100644 index 00000000000..7ae5755f9f7 --- /dev/null +++ b/tests/data/spacenet/spacenet6/SN6_buildings/train/train/AOI_11_Rotterdam/geojson_buildings/SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_4.geojson @@ -0,0 +1 @@ +{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[4.47917, 51.9225, 0.0], [4.4792, 51.92255, 0.0], [4.47925, 51.92252, 0.0], [4.47922, 51.92247, 0.0], [4.47917, 51.9225, 0.0]]]}}]} \ No newline at end of file diff --git a/tests/data/spacenet/spacenet6/data.py b/tests/data/spacenet/spacenet6/data.py new file mode 100644 index 00000000000..47f2b59510f --- /dev/null +++ b/tests/data/spacenet/spacenet6/data.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import json +import os +import shutil +from typing import Any + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + +SIZE = 2 + +NUM_SAMPLES = 4 + +dataset_id = 'SN6_buildings' + +profile = { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'width': SIZE, + 'height': SIZE, + 'crs': CRS.from_epsg(4326), + 'transform': Affine( + 4.489235388119662e-06, 0.0, 4.47917, 0.0, -4.486127586210932e-06, 51.9225 + ), +} + +np.random.seed(0) +Z = np.random.randint(np.iinfo('uint8').max, size=(SIZE, SIZE), dtype='uint8') + +# Define the types of imagery for SpaceNet6 +imagery_types = ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'] +imagery_channels = { + 'PAN': 1, + 'PS-RGB': 3, + 'PS-RGBNIR': 4, + 'RGBNIR': 4, + 'SAR-Intensity': 4, +} + + +def create_directories(base_path: str, imagery_types: list[str]) -> None: + for imagery_type in imagery_types: + os.makedirs(os.path.join(base_path, imagery_type), exist_ok=True) + + +def generate_geotiff_files( + base_path: str, + imagery_types: str, + imagery_channels: int, + profile: dict[str, Any], + Z: np.ndarray, + test: bool = False, +) -> None: + for imagery_type in imagery_types: + for i in range(1, NUM_SAMPLES + 1): + if test and imagery_type == 'SAR-Intensity': + path = os.path.join( + base_path, + f'SN6_Test_Public_AOI_11_Rotterdam_SAR-Intensity_20190804111224_20190804111453_tile_{i}.tif', + ) + else: + path = os.path.join( + base_path, + imagery_type, + f'SN6_Train_AOI_11_Rotterdam_{imagery_type}_20190804111224_20190804111453_tile_{i}.tif', + ) + profile['count'] = imagery_channels[imagery_type] + with rasterio.open(path, 'w', **profile) as src: + for j in range(1, profile['count'] + 1): + src.write(Z, j) + + +def generate_geojson_files(base_path: str, geojson: dict[str, Any]) -> None: + os.makedirs(os.path.join(base_path, 'geojson_buildings'), exist_ok=True) + for i in range(1, NUM_SAMPLES + 1): + path = os.path.join( + base_path, + 'geojson_buildings', + f'SN6_Train_AOI_11_Rotterdam_Buildings_20190804111224_20190804111453_tile_{i}.geojson', + ) + with open(path, 'w') as src: + if i % 2 == 0: + json.dump(geojson, src) + + +def compute_md5(file_path: str) -> str: + hash_md5 = hashlib.md5() + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +# Generate dummy GeoJSON files for building footprints +geojson = { + 'type': 'FeatureCollection', + 'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}}, + 'features': [ + { + 'type': 'Feature', + 'geometry': { + 'type': 'Polygon', + 'coordinates': [ + [ + [4.47917, 51.9225, 0.0], + [4.47920, 51.92255, 0.0], + [4.47925, 51.92252, 0.0], + [4.47922, 51.92247, 0.0], + [4.47917, 51.9225, 0.0], + ] + ], + }, + } + ], +} + +# Remove existing data if it exists +if os.path.exists(dataset_id): + shutil.rmtree(dataset_id) + +train_base_path = os.path.join(dataset_id, 'train/train/AOI_11_Rotterdam') +test_base_path = os.path.join( + dataset_id, 'test/test_public/AOI_11_Rotterdam/SAR-Intensity' +) + +# Create directories and generate dummy GeoTIFF files for train dataset +create_directories(train_base_path, imagery_types) +generate_geotiff_files(train_base_path, imagery_types, imagery_channels, profile, Z) +generate_geojson_files(train_base_path, geojson) + +# Create directories and generate dummy GeoTIFF files for test dataset (only SAR-Intensity) +os.makedirs(test_base_path, exist_ok=True) +generate_geotiff_files( + test_base_path, ['SAR-Intensity'], imagery_channels, profile, Z, test=True +) + +# Create tarballs for train and test datasets +shutil.make_archive( + os.path.join(dataset_id, 'train', 'SN6_buildings_AOI_11_Rotterdam_train'), + 'gztar', + root_dir=os.path.join(dataset_id, 'train'), + base_dir='train', +) +shutil.make_archive( + os.path.join(dataset_id, 'test', 'SN6_buildings_AOI_11_Rotterdam_test'), + 'gztar', + root_dir=os.path.join(dataset_id, 'test'), + base_dir='test_public', +) + +# Compute and print MD5 checksums for the generated tarballs +print('MD5 Checksums for Train Dataset:') +train_tarball_path = os.path.join( + dataset_id, 'train', 'SN6_buildings_AOI_11_Rotterdam_train.tar.gz' +) +if os.path.exists(train_tarball_path): + print(f'Train: {compute_md5(train_tarball_path)}') + +print('\nMD5 Checksums for Test Dataset:') +test_tarball_path = os.path.join( + dataset_id, 'test', 'SN6_buildings_AOI_11_Rotterdam_test.tar.gz' +) +if os.path.exists(test_tarball_path): + print(f'Test: {compute_md5(test_tarball_path)}') diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1-labels/labels.geojson b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1-labels/labels.geojson deleted file mode 100644 index 0a418938820..00000000000 --- a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1-labels/labels.geojson +++ /dev/null @@ -1,7 +0,0 @@ -{ -"type": "FeatureCollection", -"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } }, -"features": [ -{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } } -] -} diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PAN.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PAN.tif deleted file mode 100644 index a9aef1da576..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PAN.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PS-RGB.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PS-RGB.tif deleted file mode 100644 index 022510c2df5..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PS-RGB.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PS-RGBNIR.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PS-RGBNIR.tif deleted file mode 100644 index daadc4a2e39..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/PS-RGBNIR.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/RGBNIR.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/RGBNIR.tif deleted file mode 100644 index daadc4a2e39..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/RGBNIR.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/SAR-Intensity.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/SAR-Intensity.tif deleted file mode 100644 index a9aef1da576..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img1/SAR-Intensity.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2-labels/labels.geojson b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2-labels/labels.geojson deleted file mode 100644 index 0a418938820..00000000000 --- a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2-labels/labels.geojson +++ /dev/null @@ -1,7 +0,0 @@ -{ -"type": "FeatureCollection", -"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } }, -"features": [ -{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } } -] -} diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PAN.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PAN.tif deleted file mode 100644 index a9aef1da576..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PAN.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PS-RGB.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PS-RGB.tif deleted file mode 100644 index 022510c2df5..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PS-RGB.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PS-RGBNIR.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PS-RGBNIR.tif deleted file mode 100644 index daadc4a2e39..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/PS-RGBNIR.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/RGBNIR.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/RGBNIR.tif deleted file mode 100644 index daadc4a2e39..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/RGBNIR.tif and /dev/null differ diff --git a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/SAR-Intensity.tif b/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/SAR-Intensity.tif deleted file mode 100644 index a9aef1da576..00000000000 Binary files a/tests/data/spacenet/spacenet6/sn6_AOI_11_Rotterdam/sn6_AOI_11_Rotterdam_img2/SAR-Intensity.tif and /dev/null differ diff --git a/tests/data/ssl4eo_benchmark_landsat/data.py b/tests/data/ssl4eo_benchmark_landsat/data.py index d0b32779b6a..177ed7d7954 100755 --- a/tests/data/ssl4eo_benchmark_landsat/data.py +++ b/tests/data/ssl4eo_benchmark_landsat/data.py @@ -172,7 +172,7 @@ def create_mask_directory( create_mask(path.replace('all_bands', f'{mask_product}_{year}')) -def create_tarballs(directories) -> None: +def create_tarballs(directories: str) -> None: for directory in directories: # Create tarballs shutil.make_archive(directory, 'gztar', '.', directory) diff --git a/tests/data/technoserve-cashew-benin/data.py b/tests/data/technoserve-cashew-benin/data.py new file mode 100755 index 00000000000..7d0ec9d58bb --- /dev/null +++ b/tests/data/technoserve-cashew-benin/data.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import rasterio +from rasterio.crs import CRS +from rasterio.transform import Affine + +DTYPE = np.uint16 +SIZE = 2 + +np.random.seed(0) + +dates = ('00_20191105',) +all_bands = ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + 'CLD', +) +profile = { + 'driver': 'GTiff', + 'dtype': DTYPE, + 'width': SIZE, + 'height': SIZE, + 'count': 1, + 'crs': CRS.from_epsg(32631), + 'transform': Affine( + 10.002549584378608, + 0.0, + 440853.29890114715, + 0.0, + -9.99842989423825, + 1012804.082877621, + ), +} + +for date in dates: + os.makedirs(os.path.join('imagery', '00', date), exist_ok=True) + for band in all_bands: + Z = np.random.randint(np.iinfo(DTYPE).max, size=(SIZE, SIZE), dtype=DTYPE) + path = os.path.join('imagery', '00', date, f'{date}_{band}_10m.tif') + with rasterio.open(path, 'w', **profile) as src: + src.write(Z, 1) diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B01_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B01_10m.tif new file mode 100644 index 00000000000..e459a3a490d Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B01_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B02_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B02_10m.tif new file mode 100644 index 00000000000..fd4ca7ce56b Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B02_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B03_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B03_10m.tif new file mode 100644 index 00000000000..33b458bef62 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B03_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B04_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B04_10m.tif new file mode 100644 index 00000000000..76ca0fbd89d Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B04_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B05_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B05_10m.tif new file mode 100644 index 00000000000..a73de74ec33 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B05_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B06_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B06_10m.tif new file mode 100644 index 00000000000..65d8ef98d17 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B06_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B07_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B07_10m.tif new file mode 100644 index 00000000000..558bbd08853 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B07_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B08_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B08_10m.tif new file mode 100644 index 00000000000..532a7d37cf6 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B08_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B09_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B09_10m.tif new file mode 100644 index 00000000000..7111bb5dbba Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B09_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B11_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B11_10m.tif new file mode 100644 index 00000000000..68106c1669b Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B11_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B12_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B12_10m.tif new file mode 100644 index 00000000000..4ea3767ce4c Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B12_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B8A_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B8A_10m.tif new file mode 100644 index 00000000000..6f7df54f0b3 Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_B8A_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_CLD_10m.tif b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_CLD_10m.tif new file mode 100644 index 00000000000..41f05d9ba2f Binary files /dev/null and b/tests/data/technoserve-cashew-benin/imagery/00/00_20191105/00_20191105_CLD_10m.tif differ diff --git a/tests/data/technoserve-cashew-benin/labels/00.geojson b/tests/data/technoserve-cashew-benin/labels/00.geojson new file mode 100644 index 00000000000..ba92dace006 --- /dev/null +++ b/tests/data/technoserve-cashew-benin/labels/00.geojson @@ -0,0 +1,8 @@ +{ +"type": "FeatureCollection", +"name": "cashew_benin", +"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:EPSG::32631" } }, +"features": [ +{ "type": "Feature", "properties": { "OBJECTID": 1, "class": 1, "Shape_Leng": 367629.52331100003, "Shape_Area": 16997542.377500001, "class_name": "Well-managed planatation" }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 447131.214800000190735, 1001286.6359 ], [ 447166.272199999541044, 1001285.31299999915 ], [ 447196.037899999879301, 1001285.31299999915 ], [ 447244.324400000274181, 1001283.3286 ], [ 447256.230700000189245, 1001282.00569999963 ], [ 447253.58490000013262, 1001248.9327 ], [ 447254.907800000160933, 1001228.4275 ], [ 447252.923700000159442, 1001212.05179999955 ], [ 447087.555899999978, 1001212.333799999207 ], [ 447082.266800000332296, 1001241.656599999988 ], [ 447076.97510000038892, 1001256.208799999207 ], [ 447074.329300000332296, 1001286.6359 ], [ 447131.214800000190735, 1001286.6359 ] ] ] } } +] +} diff --git a/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif new file mode 100644 index 00000000000..7df6abd74c0 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif new file mode 100644 index 00000000000..967876aa69c Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Alnus_spec._5_13114_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Alnus_spec._5_13114_WEFL_NLF.tif new file mode 100644 index 00000000000..36c6c049001 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Alnus_spec._5_13114_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif new file mode 100644 index 00000000000..48b36565180 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Picea_abies_2_46896_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Picea_abies_2_46896_WEFL_NLF.tif new file mode 100644 index 00000000000..3fb7ec2f4b7 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Picea_abies_2_46896_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Picea_abies_3_46636_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Picea_abies_3_46636_WEFL_NLF.tif new file mode 100644 index 00000000000..fc3913038bf Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Picea_abies_3_46636_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif b/tests/data/treesatai/aerial/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif new file mode 100644 index 00000000000..2fdb09a25a2 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Quercus_petraea_2_84375_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Quercus_petraea_2_84375_WEFL_NLF.tif new file mode 100644 index 00000000000..f7e0af9eb85 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Quercus_petraea_2_84375_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Quercus_petraea_5_80549_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Quercus_petraea_5_80549_WEFL_NLF.tif new file mode 100644 index 00000000000..52889605c84 Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Quercus_petraea_5_80549_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial/60m/Quercus_rubra_1_92184_WEFL_NLF.tif b/tests/data/treesatai/aerial/60m/Quercus_rubra_1_92184_WEFL_NLF.tif new file mode 100644 index 00000000000..cffd19dbffe Binary files /dev/null and b/tests/data/treesatai/aerial/60m/Quercus_rubra_1_92184_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/aerial_60m_acer_pseudoplatanus.zip b/tests/data/treesatai/aerial_60m_acer_pseudoplatanus.zip new file mode 100644 index 00000000000..b24e8514895 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_acer_pseudoplatanus.zip differ diff --git a/tests/data/treesatai/aerial_60m_alnus_spec.zip b/tests/data/treesatai/aerial_60m_alnus_spec.zip new file mode 100644 index 00000000000..15cb0ecb3e2 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_alnus_spec.zip differ diff --git a/tests/data/treesatai/aerial_60m_fagus_sylvatica.zip b/tests/data/treesatai/aerial_60m_fagus_sylvatica.zip new file mode 100644 index 00000000000..42716c30c93 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_fagus_sylvatica.zip differ diff --git a/tests/data/treesatai/aerial_60m_picea_abies.zip b/tests/data/treesatai/aerial_60m_picea_abies.zip new file mode 100644 index 00000000000..33baaf54215 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_picea_abies.zip differ diff --git a/tests/data/treesatai/aerial_60m_pseudotsuga_menziesii.zip b/tests/data/treesatai/aerial_60m_pseudotsuga_menziesii.zip new file mode 100644 index 00000000000..23a3636a759 Binary files /dev/null and b/tests/data/treesatai/aerial_60m_pseudotsuga_menziesii.zip differ diff --git a/tests/data/treesatai/aerial_60m_quercus_petraea.zip b/tests/data/treesatai/aerial_60m_quercus_petraea.zip new file mode 100644 index 00000000000..268ee1134ac Binary files /dev/null and b/tests/data/treesatai/aerial_60m_quercus_petraea.zip differ diff --git a/tests/data/treesatai/aerial_60m_quercus_rubra.zip b/tests/data/treesatai/aerial_60m_quercus_rubra.zip new file mode 100644 index 00000000000..4552c6fc66c Binary files /dev/null and b/tests/data/treesatai/aerial_60m_quercus_rubra.zip differ diff --git a/tests/data/treesatai/data.py b/tests/data/treesatai/data.py new file mode 100755 index 00000000000..dac5337cff8 --- /dev/null +++ b/tests/data/treesatai/data.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import json +import os +import random +import shutil +import zipfile + +import numpy as np +import rasterio +from rasterio import Affine +from rasterio.crs import CRS + +SIZE = 32 + +random.seed(0) +np.random.seed(0) + +classes = ( + 'Abies', + 'Acer', + 'Alnus', + 'Betula', + 'Cleared', + 'Fagus', + 'Fraxinus', + 'Larix', + 'Picea', + 'Pinus', + 'Populus', + 'Prunus', + 'Pseudotsuga', + 'Quercus', + 'Tilia', +) + +species = ( + 'Acer_pseudoplatanus', + 'Alnus_spec', + 'Fagus_sylvatica', + 'Picea_abies', + 'Pseudotsuga_menziesii', + 'Quercus_petraea', + 'Quercus_rubra', +) + +profile = { + 'aerial': { + 'driver': 'GTiff', + 'dtype': 'uint8', + 'nodata': None, + 'width': SIZE, + 'height': SIZE, + 'count': 4, + 'crs': CRS.from_epsg(25832), + 'transform': Affine( + 0.19999999999977022, 0.0, 552245.4, 0.0, -0.19999999999938728, 5728215.0 + ), + }, + 's1': { + 'driver': 'GTiff', + 'dtype': 'float32', + 'nodata': -9999.0, + 'width': SIZE // 16, + 'height': SIZE // 16, + 'count': 3, + 'crs': CRS.from_epsg(32632), + 'transform': Affine(10.0, 0.0, 552245.0, 0.0, -10.0, 5728215.0), + }, + 's2': { + 'driver': 'GTiff', + 'dtype': 'uint16', + 'nodata': None, + 'width': SIZE // 16, + 'height': SIZE // 16, + 'count': 12, + 'crs': CRS.from_epsg(32632), + 'transform': Affine(10.0, 0.0, 552241.6565, 0.0, -10.0, 5728211.6251), + }, +} + +multi_labels = {} +for split in ['train', 'test']: + with open(f'{split}_filenames.lst') as f: + for filename in f: + filename = filename.strip() + for sensor in ['aerial', 's1', 's2']: + kwargs = profile[sensor] + directory = os.path.join(sensor, '60m') + os.makedirs(directory, exist_ok=True) + if 'int' in kwargs['dtype']: + Z = np.random.randint( + np.iinfo(kwargs['dtype']).min, + np.iinfo(kwargs['dtype']).max, + size=(kwargs['height'], kwargs['width']), + dtype=kwargs['dtype'], + ) + else: + Z = np.random.rand(kwargs['height'], kwargs['width']) + + path = os.path.join(directory, filename) + with rasterio.open(path, 'w', **kwargs) as src: + for i in range(1, kwargs['count'] + 1): + src.write(Z, i) + + k = random.randrange(1, 4) + labels = random.choices(classes, k=k) + pcts = np.random.rand(k) + pcts /= np.sum(pcts) + multi_labels[filename] = list(map(list, zip(labels, map(float, pcts)))) + +os.makedirs('labels', exist_ok=True) +path = os.path.join('labels', 'TreeSatBA_v9_60m_multi_labels.json') +with open(path, 'w') as f: + json.dump(multi_labels, f) + +for sensor in ['s1', 's2', 'labels']: + shutil.make_archive(sensor, 'zip', '.', sensor) + +for spec in species: + path = f'aerial_60m_{spec}.zip'.lower() + with zipfile.ZipFile(path, 'w') as f: + for path in glob.iglob(os.path.join('aerial', '60m', f'{spec}_*.tif')): + filename = os.path.split(path)[-1] + f.write(path, arcname=filename) diff --git a/tests/data/treesatai/labels.zip b/tests/data/treesatai/labels.zip new file mode 100644 index 00000000000..24a773a5ef5 Binary files /dev/null and b/tests/data/treesatai/labels.zip differ diff --git a/tests/data/treesatai/labels/TreeSatBA_v9_60m_multi_labels.json b/tests/data/treesatai/labels/TreeSatBA_v9_60m_multi_labels.json new file mode 100644 index 00000000000..e9f9a12a37b --- /dev/null +++ b/tests/data/treesatai/labels/TreeSatBA_v9_60m_multi_labels.json @@ -0,0 +1 @@ +{"Picea_abies_3_46636_WEFL_NLF.tif": [["Prunus", 0.20692122963708826], ["Fraxinus", 0.7930787703629117]], "Pseudotsuga_menziesii_1_339575_BI_NLF.tif": [["Tilia", 0.4243067837573989], ["Larix", 0.5756932162426011]], "Quercus_rubra_1_92184_WEFL_NLF.tif": [["Tilia", 0.5816157697641007], ["Fagus", 0.4183842302358993]], "Fagus_sylvatica_9_29995_WEFL_NLF.tif": [["Larix", 1.0]], "Quercus_petraea_5_80549_WEFL_NLF.tif": [["Alnus", 0.5749721529276662], ["Acer", 0.4250278470723338]], "Acer_pseudoplatanus_3_5758_WEFL_NLF.tif": [["Tilia", 0.8430361090251272], ["Larix", 0.1569638909748729]], "Alnus_spec._5_13114_WEFL_NLF.tif": [["Pseudotsuga", 0.17881149698366108], ["Quercus", 0.38732907538618866], ["Cleared", 0.4338594276301503]], "Quercus_petraea_2_84375_WEFL_NLF.tif": [["Acer", 0.3909090505343164], ["Pseudotsuga", 0.2628926194326892], ["Cleared", 0.34619833003299444]], "Picea_abies_2_46896_WEFL_NLF.tif": [["Acer", 0.4953810312272686], ["Fraxinus", 0.0006659055704136941], ["Pinus", 0.5039530632023177]], "Acer_pseudoplatanus_4_6058_WEFL_NLF.tif": [["Tilia", 1.0]]} \ No newline at end of file diff --git a/tests/data/treesatai/s1.zip b/tests/data/treesatai/s1.zip new file mode 100644 index 00000000000..052d0dc5553 Binary files /dev/null and b/tests/data/treesatai/s1.zip differ diff --git a/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif new file mode 100644 index 00000000000..e3180fbed8e Binary files /dev/null and b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif new file mode 100644 index 00000000000..0d8403f3f3b Binary files /dev/null and b/tests/data/treesatai/s1/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Alnus_spec._5_13114_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Alnus_spec._5_13114_WEFL_NLF.tif new file mode 100644 index 00000000000..5f73542d330 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Alnus_spec._5_13114_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif new file mode 100644 index 00000000000..343126b9235 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Picea_abies_2_46896_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Picea_abies_2_46896_WEFL_NLF.tif new file mode 100644 index 00000000000..b15947f122c Binary files /dev/null and b/tests/data/treesatai/s1/60m/Picea_abies_2_46896_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Picea_abies_3_46636_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Picea_abies_3_46636_WEFL_NLF.tif new file mode 100644 index 00000000000..c9878414adf Binary files /dev/null and b/tests/data/treesatai/s1/60m/Picea_abies_3_46636_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif b/tests/data/treesatai/s1/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif new file mode 100644 index 00000000000..00ba9b03129 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Quercus_petraea_2_84375_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Quercus_petraea_2_84375_WEFL_NLF.tif new file mode 100644 index 00000000000..2e4898fb55d Binary files /dev/null and b/tests/data/treesatai/s1/60m/Quercus_petraea_2_84375_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Quercus_petraea_5_80549_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Quercus_petraea_5_80549_WEFL_NLF.tif new file mode 100644 index 00000000000..0562717348c Binary files /dev/null and b/tests/data/treesatai/s1/60m/Quercus_petraea_5_80549_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s1/60m/Quercus_rubra_1_92184_WEFL_NLF.tif b/tests/data/treesatai/s1/60m/Quercus_rubra_1_92184_WEFL_NLF.tif new file mode 100644 index 00000000000..db825c3ff27 Binary files /dev/null and b/tests/data/treesatai/s1/60m/Quercus_rubra_1_92184_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2.zip b/tests/data/treesatai/s2.zip new file mode 100644 index 00000000000..eb5dabc8c98 Binary files /dev/null and b/tests/data/treesatai/s2.zip differ diff --git a/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif new file mode 100644 index 00000000000..9d182f62584 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_3_5758_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif new file mode 100644 index 00000000000..d61c7b7a20b Binary files /dev/null and b/tests/data/treesatai/s2/60m/Acer_pseudoplatanus_4_6058_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Alnus_spec._5_13114_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Alnus_spec._5_13114_WEFL_NLF.tif new file mode 100644 index 00000000000..660f23905de Binary files /dev/null and b/tests/data/treesatai/s2/60m/Alnus_spec._5_13114_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif new file mode 100644 index 00000000000..bf8c659fb45 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Fagus_sylvatica_9_29995_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Picea_abies_2_46896_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Picea_abies_2_46896_WEFL_NLF.tif new file mode 100644 index 00000000000..7bd25b4c837 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Picea_abies_2_46896_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Picea_abies_3_46636_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Picea_abies_3_46636_WEFL_NLF.tif new file mode 100644 index 00000000000..b62e8364578 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Picea_abies_3_46636_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif b/tests/data/treesatai/s2/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif new file mode 100644 index 00000000000..938c8528c28 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Pseudotsuga_menziesii_1_339575_BI_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Quercus_petraea_2_84375_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Quercus_petraea_2_84375_WEFL_NLF.tif new file mode 100644 index 00000000000..69603a72ae3 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Quercus_petraea_2_84375_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Quercus_petraea_5_80549_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Quercus_petraea_5_80549_WEFL_NLF.tif new file mode 100644 index 00000000000..affe18983a6 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Quercus_petraea_5_80549_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/s2/60m/Quercus_rubra_1_92184_WEFL_NLF.tif b/tests/data/treesatai/s2/60m/Quercus_rubra_1_92184_WEFL_NLF.tif new file mode 100644 index 00000000000..ccd44d2b692 Binary files /dev/null and b/tests/data/treesatai/s2/60m/Quercus_rubra_1_92184_WEFL_NLF.tif differ diff --git a/tests/data/treesatai/test_filenames.lst b/tests/data/treesatai/test_filenames.lst new file mode 100644 index 00000000000..9d81989c444 --- /dev/null +++ b/tests/data/treesatai/test_filenames.lst @@ -0,0 +1 @@ +Acer_pseudoplatanus_4_6058_WEFL_NLF.tif diff --git a/tests/data/treesatai/train_filenames.lst b/tests/data/treesatai/train_filenames.lst new file mode 100644 index 00000000000..9a92169b832 --- /dev/null +++ b/tests/data/treesatai/train_filenames.lst @@ -0,0 +1,9 @@ +Picea_abies_3_46636_WEFL_NLF.tif +Pseudotsuga_menziesii_1_339575_BI_NLF.tif +Quercus_rubra_1_92184_WEFL_NLF.tif +Fagus_sylvatica_9_29995_WEFL_NLF.tif +Quercus_petraea_5_80549_WEFL_NLF.tif +Acer_pseudoplatanus_3_5758_WEFL_NLF.tif +Alnus_spec._5_13114_WEFL_NLF.tif +Quercus_petraea_2_84375_WEFL_NLF.tif +Picea_abies_2_46896_WEFL_NLF.tif diff --git a/tests/data/ts_cashew_benin/ts_cashew_benin_labels.tar.gz b/tests/data/ts_cashew_benin/ts_cashew_benin_labels.tar.gz deleted file mode 100644 index 5a9d7d22a18..00000000000 Binary files a/tests/data/ts_cashew_benin/ts_cashew_benin_labels.tar.gz and /dev/null differ diff --git a/tests/data/ts_cashew_benin/ts_cashew_benin_source.tar.gz b/tests/data/ts_cashew_benin/ts_cashew_benin_source.tar.gz deleted file mode 100644 index 1e94b5526b0..00000000000 Binary files a/tests/data/ts_cashew_benin/ts_cashew_benin_source.tar.gz and /dev/null differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset.rar b/tests/data/vhr10/NWPU VHR-10 dataset.rar deleted file mode 100644 index 0f836ac8e17..00000000000 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset.rar and /dev/null differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset.zip b/tests/data/vhr10/NWPU VHR-10 dataset.zip new file mode 100644 index 00000000000..e8e722caa13 Binary files /dev/null and b/tests/data/vhr10/NWPU VHR-10 dataset.zip differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/001.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/001.jpg index 46177fa63c7..137d73e35c2 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/001.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/001.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/002.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/002.jpg index 4325ad6a369..15f488eeb4a 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/002.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/002.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/003.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/003.jpg index 90e1622ff16..473bd3114ab 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/003.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/003.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/004.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/004.jpg index 15795298cc6..6896e11c67a 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/004.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/004.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/005.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/005.jpg index 542306c1bf2..bf61c978774 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/005.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/negative image set/005.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/001.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/001.jpg index 46177fa63c7..137d73e35c2 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/001.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/001.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/002.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/002.jpg index 4325ad6a369..15f488eeb4a 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/002.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/002.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/003.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/003.jpg index 90e1622ff16..473bd3114ab 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/003.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/003.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/004.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/004.jpg index 15795298cc6..6896e11c67a 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/004.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/004.jpg differ diff --git a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/005.jpg b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/005.jpg index 542306c1bf2..bf61c978774 100644 Binary files a/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/005.jpg and b/tests/data/vhr10/NWPU VHR-10 dataset/positive image set/005.jpg differ diff --git a/tests/data/vhr10/data.py b/tests/data/vhr10/data.py index 44e60966c3b..25aa672c214 100755 --- a/tests/data/vhr10/data.py +++ b/tests/data/vhr10/data.py @@ -4,13 +4,13 @@ import json import os import shutil -import subprocess import numpy as np from PIL import Image from torchvision.datasets.utils import calculate_md5 ANNOTATION_FILE = {'images': [], 'annotations': []} +DIRECTORY = 'NWPU VHR-10 dataset' def write_data(path: str, img: np.ndarray) -> None: @@ -20,7 +20,7 @@ def write_data(path: str, img: np.ndarray) -> None: def generate_test_data(root: str, n_imgs: int = 3) -> str: - folder_path = os.path.join(root, 'NWPU VHR-10 dataset') + folder_path = os.path.join(root, DIRECTORY) pos_img_dir = os.path.join(folder_path, 'positive image set') neg_img_dir = os.path.join(folder_path, 'negative image set') ann_file = os.path.join(folder_path, 'annotations.json') @@ -65,16 +65,9 @@ def generate_test_data(root: str, n_imgs: int = 3) -> str: with open(ann_file2, 'w') as j: json.dump(ANNOTATION_FILE, j) - # Create rar file - subprocess.run( - ['rar', 'a', 'NWPU VHR-10 dataset.rar', '-m5', 'NWPU VHR-10 dataset'], - capture_output=True, - check=True, - ) - + shutil.make_archive(DIRECTORY, 'zip', '.', DIRECTORY) annotations_md5 = calculate_md5(ann_file) - archive_md5 = calculate_md5('NWPU VHR-10 dataset.rar') - shutil.rmtree(folder_path) + archive_md5 = calculate_md5(f'{DIRECTORY}.zip') return f'archive md5: {archive_md5}, annotation md5: {annotations_md5}' diff --git a/tests/data/western_usa_live_fuel_moisture/data.py b/tests/data/western_usa_live_fuel_moisture/data.py index 44fc8717b47..6cad40a7c78 100755 --- a/tests/data/western_usa_live_fuel_moisture/data.py +++ b/tests/data/western_usa_live_fuel_moisture/data.py @@ -3,10 +3,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import hashlib import json import os -import shutil NUM_SAMPLES = 3 @@ -159,65 +157,9 @@ 'geometry': {'type': 'Point', 'coordinates': [-115.8855556, 42.44111111]}, } -STAC = { - 'assets': { - 'documentation': { - 'href': '../_common/documentation.pdf', - 'type': 'application/pdf', - }, - 'labels': {'href': 'labels.geojson', 'type': 'application/geo+json'}, - 'training_features_descriptions': { - 'href': '../_common/training_features_descriptions.csv', - 'title': 'Training Features Descriptions', - 'type': 'text/csv', - }, - }, - 'bbox': [-115.8855556, 42.44111111, -115.8855556, 42.44111111], - 'collection': 'su_sar_moisture_content', - 'geometry': {'coordinates': [-115.8855556, 42.44111111], 'type': 'Point'}, - 'id': 'su_sar_moisture_content_0001', - 'links': [ - {'href': '../collection.json', 'rel': 'collection'}, - {'href': '../collection.json', 'rel': 'parent'}, - ], - 'properties': { - 'datetime': '2015-06-30T00:00:00Z', - 'label:description': '', - 'label:properties': ['percent(t)'], - 'label:type': 'vector', - }, - 'stac_extensions': ['label'], - 'stac_version': '1.0.0-beta.2', - 'type': 'Feature', -} - -def create_file(path: str) -> None: - label_path = os.path.join(path, 'labels.geojson') - with open(label_path, 'w') as f: +os.makedirs(data_dir, exist_ok=True) +for i in range(1, NUM_SAMPLES + 1): + filename = os.path.join(data_dir, f'feature_{i:04}.geojson') + with open(filename, 'w') as f: json.dump(LABELS, f) - - stac_path = os.path.join(path, 'stac.json') - with open(stac_path, 'w') as f: - json.dump(STAC, f) - - -if __name__ == '__main__': - # Remove old data - if os.path.isdir(data_dir): - shutil.rmtree(data_dir) - - os.makedirs(os.path.join(os.getcwd(), data_dir)) - - for i in range(NUM_SAMPLES): - sample_dir = os.path.join(data_dir, data_dir + f'_{i}') - os.makedirs(sample_dir) - create_file(sample_dir) - - # Compress data - shutil.make_archive(data_dir, 'gztar', '.', data_dir) - - # Compute checksums - with open(data_dir + '.tar.gz', 'rb') as f: - md5 = hashlib.md5(f.read()).hexdigest() - print(f'{data_dir}.tar.gz: {md5}') diff --git a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content.tar.gz b/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content.tar.gz deleted file mode 100644 index fe38d993d67..00000000000 Binary files a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content.tar.gz and /dev/null differ diff --git a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_0/labels.geojson b/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/feature_0001.geojson similarity index 100% rename from tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_0/labels.geojson rename to tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/feature_0001.geojson diff --git a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_1/labels.geojson b/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/feature_0002.geojson similarity index 100% rename from tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_1/labels.geojson rename to tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/feature_0002.geojson diff --git a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_2/labels.geojson b/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/feature_0003.geojson similarity index 100% rename from tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_2/labels.geojson rename to tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/feature_0003.geojson diff --git a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_0/stac.json b/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_0/stac.json deleted file mode 100644 index 469f98574d9..00000000000 --- a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_0/stac.json +++ /dev/null @@ -1 +0,0 @@ -{"assets": {"documentation": {"href": "../_common/documentation.pdf", "type": "application/pdf"}, "labels": {"href": "labels.geojson", "type": "application/geo+json"}, "training_features_descriptions": {"href": "../_common/training_features_descriptions.csv", "title": "Training Features Descriptions", "type": "text/csv"}}, "bbox": [-115.8855556, 42.44111111, -115.8855556, 42.44111111], "collection": "su_sar_moisture_content", "geometry": {"coordinates": [-115.8855556, 42.44111111], "type": "Point"}, "id": "su_sar_moisture_content_0001", "links": [{"href": "../collection.json", "rel": "collection"}, {"href": "../collection.json", "rel": "parent"}], "properties": {"datetime": "2015-06-30T00:00:00Z", "label:description": "", "label:properties": ["percent(t)"], "label:type": "vector"}, "stac_extensions": ["label"], "stac_version": "1.0.0-beta.2", "type": "Feature"} \ No newline at end of file diff --git a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_1/stac.json b/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_1/stac.json deleted file mode 100644 index 469f98574d9..00000000000 --- a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_1/stac.json +++ /dev/null @@ -1 +0,0 @@ -{"assets": {"documentation": {"href": "../_common/documentation.pdf", "type": "application/pdf"}, "labels": {"href": "labels.geojson", "type": "application/geo+json"}, "training_features_descriptions": {"href": "../_common/training_features_descriptions.csv", "title": "Training Features Descriptions", "type": "text/csv"}}, "bbox": [-115.8855556, 42.44111111, -115.8855556, 42.44111111], "collection": "su_sar_moisture_content", "geometry": {"coordinates": [-115.8855556, 42.44111111], "type": "Point"}, "id": "su_sar_moisture_content_0001", "links": [{"href": "../collection.json", "rel": "collection"}, {"href": "../collection.json", "rel": "parent"}], "properties": {"datetime": "2015-06-30T00:00:00Z", "label:description": "", "label:properties": ["percent(t)"], "label:type": "vector"}, "stac_extensions": ["label"], "stac_version": "1.0.0-beta.2", "type": "Feature"} \ No newline at end of file diff --git a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_2/stac.json b/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_2/stac.json deleted file mode 100644 index 469f98574d9..00000000000 --- a/tests/data/western_usa_live_fuel_moisture/su_sar_moisture_content/su_sar_moisture_content_2/stac.json +++ /dev/null @@ -1 +0,0 @@ -{"assets": {"documentation": {"href": "../_common/documentation.pdf", "type": "application/pdf"}, "labels": {"href": "labels.geojson", "type": "application/geo+json"}, "training_features_descriptions": {"href": "../_common/training_features_descriptions.csv", "title": "Training Features Descriptions", "type": "text/csv"}}, "bbox": [-115.8855556, 42.44111111, -115.8855556, 42.44111111], "collection": "su_sar_moisture_content", "geometry": {"coordinates": [-115.8855556, 42.44111111], "type": "Point"}, "id": "su_sar_moisture_content_0001", "links": [{"href": "../collection.json", "rel": "collection"}, {"href": "../collection.json", "rel": "parent"}], "properties": {"datetime": "2015-06-30T00:00:00Z", "label:description": "", "label:properties": ["percent(t)"], "label:type": "vector"}, "stac_extensions": ["label"], "stac_version": "1.0.0-beta.2", "type": "Feature"} \ No newline at end of file diff --git a/tests/datamodules/test_caffe.py b/tests/datamodules/test_caffe.py new file mode 100644 index 00000000000..3048cdd85fa --- /dev/null +++ b/tests/datamodules/test_caffe.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import matplotlib.pyplot as plt +import pytest + +from torchgeo.datamodules import CaFFeDataModule + + +class TestCaFFeDataModule: + @pytest.fixture + def datamodule(self) -> CaFFeDataModule: + root = os.path.join('tests', 'data', 'caffe') + batch_size = 2 + num_workers = 0 + dm = CaFFeDataModule(root=root, batch_size=batch_size, num_workers=num_workers) + return dm + + def test_train_dataloader(self, datamodule: CaFFeDataModule) -> None: + datamodule.setup('fit') + next(iter(datamodule.train_dataloader())) + + def test_val_dataloader(self, datamodule: CaFFeDataModule) -> None: + datamodule.setup('validate') + next(iter(datamodule.val_dataloader())) + + def test_test_dataloader(self, datamodule: CaFFeDataModule) -> None: + datamodule.setup('test') + next(iter(datamodule.test_dataloader())) + + def test_plot(self, datamodule: CaFFeDataModule) -> None: + datamodule.setup('validate') + batch = next(iter(datamodule.val_dataloader())) + sample = { + 'image': batch['image'][0], + 'mask_zones': batch['mask_zones'][0], + 'mask_front': batch['mask_front'][0], + } + datamodule.plot(sample) + plt.close() diff --git a/tests/datamodules/test_digital_typhoon.py b/tests/datamodules/test_digital_typhoon.py new file mode 100644 index 00000000000..0ecd85f5ec7 --- /dev/null +++ b/tests/datamodules/test_digital_typhoon.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Test Digital Typhoon Datamodule.""" + +import os + +import pytest + +from torchgeo.datamodules import DigitalTyphoonDataModule +from torchgeo.datasets.digital_typhoon import DigitalTyphoon, _SampleSequenceDict + +pytest.importorskip('h5py', minversion='3.6') + + +class TestDigitalTyphoonDataModule: + def test_invalid_param_config(self) -> None: + with pytest.raises(AssertionError, match='Please choose from'): + DigitalTyphoonDataModule( + root=os.path.join('tests', 'data', 'digital_typhoon'), + split_by='invalid', + batch_size=2, + num_workers=0, + ) + + @pytest.mark.parametrize('split_by', ['time', 'typhoon_id']) + def test_split_dataset(self, split_by: str) -> None: + dm = DigitalTyphoonDataModule( + root=os.path.join('tests', 'data', 'digital_typhoon'), + split_by=split_by, + batch_size=2, + num_workers=0, + ) + dataset = DigitalTyphoon(root=os.path.join('tests', 'data', 'digital_typhoon')) + train_indices, val_indices = dm._split_dataset(dataset.sample_sequences) + train_sequences, val_sequences = ( + [dataset.sample_sequences[i] for i in train_indices], + [dataset.sample_sequences[i] for i in val_indices], + ) + + if split_by == 'time': + + def find_max_time_per_id( + split_sequences: list[_SampleSequenceDict], + ) -> dict[str, int]: + # Find the maximum value of each id in train_sequences + max_values: dict[str, int] = {} + for seq in split_sequences: + id: str = str(seq['id']) + value: int = max(seq['seq_id']) + if id not in max_values or value > max_values[id]: + max_values[id] = value + return max_values + + train_max_values = find_max_time_per_id(train_sequences) + val_max_values = find_max_time_per_id(val_sequences) + # Assert that each max value in train_max_values is lower + # than in val_max_values for each key id + for id, max_value in train_max_values.items(): + assert ( + id not in val_max_values or max_value < val_max_values[id] + ), f'Max value for id {id} in train is not lower than in validation.' + else: + train_ids = {seq['id'] for seq in train_sequences} + val_ids = {seq['id'] for seq in val_sequences} + + # Assert that the intersection between train_ids and val_ids is empty + assert ( + len(train_ids & val_ids) == 0 + ), 'Train and validation datasets have overlapping ids.' diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 8380ce242b8..4e5431c684f 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -31,8 +31,8 @@ def __init__( self.res = 1 def __getitem__(self, query: BoundingBox) -> dict[str, Any]: - image = torch.arange(3 * 2 * 2).view(3, 2, 2) - return {'image': image, 'crs': CRS.from_epsg(4326), 'bbox': query} + image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2) + return {'image': image, 'crs': CRS.from_epsg(4326), 'bounds': query} def plot(self, *args: Any, **kwargs: Any) -> Figure: return plt.figure() @@ -68,7 +68,7 @@ def __init__( self.length = length def __getitem__(self, index: int) -> dict[str, Tensor]: - return {'image': torch.arange(3 * 2 * 2).view(3, 2, 2)} + return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)} def __len__(self) -> int: return self.length diff --git a/tests/datamodules/test_levircd.py b/tests/datamodules/test_levircd.py index 4a026f62caf..ccec9af3c8b 100644 --- a/tests/datamodules/test_levircd.py +++ b/tests/datamodules/test_levircd.py @@ -2,23 +2,14 @@ # Licensed under the MIT License. import os -import shutil -from pathlib import Path import pytest import torchvision.transforms.functional as F from lightning.pytorch import Trainer -from pytest import MonkeyPatch from torch import Tensor from torchvision.transforms import InterpolationMode -import torchgeo.datasets.utils from torchgeo.datamodules import LEVIRCDDataModule, LEVIRCDPlusDataModule -from torchgeo.datasets import LEVIRCD, LEVIRCDPlus - - -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) def transforms(sample: dict[str, Tensor]) -> dict[str, Tensor]: @@ -44,23 +35,10 @@ def transforms(sample: dict[str, Tensor]) -> dict[str, Tensor]: class TestLEVIRCDPlusDataModule: @pytest.fixture - def datamodule( - self, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> LEVIRCDPlusDataModule: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) - md5 = '0ccca34310bfe7096dadfbf05b0d180f' - monkeypatch.setattr(LEVIRCDPlus, 'md5', md5) - url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip') - monkeypatch.setattr(LEVIRCDPlus, 'url', url) - - root = str(tmp_path) + def datamodule(self) -> LEVIRCDPlusDataModule: + root = os.path.join('tests', 'data', 'levircd', 'levircdplus') dm = LEVIRCDPlusDataModule( - root=root, - download=True, - num_workers=0, - checksum=True, - val_split_pct=0.5, - transforms=transforms, + root=root, num_workers=0, val_split_pct=0.5, transforms=transforms ) dm.prepare_data() dm.trainer = Trainer(accelerator='cpu', max_epochs=1) @@ -113,36 +91,9 @@ def test_test_dataloader(self, datamodule: LEVIRCDPlusDataModule) -> None: class TestLEVIRCDDataModule: @pytest.fixture - def datamodule(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LEVIRCDDataModule: - directory = os.path.join('tests', 'data', 'levircd', 'levircd') - splits = { - 'train': { - 'url': os.path.join(directory, 'train.zip'), - 'filename': 'train.zip', - 'md5': '7c2e24b3072095519f1be7eb01fae4ff', - }, - 'val': { - 'url': os.path.join(directory, 'val.zip'), - 'filename': 'val.zip', - 'md5': '5c320223ba88b6fc8ff9d1feebc3b84e', - }, - 'test': { - 'url': os.path.join(directory, 'test.zip'), - 'filename': 'test.zip', - 'md5': '021db72d4486726d6a0702563a617b32', - }, - } - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) - monkeypatch.setattr(LEVIRCD, 'splits', splits) - - root = str(tmp_path) - dm = LEVIRCDDataModule( - root=root, - download=True, - num_workers=0, - checksum=True, - transforms=transforms, - ) + def datamodule(self) -> LEVIRCDDataModule: + root = os.path.join('tests', 'data', 'levircd', 'levircd') + dm = LEVIRCDDataModule(root=root, num_workers=0, transforms=transforms) dm.prepare_data() dm.trainer = Trainer(accelerator='cpu', max_epochs=1) return dm diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview.py similarity index 100% rename from tests/datamodules/test_xview2.py rename to tests/datamodules/test_xview.py diff --git a/tests/datasets/aws b/tests/datasets/aws new file mode 120000 index 00000000000..b5147b964d7 --- /dev/null +++ b/tests/datasets/aws @@ -0,0 +1 @@ +aws.py \ No newline at end of file diff --git a/tests/datasets/aws.bat b/tests/datasets/aws.bat new file mode 100644 index 00000000000..a5e88609e7f --- /dev/null +++ b/tests/datasets/aws.bat @@ -0,0 +1,6 @@ +REM Copyright (c) Microsoft Corporation. All rights reserved. +REM Licensed under the MIT License. + +@ECHO OFF + +python3 tests\datasets\aws.py %* diff --git a/tests/datasets/aws.py b/tests/datasets/aws.py new file mode 100755 index 00000000000..187dda206bb --- /dev/null +++ b/tests/datasets/aws.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Basic mock-up of the AWS CLI.""" + +import argparse +import shutil + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + s3 = subparsers.add_parser('s3') + subsubparsers = s3.add_subparsers() + cp = subsubparsers.add_parser('cp') + cp.add_argument('source') + cp.add_argument('destination') + args, _ = parser.parse_known_args() + shutil.copy(args.source, args.destination) diff --git a/tests/datasets/azcopy b/tests/datasets/azcopy deleted file mode 100755 index 1f74b4c4d0b..00000000000 --- a/tests/datasets/azcopy +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -"""Basic mock-up of the azcopy CLI.""" - -import argparse -import shutil - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers() - copy = subparsers.add_parser('copy') - copy.add_argument('source') - copy.add_argument('destination') - copy.add_argument('--recursive', default='false') - sync = subparsers.add_parser('sync') - sync.add_argument('source') - sync.add_argument('destination') - sync.add_argument('--recursive', default='true') - args, _ = parser.parse_known_args() - - if args.recursive == 'true': - shutil.copytree(args.source, args.destination, dirs_exist_ok=True) - else: - shutil.copy(args.source, args.destination) diff --git a/tests/datasets/azcopy b/tests/datasets/azcopy new file mode 120000 index 00000000000..2081c776877 --- /dev/null +++ b/tests/datasets/azcopy @@ -0,0 +1 @@ +azcopy.py \ No newline at end of file diff --git a/tests/datasets/azcopy.bat b/tests/datasets/azcopy.bat new file mode 100644 index 00000000000..11ea5c45b2f --- /dev/null +++ b/tests/datasets/azcopy.bat @@ -0,0 +1,6 @@ +REM Copyright (c) Microsoft Corporation. All rights reserved. +REM Licensed under the MIT License. + +@ECHO OFF + +python3 tests\datasets\azcopy.py %* diff --git a/tests/datasets/azcopy.py b/tests/datasets/azcopy.py new file mode 100755 index 00000000000..1f74b4c4d0b --- /dev/null +++ b/tests/datasets/azcopy.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Basic mock-up of the azcopy CLI.""" + +import argparse +import shutil + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + copy = subparsers.add_parser('copy') + copy.add_argument('source') + copy.add_argument('destination') + copy.add_argument('--recursive', default='false') + sync = subparsers.add_parser('sync') + sync.add_argument('source') + sync.add_argument('destination') + sync.add_argument('--recursive', default='true') + args, _ = parser.parse_known_args() + + if args.recursive == 'true': + shutil.copytree(args.source, args.destination, dirs_exist_ok=True) + else: + shutil.copy(args.source, args.destination) diff --git a/tests/datasets/conftest.py b/tests/datasets/conftest.py index 3f59d69581b..82e8e84bb26 100644 --- a/tests/datasets/conftest.py +++ b/tests/datasets/conftest.py @@ -2,11 +2,40 @@ # Licensed under the MIT License. import os +import shutil +from typing import Any import pytest +import torchvision.datasets.utils +from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets.utils import Executable, which +import torchgeo.datasets.utils +from torchgeo.datasets.utils import Executable, Path, which + + +def copy(url: str, root: Path, *args: Any, **kwargs: Any) -> None: + os.makedirs(root, exist_ok=True) + shutil.copy(url, root) + + +@pytest.fixture(autouse=True) +def download_url(monkeypatch: MonkeyPatch, request: SubRequest) -> None: + monkeypatch.setattr(torchvision.datasets.utils, 'download_url', copy) + monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', copy) + _, filename = os.path.split(request.path) + module = filename[5:-3] + try: + monkeypatch.setattr(f'torchgeo.datasets.{module}.download_url', copy) + except AttributeError: + pass + + +@pytest.fixture +def aws(monkeypatch: MonkeyPatch) -> Executable: + path = os.path.dirname(os.path.realpath(__file__)) + monkeypatch.setenv('PATH', path, prepend=os.pathsep) + return which('aws') @pytest.fixture diff --git a/tests/datasets/test_advance.py b/tests/datasets/test_advance.py index f2a34b89f4c..12d20c0fd76 100644 --- a/tests/datasets/test_advance.py +++ b/tests/datasets/test_advance.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -11,20 +10,14 @@ import torch.nn as nn from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import ADVANCE, DatasetNotFoundError pytest.importorskip('scipy', minversion='1.7.2') -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestADVANCE: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'advance') urls = [ os.path.join(data_dir, 'ADVANCE_vision.zip'), @@ -33,7 +26,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ADVANCE: md5s = ['43acacecebecd17a82bc2c1e719fd7e4', '039b7baa47879a8a4e32b9dd8287f6ad'] monkeypatch.setattr(ADVANCE, 'urls', urls) monkeypatch.setattr(ADVANCE, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ADVANCE(root, transforms, download=True, checksum=True) @@ -57,7 +50,7 @@ def test_already_downloaded(self, dataset: ADVANCE) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ADVANCE(str(tmp_path)) + ADVANCE(tmp_path) def test_plot(self, dataset: ADVANCE) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_agb_live_woody_density.py b/tests/datasets/test_agb_live_woody_density.py index 3b4c0636c1d..a3991ccd410 100644 --- a/tests/datasets/test_agb_live_woody_density.py +++ b/tests/datasets/test_agb_live_woody_density.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,7 +11,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo from torchgeo.datasets import ( AbovegroundLiveWoodyBiomassDensity, DatasetNotFoundError, @@ -21,19 +19,12 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestAbovegroundLiveWoodyBiomassDensity: @pytest.fixture def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> AbovegroundLiveWoodyBiomassDensity: transforms = nn.Identity() - monkeypatch.setattr( - torchgeo.datasets.agb_live_woody_density, 'download_url', download_url - ) url = os.path.join( 'tests', 'data', @@ -42,7 +33,7 @@ def dataset( ) monkeypatch.setattr(AbovegroundLiveWoodyBiomassDensity, 'url', url) - root = str(tmp_path) + root = tmp_path return AbovegroundLiveWoodyBiomassDensity( root, transforms=transforms, download=True ) @@ -58,7 +49,7 @@ def test_len(self, dataset: AbovegroundLiveWoodyBiomassDensity) -> None: def test_no_dataset(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - AbovegroundLiveWoodyBiomassDensity(str(tmp_path)) + AbovegroundLiveWoodyBiomassDensity(tmp_path) def test_already_downloaded( self, dataset: AbovegroundLiveWoodyBiomassDensity diff --git a/tests/datasets/test_agrifieldnet.py b/tests/datasets/test_agrifieldnet.py index 6608dc7a1bb..a9857bdeb27 100644 --- a/tests/datasets/test_agrifieldnet.py +++ b/tests/datasets/test_agrifieldnet.py @@ -8,6 +8,7 @@ import pytest import torch import torch.nn as nn +from pytest import MonkeyPatch from rasterio.crs import CRS from torchgeo.datasets import ( @@ -18,14 +19,18 @@ RGBBandsMissingError, UnionDataset, ) +from torchgeo.datasets.utils import Executable class TestAgriFieldNet: @pytest.fixture - def dataset(self) -> AgriFieldNet: - path = os.path.join('tests', 'data', 'agrifieldnet') + def dataset( + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> AgriFieldNet: + url = os.path.join('tests', 'data', 'agrifieldnet') + monkeypatch.setattr(AgriFieldNet, 'url', url) transforms = nn.Identity() - return AgriFieldNet(paths=path, transforms=transforms) + return AgriFieldNet(tmp_path, transforms=transforms, download=True) def test_getitem(self, dataset: AgriFieldNet) -> None: x = dataset[dataset.bounds] @@ -50,7 +55,7 @@ def test_already_downloaded(self, dataset: AgriFieldNet) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - AgriFieldNet(str(tmp_path)) + AgriFieldNet(tmp_path) def test_plot(self, dataset: AgriFieldNet) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_airphen.py b/tests/datasets/test_airphen.py index 3c60fb090f7..9b3618f9e8b 100644 --- a/tests/datasets/test_airphen.py +++ b/tests/datasets/test_airphen.py @@ -52,7 +52,7 @@ def test_plot(self, dataset: Airphen) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Airphen(str(tmp_path)) + Airphen(tmp_path) def test_invalid_query(self, dataset: Airphen) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_astergdem.py b/tests/datasets/test_astergdem.py index 7f1aeaa4cd6..abcf822eb4e 100644 --- a/tests/datasets/test_astergdem.py +++ b/tests/datasets/test_astergdem.py @@ -25,7 +25,7 @@ class TestAsterGDEM: def dataset(self, tmp_path: Path) -> AsterGDEM: zipfile = os.path.join('tests', 'data', 'astergdem', 'astergdem.zip') shutil.unpack_archive(zipfile, tmp_path, 'zip') - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return AsterGDEM(root, transforms=transforms) @@ -33,7 +33,7 @@ def test_datasetmissing(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - AsterGDEM(str(tmp_path)) + AsterGDEM(tmp_path) def test_getitem(self, dataset: AsterGDEM) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_benin_cashews.py b/tests/datasets/test_benin_cashews.py index 1e960527a84..795e5c6d521 100644 --- a/tests/datasets/test_benin_cashews.py +++ b/tests/datasets/test_benin_cashews.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -18,44 +16,22 @@ DatasetNotFoundError, RGBBandsMissingError, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join('tests', 'data', 'ts_cashew_benin', '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestBeninSmallHolderCashews: @pytest.fixture def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path ) -> BeninSmallHolderCashews: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - source_md5 = '255efff0f03bc6322470949a09bc76db' - labels_md5 = 'ed2195d93ca6822d48eb02bc3e81c127' - monkeypatch.setitem(BeninSmallHolderCashews.image_meta, 'md5', source_md5) - monkeypatch.setitem(BeninSmallHolderCashews.target_meta, 'md5', labels_md5) - monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('2019_11_05',)) - root = str(tmp_path) + url = os.path.join('tests', 'data', 'technoserve-cashew-benin') + monkeypatch.setattr(BeninSmallHolderCashews, 'url', url) + monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('20191105',)) + monkeypatch.setattr(BeninSmallHolderCashews, 'tile_height', 2) + monkeypatch.setattr(BeninSmallHolderCashews, 'tile_width', 2) + root = tmp_path transforms = nn.Identity() - bands = BeninSmallHolderCashews.all_bands - - return BeninSmallHolderCashews( - root, - transforms=transforms, - bands=bands, - download=True, - api_key='', - checksum=True, - verbose=True, - ) + return BeninSmallHolderCashews(root, transforms=transforms, download=True) def test_getitem(self, dataset: BeninSmallHolderCashews) -> None: x = dataset[0] @@ -66,25 +42,22 @@ def test_getitem(self, dataset: BeninSmallHolderCashews) -> None: assert isinstance(x['y'], torch.Tensor) def test_len(self, dataset: BeninSmallHolderCashews) -> None: - assert len(dataset) == 72 + assert len(dataset) == 1 def test_add(self, dataset: BeninSmallHolderCashews) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 144 + assert len(ds) == 2 def test_already_downloaded(self, dataset: BeninSmallHolderCashews) -> None: - BeninSmallHolderCashews(root=dataset.root, download=True, api_key='') + BeninSmallHolderCashews(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - BeninSmallHolderCashews(str(tmp_path)) + BeninSmallHolderCashews(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - BeninSmallHolderCashews(bands=['B01', 'B02']) # type: ignore[arg-type] - - with pytest.raises(ValueError, match='is an invalid band name.'): BeninSmallHolderCashews(bands=('foo', 'bar')) def test_plot(self, dataset: BeninSmallHolderCashews) -> None: diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 82a3655626f..c93240e21f1 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -12,14 +12,9 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import BigEarthNet, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestBigEarthNet: @pytest.fixture( params=zip(['all', 's1', 's2'], [43, 19, 19], ['train', 'val', 'test']) @@ -27,7 +22,6 @@ class TestBigEarthNet: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> BigEarthNet: - monkeypatch.setattr(torchgeo.datasets.bigearthnet, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'bigearthnet') metadata = { 's1': { @@ -63,7 +57,7 @@ def dataset( monkeypatch.setattr(BigEarthNet, 'metadata', metadata) monkeypatch.setattr(BigEarthNet, 'splits_metadata', splits_metadata) bands, num_classes, split = request.param - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return BigEarthNet( root, split, bands, num_classes, transforms, download=True, checksum=True @@ -95,7 +89,7 @@ def test_len(self, dataset: BigEarthNet) -> None: def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None: BigEarthNet( - root=str(tmp_path), + root=tmp_path, bands=dataset.bands, split=dataset.split, num_classes=dataset.num_classes, @@ -112,21 +106,21 @@ def test_already_downloaded_not_extracted( shutil.rmtree( os.path.join(dataset.root, dataset.metadata['s2']['directory']) ) - download_url(dataset.metadata['s1']['url'], root=str(tmp_path)) - download_url(dataset.metadata['s2']['url'], root=str(tmp_path)) + shutil.copy(dataset.metadata['s1']['url'], tmp_path) + shutil.copy(dataset.metadata['s2']['url'], tmp_path) elif dataset.bands == 's1': shutil.rmtree( os.path.join(dataset.root, dataset.metadata['s1']['directory']) ) - download_url(dataset.metadata['s1']['url'], root=str(tmp_path)) + shutil.copy(dataset.metadata['s1']['url'], tmp_path) else: shutil.rmtree( os.path.join(dataset.root, dataset.metadata['s2']['directory']) ) - download_url(dataset.metadata['s2']['url'], root=str(tmp_path)) + shutil.copy(dataset.metadata['s2']['url'], tmp_path) BigEarthNet( - root=str(tmp_path), + root=tmp_path, bands=dataset.bands, split=dataset.split, num_classes=dataset.num_classes, @@ -135,7 +129,7 @@ def test_already_downloaded_not_extracted( def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - BigEarthNet(str(tmp_path)) + BigEarthNet(tmp_path) def test_plot(self, dataset: BigEarthNet) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_biomassters.py b/tests/datasets/test_biomassters.py index 8a853145da8..f9ea246ae73 100644 --- a/tests/datasets/test_biomassters.py +++ b/tests/datasets/test_biomassters.py @@ -37,7 +37,7 @@ def test_invalid_bands(self, dataset: BioMassters) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - BioMassters(str(tmp_path)) + BioMassters(tmp_path) def test_plot(self, dataset: BioMassters) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_cabuar.py b/tests/datasets/test_cabuar.py new file mode 100644 index 00000000000..967f43ee4d3 --- /dev/null +++ b/tests/datasets/test_cabuar.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +from itertools import product +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import CaBuAr, DatasetNotFoundError + +pytest.importorskip('h5py', minversion='3.6') + + +class TestCaBuAr: + @pytest.fixture( + params=product([CaBuAr.all_bands, CaBuAr.rgb_bands], ['train', 'val', 'test']) + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> CaBuAr: + data_dir = os.path.join('tests', 'data', 'cabuar') + urls = ( + os.path.join(data_dir, '512x512.hdf5'), + os.path.join(data_dir, 'chabud_test.h5'), + ) + monkeypatch.setattr(CaBuAr, 'urls', urls) + bands, split = request.param + root = tmp_path + transforms = nn.Identity() + return CaBuAr( + root=root, + split=split, + bands=bands, + transforms=transforms, + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: CaBuAr) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + + # Image tests + assert x['image'].ndim == 3 + + if dataset.bands == CaBuAr.rgb_bands: + assert x['image'].shape[0] == 2 * 3 + elif dataset.bands == CaBuAr.all_bands: + assert x['image'].shape[0] == 2 * 12 + + # Mask tests: + assert x['mask'].ndim == 2 + + def test_len(self, dataset: CaBuAr) -> None: + assert len(dataset) == 4 + + def test_already_downloaded(self, dataset: CaBuAr) -> None: + CaBuAr(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + CaBuAr(tmp_path) + + def test_invalid_bands(self) -> None: + with pytest.raises(AssertionError): + CaBuAr(bands=('OK', 'BK')) + + def test_plot(self, dataset: CaBuAr) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + sample = dataset[0] + sample['prediction'] = sample['mask'].clone() + dataset.plot(sample, suptitle='prediction') + plt.close() + + def test_plot_rgb(self, dataset: CaBuAr) -> None: + dataset = CaBuAr(root=dataset.root, bands=('B02',)) + with pytest.raises(ValueError, match="doesn't contain some of the RGB bands"): + dataset.plot(dataset[0], suptitle='Single Band') + + def test_invalid_split(self, dataset: CaBuAr) -> None: + with pytest.raises(AssertionError): + CaBuAr(dataset.root, split='foo') diff --git a/tests/datasets/test_caffe.py b/tests/datasets/test_caffe.py new file mode 100644 index 00000000000..afdf0a1a5dc --- /dev/null +++ b/tests/datasets/test_caffe.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import CaFFe, DatasetNotFoundError + + +class TestCaFFe: + @pytest.fixture(params=['train', 'val', 'test']) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> CaFFe: + md5 = '73c0aba603c356b2cce9ebf952fb7be0' + monkeypatch.setattr(CaFFe, 'md5', md5) + url = os.path.join('tests', 'data', 'caffe', 'caffe.zip') + monkeypatch.setattr(CaFFe, 'url', url) + root = tmp_path + split = request.param + transforms = nn.Identity() + return CaFFe(root, split, transforms, download=True, checksum=True) + + def test_getitem(self, dataset: CaFFe) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert x['image'].shape[0] == 1 + assert isinstance(x['mask_zones'], torch.Tensor) + assert x['image'].shape[-2:] == x['mask_zones'].shape[-2:] + + def test_len(self, dataset: CaFFe) -> None: + if dataset.split == 'train': + assert len(dataset) == 3 + else: + assert len(dataset) == 3 + + def test_already_downloaded(self, dataset: CaFFe) -> None: + CaFFe(root=dataset.root) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + filename = 'caffe.zip' + dir = os.path.join('tests', 'data', 'caffe') + shutil.copyfile( + os.path.join(dir, filename), os.path.join(str(tmp_path), filename) + ) + CaFFe(root=str(tmp_path)) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + CaFFe(split='foo') + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + CaFFe(tmp_path) + + def test_plot(self, dataset: CaFFe) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + sample = dataset[0] + sample['prediction'] = torch.clone(sample['mask_zones']) + dataset.plot(sample, suptitle='Prediction') + plt.close() diff --git a/tests/datasets/test_cbf.py b/tests/datasets/test_cbf.py index 4287cd9673d..17adb1961cc 100644 --- a/tests/datasets/test_cbf.py +++ b/tests/datasets/test_cbf.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,7 +11,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, CanadianBuildingFootprints, @@ -22,16 +20,11 @@ ) -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestCanadianBuildingFootprints: @pytest.fixture def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path ) -> CanadianBuildingFootprints: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) monkeypatch.setattr( CanadianBuildingFootprints, 'provinces_territories', ['Alberta'] ) @@ -41,7 +34,7 @@ def dataset( url = os.path.join('tests', 'data', 'cbf') + os.sep monkeypatch.setattr(CanadianBuildingFootprints, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return CanadianBuildingFootprints( root, res=0.1, transforms=transforms, download=True, checksum=True @@ -80,7 +73,7 @@ def test_plot_prediction(self, dataset: CanadianBuildingFootprints) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CanadianBuildingFootprints(str(tmp_path)) + CanadianBuildingFootprints(tmp_path) def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None: query = BoundingBox(2, 2, 2, 2, 2, 2) diff --git a/tests/datasets/test_cdl.py b/tests/datasets/test_cdl.py index 19ae64514c2..c68bc9deb61 100644 --- a/tests/datasets/test_cdl.py +++ b/tests/datasets/test_cdl.py @@ -14,7 +14,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( CDL, BoundingBox, @@ -24,15 +23,9 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestCDL: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL: - monkeypatch.setattr(torchgeo.datasets.cdl, 'download_url', download_url) - md5s = { 2023: '3fbd3eecf92b8ce1ae35060ada463c6d', 2022: '826c6fd639d9cdd94a44302fbc5b76c3', @@ -41,7 +34,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CDL: url = os.path.join('tests', 'data', 'cdl', '{}_30m_cdls.zip') monkeypatch.setattr(CDL, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return CDL( root, @@ -87,7 +80,7 @@ def test_already_extracted(self, dataset: CDL) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'cdl', '*_30m_cdls.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) CDL(root, years=[2023, 2022]) @@ -97,7 +90,7 @@ def test_invalid_year(self, tmp_path: Path) -> None: AssertionError, match='CDL data product only exists for the following years:', ): - CDL(str(tmp_path), years=[1996]) + CDL(tmp_path, years=[1996]) def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): @@ -121,7 +114,7 @@ def test_plot_prediction(self, dataset: CDL) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CDL(str(tmp_path)) + CDL(tmp_path) def test_invalid_query(self, dataset: CDL) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_chabud.py b/tests/datasets/test_chabud.py index 074674a1733..fed0aed5087 100644 --- a/tests/datasets/test_chabud.py +++ b/tests/datasets/test_chabud.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,29 +11,23 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import ChaBuD, DatasetNotFoundError pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, os.path.join(root, filename)) - - class TestChaBuD: @pytest.fixture(params=zip([ChaBuD.all_bands, ChaBuD.rgb_bands], ['train', 'val'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> ChaBuD: - monkeypatch.setattr(torchgeo.datasets.chabud, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'chabud') url = os.path.join(data_dir, 'train_eval.hdf5') md5 = '1bec048beeb87a865c53f40ab418aa75' monkeypatch.setattr(ChaBuD, 'url', url) monkeypatch.setattr(ChaBuD, 'md5', md5) bands, split = request.param - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ChaBuD( root=root, @@ -70,7 +63,7 @@ def test_already_downloaded(self, dataset: ChaBuD) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ChaBuD(str(tmp_path)) + ChaBuD(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_chesapeake.py b/tests/datasets/test_chesapeake.py index 814c6997d32..a6bcdd81b6f 100644 --- a/tests/datasets/test_chesapeake.py +++ b/tests/datasets/test_chesapeake.py @@ -13,88 +13,76 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, - Chesapeake13, ChesapeakeCVPR, + ChesapeakeDC, DatasetNotFoundError, IntersectionDataset, UnionDataset, ) -pytest.importorskip('zipfile_deflate64') - -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - -class TestChesapeake13: +class TestChesapeakeDC: @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Chesapeake13: - monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url) - md5 = 'fe35a615b8e749b21270472aa98bb42c' - monkeypatch.setattr(Chesapeake13, 'md5', md5) + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ChesapeakeDC: url = os.path.join( - 'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip' + 'tests', + 'data', + 'chesapeake', + 'lulc', + '{state}_lulc_{year}_2022-Edition.zip', ) - monkeypatch.setattr(Chesapeake13, 'url', url) + monkeypatch.setattr(ChesapeakeDC, 'url', url) + md5s = {2018: '35c644f13ccdb1baf62adf85cb8c7e48'} + monkeypatch.setattr(ChesapeakeDC, 'md5s', md5s) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) transforms = nn.Identity() - return Chesapeake13(root, transforms=transforms, download=True, checksum=True) + return ChesapeakeDC( + tmp_path, transforms=transforms, download=True, checksum=True + ) - def test_getitem(self, dataset: Chesapeake13) -> None: + def test_getitem(self, dataset: ChesapeakeDC) -> None: x = dataset[dataset.bounds] assert isinstance(x, dict) assert isinstance(x['crs'], CRS) assert isinstance(x['mask'], torch.Tensor) - def test_len(self, dataset: Chesapeake13) -> None: + def test_len(self, dataset: ChesapeakeDC) -> None: assert len(dataset) == 1 - def test_and(self, dataset: Chesapeake13) -> None: + def test_and(self, dataset: ChesapeakeDC) -> None: ds = dataset & dataset assert isinstance(ds, IntersectionDataset) - def test_or(self, dataset: Chesapeake13) -> None: + def test_or(self, dataset: ChesapeakeDC) -> None: ds = dataset | dataset assert isinstance(ds, UnionDataset) - def test_already_extracted(self, dataset: Chesapeake13) -> None: - Chesapeake13(dataset.paths, download=True) + def test_already_extracted(self, dataset: ChesapeakeDC) -> None: + ChesapeakeDC(dataset.paths, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: url = os.path.join( - 'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip' + 'tests', 'data', 'chesapeake', 'lulc', 'dc_lulc_2018_2022-Edition.zip' ) - root = str(tmp_path) - shutil.copy(url, root) - Chesapeake13(root) + shutil.copy(url, tmp_path) + ChesapeakeDC(tmp_path) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Chesapeake13(str(tmp_path), checksum=True) + ChesapeakeDC(tmp_path, checksum=True) - def test_plot(self, dataset: Chesapeake13) -> None: + def test_plot(self, dataset: ChesapeakeDC) -> None: query = dataset.bounds x = dataset[query] dataset.plot(x, suptitle='Test') plt.close() - - def test_plot_prediction(self, dataset: Chesapeake13) -> None: - query = dataset.bounds - x = dataset[query] x['prediction'] = x['mask'].clone() dataset.plot(x, suptitle='Prediction') plt.close() - def test_url(self) -> None: - ds = Chesapeake13(os.path.join('tests', 'data', 'chesapeake', 'BAYWIDE')) - assert 'cicwebresources.blob.core.windows.net' in ds.url - - def test_invalid_query(self, dataset: Chesapeake13) -> None: + def test_invalid_query(self, dataset: ChesapeakeDC) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) with pytest.raises( IndexError, match='query: .* not found in index with bounds:' @@ -114,7 +102,6 @@ class TestChesapeakeCVPR: def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> ChesapeakeCVPR: - monkeypatch.setattr(torchgeo.datasets.chesapeake, 'download_url', download_url) monkeypatch.setattr( ChesapeakeCVPR, 'md5s', @@ -148,7 +135,7 @@ def dataset( '_files', ['de_1m_2013_extended-debuffered-test_tiles', 'spatial_index.geojson'], ) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ChesapeakeCVPR( root, @@ -180,7 +167,7 @@ def test_already_extracted(self, dataset: ChesapeakeCVPR) -> None: ChesapeakeCVPR(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - root = str(tmp_path) + root = tmp_path shutil.copy( os.path.join( 'tests', 'data', 'chesapeake', 'cvpr', 'cvpr_chesapeake_landcover.zip' @@ -201,7 +188,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ChesapeakeCVPR(str(tmp_path), checksum=True) + ChesapeakeCVPR(tmp_path, checksum=True) def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) @@ -225,6 +212,9 @@ def test_plot(self, dataset: ChesapeakeCVPR) -> None: plt.close() dataset.plot(x, show_titles=False) plt.close() - x['prediction'] = x['mask'][:, :, 0].clone().unsqueeze(2) + if x['mask'].ndim == 2: + x['prediction'] = x['mask'].clone() + else: + x['prediction'] = x['mask'][0, :, :].clone() dataset.plot(x) plt.close() diff --git a/tests/datasets/test_cloud_cover.py b/tests/datasets/test_cloud_cover.py index e1dc89483c4..c2ed31bf108 100644 --- a/tests/datasets/test_cloud_cover.py +++ b/tests/datasets/test_cloud_cover.py @@ -1,15 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt import pytest import torch import torch.nn as nn +from _pytest.fixtures import SubRequest from pytest import MonkeyPatch from torchgeo.datasets import ( @@ -17,62 +16,30 @@ DatasetNotFoundError, RGBBandsMissingError, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join( - 'tests', 'data', 'ref_cloud_cover_detection_challenge_v1', '*.tar.gz' - ) - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestCloudCoverDetection: - @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CloudCoverDetection: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - - test_image_meta = { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_source.tar.gz', - 'md5': '542e64a6e39b53c84c6462ec1b989e43', - } - monkeypatch.setitem(CloudCoverDetection.image_meta, 'test', test_image_meta) - - test_target_meta = { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz', - 'md5': 'e8d41de08744a9845e74fca1eee3d1d3', - } - monkeypatch.setitem(CloudCoverDetection.target_meta, 'test', test_target_meta) - - root = str(tmp_path) - split = 'test' + @pytest.fixture(params=['train', 'test']) + def dataset( + self, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, + request: SubRequest, + ) -> CloudCoverDetection: + url = os.path.join('tests', 'data', 'ref_cloud_cover_detection_challenge_v1') + monkeypatch.setattr(CloudCoverDetection, 'url', url) + root = tmp_path + split = request.param transforms = nn.Identity() - return CloudCoverDetection( - root=root, - transforms=transforms, - split=split, - download=True, - api_key='', - checksum=True, + root=root, split=split, transforms=transforms, download=True ) def test_invalid_band(self, dataset: CloudCoverDetection) -> None: - invalid_bands = ['B09'] - with pytest.raises(ValueError): - CloudCoverDetection( - root=dataset.root, - split='test', - download=False, - api_key='', - bands=invalid_bands, - ) + with pytest.raises(AssertionError): + CloudCoverDetection(root=dataset.root, split=dataset.split, bands=['B09']) def test_getitem(self, dataset: CloudCoverDetection) -> None: x = dataset[0] @@ -84,28 +51,23 @@ def test_len(self, dataset: CloudCoverDetection) -> None: assert len(dataset) == 1 def test_already_downloaded(self, dataset: CloudCoverDetection) -> None: - CloudCoverDetection(root=dataset.root, split='test', download=True, api_key='') + CloudCoverDetection(root=dataset.root, split=dataset.split, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CloudCoverDetection(str(tmp_path)) + CloudCoverDetection(tmp_path) def test_plot(self, dataset: CloudCoverDetection) -> None: - dataset.plot(dataset[0], suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, suptitle='Test') + plt.close() sample['prediction'] = sample['mask'].clone() dataset.plot(sample, suptitle='Pred') plt.close() def test_plot_rgb(self, dataset: CloudCoverDetection) -> None: dataset = CloudCoverDetection( - root=dataset.root, - split='test', - bands=list(['B08']), - download=True, - api_key='', + root=dataset.root, split=dataset.split, bands=['B08'], download=True ) with pytest.raises( RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' diff --git a/tests/datasets/test_cms_mangrove_canopy.py b/tests/datasets/test_cms_mangrove_canopy.py index ce12795a07e..ce6f59796a7 100644 --- a/tests/datasets/test_cms_mangrove_canopy.py +++ b/tests/datasets/test_cms_mangrove_canopy.py @@ -20,10 +20,6 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestCMSGlobalMangroveCanopy: @pytest.fixture def dataset( @@ -54,7 +50,7 @@ def test_len(self, dataset: CMSGlobalMangroveCanopy) -> None: def test_no_dataset(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CMSGlobalMangroveCanopy(str(tmp_path)) + CMSGlobalMangroveCanopy(tmp_path) def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( @@ -63,7 +59,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: 'cms_mangrove_canopy', 'CMS_Global_Map_Mangrove_Canopy_1665.zip', ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) CMSGlobalMangroveCanopy(root, country='Angola') @@ -73,7 +69,7 @@ def test_corrupted(self, tmp_path: Path) -> None: ) as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - CMSGlobalMangroveCanopy(str(tmp_path), country='Angola', checksum=True) + CMSGlobalMangroveCanopy(tmp_path, country='Angola', checksum=True) def test_invalid_country(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_cowc.py b/tests/datasets/test_cowc.py index f454569d5b7..7ee61a0e985 100644 --- a/tests/datasets/test_cowc.py +++ b/tests/datasets/test_cowc.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -13,14 +12,9 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestCOWC: def test_not_implemented(self) -> None: with pytest.raises(TypeError, match="Can't instantiate abstract class"): @@ -32,7 +26,6 @@ class TestCOWCCounting: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> COWC: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) base_url = os.path.join('tests', 'data', 'cowc_counting') + os.sep monkeypatch.setattr(COWCCounting, 'base_url', base_url) md5s = [ @@ -46,7 +39,7 @@ def dataset( '0a4daed8c5f6c4e20faa6e38636e4346', ] monkeypatch.setattr(COWCCounting, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return COWCCounting(root, split, transforms, download=True, checksum=True) @@ -78,7 +71,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - COWCCounting(str(tmp_path)) + COWCCounting(tmp_path) def test_plot(self, dataset: COWCCounting) -> None: x = dataset[0].copy() @@ -96,7 +89,6 @@ class TestCOWCDetection: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> COWC: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) base_url = os.path.join('tests', 'data', 'cowc_detection') + os.sep monkeypatch.setattr(COWCDetection, 'base_url', base_url) md5s = [ @@ -110,7 +102,7 @@ def dataset( 'dccc2257e9c4a9dde2b4f84769804046', ] monkeypatch.setattr(COWCDetection, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return COWCDetection(root, split, transforms, download=True, checksum=True) @@ -142,7 +134,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - COWCDetection(str(tmp_path)) + COWCDetection(tmp_path) def test_plot(self, dataset: COWCDetection) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py index 2ad82fca137..3d77ac2b5fe 100644 --- a/tests/datasets/test_cropharvest.py +++ b/tests/datasets/test_cropharvest.py @@ -11,20 +11,14 @@ import torch.nn as nn from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import CropHarvest, DatasetNotFoundError pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, filename: str, md5: str) -> None: - shutil.copy(url, os.path.join(root, filename)) - - class TestCropHarvest: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: - monkeypatch.setattr(torchgeo.datasets.cropharvest, 'download_url', download_url) monkeypatch.setitem( CropHarvest.file_dict['features'], 'md5', 'ef6f4f00c0b3b50ed8380b0044928572' ) @@ -42,7 +36,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: os.path.join('tests', 'data', 'cropharvest', 'labels.geojson'), ) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() dataset = CropHarvest(root, transforms, download=True, checksum=True) @@ -61,16 +55,16 @@ def test_len(self, dataset: CropHarvest) -> None: assert len(dataset) == 5 def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None: - CropHarvest(root=str(tmp_path), download=False) + CropHarvest(root=tmp_path, download=False) def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None: feature_path = os.path.join(tmp_path, 'features') shutil.rmtree(feature_path) - CropHarvest(root=str(tmp_path), download=True) + CropHarvest(root=tmp_path, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CropHarvest(str(tmp_path)) + CropHarvest(tmp_path) def test_plot(self, dataset: CropHarvest) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_cv4a_kenya_crop_type.py b/tests/datasets/test_cv4a_kenya_crop_type.py index ad0e26ed03d..e6309844054 100644 --- a/tests/datasets/test_cv4a_kenya_crop_type.py +++ b/tests/datasets/test_cv4a_kenya_crop_type.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -18,44 +16,23 @@ DatasetNotFoundError, RGBBandsMissingError, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join( - 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' - ) - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestCV4AKenyaCropType: @pytest.fixture - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - source_md5 = '7f4dcb3f33743dddd73f453176308bfb' - labels_md5 = '95fc59f1d94a85ec00931d4d1280bec9' - monkeypatch.setitem(CV4AKenyaCropType.image_meta, 'md5', source_md5) - monkeypatch.setitem(CV4AKenyaCropType.target_meta, 'md5', labels_md5) - monkeypatch.setattr( - CV4AKenyaCropType, 'tile_names', ['ref_african_crops_kenya_02_tile_00'] - ) + def dataset( + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> CV4AKenyaCropType: + url = os.path.join('tests', 'data', 'cv4a_kenya_crop_type') + monkeypatch.setattr(CV4AKenyaCropType, 'url', url) + monkeypatch.setattr(CV4AKenyaCropType, 'tiles', list(map(str, range(1)))) monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606']) - root = str(tmp_path) + monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2) + monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2) + root = tmp_path transforms = nn.Identity() - return CV4AKenyaCropType( - root, - transforms=transforms, - download=True, - api_key='', - checksum=True, - verbose=True, - ) + return CV4AKenyaCropType(root, transforms=transforms, download=True) def test_getitem(self, dataset: CV4AKenyaCropType) -> None: x = dataset[0] @@ -66,60 +43,34 @@ def test_getitem(self, dataset: CV4AKenyaCropType) -> None: assert isinstance(x['y'], torch.Tensor) def test_len(self, dataset: CV4AKenyaCropType) -> None: - assert len(dataset) == 345 + assert len(dataset) == 1 def test_add(self, dataset: CV4AKenyaCropType) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 690 - - def test_get_splits(self, dataset: CV4AKenyaCropType) -> None: - train_field_ids, test_field_ids = dataset.get_splits() - assert isinstance(train_field_ids, list) - assert isinstance(test_field_ids, list) - assert len(train_field_ids) == 18 - assert len(test_field_ids) == 9 - assert 336 in train_field_ids - assert 336 not in test_field_ids - assert 4793 in test_field_ids - assert 4793 not in train_field_ids + assert len(ds) == 2 def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None: - CV4AKenyaCropType(root=dataset.root, download=True, api_key='') + CV4AKenyaCropType(root=dataset.root, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - CV4AKenyaCropType(str(tmp_path)) - - def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None: - with pytest.raises(AssertionError): - dataset._load_label_tile('foo') - - with pytest.raises(AssertionError): - dataset._load_all_image_tiles('foo', ('B01', 'B02')) - - with pytest.raises(AssertionError): - dataset._load_single_image_tile('foo', '20190606', ('B01', 'B02')) + CV4AKenyaCropType(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(AssertionError): - CV4AKenyaCropType(bands=['B01', 'B02']) # type: ignore[arg-type] - - with pytest.raises(ValueError, match='is an invalid band name.'): CV4AKenyaCropType(bands=('foo', 'bar')) def test_plot(self, dataset: CV4AKenyaCropType) -> None: - dataset.plot(dataset[0], time_step=0, suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, time_step=0, suptitle='Test') + plt.close() sample['prediction'] = sample['mask'].clone() dataset.plot(sample, time_step=0, suptitle='Pred') plt.close() def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None: dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(['B01'])) - with pytest.raises( - RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' - ): - dataset.plot(dataset[0], time_step=0, suptitle='Single Band') + match = 'Dataset does not contain some of the RGB bands' + with pytest.raises(RGBBandsMissingError, match=match): + dataset.plot(dataset[0]) diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index d165b064a90..8dfd39b3c9f 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -15,52 +13,33 @@ from torch.utils.data import ConcatDataset from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - for tarball in glob.iglob(os.path.join('tests', 'data', 'cyclone', '*.tar.gz')): - shutil.copy(tarball, output_dir) - - -def fetch(collection_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestTropicalCyclone: @pytest.fixture(params=['train', 'test']) def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + self, + request: SubRequest, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, ) -> TropicalCyclone: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - md5s = { - 'train': { - 'source': '2b818e0a0873728dabf52c7054a0ce4c', - 'labels': 'c3c2b6d02c469c5519f4add4f9132712', - }, - 'test': { - 'source': 'bc07c519ddf3ce88857435ddddf98a16', - 'labels': '3ca4243eff39b87c73e05ec8db1824bf', - }, - } - monkeypatch.setattr(TropicalCyclone, 'md5s', md5s) - monkeypatch.setattr(TropicalCyclone, 'size', 1) - root = str(tmp_path) + url = os.path.join('tests', 'data', 'cyclone') + monkeypatch.setattr(TropicalCyclone, 'url', url) + monkeypatch.setattr(TropicalCyclone, 'size', 2) + root = tmp_path split = request.param transforms = nn.Identity() - return TropicalCyclone( - root, split, transforms, download=True, api_key='', checksum=True - ) + return TropicalCyclone(root, split, transforms, download=True) @pytest.mark.parametrize('index', [0, 1]) def test_getitem(self, dataset: TropicalCyclone, index: int) -> None: x = dataset[index] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['storm_id'], str) - assert isinstance(x['relative_time'], int) - assert isinstance(x['ocean'], int) + assert isinstance(x['relative_time'], torch.Tensor) + assert isinstance(x['ocean'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) assert x['image'].shape == (3, dataset.size, dataset.size) @@ -73,7 +52,7 @@ def test_add(self, dataset: TropicalCyclone) -> None: assert len(ds) == 10 def test_already_downloaded(self, dataset: TropicalCyclone) -> None: - TropicalCyclone(root=dataset.root, download=True, api_key='') + TropicalCyclone(root=dataset.root, download=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -81,13 +60,12 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - TropicalCyclone(str(tmp_path)) + TropicalCyclone(tmp_path) def test_plot(self, dataset: TropicalCyclone) -> None: - dataset.plot(dataset[0], suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, suptitle='Test') + plt.close() sample['prediction'] = sample['label'] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_deepglobelandcover.py b/tests/datasets/test_deepglobelandcover.py index 5e845958668..2ea779a98fa 100644 --- a/tests/datasets/test_deepglobelandcover.py +++ b/tests/datasets/test_deepglobelandcover.py @@ -39,16 +39,14 @@ def test_len(self, dataset: DeepGlobeLandCover) -> None: def test_extract(self, tmp_path: Path) -> None: root = os.path.join('tests', 'data', 'deepglobelandcover') filename = 'data.zip' - shutil.copyfile( - os.path.join(root, filename), os.path.join(str(tmp_path), filename) - ) - DeepGlobeLandCover(root=str(tmp_path)) + shutil.copyfile(os.path.join(root, filename), os.path.join(tmp_path, filename)) + DeepGlobeLandCover(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'data.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - DeepGlobeLandCover(root=str(tmp_path), checksum=True) + DeepGlobeLandCover(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -56,7 +54,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - DeepGlobeLandCover(str(tmp_path)) + DeepGlobeLandCover(tmp_path) def test_plot(self, dataset: DeepGlobeLandCover) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_dfc2022.py b/tests/datasets/test_dfc2022.py index d353da5e274..4d40aa2e442 100644 --- a/tests/datasets/test_dfc2022.py +++ b/tests/datasets/test_dfc2022.py @@ -61,13 +61,13 @@ def test_extract(self, tmp_path: Path) -> None: os.path.join('tests', 'data', 'dfc2022', 'val.zip'), os.path.join(tmp_path, 'val.zip'), ) - DFC2022(root=str(tmp_path)) + DFC2022(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'labeled_train.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - DFC2022(root=str(tmp_path), checksum=True) + DFC2022(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -75,7 +75,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - DFC2022(str(tmp_path)) + DFC2022(tmp_path) def test_plot(self, dataset: DFC2022) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_digital_typhoon.py b/tests/datasets/test_digital_typhoon.py new file mode 100644 index 00000000000..c3df283ec35 --- /dev/null +++ b/tests/datasets/test_digital_typhoon.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import DatasetNotFoundError, DigitalTyphoon + +pytest.importorskip('h5py', minversion='3.6') + + +class TestDigitalTyphoon: + @pytest.fixture( + params=[ + (3, {'wind': 0}, {'pressure': 1500}), + (3, {'pressure': 0}, {'wind': 100}), + ] + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> DigitalTyphoon: + sequence_length, min_features, max_features = request.param + + url = os.path.join('tests', 'data', 'digital_typhoon', 'WP.tar.gz{0}') + monkeypatch.setattr(DigitalTyphoon, 'url', url) + + md5sums = { + 'aa': '692ea3796c9bc9ef1e0ab6f2b8bc51ad', + 'ab': '692ea3796c9bc9ef1e0ab6f2b8bc51ad', + } + monkeypatch.setattr(DigitalTyphoon, 'md5sums', md5sums) + root = tmp_path + + transforms = nn.Identity() + return DigitalTyphoon( + root=root, + sequence_length=sequence_length, + min_feature_value=min_features, + max_feature_value=max_features, + transforms=transforms, + download=True, + checksum=True, + ) + + def test_len(self, dataset: DigitalTyphoon) -> None: + assert len(dataset) == 15 + + @pytest.mark.parametrize('index', [0, 1]) + def test_getitem(self, dataset: DigitalTyphoon, index: int) -> None: + x = dataset[index] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert x['image'].min() >= 0 and x['image'].max() <= 1 + assert isinstance(x['label'], torch.Tensor) + + def test_already_downloaded(self, dataset: DigitalTyphoon) -> None: + DigitalTyphoon(root=dataset.root) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + root = os.path.join('tests', 'data', 'digital_typhoon') + filenames = ['WP.tar.gzaa', 'WP.tar.gzab'] + for filename in filenames: + shutil.copyfile(os.path.join(root, filename), tmp_path / filename) + DigitalTyphoon(root=str(tmp_path)) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + DigitalTyphoon(root=str(tmp_path)) + + def test_plot(self, dataset: DigitalTyphoon) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + sample = dataset[0] + sample['prediction'] = sample['label'] + dataset.plot(sample) + plt.close() diff --git a/tests/datasets/test_eddmaps.py b/tests/datasets/test_eddmaps.py index 364e988aba3..1a1e805e13f 100644 --- a/tests/datasets/test_eddmaps.py +++ b/tests/datasets/test_eddmaps.py @@ -38,7 +38,7 @@ def test_or(self, dataset: EDDMapS) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EDDMapS(str(tmp_path)) + EDDMapS(tmp_path) def test_invalid_query(self, dataset: EDDMapS) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_enviroatlas.py b/tests/datasets/test_enviroatlas.py index 11ac3b93436..9534032ed45 100644 --- a/tests/datasets/test_enviroatlas.py +++ b/tests/datasets/test_enviroatlas.py @@ -13,7 +13,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, @@ -24,10 +23,6 @@ from torchgeo.samplers import RandomGeoSampler -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestEnviroAtlas: @pytest.fixture( params=[ @@ -39,7 +34,6 @@ class TestEnviroAtlas: def dataset( self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path ) -> EnviroAtlas: - monkeypatch.setattr(torchgeo.datasets.enviroatlas, 'download_url', download_url) monkeypatch.setattr(EnviroAtlas, 'md5', '071ec65c611e1d4915a5247bffb5ad87') monkeypatch.setattr( EnviroAtlas, @@ -51,7 +45,7 @@ def dataset( '_files', ['pittsburgh_pa-2010_1m-train_tiles-debuffered', 'spatial_index.geojson'], ) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return EnviroAtlas( root, @@ -85,7 +79,7 @@ def test_already_extracted(self, dataset: EnviroAtlas) -> None: EnviroAtlas(root=dataset.root, download=True) def test_already_downloaded(self, tmp_path: Path) -> None: - root = str(tmp_path) + root = tmp_path shutil.copy( os.path.join('tests', 'data', 'enviroatlas', 'enviroatlas_lotp.zip'), root ) @@ -93,7 +87,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EnviroAtlas(str(tmp_path), checksum=True) + EnviroAtlas(tmp_path, checksum=True) def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py index 3fe1207b6b4..29de898e32e 100644 --- a/tests/datasets/test_esri2020.py +++ b/tests/datasets/test_esri2020.py @@ -12,7 +12,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, @@ -22,14 +21,9 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestEsri2020: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Esri2020: - monkeypatch.setattr(torchgeo.datasets.esri2020, 'download_url', download_url) zipfile = 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip' monkeypatch.setattr(Esri2020, 'zipfile', zipfile) @@ -42,7 +36,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> Esri2020: 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip', ) monkeypatch.setattr(Esri2020, 'url', url) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return Esri2020(root, transforms=transforms, download=True, checksum=True) @@ -66,11 +60,11 @@ def test_not_extracted(self, tmp_path: Path) -> None: 'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip', ) shutil.copy(url, tmp_path) - Esri2020(str(tmp_path)) + Esri2020(tmp_path) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Esri2020(str(tmp_path), checksum=True) + Esri2020(tmp_path, checksum=True) def test_and(self, dataset: Esri2020) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index 0cf4029921d..b093e60684f 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,20 +11,14 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import ETCI2021, DatasetNotFoundError -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestETCI2021: @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> ETCI2021: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'etci2021') metadata = { 'train': { @@ -48,7 +41,7 @@ def dataset( }, } monkeypatch.setattr(ETCI2021, 'metadata', metadata) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return ETCI2021(root, split, transforms, download=True, checksum=True) @@ -78,7 +71,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ETCI2021(str(tmp_path)) + ETCI2021(tmp_path) def test_plot(self, dataset: ETCI2021) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_eudem.py b/tests/datasets/test_eudem.py index e984dd9e079..c41d36c8301 100644 --- a/tests/datasets/test_eudem.py +++ b/tests/datasets/test_eudem.py @@ -28,7 +28,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> EUDEM: monkeypatch.setattr(EUDEM, 'md5s', md5s) zipfile = os.path.join('tests', 'data', 'eudem', 'eu_dem_v11_E30N10.zip') shutil.copy(zipfile, tmp_path) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return EUDEM(root, transforms=transforms) @@ -42,7 +42,7 @@ def test_len(self, dataset: EUDEM) -> None: assert len(dataset) == 1 def test_extracted_already(self, dataset: EUDEM) -> None: - assert isinstance(dataset.paths, str) + assert isinstance(dataset.paths, Path) zipfile = os.path.join(dataset.paths, 'eu_dem_v11_E30N10.zip') shutil.unpack_archive(zipfile, dataset.paths, 'zip') EUDEM(dataset.paths) @@ -51,13 +51,13 @@ def test_no_dataset(self, tmp_path: Path) -> None: shutil.rmtree(tmp_path) os.makedirs(tmp_path) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EUDEM(str(tmp_path)) + EUDEM(tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'eu_dem_v11_E30N10.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - EUDEM(str(tmp_path), checksum=True) + EUDEM(tmp_path, checksum=True) def test_and(self, dataset: EUDEM) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_eurocrops.py b/tests/datasets/test_eurocrops.py index e716bbb783a..3b2d4fc63f7 100644 --- a/tests/datasets/test_eurocrops.py +++ b/tests/datasets/test_eurocrops.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -13,7 +12,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, @@ -23,18 +21,12 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestEuroCrops: @pytest.fixture(params=[None, ['1000000010'], ['1000000000'], ['2000000000']]) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> EuroCrops: classes = request.param - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) - monkeypatch.setattr(torchgeo.datasets.eurocrops, 'download_url', download_url) monkeypatch.setattr( EuroCrops, 'zenodo_files', [('AA.zip', 'b2ef5cac231294731c1dfea47cba544d')] ) @@ -42,7 +34,7 @@ def dataset( base_url = os.path.join('tests', 'data', 'eurocrops') + os.sep monkeypatch.setattr(EuroCrops, 'base_url', base_url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return EuroCrops( root, classes=classes, transforms=transforms, download=True, checksum=True @@ -81,7 +73,7 @@ def test_plot_prediction(self, dataset: EuroCrops) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EuroCrops(str(tmp_path)) + EuroCrops(tmp_path) def test_invalid_query(self, dataset: EuroCrops) -> None: query = BoundingBox(200, 200, 200, 200, 2, 2) @@ -91,5 +83,5 @@ def test_invalid_query(self, dataset: EuroCrops) -> None: dataset[query] def test_integrity_error(self, dataset: EuroCrops) -> None: - dataset.zenodo_files = [('AA.zip', 'invalid')] + dataset.zenodo_files = (('AA.zip', 'invalid'),) assert not dataset._check_integrity() diff --git a/tests/datasets/test_eurosat.py b/tests/datasets/test_eurosat.py index 99841c26031..282ff581931 100644 --- a/tests/datasets/test_eurosat.py +++ b/tests/datasets/test_eurosat.py @@ -14,7 +14,6 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import ( DatasetNotFoundError, EuroSAT, @@ -24,10 +23,6 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestEuroSAT: @pytest.fixture( params=product([EuroSAT, EuroSATSpatial, EuroSAT100], ['train', 'val', 'test']) @@ -37,35 +32,10 @@ def dataset( ) -> EuroSAT: base_class: type[EuroSAT] = request.param[0] split: str = request.param[1] - monkeypatch.setattr(torchgeo.datasets.eurosat, 'download_url', download_url) - md5 = 'aa051207b0547daba0ac6af57808d68e' - monkeypatch.setattr(base_class, 'md5', md5) - url = os.path.join('tests', 'data', 'eurosat', 'EuroSATallBands.zip') + url = os.path.join('tests', 'data', 'eurosat') + os.sep monkeypatch.setattr(base_class, 'url', url) - monkeypatch.setattr(base_class, 'filename', 'EuroSATallBands.zip') - monkeypatch.setattr( - base_class, - 'split_urls', - { - 'train': os.path.join('tests', 'data', 'eurosat', 'eurosat-train.txt'), - 'val': os.path.join('tests', 'data', 'eurosat', 'eurosat-val.txt'), - 'test': os.path.join('tests', 'data', 'eurosat', 'eurosat-test.txt'), - }, - ) - monkeypatch.setattr( - base_class, - 'split_md5s', - { - 'train': '4af60a00fdfdf8500572ae5360694b71', - 'val': '4af60a00fdfdf8500572ae5360694b71', - 'test': '4af60a00fdfdf8500572ae5360694b71', - }, - ) - root = str(tmp_path) transforms = nn.Identity() - return base_class( - root=root, split=split, transforms=transforms, download=True, checksum=True - ) + return base_class(tmp_path, split=split, transforms=transforms, download=True) def test_getitem(self, dataset: EuroSAT) -> None: x = dataset[0] @@ -90,18 +60,18 @@ def test_add(self, dataset: EuroSAT) -> None: assert len(ds) == 4 def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None: - EuroSAT(root=str(tmp_path), download=True) + type(dataset)(tmp_path) def test_already_downloaded_not_extracted( self, dataset: EuroSAT, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - EuroSAT(root=str(tmp_path), download=False) + shutil.copy(dataset.url + dataset.filename, tmp_path) + type(dataset)(tmp_path) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - EuroSAT(str(tmp_path)) + EuroSAT(tmp_path) def test_plot(self, dataset: EuroSAT) -> None: x = dataset[0].copy() @@ -114,7 +84,7 @@ def test_plot(self, dataset: EuroSAT) -> None: plt.close() def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None: - dataset = EuroSAT(root=str(tmp_path), bands=('B03',)) + dataset = type(dataset)(tmp_path, bands=('B03',)) with pytest.raises( RGBBandsMissingError, match='Dataset does not contain some of the RGB bands' ): diff --git a/tests/datasets/test_fair1m.py b/tests/datasets/test_fair1m.py index 38db23974d3..3ff3f66733f 100644 --- a/tests/datasets/test_fair1m.py +++ b/tests/datasets/test_fair1m.py @@ -12,15 +12,9 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import FAIR1M, DatasetNotFoundError -def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None: - os.makedirs(root, exist_ok=True) - shutil.copy(url, os.path.join(root, filename)) - - class TestFAIR1M: test_root = os.path.join('tests', 'data', 'fair1m') @@ -28,7 +22,6 @@ class TestFAIR1M: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> FAIR1M: - monkeypatch.setattr(torchgeo.datasets.fair1m, 'download_url', download_url) urls = { 'train': ( os.path.join(self.test_root, 'train', 'part1', 'images.zip'), @@ -65,7 +58,7 @@ def dataset( } monkeypatch.setattr(FAIR1M, 'urls', urls) monkeypatch.setattr(FAIR1M, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return FAIR1M(root, split, transforms, download=True, checksum=True) @@ -89,7 +82,7 @@ def test_len(self, dataset: FAIR1M) -> None: assert len(dataset) == 4 def test_already_downloaded(self, dataset: FAIR1M, tmp_path: Path) -> None: - FAIR1M(root=str(tmp_path), split=dataset.split, download=True) + FAIR1M(root=tmp_path, split=dataset.split, download=True) def test_already_downloaded_not_extracted( self, dataset: FAIR1M, tmp_path: Path @@ -98,11 +91,11 @@ def test_already_downloaded_not_extracted( for filepath, url in zip( dataset.paths[dataset.split], dataset.urls[dataset.split] ): - output = os.path.join(str(tmp_path), filepath) + output = os.path.join(tmp_path, filepath) os.makedirs(os.path.dirname(output), exist_ok=True) - download_url(url, root=os.path.dirname(output), filename=output) + shutil.copy(url, output) - FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True) + FAIR1M(root=tmp_path, split=dataset.split, checksum=True) def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None: md5s = tuple(['randomhash'] * len(FAIR1M.md5s[dataset.split])) @@ -111,17 +104,17 @@ def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None: for filepath, url in zip( dataset.paths[dataset.split], dataset.urls[dataset.split] ): - output = os.path.join(str(tmp_path), filepath) + output = os.path.join(tmp_path, filepath) os.makedirs(os.path.dirname(output), exist_ok=True) - download_url(url, root=os.path.dirname(output), filename=output) + shutil.copy(url, output) with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True) + FAIR1M(root=tmp_path, split=dataset.split, checksum=True) def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None: - shutil.rmtree(str(tmp_path)) + shutil.rmtree(tmp_path) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - FAIR1M(root=str(tmp_path), split=dataset.split) + FAIR1M(root=tmp_path, split=dataset.split) def test_plot(self, dataset: FAIR1M) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_fire_risk.py b/tests/datasets/test_fire_risk.py index e3f235c464d..472a6213ffc 100644 --- a/tests/datasets/test_fire_risk.py +++ b/tests/datasets/test_fire_risk.py @@ -12,25 +12,19 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, FireRisk -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestFireRisk: @pytest.fixture(params=['train', 'val']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> FireRisk: - monkeypatch.setattr(torchgeo.datasets.fire_risk, 'download_url', download_url) url = os.path.join('tests', 'data', 'fire_risk', 'FireRisk.zip') md5 = 'db22106d61b10d855234b4a74db921ac' monkeypatch.setattr(FireRisk, 'md5', md5) monkeypatch.setattr(FireRisk, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return FireRisk(root, split, transforms, download=True, checksum=True) @@ -46,18 +40,18 @@ def test_len(self, dataset: FireRisk) -> None: assert len(dataset) == 5 def test_already_downloaded(self, dataset: FireRisk, tmp_path: Path) -> None: - FireRisk(root=str(tmp_path), download=True) + FireRisk(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: FireRisk, tmp_path: Path ) -> None: shutil.rmtree(os.path.dirname(dataset.root)) - download_url(dataset.url, root=str(tmp_path)) - FireRisk(root=str(tmp_path), download=False) + shutil.copy(dataset.url, tmp_path) + FireRisk(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - FireRisk(str(tmp_path)) + FireRisk(tmp_path) def test_plot(self, dataset: FireRisk) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_forestdamage.py b/tests/datasets/test_forestdamage.py index 39aae73026a..64acb20b5b3 100644 --- a/tests/datasets/test_forestdamage.py +++ b/tests/datasets/test_forestdamage.py @@ -11,27 +11,18 @@ import torch.nn as nn from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, ForestDamage -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestForestDamage: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ForestDamage: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'forestdamage') - url = os.path.join(data_dir, 'Data_Set_Larch_Casebearer.zip') - md5 = '52d82ac38899e6e6bb40aacda643ee15' - monkeypatch.setattr(ForestDamage, 'url', url) monkeypatch.setattr(ForestDamage, 'md5', md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ForestDamage( root=root, transforms=transforms, download=True, checksum=True @@ -57,17 +48,17 @@ def test_not_extracted(self, tmp_path: Path) -> None: 'tests', 'data', 'forestdamage', 'Data_Set_Larch_Casebearer.zip' ) shutil.copy(url, tmp_path) - ForestDamage(root=str(tmp_path)) + ForestDamage(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'Data_Set_Larch_Casebearer.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - ForestDamage(root=str(tmp_path), checksum=True) + ForestDamage(root=tmp_path, checksum=True) def test_not_found(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ForestDamage(str(tmp_path)) + ForestDamage(tmp_path) def test_plot(self, dataset: ForestDamage) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_ftw.py b/tests/datasets/test_ftw.py new file mode 100644 index 00000000000..1a3d7130795 --- /dev/null +++ b/tests/datasets/test_ftw.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from itertools import product +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torch.utils.data import ConcatDataset + +from torchgeo.datasets import DatasetNotFoundError, FieldsOfTheWorld + +pytest.importorskip('pyarrow') + + +class TestFieldsOfTheWorld: + @pytest.fixture( + params=product(['train', 'val', 'test'], ['2-class', '3-class', 'instance']) + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> FieldsOfTheWorld: + split, task = request.param + + monkeypatch.setattr(FieldsOfTheWorld, 'valid_countries', ['austria']) + monkeypatch.setattr( + FieldsOfTheWorld, + 'country_to_md5', + {'austria': '1cf9593c9bdceeaba21bbcb24d35816c'}, + ) + base_url = os.path.join('tests', 'data', 'ftw') + '/' + monkeypatch.setattr(FieldsOfTheWorld, 'base_url', base_url) + root = tmp_path + transforms = nn.Identity() + return FieldsOfTheWorld( + root, + split, + task, + countries='austria', + transforms=transforms, + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: FieldsOfTheWorld) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + + def test_len(self, dataset: FieldsOfTheWorld) -> None: + assert len(dataset) == 2 + + def test_add(self, dataset: FieldsOfTheWorld) -> None: + ds = dataset + dataset + assert isinstance(ds, ConcatDataset) + assert len(ds) == 4 + + def test_already_extracted(self, dataset: FieldsOfTheWorld) -> None: + FieldsOfTheWorld(root=dataset.root, download=True) + + def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + url = os.path.join('tests', 'data', 'ftw', 'austria.zip') + root = tmp_path + shutil.copy(url, root) + FieldsOfTheWorld(root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + FieldsOfTheWorld(tmp_path) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + FieldsOfTheWorld(split='foo') + + def test_plot(self, dataset: FieldsOfTheWorld) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle='Test') + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x['prediction'] = x['mask'].clone() + dataset.plot(x) + plt.close() diff --git a/tests/datasets/test_gbif.py b/tests/datasets/test_gbif.py index 35426d18b03..8c64d614c30 100644 --- a/tests/datasets/test_gbif.py +++ b/tests/datasets/test_gbif.py @@ -38,7 +38,7 @@ def test_or(self, dataset: GBIF) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - GBIF(str(tmp_path)) + GBIF(tmp_path) def test_invalid_query(self, dataset: GBIF) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index e3b11e7fc2a..71e07f6928b 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. + +import math import os import pickle import sys @@ -36,7 +38,7 @@ def __init__( bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5), crs: CRS = CRS.from_epsg(4087), res: float = 1, - paths: str | Iterable[str] | None = None, + paths: str | os.PathLike[str] | Iterable[str | os.PathLike[str]] | None = None, ) -> None: super().__init__() self.index.insert(0, tuple(bounds)) @@ -70,7 +72,7 @@ class CustomVectorDataset(VectorDataset): class CustomSentinelDataset(Sentinel2): - all_bands: list[str] = [] + all_bands: tuple[str, ...] = () separate_files = False @@ -172,7 +174,7 @@ def test_and_nongeo(self, dataset: GeoDataset) -> None: dataset & ds2 # type: ignore[operator] def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None: - paths = [str(tmp_path), str(tmp_path / 'non_existing_file.tif')] + paths = [tmp_path, tmp_path / 'non_existing_file.tif'] with pytest.warns(UserWarning, match='Path was ignored.'): assert len(CustomGeoDataset(paths=paths).files) == 0 @@ -203,16 +205,37 @@ def test_files_property_deterministic(self) -> None: CustomGeoDataset(paths=paths1).files == CustomGeoDataset(paths=paths2).files ) + def test_files_property_mix_str_and_pathlib(self, tmp_path: Path) -> None: + foo = tmp_path / 'foo.txt' + bar = tmp_path / 'bar.txt' + foo.touch() + bar.touch() + ds = CustomGeoDataset(paths=[str(foo), bar]) + assert ds.files == [str(bar), str(foo)] + class TestRasterDataset: + naip_dir = os.path.join('tests', 'data', 'naip') + s2_dir = os.path.join( + 'tests', + 'data', + 'sentinel2', + 'S2A_MSIL2A_20220414T110751_N0400_R108_T26EMU_20220414T165533.SAFE', + 'GRANULE', + 'L2A_T26EMU_A035569_20220414T110747', + 'IMG_DATA', + 'R10m', + ) + @pytest.fixture(params=zip([['R', 'G', 'B'], None], [True, False])) def naip(self, request: SubRequest) -> NAIP: - root = os.path.join('tests', 'data', 'naip') bands = request.param[0] crs = CRS.from_epsg(4087) transforms = nn.Identity() cache = request.param[1] - return NAIP(root, crs=crs, bands=bands, transforms=transforms, cache=cache) + return NAIP( + self.naip_dir, crs=crs, bands=bands, transforms=transforms, cache=cache + ) @pytest.fixture( params=zip( @@ -234,34 +257,55 @@ def sentinel(self, request: SubRequest) -> Sentinel2: 'paths', [ # Single directory - os.path.join('tests', 'data', 'naip'), + naip_dir, # Multiple directories - [ - os.path.join('tests', 'data', 'naip'), - os.path.join('tests', 'data', 'naip'), - ], - # Single file - os.path.join('tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif'), + [naip_dir, naip_dir], # Multiple files ( - os.path.join( - 'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif' - ), - os.path.join( - 'tests', 'data', 'naip', 'm_3807511_ne_18_060_20190605.tif' - ), + os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif'), + os.path.join(naip_dir, 'm_3807511_ne_18_060_20190605.tif'), ), # Combination - { - os.path.join('tests', 'data', 'naip'), - os.path.join( - 'tests', 'data', 'naip', 'm_3807511_ne_18_060_20181104.tif' - ), - }, + {naip_dir, os.path.join(naip_dir, 'm_3807511_ne_18_060_20181104.tif')}, ], ) def test_files(self, paths: str | Iterable[str]) -> None: - assert 1 <= len(NAIP(paths).files) <= 2 + assert len(NAIP(paths).files) == 2 + + @pytest.mark.parametrize( + 'paths', + [ + # Single directory + s2_dir, + # Multiple directories + [s2_dir, s2_dir], + # Multiple files (single band) + [ + os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'), + ], + # Multiple files (multiple bands) + [ + os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20190414T110751_B03_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20190414T110751_B02_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'), + ], + # Combination + [ + s2_dir, + os.path.join(s2_dir, 'T26EMU_20190414T110751_B04_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20220414T110751_B04_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20220414T110751_B03_10m.jp2'), + os.path.join(s2_dir, 'T26EMU_20220414T110751_B02_10m.jp2'), + ], + ], + ) + @pytest.mark.filterwarnings('ignore:Could not find any relevant files') + def test_files_separate(self, paths: str | Iterable[str]) -> None: + assert len(Sentinel2(paths, bands=Sentinel2.rgb_bands).files) == 2 def test_getitem_single_file(self, naip: NAIP) -> None: x = naip[naip.bounds] @@ -277,6 +321,11 @@ def test_getitem_separate_files(self, sentinel: Sentinel2) -> None: assert isinstance(x['image'], torch.Tensor) assert len(sentinel.bands) == x['image'].shape[0] + def test_reprojection(self, naip: NAIP) -> None: + naip2 = NAIP(naip.paths, crs='EPSG:4326') + assert naip.crs != naip2.crs + assert not math.isclose(naip.res, naip2.res) + @pytest.mark.parametrize('dtype', ['uint16', 'uint32']) def test_getitem_uint_dtype(self, dtype: str) -> None: root = os.path.join('tests', 'data', 'raster', dtype) @@ -311,11 +360,11 @@ def test_invalid_query(self, sentinel: Sentinel2) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - RasterDataset(str(tmp_path)) + RasterDataset(tmp_path) def test_no_all_bands(self) -> None: root = os.path.join('tests', 'data', 'sentinel2') - bands = ['B04', 'B03', 'B02'] + bands = ('B04', 'B03', 'B02') transforms = nn.Identity() cache = True msg = ( @@ -380,7 +429,7 @@ def test_invalid_query(self, dataset: CustomVectorDataset) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - VectorDataset(str(tmp_path)) + VectorDataset(tmp_path) class TestNonGeoDataset: diff --git a/tests/datasets/test_geonrw.py b/tests/datasets/test_geonrw.py new file mode 100644 index 00000000000..f4613365eba --- /dev/null +++ b/tests/datasets/test_geonrw.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import DatasetNotFoundError, GeoNRW + + +class TestGeoNRW: + @pytest.fixture(params=['train', 'test']) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> GeoNRW: + md5 = '6ffc014d4b345bba3076e8d76ab481fa' + monkeypatch.setattr(GeoNRW, 'md5', md5) + url = os.path.join('tests', 'data', 'geonrw', 'nrw_dataset.tar.gz') + monkeypatch.setattr(GeoNRW, 'url', url) + monkeypatch.setattr(GeoNRW, 'train_list', ['aachen', 'bergisch', 'bielefeld']) + monkeypatch.setattr(GeoNRW, 'test_list', ['duesseldorf']) + root = tmp_path + split = request.param + transforms = nn.Identity() + return GeoNRW(root, split, transforms, download=True, checksum=True) + + def test_getitem(self, dataset: GeoNRW) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert x['image'].shape[0] == 3 + assert isinstance(x['mask'], torch.Tensor) + assert x['image'].shape[-2:] == x['mask'].shape[-2:] + + def test_len(self, dataset: GeoNRW) -> None: + if dataset.split == 'train': + assert len(dataset) == 6 + else: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: GeoNRW) -> None: + GeoNRW(root=dataset.root) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + filename = 'nrw_dataset.tar.gz' + dir = os.path.join('tests', 'data', 'geonrw') + shutil.copyfile( + os.path.join(dir, filename), os.path.join(str(tmp_path), filename) + ) + GeoNRW(root=str(tmp_path)) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + GeoNRW(split='foo') + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + GeoNRW(tmp_path) + + def test_plot(self, dataset: GeoNRW) -> None: + dataset.plot(dataset[0], suptitle='Test') + plt.close() + + sample = dataset[0] + sample['prediction'] = torch.clone(sample['mask']) + dataset.plot(sample, suptitle='Prediction') + plt.close() diff --git a/tests/datasets/test_gid15.py b/tests/datasets/test_gid15.py index 9c0358fb08b..3ff01695ba1 100644 --- a/tests/datasets/test_gid15.py +++ b/tests/datasets/test_gid15.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,25 +11,19 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import GID15, DatasetNotFoundError -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestGID15: @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> GID15: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) md5 = '3d5b1373ef9a3084ec493b9b2056fe07' monkeypatch.setattr(GID15, 'md5', md5) url = os.path.join('tests', 'data', 'gid15', 'gid-15.zip') monkeypatch.setattr(GID15, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return GID15(root, split, transforms, download=True, checksum=True) @@ -59,7 +52,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - GID15(str(tmp_path)) + GID15(tmp_path) def test_plot(self, dataset: GID15) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_globbiomass.py b/tests/datasets/test_globbiomass.py index 2e31b7b2222..5940b7113fd 100644 --- a/tests/datasets/test_globbiomass.py +++ b/tests/datasets/test_globbiomass.py @@ -37,7 +37,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> GlobBiomass: } monkeypatch.setattr(GlobBiomass, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return GlobBiomass(root, transforms=transforms, checksum=True) @@ -55,13 +55,13 @@ def test_already_extracted(self, dataset: GlobBiomass) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - GlobBiomass(str(tmp_path), checksum=True) + GlobBiomass(tmp_path, checksum=True) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'N00E020_agb.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - GlobBiomass(str(tmp_path), checksum=True) + GlobBiomass(tmp_path, checksum=True) def test_and(self, dataset: GlobBiomass) -> None: ds = dataset & dataset diff --git a/tests/datasets/test_hyspecnet.py b/tests/datasets/test_hyspecnet.py new file mode 100644 index 00000000000..1e5a646cee6 --- /dev/null +++ b/tests/datasets/test_hyspecnet.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch.nn as nn +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, HySpecNet11k, RGBBandsMissingError + +root = os.path.join('tests', 'data', 'hyspecnet') +md5s = {'hyspecnet-11k-01.tar.gz': '', 'hyspecnet-11k-splits.tar.gz': ''} + + +class TestHySpecNet11k: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch) -> HySpecNet11k: + monkeypatch.setattr(HySpecNet11k, 'url', root + os.sep) + monkeypatch.setattr(HySpecNet11k, 'md5s', md5s) + transforms = nn.Identity() + return HySpecNet11k(root, transforms=transforms) + + def test_getitem(self, dataset: HySpecNet11k) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], Tensor) + + def test_len(self, dataset: HySpecNet11k) -> None: + assert len(dataset) == 2 + + def test_download(self, dataset: HySpecNet11k, tmp_path: Path) -> None: + HySpecNet11k(tmp_path, download=True) + + def test_extract(self, dataset: HySpecNet11k, tmp_path: Path) -> None: + for file in glob.iglob(os.path.join(root, '*.tar.gz')): + shutil.copy(file, tmp_path) + HySpecNet11k(tmp_path) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + HySpecNet11k(tmp_path) + + def test_plot(self, dataset: HySpecNet11k) -> None: + x = dataset[0] + dataset.plot(x, suptitle='Test') + plt.close() + + def test_plot_rgb(self, dataset: HySpecNet11k) -> None: + dataset = HySpecNet11k(root=dataset.root, bands=(1, 2, 3)) + match = 'Dataset does not contain some of the RGB bands' + with pytest.raises(RGBBandsMissingError, match=match): + dataset.plot(dataset[0]) diff --git a/tests/datasets/test_idtrees.py b/tests/datasets/test_idtrees.py index a4c05580b58..5fd858ac04f 100644 --- a/tests/datasets/test_idtrees.py +++ b/tests/datasets/test_idtrees.py @@ -13,22 +13,16 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, IDTReeS pytest.importorskip('laspy', minversion='2') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestIDTReeS: @pytest.fixture(params=zip(['train', 'test', 'test'], ['task1', 'task1', 'task2'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> IDTReeS: - monkeypatch.setattr(torchgeo.datasets.idtrees, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'idtrees') metadata = { 'train': { @@ -44,7 +38,7 @@ def dataset( } split, task = request.param monkeypatch.setattr(IDTReeS, 'metadata', metadata) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return IDTReeS(root, split, task, transforms, download=True, checksum=True) @@ -77,11 +71,11 @@ def test_already_downloaded(self, dataset: IDTReeS) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - IDTReeS(str(tmp_path)) + IDTReeS(tmp_path) def test_not_extracted(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'idtrees', '*.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) IDTReeS(root) diff --git a/tests/datasets/test_inaturalist.py b/tests/datasets/test_inaturalist.py index 0f9a5424875..a1e255d7745 100644 --- a/tests/datasets/test_inaturalist.py +++ b/tests/datasets/test_inaturalist.py @@ -38,7 +38,7 @@ def test_or(self, dataset: INaturalist) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - INaturalist(str(tmp_path)) + INaturalist(tmp_path) def test_invalid_query(self, dataset: INaturalist) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_inria.py b/tests/datasets/test_inria.py index 21bcb1a900d..41ba41dee69 100644 --- a/tests/datasets/test_inria.py +++ b/tests/datasets/test_inria.py @@ -50,7 +50,7 @@ def test_already_downloaded(self, dataset: InriaAerialImageLabeling) -> None: def test_not_downloaded(self, tmp_path: str) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - InriaAerialImageLabeling(str(tmp_path)) + InriaAerialImageLabeling(tmp_path) def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None: InriaAerialImageLabeling.md5 = 'randommd5hash123' diff --git a/tests/datasets/test_iobench.py b/tests/datasets/test_iobench.py index 747d9ed1464..48aca59cad3 100644 --- a/tests/datasets/test_iobench.py +++ b/tests/datasets/test_iobench.py @@ -13,7 +13,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, @@ -24,19 +23,14 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestIOBench: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> IOBench: - monkeypatch.setattr(torchgeo.datasets.iobench, 'download_url', download_url) md5 = 'e82398add7c35896a31c4398c608ef83' url = os.path.join('tests', 'data', 'iobench', '{}.tar.gz') monkeypatch.setattr(IOBench, 'url', url) monkeypatch.setitem(IOBench.md5s, 'preprocessed', md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return IOBench(root, transforms=transforms, download=True, checksum=True) @@ -68,14 +62,14 @@ def test_already_extracted(self, dataset: IOBench) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'iobench', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) IOBench(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - IOBench(str(tmp_path)) + IOBench(tmp_path) def test_invalid_query(self, dataset: IOBench) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_l7irish.py b/tests/datasets/test_l7irish.py index f760ae89058..88ee559c5e9 100644 --- a/tests/datasets/test_l7irish.py +++ b/tests/datasets/test_l7irish.py @@ -14,7 +14,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, @@ -25,14 +24,9 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestL7Irish: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L7Irish: - monkeypatch.setattr(torchgeo.datasets.l7irish, 'download_url', download_url) md5s = { 'austral': '0485d6045f6b508068ef8daf9e5a5326', 'boreal': '5798f32545d7166564c4c4429357b840', @@ -41,7 +35,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L7Irish: url = os.path.join('tests', 'data', 'l7irish', '{}.tar.gz') monkeypatch.setattr(L7Irish, 'url', url) monkeypatch.setattr(L7Irish, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return L7Irish(root, transforms=transforms, download=True, checksum=True) @@ -75,14 +69,14 @@ def test_already_extracted(self, dataset: L7Irish) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'l7irish', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) L7Irish(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - L7Irish(str(tmp_path)) + L7Irish(tmp_path) def test_plot_prediction(self, dataset: L7Irish) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_l8biome.py b/tests/datasets/test_l8biome.py index d00cebb131a..0d3fd3b1c44 100644 --- a/tests/datasets/test_l8biome.py +++ b/tests/datasets/test_l8biome.py @@ -14,7 +14,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, @@ -25,14 +24,9 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestL8Biome: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L8Biome: - monkeypatch.setattr(torchgeo.datasets.l8biome, 'download_url', download_url) md5s = { 'barren': '29c9910adbc89677389f210226fb163d', 'forest': 'b7dbb82fb2c22cbb03389d8828d73713', @@ -41,7 +35,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> L8Biome: url = os.path.join('tests', 'data', 'l8biome', '{}.tar.gz') monkeypatch.setattr(L8Biome, 'url', url) monkeypatch.setattr(L8Biome, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return L8Biome(root, transforms=transforms, download=True, checksum=True) @@ -75,14 +69,14 @@ def test_already_extracted(self, dataset: L8Biome) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) L8Biome(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - L8Biome(str(tmp_path)) + L8Biome(tmp_path) def test_plot_prediction(self, dataset: L8Biome) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index 7c81f257250..cda68604599 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -3,6 +3,7 @@ import os import shutil +from itertools import product from pathlib import Path import matplotlib.pyplot as plt @@ -13,28 +14,23 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, LandCoverAI, + LandCoverAI100, LandCoverAIGeo, ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestLandCoverAIGeo: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LandCoverAIGeo: - monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url) md5 = 'ff8998857cc8511f644d3f7d0f3688d0' monkeypatch.setattr(LandCoverAIGeo, 'md5', md5) url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') monkeypatch.setattr(LandCoverAIGeo, 'url', url) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return LandCoverAIGeo(root, transforms=transforms, download=True, checksum=True) @@ -49,13 +45,13 @@ def test_already_extracted(self, dataset: LandCoverAIGeo) -> None: def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(url, root) LandCoverAIGeo(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LandCoverAIGeo(str(tmp_path)) + LandCoverAIGeo(tmp_path) def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) @@ -78,21 +74,25 @@ def test_plot(self, dataset: LandCoverAIGeo) -> None: class TestLandCoverAI: pytest.importorskip('cv2', minversion='4.5.4') - @pytest.fixture(params=['train', 'val', 'test']) + @pytest.fixture( + params=product([LandCoverAI100, LandCoverAI], ['train', 'val', 'test']) + ) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LandCoverAI: - monkeypatch.setattr(torchgeo.datasets.landcoverai, 'download_url', download_url) + base_class: type[LandCoverAI] = request.param[0] + split: str = request.param[1] md5 = 'ff8998857cc8511f644d3f7d0f3688d0' - monkeypatch.setattr(LandCoverAI, 'md5', md5) + monkeypatch.setattr(base_class, 'md5', md5) url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') - monkeypatch.setattr(LandCoverAI, 'url', url) + monkeypatch.setattr(base_class, 'url', url) sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' - monkeypatch.setattr(LandCoverAI, 'sha256', sha256) - root = str(tmp_path) - split = request.param + monkeypatch.setattr(base_class, 'sha256', sha256) + if base_class == LandCoverAI100: + monkeypatch.setattr(base_class, 'filename', 'landcover.ai.v1.zip') + root = tmp_path transforms = nn.Identity() - return LandCoverAI(root, split, transforms, download=True, checksum=True) + return base_class(root, split, transforms, download=True, checksum=True) def test_getitem(self, dataset: LandCoverAI) -> None: x = dataset[0] @@ -115,13 +115,13 @@ def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> N sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' monkeypatch.setattr(LandCoverAI, 'sha256', sha256) url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(url, root) LandCoverAI(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LandCoverAI(str(tmp_path)) + LandCoverAI(tmp_path) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_landsat.py b/tests/datasets/test_landsat.py index 51e6a6fef24..7bae5cba57f 100644 --- a/tests/datasets/test_landsat.py +++ b/tests/datasets/test_landsat.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn from _pytest.fixtures import SubRequest -from pytest import MonkeyPatch from rasterio.crs import CRS from torchgeo.datasets import ( @@ -29,7 +28,7 @@ class TestLandsat8: ['SR_B4', 'SR_B3', 'SR_B2', 'SR_QA_AEROSOL'], ] ) - def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> Landsat8: + def dataset(self, request: SubRequest) -> Landsat8: root = os.path.join('tests', 'data', 'landsat8') bands = request.param transforms = nn.Identity() @@ -71,7 +70,7 @@ def test_plot_wrong_bands(self, dataset: Landsat8) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Landsat8(str(tmp_path)) + Landsat8(tmp_path) def test_invalid_query(self, dataset: Landsat8) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_levircd.py b/tests/datasets/test_levircd.py index cbec555746c..1105d7e8569 100644 --- a/tests/datasets/test_levircd.py +++ b/tests/datasets/test_levircd.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,14 +11,9 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import LEVIRCD, DatasetNotFoundError, LEVIRCDPlus -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestLEVIRCD: @pytest.fixture(params=['train', 'val', 'test']) def dataset( @@ -43,9 +37,8 @@ def dataset( 'md5': '021db72d4486726d6a0702563a617b32', }, } - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) monkeypatch.setattr(LEVIRCD, 'splits', splits) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return LEVIRCD(root, split, transforms, download=True, checksum=True) @@ -71,7 +64,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LEVIRCD(str(tmp_path)) + LEVIRCD(tmp_path) def test_plot(self, dataset: LEVIRCD) -> None: dataset.plot(dataset[0], suptitle='Test') @@ -88,12 +81,11 @@ class TestLEVIRCDPlus: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LEVIRCDPlus: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) md5 = '0ccca34310bfe7096dadfbf05b0d180f' monkeypatch.setattr(LEVIRCDPlus, 'md5', md5) url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip') monkeypatch.setattr(LEVIRCDPlus, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return LEVIRCDPlus(root, split, transforms, download=True, checksum=True) @@ -119,7 +111,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LEVIRCDPlus(str(tmp_path)) + LEVIRCDPlus(tmp_path) def test_plot(self, dataset: LEVIRCDPlus) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_loveda.py b/tests/datasets/test_loveda.py index be36bec2f1e..580c0529021 100644 --- a/tests/datasets/test_loveda.py +++ b/tests/datasets/test_loveda.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,20 +11,14 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, LoveDA -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestLoveDA: @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> LoveDA: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) md5 = '3d5b1373ef9a3084ec493b9b2056fe07' info_dict = { @@ -48,7 +41,7 @@ def dataset( monkeypatch.setattr(LoveDA, 'info_dict', info_dict) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return LoveDA( @@ -84,7 +77,7 @@ def test_invalid_scene(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - LoveDA(str(tmp_path)) + LoveDA(tmp_path) def test_plot(self, dataset: LoveDA) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_mapinwild.py b/tests/datasets/test_mapinwild.py index aff7d200099..2980499e19c 100644 --- a/tests/datasets/test_mapinwild.py +++ b/tests/datasets/test_mapinwild.py @@ -14,21 +14,14 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, MapInWild -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestMapInWild: @pytest.fixture(params=['train', 'validation', 'test']) def dataset( self, tmp_path: Path, monkeypatch: MonkeyPatch, request: SubRequest ) -> MapInWild: - monkeypatch.setattr(torchgeo.datasets.mapinwild, 'download_url', download_url) - md5s = { 'ESA_WC.zip': '3a1e696353d238c50996958855da02fc', 'VIIRS.zip': 'e8b0e230edb1183c02092357af83bd52', @@ -53,7 +46,7 @@ def dataset( urls = os.path.join('tests', 'data', 'mapinwild') monkeypatch.setattr(MapInWild, 'url', urls) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() @@ -98,12 +91,12 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - MapInWild(root=str(tmp_path)) + MapInWild(root=tmp_path) def test_downloaded_not_extracted(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'mapinwild', '*', '*') pathname_glob = glob.glob(pathname) - root = str(tmp_path) + root = tmp_path for zipfile in pathname_glob: shutil.copy(zipfile, root) MapInWild(root, download=False) @@ -111,7 +104,7 @@ def test_downloaded_not_extracted(self, tmp_path: Path) -> None: def test_corrupted(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'mapinwild', '**', '*.zip') pathname_glob = glob.glob(pathname, recursive=True) - root = str(tmp_path) + root = tmp_path for zipfile in pathname_glob: shutil.copy(zipfile, root) splitfile = os.path.join( @@ -121,10 +114,10 @@ def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'mask.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - MapInWild(root=str(tmp_path), download=True, checksum=True) + MapInWild(root=tmp_path, download=True, checksum=True) def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None: - MapInWild(root=str(tmp_path), modality=dataset.modality, download=True) + MapInWild(root=tmp_path, modality=dataset.modality, download=True) def test_plot(self, dataset: MapInWild) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index 349006ce248..8b1dcef988a 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -39,18 +39,18 @@ def test_len(self, dataset: MillionAID) -> None: def test_not_found(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - MillionAID(str(tmp_path)) + MillionAID(tmp_path) def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'millionaid', 'train.zip') shutil.copy(url, tmp_path) - MillionAID(str(tmp_path)) + MillionAID(tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'train.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - MillionAID(str(tmp_path), checksum=True) + MillionAID(tmp_path, checksum=True) def test_plot(self, dataset: MillionAID) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_mmearth.py b/tests/datasets/test_mmearth.py new file mode 100644 index 00000000000..c25c2a1dece --- /dev/null +++ b/tests/datasets/test_mmearth.py @@ -0,0 +1,144 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest + +from torchgeo.datasets import DatasetNotFoundError, MMEarth + +pytest.importorskip('h5py', minversion='3.6') + +data_dir_dict = { + 'MMEarth': os.path.join('tests', 'data', 'mmearth', 'data_1M_v001'), + 'MMEarth64': os.path.join('tests', 'data', 'mmearth', 'data_1M_v001_64'), + 'MMEarth100k': os.path.join('tests', 'data', 'mmearth', 'data_100k_v001'), +} + + +class TestMMEarth: + @pytest.fixture(params=['MMEarth', 'MMEarth64', 'MMEarth100k']) + def dataset(self, tmp_path: Path, request: SubRequest) -> MMEarth: + root = tmp_path + subset = request.param + shutil.copytree(data_dir_dict[subset], root / Path(data_dir_dict[subset]).name) + transforms = nn.Identity() + return MMEarth(root, subset=subset, transforms=transforms) + + def test_getitem(self, dataset: MMEarth) -> None: + x = dataset[0] + assert isinstance(x, dict) + for modality in dataset.modalities: + modality_name = dataset.modality_category_name.get(modality, '') + modality + assert modality_name in x + assert isinstance(x[modality_name], torch.Tensor) + assert x[modality_name].shape[0] == len(dataset.modality_bands[modality]) + + def test_subset_modalities(self, dataset: MMEarth) -> None: + specified_modalities = ['sentinel2', 'dynamic_world'] + dataset = MMEarth( + dataset.root, subset=dataset.subset, modalities=specified_modalities + ) + x = dataset[0] + assert isinstance(x, dict) + + for modality in dataset.modalities: + modality_name = dataset.modality_category_name.get(modality, '') + modality + if modality in specified_modalities: + assert modality_name in x + else: + assert modality_name not in x + + def test_dataset_not_found(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + MMEarth(tmp_path) + + def test_invalid_modalities(self, dataset: MMEarth) -> None: + with pytest.raises(ValueError, match='is an invalid modality'): + MMEarth(dataset.root, subset=dataset.subset, modalities=['invalid']) + + def test_invalid_modality_bands_modality_name(self, dataset: MMEarth) -> None: + with pytest.raises(ValueError, match='is an invalid modality name'): + MMEarth( + dataset.root, + subset=dataset.subset, + modality_bands={'invalid': ['invalid']}, + ) + + def test_invalid_modality_bands(self, dataset: MMEarth) -> None: + with pytest.raises(ValueError, match='is an invalid band name for modality'): + MMEarth( + dataset.root, + subset=dataset.subset, + modality_bands={'sentinel2': ['invalid']}, + ) + + @pytest.mark.parametrize( + 'modality_bands, modalities', + [ + ({'sentinel2': ['B2', 'B3']}, ['sentinel2']), + ( + {'sentinel1_asc': ['VV'], 'sentinel1_desc': ['VH']}, + ['sentinel1_asc', 'sentinel1_desc'], + ), + ], + ) + def test_subset_modaliy_bands( + self, + dataset: MMEarth, + modality_bands: dict[str, list[str]], + modalities: list[str], + ) -> None: + dataset = MMEarth( + dataset.root, + subset=dataset.subset, + modalities=modalities, + modality_bands=modality_bands, + ) + x = dataset[0] + assert isinstance(x, dict) + + for modality in dataset.modalities: + modality_name = dataset.modality_category_name.get(modality, '') + modality + if modality in modality_bands: + assert modality_name in x + assert x[modality_name].shape[0] == len(modality_bands[modality]) + else: + assert modality_name not in x + + def test_sentinel1_asc_desc(self, dataset: MMEarth) -> None: + modality_bands = {'sentinel1_asc': ['VV'], 'sentinel1_desc': ['VH']} + dataset = MMEarth( + dataset.root, + subset=dataset.subset, + modalities=['sentinel1_asc', 'sentinel1_desc'], + modality_bands=modality_bands, + ) + x = dataset[0] + assert isinstance(x, dict) + + for modality in dataset.modalities: + modality_name = dataset.modality_category_name.get(modality, '') + modality + if modality in modality_bands: + assert modality_name in x + assert x[modality_name].shape[0] == len(modality_bands[modality]) + else: + assert modality_name not in x + + @pytest.mark.parametrize('normalization_mode', ['z-score', 'min-max']) + def test_normalization_mode( + self, dataset: MMEarth, normalization_mode: str + ) -> None: + dataset = MMEarth( + dataset.root, subset=dataset.subset, normalization_mode=normalization_mode + ) + x = dataset[0] + assert isinstance(x, dict) + + def test_len(self, dataset: MMEarth) -> None: + assert len(dataset) >= 2 diff --git a/tests/datasets/test_naip.py b/tests/datasets/test_naip.py index 580b309b432..ea54ec881a3 100644 --- a/tests/datasets/test_naip.py +++ b/tests/datasets/test_naip.py @@ -51,7 +51,7 @@ def test_plot(self, dataset: NAIP) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NAIP(str(tmp_path)) + NAIP(tmp_path) def test_invalid_query(self, dataset: NAIP) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index 588cd89174a..e2787b51b7d 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -13,41 +11,18 @@ from pytest import MonkeyPatch from torchgeo.datasets import DatasetNotFoundError, NASAMarineDebris - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join('tests', 'data', 'nasa_marine_debris', '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(collection_id: str, **kwargs: str) -> Collection: - return Collection() - - -class Collection_corrupted: - def download(self, output_dir: str, **kwargs: str) -> None: - filenames = NASAMarineDebris.filenames - for filename in filenames: - with open(os.path.join(output_dir, filename), 'w') as f: - f.write('bad') - - -def fetch_corrupted(collection_id: str, **kwargs: str) -> Collection_corrupted: - return Collection_corrupted() +from torchgeo.datasets.utils import Executable class TestNASAMarineDebris: - @pytest.fixture() - def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NASAMarineDebris: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - md5s = ['6f4f0d2313323950e45bf3fc0c09b5de', '540cf1cf4fd2c13b609d0355abe955d7'] - monkeypatch.setattr(NASAMarineDebris, 'md5s', md5s) - root = str(tmp_path) + @pytest.fixture + def dataset( + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> NASAMarineDebris: + url = os.path.join('tests', 'data', 'nasa_marine_debris') + monkeypatch.setattr(NASAMarineDebris, 'url', url) transforms = nn.Identity() - return NASAMarineDebris(root, transforms, download=True, checksum=True) + return NASAMarineDebris(tmp_path, transforms, download=True) def test_getitem(self, dataset: NASAMarineDebris) -> None: x = dataset[0] @@ -58,40 +33,16 @@ def test_getitem(self, dataset: NASAMarineDebris) -> None: assert x['boxes'].shape[-1] == 4 def test_len(self, dataset: NASAMarineDebris) -> None: - assert len(dataset) == 4 + assert len(dataset) == 5 def test_already_downloaded( self, dataset: NASAMarineDebris, tmp_path: Path ) -> None: - NASAMarineDebris(root=str(tmp_path), download=True) - - def test_already_downloaded_not_extracted( - self, dataset: NASAMarineDebris, tmp_path: Path - ) -> None: - shutil.rmtree(dataset.root) - os.makedirs(str(tmp_path), exist_ok=True) - Collection().download(output_dir=str(tmp_path)) - NASAMarineDebris(root=str(tmp_path), download=False) - - def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None: - filenames = NASAMarineDebris.filenames - for filename in filenames: - with open(os.path.join(tmp_path, filename), 'w') as f: - f.write('bad') - with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): - NASAMarineDebris(root=str(tmp_path), download=False, checksum=True) - - def test_corrupted_new_download( - self, tmp_path: Path, monkeypatch: MonkeyPatch - ) -> None: - with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'): - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_corrupted) - NASAMarineDebris(root=str(tmp_path), download=True, checksum=True) + NASAMarineDebris(tmp_path, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NASAMarineDebris(str(tmp_path)) + NASAMarineDebris(tmp_path) def test_plot(self, dataset: NASAMarineDebris) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_nccm.py b/tests/datasets/test_nccm.py index 8def40c4c4e..cb4dfa5c4ef 100644 --- a/tests/datasets/test_nccm.py +++ b/tests/datasets/test_nccm.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,7 +11,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( NCCM, BoundingBox, @@ -22,14 +20,9 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestNCCM: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM: - monkeypatch.setattr(torchgeo.datasets.nccm, 'download_url', download_url) md5s = { 2017: 'ae5c390d0ffb8970d544b8a09142759f', 2018: '0d453bdb8ea5b7318c33e62513760580', @@ -43,7 +36,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NCCM: } monkeypatch.setattr(NCCM, 'urls', urls) transforms = nn.Identity() - root = str(tmp_path) + root = tmp_path return NCCM(root, transforms=transforms, download=True, checksum=True) def test_getitem(self, dataset: NCCM) -> None: @@ -84,7 +77,7 @@ def test_plot_prediction(self, dataset: NCCM) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NCCM(str(tmp_path)) + NCCM(tmp_path) def test_invalid_query(self, dataset: NCCM) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_nlcd.py b/tests/datasets/test_nlcd.py index c24220100b1..6f8934a2996 100644 --- a/tests/datasets/test_nlcd.py +++ b/tests/datasets/test_nlcd.py @@ -12,7 +12,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( NLCD, BoundingBox, @@ -22,27 +21,19 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestNLCD: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> NLCD: - monkeypatch.setattr(torchgeo.datasets.nlcd, 'download_url', download_url) - md5s = { - 2011: '99546a3b89a0dddbe4e28e661c79984e', - 2019: 'a4008746f15720b8908ddd357a75fded', + 2011: '3346297a3cb53c9bd1c7e03b2e6e2d74', + 2019: 'a307cdaa1add9dae05efe02fec4c33bb', } monkeypatch.setattr(NLCD, 'md5s', md5s) - url = os.path.join( - 'tests', 'data', 'nlcd', 'nlcd_{}_land_cover_l48_20210604.zip' - ) + url = os.path.join('tests', 'data', 'nlcd', 'Annual_NLCD_LndCov_{}_CU_C1V0.tif') monkeypatch.setattr(NLCD, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return NLCD( root, @@ -82,9 +73,9 @@ def test_already_extracted(self, dataset: NLCD) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( - 'tests', 'data', 'nlcd', 'nlcd_2019_land_cover_l48_20210604.zip' + 'tests', 'data', 'nlcd', 'Annual_NLCD_LndCov_2019_CU_C1V0.tif' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) NLCD(root, years=[2019]) @@ -93,7 +84,7 @@ def test_invalid_year(self, tmp_path: Path) -> None: AssertionError, match='NLCD data product only exists for the following years:', ): - NLCD(str(tmp_path), years=[1996]) + NLCD(tmp_path, years=[1984]) def test_invalid_classes(self) -> None: with pytest.raises(AssertionError): @@ -117,7 +108,7 @@ def test_plot_prediction(self, dataset: NLCD) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - NLCD(str(tmp_path)) + NLCD(tmp_path) def test_invalid_query(self, dataset: NLCD) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_openbuildings.py b/tests/datasets/test_openbuildings.py index 38610ee7195..a322e4e715e 100644 --- a/tests/datasets/test_openbuildings.py +++ b/tests/datasets/test_openbuildings.py @@ -26,7 +26,7 @@ class TestOpenBuildings: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings: - root = str(tmp_path) + root = tmp_path shutil.copy( os.path.join('tests', 'data', 'openbuildings', 'tiles.geojson'), root ) @@ -55,7 +55,7 @@ def test_no_shapes_to_rasterize( def test_not_download(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - OpenBuildings(str(tmp_path)) + OpenBuildings(tmp_path) def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None: with open(os.path.join(tmp_path, '000_buildings.csv.gz'), 'w') as f: diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index cd1c80a443b..711392f7fc4 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -14,20 +14,14 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import OSCD, DatasetNotFoundError, RGBBandsMissingError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestOSCD: @pytest.fixture(params=zip([OSCD.all_bands, OSCD.rgb_bands], ['train', 'test'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> OSCD: - monkeypatch.setattr(torchgeo.datasets.oscd, 'download_url', download_url) md5s = { 'Onera Satellite Change Detection dataset - Images.zip': ( 'fb4e3f54c3a31fd3f21f98cad4ddfb74' @@ -63,7 +57,7 @@ def dataset( monkeypatch.setattr(OSCD, 'urls', urls) bands, split = request.param - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return OSCD( root, split, bands, transforms=transforms, download=True, checksum=True @@ -101,14 +95,14 @@ def test_already_extracted(self, dataset: OSCD) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'oscd', '*Onera*.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) OSCD(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - OSCD(str(tmp_path)) + OSCD(tmp_path) def test_plot(self, dataset: OSCD) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_pastis.py b/tests/datasets/test_pastis.py index 62ff5f913e6..be327e1628d 100644 --- a/tests/datasets/test_pastis.py +++ b/tests/datasets/test_pastis.py @@ -13,14 +13,9 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import PASTIS, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestPASTIS: @pytest.fixture( params=[ @@ -32,13 +27,11 @@ class TestPASTIS: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> PASTIS: - monkeypatch.setattr(torchgeo.datasets.pastis, 'download_url', download_url) - md5 = '135a29fb8221241dde14f31579c07f45' monkeypatch.setattr(PASTIS, 'md5', md5) url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip') monkeypatch.setattr(PASTIS, 'url', url) - root = str(tmp_path) + root = tmp_path folds = request.param['folds'] bands = request.param['bands'] mode = request.param['mode'] @@ -75,19 +68,19 @@ def test_already_extracted(self, dataset: PASTIS) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(url, root) PASTIS(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - PASTIS(str(tmp_path)) + PASTIS(tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'PASTIS-R.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - PASTIS(root=str(tmp_path), checksum=True) + PASTIS(root=tmp_path, checksum=True) def test_invalid_fold(self) -> None: with pytest.raises(AssertionError): diff --git a/tests/datasets/test_patternnet.py b/tests/datasets/test_patternnet.py index 915d7388bad..e4c18bdba59 100644 --- a/tests/datasets/test_patternnet.py +++ b/tests/datasets/test_patternnet.py @@ -11,23 +11,17 @@ import torch.nn as nn from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, PatternNet -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestPatternNet: @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> PatternNet: - monkeypatch.setattr(torchgeo.datasets.patternnet, 'download_url', download_url) md5 = '5649754c78219a2c19074ff93666cc61' monkeypatch.setattr(PatternNet, 'md5', md5) url = os.path.join('tests', 'data', 'patternnet', 'PatternNet.zip') monkeypatch.setattr(PatternNet, 'url', url) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return PatternNet(root, transforms, download=True, checksum=True) @@ -42,18 +36,18 @@ def test_len(self, dataset: PatternNet) -> None: assert len(dataset) == 2 def test_already_downloaded(self, dataset: PatternNet, tmp_path: Path) -> None: - PatternNet(root=str(tmp_path), download=True) + PatternNet(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: PatternNet, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - PatternNet(root=str(tmp_path), download=False) + shutil.copy(dataset.url, tmp_path) + PatternNet(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - PatternNet(str(tmp_path)) + PatternNet(tmp_path) def test_plot(self, dataset: PatternNet) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_potsdam.py b/tests/datasets/test_potsdam.py index 4529d937690..9de329686d0 100644 --- a/tests/datasets/test_potsdam.py +++ b/tests/datasets/test_potsdam.py @@ -43,9 +43,9 @@ def test_extract(self, tmp_path: Path) -> None: root = os.path.join('tests', 'data', 'potsdam') for filename in ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']: shutil.copyfile( - os.path.join(root, filename), os.path.join(str(tmp_path), filename) + os.path.join(root, filename), os.path.join(tmp_path, filename) ) - Potsdam2D(root=str(tmp_path)) + Potsdam2D(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, '4_Ortho_RGBIR.zip'), 'w') as f: @@ -53,7 +53,7 @@ def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, '5_Labels_all.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - Potsdam2D(root=str(tmp_path), checksum=True) + Potsdam2D(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -61,7 +61,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Potsdam2D(str(tmp_path)) + Potsdam2D(tmp_path) def test_plot(self, dataset: Potsdam2D) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_prisma.py b/tests/datasets/test_prisma.py index 89ab52c7275..d43af61e97f 100644 --- a/tests/datasets/test_prisma.py +++ b/tests/datasets/test_prisma.py @@ -50,7 +50,7 @@ def test_plot(self, dataset: PRISMA) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - PRISMA(str(tmp_path)) + PRISMA(tmp_path) def test_invalid_query(self, dataset: PRISMA) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_quakeset.py b/tests/datasets/test_quakeset.py index 636e9d7a666..fbb6ea29234 100644 --- a/tests/datasets/test_quakeset.py +++ b/tests/datasets/test_quakeset.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -12,27 +11,21 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, QuakeSet pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestQuakeSet: @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> QuakeSet: - monkeypatch.setattr(torchgeo.datasets.quakeset, 'download_url', download_url) url = os.path.join('tests', 'data', 'quakeset', 'earthquakes.h5') md5 = '127d0d6a1f82d517129535f50053a4c9' monkeypatch.setattr(QuakeSet, 'md5', md5) monkeypatch.setattr(QuakeSet, 'url', url) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return QuakeSet( @@ -50,11 +43,11 @@ def test_len(self, dataset: QuakeSet) -> None: assert len(dataset) == 8 def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None: - QuakeSet(root=str(tmp_path), download=True) + QuakeSet(root=tmp_path, download=True) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - QuakeSet(str(tmp_path)) + QuakeSet(tmp_path) def test_plot(self, dataset: QuakeSet) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_reforestree.py b/tests/datasets/test_reforestree.py index 092e7cf2f1f..c0ab375d9f6 100644 --- a/tests/datasets/test_reforestree.py +++ b/tests/datasets/test_reforestree.py @@ -11,27 +11,18 @@ import torch.nn as nn from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, ReforesTree -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestReforesTree: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ReforesTree: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'reforestree') - url = os.path.join(data_dir, 'reforesTree.zip') - md5 = '387e04dbbb0aa803f72bd6d774409648' - monkeypatch.setattr(ReforesTree, 'url', url) monkeypatch.setattr(ReforesTree, 'md5', md5) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ReforesTree( root=root, transforms=transforms, download=True, checksum=True @@ -57,17 +48,17 @@ def test_len(self, dataset: ReforesTree) -> None: def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join('tests', 'data', 'reforestree', 'reforesTree.zip') shutil.copy(url, tmp_path) - ReforesTree(root=str(tmp_path)) + ReforesTree(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, 'reforesTree.zip'), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - ReforesTree(root=str(tmp_path), checksum=True) + ReforesTree(root=tmp_path, checksum=True) def test_not_found(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ReforesTree(str(tmp_path)) + ReforesTree(tmp_path) def test_plot(self, dataset: ReforesTree) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_resisc45.py b/tests/datasets/test_resisc45.py index d52d2d01194..adcd59dc004 100644 --- a/tests/datasets/test_resisc45.py +++ b/tests/datasets/test_resisc45.py @@ -12,25 +12,15 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import RESISC45, DatasetNotFoundError -pytest.importorskip('rarfile', minversion='4') - - -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - class TestRESISC45: @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> RESISC45: - monkeypatch.setattr(torchgeo.datasets.resisc45, 'download_url', download_url) - md5 = '5895dea3757ba88707d52f5521c444d3' - monkeypatch.setattr(RESISC45, 'md5', md5) - url = os.path.join('tests', 'data', 'resisc45', 'NWPU-RESISC45.rar') + url = os.path.join('tests', 'data', 'resisc45', 'NWPU-RESISC45.zip') monkeypatch.setattr(RESISC45, 'url', url) monkeypatch.setattr( RESISC45, @@ -52,7 +42,7 @@ def dataset( 'test': '7760b1960c9a3ff46fb985810815e14d', }, ) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return RESISC45(root, split, transforms, download=True, checksum=True) @@ -68,18 +58,18 @@ def test_len(self, dataset: RESISC45) -> None: assert len(dataset) == 9 def test_already_downloaded(self, dataset: RESISC45, tmp_path: Path) -> None: - RESISC45(root=str(tmp_path), download=True) + RESISC45(root=tmp_path, download=True) def test_already_downloaded_not_extracted( self, dataset: RESISC45, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - RESISC45(root=str(tmp_path), download=False) + shutil.copy(dataset.url, tmp_path) + RESISC45(root=tmp_path, download=False) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - RESISC45(str(tmp_path)) + RESISC45(tmp_path) def test_plot(self, dataset: RESISC45) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_rwanda_field_boundary.py b/tests/datasets/test_rwanda_field_boundary.py index 6f83b12a93d..d08532e7507 100644 --- a/tests/datasets/test_rwanda_field_boundary.py +++ b/tests/datasets/test_rwanda_field_boundary.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -19,45 +17,26 @@ RGBBandsMissingError, RwandaFieldBoundary, ) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join('tests', 'data', 'rwanda_field_boundary', '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch(dataset_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestRwandaFieldBoundary: @pytest.fixture(params=['train', 'test']) def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + self, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, + request: SubRequest, ) -> RwandaFieldBoundary: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - monkeypatch.setattr( - RwandaFieldBoundary, 'number_of_patches_per_split', {'train': 5, 'test': 5} - ) - monkeypatch.setattr( - RwandaFieldBoundary, - 'md5s', - { - 'train_images': 'af9395e2e49deefebb35fa65fa378ba3', - 'test_images': 'd104bb82323a39e7c3b3b7dd0156f550', - 'train_labels': '6cceaf16a141cf73179253a783e7d51b', - }, - ) + url = os.path.join('tests', 'data', 'rwanda_field_boundary') + monkeypatch.setattr(RwandaFieldBoundary, 'url', url) + monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1}) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() - return RwandaFieldBoundary( - root, split, transforms=transforms, api_key='', download=True, checksum=True - ) + return RwandaFieldBoundary(root, split, transforms=transforms, download=True) def test_getitem(self, dataset: RwandaFieldBoundary) -> None: x = dataset[0] @@ -69,60 +48,22 @@ def test_getitem(self, dataset: RwandaFieldBoundary) -> None: assert 'mask' not in x def test_len(self, dataset: RwandaFieldBoundary) -> None: - assert len(dataset) == 5 + assert len(dataset) == 1 def test_add(self, dataset: RwandaFieldBoundary) -> None: ds = dataset + dataset assert isinstance(ds, ConcatDataset) - assert len(ds) == 10 - - def test_needs_extraction(self, tmp_path: Path) -> None: - root = str(tmp_path) - for fn in [ - 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - ]: - url = os.path.join('tests', 'data', 'rwanda_field_boundary', fn) - shutil.copy(url, root) - RwandaFieldBoundary(root, checksum=False) + assert len(ds) == 2 def test_already_downloaded(self, dataset: RwandaFieldBoundary) -> None: RwandaFieldBoundary(root=dataset.root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - RwandaFieldBoundary(str(tmp_path)) - - def test_corrupted(self, tmp_path: Path) -> None: - for fn in [ - 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - ]: - with open(os.path.join(tmp_path, fn), 'w') as f: - f.write('bad') - with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - RwandaFieldBoundary(root=str(tmp_path), checksum=True) - - def test_failed_download(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - monkeypatch.setattr( - RwandaFieldBoundary, - 'md5s', - {'train_images': 'bad', 'test_images': 'bad', 'train_labels': 'bad'}, - ) - root = str(tmp_path) - with pytest.raises(RuntimeError, match='Dataset not found or corrupted.'): - RwandaFieldBoundary(root, 'train', api_key='', download=True, checksum=True) - - def test_no_api_key(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match='Must provide an API key to download'): - RwandaFieldBoundary(str(tmp_path), api_key=None, download=True) + RwandaFieldBoundary(tmp_path) def test_invalid_bands(self) -> None: - with pytest.raises(ValueError, match='is an invalid band name.'): + with pytest.raises(AssertionError): RwandaFieldBoundary(bands=('foo', 'bar')) def test_plot(self, dataset: RwandaFieldBoundary) -> None: diff --git a/tests/datasets/test_satlas.py b/tests/datasets/test_satlas.py new file mode 100644 index 00000000000..7c10f55bd7b --- /dev/null +++ b/tests/datasets/test_satlas.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch.nn as nn +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, SatlasPretrain +from torchgeo.datasets.utils import Executable + + +class TestSatlasPretrain: + @pytest.fixture + def dataset( + self, aws: Executable, monkeypatch: MonkeyPatch, tmp_path: Path + ) -> SatlasPretrain: + url = os.path.join('tests', 'data', 'satlas', '') + monkeypatch.setattr(SatlasPretrain, 'url', url) + images = ('landsat', 'naip', 'sentinel1', 'sentinel2') + products = (*images, 'static', 'metadata') + tarballs = {product: (f'{product}.tar',) for product in products} + monkeypatch.setattr(SatlasPretrain, 'tarballs', tarballs) + transforms = nn.Identity() + return SatlasPretrain( + tmp_path, images=images, transforms=transforms, download=True + ) + + @pytest.mark.parametrize('index', [0, 1]) + def test_getitem(self, dataset: SatlasPretrain, index: int) -> None: + x = dataset[index] + assert isinstance(x, dict) + for image in dataset.images: + assert isinstance(x[f'image_{image}'], Tensor) + assert isinstance(x[f'time_{image}'], Tensor) + for label in dataset.labels: + assert isinstance(x[f'mask_{label}'], Tensor) + + def test_len(self, dataset: SatlasPretrain) -> None: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: SatlasPretrain) -> None: + shutil.rmtree(os.path.join(dataset.root, 'landsat')) + SatlasPretrain(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + SatlasPretrain(tmp_path) + + def test_plot(self, dataset: SatlasPretrain) -> None: + x = dataset[0] + x['prediction_land_cover'] = x['mask_land_cover'] + dataset.plot(x, suptitle='Test') + plt.close() diff --git a/tests/datasets/test_seasonet.py b/tests/datasets/test_seasonet.py index 9178dcb1217..93e5a99fc59 100644 --- a/tests/datasets/test_seasonet.py +++ b/tests/datasets/test_seasonet.py @@ -14,17 +14,9 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, SeasoNet -def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - torchgeo.datasets.utils.check_integrity( - os.path.join(root, os.path.basename(url)), md5 - ) - - class TestSeasoNet: @pytest.fixture( params=zip( @@ -38,7 +30,6 @@ class TestSeasoNet: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SeasoNet: - monkeypatch.setattr(torchgeo.datasets.seasonet, 'download_url', download_url) monkeypatch.setitem( SeasoNet.metadata[0], 'md5', '836a0896eba0e3005208f3fd180e429d' ) @@ -95,7 +86,7 @@ def dataset( 'url', os.path.join('tests', 'data', 'seasonet', 'meta.csv'), ) - root = str(tmp_path) + root = tmp_path split, seasons, bands, grids, concat_seasons = request.param transforms = nn.Identity() return SeasoNet( @@ -141,14 +132,14 @@ def test_already_extracted(self, dataset: SeasoNet) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: paths = os.path.join('tests', 'data', 'seasonet', '*.*') - root = str(tmp_path) + root = tmp_path for path in glob.iglob(paths): shutil.copy(path, root) SeasoNet(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SeasoNet(str(tmp_path), download=False) + SeasoNet(tmp_path, download=False) def test_out_of_bounds(self, dataset: SeasoNet) -> None: with pytest.raises(IndexError): diff --git a/tests/datasets/test_seco.py b/tests/datasets/test_seco.py index ed273a8810c..1e7570808ea 100644 --- a/tests/datasets/test_seco.py +++ b/tests/datasets/test_seco.py @@ -14,7 +14,6 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import ( DatasetNotFoundError, RGBBandsMissingError, @@ -22,10 +21,6 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestSeasonalContrastS2: @pytest.fixture( params=zip( @@ -37,7 +32,6 @@ class TestSeasonalContrastS2: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SeasonalContrastS2: - monkeypatch.setattr(torchgeo.datasets.seco, 'download_url', download_url) monkeypatch.setitem( SeasonalContrastS2.metadata['100k'], 'url', @@ -56,7 +50,7 @@ def dataset( monkeypatch.setitem( SeasonalContrastS2.metadata['1m'], 'md5', '3bb3fcf90f5de7d5781ce0cb85fd20af' ) - root = str(tmp_path) + root = tmp_path version, seasons, bands = request.param transforms = nn.Identity() return SeasonalContrastS2( @@ -88,7 +82,7 @@ def test_already_extracted(self, dataset: SeasonalContrastS2) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'seco', '*.zip') - root = str(tmp_path) + root = tmp_path for zipfile in glob.iglob(pathname): shutil.copy(zipfile, root) SeasonalContrastS2(root) @@ -103,7 +97,7 @@ def test_invalid_band(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SeasonalContrastS2(str(tmp_path)) + SeasonalContrastS2(tmp_path) def test_plot(self, dataset: SeasonalContrastS2) -> None: x = dataset[0] diff --git a/tests/datasets/test_sen12ms.py b/tests/datasets/test_sen12ms.py index 5732eaf18dd..b7ff8e8e978 100644 --- a/tests/datasets/test_sen12ms.py +++ b/tests/datasets/test_sen12ms.py @@ -66,10 +66,10 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SEN12MS(str(tmp_path), checksum=True) + SEN12MS(tmp_path, checksum=True) with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SEN12MS(str(tmp_path), checksum=False) + SEN12MS(tmp_path, checksum=False) def test_check_integrity_light(self) -> None: root = os.path.join('tests', 'data', 'sen12ms') diff --git a/tests/datasets/test_sentinel.py b/tests/datasets/test_sentinel.py index 28cf6eb1e67..ee4933b44f7 100644 --- a/tests/datasets/test_sentinel.py +++ b/tests/datasets/test_sentinel.py @@ -70,7 +70,7 @@ def test_plot(self, dataset: Sentinel2) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Sentinel1(str(tmp_path)) + Sentinel1(tmp_path) def test_empty_bands(self) -> None: with pytest.raises(AssertionError, match="'bands' cannot be an empty list"): @@ -132,7 +132,7 @@ def test_or(self, dataset: Sentinel2) -> None: def test_no_data(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Sentinel2(str(tmp_path)) + Sentinel2(tmp_path) def test_plot(self, dataset: Sentinel2) -> None: x = dataset[dataset.bounds] diff --git a/tests/datasets/test_skippd.py b/tests/datasets/test_skippd.py index d4deb975b3b..68f4e889df8 100644 --- a/tests/datasets/test_skippd.py +++ b/tests/datasets/test_skippd.py @@ -13,25 +13,17 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import SKIPPD, DatasetNotFoundError pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestSKIPPD: @pytest.fixture(params=product(['nowcast', 'forecast'], ['trainval', 'test'])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SKIPPD: task, split = request.param - - monkeypatch.setattr(torchgeo.datasets.skippd, 'download_url', download_url) - md5 = { 'nowcast': '6f5e54906927278b189f9281a2f54f39', 'forecast': 'f3b5d7d5c28ba238144fa1e726c46969', @@ -40,7 +32,7 @@ def dataset( url = os.path.join('tests', 'data', 'skippd', '{}') monkeypatch.setattr(SKIPPD, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return SKIPPD( root=root, @@ -59,7 +51,7 @@ def test_already_downloaded(self, tmp_path: Path, task: str) -> None: pathname = os.path.join( 'tests', 'data', 'skippd', f'2017_2019_images_pv_processed_{task}.zip' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) SKIPPD(root=root, task=task) @@ -84,7 +76,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SKIPPD(str(tmp_path)) + SKIPPD(tmp_path) def test_plot(self, dataset: SKIPPD) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_skyscript.py b/tests/datasets/test_skyscript.py new file mode 100644 index 00000000000..f937ea48332 --- /dev/null +++ b/tests/datasets/test_skyscript.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path + +import pytest +import torch.nn as nn +from matplotlib import pyplot as plt +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, SkyScript + + +class TestSkyScript: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SkyScript: + url = os.path.join('tests', 'data', 'skyscript', '{}') + monkeypatch.setattr(SkyScript, 'url', url) + transforms = nn.Identity() + return SkyScript(tmp_path, transforms=transforms, download=True) + + def test_getitem(self, dataset: SkyScript) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], Tensor) + assert isinstance(x['caption'], str) + + def test_len(self, dataset: SkyScript) -> None: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: SkyScript) -> None: + shutil.rmtree(os.path.join(dataset.root, 'images2')) + SkyScript(dataset.root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + SkyScript(tmp_path) + + def test_plot(self, dataset: SkyScript) -> None: + x = dataset[0] + x['prediction'] = x['caption'] + dataset.plot(x, suptitle='Test') + plt.close() diff --git a/tests/datasets/test_so2sat.py b/tests/datasets/test_so2sat.py index 1caf86b6c30..bc88662c16a 100644 --- a/tests/datasets/test_so2sat.py +++ b/tests/datasets/test_so2sat.py @@ -58,7 +58,7 @@ def test_invalid_bands(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - So2Sat(str(tmp_path)) + So2Sat(tmp_path) def test_plot(self, dataset: So2Sat) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_south_africa_crop_type.py b/tests/datasets/test_south_africa_crop_type.py index 75b014e2227..e274ed3a442 100644 --- a/tests/datasets/test_south_africa_crop_type.py +++ b/tests/datasets/test_south_africa_crop_type.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch from rasterio.crs import CRS from torchgeo.datasets import ( @@ -19,15 +20,25 @@ SouthAfricaCropType, UnionDataset, ) +from torchgeo.datasets.utils import Executable class TestSouthAfricaCropType: @pytest.fixture(params=[SouthAfricaCropType.s1_bands, SouthAfricaCropType.s2_bands]) - def dataset(self, request: SubRequest) -> SouthAfricaCropType: - path = os.path.join('tests', 'data', 'south_africa_crop_type') + def dataset( + self, + request: SubRequest, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, + ) -> SouthAfricaCropType: + url = os.path.join('tests', 'data', 'south_africa_crop_type') + monkeypatch.setattr(SouthAfricaCropType, 'url', url) bands = request.param transforms = nn.Identity() - return SouthAfricaCropType(path, bands=bands, transforms=transforms) + return SouthAfricaCropType( + tmp_path, bands=bands, transforms=transforms, download=True + ) def test_getitem(self, dataset: SouthAfricaCropType) -> None: x = dataset[dataset.bounds] @@ -52,7 +63,7 @@ def test_already_downloaded(self, dataset: SouthAfricaCropType) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SouthAfricaCropType(str(tmp_path)) + SouthAfricaCropType(tmp_path) def test_plot(self) -> None: path = os.path.join('tests', 'data', 'south_africa_crop_type') diff --git a/tests/datasets/test_south_america_soybean.py b/tests/datasets/test_south_america_soybean.py index c119dc2749b..f04bf14b221 100644 --- a/tests/datasets/test_south_america_soybean.py +++ b/tests/datasets/test_south_america_soybean.py @@ -11,7 +11,6 @@ from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import ( BoundingBox, DatasetNotFoundError, @@ -21,23 +20,16 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestSouthAmericaSoybean: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SouthAmericaSoybean: - monkeypatch.setattr( - torchgeo.datasets.south_america_soybean, 'download_url', download_url - ) transforms = nn.Identity() url = os.path.join( 'tests', 'data', 'south_america_soybean', 'SouthAmerica_Soybean_{}.tif' ) monkeypatch.setattr(SouthAmericaSoybean, 'url', url) - root = str(tmp_path) + root = tmp_path return SouthAmericaSoybean( paths=root, years=[2002, 2021], @@ -70,7 +62,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( 'tests', 'data', 'south_america_soybean', 'SouthAmerica_Soybean_2002.tif' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) SouthAmericaSoybean(root) @@ -89,7 +81,7 @@ def test_plot_prediction(self, dataset: SouthAmericaSoybean) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SouthAmericaSoybean(str(tmp_path)) + SouthAmericaSoybean(tmp_path) def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_spacenet.py b/tests/datasets/test_spacenet.py index 2676af497fd..36d0d57c24b 100644 --- a/tests/datasets/test_spacenet.py +++ b/tests/datasets/test_spacenet.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os import shutil from pathlib import Path @@ -13,434 +12,80 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import ( - DatasetNotFoundError, - SpaceNet1, - SpaceNet2, - SpaceNet3, - SpaceNet4, - SpaceNet5, - SpaceNet6, - SpaceNet7, -) +from torchgeo.datasets import DatasetNotFoundError, SpaceNet, SpaceNet1, SpaceNet6 +from torchgeo.datasets.utils import Executable -TEST_DATA_DIR = 'tests/data/spacenet' -radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - -class Collection: - def __init__(self, collection_id: str) -> None: - self.collection_id = collection_id - - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join(TEST_DATA_DIR, '*.tar.gz') - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -class Dataset: - def __init__(self, dataset_id: str) -> None: - self.dataset_id = dataset_id - - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join(TEST_DATA_DIR, 'spacenet*') - for directory in glob.iglob(glob_path): - dataset_name = os.path.basename(directory) - output_dir = os.path.join(output_dir, dataset_name) - shutil.copytree(directory, output_dir) - - -def fetch_collection(collection_id: str, **kwargs: str) -> Collection: - return Collection(collection_id) - - -def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset: - return Dataset(dataset_id) - - -class TestSpaceNet1: - @pytest.fixture(params=['rgb', '8band']) +class TestSpaceNet: + @pytest.fixture(params=[SpaceNet1, SpaceNet6]) def dataset( - self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> SpaceNet1: - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - test_md5 = {'sn1_AOI_1_RIO': '127a523561987110f008e8c9815ce807'} - - # Refer https://github.com/python/mypy/issues/1032 - monkeypatch.setattr(SpaceNet1, 'collection_md5_dict', test_md5) - root = str(tmp_path) - transforms = nn.Identity() - return SpaceNet1( - root, image=request.param, transforms=transforms, download=True, api_key='' + self, + request: SubRequest, + aws: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, + ) -> SpaceNet: + dataset_class: type[SpaceNet] = request.param + url = os.path.join( + 'tests', + 'data', + 'spacenet', + dataset_class.__name__.lower(), + '{dataset_id}', + 'train', + '{tarball}', ) - - def test_getitem(self, dataset: SpaceNet1) -> None: - x = dataset[0] - dataset[1] - assert isinstance(x, dict) - assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['mask'], torch.Tensor) - if dataset.image == 'rgb': - assert x['image'].shape[0] == 3 - else: - assert x['image'].shape[0] == 8 - - def test_len(self, dataset: SpaceNet1) -> None: - assert len(dataset) == 3 - - def test_already_downloaded(self, dataset: SpaceNet1) -> None: - SpaceNet1(root=dataset.root, download=True) - - def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet1(str(tmp_path)) - - def test_plot(self, dataset: SpaceNet1) -> None: - x = dataset[0].copy() - x['prediction'] = x['mask'] - dataset.plot(x, suptitle='Test') - plt.close() - dataset.plot(x, show_titles=False) - plt.close() - - -class TestSpaceNet2: - @pytest.fixture(params=['PAN', 'MS', 'PS-MS', 'PS-RGB']) - def dataset( - self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> SpaceNet2: - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - test_md5 = { - 'sn2_AOI_2_Vegas': '131048686ba21a45853c05f227f40b7f', - 'sn2_AOI_3_Paris': '62242fd198ee32b59f0178cf656e1513', - 'sn2_AOI_4_Shanghai': '563b0817ecedd8ff3b3e4cb2991bf3fb', - 'sn2_AOI_5_Khartoum': 'e4185a2e9a12cf7b3d0cd1db6b3e0f06', - } - - monkeypatch.setattr(SpaceNet2, 'collection_md5_dict', test_md5) - root = str(tmp_path) + monkeypatch.setattr(dataset_class, 'url', url) transforms = nn.Identity() - return SpaceNet2( - root, - image=request.param, - collections=['sn2_AOI_2_Vegas', 'sn2_AOI_5_Khartoum'], - transforms=transforms, - download=True, - api_key='', - ) + return dataset_class(tmp_path, transforms=transforms, download=True) - def test_getitem(self, dataset: SpaceNet2) -> None: - x = dataset[0] + @pytest.mark.parametrize('index', [0, 1]) + def test_getitem(self, dataset: SpaceNet, index: int) -> None: + x = dataset[index] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) assert isinstance(x['mask'], torch.Tensor) - if dataset.image == 'PS-RGB': - assert x['image'].shape[0] == 3 - elif dataset.image in ['MS', 'PS-MS']: - assert x['image'].shape[0] == 8 - else: - assert x['image'].shape[0] == 1 - def test_len(self, dataset: SpaceNet2) -> None: + def test_len(self, dataset: SpaceNet) -> None: assert len(dataset) == 4 - def test_already_downloaded(self, dataset: SpaceNet2) -> None: - SpaceNet2(root=dataset.root, download=True) - - def test_not_downloaded(self, tmp_path: Path) -> None: + def test_already_extracted(self, dataset: SpaceNet) -> None: + dataset.__class__(root=dataset.root) + + def test_already_downloaded(self, dataset: SpaceNet) -> None: + if dataset.dataset_id == 'SN1_buildings': + base_dir = os.path.join(dataset.root, dataset.dataset_id, dataset.split) + elif dataset.dataset_id == 'SN6_buildings': + base_dir = os.path.join( + dataset.root, + dataset.dataset_id, + dataset.split, + dataset.split, + 'AOI_11_Rotterdam', + ) + for product in dataset.valid_images['train'] + list(dataset.valid_masks): + dir = os.path.join(base_dir, product) + shutil.rmtree(dir) + dataset.__class__(root=dataset.root) + + def test_not_downloaded(self, tmp_path: Path, dataset: SpaceNet) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet2(str(tmp_path)) - - def test_collection_checksum(self, dataset: SpaceNet2) -> None: - dataset.collection_md5_dict['sn2_AOI_2_Vegas'] = 'randommd5hash123' - with pytest.raises(RuntimeError, match='Collection sn2_AOI_2_Vegas corrupted'): - SpaceNet2(root=dataset.root, download=True, checksum=True) + dataset.__class__(root=os.path.join(tmp_path, 'dummy')) - def test_plot(self, dataset: SpaceNet2) -> None: - x = dataset[0].copy() - x['prediction'] = x['mask'] - dataset.plot(x, suptitle='Test') - plt.close() - dataset.plot(x, show_titles=False) - plt.close() - - -class TestSpaceNet3: - @pytest.fixture(params=zip(['PAN', 'MS'], [False, True])) - def dataset( - self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> SpaceNet3: - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - test_md5 = { - 'sn3_AOI_3_Paris': '93452c68da11dd6b57dc83dba43c2c9d', - 'sn3_AOI_5_Khartoum': '7c9d96810198bf101cbaf54f7a5e8b3b', - } - - monkeypatch.setattr(SpaceNet3, 'collection_md5_dict', test_md5) - root = str(tmp_path) - transforms = nn.Identity() - return SpaceNet3( - root, - image=request.param[0], - speed_mask=request.param[1], - collections=['sn3_AOI_3_Paris', 'sn3_AOI_5_Khartoum'], - transforms=transforms, - download=True, - api_key='', - ) - - def test_getitem(self, dataset: SpaceNet3) -> None: - # Iterate over all elements to maximize coverage - samples = [dataset[i] for i in range(len(dataset))] - x = samples[0] - assert isinstance(x, dict) - assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['mask'], torch.Tensor) - if dataset.image == 'MS': - assert x['image'].shape[0] == 8 - else: - assert x['image'].shape[0] == 1 - - def test_len(self, dataset: SpaceNet3) -> None: - assert len(dataset) == 4 - - def test_already_downloaded(self, dataset: SpaceNet3) -> None: - SpaceNet3(root=dataset.root, download=True) - - def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet3(str(tmp_path)) - - def test_collection_checksum(self, dataset: SpaceNet3) -> None: - dataset.collection_md5_dict['sn3_AOI_5_Khartoum'] = 'randommd5hash123' - with pytest.raises( - RuntimeError, match='Collection sn3_AOI_5_Khartoum corrupted' - ): - SpaceNet3(root=dataset.root, download=True, checksum=True) - - def test_plot(self, dataset: SpaceNet3) -> None: - x = dataset[0].copy() - x['prediction'] = x['mask'] - dataset.plot(x, suptitle='Test') - plt.close() - dataset.plot(x, show_titles=False) - plt.close() - dataset.plot({'image': x['image']}) - plt.close() - - -class TestSpaceNet4: - @pytest.fixture(params=['PAN', 'MS', 'PS-RGBNIR']) - def dataset( - self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> SpaceNet4: - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - test_md5 = {'sn4_AOI_6_Atlanta': '097a76a2319b7ba34dac1722862fc93b'} - - test_angles = ['nadir', 'off-nadir', 'very-off-nadir'] - - monkeypatch.setattr(SpaceNet4, 'collection_md5_dict', test_md5) - root = str(tmp_path) - transforms = nn.Identity() - return SpaceNet4( - root, - image=request.param, - angles=test_angles, - transforms=transforms, - download=True, - api_key='', - ) - - def test_getitem(self, dataset: SpaceNet4) -> None: - # Get image-label pair with empty label to - # ensure coverage - x = dataset[2] - assert isinstance(x, dict) - assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['mask'], torch.Tensor) - if dataset.image == 'PS-RGBNIR': - assert x['image'].shape[0] == 4 - elif dataset.image == 'MS': - assert x['image'].shape[0] == 8 - else: - assert x['image'].shape[0] == 1 - - def test_len(self, dataset: SpaceNet4) -> None: - assert len(dataset) == 4 - - def test_already_downloaded(self, dataset: SpaceNet4) -> None: - SpaceNet4(root=dataset.root, download=True) - - def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet4(str(tmp_path)) - - def test_collection_checksum(self, dataset: SpaceNet4) -> None: - dataset.collection_md5_dict['sn4_AOI_6_Atlanta'] = 'randommd5hash123' - with pytest.raises( - RuntimeError, match='Collection sn4_AOI_6_Atlanta corrupted' - ): - SpaceNet4(root=dataset.root, download=True, checksum=True) - - def test_plot(self, dataset: SpaceNet4) -> None: - x = dataset[0].copy() - x['prediction'] = x['mask'] - dataset.plot(x, suptitle='Test') - plt.close() - dataset.plot(x, show_titles=False) - plt.close() - - -class TestSpaceNet5: - @pytest.fixture(params=zip(['PAN', 'MS'], [False, True])) - def dataset( - self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> SpaceNet5: - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - test_md5 = { - 'sn5_AOI_7_Moscow': '5c511dd31eea739cc1f81ef5962f3d56', - 'sn5_AOI_8_Mumbai': 'e00452b87bbe87feaef65f373be3978e', - } - - monkeypatch.setattr(SpaceNet5, 'collection_md5_dict', test_md5) - root = str(tmp_path) - transforms = nn.Identity() - return SpaceNet5( - root, - image=request.param[0], - speed_mask=request.param[1], - collections=['sn5_AOI_7_Moscow', 'sn5_AOI_8_Mumbai'], - transforms=transforms, - download=True, - api_key='', - ) - - def test_getitem(self, dataset: SpaceNet5) -> None: - # Iterate over all elements to maximize coverage - samples = [dataset[i] for i in range(len(dataset))] - x = samples[0] - assert isinstance(x, dict) - assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['mask'], torch.Tensor) - if dataset.image == 'MS': - assert x['image'].shape[0] == 8 - else: - assert x['image'].shape[0] == 1 - - def test_len(self, dataset: SpaceNet5) -> None: - assert len(dataset) == 5 - - def test_already_downloaded(self, dataset: SpaceNet5) -> None: - SpaceNet5(root=dataset.root, download=True) - - def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet5(str(tmp_path)) - - def test_collection_checksum(self, dataset: SpaceNet5) -> None: - dataset.collection_md5_dict['sn5_AOI_8_Mumbai'] = 'randommd5hash123' - with pytest.raises(RuntimeError, match='Collection sn5_AOI_8_Mumbai corrupted'): - SpaceNet5(root=dataset.root, download=True, checksum=True) - - def test_plot(self, dataset: SpaceNet5) -> None: - x = dataset[0].copy() - x['prediction'] = x['mask'] - dataset.plot(x, suptitle='Test') - plt.close() + def test_plot(self, dataset: SpaceNet) -> None: + x = dataset[0] dataset.plot(x, show_titles=False) plt.close() - dataset.plot({'image': x['image']}) - plt.close() - - -class TestSpaceNet6: - @pytest.fixture(params=['PAN', 'RGBNIR', 'PS-RGB', 'PS-RGBNIR', 'SAR-Intensity']) - def dataset( - self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> SpaceNet6: - monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset) - root = str(tmp_path) - transforms = nn.Identity() - return SpaceNet6( - root, image=request.param, transforms=transforms, download=True, api_key='' - ) - - def test_getitem(self, dataset: SpaceNet6) -> None: - x = dataset[0] - assert isinstance(x, dict) - assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['mask'], torch.Tensor) - if dataset.image == 'PS-RGB': - assert x['image'].shape[0] == 3 - elif dataset.image in ['RGBNIR', 'PS-RGBNIR']: - assert x['image'].shape[0] == 4 - else: - assert x['image'].shape[0] == 1 - - def test_len(self, dataset: SpaceNet6) -> None: - assert len(dataset) == 2 - - def test_already_downloaded(self, dataset: SpaceNet6) -> None: - SpaceNet6(root=dataset.root, download=True) - - def test_plot(self, dataset: SpaceNet6) -> None: - x = dataset[0].copy() x['prediction'] = x['mask'] dataset.plot(x, suptitle='Test') plt.close() - dataset.plot(x, show_titles=False) - plt.close() + def test_image_id(self, monkeypatch: MonkeyPatch, dataset: SpaceNet) -> None: + file_regex = r'global_monthly_(\d+.*\d+)' + monkeypatch.setattr(dataset, 'file_regex', file_regex) + dataset._image_id('global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160.tif') -class TestSpaceNet7: - @pytest.fixture(params=['train', 'test']) - def dataset( - self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path - ) -> SpaceNet7: - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - test_md5 = { - 'sn7_train_source': '197bfa8842a40b09b6837b824a6370e0', - 'sn7_train_labels': '625ad8a989a5105bc766a53e53df4d0e', - 'sn7_test_source': '461f59eb21bb4f416c867f5037dfceeb', - } - - monkeypatch.setattr(SpaceNet7, 'collection_md5_dict', test_md5) - root = str(tmp_path) - transforms = nn.Identity() - return SpaceNet7( - root, split=request.param, transforms=transforms, download=True, api_key='' - ) - - def test_getitem(self, dataset: SpaceNet7) -> None: - x = dataset[0] - assert isinstance(x, dict) - assert isinstance(x['image'], torch.Tensor) - if dataset.split == 'train': - assert isinstance(x['mask'], torch.Tensor) - - def test_len(self, dataset: SpaceNet7) -> None: - if dataset.split == 'train': - assert len(dataset) == 2 - else: - assert len(dataset) == 1 - - def test_already_downloaded(self, dataset: SpaceNet4) -> None: - SpaceNet7(root=dataset.root, download=True) - - def test_not_downloaded(self, tmp_path: Path) -> None: - with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SpaceNet7(str(tmp_path)) - - def test_collection_checksum(self, dataset: SpaceNet4) -> None: - dataset.collection_md5_dict['sn7_train_source'] = 'randommd5hash123' - with pytest.raises(RuntimeError, match='Collection sn7_train_source corrupted'): - SpaceNet7(root=dataset.root, download=True, checksum=True) - - def test_plot(self, dataset: SpaceNet7) -> None: - x = dataset[0].copy() - if dataset.split == 'train': - x['prediction'] = x['mask'] - dataset.plot(x, suptitle='Test') - plt.close() - dataset.plot(x, show_titles=False) - plt.close() + def test_list_files(self, monkeypatch: MonkeyPatch, dataset: SpaceNet) -> None: + directory_glob = os.path.join('**', 'AOI_{aoi}_*', '{product}') + monkeypatch.setattr(dataset, 'directory_glob', directory_glob) + dataset._list_files(aoi=1) diff --git a/tests/datasets/test_ssl4eo.py b/tests/datasets/test_ssl4eo.py index ad45798946e..32597f02b20 100644 --- a/tests/datasets/test_ssl4eo.py +++ b/tests/datasets/test_ssl4eo.py @@ -14,21 +14,14 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestSSL4EOL: @pytest.fixture(params=zip(SSL4EOL.metadata.keys(), [1, 1, 2, 2, 4])) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SSL4EOL: - monkeypatch.setattr(torchgeo.datasets.ssl4eo, 'download_url', download_url) - url = os.path.join('tests', 'data', 'ssl4eo', 'l', 'ssl4eo_l_{0}.tar.gz{1}') monkeypatch.setattr(SSL4EOL, 'url', url) @@ -61,7 +54,7 @@ def dataset( } monkeypatch.setattr(SSL4EOL, 'checksums', checksums) - root = str(tmp_path) + root = tmp_path split, seasons = request.param transforms = nn.Identity() return SSL4EOL(root, split, seasons, transforms, download=True, checksum=True) @@ -88,14 +81,14 @@ def test_already_extracted(self, dataset: SSL4EOL) -> None: def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'ssl4eo', 'l', '*.tar.gz*') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) SSL4EOL(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SSL4EOL(str(tmp_path)) + SSL4EOL(tmp_path) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -148,7 +141,7 @@ def test_extract(self, tmp_path: Path) -> None: os.path.join('tests', 'data', 'ssl4eo', 's12', filename), tmp_path / filename, ) - SSL4EOS12(str(tmp_path)) + SSL4EOS12(tmp_path) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -156,7 +149,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SSL4EOS12(str(tmp_path)) + SSL4EOS12(tmp_path) def test_plot(self, dataset: SSL4EOS12) -> None: sample = dataset[0] diff --git a/tests/datasets/test_ssl4eo_benchmark.py b/tests/datasets/test_ssl4eo_benchmark.py index db1d36f73b0..fbf818294ef 100644 --- a/tests/datasets/test_ssl4eo_benchmark.py +++ b/tests/datasets/test_ssl4eo_benchmark.py @@ -15,7 +15,6 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import ( CDL, NLCD, @@ -25,10 +24,6 @@ ) -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestSSL4EOLBenchmark: @pytest.fixture( params=product( @@ -40,11 +35,7 @@ class TestSSL4EOLBenchmark: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SSL4EOLBenchmark: - monkeypatch.setattr( - torchgeo.datasets.ssl4eo_benchmark, 'download_url', download_url - ) - root = str(tmp_path) - + root = tmp_path url = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '{}.tar.gz') monkeypatch.setattr(SSL4EOLBenchmark, 'url', url) @@ -140,14 +131,14 @@ def test_already_extracted(self, dataset: SSL4EOLBenchmark) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '*.tar.gz') - root = str(tmp_path) + root = tmp_path for tarfile in glob.iglob(pathname): shutil.copy(tarfile, root) SSL4EOLBenchmark(root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SSL4EOLBenchmark(str(tmp_path)) + SSL4EOLBenchmark(tmp_path) def test_plot(self, dataset: SSL4EOLBenchmark) -> None: sample = dataset[0] diff --git a/tests/datasets/test_sustainbench_crop_yield.py b/tests/datasets/test_sustainbench_crop_yield.py index 36e746aaf92..550c8b8adc1 100644 --- a/tests/datasets/test_sustainbench_crop_yield.py +++ b/tests/datasets/test_sustainbench_crop_yield.py @@ -12,29 +12,20 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestSustainBenchCropYield: @pytest.fixture(params=['train', 'dev', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> SustainBenchCropYield: - monkeypatch.setattr( - torchgeo.datasets.sustainbench_crop_yield, 'download_url', download_url - ) - md5 = '7a5591794e14dd73d2b747cd2244acbc' monkeypatch.setattr(SustainBenchCropYield, 'md5', md5) url = os.path.join('tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip') monkeypatch.setattr(SustainBenchCropYield, 'url', url) monkeypatch.setattr(plt, 'show', lambda *args: None) - root = str(tmp_path) + root = tmp_path split = request.param countries = ['argentina', 'brazil', 'usa'] transforms = nn.Identity() @@ -49,7 +40,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join( 'tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip' ) - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) SustainBenchCropYield(root) @@ -72,7 +63,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - SustainBenchCropYield(str(tmp_path)) + SustainBenchCropYield(tmp_path) def test_plot(self, dataset: SustainBenchCropYield) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_treesatai.py b/tests/datasets/test_treesatai.py new file mode 100644 index 00000000000..7788adb9791 --- /dev/null +++ b/tests/datasets/test_treesatai.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import glob +import os +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch.nn as nn +from pytest import MonkeyPatch +from torch import Tensor + +from torchgeo.datasets import DatasetNotFoundError, TreeSatAI + +root = os.path.join('tests', 'data', 'treesatai') +md5s = { + 'aerial_60m_acer_pseudoplatanus.zip': '', + 'labels.zip': '', + 's1.zip': '', + 's2.zip': '', + 'test_filenames.lst': '', + 'train_filenames.lst': '', +} + + +class TestTreeSatAI: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch) -> TreeSatAI: + monkeypatch.setattr(TreeSatAI, 'url', root + os.sep) + monkeypatch.setattr(TreeSatAI, 'md5s', md5s) + transforms = nn.Identity() + return TreeSatAI(root, transforms=transforms) + + def test_getitem(self, dataset: TreeSatAI) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['label'], Tensor) + for sensor in dataset.sensors: + assert isinstance(x[f'image_{sensor}'], Tensor) + + def test_len(self, dataset: TreeSatAI) -> None: + assert len(dataset) == 9 + + def test_download(self, dataset: TreeSatAI, tmp_path: Path) -> None: + TreeSatAI(tmp_path, download=True) + + def test_extract(self, dataset: TreeSatAI, tmp_path: Path) -> None: + for file in glob.iglob(os.path.join(root, '*.*')): + shutil.copy(file, tmp_path) + TreeSatAI(tmp_path) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + TreeSatAI(tmp_path) + + def test_plot(self, dataset: TreeSatAI) -> None: + x = dataset[0] + x['prediction'] = x['label'] + dataset.plot(x) + plt.close() diff --git a/tests/datasets/test_ucmerced.py b/tests/datasets/test_ucmerced.py index bedeb588c66..10fa46a3c01 100644 --- a/tests/datasets/test_ucmerced.py +++ b/tests/datasets/test_ucmerced.py @@ -13,48 +13,19 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, UCMerced -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestUCMerced: @pytest.fixture(params=['train', 'val', 'test']) def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> UCMerced: - monkeypatch.setattr(torchgeo.datasets.ucmerced, 'download_url', download_url) - md5 = 'a42ef8779469d196d8f2971ee135f030' - monkeypatch.setattr(UCMerced, 'md5', md5) - url = os.path.join('tests', 'data', 'ucmerced', 'UCMerced_LandUse.zip') + url = os.path.join('tests', 'data', 'ucmerced') + os.sep monkeypatch.setattr(UCMerced, 'url', url) - monkeypatch.setattr( - UCMerced, - 'split_urls', - { - 'train': os.path.join( - 'tests', 'data', 'ucmerced', 'uc_merced-train.txt' - ), - 'val': os.path.join('tests', 'data', 'ucmerced', 'uc_merced-val.txt'), - 'test': os.path.join('tests', 'data', 'ucmerced', 'uc_merced-test.txt'), - }, - ) - monkeypatch.setattr( - UCMerced, - 'split_md5s', - { - 'train': 'a01fa9f13333bb176fc1bfe26ff4c711', - 'val': 'a01fa9f13333bb176fc1bfe26ff4c711', - 'test': 'a01fa9f13333bb176fc1bfe26ff4c711', - }, - ) - root = str(tmp_path) split = request.param transforms = nn.Identity() - return UCMerced(root, split, transforms, download=True, checksum=True) + return UCMerced(tmp_path, split, transforms, download=True) def test_getitem(self, dataset: UCMerced) -> None: x = dataset[0] @@ -71,18 +42,18 @@ def test_add(self, dataset: UCMerced) -> None: assert len(ds) == 8 def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None: - UCMerced(root=str(tmp_path), download=True) + UCMerced(tmp_path) def test_already_downloaded_not_extracted( self, dataset: UCMerced, tmp_path: Path ) -> None: shutil.rmtree(dataset.root) - download_url(dataset.url, root=str(tmp_path)) - UCMerced(root=str(tmp_path), download=False) + shutil.copy(dataset.url + dataset.filename, tmp_path) + UCMerced(tmp_path) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - UCMerced(str(tmp_path)) + UCMerced(tmp_path) def test_plot(self, dataset: UCMerced) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_usavars.py b/tests/datasets/test_usavars.py index 0566a1f3153..9e324cb1ee3 100644 --- a/tests/datasets/test_usavars.py +++ b/tests/datasets/test_usavars.py @@ -13,14 +13,9 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, USAVars -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestUSAVars: @pytest.fixture( params=zip( @@ -35,8 +30,6 @@ class TestUSAVars: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> USAVars: - monkeypatch.setattr(torchgeo.datasets.usavars, 'download_url', download_url) - md5 = 'b504580a00bdc27097d5421dec50481b' monkeypatch.setattr(USAVars, 'md5', md5) @@ -73,7 +66,7 @@ def dataset( } monkeypatch.setattr(USAVars, 'split_metadata', split_metadata) - root = str(tmp_path) + root = tmp_path split, labels = request.param transforms = nn.Identity() @@ -109,7 +102,7 @@ def test_already_extracted(self, dataset: USAVars) -> None: def test_already_downloaded(self, tmp_path: Path) -> None: pathname = os.path.join('tests', 'data', 'usavars', 'uar.zip') - root = str(tmp_path) + root = tmp_path shutil.copy(pathname, root) csvs = [ 'elevation.csv', @@ -130,7 +123,7 @@ def test_already_downloaded(self, tmp_path: Path) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - USAVars(str(tmp_path)) + USAVars(tmp_path) def test_plot(self, dataset: USAVars) -> None: dataset.plot(dataset[0], suptitle='Test') diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index c53bfbed0fc..141709ceba8 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import math import os import pickle import re -import shutil import sys from datetime import datetime from pathlib import Path @@ -15,20 +13,14 @@ import numpy as np import pytest import torch -from pytest import MonkeyPatch from rasterio.crs import CRS -import torchgeo.datasets.utils from torchgeo.datasets import BoundingBox, DependencyNotFoundError from torchgeo.datasets.utils import ( Executable, array_to_tensor, concat_samples, disambiguate_timestamp, - download_and_extract_archive, - download_radiant_mlhub_collection, - download_radiant_mlhub_dataset, - extract_archive, lazy_import, merge_samples, percentile_normalization, @@ -39,86 +31,6 @@ ) -class MLHubDataset: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join( - 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' - ) - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - glob_path = os.path.join( - 'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz' - ) - for tarball in glob.iglob(glob_path): - shutil.copy(tarball, output_dir) - - -def fetch_dataset(dataset_id: str, **kwargs: str) -> MLHubDataset: - return MLHubDataset() - - -def fetch_collection(collection_id: str, **kwargs: str) -> Collection: - return Collection() - - -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - -@pytest.mark.parametrize( - 'src', - [ - os.path.join('cowc_detection', 'COWC_Detection_Columbus_CSUAV_AFRL.tbz'), - os.path.join('cowc_detection', 'COWC_test_list_detection.txt.bz2'), - os.path.join('vhr10', 'NWPU VHR-10 dataset.rar'), - os.path.join('landcoverai', 'landcover.ai.v1.zip'), - os.path.join('chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip'), - os.path.join('sen12ms', 'ROIs1158_spring_lc.tar.gz'), - ], -) -def test_extract_archive(src: str, tmp_path: Path) -> None: - if src.endswith('.rar'): - pytest.importorskip('rarfile', minversion='4') - if src.startswith('chesapeake'): - pytest.importorskip('zipfile_deflate64') - extract_archive(os.path.join('tests', 'data', src), str(tmp_path)) - - -def test_unsupported_scheme() -> None: - with pytest.raises( - RuntimeError, match='src file has unknown archival/compression scheme' - ): - extract_archive('foo.bar') - - -def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) - download_and_extract_archive( - os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'), - str(tmp_path), - ) - - -def test_download_radiant_mlhub_dataset( - tmp_path: Path, monkeypatch: MonkeyPatch -) -> None: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset) - download_radiant_mlhub_dataset('', str(tmp_path)) - - -def test_download_radiant_mlhub_collection( - tmp_path: Path, monkeypatch: MonkeyPatch -) -> None: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection) - download_radiant_mlhub_collection('', str(tmp_path)) - - class TestBoundingBox: def test_repr_str(self) -> None: bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4) @@ -457,6 +369,18 @@ def test_invalid_t(self) -> None: datetime(2021, 9, 13, 17, 21, 53, 123).timestamp(), datetime(2021, 9, 13, 17, 21, 53, 123).timestamp(), ), + ( + '2021-09-13%2017:21:53', + '%Y-%m-%d%%20%H:%M:%S', + datetime(2021, 9, 13, 17, 21, 53, 0).timestamp(), + datetime(2021, 9, 13, 17, 21, 53, 999999).timestamp(), + ), + ( + '2021%m', + '%Y%%m', + datetime(2021, 1, 1, 0, 0, 0, 0).timestamp(), + datetime(2021, 12, 31, 23, 59, 59, 999999).timestamp(), + ), ], ) def test_disambiguate_timestamp( @@ -597,7 +521,7 @@ def test_lazy_import_missing(name: str) -> None: def test_azcopy(tmp_path: Path, azcopy: Executable) -> None: source = os.path.join('tests', 'data', 'cyclone') azcopy('sync', source, tmp_path, '--recursive=true') - assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels') + assert os.path.exists(tmp_path / 'test') def test_which() -> None: diff --git a/tests/datasets/test_vaihingen.py b/tests/datasets/test_vaihingen.py index e4b36b99edd..7b6cbe878f9 100644 --- a/tests/datasets/test_vaihingen.py +++ b/tests/datasets/test_vaihingen.py @@ -49,9 +49,9 @@ def test_extract(self, tmp_path: Path) -> None: ] for filename in filenames: shutil.copyfile( - os.path.join(root, filename), os.path.join(str(tmp_path), filename) + os.path.join(root, filename), os.path.join(tmp_path, filename) ) - Vaihingen2D(root=str(tmp_path)) + Vaihingen2D(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: filenames = [ @@ -62,7 +62,7 @@ def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, filename), 'w') as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - Vaihingen2D(root=str(tmp_path), checksum=True) + Vaihingen2D(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -70,7 +70,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - Vaihingen2D(str(tmp_path)) + Vaihingen2D(tmp_path) def test_plot(self, dataset: Vaihingen2D) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_vhr10.py b/tests/datasets/test_vhr10.py index dee46c1db88..aa0920d69e7 100644 --- a/tests/datasets/test_vhr10.py +++ b/tests/datasets/test_vhr10.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -13,15 +12,9 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset -import torchgeo.datasets.utils from torchgeo.datasets import VHR10, DatasetNotFoundError pytest.importorskip('pycocotools') -pytest.importorskip('rarfile', minversion='4') - - -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) class TestVHR10: @@ -29,17 +22,15 @@ class TestVHR10: def dataset( self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest ) -> VHR10: - monkeypatch.setattr(torchgeo.datasets.vhr10, 'download_url', download_url) - monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url) - url = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.rar') + url = os.path.join('tests', 'data', 'vhr10', 'NWPU VHR-10 dataset.zip') monkeypatch.setitem(VHR10.image_meta, 'url', url) - md5 = '92769845cae6a4e8c74bfa1a0d1d4a80' + md5 = '497cb7e19a12c7d5abbefe8eac71d22d' monkeypatch.setitem(VHR10.image_meta, 'md5', md5) url = os.path.join('tests', 'data', 'vhr10', 'annotations.json') monkeypatch.setitem(VHR10.target_meta, 'url', url) md5 = '567c4cd8c12624864ff04865de504c58' monkeypatch.setitem(VHR10.target_meta, 'md5', md5) - root = str(tmp_path) + root = tmp_path split = request.param transforms = nn.Identity() return VHR10(root, split, transforms, download=True, checksum=True) @@ -78,7 +69,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - VHR10(str(tmp_path)) + VHR10(tmp_path) def test_plot(self, dataset: VHR10) -> None: pytest.importorskip('skimage', minversion='0.19') diff --git a/tests/datasets/test_western_usa_live_fuel_moisture.py b/tests/datasets/test_western_usa_live_fuel_moisture.py index e2c9120ae02..d71f18263fa 100644 --- a/tests/datasets/test_western_usa_live_fuel_moisture.py +++ b/tests/datasets/test_western_usa_live_fuel_moisture.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import pytest @@ -11,41 +10,23 @@ from pytest import MonkeyPatch from torchgeo.datasets import DatasetNotFoundError, WesternUSALiveFuelMoisture - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - tarball_path = os.path.join( - 'tests', - 'data', - 'western_usa_live_fuel_moisture', - 'su_sar_moisture_content.tar.gz', - ) - shutil.copy(tarball_path, output_dir) - - -def fetch(collection_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestWesternUSALiveFuelMoisture: @pytest.fixture def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path + self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path ) -> WesternUSALiveFuelMoisture: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - md5 = 'ecbc9269dd27c4efe7aa887960054351' - monkeypatch.setattr(WesternUSALiveFuelMoisture, 'md5', md5) - root = str(tmp_path) + url = os.path.join('tests', 'data', 'western_usa_live_fuel_moisture') + monkeypatch.setattr(WesternUSALiveFuelMoisture, 'url', url) transforms = nn.Identity() return WesternUSALiveFuelMoisture( - root, transforms=transforms, download=True, api_key='', checksum=True + tmp_path, transforms=transforms, download=True ) - @pytest.mark.parametrize('index', [0, 1, 2]) - def test_getitem(self, dataset: WesternUSALiveFuelMoisture, index: int) -> None: - x = dataset[index] + def test_getitem(self, dataset: WesternUSALiveFuelMoisture) -> None: + x = dataset[0] assert isinstance(x, dict) assert isinstance(x['input'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) @@ -53,21 +34,9 @@ def test_getitem(self, dataset: WesternUSALiveFuelMoisture, index: int) -> None: def test_len(self, dataset: WesternUSALiveFuelMoisture) -> None: assert len(dataset) == 3 - def test_already_downloaded(self, tmp_path: Path) -> None: - pathname = os.path.join( - 'tests', - 'data', - 'western_usa_live_fuel_moisture', - 'su_sar_moisture_content.tar.gz', - ) - root = str(tmp_path) - shutil.copy(pathname, root) - WesternUSALiveFuelMoisture(root) + def test_already_downloaded(self, dataset: WesternUSALiveFuelMoisture) -> None: + WesternUSALiveFuelMoisture(dataset.root) def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - WesternUSALiveFuelMoisture(str(tmp_path)) - - def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None: - with pytest.raises(AssertionError, match='Invalid input variable name.'): - WesternUSALiveFuelMoisture(dataset.root, input_features=['foo']) + WesternUSALiveFuelMoisture(tmp_path) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview.py similarity index 96% rename from tests/datasets/test_xview2.py rename to tests/datasets/test_xview.py index 7689acf5f78..c54b597fadf 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview.py @@ -61,7 +61,7 @@ def test_extract(self, tmp_path: Path) -> None: ), os.path.join(tmp_path, 'test_images_labels_targets.tar.gz'), ) - XView2(root=str(tmp_path)) + XView2(root=tmp_path) def test_corrupted(self, tmp_path: Path) -> None: with open( @@ -73,7 +73,7 @@ def test_corrupted(self, tmp_path: Path) -> None: ) as f: f.write('bad') with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'): - XView2(root=str(tmp_path), checksum=True) + XView2(root=tmp_path, checksum=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -81,7 +81,7 @@ def test_invalid_split(self) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - XView2(str(tmp_path)) + XView2(tmp_path) def test_plot(self, dataset: XView2) -> None: x = dataset[0].copy() diff --git a/tests/datasets/test_zuericrop.py b/tests/datasets/test_zuericrop.py index bea0d9e8519..e985b10ada6 100644 --- a/tests/datasets/test_zuericrop.py +++ b/tests/datasets/test_zuericrop.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -11,20 +10,14 @@ import torch.nn as nn from pytest import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriCrop pytest.importorskip('h5py', minversion='3.6') -def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: - shutil.copy(url, root) - - class TestZueriCrop: @pytest.fixture def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop: - monkeypatch.setattr(torchgeo.datasets.zuericrop, 'download_url', download_url) data_dir = os.path.join('tests', 'data', 'zuericrop') urls = [ os.path.join(data_dir, 'ZueriCrop.hdf5'), @@ -33,7 +26,7 @@ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop: md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b'] monkeypatch.setattr(ZueriCrop, 'urls', urls) monkeypatch.setattr(ZueriCrop, 'md5s', md5s) - root = str(tmp_path) + root = tmp_path transforms = nn.Identity() return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True) @@ -67,7 +60,7 @@ def test_already_downloaded(self, dataset: ZueriCrop) -> None: def test_not_downloaded(self, tmp_path: Path) -> None: with pytest.raises(DatasetNotFoundError, match='Dataset not found'): - ZueriCrop(str(tmp_path)) + ZueriCrop(tmp_path) def test_invalid_bands(self) -> None: with pytest.raises(ValueError): diff --git a/tests/models/test_api.py b/tests/models/test_api.py index 1aecf8341fa..c5a56d1808a 100644 --- a/tests/models/test_api.py +++ b/tests/models/test_api.py @@ -13,7 +13,10 @@ DOFALarge16_Weights, ResNet18_Weights, ResNet50_Weights, + ResNet152_Weights, + ScaleMAELarge16_Weights, Swin_V2_B_Weights, + Swin_V2_T_Weights, ViTSmall16_Weights, dofa_base_patch16_224, dofa_large_patch16_224, @@ -23,7 +26,10 @@ list_models, resnet18, resnet50, + resnet152, + scalemae_large_patch16, swin_v2_b, + swin_v2_t, vit_small_patch16_224, ) @@ -32,6 +38,9 @@ dofa_large_patch16_224, resnet18, resnet50, + resnet152, + scalemae_large_patch16, + swin_v2_t, swin_v2_b, vit_small_patch16_224, ] @@ -40,6 +49,9 @@ DOFALarge16_Weights, ResNet18_Weights, ResNet50_Weights, + ResNet152_Weights, + ScaleMAELarge16_Weights, + Swin_V2_T_Weights, Swin_V2_B_Weights, ViTSmall16_Weights, ] @@ -68,3 +80,8 @@ def test_get_weight(enum: WeightsEnum) -> None: def test_list_models() -> None: models = [builder.__name__ for builder in builders] assert set(models) == set(list_models()) + + +def test_invalid_model() -> None: + with pytest.raises(ValueError, match='bad_model is not a valid WeightsEnum'): + get_weight('bad_model') diff --git a/tests/models/test_croma.py b/tests/models/test_croma.py new file mode 100644 index 00000000000..83f63770339 --- /dev/null +++ b/tests/models/test_croma.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path + +import pytest +import torch +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torchvision.models._api import WeightsEnum + +from torchgeo.models import ( + CROMA, + CROMABase_Weights, + CROMALarge_Weights, + croma_base, + croma_large, +) + + +def save_model(model: torch.nn.Module, path: Path) -> None: + state_dict = { + 's1_encoder': model.s1_encoder.state_dict(), + 's1_GAP_FFN': model.s1_GAP_FFN.state_dict(), + 's2_encoder': model.s2_encoder.state_dict(), + 's2_GAP_FFN': model.s2_GAP_FFN.state_dict(), + 'joint_encoder': model.joint_encoder.state_dict(), + } + torch.save(state_dict, path) + + +class TestCROMA: + @pytest.mark.parametrize('modalities', [['sar'], ['optical'], ['sar', 'optical']]) + def test_croma(self, modalities: list[str]) -> None: + batch_size = 2 + model = CROMA(modalities=modalities) + if 'sar' in modalities: + sar_images = torch.randn( + [batch_size, 2, model.image_size, model.image_size] + ) + else: + sar_images = None + if 'optical' in modalities: + optical_images = torch.randn( + [batch_size, 12, model.image_size, model.image_size] + ) + else: + optical_images = None + out = model(sar_images, optical_images) + for modality in modalities: + assert f'{modality}_encodings' in out + if set(modalities) == {'sar', 'optical'}: + assert 'joint_encodings' in out + + +class TestCROMABase: + @pytest.fixture(params=[*CROMABase_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = croma_base() + save_model(model, path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_croma(self) -> None: + croma_base() + + def test_croma_weights(self, mocked_weights: WeightsEnum) -> None: + croma_base(weights=mocked_weights) + + @pytest.mark.slow + def test_croma_download(self, weights: WeightsEnum) -> None: + croma_base(weights=weights) + + +class TestCROMALarge: + @pytest.fixture(params=[*CROMALarge_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = croma_large() + save_model(model, path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_croma(self) -> None: + croma_large() + + def test_croma_weights(self, mocked_weights: WeightsEnum) -> None: + croma_large(weights=mocked_weights) + + @pytest.mark.slow + def test_croma_download(self, weights: WeightsEnum) -> None: + croma_large(weights=weights) diff --git a/tests/models/test_dofa.py b/tests/models/test_dofa.py index 34f1701d7b1..d6be97c2100 100644 --- a/tests/models/test_dofa.py +++ b/tests/models/test_dofa.py @@ -2,11 +2,9 @@ # Licensed under the MIT License. from pathlib import Path -from typing import Any import pytest import torch -import torchvision from _pytest.fixtures import SubRequest from pytest import MonkeyPatch from torchvision.models._api import WeightsEnum @@ -22,11 +20,6 @@ ) -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - class TestDOFA: @pytest.mark.parametrize( 'wavelengths', @@ -86,7 +79,11 @@ def weights(self, request: SubRequest) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = dofa_base_patch16_224() @@ -95,7 +92,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_dofa(self) -> None: @@ -123,7 +119,11 @@ def weights(self, request: SubRequest) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = dofa_large_patch16_224() @@ -132,7 +132,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_dofa(self) -> None: diff --git a/tests/models/test_resnet.py b/tests/models/test_resnet.py index ea5397e6099..17cfe520ad9 100644 --- a/tests/models/test_resnet.py +++ b/tests/models/test_resnet.py @@ -2,22 +2,22 @@ # Licensed under the MIT License. from pathlib import Path -from typing import Any import pytest import timm import torch -import torchvision from _pytest.fixtures import SubRequest from pytest import MonkeyPatch from torchvision.models._api import WeightsEnum -from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 - - -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict +from torchgeo.models import ( + ResNet18_Weights, + ResNet50_Weights, + ResNet152_Weights, + resnet18, + resnet50, + resnet152, +) class TestResNet18: @@ -27,7 +27,11 @@ def weights(self, request: SubRequest) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model('resnet18', in_chans=weights.meta['in_chans']) @@ -36,7 +40,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_resnet(self) -> None: @@ -45,10 +48,14 @@ def test_resnet(self) -> None: def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: resnet18(weights=mocked_weights) + def test_bands(self, mocked_weights: WeightsEnum) -> None: + if 'bands' in mocked_weights.meta: + assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans'] + def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = mocked_weights.meta['in_chans'] sample = { - 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) } mocked_weights.transforms(sample) @@ -64,7 +71,11 @@ def weights(self, request: SubRequest) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model('resnet50', in_chans=weights.meta['in_chans']) @@ -73,7 +84,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_resnet(self) -> None: @@ -82,13 +92,61 @@ def test_resnet(self) -> None: def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: resnet50(weights=mocked_weights) + def test_bands(self, mocked_weights: WeightsEnum) -> None: + if 'bands' in mocked_weights.meta: + assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans'] + def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = mocked_weights.meta['in_chans'] sample = { - 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) } mocked_weights.transforms(sample) @pytest.mark.slow def test_resnet_download(self, weights: WeightsEnum) -> None: resnet50(weights=weights) + + +class TestResNet152: + @pytest.fixture(params=[*ResNet152_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = timm.create_model('resnet152', in_chans=weights.meta['in_chans']) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_resnet(self) -> None: + resnet152() + + def test_resnet_weights(self, mocked_weights: WeightsEnum) -> None: + resnet152(weights=mocked_weights) + + def test_bands(self, mocked_weights: WeightsEnum) -> None: + if 'bands' in mocked_weights.meta: + assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans'] + + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta['in_chans'] + sample = { + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) + } + mocked_weights.transforms(sample) + + @pytest.mark.slow + def test_resnet_download(self, weights: WeightsEnum) -> None: + resnet152(weights=weights) diff --git a/tests/models/test_scale_mae.py b/tests/models/test_scale_mae.py new file mode 100644 index 00000000000..f901534110f --- /dev/null +++ b/tests/models/test_scale_mae.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path + +import pytest +import torch +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torchvision.models._api import WeightsEnum + +from torchgeo.models import ScaleMAELarge16_Weights, scalemae_large_patch16 + + +class TestScaleMAE: + @pytest.fixture(params=[*ScaleMAELarge16_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = scalemae_large_patch16() + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_scalemae(self) -> None: + scalemae_large_patch16() + + def test_scalemae_forward_pass(self) -> None: + model = scalemae_large_patch16(img_size=64, num_classes=2) + x = torch.randn(1, 3, 64, 64) + y = model(x) + assert y.shape == (1, 2) + + def test_scalemae_weights(self, mocked_weights: WeightsEnum) -> None: + scalemae_large_patch16(weights=mocked_weights) + + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta['in_chans'] + sample = { + 'image': torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + } + mocked_weights.transforms(sample) + + def test_scalemae_weights_diff_image_size( + self, mocked_weights: WeightsEnum + ) -> None: + scalemae_large_patch16(weights=mocked_weights, img_size=256) + + @pytest.mark.slow + def test_scalemae_download(self, weights: WeightsEnum) -> None: + scalemae_large_patch16(weights=weights) diff --git a/tests/models/test_swin.py b/tests/models/test_swin.py index 489b3642ce7..4ae0f08002b 100644 --- a/tests/models/test_swin.py +++ b/tests/models/test_swin.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. from pathlib import Path -from typing import Any import pytest import torch @@ -11,12 +10,56 @@ from pytest import MonkeyPatch from torchvision.models._api import WeightsEnum -from torchgeo.models import Swin_V2_B_Weights, swin_v2_b +from torchgeo.models import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict +class TestSwin_V2_T: + @pytest.fixture(params=[*Swin_V2_T_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, + ) -> WeightsEnum: + path = tmp_path / f'{weights}.pth' + model = torchvision.models.swin_v2_t() + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, 'url', str(path)) + except AttributeError: + monkeypatch.setattr(weights, 'url', str(path)) + return weights + + def test_swin_v2_t(self) -> None: + swin_v2_t() + + def test_swin_v2_t_weights(self, mocked_weights: WeightsEnum) -> None: + swin_v2_t(weights=mocked_weights) + + def test_bands(self, mocked_weights: WeightsEnum) -> None: + if 'bands' in mocked_weights.meta: + assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans'] + + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = mocked_weights.meta['in_chans'] + sample = { + 'image': torch.arange(c * 256 * 256, dtype=torch.float).view(c, 256, 256) + } + mocked_weights.transforms(sample) + + @pytest.mark.slow + def test_swin_v2_t_download(self, weights: WeightsEnum) -> None: + swin_v2_t(weights=weights) class TestSwin_V2_B: @@ -26,16 +69,24 @@ def weights(self, request: SubRequest) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = torchvision.models.swin_v2_b() + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) torch.save(model.state_dict(), path) try: monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_swin_v2_b(self) -> None: @@ -44,6 +95,10 @@ def test_swin_v2_b(self) -> None: def test_swin_v2_b_weights(self, mocked_weights: WeightsEnum) -> None: swin_v2_b(weights=mocked_weights) + def test_bands(self, mocked_weights: WeightsEnum) -> None: + if 'bands' in mocked_weights.meta: + assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans'] + def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = mocked_weights.meta['in_chans'] sample = { diff --git a/tests/models/test_vit.py b/tests/models/test_vit.py index b69e2398996..4ae0e47bfbc 100644 --- a/tests/models/test_vit.py +++ b/tests/models/test_vit.py @@ -2,12 +2,10 @@ # Licensed under the MIT License. from pathlib import Path -from typing import Any import pytest import timm import torch -import torchvision from _pytest.fixtures import SubRequest from pytest import MonkeyPatch from torchvision.models._api import WeightsEnum @@ -15,11 +13,6 @@ from torchgeo.models import ViTSmall16_Weights, vit_small_patch16_224 -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - class TestViTSmall16: @pytest.fixture(params=[*ViTSmall16_Weights]) def weights(self, request: SubRequest) -> WeightsEnum: @@ -27,7 +20,11 @@ def weights(self, request: SubRequest) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -38,7 +35,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_vit(self) -> None: @@ -47,6 +43,10 @@ def test_vit(self) -> None: def test_vit_weights(self, mocked_weights: WeightsEnum) -> None: vit_small_patch16_224(weights=mocked_weights) + def test_bands(self, mocked_weights: WeightsEnum) -> None: + if 'bands' in mocked_weights.meta: + assert len(mocked_weights.meta['bands']) == mocked_weights.meta['in_chans'] + def test_transforms(self, mocked_weights: WeightsEnum) -> None: c = mocked_weights.meta['in_chans'] sample = { diff --git a/tests/samplers/test_batch.py b/tests/samplers/test_batch.py index 59c8aaa00be..199239a0e79 100644 --- a/tests/samplers/test_batch.py +++ b/tests/samplers/test_batch.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -144,6 +145,17 @@ def test_weighted_sampling(self) -> None: for bbox in batch: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + generator1 = torch.Generator().manual_seed(0) + generator2 = torch.Generator().manual_seed(0) + sampler1 = RandomBatchGeoSampler(ds, 1, 1, generator=generator1) + sampler2 = RandomBatchGeoSampler(ds, 1, 1, generator=generator2) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 1416368098a..e2c829f1b9e 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -6,6 +6,7 @@ from itertools import product import pytest +import torch from _pytest.fixtures import SubRequest from rasterio.crs import CRS from torch.utils.data import DataLoader @@ -139,6 +140,17 @@ def test_weighted_sampling(self) -> None: for bbox in sampler: assert bbox == BoundingBox(0, 10, 0, 10, 0, 10) + def test_random_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + generator1 = torch.Generator().manual_seed(0) + generator2 = torch.Generator().manual_seed(0) + sampler1 = RandomGeoSampler(ds, 1, 1, generator=generator1) + sampler2 = RandomGeoSampler(ds, 1, 1, generator=generator2) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( @@ -288,6 +300,18 @@ def test_point_data(self) -> None: for _ in sampler: continue + def test_shuffle_seed(self) -> None: + ds = CustomGeoDataset() + ds.index.insert(0, (0, 10, 0, 10, 0, 10)) + ds.index.insert(1, (0, 11, 0, 11, 0, 11)) + generator1 = torch.Generator().manual_seed(0) + generator2 = torch.Generator().manual_seed(0) + sampler1 = PreChippedGeoSampler(ds, shuffle=True, generator=generator1) + sampler2 = PreChippedGeoSampler(ds, shuffle=True, generator=generator2) + sample1 = next(iter(sampler1)) + sample2 = next(iter(sampler2)) + assert sample1 == sample2 + @pytest.mark.slow @pytest.mark.parametrize('num_workers', [0, 1, 2]) def test_dataloader( diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 64143759a3c..808bf937220 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -3,13 +3,11 @@ import os from pathlib import Path -from typing import Any import pytest import timm import torch import torch.nn as nn -import torchvision from pytest import MonkeyPatch from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum @@ -21,11 +19,6 @@ from torchgeo.trainers.byol import BYOL, SimCLRAugmentation -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - class TestBYOL: def test_custom_augment_fn(self) -> None: model = resnet18() @@ -48,6 +41,7 @@ class TestBYOLTask: 'name', [ 'chesapeake_cvpr_prior_byol', + 'hyspecnet_byol', 'seco_byol_1', 'seco_byol_2', 'ssl4eo_l_byol_1', @@ -80,7 +74,7 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) @pytest.fixture def weights(self) -> WeightsEnum: @@ -88,7 +82,11 @@ def weights(self) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -99,7 +97,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index cd437f9faed..e2e2d9bb3e5 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -9,7 +9,6 @@ import timm import torch import torch.nn as nn -import torchvision from lightning.pytorch import Trainer from pytest import MonkeyPatch from torch.nn.modules import Module @@ -56,11 +55,6 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - def plot(*args: Any, **kwargs: Any) -> None: return None @@ -109,13 +103,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass @@ -125,7 +119,11 @@ def weights(self) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -136,7 +134,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: @@ -240,7 +237,7 @@ def test_freeze_backbone(self, model_name: str) -> None: class TestMultiLabelClassificationTask: @pytest.mark.parametrize( - 'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2'] + 'name', ['bigearthnet_all', 'bigearthnet_s1', 'bigearthnet_s2', 'treesatai'] ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool @@ -262,13 +259,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 035bdacc260..742cda3c371 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -97,13 +97,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_iobench.py b/tests/trainers/test_iobench.py index f67d19582ac..0fbde73bdc8 100644 --- a/tests/trainers/test_iobench.py +++ b/tests/trainers/test_iobench.py @@ -27,12 +27,12 @@ def test_trainer(self, name: str, fast_dev_run: bool) -> None: '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass diff --git a/tests/trainers/test_moco.py b/tests/trainers/test_moco.py index ba3b5641d6b..002944b929e 100644 --- a/tests/trainers/test_moco.py +++ b/tests/trainers/test_moco.py @@ -8,7 +8,6 @@ import pytest import timm import torch -import torchvision from pytest import MonkeyPatch from torch.nn import Module from torchvision.models._api import WeightsEnum @@ -25,16 +24,12 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - class TestMoCoTask: @pytest.mark.parametrize( 'name', [ 'chesapeake_cvpr_prior_moco', + 'hyspecnet_moco', 'seco_moco_1', 'seco_moco_2', 'ssl4eo_l_moco_1', @@ -69,7 +64,7 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) def test_version_warnings(self) -> None: with pytest.warns(UserWarning, match='MoCo v1 uses a memory bank'): @@ -89,7 +84,11 @@ def weights(self) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -100,7 +99,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index c62c808c72f..f4089283242 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -10,7 +10,6 @@ import timm import torch import torch.nn as nn -import torchvision from lightning.pytorch import Trainer from pytest import MonkeyPatch from torch.nn.modules import Module @@ -46,11 +45,6 @@ def setup(self, stage: str) -> None: self.predict_dataset = TropicalCyclone(split='test', **self.kwargs) -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - def plot(*args: Any, **kwargs: Any) -> None: return None @@ -65,12 +59,20 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return RegressionTestModel(**kwargs) @pytest.mark.parametrize( - 'name', ['cowc_counting', 'cyclone', 'sustainbench_crop_yield', 'skippd'] + 'name', + [ + 'cowc_counting', + 'cyclone', + 'digital_typhoon_id', + 'digital_typhoon_time', + 'sustainbench_crop_yield', + 'skippd', + ], ) def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: - if name == 'skippd': + if name in ['skippd', 'digital_typhoon_id', 'digital_typhoon_time']: pytest.importorskip('h5py', minversion='3.6') config = os.path.join('tests', 'conf', name + '.yaml') @@ -90,13 +92,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass @@ -106,7 +108,11 @@ def weights(self) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -117,7 +123,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: @@ -240,13 +245,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass @@ -261,7 +266,11 @@ def weights(self) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -272,7 +281,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index d8b207d5d2d..4bdd966a1bb 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -10,7 +10,6 @@ import timm import torch import torch.nn as nn -import torchvision from lightning.pytorch import Trainer from pytest import MonkeyPatch from torch.nn.modules import Module @@ -38,11 +37,6 @@ def create_model(**kwargs: Any) -> Module: return SegmentationTestModel(**kwargs) -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - def plot(*args: Any, **kwargs: Any) -> None: return None @@ -56,16 +50,20 @@ class TestSemanticSegmentationTask: 'name', [ 'agrifieldnet', + 'cabuar', 'chabud', 'chesapeake_cvpr_5', 'chesapeake_cvpr_7', 'deepglobelandcover', 'etci2021', + 'ftw', + 'geonrw', 'gid15', 'inria', 'l7irish', 'l8biome', 'landcoverai', + 'landcoverai100', 'loveda', 'naipchesapeake', 'potsdam2d', @@ -79,6 +77,7 @@ class TestSemanticSegmentationTask: 'sentinel2_south_america_soybean', 'southafricacroptype', 'spacenet1', + 'spacenet6', 'ssl4eo_l_benchmark_cdl', 'ssl4eo_l_benchmark_nlcd', 'vaihingen2d', @@ -88,15 +87,15 @@ def test_trainer( self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool ) -> None: match name: - case 'chabud': + case 'chabud' | 'cabuar': pytest.importorskip('h5py', minversion='3.6') + case 'ftw': + pytest.importorskip('pyarrow') case 'landcoverai': sha256 = ( 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b' ) monkeypatch.setattr(LandCoverAI, 'sha256', sha256) - case 'naipchesapeake': - pytest.importorskip('zipfile_deflate64') config = os.path.join('tests', 'conf', name + '.yaml') @@ -116,13 +115,13 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) try: - main(['test'] + args) + main(['test', *args]) except MisconfigurationException: pass try: - main(['predict'] + args) + main(['predict', *args]) except MisconfigurationException: pass @@ -132,7 +131,11 @@ def weights(self) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -143,7 +146,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: diff --git a/tests/trainers/test_simclr.py b/tests/trainers/test_simclr.py index b3cbee1fcab..3924b6e3785 100644 --- a/tests/trainers/test_simclr.py +++ b/tests/trainers/test_simclr.py @@ -8,7 +8,6 @@ import pytest import timm import torch -import torchvision from pytest import MonkeyPatch from torch.nn import Module from torchvision.models._api import WeightsEnum @@ -25,16 +24,12 @@ def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) -def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: - state_dict: dict[str, Any] = torch.load(url) - return state_dict - - class TestSimCLRTask: @pytest.mark.parametrize( 'name', [ 'chesapeake_cvpr_prior_simclr', + 'hyspecnet_simclr', 'seco_simclr_1', 'seco_simclr_2', 'ssl4eo_l_simclr_1', @@ -69,7 +64,7 @@ def test_trainer( '1', ] - main(['fit'] + args) + main(['fit', *args]) def test_version_warnings(self) -> None: with pytest.warns(UserWarning, match='SimCLR v1 only uses 2 layers'): @@ -87,7 +82,11 @@ def weights(self) -> WeightsEnum: @pytest.fixture def mocked_weights( - self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + self, + tmp_path: Path, + monkeypatch: MonkeyPatch, + weights: WeightsEnum, + load_state_dict_from_url: None, ) -> WeightsEnum: path = tmp_path / f'{weights}.pth' model = timm.create_model( @@ -98,7 +97,6 @@ def mocked_weights( monkeypatch.setattr(weights.value, 'url', str(path)) except AttributeError: monkeypatch.setattr(weights, 'url', str(path)) - monkeypatch.setattr(torchvision.models._api, 'load_state_dict_from_url', load) return weights def test_weight_file(self, checkpoint: str) -> None: diff --git a/tests/transforms/test_color.py b/tests/transforms/test_color.py index 2e271f89bc9..b235f7195f2 100644 --- a/tests/transforms/test_color.py +++ b/tests/transforms/test_color.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import kornia.augmentation as K import pytest import torch from torch import Tensor -from torchgeo.transforms import AugmentationSequential, RandomGrayscale +from torchgeo.transforms import RandomGrayscale @pytest.fixture @@ -33,12 +34,13 @@ def batch() -> dict[str, Tensor]: ], ) def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) + aug = K.AugmentationSequential( + RandomGrayscale(weights, p=1), keepdim=True, data_keys=None + ) output = aug(sample) assert output['image'].shape == sample['image'].shape - assert output['image'].sum() == sample['image'].sum() for i in range(1, 3): - assert torch.allclose(output['image'][0, 0], output['image'][0, i]) + assert torch.allclose(output['image'][0], output['image'][i]) @pytest.mark.parametrize( @@ -50,9 +52,8 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> ], ) def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None: - aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image']) + aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None) output = aug(batch) assert output['image'].shape == batch['image'].shape - assert output['image'].sum() == batch['image'].sum() for i in range(1, 3): assert torch.allclose(output['image'][0, 0], output['image'][0, i]) diff --git a/tests/transforms/test_indices.py b/tests/transforms/test_indices.py index 3d83f857304..9e6f54e48c4 100644 --- a/tests/transforms/test_indices.py +++ b/tests/transforms/test_indices.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import kornia.augmentation as K import pytest import torch from torch import Tensor @@ -20,7 +21,6 @@ AppendRBNDVI, AppendSWI, AppendTriBandNormalizedDifferenceIndex, - AugmentationSequential, ) @@ -42,9 +42,8 @@ def batch() -> dict[str, Tensor]: def test_append_index_sample(sample: dict[str, Tensor]) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential( - AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=['image', 'mask'], + aug = K.AugmentationSequential( + AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None ) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) @@ -52,9 +51,8 @@ def test_append_index_sample(sample: dict[str, Tensor]) -> None: def test_append_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch['image'].shape - aug = AugmentationSequential( - AppendNormalizedDifferenceIndex(index_a=0, index_b=1), - data_keys=['image', 'mask'], + aug = K.AugmentationSequential( + AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None ) output = aug(batch) assert output['image'].shape == (b, c + 1, h, w) @@ -62,9 +60,9 @@ def test_append_index_batch(batch: dict[str, Tensor]) -> None: def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None: b, c, h, w = batch['image'].shape - aug = AugmentationSequential( + aug = K.AugmentationSequential( AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2), - data_keys=['image', 'mask'], + data_keys=None, ) output = aug(batch) assert output['image'].shape == (b, c + 1, h, w) @@ -88,7 +86,7 @@ def test_append_normalized_difference_indices( sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential(index(0, 1), data_keys=['image', 'mask']) + aug = K.AugmentationSequential(index(0, 1), data_keys=None) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) @@ -98,6 +96,6 @@ def test_append_tri_band_normalized_difference_indices( sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex ) -> None: c, h, w = sample['image'].shape - aug = AugmentationSequential(index(0, 1, 2), data_keys=['image', 'mask']) + aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None) output = aug(sample) assert output['image'].shape == (1, c + 1, h, w) diff --git a/torchgeo/__init__.py b/torchgeo/__init__.py index 21e4dc5ee3f..fb5ca15dd9c 100644 --- a/torchgeo/__init__.py +++ b/torchgeo/__init__.py @@ -3,12 +3,12 @@ """TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data. -This library is part of the `PyTorch `_ project. PyTorch is an open -source machine learning framework. +This library is part of the `PyTorch `_ project. PyTorch is an +open source machine learning framework. The :mod:`torchgeo` package consists of popular datasets, model architectures, and common image transformations for geospatial data. """ __author__ = 'Adam J. Stewart' -__version__ = '0.6.0.dev0' +__version__ = '0.7.0.dev0' diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index a22a581dea3..6dd7231e3df 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -5,22 +5,28 @@ from .agrifieldnet import AgriFieldNetDataModule from .bigearthnet import BigEarthNetDataModule +from .cabuar import CaBuArDataModule +from .caffe import CaFFeDataModule from .chabud import ChaBuDDataModule from .chesapeake import ChesapeakeCVPRDataModule from .cowc import COWCCountingDataModule from .cyclone import TropicalCycloneDataModule from .deepglobelandcover import DeepGlobeLandCoverDataModule +from .digital_typhoon import DigitalTyphoonDataModule from .etci2021 import ETCI2021DataModule from .eurosat import EuroSAT100DataModule, EuroSATDataModule, EuroSATSpatialDataModule from .fair1m import FAIR1MDataModule from .fire_risk import FireRiskDataModule +from .ftw import FieldsOfTheWorldDataModule from .geo import BaseDataModule, GeoDataModule, NonGeoDataModule +from .geonrw import GeoNRWDataModule from .gid15 import GID15DataModule +from .hyspecnet import HySpecNet11kDataModule from .inria import InriaAerialImageLabelingDataModule from .iobench import IOBenchDataModule from .l7irish import L7IrishDataModule from .l8biome import L8BiomeDataModule -from .landcoverai import LandCoverAIDataModule +from .landcoverai import LandCoverAI100DataModule, LandCoverAIDataModule from .levircd import LEVIRCDDataModule, LEVIRCDPlusDataModule from .loveda import LoveDADataModule from .naip import NAIPChesapeakeDataModule @@ -38,10 +44,11 @@ from .skippd import SKIPPDDataModule from .so2sat import So2SatDataModule from .southafricacroptype import SouthAfricaCropTypeDataModule -from .spacenet import SpaceNet1DataModule +from .spacenet import SpaceNet1DataModule, SpaceNet6DataModule, SpaceNetBaseDataModule from .ssl4eo import SSL4EOLDataModule, SSL4EOS12DataModule from .ssl4eo_benchmark import SSL4EOLBenchmarkDataModule from .sustainbench_crop_yield import SustainBenchCropYieldDataModule +from .treesatai import TreeSatAIDataModule from .ucmerced import UCMercedDataModule from .usavars import USAVarsDataModule from .utils import MisconfigurationException @@ -50,59 +57,65 @@ from .xview import XView2DataModule __all__ = ( - # GeoDataset 'AgriFieldNetDataModule', - 'ChesapeakeCVPRDataModule', - 'IOBenchDataModule', - 'L7IrishDataModule', - 'L8BiomeDataModule', - 'NAIPChesapeakeDataModule', - 'Sentinel2CDLDataModule', - 'Sentinel2EuroCropsDataModule', - 'Sentinel2NCCMDataModule', - 'Sentinel2SouthAmericaSoybeanDataModule', - 'SouthAfricaCropTypeDataModule', - # NonGeoDataset + 'BaseDataModule', 'BigEarthNetDataModule', - 'ChaBuDDataModule', 'COWCCountingDataModule', + 'CaBuArDataModule', + 'CaFFeDataModule', + 'ChaBuDDataModule', + 'ChesapeakeCVPRDataModule', 'DeepGlobeLandCoverDataModule', + 'DigitalTyphoonDataModule', 'ETCI2021DataModule', + 'EuroSAT100DataModule', 'EuroSATDataModule', 'EuroSATSpatialDataModule', - 'EuroSAT100DataModule', 'FAIR1MDataModule', + 'FieldsOfTheWorldDataModule', 'FireRiskDataModule', 'GID15DataModule', + 'GeoDataModule', + 'GeoNRWDataModule', + 'HySpecNet11kDataModule', + 'IOBenchDataModule', 'InriaAerialImageLabelingDataModule', - 'LandCoverAIDataModule', + 'L7IrishDataModule', + 'L8BiomeDataModule', 'LEVIRCDDataModule', 'LEVIRCDPlusDataModule', + 'LandCoverAI100DataModule', + 'LandCoverAIDataModule', 'LoveDADataModule', + 'MisconfigurationException', + 'NAIPChesapeakeDataModule', 'NASAMarineDebrisDataModule', + 'NonGeoDataModule', 'OSCDDataModule', 'Potsdam2DDataModule', 'QuakeSetDataModule', 'RESISC45DataModule', - 'SeasonalContrastS2DataModule', 'SEN12MSDataModule', 'SKIPPDDataModule', - 'So2SatDataModule', - 'SpaceNet1DataModule', 'SSL4EOLBenchmarkDataModule', 'SSL4EOLDataModule', 'SSL4EOS12DataModule', + 'SeasonalContrastS2DataModule', + 'Sentinel2CDLDataModule', + 'Sentinel2EuroCropsDataModule', + 'Sentinel2NCCMDataModule', + 'Sentinel2SouthAmericaSoybeanDataModule', + 'So2SatDataModule', + 'SouthAfricaCropTypeDataModule', + 'SpaceNet1DataModule', + 'SpaceNet6DataModule', + 'SpaceNetBaseDataModule', 'SustainBenchCropYieldDataModule', + 'TreeSatAIDataModule', 'TropicalCycloneDataModule', 'UCMercedDataModule', 'USAVarsDataModule', - 'Vaihingen2DDataModule', 'VHR10DataModule', + 'Vaihingen2DDataModule', 'XView2DataModule', - # Base classes - 'BaseDataModule', - 'GeoDataModule', - 'NonGeoDataModule', - # Utilities - 'MisconfigurationException', ) diff --git a/torchgeo/datamodules/agrifieldnet.py b/torchgeo/datamodules/agrifieldnet.py index bed6365d4a2..cbb8af25356 100644 --- a/torchgeo/datamodules/agrifieldnet.py +++ b/torchgeo/datamodules/agrifieldnet.py @@ -12,7 +12,6 @@ from ..datasets import AgriFieldNet, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,13 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, @@ -74,7 +74,11 @@ def setup(self, stage: str) -> None: if stage in ['fit']: self.train_batch_sampler = RandomBatchGeoSampler( - self.train_dataset, self.patch_size, self.batch_size, self.length + self.train_dataset, + self.patch_size, + self.batch_size, + self.length, + generator=generator, ) if stage in ['fit', 'validate']: self.val_sampler = GridGeoSampler( diff --git a/torchgeo/datamodules/cabuar.py b/torchgeo/datamodules/cabuar.py new file mode 100644 index 00000000000..2ce459bceae --- /dev/null +++ b/torchgeo/datamodules/cabuar.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""CaBuAr datamodule.""" + +from typing import Any + +import torch +from einops import repeat + +from ..datasets import CaBuAr +from .geo import NonGeoDataModule + + +class CaBuArDataModule(NonGeoDataModule): + """LightningDataModule implementation for the CaBuAr dataset. + + Uses the train/val/test splits from the dataset + + .. versionadded:: 0.6 + """ + + # min/max values computed on train set using 2/98 percentiles + min = torch.tensor( + [0.0, 1.0, 73.0, 39.0, 46.0, 25.0, 26.0, 21.0, 17.0, 1.0, 20.0, 21.0] + ) + max = torch.tensor( + [ + 1926.0, + 2174.0, + 2527.0, + 2950.0, + 3237.0, + 3717.0, + 4087.0, + 4271.0, + 4290.0, + 4219.0, + 4568.0, + 3753.0, + ] + ) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new CaBuArDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.CaBuAr`. + """ + bands = kwargs.get('bands', CaBuAr.all_bands) + band_indices = [CaBuAr.all_bands.index(b) for b in bands] + mins = self.min[band_indices] + maxs = self.max[band_indices] + + # Change detection, 2 images from different times + mins = repeat(mins, 'c -> (t c)', t=2) + maxs = repeat(maxs, 'c -> (t c)', t=2) + + self.mean = mins + self.std = maxs - mins + + super().__init__(CaBuAr, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/caffe.py b/torchgeo/datamodules/caffe.py new file mode 100644 index 00000000000..a58136df30b --- /dev/null +++ b/torchgeo/datamodules/caffe.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""CaFFe datamodule.""" + +from typing import Any + +import kornia.augmentation as K +import torch + +from ..datasets import CaFFe +from .geo import NonGeoDataModule + + +class CaFFeDataModule(NonGeoDataModule): + """LightningDataModule implementation for the CaFFe dataset. + + Implements the default splits that come with the dataset. + + .. versionadded:: 0.7 + """ + + mean = torch.Tensor([0.5517]) + std = torch.Tensor([11.8478]) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, size: int = 512, **kwargs: Any + ) -> None: + """Initialize a new CaFFeDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + size: resize images of input size 512x512 to size x size + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.CaFFe`. + """ + super().__init__(CaFFe, batch_size, num_workers, **kwargs) + + self.size = size + + self.train_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.Resize(size), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + data_keys=None, + keepdim=True, + ) + + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.Resize(size), + data_keys=None, + keepdim=True, + ) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 37a0d32edbd..41e944e1af5 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -6,49 +6,14 @@ from typing import Any import kornia.augmentation as K -import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from torch import Tensor from ..datasets import ChesapeakeCVPR from ..samplers import GridGeoSampler, RandomBatchGeoSampler -from ..transforms import AugmentationSequential from .geo import GeoDataModule -class _Transform(nn.Module): - """Version of AugmentationSequential designed for samples, not batches.""" - - def __init__(self, aug: nn.Module) -> None: - """Initialize a new _Transform instance. - - Args: - aug: Augmentation to apply. - """ - super().__init__() - self.aug = aug - - def forward(self, sample: dict[str, Any]) -> dict[str, Any]: - """Apply the augmentation. - - Args: - sample: Input sample. - - Returns: - Augmented sample. - """ - for key in ['image', 'mask']: - dtype = sample[key].dtype - # All inputs must be float - sample[key] = sample[key].float() - sample[key] = self.aug(sample[key]) - sample[key] = sample[key].to(dtype) - # Kornia adds batch dimension - sample[key] = rearrange(sample[key], '() c h w -> c h w') - return sample - - class ChesapeakeCVPRDataModule(GeoDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. @@ -94,7 +59,9 @@ def __init__( # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 3 - kwargs['transforms'] = _Transform(K.CenterCrop(patch_size)) + kwargs['transforms'] = K.AugmentationSequential( + K.CenterCrop(patch_size), data_keys=None, keepdim=True + ) super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs @@ -122,8 +89,8 @@ def __init__( else: self.layers = ['naip-new', 'lc'] - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 39021fc2acc..e9af302094b 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -43,18 +43,11 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ['fit', 'validate']: - self.dataset = TropicalCyclone(split='train', **self.kwargs) - - storm_ids = [] - for item in self.dataset.collection: - storm_id = item['href'].split('/')[0].split('_')[-2] - storm_ids.append(storm_id) - + dataset = TropicalCyclone(split='train', **self.kwargs) train_indices, val_indices = group_shuffle_split( - storm_ids, test_size=0.2, random_state=0 + dataset.features['Storm ID'], test_size=0.2, random_state=0 ) - - self.train_dataset = Subset(self.dataset, train_indices) - self.val_dataset = Subset(self.dataset, val_indices) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) if stage in ['test']: self.test_dataset = TropicalCyclone(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 80ea6052cb7..b3ab2d687b5 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -11,7 +11,6 @@ from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -46,10 +45,11 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/digital_typhoon.py b/torchgeo/datamodules/digital_typhoon.py new file mode 100644 index 00000000000..ce799bf3d52 --- /dev/null +++ b/torchgeo/datamodules/digital_typhoon.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Digital Typhoon Data Module.""" + +import copy +from collections import defaultdict +from typing import Any + +from torch.utils.data import Subset + +from ..datasets import DigitalTyphoon +from ..datasets.digital_typhoon import _SampleSequenceDict +from .geo import NonGeoDataModule +from .utils import group_shuffle_split + + +class DigitalTyphoonDataModule(NonGeoDataModule): + """Digital Typhoon Data Module. + + .. versionadded:: 0.6 + """ + + valid_split_types = ('time', 'typhoon_id') + + def __init__( + self, + split_by: str = 'time', + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new DigitalTyphoonDataModule instance. + + Args: + split_by: Either 'time' or 'typhoon_id', which decides how to split + the dataset for train, val, test + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.DigitalTyphoon`. + + """ + super().__init__(DigitalTyphoon, batch_size, num_workers, **kwargs) + + assert ( + split_by in self.valid_split_types + ), f'Please choose from {self.valid_split_types}' + self.split_by = split_by + + def _split_dataset( + self, sample_sequences: list[_SampleSequenceDict] + ) -> tuple[list[int], list[int]]: + """Split dataset into two parts. + + Args: + sample_sequences: List of sample sequence dictionaries to be split + + Returns: + a tuple of the subset datasets + """ + if self.split_by == 'time': + # split dataset such that only unseen future time steps of storms + # are contained in validation + grouped_sequences = defaultdict(list) + for idx, seq in enumerate(sample_sequences): + grouped_sequences[seq['id']].append((idx, seq['seq_id'])) + + train_indices = [] + val_indices = [] + + for id, sequences in grouped_sequences.items(): + split_idx = int(len(sequences) * 0.8) + train_sequences = sequences[:split_idx] + val_sequences = sequences[split_idx:] + train_indices.extend([idx for idx, _ in train_sequences]) + val_indices.extend([idx for idx, _ in val_sequences]) + + else: + # split dataset such that the id of storms is mutually exclusive + train_indices, val_indices = group_shuffle_split( + [x['id'] for x in sample_sequences], train_size=0.8, random_state=0 + ) + + return train_indices, val_indices + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + self.dataset = DigitalTyphoon(**self.kwargs) + + all_sample_sequences = copy.deepcopy(self.dataset.sample_sequences) + + train_indices, test_indices = self._split_dataset(self.dataset.sample_sequences) + + if stage in ['fit', 'validate']: + # Randomly split train into train and validation sets + index_mapping = { + new_index: original_index + for new_index, original_index in enumerate(train_indices) + } + train_sequences = [all_sample_sequences[i] for i in train_indices] + train_indices, val_indices = self._split_dataset(train_sequences) + train_indices = [index_mapping[i] for i in train_indices] + val_indices = [index_mapping[i] for i in val_indices] + + # Create train val subset dataset + self.train_dataset = Subset(self.dataset, train_indices) + self.val_dataset = Subset(self.dataset, val_indices) + + if stage in ['test']: + self.test_dataset = Subset(self.dataset, test_indices) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 845edd243dd..6b8b9115d1c 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -19,11 +19,11 @@ 'B06': 2130.3491, 'B07': 2524.0549, 'B08': 2454.1938, - 'B8A': 785.4963, - 'B09': 12.4639, - 'B10': 1969.9224, - 'B11': 1206.2421, - 'B12': 2779.4104, + 'B09': 785.4963, + 'B10': 12.4639, + 'B11': 1969.9224, + 'B12': 1206.2421, + 'B8A': 2779.4104, } SPATIAL_STD = { @@ -35,11 +35,11 @@ 'B06': 806.8271, 'B07': 1022.6378, 'B08': 1065.4312, - 'B8A': 410.5831, - 'B09': 4.8878, - 'B10': 958.4751, - 'B11': 740.6196, - 'B12': 1157.2896, + 'B09': 410.5831, + 'B10': 4.8878, + 'B11': 958.4751, + 'B12': 740.6196, + 'B8A': 1157.2896, } MEAN = { @@ -51,11 +51,11 @@ 'B06': 1999.79090914, 'B07': 2369.22292565, 'B08': 2296.82608323, - 'B8A': 732.08340178, - 'B09': 12.11327804, - 'B10': 1819.01027855, - 'B11': 1118.92391149, - 'B12': 2594.14080798, + 'B09': 732.08340178, + 'B10': 12.11327804, + 'B11': 1819.01027855, + 'B12': 1118.92391149, + 'B8A': 2594.14080798, } STD = { @@ -67,11 +67,11 @@ 'B06': 861.18399006, 'B07': 1086.63139075, 'B08': 1117.98170791, - 'B8A': 404.91978886, - 'B09': 4.77584468, - 'B10': 1002.58768311, - 'B11': 761.30323499, - 'B12': 1231.58581042, + 'B09': 404.91978886, + 'B10': 4.77584468, + 'B11': 1002.58768311, + 'B12': 761.30323499, + 'B8A': 1231.58581042, } @@ -94,11 +94,10 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.EuroSAT`. """ - super().__init__(EuroSAT, batch_size, num_workers, **kwargs) - bands = kwargs.get('bands', EuroSAT.all_band_names) self.mean = torch.tensor([MEAN[b] for b in bands]) self.std = torch.tensor([STD[b] for b in bands]) + super().__init__(EuroSAT, batch_size, num_workers, **kwargs) class EuroSATSpatialDataModule(NonGeoDataModule): @@ -120,11 +119,10 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.EuroSATSpatial`. """ - super().__init__(EuroSATSpatial, batch_size, num_workers, **kwargs) - bands = kwargs.get('bands', EuroSAT.all_band_names) self.mean = torch.tensor([SPATIAL_MEAN[b] for b in bands]) self.std = torch.tensor([SPATIAL_STD[b] for b in bands]) + super().__init__(EuroSATSpatial, batch_size, num_workers, **kwargs) class EuroSAT100DataModule(NonGeoDataModule): @@ -146,8 +144,7 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.EuroSAT100`. """ - super().__init__(EuroSAT100, batch_size, num_workers, **kwargs) - bands = kwargs.get('bands', EuroSAT.all_band_names) self.mean = torch.tensor([MEAN[b] for b in bands]) self.std = torch.tensor([STD[b] for b in bands]) + super().__init__(EuroSAT100, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/fire_risk.py b/torchgeo/datamodules/fire_risk.py index 1a0d6c7c047..d317981cff3 100644 --- a/torchgeo/datamodules/fire_risk.py +++ b/torchgeo/datamodules/fire_risk.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import FireRisk -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -30,7 +29,7 @@ def __init__( :class:`~torchgeo.datasets.FireRisk`. """ super().__init__(FireRisk, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -38,7 +37,8 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/ftw.py b/torchgeo/datamodules/ftw.py new file mode 100644 index 00000000000..a197a789c48 --- /dev/null +++ b/torchgeo/datamodules/ftw.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""FTW datamodule.""" + +from typing import Any + +import kornia.augmentation as K +import torch + +from ..datasets import FieldsOfTheWorld +from .geo import NonGeoDataModule + + +class FieldsOfTheWorldDataModule(NonGeoDataModule): + """LightningDataModule implementation for the FTW dataset. + + .. versionadded:: 0.7 + """ + + mean = torch.tensor([0]) + std = torch.tensor([3000]) + + def __init__( + self, + train_countries: list[str] = ['austria'], + val_countries: list[str] = ['austria'], + test_countries: list[str] = ['austria'], + batch_size: int = 64, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new FTWDataModule instance. + + Args: + train_countries: List of countries to use for training. + val_countries: List of countries to use for validation. + test_countries: List of countries to use for testing. + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.FieldsOfTheWorld`. + + Raises: + AssertionError: If 'countries' are specified in kwargs + """ + assert ( + 'countries' not in kwargs + ), "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs" + + super().__init__(FieldsOfTheWorld, batch_size, num_workers, **kwargs) + + self.train_countries = train_countries + self.val_countries = val_countries + self.test_countries = test_countries + + self.train_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + data_keys=None, + keepdim=True, + ) + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', or 'test'. + """ + if stage in ['fit', 'validate']: + self.train_dataset = FieldsOfTheWorld( + split='train', countries=self.train_countries, **self.kwargs + ) + self.val_dataset = FieldsOfTheWorld( + split='val', countries=self.val_countries, **self.kwargs + ) + if stage in ['test']: + self.test_dataset = FieldsOfTheWorld( + split='test', countries=self.test_countries, **self.kwargs + ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 5f77c0c4d6b..8721ea6e7f6 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -20,7 +20,6 @@ GridGeoSampler, RandomBatchGeoSampler, ) -from ..transforms import AugmentationSequential from .utils import MisconfigurationException @@ -70,9 +69,10 @@ def __init__( # Data augmentation Transform = Callable[[dict[str, Tensor]], dict[str, Tensor]] - self.aug: Transform = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image'] + self.aug: Transform = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) + self.train_aug: Transform | None = None self.val_aug: Transform | None = None self.test_aug: Transform | None = None @@ -286,6 +286,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=self.collate_fn, + persistent_workers=self.num_workers > 0, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: @@ -353,7 +354,7 @@ def transfer_batch_to_device( """ # Non-Tensor values cannot be moved to a device del batch['crs'] - del batch['bbox'] + del batch['bounds'] batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch @@ -429,6 +430,7 @@ def _dataloader_factory(self, split: str) -> DataLoader[dict[str, Tensor]]: shuffle=split == 'train', num_workers=self.num_workers, collate_fn=self.collate_fn, + persistent_workers=self.num_workers > 0, ) def train_dataloader(self) -> DataLoader[dict[str, Tensor]]: diff --git a/torchgeo/datamodules/geonrw.py b/torchgeo/datamodules/geonrw.py new file mode 100644 index 00000000000..5283b0ed7ac --- /dev/null +++ b/torchgeo/datamodules/geonrw.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""GeoNRW datamodule.""" + +import os +from typing import Any + +import kornia.augmentation as K +from torch.utils.data import Subset + +from ..datasets import GeoNRW +from .geo import NonGeoDataModule +from .utils import group_shuffle_split + + +class GeoNRWDataModule(NonGeoDataModule): + """LightningDataModule implementation for the GeoNRW dataset. + + Implements 80/20 train/val splits based on city locations. + See :func:`setup` for more details. + + .. versionadded:: 0.6 + """ + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any + ) -> None: + """Initialize a new GeoNRWDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + size: resize images of input size 1000x1000 to size x size + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.GeoNRW`. + """ + super().__init__(GeoNRW, batch_size, num_workers, **kwargs) + + self.train_aug = K.AugmentationSequential( + K.Resize(size), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + data_keys=None, + keepdim=True, + ) + + self.aug = K.AugmentationSequential( + K.Resize(size), data_keys=None, keepdim=True + ) + + self.size = size + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ['fit', 'validate']: + dataset = GeoNRW(split='train', **self.kwargs) + city_paths = [os.path.dirname(path) for path in dataset.file_list] + train_indices, val_indices = group_shuffle_split( + city_paths, test_size=0.2, random_state=0 + ) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) + if stage in ['test']: + self.test_dataset = GeoNRW(split='test', **self.kwargs) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index fc4e802c148..d33c55ec829 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -11,7 +11,6 @@ from ..datasets import GID15 from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,15 +47,17 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = self.val_aug = AugmentationSequential( + self.train_aug = self.val_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/hyspecnet.py b/torchgeo/datamodules/hyspecnet.py new file mode 100644 index 00000000000..3e508ef11a7 --- /dev/null +++ b/torchgeo/datamodules/hyspecnet.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HySpecNet datamodule.""" + +from typing import Any + +import torch + +from ..datasets import HySpecNet11k +from .geo import NonGeoDataModule + + +class HySpecNet11kDataModule(NonGeoDataModule): + """LightningDataModule implementation for the HySpecNet11k dataset. + + .. versionadded:: 0.7 + """ + + # https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb + mean = torch.tensor(0) + std = torch.tensor(10000) + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new HySpecNet11kDataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.HySpecNet11k`. + """ + super().__init__(HySpecNet11k, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 39e8ede22c5..797f5484b6a 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -9,7 +9,6 @@ from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -44,22 +43,25 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.predict_aug = AugmentationSequential( + self.predict_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/l7irish.py b/torchgeo/datamodules/l7irish.py index 35408feddbb..3a70446f90f 100644 --- a/torchgeo/datamodules/l7irish.py +++ b/torchgeo/datamodules/l7irish.py @@ -12,7 +12,6 @@ from ..datasets import L7Irish, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,13 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/l8biome.py b/torchgeo/datamodules/l8biome.py index ddc802a5ce3..cf0415b34c9 100644 --- a/torchgeo/datamodules/l8biome.py +++ b/torchgeo/datamodules/l8biome.py @@ -12,7 +12,6 @@ from ..datasets import L8Biome, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,12 +48,13 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index d775cf21a04..9ed2a4d34d6 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -1,14 +1,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""LandCover.ai datamodule.""" +"""LandCover.ai datamodules.""" from typing import Any import kornia.augmentation as K -from ..datasets import LandCoverAI -from ..transforms import AugmentationSequential +from ..datasets import LandCoverAI, LandCoverAI100 from .geo import NonGeoDataModule @@ -31,15 +30,42 @@ def __init__( """ super().__init__(LandCoverAI, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), K.RandomSharpness(p=0.5), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True + ) + + +class LandCoverAI100DataModule(NonGeoDataModule): + """LightningDataModule implementation for the LandCoverAI100 dataset. + + Uses the train/val/test splits from the dataset. + + .. versionadded:: 0.7 + """ + + def __init__( + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + ) -> None: + """Initialize a new LandCoverAI100DataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.LandCoverAI100`. + """ + super().__init__(LandCoverAI100, batch_size, num_workers, **kwargs) + + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) diff --git a/torchgeo/datamodules/levircd.py b/torchgeo/datamodules/levircd.py index 8488c2e58b3..0e3a124dc94 100644 --- a/torchgeo/datamodules/levircd.py +++ b/torchgeo/datamodules/levircd.py @@ -11,7 +11,6 @@ from ..datasets import LEVIRCD, LEVIRCDPlus from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -43,18 +42,17 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, + keepdim=True, ) - self.val_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.val_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - self.test_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.test_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) @@ -91,18 +89,17 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, + keepdim=True, ) - self.val_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.val_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) - self.test_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=['image1', 'image2', 'mask'], + self.test_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index b414cc0991e..0520d0264ad 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -8,9 +8,18 @@ import kornia.augmentation as K from matplotlib.figure import Figure -from ..datasets import NAIP, BoundingBox, Chesapeake13 +from ..datasets import ( + NAIP, + BoundingBox, + ChesapeakeDC, + ChesapeakeDE, + ChesapeakeMD, + ChesapeakeNY, + ChesapeakePA, + ChesapeakeVA, + ChesapeakeWV, +) from ..samplers import GridGeoSampler, RandomBatchGeoSampler -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -37,7 +46,7 @@ def __init__( num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.NAIP` (prefix keys with ``naip_``) and - :class:`~torchgeo.datasets.Chesapeake13` + :class:`~torchgeo.datasets.Chesapeake` (prefix keys with ``chesapeake_``). """ self.naip_kwargs = {} @@ -49,16 +58,11 @@ def __init__( self.chesapeake_kwargs[key[11:]] = val super().__init__( - Chesapeake13, - batch_size, - patch_size, - length, - num_workers, - **self.chesapeake_kwargs, + NAIP, batch_size, patch_size, length, num_workers, **self.naip_kwargs ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: @@ -67,9 +71,16 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.chesapeake = Chesapeake13(**self.chesapeake_kwargs) self.naip = NAIP(**self.naip_kwargs) - self.dataset = self.chesapeake & self.naip + dc = ChesapeakeDC(**self.chesapeake_kwargs) + de = ChesapeakeDE(**self.chesapeake_kwargs) + md = ChesapeakeMD(**self.chesapeake_kwargs) + ny = ChesapeakeNY(**self.chesapeake_kwargs) + pa = ChesapeakePA(**self.chesapeake_kwargs) + va = ChesapeakeVA(**self.chesapeake_kwargs) + wv = ChesapeakeWV(**self.chesapeake_kwargs) + self.chesapeake = dc | de | md | ny | pa | va | wv + self.dataset = self.naip & self.chesapeake roi = self.dataset.bounds midx = roi.minx + (roi.maxx - roi.minx) / 2 diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 630ef635198..8db1dd7061a 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -11,40 +11,39 @@ from ..datasets import OSCD from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule MEAN = { - 'B01': 1583.0741, - 'B02': 1374.3202, - 'B03': 1294.1616, - 'B04': 1325.6158, - 'B05': 1478.7408, - 'B06': 1933.0822, - 'B07': 2166.0608, - 'B08': 2076.4868, - 'B8A': 2306.0652, - 'B09': 690.9814, - 'B10': 16.2360, - 'B11': 2080.3347, - 'B12': 1524.6930, + 'B01': 1565.696044921875, + 'B02': 1351.3319091796875, + 'B03': 1257.1082763671875, + 'B04': 1254.932861328125, + 'B05': 1388.689208984375, + 'B06': 1827.6710205078125, + 'B07': 2050.2744140625, + 'B08': 1963.4619140625, + 'B8A': 2182.680908203125, + 'B09': 629.837646484375, + 'B10': 14.855598449707031, + 'B11': 1909.8394775390625, + 'B12': 1379.6024169921875, } STD = { - 'B01': 52.1937, - 'B02': 83.4168, - 'B03': 105.6966, - 'B04': 151.1401, - 'B05': 147.4615, - 'B06': 115.9289, - 'B07': 123.1974, - 'B08': 114.6483, - 'B8A': 141.4530, - 'B09': 73.2758, - 'B10': 4.8368, - 'B11': 213.4821, - 'B12': 179.4793, + 'B01': 263.7977600097656, + 'B02': 394.5567321777344, + 'B03': 508.9673767089844, + 'B04': 726.4053344726562, + 'B05': 686.6111450195312, + 'B06': 730.0204467773438, + 'B07': 822.0133056640625, + 'B08': 842.5917358398438, + 'B8A': 895.7645263671875, + 'B09': 314.8407287597656, + 'B10': 9.417905807495117, + 'B11': 984.9249267578125, + 'B12': 844.7711181640625, } @@ -85,10 +84,11 @@ def __init__( self.mean = torch.tensor([MEAN[b] for b in self.bands]) self.std = torch.tensor([STD[b] for b in self.bands]) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image1', 'image2', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 8011382c769..7a5495a4458 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -11,7 +11,6 @@ from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,10 +47,11 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/quakeset.py b/torchgeo/datamodules/quakeset.py index 1a3e19a5122..03c677138e9 100644 --- a/torchgeo/datamodules/quakeset.py +++ b/torchgeo/datamodules/quakeset.py @@ -9,7 +9,6 @@ import torch from ..datasets import QuakeSet -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -34,9 +33,10 @@ def __init__( :class:`~torchgeo.datasets.QuakeSet`. """ super().__init__(QuakeSet, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=['image'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index e88e139f481..e279478f8d0 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -9,7 +9,6 @@ import torch from ..datasets import RESISC45 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -36,7 +35,7 @@ def __init__( """ super().__init__(RESISC45, batch_size, num_workers, **kwargs) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -44,5 +43,6 @@ def __init__( K.RandomSharpness(p=0.5), K.RandomErasing(p=0.1), K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/seco.py b/torchgeo/datamodules/seco.py index f1ed2346164..ecfeb04b288 100644 --- a/torchgeo/datamodules/seco.py +++ b/torchgeo/datamodules/seco.py @@ -10,7 +10,6 @@ from einops import repeat from ..datasets import SeasonalContrastS2 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -37,7 +36,7 @@ def __init__( seasons = kwargs.get('seasons', 1) # Normalization only available for RGB dataset, defined here: - # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 + # https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py if bands == SeasonalContrastS2.rgb_bands: _min = torch.tensor([3, 2, 0]) _max = torch.tensor([88, 103, 129]) @@ -49,11 +48,12 @@ def __init__( _mean = repeat(_mean, 'c -> (t c)', t=seasons) _std = repeat(_std, 'c -> (t c)', t=seasons) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=_min, std=_max - _min), K.Normalize(mean=torch.tensor(0), std=1 / torch.tensor(255)), K.Normalize(mean=_mean, std=_std), - data_keys=['image'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_cdl.py b/torchgeo/datamodules/sentinel2_cdl.py index 97c3d05392e..91af34b0ef1 100644 --- a/torchgeo/datamodules/sentinel2_cdl.py +++ b/torchgeo/datamodules/sentinel2_cdl.py @@ -13,7 +13,6 @@ from ..datasets import CDL, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -63,19 +62,20 @@ def __init__( CDL, batch_size, patch_size, length, num_workers, **self.cdl_kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_eurocrops.py b/torchgeo/datamodules/sentinel2_eurocrops.py index 4e0893e4f8d..8f34c2598ef 100644 --- a/torchgeo/datamodules/sentinel2_eurocrops.py +++ b/torchgeo/datamodules/sentinel2_eurocrops.py @@ -13,7 +13,6 @@ from ..datasets import EuroCrops, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -64,19 +63,20 @@ def __init__( **self.eurocrops_kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_nccm.py b/torchgeo/datamodules/sentinel2_nccm.py index 91b4f936fdc..34fde0f3153 100644 --- a/torchgeo/datamodules/sentinel2_nccm.py +++ b/torchgeo/datamodules/sentinel2_nccm.py @@ -13,7 +13,6 @@ from ..datasets import NCCM, Sentinel2, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -63,19 +62,20 @@ def __init__( NCCM, batch_size, patch_size, length, num_workers, **self.nccm_kwargs ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py index e3363e857f5..d3deff9e823 100644 --- a/torchgeo/datamodules/sentinel2_south_america_soybean.py +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -14,7 +14,6 @@ from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -62,19 +61,20 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 64701cef519..99126698ffa 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -3,7 +3,7 @@ """So2Sat datamodule.""" -from typing import Any +from typing import Any, ClassVar import torch from torch import Generator, Tensor @@ -21,7 +21,7 @@ class So2SatDataModule(NonGeoDataModule): "train" set and use the "test" set as the test set. """ - means_per_version: dict[str, Tensor] = { + means_per_version: ClassVar[dict[str, Tensor]] = { '2': torch.tensor( [ -0.00003591224260, @@ -91,7 +91,7 @@ class So2SatDataModule(NonGeoDataModule): } means_per_version['3_culture_10'] = means_per_version['2'] - stds_per_version: dict[str, Tensor] = { + stds_per_version: ClassVar[dict[str, Tensor]] = { '2': torch.tensor( [ 0.17555201, diff --git a/torchgeo/datamodules/southafricacroptype.py b/torchgeo/datamodules/southafricacroptype.py index 3f44bb61471..37fdef5e7db 100644 --- a/torchgeo/datamodules/southafricacroptype.py +++ b/torchgeo/datamodules/southafricacroptype.py @@ -12,7 +12,6 @@ from ..datasets import SouthAfricaCropType, random_bbox_assignment from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -49,19 +48,20 @@ def __init__( **kwargs, ) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), data_keys=['image', 'mask'] + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=None, keepdim=True ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index dad903bd717..7353efbbaec 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -10,57 +10,43 @@ from torch import Tensor from torch.utils.data import random_split -from ..datasets import SpaceNet1 -from ..transforms import AugmentationSequential +from ..datasets import SpaceNet, SpaceNet1, SpaceNet6 from .geo import NonGeoDataModule -class SpaceNet1DataModule(NonGeoDataModule): - """LightningDataModule implementation for the SpaceNet1 dataset. +class SpaceNetBaseDataModule(NonGeoDataModule): + """LightningDataModule implementation for the SpaceNet datasets. - Randomly splits into train/val/test. + Randomly splits the train split into train/val/test. The test split does not have labels, + and is only used for prediction. - .. versionadded:: 0.4 + .. versionadded:: 0.7 """ def __init__( self, + spacenet_ds_class: type[SpaceNet], batch_size: int = 64, num_workers: int = 0, val_split_pct: float = 0.1, test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a new SpaceNet1DataModule instance. + """Initialize a new SpaceNetBaseDataModule instance. Args: + spacenet_ds_class: The SpaceNet dataset class to use. batch_size: Size of each mini-batch. - num_workers: Number of workers for parallel data loading. val_split_pct: Percentage of the dataset to use as a validation set. test_split_pct: Percentage of the dataset to use as a test set. - **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.SpaceNet1`. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to the SpaceNet dataset. """ - super().__init__(SpaceNet1, batch_size, num_workers, **kwargs) + super().__init__(spacenet_ds_class, batch_size, num_workers, **kwargs) self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - - self.train_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - K.PadTo((448, 448)), - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - data_keys=['image', 'mask'], - ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - K.PadTo((448, 448)), - data_keys=['image', 'mask'], - ) + self.spacenet_ds_class = spacenet_ds_class def setup(self, stage: str) -> None: """Set up datasets. @@ -68,17 +54,22 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = SpaceNet1(**self.kwargs) - generator = torch.Generator().manual_seed(0) - self.train_dataset, self.val_dataset, self.test_dataset = random_split( - self.dataset, - [ - 1 - self.val_split_pct - self.test_split_pct, - self.val_split_pct, - self.test_split_pct, - ], - generator, - ) + if stage in ['fit', 'validate', 'test']: + self.dataset = self.spacenet_ds_class(split='train', **self.kwargs) + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset, self.test_dataset = random_split( + self.dataset, + [ + 1 - self.val_split_pct - self.test_split_pct, + self.val_split_pct, + self.test_split_pct, + ], + generator, + ) + + # test split in SpaceNet does not have labels + if stage in ['predict']: + self.predict_dataset = self.spacenet_ds_class(split='test', **self.kwargs) def on_after_batch_transfer( self, batch: dict[str, Tensor], dataloader_idx: int @@ -95,6 +86,95 @@ def on_after_batch_transfer( # We add 1 to the mask to map the current {background, building} labels to # the values {1, 2}. This is necessary because we add 0 padding to the # mask that we want to ignore in the loss function. - batch['mask'] += 1 + if 'mask' in batch: + batch['mask'] += 1 return super().on_after_batch_transfer(batch, dataloader_idx) + + +class SpaceNet1DataModule(SpaceNetBaseDataModule): + """LightningDataModule implementation for the SpaceNet1 dataset. + + Randomly splits the train split into train/val/test. The test split does not have labels, + and is only used for prediction. + + .. versionadded:: 0.4 + """ + + def __init__( + self, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.1, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a new SpaceNet1DataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.SpaceNet1`. + """ + super().__init__( + SpaceNet1, batch_size, num_workers, val_split_pct, test_split_pct, **kwargs + ) + + self.train_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.PadTo((448, 448)), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + data_keys=None, + keepdim=True, + ) + self.aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.PadTo((448, 448)), + data_keys=None, + keepdim=True, + ) + + self.predict_aug = K.AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.PadTo((448, 448)), + data_keys=None, + ) + + +class SpaceNet6DataModule(SpaceNetBaseDataModule): + """LightningDataModule implementation for the SpaceNet6 dataset. + + Randomly splits the train split into train/val/test. The test split does not have labels, + and is only used for prediction. + + .. versionadded:: 0.7 + """ + + def __init__( + self, + batch_size: int = 64, + num_workers: int = 0, + val_split_pct: float = 0.1, + test_split_pct: float = 0.2, + **kwargs: Any, + ) -> None: + """Initialize a new SpaceNet6DataModule instance. + + Args: + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.SpaceNet6`. + """ + super().__init__( + SpaceNet6, batch_size, num_workers, val_split_pct, test_split_pct, **kwargs + ) diff --git a/torchgeo/datamodules/ssl4eo.py b/torchgeo/datamodules/ssl4eo.py index 6ad558dcf87..f0b1ecdee46 100644 --- a/torchgeo/datamodules/ssl4eo.py +++ b/torchgeo/datamodules/ssl4eo.py @@ -45,7 +45,7 @@ class SSL4EOS12DataModule(NonGeoDataModule): .. versionadded:: 0.5 """ - # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 + # https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 mean = torch.tensor(0) std = torch.tensor(10000) diff --git a/torchgeo/datamodules/ssl4eo_benchmark.py b/torchgeo/datamodules/ssl4eo_benchmark.py index c9eb1d2e315..02e5de917dd 100644 --- a/torchgeo/datamodules/ssl4eo_benchmark.py +++ b/torchgeo/datamodules/ssl4eo_benchmark.py @@ -10,7 +10,6 @@ from ..datasets import SSL4EOLBenchmark from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -40,23 +39,26 @@ def __init__( self.patch_size = _to_tuple(patch_size) - self.train_aug = AugmentationSequential( + self.train_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), K.RandomVerticalFlip(p=0.5), K.RandomHorizontalFlip(p=0.5), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, extra_args={ DataKey.MASK: {'resample': Resample.NEAREST, 'align_corners': None} }, ) - self.val_aug = AugmentationSequential( + self.val_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) - self.test_aug = AugmentationSequential( + self.test_aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.CenterCrop(self.patch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/treesatai.py b/torchgeo/datamodules/treesatai.py new file mode 100644 index 00000000000..3db24b4724a --- /dev/null +++ b/torchgeo/datamodules/treesatai.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TreeSatAI datamodules.""" + +from typing import Any + +import kornia.augmentation as K +import torch +from torch import Tensor +from torch.utils.data import random_split + +from ..datasets import TreeSatAI +from ..samplers.utils import _to_tuple +from .geo import NonGeoDataModule + +# https://git.tu-berlin.de/rsim/treesat_benchmark/-/blob/master/configs/multimodal/AllModes_Xformer_ResnetScratch_v8.json +means = { + 'aerial': [ + 151.26809261440323, + 93.1159469148246, + 85.05016794624635, + 81.0471576353153, + ], + 's1': [-6.933713050794077, -12.628564056094067, 0.47448312147709354], + 's2': [ + 231.43385024546893, + 376.94788434611434, + 241.03688288984037, + 2809.8421354087955, + 616.5578221193639, + 2104.3826773960823, + 2695.083864757169, + 2969.868417923599, + 1306.0814241837832, + 587.0608264363341, + 249.1888624097736, + 2950.2294375352285, + ], +} +stds = { + 'aerial': [ + 48.70879149145466, + 33.59622314610158, + 28.000497087051126, + 33.683983599997724, + ], + 's1': [87.8762246957811, 47.03070478433704, 1.297291303623673], + 's2': [ + 123.16515044781909, + 139.78991338362886, + 140.6154081184225, + 786.4508872594147, + 202.51268536579394, + 530.7255451201194, + 710.2650071967689, + 777.4421400779165, + 424.30312334282684, + 247.21468849049668, + 122.80062680549261, + 702.7404237034002, + ], +} + + +class TreeSatAIDataModule(NonGeoDataModule): + """LightningDataModule implementation for the TreeSatAI dataset. + + .. versionadded:: 0.7 + """ + + def __init__( + self, + batch_size: int = 64, + patch_size: int | tuple[int, int] = 304, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new TreeSatAIDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.TreeSatAI`. + """ + super().__init__(TreeSatAI, batch_size, num_workers, **kwargs) + + self.patch_size = _to_tuple(patch_size) + self.sensors = kwargs.get('sensors', TreeSatAI.all_sensors) + + self.train_aug = K.AugmentationSequential( + K.RandomVerticalFlip(p=0.5), + K.RandomHorizontalFlip(p=0.5), + K.Resize(self.patch_size), + data_keys=None, + keepdim=True, + ) + self.aug = K.AugmentationSequential( + K.Resize(self.patch_size), data_keys=None, keepdim=True + ) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + # Convert 90-10 train-test split to 80-10-10 train-val-test split + train_val_dataset = TreeSatAI(split='train', **self.kwargs) + self.test_dataset = TreeSatAI(split='test', **self.kwargs) + generator = torch.Generator().manual_seed(0) + self.train_dataset, self.val_dataset = random_split( + train_val_dataset, + [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], + generator=generator, + ) + + def on_after_batch_transfer( + self, batch: dict[str, Tensor], dataloader_idx: int + ) -> dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + batch = super().on_after_batch_transfer(batch, dataloader_idx) + + images = [] + for sensor in self.sensors: + aug = K.Normalize(mean=means[sensor], std=stds[sensor], keepdim=True) + batch[f'image_{sensor}'] = aug(batch[f'image_{sensor}']) + images.append(batch[f'image_{sensor}']) + + batch['image'] = torch.cat(images, dim=1) + + return batch diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 59bb49444ee..6bb3e70eab2 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -8,7 +8,6 @@ import kornia.augmentation as K from ..datasets import UCMerced -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -31,8 +30,9 @@ def __init__( """ super().__init__(UCMerced, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), K.Resize(size=256), - data_keys=['image'], + data_keys=None, + keepdim=True, ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 98bc4945e95..4fead8c85c8 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -11,7 +11,6 @@ from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple -from ..transforms import AugmentationSequential from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule @@ -48,10 +47,11 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( + self.aug = K.AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, batch_size), - data_keys=['image', 'mask'], + data_keys=None, + keepdim=True, ) def setup(self, stage: str) -> None: diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index e1daabf3ed9..f7e1a9267d1 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -11,13 +11,13 @@ from .benin_cashews import BeninSmallHolderCashews from .bigearthnet import BigEarthNet from .biomassters import BioMassters +from .cabuar import CaBuAr +from .caffe import CaFFe from .cbf import CanadianBuildingFootprints from .cdl import CDL from .chabud import ChaBuD from .chesapeake import ( Chesapeake, - Chesapeake7, - Chesapeake13, ChesapeakeCVPR, ChesapeakeDC, ChesapeakeDE, @@ -35,6 +35,7 @@ from .cyclone import TropicalCyclone from .deepglobelandcover import DeepGlobeLandCover from .dfc2022 import DFC2022 +from .digital_typhoon import DigitalTyphoon from .eddmaps import EDDMapS from .enviroatlas import EnviroAtlas from .errors import DatasetNotFoundError, DependencyNotFoundError, RGBBandsMissingError @@ -46,6 +47,7 @@ from .fair1m import FAIR1M from .fire_risk import FireRisk from .forestdamage import ForestDamage +from .ftw import FieldsOfTheWorld from .gbif import GBIF from .geo import ( GeoDataset, @@ -56,15 +58,17 @@ UnionDataset, VectorDataset, ) +from .geonrw import GeoNRW from .gid15 import GID15 from .globbiomass import GlobBiomass +from .hyspecnet import HySpecNet11k from .idtrees import IDTReeS from .inaturalist import INaturalist from .inria import InriaAerialImageLabeling from .iobench import IOBench from .l7irish import L7Irish from .l8biome import L8Biome -from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo +from .landcoverai import LandCoverAI, LandCoverAI100, LandCoverAIBase, LandCoverAIGeo from .landsat import ( Landsat, Landsat1, @@ -82,6 +86,7 @@ from .loveda import LoveDA from .mapinwild import MapInWild from .millionaid import MillionAID +from .mmearth import MMEarth from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris from .nccm import NCCM @@ -96,11 +101,13 @@ from .reforestree import ReforesTree from .resisc45 import RESISC45 from .rwanda_field_boundary import RwandaFieldBoundary +from .satlas import SatlasPretrain from .seasonet import SeasoNet from .seco import SeasonalContrastS2 from .sen12ms import SEN12MS from .sentinel import Sentinel, Sentinel1, Sentinel2 from .skippd import SKIPPD +from .skyscript import SkyScript from .so2sat import So2Sat from .south_africa_crop_type import SouthAfricaCropType from .south_america_soybean import SouthAmericaSoybean @@ -113,6 +120,7 @@ SpaceNet5, SpaceNet6, SpaceNet7, + SpaceNet8, ) from .splits import ( random_bbox_assignment, @@ -124,6 +132,7 @@ from .ssl4eo import SSL4EO, SSL4EOL, SSL4EOS12 from .ssl4eo_benchmark import SSL4EOLBenchmark from .sustainbench_crop_yield import SustainBenchCropYield +from .treesatai import TreeSatAI from .ucmerced import UCMerced from .usavars import USAVars from .utils import ( @@ -140,16 +149,47 @@ from .zuericrop import ZueriCrop __all__ = ( - # GeoDataset + 'ADVANCE', + 'CDL', + 'COWC', + 'DFC2022', + 'ETCI2021', + 'EUDEM', + 'FAIR1M', + 'GBIF', + 'GID15', + 'LEVIRCD', + 'NAIP', + 'NCCM', + 'NLCD', + 'OSCD', + 'PASTIS', + 'PRISMA', + 'RESISC45', + 'SEN12MS', + 'SKIPPD', + 'SSL4EO', + 'SSL4EOL', + 'SSL4EOS12', + 'VHR10', 'AbovegroundLiveWoodyBiomassDensity', 'AgriFieldNet', 'Airphen', 'AsterGDEM', + 'BeninSmallHolderCashews', + 'BigEarthNet', + 'BioMassters', + 'BoundingBox', + 'CMSGlobalMangroveCanopy', + 'COWCCounting', + 'COWCDetection', + 'CV4AKenyaCropType', + 'CaBuAr', + 'CaFFe', 'CanadianBuildingFootprints', - 'CDL', + 'ChaBuD', 'Chesapeake', - 'Chesapeake7', - 'Chesapeake13', + 'ChesapeakeCVPR', 'ChesapeakeDC', 'ChesapeakeDE', 'ChesapeakeMD', @@ -157,19 +197,37 @@ 'ChesapeakePA', 'ChesapeakeVA', 'ChesapeakeWV', - 'ChesapeakeCVPR', - 'CMSGlobalMangroveCanopy', + 'CloudCoverDetection', 'CropHarvest', + 'DatasetNotFoundError', + 'DeepGlobeLandCover', + 'DependencyNotFoundError', + 'DigitalTyphoon', 'EDDMapS', + 'EnviroAtlas', 'Esri2020', 'EuroCrops', - 'EUDEM', - 'GBIF', + 'EuroSAT', + 'EuroSAT100', + 'EuroSATSpatial', + 'FieldsOfTheWorld', + 'FireRisk', + 'ForestDamage', + 'GeoDataset', + 'GeoNRW', 'GlobBiomass', + 'HySpecNet11k', + 'IDTReeS', 'INaturalist', 'IOBench', + 'InriaAerialImageLabeling', + 'IntersectionDataset', 'L7Irish', 'L8Biome', + 'LEVIRCDBase', + 'LEVIRCDPlus', + 'LandCoverAI', + 'LandCoverAI100', 'LandCoverAIBase', 'LandCoverAIGeo', 'Landsat', @@ -183,61 +241,32 @@ 'Landsat7', 'Landsat8', 'Landsat9', - 'NAIP', - 'NCCM', - 'NLCD', - 'OpenBuildings', - 'PRISMA', - 'Sentinel', - 'Sentinel1', - 'Sentinel2', - 'SouthAfricaCropType', - 'SouthAmericaSoybean', - # NonGeoDataset - 'ADVANCE', - 'BeninSmallHolderCashews', - 'BigEarthNet', - 'BioMassters', - 'ChaBuD', - 'CloudCoverDetection', - 'COWC', - 'COWCCounting', - 'COWCDetection', - 'CV4AKenyaCropType', - 'DeepGlobeLandCover', - 'DFC2022', - 'EnviroAtlas', - 'ETCI2021', - 'EuroSAT', - 'EuroSATSpatial', - 'EuroSAT100', - 'FAIR1M', - 'FireRisk', - 'ForestDamage', - 'GID15', - 'IDTReeS', - 'InriaAerialImageLabeling', - 'LandCoverAI', - 'LEVIRCD', - 'LEVIRCDBase', - 'LEVIRCDPlus', 'LoveDA', + 'MMEarth', 'MapInWild', 'MillionAID', 'NASAMarineDebris', - 'OSCD', - 'PASTIS', + 'NonGeoClassificationDataset', + 'NonGeoDataset', + 'OpenBuildings', 'PatternNet', 'Potsdam2D', 'QuakeSet', - 'RESISC45', + 'RGBBandsMissingError', + 'RasterDataset', 'ReforesTree', 'RwandaFieldBoundary', - 'SeasonalContrastS2', + 'SSL4EOLBenchmark', + 'SatlasPretrain', 'SeasoNet', - 'SEN12MS', - 'SKIPPD', + 'SeasonalContrastS2', + 'Sentinel', + 'Sentinel1', + 'Sentinel2', + 'SkyScript', 'So2Sat', + 'SouthAfricaCropType', + 'SouthAmericaSoybean', 'SpaceNet', 'SpaceNet1', 'SpaceNet2', @@ -246,42 +275,26 @@ 'SpaceNet5', 'SpaceNet6', 'SpaceNet7', - 'SSL4EO', - 'SSL4EOLBenchmark', - 'SSL4EOL', - 'SSL4EOS12', + 'SpaceNet8', 'SustainBenchCropYield', + 'TreeSatAI', 'TropicalCyclone', 'UCMerced', 'USAVars', + 'UnionDataset', 'Vaihingen2D', - 'VHR10', + 'VectorDataset', 'WesternUSALiveFuelMoisture', 'XView2', 'XView2DistShift' 'ZueriCrop', - # Base classes - 'GeoDataset', - 'IntersectionDataset', - 'NonGeoClassificationDataset', - 'NonGeoDataset', - 'RasterDataset', - 'UnionDataset', - 'VectorDataset', - # Utilities - 'BoundingBox', 'concat_samples', 'merge_samples', - 'stack_samples', - 'unbind_samples', - # Splits 'random_bbox_assignment', 'random_bbox_splitting', 'random_grid_cell_assignment', 'roi_split', + 'stack_samples', 'time_series_split', - # Errors - 'DatasetNotFoundError', - 'DependencyNotFoundError', - 'RGBBandsMissingError', + 'unbind_samples', ) diff --git a/torchgeo/datasets/advance.py b/torchgeo/datasets/advance.py index be0e5996b17..c9fcea22a01 100644 --- a/torchgeo/datasets/advance.py +++ b/torchgeo/datasets/advance.py @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive, lazy_import +from .utils import Path, download_and_extract_archive, lazy_import class ADVANCE(NonGeoDataset): @@ -63,14 +63,14 @@ class ADVANCE(NonGeoDataset): * `scipy `_ to load the audio files to tensors """ - urls = [ - 'https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1', - 'https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1', - ] - filenames = ['ADVANCE_vision.zip', 'ADVANCE_sound.zip'] - md5s = ['a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31'] - directories = ['vision', 'sound'] - classes = [ + urls = ( + 'https://zenodo.org/records/3828124/files/ADVANCE_vision.zip?download=1', + 'https://zenodo.org/records/3828124/files/ADVANCE_sound.zip?download=1', + ) + filenames = ('ADVANCE_vision.zip', 'ADVANCE_sound.zip') + md5s = ('a9e8748219ef5864d3b5a8979a67b471', 'a2d12f2d2a64f5c3d3a9d8c09aaf1c31') + directories = ('vision', 'sound') + classes: tuple[str, ...] = ( 'airport', 'beach', 'bridge', @@ -84,11 +84,11 @@ class ADVANCE(NonGeoDataset): 'sparse shrub land', 'sports land', 'train station', - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -119,7 +119,7 @@ def __init__( raise DatasetNotFoundError(self) self.files = self._load_files(self.root) - self.classes = sorted({f['cls'] for f in self.files}) + self.classes = tuple(sorted({f['cls'] for f in self.files})) self.class_to_idx: dict[str, int] = {c: i for i, c in enumerate(self.classes)} def __getitem__(self, index: int) -> dict[str, Tensor]: @@ -151,7 +151,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -169,7 +169,7 @@ def _load_files(self, root: str) -> list[dict[str, str]]: ] return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -185,7 +185,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target audio for a single image. Args: diff --git a/torchgeo/datasets/agb_live_woody_density.py b/torchgeo/datasets/agb_live_woody_density.py index e9a8ac844b9..aaef8db9751 100644 --- a/torchgeo/datasets/agb_live_woody_density.py +++ b/torchgeo/datasets/agb_live_woody_density.py @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import download_url +from .utils import Path, download_url class AbovegroundLiveWoodyBiomassDensity(RasterDataset): @@ -45,7 +45,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): is_image = False - url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' # noqa: E501 + url = 'https://opendata.arcgis.com/api/v3/datasets/e4bdbe8d6d8d4e32ace7d36a4aec7b93_0/downloads/data?format=geojson&spatialRefId=4326' base_filename = 'Aboveground_Live_Woody_Biomass_Density.geojson' @@ -57,7 +57,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -105,7 +105,7 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) download_url(self.url, self.paths, self.base_filename) with open(os.path.join(self.paths, self.base_filename)) as f: diff --git a/torchgeo/datasets/agrifieldnet.py b/torchgeo/datasets/agrifieldnet.py index 8be13a170cf..3624c1e193e 100644 --- a/torchgeo/datasets/agrifieldnet.py +++ b/torchgeo/datasets/agrifieldnet.py @@ -6,7 +6,7 @@ import os import re from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -14,9 +14,9 @@ from rasterio.crs import CRS from torch import Tensor -from .errors import RGBBandsMissingError +from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import RasterDataset -from .utils import BoundingBox +from .utils import BoundingBox, Path, which class AgriFieldNet(RasterDataset): @@ -51,28 +51,37 @@ class AgriFieldNet(RasterDataset): Dataset classes: - 0 - No-Data - 1 - Wheat - 2 - Mustard - 3 - Lentil - 4 - No Crop/Fallow - 5 - Green pea - 6 - Sugarcane - 8 - Garlic - 9 - Maize - 13 - Gram - 14 - Coriander - 15 - Potato - 16 - Berseem - 36 - Rice + * 0. No-Data + * 1. Wheat + * 2. Mustard + * 3. Lentil + * 4. No Crop/Fallow + * 5. Green pea + * 6. Sugarcane + * 8. Garlic + * 9. Maize + * 13. Gram + * 14. Coriander + * 15. Potato + * 16. Berseem + * 36. Rice If you use this dataset in your research, please cite the following dataset: * https://doi.org/10.34911/rdnt.wu92p1 + .. note:: + + This dataset requires the following additional library to be installed: + + * `azcopy `_: to download the + dataset from Source Cooperative. + .. versionadded:: 0.6 """ + url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_agrifieldnet_competition_v1' + filename_glob = 'ref_agrifieldnet_competition_v1_source_*_{}_10m.*' filename_regex = r""" ^ref_agrifieldnet_competition_v1_source_ @@ -80,8 +89,8 @@ class AgriFieldNet(RasterDataset): _(?PB[0-9A-Z]{2})_10m """ - rgb_bands = ['B04', 'B03', 'B02'] - all_bands = [ + rgb_bands = ('B04', 'B03', 'B02') + all_bands = ( 'B01', 'B02', 'B03', @@ -94,9 +103,9 @@ class AgriFieldNet(RasterDataset): 'B09', 'B11', 'B12', - ] + ) - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (255, 211, 0, 255), 2: (255, 37, 37, 255), @@ -115,12 +124,13 @@ class AgriFieldNet(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, classes: list[int] = list(cmap.keys()), bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, cache: bool = True, + download: bool = False, ) -> None: """Initialize a new AgriFieldNet dataset instance. @@ -134,9 +144,10 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version cache: if True, cache the dataset in memory + download: if True, download dataset and store it in the root directory Raises: - DatasetNotFoundError: If dataset is not found. + DatasetNotFoundError: If dataset is not found and *download* is False. """ assert ( set(classes) <= self.cmap.keys() @@ -144,17 +155,19 @@ def __init__( assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths - self.classes = classes - self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) - self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8) + self.download = download self.filename_glob = self.filename_glob.format(bands[0]) + self._verify() + super().__init__( paths=paths, crs=crs, bands=bands, transforms=transforms, cache=cache ) # Map chosen classes to ordinal numbers, all others mapped to background class - for v, k in enumerate(self.classes): + self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) + self.ordinal_cmap = torch.zeros((len(classes), 4), dtype=torch.uint8) + for v, k in enumerate(classes): self.ordinal_map[k] = v self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) @@ -167,7 +180,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: Returns: data, label, and field ids at that index """ - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[str], [hit.object for hit in hits]) @@ -207,7 +220,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample = { 'crs': self.crs, - 'bbox': query, + 'bounds': query, 'image': image.float(), 'mask': mask.long(), } @@ -217,6 +230,26 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if self.files: + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + assert isinstance(self.paths, str | os.PathLike) + os.makedirs(self.paths, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', f'{self.url}', self.paths, '--recursive=true') + def plot( self, sample: dict[str, Tensor], diff --git a/torchgeo/datasets/airphen.py b/torchgeo/datasets/airphen.py index 3b0caf607ed..12b8c38141c 100644 --- a/torchgeo/datasets/airphen.py +++ b/torchgeo/datasets/airphen.py @@ -40,8 +40,8 @@ class Airphen(RasterDataset): # Each camera measures a custom set of spectral bands chosen at purchase time. # Hiphen offers 8 bands to choose from, sorted from short to long wavelength. - all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8'] - rgb_bands = ['B4', 'B3', 'B1'] + all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8') + rgb_bands = ('B4', 'B3', 'B1') def plot( self, diff --git a/torchgeo/datasets/astergdem.py b/torchgeo/datasets/astergdem.py index 479d0f79ce8..c4ef23061b8 100644 --- a/torchgeo/datasets/astergdem.py +++ b/torchgeo/datasets/astergdem.py @@ -12,6 +12,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset +from .utils import Path class AsterGDEM(RasterDataset): @@ -47,7 +48,7 @@ class AsterGDEM(RasterDataset): def __init__( self, - paths: str | list[str] = 'data', + paths: Path | list[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, diff --git a/torchgeo/datasets/benin_cashews.py b/torchgeo/datasets/benin_cashews.py index 6682f42b7b7..4dd1ae927de 100644 --- a/torchgeo/datasets/benin_cashews.py +++ b/torchgeo/datasets/benin_cashews.py @@ -5,7 +5,7 @@ import json import os -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache import matplotlib.pyplot as plt @@ -19,10 +19,9 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import Path, which -# TODO: read geospatial information from stac.json files class BeninSmallHolderCashews(NonGeoDataset): r"""Smallholder Cashew Plantations in Benin dataset. @@ -30,8 +29,8 @@ class BeninSmallHolderCashews(NonGeoDataset): in the center of Benin. Each pixel is classified for Well-managed plantation, Poorly-managed plantation, No plantation and other classes. The labels are generated using a combination of ground data collection with a handheld GPS device, - and final corrections based on Airbus Pléiades imagery. See - `this website `__ for dataset details. + and final corrections based on Airbus Pléiades imagery. See `this website + `__ for dataset details. Specifically, the data consists of Sentinel 2 imagery from a 120 km\ :sup:`2`\ area in the center of Benin over 71 points in time from 11/05/2019 to 10/30/2020 @@ -47,97 +46,88 @@ class BeninSmallHolderCashews(NonGeoDataset): If you use this dataset in your research, please cite the following: - * https://doi.org/10.34911/rdnt.hfv20i + * https://beta.source.coop/technoserve/cashews-benin/ .. note:: This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. """ - dataset_id = 'ts_cashew_benin' - collection_ids = ['ts_cashew_benin_source', 'ts_cashew_benin_labels'] - image_meta = { - 'filename': 'ts_cashew_benin_source.tar.gz', - 'md5': '957272c86e518a925a4e0d90dab4f92d', - } - target_meta = { - 'filename': 'ts_cashew_benin_labels.tar.gz', - 'md5': 'f9d3f0c671427d852fae9b52a0ae0051', - } + url = 'https://radiantearth.blob.core.windows.net/mlhub/technoserve-cashew-benin' dates = ( - '2019_11_05', - '2019_11_10', - '2019_11_15', - '2019_11_20', - '2019_11_30', - '2019_12_05', - '2019_12_10', - '2019_12_15', - '2019_12_20', - '2019_12_25', - '2019_12_30', - '2020_01_04', - '2020_01_09', - '2020_01_14', - '2020_01_19', - '2020_01_24', - '2020_01_29', - '2020_02_08', - '2020_02_13', - '2020_02_18', - '2020_02_23', - '2020_02_28', - '2020_03_04', - '2020_03_09', - '2020_03_14', - '2020_03_19', - '2020_03_24', - '2020_03_29', - '2020_04_03', - '2020_04_08', - '2020_04_13', - '2020_04_18', - '2020_04_23', - '2020_04_28', - '2020_05_03', - '2020_05_08', - '2020_05_13', - '2020_05_18', - '2020_05_23', - '2020_05_28', - '2020_06_02', - '2020_06_07', - '2020_06_12', - '2020_06_17', - '2020_06_22', - '2020_06_27', - '2020_07_02', - '2020_07_07', - '2020_07_12', - '2020_07_17', - '2020_07_22', - '2020_07_27', - '2020_08_01', - '2020_08_06', - '2020_08_11', - '2020_08_16', - '2020_08_21', - '2020_08_26', - '2020_08_31', - '2020_09_05', - '2020_09_10', - '2020_09_15', - '2020_09_20', - '2020_09_25', - '2020_09_30', - '2020_10_10', - '2020_10_15', - '2020_10_20', - '2020_10_25', - '2020_10_30', + '20191105', + '20191110', + '20191115', + '20191120', + '20191130', + '20191205', + '20191210', + '20191215', + '20191220', + '20191225', + '20191230', + '20200104', + '20200109', + '20200114', + '20200119', + '20200124', + '20200129', + '20200208', + '20200213', + '20200218', + '20200223', + '20200228', + '20200304', + '20200309', + '20200314', + '20200319', + '20200324', + '20200329', + '20200403', + '20200408', + '20200413', + '20200418', + '20200423', + '20200428', + '20200503', + '20200508', + '20200513', + '20200518', + '20200523', + '20200528', + '20200602', + '20200607', + '20200612', + '20200617', + '20200622', + '20200627', + '20200702', + '20200707', + '20200712', + '20200717', + '20200722', + '20200727', + '20200801', + '20200806', + '20200811', + '20200816', + '20200821', + '20200826', + '20200831', + '20200905', + '20200910', + '20200915', + '20200920', + '20200925', + '20200930', + '20201010', + '20201015', + '20201020', + '20201025', + '20201030', ) all_bands = ( @@ -157,7 +147,7 @@ class BeninSmallHolderCashews(NonGeoDataset): ) rgb_bands = ('B04', 'B03', 'B02') - classes = [ + classes = ( 'No data', 'Well-managed planatation', 'Poorly-managed planatation', @@ -165,7 +155,7 @@ class BeninSmallHolderCashews(NonGeoDataset): 'Residential', 'Background', 'Uncertain', - ] + ) # Same for all tiles tile_height = 1186 @@ -173,15 +163,12 @@ class BeninSmallHolderCashews(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', chip_size: int = 256, stride: int = 128, - bands: tuple[str, ...] = all_bands, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, - verbose: bool = False, ) -> None: """Initialize a new Benin Smallholder Cashew Plantations Dataset instance. @@ -194,36 +181,31 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - verbose: if True, print messages when new tiles are loaded Raises: + AssertionError: If *bands* is invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) + assert set(bands) <= set(self.all_bands) self.root = root self.chip_size = chip_size self.stride = stride self.bands = bands self.transforms = transforms - self.checksum = checksum - self.verbose = verbose - - if download: - self._download(api_key) + self.download = download - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() # Calculate the indices that we will use over all tiles self.chips_metadata = [] - for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ - self.tile_height - self.chip_size + for y in [ + *list(range(0, self.tile_height - self.chip_size, stride)), + self.tile_height - self.chip_size, ]: - for x in list(range(0, self.tile_width - self.chip_size, stride)) + [ - self.tile_width - self.chip_size + for x in [ + *list(range(0, self.tile_width - self.chip_size, stride)), + self.tile_width - self.chip_size, ]: self.chips_metadata.append((y, x)) @@ -238,7 +220,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: """ y, x = self.chips_metadata[index] - img, transform, crs = self._load_all_imagery(self.bands) + img, transform, crs = self._load_all_imagery() labels = self._load_mask(transform) img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] @@ -266,92 +248,55 @@ def __len__(self) -> int: """ return len(self.chips_metadata) - def _validate_bands(self, bands: tuple[str, ...]) -> None: - """Validate list of bands. - - Args: - bands: user-provided tuple of bands to load - - Raises: - AssertionError: if ``bands`` is not a tuple - ValueError: if an invalid band name is provided - """ - assert isinstance(bands, tuple), 'The list of bands must be a tuple' - for band in bands: - if band not in self.all_bands: - raise ValueError(f"'{band}' is an invalid band name.") - @lru_cache(maxsize=128) - def _load_all_imagery( - self, bands: tuple[str, ...] = all_bands - ) -> tuple[Tensor, rasterio.Affine, CRS]: + def _load_all_imagery(self) -> tuple[Tensor, rasterio.Affine, CRS]: """Load all the imagery (across time) for the dataset. - Optionally allows for subsetting of the bands that are loaded. - - Args: - bands: tuple of bands to load - Returns: imagery of shape (70, number of bands, 1186, 1122) where 70 is the number of points in time, 1186 is the tile height, and 1122 is the tile width rasterio affine transform, mapping pixel coordinates to geo coordinates coordinate reference system of transform """ - if self.verbose: - print('Loading all imagery') - img = torch.zeros( len(self.dates), - len(bands), + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32, ) for date_index, date in enumerate(self.dates): - single_scene, transform, crs = self._load_single_scene(date, self.bands) + single_scene, transform, crs = self._load_single_scene(date) img[date_index] = single_scene return img, transform, crs @lru_cache(maxsize=128) - def _load_single_scene( - self, date: str, bands: tuple[str, ...] - ) -> tuple[Tensor, rasterio.Affine, CRS]: + def _load_single_scene(self, date: str) -> tuple[Tensor, rasterio.Affine, CRS]: """Load the imagery for a single date. - Optionally allows for subsetting of the bands that are loaded. - Args: date: date of the imagery to load - bands: bands to load Returns: Tensor containing a single image tile, rasterio affine transform, mapping pixel coordinates to geo coordinates, and coordinate reference system of transform. - - Raises: - AssertionError: if ``date`` is invalid """ - assert date in self.dates - - if self.verbose: - print(f'Loading imagery at {date}') - img = torch.zeros( - len(bands), self.tile_height, self.tile_width, dtype=torch.float32 + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32 ) for band_index, band_name in enumerate(self.bands): filepath = os.path.join( self.root, - 'ts_cashew_benin_source', - f'ts_cashew_benin_source_00_{date}', - f'{band_name}.tif', + 'imagery', + '00', + f'00_{date}', + f'00_{date}_{band_name}_10m.tif', ) with rasterio.open(filepath) as src: - transform = src.transform # same transform for every bands + transform = src.transform # same transform for every band crs = src.crs array = src.read().astype(np.float32) img[band_index] = torch.from_numpy(array) @@ -362,10 +307,7 @@ def _load_single_scene( def _load_mask(self, transform: rasterio.Affine) -> Tensor: """Rasterizes the dataset's labels (in geojson format).""" # Create a mask layer out of the geojson - mask_geojson_fn = os.path.join( - self.root, 'ts_cashew_benin_labels', '_common', 'labels.geojson' - ) - with open(mask_geojson_fn) as f: + with open(os.path.join(self.root, 'labels', '00.geojson')) as f: geojson = json.load(f) labels = [ @@ -385,44 +327,24 @@ def _load_mask(self, transform: rasterio.Affine) -> Tensor: mask = torch.from_numpy(mask_data).long() return mask - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - images: bool = check_integrity( - os.path.join(self.root, self.image_meta['filename']), - self.image_meta['md5'] if self.checksum else None, - ) - - targets: bool = check_integrity( - os.path.join(self.root, self.target_meta['filename']), - self.target_meta['md5'] if self.checksum else None, - ) - - return images and targets - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - RuntimeError: if download doesn't work correctly or checksums don't match - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(os.path.join(self.root, 'labels', '00.geojson')): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() - image_archive_path = os.path.join(self.root, self.image_meta['filename']) - target_archive_path = os.path.join(self.root, self.target_meta['filename']) - for fn in [image_archive_path, target_archive_path]: - extract_archive(fn, self.root) + def _download(self) -> None: + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, @@ -454,9 +376,6 @@ def plot( else: raise RGBBandsMissingError() - num_time_points = sample['image'].shape[0] - assert time_step < num_time_points - image = np.rollaxis(sample['image'][time_step, rgb_indices].numpy(), 0, 3) image = np.clip(image / 3000, 0, 1) mask = sample['mask'].numpy() diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 075af089785..38669cd6ff1 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -7,6 +7,7 @@ import json import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -18,7 +19,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, sort_sentinel2_bands +from .utils import Path, download_url, extract_archive, sort_sentinel2_bands class BigEarthNet(NonGeoDataset): @@ -124,9 +125,9 @@ class BigEarthNet(NonGeoDataset): * https://doi.org/10.1109/IGARSS.2019.8900532 - """ # noqa: E501 + """ - class_sets = { + class_sets: ClassVar[dict[int, list[str]]] = { 19: [ 'Urban fabric', 'Industrial or commercial units', @@ -197,7 +198,7 @@ class BigEarthNet(NonGeoDataset): ], } - label_converter = { + label_converter: ClassVar[dict[int, int]] = { 0: 0, 1: 0, 2: 1, @@ -232,32 +233,32 @@ class BigEarthNet(NonGeoDataset): 42: 18, } - splits_metadata = { + splits_metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { - 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', # noqa: E501 + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/train.csv?inline=false', 'filename': 'bigearthnet-train.csv', 'md5': '623e501b38ab7b12fe44f0083c00986d', }, 'val': { - 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', # noqa: E501 + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/val.csv?inline=false', 'filename': 'bigearthnet-val.csv', 'md5': '22efe8ed9cbd71fa10742ff7df2b7978', }, 'test': { - 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', # noqa: E501 + 'url': 'https://git.tu-berlin.de/rsim/BigEarthNet-MM_19-classes_models/-/raw/9a5be07346ab0884b2d9517475c27ef9db9b5104/splits/test.csv?inline=false', 'filename': 'bigearthnet-test.csv', 'md5': '697fb90677e30571b9ac7699b7e5b432', }, } - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { 's1': { - 'url': 'https://bigearth.net/downloads/BigEarthNet-S1-v1.0.tar.gz', + 'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S1-v1.0.tar.gz', 'md5': '94ced73440dea8c7b9645ee738c5a172', 'filename': 'BigEarthNet-S1-v1.0.tar.gz', 'directory': 'BigEarthNet-S1-v1.0', }, 's2': { - 'url': 'https://bigearth.net/downloads/BigEarthNet-S2-v1.0.tar.gz', + 'url': 'https://zenodo.org/records/12687186/files/BigEarthNet-S2-v1.0.tar.gz', 'md5': '5a64e9ce38deb036a435a7b59494924c', 'filename': 'BigEarthNet-S2-v1.0.tar.gz', 'directory': 'BigEarthNet-v1.0', @@ -267,7 +268,7 @@ class BigEarthNet(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: str = 'all', num_classes: int = 19, @@ -486,7 +487,7 @@ def _verify(self) -> None: filepath = os.path.join(self.root, filename) self._extract(filepath) - def _download(self, url: str, filename: str, md5: str) -> None: + def _download(self, url: str, filename: Path, md5: str) -> None: """Download the dataset. Args: @@ -499,13 +500,13 @@ def _download(self, url: str, filename: str, md5: str) -> None: url, self.root, filename=filename, md5=md5 if self.checksum else None ) - def _extract(self, filepath: str) -> None: + def _extract(self, filepath: Path) -> None: """Extract the dataset. Args: filepath: path to file to be extracted """ - if not filepath.endswith('.csv'): + if not str(filepath).endswith('.csv'): extract_archive(filepath) def _onehot_labels_to_names( diff --git a/torchgeo/datasets/biomassters.py b/torchgeo/datasets/biomassters.py index bb975c8002b..70a53a4220a 100644 --- a/torchgeo/datasets/biomassters.py +++ b/torchgeo/datasets/biomassters.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import percentile_normalization +from .utils import Path, percentile_normalization class BioMassters(NonGeoDataset): @@ -40,7 +40,7 @@ class BioMassters(NonGeoDataset): * Sentinel 1 and Sentinel 2 data for each location * Sentinel 1 available for every month * Sentinel 2 available for almost every month - (not available for every month due to ESA aquisition halt over the region + (not available for every month due to ESA acquisition halt over the region during particular periods) If you use this dataset in your research, please cite the following paper: @@ -50,14 +50,14 @@ class BioMassters(NonGeoDataset): .. versionadded:: 0.5 """ - valid_splits = ['train', 'test'] + valid_splits = ('train', 'test') valid_sensors = ('S1', 'S2') metadata_filename = 'The_BioMassters_-_features_metadata.csv.csv' def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', sensors: Sequence[str] = ['S1', 'S2'], as_time_series: bool = False, @@ -167,7 +167,7 @@ def __len__(self) -> int: """ return len(self.df['num_index'].unique()) - def _load_input(self, filenames: list[str]) -> Tensor: + def _load_input(self, filenames: list[Path]) -> Tensor: """Load the input imagery at the index. Args: @@ -186,7 +186,7 @@ def _load_input(self, filenames: list[str]) -> Tensor: arr = np.concatenate(arr_list, axis=0) return torch.tensor(arr.astype(np.int32)) - def _load_target(self, filename: str) -> Tensor: + def _load_target(self, filename: Path) -> Tensor: """Load the target mask at the index. Args: @@ -196,7 +196,7 @@ def _load_target(self, filename: str) -> Tensor: target mask """ with rasterio.open(os.path.join(self.root, 'train_agbm', filename), 'r') as src: - arr: np.typing.NDArray[np.float_] = src.read() + arr: np.typing.NDArray[np.float64] = src.read() target = torch.from_numpy(arr).float() return target diff --git a/torchgeo/datasets/cabuar.py b/torchgeo/datasets/cabuar.py new file mode 100644 index 00000000000..69ca818a70e --- /dev/null +++ b/torchgeo/datasets/cabuar.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""CaBuAr dataset.""" + +import os +from collections.abc import Callable +from typing import ClassVar + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_url, lazy_import, percentile_normalization + + +class CaBuAr(NonGeoDataset): + """CaBuAr dataset. + + `CaBuAr `__ + is a dataset for Change detection for Burned area Delineation and part of + the splits are used for the ChaBuD ECML-PKDD 2023 Discovery Challenge. + + Dataset features: + + * Sentinel-2 multispectral imagery + * binary masks of burned areas + * 12 multispectral bands + * 424 pairs of pre and post images with 20 m per pixel resolution (512x512 px) + + Dataset format: + + * single hdf5 dataset containing images and masks + + Dataset classes: + + 0. no change + 1. burned area + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.1109/MGRS.2023.3292467 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `h5py `_ to load the dataset + + .. versionadded:: 0.6 + """ + + all_bands = ( + 'B01', + 'B02', + 'B03', + 'B04', + 'B05', + 'B06', + 'B07', + 'B08', + 'B8A', + 'B09', + 'B11', + 'B12', + ) + rgb_bands = ('B04', 'B03', 'B02') + folds: ClassVar[dict[str, list[object]]] = { + 'train': [1, 2, 3, 4], + 'val': [0], + 'test': ['chabud'], + } + urls = ( + 'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/512x512.hdf5', + 'https://huggingface.co/datasets/DarthReca/california_burned_areas/resolve/main/raw/patched/chabud_test.h5', + ) + filenames = ('512x512.hdf5', 'chabud_test.h5') + md5s = ('15d78fb825f9a81dad600db828d22c08', 'a70bb7e4a2788657c2354c4c3d9296fe') + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + bands: tuple[str, ...] = all_bands, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new CaBuAr dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", "test" + bands: the subset of bands to load + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If ``split`` or ``bands`` arguments are invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If h5py is not installed. + """ + lazy_import('h5py') + + assert split in self.folds + assert set(bands) <= set(self.all_bands) + + # Set the file index based on the split + file_index = 1 if split == 'test' else 0 + + self.root = root + self.split = split + self.bands = bands + self.transforms = transforms + self.download = download + self.checksum = checksum + self.filepath = os.path.join(root, self.filenames[file_index]) + self.band_indices = [self.all_bands.index(b) for b in bands] + + self._verify() + + self.uuids = self._load_uuids() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + sample containing image and mask + """ + image = self._load_image(index) + mask = self._load_target(index) + + sample = {'image': image, 'mask': mask} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.uuids) + + def _load_uuids(self) -> list[str]: + """Return the image uuids for the given split. + + Returns: + the image uuids + """ + h5py = lazy_import('h5py') + uuids = [] + with h5py.File(self.filepath, 'r') as f: + for k, v in f.items(): + if v.attrs['fold'] in self.folds[self.split] and 'pre_fire' in v.keys(): + uuids.append(k) + return sorted(uuids) + + def _load_image(self, index: int) -> Tensor: + """Load a single image. + + Args: + index: index to return + + Returns: + the image + """ + h5py = lazy_import('h5py') + uuid = self.uuids[index] + with h5py.File(self.filepath, 'r') as f: + pre_array = f[uuid]['pre_fire'][:] + post_array = f[uuid]['post_fire'][:] + + # index specified bands and concatenate + pre_array = pre_array[..., self.band_indices] + post_array = post_array[..., self.band_indices] + array = np.concatenate([pre_array, post_array], axis=-1).astype(np.float32) + + tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _load_target(self, index: int) -> Tensor: + """Load the target mask for a single image. + + Args: + index: index to return + + Returns: + the target mask + """ + h5py = lazy_import('h5py') + uuid = self.uuids[index] + with h5py.File(self.filepath, 'r') as f: + array = f[uuid]['mask'][:].astype(np.int32).squeeze(axis=-1) + + tensor = torch.from_numpy(array) + tensor = tensor.to(torch.long) + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + exists = [] + for filename in self.filenames: + filepath = os.path.join(self.root, filename) + exists.append(os.path.exists(filepath)) + + if all(exists): + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + for url, filename, md5 in zip(self.urls, self.filenames, self.md5s): + filepath = os.path.join(self.root, filename) + if not os.path.exists(filepath): + download_url( + url, + self.root, + filename=filename, + md5=md5 if self.checksum else None, + ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise ValueError("Dataset doesn't contain some of the RGB bands") + + mask = sample['mask'].numpy() + image_pre = sample['image'][: len(self.bands)][rgb_indices].numpy() + image_post = sample['image'][len(self.bands) :][rgb_indices].numpy() + image_pre = percentile_normalization(image_pre) + image_post = percentile_normalization(image_post) + + ncols = 3 + + showing_predictions = 'prediction' in sample + if showing_predictions: + prediction = sample['prediction'] + ncols += 1 + + fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) + + axs[0].imshow(np.transpose(image_pre, (1, 2, 0))) + axs[0].axis('off') + axs[1].imshow(np.transpose(image_post, (1, 2, 0))) + axs[1].axis('off') + axs[2].imshow(mask) + axs[2].axis('off') + + if showing_predictions: + axs[3].imshow(prediction) + axs[3].axis('off') + + if show_titles: + axs[0].set_title('Image Pre') + axs[1].set_title('Image Post') + axs[2].set_title('Mask') + if showing_predictions: + axs[3].set_title('Prediction') + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/caffe.py b/torchgeo/datasets/caffe.py new file mode 100644 index 00000000000..84ce4ac55d7 --- /dev/null +++ b/torchgeo/datasets/caffe.py @@ -0,0 +1,297 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""CaFFe dataset.""" + +import glob +import os +import textwrap +from collections.abc import Callable +from typing import ClassVar + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.colors import ListedColormap +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_and_extract_archive, extract_archive + + +class CaFFe(NonGeoDataset): + """CaFFe (CAlving Fronts and where to Find thEm) dataset. + + The `CaFFe `__ dataset is a + semantic segmentation dataset of marine-terminating glaciers. + + Dataset features: + + * 13,090 train, 2,241 validation, and 3,761 test images + * varying spatial resolution of 6-20m + * paired binary calving front segmentation masks + * paired multi-class land cover segmentation masks + + Dataset format: + + * images are single-channel pngs with dimension 512x512 + * segmentation masks are single-channel pngs + + Dataset classes: + + 0. N/A + 1. rock + 2. glacier + 3. ocean/ice melange + + If you use this dataset in your research, please cite the following paper: + + * https://essd.copernicus.org/articles/14/4287/2022/ + + .. versionadded:: 0.7 + """ + + valid_splits = ('train', 'val', 'test') + + zipfilename = 'caffe.zip' + + data_dir = 'caffe' + + image_dir = 'sar_images' + + mask_dirs = ('fronts', 'zones') + + url = 'https://huggingface.co/datasets/torchgeo/caffe/resolve/cc96e8418981ce0f03afc9beace6422fdd7142c4/caffe.zip' + + md5 = '9a92fd6f05af74fbc41602595a55df0d' + + px_class_values_zones: ClassVar[dict[int, str]] = { + 0: 'N/A', + 64: 'rock', + 127: 'glacier', + 254: 'ocean/ice melange', + } + + zone_class_colors = ('black', 'brown', 'lightgray', 'blue') + zone_cmap = ListedColormap(zone_class_colors) + + px_class_values_fronts: ClassVar[dict[int, str]] = {0: 'no front', 255: 'front'} + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new instance of CaFFe dataset. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``split`` argument is invalid + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert split in self.valid_splits, f'split must be one of {self.valid_splits}' + + self.root = root + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + self.fpaths = glob.glob( + os.path.join( + self.root, + self.zipfilename.replace('.zip', ''), + self.mask_dirs[1], + self.split, + '*.png', + ) + ) + + self.ordinal_map_zones = torch.zeros( + max(self.px_class_values_zones.keys()) + 1, dtype=torch.long + ) + for ordinal, px_class in enumerate(self.px_class_values_zones.keys()): + self.ordinal_map_zones[px_class] = ordinal + + self.ordinal_map_fronts = torch.zeros( + max(self.px_class_values_fronts.keys()) + 1, dtype=torch.long + ) + for ordinal, px_class in enumerate(self.px_class_values_fronts.keys()): + self.ordinal_map_fronts[px_class] = ordinal + + def __len__(self) -> int: + """Return the number of images in the dataset.""" + return len(self.fpaths) + + def __getitem__(self, idx: int) -> dict[str, Tensor]: + """Return the image and mask at the given index. + + Args: + idx: index of the image and mask to return + + Returns: + dict: a dict containing the image and mask + """ + zones_filename = os.path.basename(self.fpaths[idx]) + img_filename = zones_filename.replace('_zones_', '_') + front_filename = zones_filename.replace('_zones_', '_front_') + + def read_tensor(path: str) -> Tensor: + return torch.from_numpy(np.array(Image.open(path))) + + img_path = os.path.join( + self.root, self.data_dir, self.image_dir, self.split, img_filename + ) + img = read_tensor(img_path).unsqueeze(0).float() + + front_mask = read_tensor( + os.path.join( + self.root, self.data_dir, self.mask_dirs[0], self.split, front_filename + ) + ).long() + + zone_mask = read_tensor( + os.path.join( + self.root, self.data_dir, self.mask_dirs[1], self.split, zones_filename + ) + ).long() + + zone_mask = self.ordinal_map_zones[zone_mask] + front_mask = self.ordinal_map_fronts[front_mask] + + sample = {'image': img, 'mask_front': front_mask, 'mask_zones': zone_mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + exists = [] + if os.path.exists( + os.path.join( + self.root, + self.zipfilename.replace('.zip', ''), + self.image_dir, + self.split, + ) + ): + exists.append(True) + else: + exists.append(False) + + for mask_dir in self.mask_dirs: + if os.path.exists( + os.path.join( + self.root, + self.zipfilename.replace('.zip', ''), + mask_dir, + self.split, + ) + ): + exists.append(True) + else: + exists.append(False) + + if all(exists): + return + + # check download of zipfile + if os.path.exists(os.path.join(self.root, self.zipfilename)): + self._extract() + return + + if not self.download: + raise DatasetNotFoundError(self) + + self._download() + + def _download(self) -> None: + """Download the dataset.""" + download_and_extract_archive( + self.url, + self.root, + filename=self.zipfilename, + md5=self.md5 if self.checksum else None, + ) + + def _extract(self) -> None: + """Extract the dataset.""" + extract_archive(os.path.join(self.root, self.zipfilename), self.root) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`CaFFe.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + if 'prediction' in sample: + ncols = 4 + else: + ncols = 3 + fig, axs = plt.subplots(1, ncols, figsize=(15, 5)) + + axs[0].imshow(sample['image'].permute(1, 2, 0).numpy()) + axs[0].axis('off') + + axs[1].imshow(sample['mask_front'].numpy(), cmap='gray') + axs[1].axis('off') + + unique_classes = np.unique(sample['mask_zones'].numpy()) + axs[2].imshow(sample['mask_zones'].numpy(), cmap=self.zone_cmap) + axs[2].axis('off') + + handles = [ + mpatches.Patch( + color=self.zone_cmap(ordinal), + label='\n'.join( + textwrap.wrap(self.px_class_values_zones[px_class], width=10) + ), + ) + for ordinal, px_class in enumerate(self.px_class_values_zones.keys()) + if ordinal in unique_classes + ] + axs[2].legend(handles=handles, loc='upper right', bbox_to_anchor=(1.4, 1)) + + if show_titles: + axs[0].set_title('Image') + axs[1].set_title('Front Mask') + axs[2].set_title('Zone Mask') + + if 'prediction' in sample: + axs[3].imshow(sample['prediction'].numpy(), cmap='gray') + axs[3].axis('off') + if show_titles: + axs[3].set_title('Prediction') + + if suptitle: + fig.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/cbf.py b/torchgeo/datasets/cbf.py index 2c8105b21f8..3c986eb44c1 100644 --- a/torchgeo/datasets/cbf.py +++ b/torchgeo/datasets/cbf.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import check_integrity, download_and_extract_archive +from .utils import Path, check_integrity, download_and_extract_archive class CanadianBuildingFootprints(VectorDataset): @@ -29,7 +29,7 @@ class CanadianBuildingFootprints(VectorDataset): # https://github.com/microsoft/CanadianBuildingFootprints/issues/11 url = 'https://usbuildingdata.blob.core.windows.net/canadian-buildings-v2/' - provinces_territories = [ + provinces_territories = ( 'Alberta', 'BritishColumbia', 'Manitoba', @@ -43,8 +43,8 @@ class CanadianBuildingFootprints(VectorDataset): 'Quebec', 'Saskatchewan', 'YukonTerritory', - ] - md5s = [ + ) + md5s = ( '8b4190424e57bb0902bd8ecb95a9235b', 'fea05d6eb0006710729c675de63db839', 'adf11187362624d68f9c69aaa693c46f', @@ -58,11 +58,11 @@ class CanadianBuildingFootprints(VectorDataset): '9ff4417ae00354d39a0cf193c8df592c', 'a51078d8e60082c7d3a3818240da6dd5', 'c11f3bd914ecabd7cac2cb2871ec0261', - ] + ) def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.00001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -104,7 +104,7 @@ def _check_integrity(self) -> bool: Returns: True if dataset files are found and/or MD5s match, else False """ - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): filepath = os.path.join(self.paths, prov_terr + '.zip') if not check_integrity(filepath, md5 if self.checksum else None): @@ -116,7 +116,7 @@ def _download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) for prov_terr, md5 in zip(self.provinces_territories, self.md5s): download_and_extract_archive( self.url + prov_terr + '.zip', diff --git a/torchgeo/datasets/cdl.py b/torchgeo/datasets/cdl.py index b2e43d7a1d4..0b0f6ac5b3d 100644 --- a/torchgeo/datasets/cdl.py +++ b/torchgeo/datasets/cdl.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -14,14 +14,14 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class CDL(RasterDataset): """Cropland Data Layer (CDL) dataset. The `Cropland Data Layer - `__, hosted on + `__, hosted on `CropScape `_, provides a raster, geo-referenced, crop-specific land cover map for the continental United States. The CDL also includes a crop mask layer and planting frequency layers, as well as @@ -36,8 +36,8 @@ class CDL(RasterDataset): If you use this dataset in your research, please cite it using the following format: - * https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#Section1_14.0 - """ # noqa: E501 + * https://www.nass.usda.gov/Research_and_Science/Cropland/sarsfaqs2.php#what.1 + """ filename_glob = '*_30m_cdls.tif' filename_regex = r""" @@ -48,8 +48,8 @@ class CDL(RasterDataset): date_format = '%Y' is_image = False - url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' # noqa: E501 - md5s = { + url = 'https://www.nass.usda.gov/Research_and_Science/Cropland/Release/datasets/{}_30m_cdls.zip' + md5s: ClassVar[dict[int, str]] = { 2023: '8c7685d6278d50c554f934b16a6076b7', 2022: '754cf50670cdfee511937554785de3e6', 2021: '27606eab08fe975aa138baad3e5dfcd8', @@ -68,7 +68,7 @@ class CDL(RasterDataset): 2008: '0610f2f17ab60a9fbb3baeb7543993a4', } - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (255, 211, 0, 255), 2: (255, 37, 37, 255), @@ -207,7 +207,7 @@ class CDL(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2023], @@ -294,7 +294,7 @@ def _verify(self) -> None: # Check if the zip files have already been downloaded exists = [] - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) for year in self.years: pathname = os.path.join( self.paths, self.zipfile_glob.replace('*', str(year)) @@ -327,7 +327,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) for year in self.years: zipfile_name = self.zipfile_glob.replace('*', str(year)) pathname = os.path.join(self.paths, zipfile_name) diff --git a/torchgeo/datasets/chabud.py b/torchgeo/datasets/chabud.py index 905c2d8496e..ba773607a54 100644 --- a/torchgeo/datasets/chabud.py +++ b/torchgeo/datasets/chabud.py @@ -4,7 +4,8 @@ """ChaBuD dataset.""" import os -from collections.abc import Callable +from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -14,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, lazy_import, percentile_normalization +from .utils import Path, download_url, lazy_import, percentile_normalization class ChaBuD(NonGeoDataset): @@ -53,7 +54,7 @@ class ChaBuD(NonGeoDataset): .. versionadded:: 0.6 """ - all_bands = [ + all_bands = ( 'B01', 'B02', 'B03', @@ -66,18 +67,18 @@ class ChaBuD(NonGeoDataset): 'B09', 'B11', 'B12', - ] - rgb_bands = ['B04', 'B03', 'B02'] - folds = {'train': [1, 2, 3, 4], 'val': [0]} - url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' # noqa: E501 + ) + rgb_bands = ('B04', 'B03', 'B02') + folds: ClassVar[dict[str, list[int]]] = {'train': [1, 2, 3, 4], 'val': [0]} + url = 'https://hf.co/datasets/chabud-team/chabud-ecml-pkdd2023/resolve/de222d434e26379aa3d4f3dd1b2caf502427a8b2/train_eval.hdf5' filename = 'train_eval.hdf5' md5 = '15d78fb825f9a81dad600db828d22c08' def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', - bands: list[str] = all_bands, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 55dddd02cfd..459d096043c 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -"""Chesapeake Bay High-Resolution Land Cover Project datasets.""" +"""Cheasapeake Bay Program Land Use/Land Cover Data Project datasets.""" -import abc +import glob import os import sys +from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -26,72 +27,106 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset, RasterDataset from .nlcd import NLCD -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive -class Chesapeake(RasterDataset, abc.ABC): +class Chesapeake(RasterDataset, ABC): """Abstract base class for all Chesapeake datasets. - `Chesapeake Bay High-Resolution Land Cover Project - `_ - dataset. - - This dataset was collected by the Chesapeake Conservancy's Conservation Innovation - Center (CIC) in partnership with the University of Vermont and WorldView Solutions, - Inc. It consists of one-meter resolution land cover information for the Chesapeake - Bay watershed (~100,000 square miles of land). + `Chesapeake Bay Land Use and Land Cover (LULC) Database 2022 Edition + `_ + + The Chesapeake Bay Land Use and Land Cover Database (LULC) facilitates + characterization of the landscape and land change for and between discrete time + periods. The database was developed by the University of Vermont's Spatial Analysis + Laboratory in cooperation with Chesapeake Conservancy (CC) and U.S. Geological + Survey (USGS) as part of a 6-year Cooperative Agreement between Chesapeake + Conservancy and the U.S. Environmental Protection Agency (EPA) and a separate + Interagency Agreement between the USGS and EPA to provide geospatial support to the + Chesapeake Bay Program Office. + + The database contains one-meter 13-class Land Cover (LC) and 54-class Land Use/Land + Cover (LULC) for all counties within or adjacent to the Chesapeake Bay watershed for + 2013/14 and 2017/18, depending on availability of National Agricultural Imagery + Program (NAIP) imagery for each state. Additionally, 54 LULC classes are generalized + into 18 LULC classes for ease of visualization and communication of LULC trends. LC + change between discrete time periods, detected by spectral changes in NAIP imagery + and LiDAR, represents changes between the 12 land cover classes. LULC change uses LC + change to identify where changes are happening and then LC is translated to LULC to + represent transitions between the 54 LULC classes. The LULCC data is represented as + a LULC class change transition matrix which provides users acres of change between + multiple classes. It is organized by 18x18 and 54x54 LULC classes. The Chesapeake + Bay Water (CBW) indicates raster tabulations were performed for only areas that fall + inside the CBW boundary e.g., if user is interested in CBW portion of a county then + they will use LULC Matrix CBW. Conversely, if they are interested change transitions + across the entire county, they will use LULC Matrix. + + If you use this dataset in your research, please cite the following: + + * https://doi.org/10.5066/P981GV1L """ + url = 'https://hf.co/datasets/torchgeo/chesapeake/resolve/1e0370eda6a24d93af4153745e54fd383d015bf5/{state}_lulc_{year}_2022-Edition.zip' + filename_glob = '{state}_lulc_*_2022-Edition.tif' + filename_regex = r'^{state}_lulc_(?P\d{{4}})_2022-Edition\.tif$' + date_format = '%Y' is_image = False - # subclasses use the 13 class cmap by default - cmap = { - 0: (0, 0, 0, 0), - 1: (0, 197, 255, 255), - 2: (0, 168, 132, 255), - 3: (38, 115, 0, 255), - 4: (76, 230, 0, 255), - 5: (163, 255, 115, 255), - 6: (255, 170, 0, 255), - 7: (255, 0, 0, 255), - 8: (156, 156, 156, 255), - 9: (0, 0, 0, 255), - 10: (115, 115, 0, 255), - 11: (230, 230, 0, 255), - 12: (255, 255, 115, 255), - 13: (197, 0, 255, 255), - } - - @property - @abc.abstractmethod - def base_folder(self) -> str: - """Parent directory of dataset in URL.""" - - @property - @abc.abstractmethod - def filename(self) -> str: - """Filename to find/store dataset in.""" - @property - @abc.abstractmethod - def zipfile(self) -> str: - """Name of zipfile in download URL.""" + @abstractmethod + def md5s(self) -> dict[int, str]: + """Mapping between data year and zip file MD5.""" @property - @abc.abstractmethod - def md5(self) -> str: - """MD5 checksum to verify integrity of dataset.""" - - @property - def url(self) -> str: - """URL to download dataset from.""" - url = 'https://cicwebresources.blob.core.windows.net/chesapeakebaylandcover' - url += f'/{self.base_folder}/{self.zipfile}' - return url + def state(self) -> str: + """State abbreviation.""" + return self.__class__.__name__[-2:].lower() + + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { + 11: (0, 92, 230, 255), + 12: (0, 92, 230, 255), + 13: (0, 92, 230, 255), + 14: (0, 92, 230, 255), + 15: (0, 92, 230, 255), + 21: (0, 0, 0, 255), + 22: (235, 6, 2, 255), + 23: (89, 89, 89, 255), + 24: (138, 138, 136, 255), + 25: (138, 138, 136, 255), + 26: (138, 138, 136, 255), + 27: (115, 115, 0, 255), + 28: (233, 255, 190, 255), + 29: (255, 255, 115, 255), + 41: (38, 115, 0, 255), + 42: (56, 168, 0, 255), + 51: (255, 255, 115, 255), + 52: (255, 255, 115, 255), + 53: (255, 255, 115, 255), + 54: (170, 255, 0, 255), + 55: (170, 255, 0, 255), + 56: (170, 255, 0, 255), + 62: (77, 209, 148, 255), + 63: (77, 209, 148, 255), + 64: (56, 168, 0, 255), + 65: (38, 115, 0, 255), + 72: (186, 245, 217, 255), + 73: (186, 245, 217, 255), + 74: (56, 168, 0, 255), + 75: (38, 115, 0, 255), + 83: (255, 211, 127, 255), + 84: (255, 211, 127, 255), + 85: (255, 211, 127, 255), + 91: (0, 168, 132, 255), + 92: (0, 168, 132, 255), + 93: (0, 168, 132, 255), + 94: (56, 168, 0, 255), + 95: (38, 115, 0, 255), + 127: (255, 255, 255, 255), + } def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -99,7 +134,7 @@ def __init__( download: bool = False, checksum: bool = False, ) -> None: - """Initialize a new Dataset instance. + """Initialize a new Chesapeake instance. Args: paths: one or more root directories to search or files to load @@ -119,23 +154,15 @@ def __init__( .. versionchanged:: 0.5 *root* was renamed to *paths*. """ + self.filename_glob = self.filename_glob.format(state=self.state) + self.filename_regex = self.filename_regex.format(state=self.state) + self.paths = paths self.download = download self.checksum = checksum self._verify() - colors = [] - for i in range(len(self.cmap)): - colors.append( - ( - self.cmap[i][0] / 255.0, - self.cmap[i][1] / 255.0, - self.cmap[i][2] / 255.0, - ) - ) - self._cmap = ListedColormap(colors) - super().__init__(paths, crs, res, transforms=transforms, cache=cache) def _verify(self) -> None: @@ -145,8 +172,8 @@ def _verify(self) -> None: return # Check if the zip file has already been downloaded - assert isinstance(self.paths, str) - if os.path.exists(os.path.join(self.paths, self.zipfile)): + assert isinstance(self.paths, str | os.PathLike) + if glob.glob(os.path.join(self.paths, '**', '*.zip'), recursive=True): self._extract() return @@ -160,12 +187,16 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" - download_url(self.url, self.paths, filename=self.zipfile, md5=self.md5) + for year, md5 in self.md5s.items(): + url = self.url.format(state=self.state, year=year) + print(url) + download_url(url, self.paths, md5=md5 if self.checksum else None) def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) - extract_archive(os.path.join(self.paths, self.zipfile)) + assert isinstance(self.paths, str | os.PathLike) + for file in glob.iglob(os.path.join(self.paths, '**', '*.zip'), recursive=True): + extract_archive(file) def plot( self, @@ -187,48 +218,32 @@ def plot( Method now takes a sample dict, not a Tensor. Additionally, possible to show subplot titles and/or use a custom suptitle. """ + cmap = torch.zeros(max(self.cmap) + 1, 4, dtype=torch.uint8) + for key, value in self.cmap.items(): + cmap[key] = torch.tensor(value) + mask = sample['mask'].squeeze(0) + mask = cmap[mask] ncols = 1 showing_predictions = 'prediction' in sample if showing_predictions: pred = sample['prediction'].squeeze(0) + pred = cmap[pred] ncols = 2 - fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4 * ncols, 4)) + fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(4 * ncols, 4)) - if showing_predictions: - axs[0].imshow( - mask, - vmin=0, - vmax=self._cmap.N - 1, - cmap=self._cmap, - interpolation='none', - ) - axs[0].axis('off') - axs[1].imshow( - pred, - vmin=0, - vmax=self._cmap.N - 1, - cmap=self._cmap, - interpolation='none', - ) - axs[1].axis('off') - if show_titles: - axs[0].set_title('Mask') - axs[1].set_title('Prediction') + axs[0, 0].imshow(mask) + axs[0, 0].axis('off') + if show_titles: + axs[0, 0].set_title('Mask') - else: - axs.imshow( - mask, - vmin=0, - vmax=self._cmap.N - 1, - cmap=self._cmap, - interpolation='none', - ) - axs.axis('off') + if showing_predictions: + axs[0, 1].imshow(pred) + axs[0, 1].axis('off') if show_titles: - axs.set_title('Mask') + axs[0, 1].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) @@ -236,159 +251,67 @@ def plot( return fig -class Chesapeake7(Chesapeake): - """Complete 7-class dataset. - - This version of the dataset is composed of 7 classes: - - 0. No Data: Background values - 1. Water: All areas of open water including ponds, rivers, and lakes - 2. Tree Canopy and Shrubs: All woody vegetation including trees and shrubs - 3. Low Vegetation: Plant material less than 2 meters in height including lawns - 4. Barren: Areas devoid of vegetation consisting of natural earthen material - 5. Impervious Surfaces: Human-constructed surfaces less than 2 meters in height - 6. Impervious Roads: Impervious surfaces that are used for transportation - 7. Aberdeen Proving Ground: U.S. Army facility with no labels - """ - - base_folder = 'BAYWIDE' - filename = 'Baywide_7class_20132014.tif' - filename_glob = filename - zipfile = 'Baywide_7Class_20132014.zip' - md5 = '61a4e948fb2551840b6557ef195c2084' - - cmap = { - 0: (0, 0, 0, 0), - 1: (0, 197, 255, 255), - 2: (38, 115, 0, 255), - 3: (163, 255, 115, 255), - 4: (255, 170, 0, 255), - 5: (156, 156, 156, 255), - 6: (0, 0, 0, 255), - 7: (197, 0, 255, 255), - } - - -class Chesapeake13(Chesapeake): - """Complete 13-class dataset. - - This version of the dataset is composed of 13 classes: - - 0. No Data: Background values - 1. Water: All areas of open water including ponds, rivers, and lakes - 2. Wetlands: Low vegetation areas located along marine or estuarine regions - 3. Tree Canopy: Deciduous and evergreen woody vegetation over 3-5 meters in height - 4. Shrubland: Heterogeneous woody vegetation including shrubs and young trees - 5. Low Vegetation: Plant material less than 2 meters in height including lawns - 6. Barren: Areas devoid of vegetation consisting of natural earthen material - 7. Structures: Human-constructed objects made of impervious materials - 8. Impervious Surfaces: Human-constructed surfaces less than 2 meters in height - 9. Impervious Roads: Impervious surfaces that are used for transportation - 10. Tree Canopy over Structures: Tree cover overlapping impervious structures - 11. Tree Canopy over Impervious Surfaces: Tree cover overlapping impervious surfaces - 12. Tree Canopy over Impervious Roads: Tree cover overlapping impervious roads - 13. Aberdeen Proving Ground: U.S. Army facility with no labels - """ - - base_folder = 'BAYWIDE' - filename = 'Baywide_13Class_20132014.tif' - filename_glob = filename - zipfile = 'Baywide_13Class_20132014.zip' - md5 = '7e51118923c91e80e6e268156d25a4b9' - - class ChesapeakeDC(Chesapeake): """This subset of the dataset contains data only for Washington, D.C.""" - base_folder = 'DC' - filename = os.path.join('DC_11001', 'DC_11001.img') - filename_glob = filename - zipfile = 'DC_11001.zip' - md5 = 'ed06ba7570d2955e8857d7d846c53b06' + md5s: ClassVar[dict[int, str]] = { + 2013: '9f1df21afbb9d5c0fcf33af7f6750a7f', + 2017: 'c45e4af2950e1c93ecd47b61af296d9b', + } class ChesapeakeDE(Chesapeake): """This subset of the dataset contains data only for Delaware.""" - base_folder = 'DE' - filename = 'DE_STATEWIDE.tif' - filename_glob = filename - zipfile = '_DE_STATEWIDE.zip' - md5 = '5e12eff3b6950c01092c7e480b38e544' + md5s: ClassVar[dict[int, str]] = { + 2013: '5850d96d897babba85610658aeb5951a', + 2018: 'ee94c8efeae423d898677104117bdebc', + } class ChesapeakeMD(Chesapeake): - """This subset of the dataset contains data only for Maryland. - - .. note:: - - This dataset requires the following additional library to be installed: - - * `zipfile-deflate64 `_ to extract - the proprietary deflate64 compressed zip file. - """ + """This subset of the dataset contains data only for Maryland.""" - base_folder = 'MD' - filename = 'MD_STATEWIDE.tif' - filename_glob = filename - zipfile = '_MD_STATEWIDE.zip' - md5 = '40c7cd697a887f2ffdb601b5c114e567' + md5s: ClassVar[dict[int, str]] = { + 2013: '9c3ca5040668d15284c1bd64b7d6c7a0', + 2018: '0647530edf8bec6e60f82760dcc7db9c', + } class ChesapeakeNY(Chesapeake): - """This subset of the dataset contains data only for New York. - - .. note:: - - This dataset requires the following additional library to be installed: + """This subset of the dataset contains data only for New York.""" - * `zipfile-deflate64 `_ to extract - the proprietary deflate64 compressed zip file. - """ - - base_folder = 'NY' - filename = 'NY_STATEWIDE.tif' - filename_glob = filename - zipfile = '_NY_STATEWIDE.zip' - md5 = '1100078c526616454ef2e508affda915' + md5s: ClassVar[dict[int, str]] = { + 2013: '38a29b721610ba661a7f8b6ec71a48b7', + 2017: '4c1b1a50fd9368cd7b8b12c4d80c63f3', + } class ChesapeakePA(Chesapeake): """This subset of the dataset contains data only for Pennsylvania.""" - base_folder = 'PA' - filename = 'PA_STATEWIDE.tif' - filename_glob = filename - zipfile = '_PA_STATEWIDE.zip' - md5 = '20a2a857c527a4dbadd6beed8b47e5ab' + md5s: ClassVar[dict[int, str]] = { + 2013: '86febd603a120a49ef7d23ef486152a3', + 2017: 'b11d92e4471e8cb887c790d488a338c1', + } class ChesapeakeVA(Chesapeake): - """This subset of the dataset contains data only for Virginia. + """This subset of the dataset contains data only for Virginia.""" - .. note:: - - This dataset requires the following additional library to be installed: - - * `zipfile-deflate64 `_ to extract - the proprietary deflate64 compressed zip file. - """ - - base_folder = 'VA' - filename = 'CIC2014_VA_STATEWIDE.tif' - filename_glob = filename - zipfile = '_VA_STATEWIDE.zip' - md5 = '6f2c97deaf73bb3e1ea9b21bd7a3fc8e' + md5s: ClassVar[dict[int, str]] = { + 2014: '49c9700c71854eebd00de24d8488eb7c', + 2018: '51731c8b5632978bfd1df869ea10db5b', + } class ChesapeakeWV(Chesapeake): """This subset of the dataset contains data only for West Virginia.""" - base_folder = 'WV' - filename = 'WV_STATEWIDE.tif' - filename_glob = filename - zipfile = '_WV_STATEWIDE.zip' - md5 = '350621ea293651fbc557a1c3e3c64cc3' + md5s: ClassVar[dict[int, str]] = { + 2014: '32fea42fae147bd58a83e3ea6cccfb94', + 2018: '80f25dcba72e39685ab33215c5d97292', + } class ChesapeakeCVPR(GeoDataset): @@ -406,23 +329,23 @@ class ChesapeakeCVPR(GeoDataset): additional layer of data to this dataset containing a prior over the Chesapeake Bay land cover classes generated from the NLCD land cover labels. For more information about this layer see `the dataset documentation - `_. + `_. If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/cvpr.2019.01301 """ - subdatasets = ['base', 'prior_extension'] - urls = { - 'base': 'https://lilablobssc.blob.core.windows.net/lcmcvpr2019/cvpr_chesapeake_landcover.zip', # noqa: E501 - 'prior_extension': 'https://zenodo.org/record/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', # noqa: E501 + subdatasets = ('base', 'prior_extension') + urls: ClassVar[dict[str, str]] = { + 'base': 'https://lilawildlife.blob.core.windows.net/lila-wildlife/lcmcvpr2019/cvpr_chesapeake_landcover.zip', + 'prior_extension': 'https://zenodo.org/records/5866525/files/cvpr_chesapeake_landcover_prior_extension.zip?download=1', } - filenames = { + filenames: ClassVar[dict[str, str]] = { 'base': 'cvpr_chesapeake_landcover.zip', 'prior_extension': 'cvpr_chesapeake_landcover_prior_extension.zip', } - md5s = { + md5s: ClassVar[dict[str, str]] = { 'base': '1225ccbb9590e9396875f221e5031514', 'prior_extension': '402f41d07823c8faf7ea6960d7c4e17a', } @@ -430,7 +353,7 @@ class ChesapeakeCVPR(GeoDataset): crs = CRS.from_epsg(3857) res = 1 - lc_cmap = { + lc_cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 0), 1: (0, 197, 255, 255), 2: (38, 115, 0, 255), @@ -450,7 +373,7 @@ class ChesapeakeCVPR(GeoDataset): ] ) - valid_layers = [ + valid_layers = ( 'naip-new', 'naip-old', 'landsat-leaf-on', @@ -459,8 +382,8 @@ class ChesapeakeCVPR(GeoDataset): 'lc', 'buildings', 'prior_from_cooccurrences_101_31_no_osm_no_buildings', - ] - states = ['de', 'md', 'va', 'wv', 'pa', 'ny'] + ) + states = ('de', 'md', 'va', 'wv', 'pa', 'ny') splits = ( [f'{state}-train' for state in states] + [f'{state}-val' for state in states] @@ -468,7 +391,7 @@ class ChesapeakeCVPR(GeoDataset): ) # these are used to check the integrity of the dataset - _files = [ + _files = ( 'de_1m_2013_extended-debuffered-test_tiles', 'de_1m_2013_extended-debuffered-train_tiles', 'de_1m_2013_extended-debuffered-val_tiles', @@ -488,18 +411,18 @@ class ChesapeakeCVPR(GeoDataset): 'wv_1m_2014_extended-debuffered-train_tiles', 'wv_1m_2014_extended-debuffered-val_tiles', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_buildings.tif', - 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', # noqa: E501 - 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', # noqa: E501 + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-off.tif', + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_landsat-leaf-on.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_lc.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-new.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_naip-old.tif', 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_nlcd.tif', - 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 + 'wv_1m_2014_extended-debuffered-val_tiles/m_3708035_ne_17_1_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', 'spatial_index.geojson', - ] + ) p_src_crs = pyproj.CRS('epsg:3857') - p_transformers = { + p_transformers: ClassVar[dict[str, CRS]] = { 'epsg:26917': pyproj.Transformer.from_crs( p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True ).transform, @@ -510,7 +433,7 @@ class ChesapeakeCVPR(GeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', splits: Sequence[str] = ['de-train'], layers: Sequence[str] = ['naip-new', 'lc'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -553,13 +476,11 @@ def __init__( lc_colors = np.zeros((max(self.lc_cmap.keys()) + 1, 4)) lc_colors[list(self.lc_cmap.keys())] = list(self.lc_cmap.values()) - lc_colors = lc_colors[:, :3] / 255 - self._lc_cmap = ListedColormap(lc_colors) + self._lc_cmap = ListedColormap(lc_colors[:, :3] / 255) nlcd_colors = np.zeros((max(NLCD.cmap.keys()) + 1, 4)) nlcd_colors[list(NLCD.cmap.keys())] = list(NLCD.cmap.values()) - nlcd_colors = nlcd_colors[:, :3] / 255 - self._nlcd_cmap = ListedColormap(nlcd_colors) + self._nlcd_cmap = ListedColormap(nlcd_colors[:, :3] / 255) # Add all tiles into the index in epsg:3857 based on the included geojson mint: float = 0 @@ -587,7 +508,7 @@ def __init__( 'lc': row['properties']['lc'], 'nlcd': row['properties']['nlcd'], 'buildings': row['properties']['buildings'], - 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, # noqa: E501 + 'prior_from_cooccurrences_101_31_no_osm_no_buildings': prior_fn, }, ) @@ -606,7 +527,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - sample = {'image': [], 'mask': [], 'crs': self.crs, 'bbox': query} + sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} if len(filepaths) == 0: raise IndexError( @@ -658,7 +579,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample['mask'] = np.concatenate(sample['mask'], axis=0) sample['image'] = torch.from_numpy(sample['image']).float() - sample['mask'] = torch.from_numpy(sample['mask']).long() + sample['mask'] = torch.from_numpy(sample['mask']).long().squeeze(0) if self.transforms is not None: sample = self.transforms(sample) @@ -668,7 +589,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" - def exists(filename: str) -> bool: + def exists(filename: Path) -> bool: return os.path.exists(os.path.join(self.root, filename)) # Check if the extracted files already exist diff --git a/torchgeo/datasets/cloud_cover.py b/torchgeo/datasets/cloud_cover.py index 2aeea01c568..e0ca0045e33 100644 --- a/torchgeo/datasets/cloud_cover.py +++ b/torchgeo/datasets/cloud_cover.py @@ -3,13 +3,13 @@ """Cloud Cover Detection Challenge dataset.""" -import json import os from collections.abc import Callable, Sequence -from typing import Any +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np +import pandas as pd import rasterio import torch from matplotlib.figure import Figure @@ -17,18 +17,16 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import Path, which -# TODO: read geospatial information from stac.json files class CloudCoverDetection(NonGeoDataset): - """Cloud Cover Detection Challenge dataset. + """Sentinel-2 Cloud Cover Segmentation Dataset. - This training dataset was generated as part of a - `crowdsourcing competition + This training dataset was generated as part of a `crowdsourcing competition `_ on DrivenData.org, and - later on was validated using a team of expert annotators. See - `this website `__ + later on was validated using a team of expert annotators. See `this website + `__ for dataset details. The dataset consists of Sentinel-2 satellite imagery and corresponding cloudy @@ -51,96 +49,52 @@ class CloudCoverDetection(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.4 """ - collection_ids = [ - 'ref_cloud_cover_detection_challenge_v1_train_source', - 'ref_cloud_cover_detection_challenge_v1_train_labels', - 'ref_cloud_cover_detection_challenge_v1_test_source', - 'ref_cloud_cover_detection_challenge_v1_test_labels', - ] - - image_meta = { - 'train': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_train_source.tar.gz', - 'md5': '32cfe38e313bcedc09dca3f0f9575eea', - }, - 'test': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_source.tar.gz', - 'md5': '6c67edae18716598d47298f24992db6c', - }, - } - - target_meta = { - 'train': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_train_labels.tar.gz', - 'md5': '695dfb1034924c10fbb17f9293815671', - }, - 'test': { - 'filename': 'ref_cloud_cover_detection_challenge_v1_test_labels.tar.gz', - 'md5': 'ec2b42bb43e9a03a01ae096f9e09db9c', - }, - } - - collection_names = { - 'train': [ - 'ref_cloud_cover_detection_challenge_v1_train_source', - 'ref_cloud_cover_detection_challenge_v1_train_labels', - ], - 'test': [ - 'ref_cloud_cover_detection_challenge_v1_test_source', - 'ref_cloud_cover_detection_challenge_v1_test_labels', - ], - } - - band_names = ['B02', 'B03', 'B04', 'B08'] - - rgb_bands = ['B04', 'B03', 'B02'] + url = 'https://radiantearth.blob.core.windows.net/mlhub/ref_cloud_cover_detection_challenge_v1/final' + all_bands = ('B02', 'B03', 'B04', 'B08') + rgb_bands = ('B04', 'B03', 'B02') + splits: ClassVar[dict[str, str]] = {'train': 'public', 'test': 'private'} def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', - bands: Sequence[str] = band_names, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: - """Initiatlize a new Cloud Cover Detection Dataset instance. + """Initiatlize a CloudCoverDetection instance. Args: root: root directory where dataset can be found - split: train/val/test split to load + split: 'train' or 'test' bands: the subset of bands to load transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: If *split* or *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ + assert split in self.splits + assert set(bands) <= set(self.all_bands) + self.root = root self.split = split - self.transforms = transforms - self.checksum = checksum - - self._validate_bands(bands) self.bands = bands + self.transforms = transforms + self.download = download - if download: - self._download(api_key) - - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self.csv = os.path.join(self.root, self.split, f'{self.split}_metadata.csv') + self._verify() - self.chip_paths = self._load_collections() + self.metadata = pd.read_csv(self.csv) def __len__(self) -> int: """Return the number of items in the dataset. @@ -148,7 +102,7 @@ def __len__(self) -> int: Returns: length of dataset in integer """ - return len(self.chip_paths) + return len(self.metadata) def __getitem__(self, index: int) -> dict[str, Tensor]: """Returns a sample from dataset. @@ -159,192 +113,65 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at given index """ - image = self._load_image(index) - label = self._load_target(index) - sample: dict[str, Tensor] = {'image': image, 'mask': label} + chip_id = self.metadata.iat[index, 0] + image = self._load_image(chip_id) + label = self._load_target(chip_id) + sample = {'image': image, 'mask': label} if self.transforms is not None: sample = self.transforms(sample) return sample - def _load_image(self, index: int) -> Tensor: + def _load_image(self, chip_id: str) -> Tensor: """Load all source images for a chip. Args: - index: position of the indexed chip + chip_id: ID of the chip. Returns: a tensor of stacked source image data """ - source_asset_paths = self.chip_paths[index]['source'] + path = os.path.join(self.root, self.split, f'{self.split}_features', chip_id) images = [] - for path in source_asset_paths: - with rasterio.open(path) as image_data: - image_array = image_data.read(1).astype(np.int32) - images.append(image_array) - image_stack: np.typing.NDArray[np.int_] = np.stack(images, axis=0) - image_tensor = torch.from_numpy(image_stack) - return image_tensor - - def _load_target(self, index: int) -> Tensor: + for band in self.bands: + with rasterio.open(os.path.join(path, f'{band}.tif')) as src: + images.append(src.read(1).astype(np.float32)) + return torch.from_numpy(np.stack(images, axis=0)) + + def _load_target(self, chip_id: str) -> Tensor: """Load label image for a chip. Args: - index: position of the indexed chip + chip_id: ID of the chip. Returns: a tensor of the label image data """ - label_asset_path = self.chip_paths[index]['target'][0] - with rasterio.open(label_asset_path) as target_data: - target_img = target_data.read(1).astype(np.int32) - - target_array: np.typing.NDArray[np.int_] = np.array(target_img) - target_tensor = torch.from_numpy(target_array) - return target_tensor - - @staticmethod - def _read_json_data(object_path: str) -> Any: - """Loads a JSON file. - - Args: - object_path: string path to the JSON file - - Returns: - json_data: JSON object / dictionary - - """ - with open(object_path) as read_contents: - json_data = json.load(read_contents) - return json_data - - def _load_items(self, item_json: str) -> dict[str, list[str]]: - """Loads the label item and corresponding source items. - - Args: - item_json: a string path to the item JSON file on disk - - Returns: - a dictionary with paths to the source and target TIF filenames - """ - item_meta = {} - - label_data = self._read_json_data(item_json) - label_asset_path = os.path.join( - os.path.split(item_json)[0], label_data['assets']['labels']['href'] - ) - item_meta['target'] = [label_asset_path] - - source_item_hrefs = [] - for link in label_data['links']: - if link['rel'] == 'source': - source_item_hrefs.append( - os.path.join(self.root, link['href'].replace('../../', '')) - ) - - source_item_hrefs = sorted(source_item_hrefs) - source_item_paths = [] - - for item_href in source_item_hrefs: - source_item_path = os.path.split(item_href)[0] - source_data = self._read_json_data(item_href) - source_item_assets = [] - for asset_key, asset_value in source_data['assets'].items(): - if asset_key in self.bands: - source_item_assets.append( - os.path.join(source_item_path, asset_value['href']) - ) - source_item_assets = sorted(source_item_assets) - for source_item_asset in source_item_assets: - source_item_paths.append(source_item_asset) - - item_meta['source'] = source_item_paths - return item_meta - - def _load_collections(self) -> list[dict[str, Any]]: - """Loads the paths to source and label assets for each collection. - - Returns: - a dictionary with lists of filepaths to all assets for each chip/item - - Raises: - RuntimeError if collection.json is not found in the uncompressed dataset - """ - indexed_chips = [] - label_collection: list[str] = [] - for c in self.collection_names[self.split]: - if 'label' in c: - label_collection.append(c) - label_collection_path = os.path.join(self.root, label_collection[0]) - label_collection_json = os.path.join(label_collection_path, 'collection.json') - - label_collection_item_hrefs = [] - for link in self._read_json_data(label_collection_json)['links']: - if link['rel'] == 'item': - label_collection_item_hrefs.append(link['href']) - - label_collection_item_hrefs = sorted(label_collection_item_hrefs) - - for label_href in label_collection_item_hrefs: - label_json = os.path.join(label_collection_path, label_href) - indexed_item = self._load_items(label_json) - indexed_chips.append(indexed_item) - - return indexed_chips - - def _validate_bands(self, bands: Sequence[str]) -> None: - """Validate list of bands. - - Args: - bands: user-provided tuple of bands to load - - Raises: - ValueError: if an invalid band name is provided - """ - for band in bands: - if band not in self.band_names: - raise ValueError(f"'{band}' is an invalid band name.") - - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - images: bool = check_integrity( - os.path.join(self.root, self.image_meta[self.split]['filename']), - self.image_meta[self.split]['md5'] if self.checksum else None, - ) - - targets: bool = check_integrity( - os.path.join(self.root, self.target_meta[self.split]['filename']), - self.target_meta[self.split]['md5'] if self.checksum else None, - ) - - return images and targets + path = os.path.join(self.root, self.split, f'{self.split}_labels') + with rasterio.open(os.path.join(path, f'{chip_id}.tif')) as src: + return torch.from_numpy(src.read(1).astype(np.int64)) + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(self.csv): + return - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') - return + # Download the dataset + self._download() - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) - - image_archive_path = os.path.join( - self.root, self.image_meta[self.split]['filename'] - ) - target_archive_path = os.path.join( - self.root, self.target_meta[self.split]['filename'] - ) - for fn in [image_archive_path, target_archive_path]: - extract_archive(fn, self.root) + def _download(self) -> None: + """Download the dataset.""" + directory = os.path.join(self.root, self.split) + os.makedirs(directory, exist_ok=True) + url = f'{self.url}/{self.splits[self.split]}' + azcopy = which('azcopy') + azcopy('sync', url, directory, '--recursive=true') def plot( self, diff --git a/torchgeo/datasets/cms_mangrove_canopy.py b/torchgeo/datasets/cms_mangrove_canopy.py index a1bfd0da56e..f9db256238d 100644 --- a/torchgeo/datasets/cms_mangrove_canopy.py +++ b/torchgeo/datasets/cms_mangrove_canopy.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import check_integrity, extract_archive +from .utils import Path, check_integrity, extract_archive class CMSGlobalMangroveCanopy(RasterDataset): @@ -24,7 +24,7 @@ class CMSGlobalMangroveCanopy(RasterDataset): consists of a single band map at 30m resolution of either aboveground biomass (agb), basal area weighted height (hba95), or maximum canopy height (hmax95). - The dataset needs to be manually dowloaded from the above link, where you can make + The dataset needs to be manually downloaded from the above link, where you can make an account and subsequently download the dataset. .. versionadded:: 0.3 @@ -41,7 +41,7 @@ class CMSGlobalMangroveCanopy(RasterDataset): zipfile = 'CMS_Global_Map_Mangrove_Canopy_1665.zip' md5 = '3e7f9f23bf971c25e828b36e6c5496e3' - all_countries = [ + all_countries = ( 'AndamanAndNicobar', 'Angola', 'Anguilla', @@ -163,13 +163,13 @@ class CMSGlobalMangroveCanopy(RasterDataset): 'VirginIslandsUs', 'WallisAndFutuna', 'Yemen', - ] + ) - measurements = ['agb', 'hba95', 'hmax95'] + measurements = ('agb', 'hba95', 'hmax95') def __init__( self, - paths: str | list[str] = 'data', + paths: Path | list[Path] = 'data', crs: CRS | None = None, res: float | None = None, measurement: str = 'agb', @@ -228,7 +228,7 @@ def _verify(self) -> None: return # Check if the zip file has already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile) if os.path.exists(pathname): if self.checksum and not check_integrity(pathname, self.md5): @@ -240,7 +240,7 @@ def _verify(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile) extract_archive(pathname) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 838123d39b2..fa97fa87037 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive +from .utils import Path, check_integrity, download_and_extract_archive class COWC(NonGeoDataset, abc.ABC): @@ -50,12 +50,12 @@ def base_url(self) -> str: @property @abc.abstractmethod - def filenames(self) -> list[str]: + def filenames(self) -> tuple[str, ...]: """List of files to download.""" @property @abc.abstractmethod - def md5s(self) -> list[str]: + def md5s(self) -> tuple[str, ...]: """List of MD5 checksums of files to download.""" @property @@ -65,7 +65,7 @@ def filename(self) -> str: def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -239,7 +239,7 @@ class COWCCounting(COWC): base_url = ( 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/counting/' ) - filenames = [ + filenames = ( 'COWC_train_list_64_class.txt.bz2', 'COWC_test_list_64_class.txt.bz2', 'COWC_Counting_Toronto_ISPRS.tbz', @@ -248,8 +248,8 @@ class COWCCounting(COWC): 'COWC_Counting_Vaihingen_ISPRS.tbz', 'COWC_Counting_Columbus_CSUAV_AFRL.tbz', 'COWC_Counting_Utah_AGRC.tbz', - ] - md5s = [ + ) + md5s = ( '187543d20fa6d591b8da51136e8ef8fb', '930cfd6e160a7b36db03146282178807', 'bc2613196dfa93e66d324ae43e7c1fdb', @@ -258,7 +258,7 @@ class COWCCounting(COWC): '4009c1e420566390746f5b4db02afdb9', 'daf8033c4e8ceebbf2c3cac3fabb8b10', '777ec107ed2a3d54597a739ce74f95ad', - ] + ) filename = 'COWC_{}_list_64_class.txt' @@ -268,7 +268,7 @@ class COWCDetection(COWC): base_url = ( 'https://gdo152.llnl.gov/cowc/download/cowc/datasets/patch_sets/detection/' ) - filenames = [ + filenames = ( 'COWC_train_list_detection.txt.bz2', 'COWC_test_list_detection.txt.bz2', 'COWC_Detection_Toronto_ISPRS.tbz', @@ -277,8 +277,8 @@ class COWCDetection(COWC): 'COWC_Detection_Vaihingen_ISPRS.tbz', 'COWC_Detection_Columbus_CSUAV_AFRL.tbz', 'COWC_Detection_Utah_AGRC.tbz', - ] - md5s = [ + ) + md5s = ( 'c954a5a3dac08c220b10cfbeec83893c', 'c6c2d0a78f12a2ad88b286b724a57c1a', '11af24f43b198b0f13c8e94814008a48', @@ -287,7 +287,7 @@ class COWCDetection(COWC): '23945d5b22455450a938382ccc2a8b27', 'f40522dc97bea41b10117d4a5b946a6f', '195da7c9443a939a468c9f232fd86ee3', - ] + ) filename = 'COWC_{}_list_detection.txt' diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py index 400b5ceb63c..bb3e4b3f3c5 100644 --- a/torchgeo/datasets/cropharvest.py +++ b/torchgeo/datasets/cropharvest.py @@ -7,6 +7,7 @@ import json import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -17,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, lazy_import +from .utils import Path, download_url, extract_archive, lazy_import class CropHarvest(NonGeoDataset): @@ -55,7 +56,7 @@ class CropHarvest(NonGeoDataset): """ # https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py - all_bands = [ + all_bands = ( 'VV', 'VH', 'B2', @@ -74,12 +75,12 @@ class CropHarvest(NonGeoDataset): 'elevation', 'slope', 'NDVI', - ] - rgb_bands = ['B4', 'B3', 'B2'] + ) + rgb_bands = ('B4', 'B3', 'B2') features_url = 'https://zenodo.org/records/7257688/files/features.tar.gz?download=1' labels_url = 'https://zenodo.org/records/7257688/files/labels.geojson?download=1' - file_dict = { + file_dict: ClassVar[dict[str, dict[str, str]]] = { 'features': { 'url': features_url, 'filename': 'features.tar.gz', @@ -96,7 +97,7 @@ class CropHarvest(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -157,7 +158,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_features(self, root: str) -> list[dict[str, str]]: + def _load_features(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -181,7 +182,7 @@ def _load_features(self, root: str) -> list[dict[str, str]]: files.append(dict(chip=chip_path, index=index, dataset=dataset)) return files - def _load_labels(self, root: str) -> pd.DataFrame: + def _load_labels(self, root: Path) -> pd.DataFrame: """Return the paths of the files in the dataset. Args: @@ -196,7 +197,7 @@ def _load_labels(self, root: str) -> pd.DataFrame: df = pd.json_normalize(data['features']) return df - def _load_array(self, path: str) -> Tensor: + def _load_array(self, path: Path) -> Tensor: """Load an individual single pixel time series. Args: diff --git a/torchgeo/datasets/cv4a_kenya_crop_type.py b/torchgeo/datasets/cv4a_kenya_crop_type.py index a532c1539c4..2248dab4292 100644 --- a/torchgeo/datasets/cv4a_kenya_crop_type.py +++ b/torchgeo/datasets/cv4a_kenya_crop_type.py @@ -3,9 +3,8 @@ """CV4A Kenya Crop Type dataset.""" -import csv import os -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import lru_cache import matplotlib.pyplot as plt @@ -17,16 +16,23 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import Path, which -# TODO: read geospatial information from stac.json files class CV4AKenyaCropType(NonGeoDataset): - """CV4A Kenya Crop Type dataset. + """CV4A Kenya Crop Type Competition dataset. - Used in a competition in the Computer NonGeo for Agriculture (CV4A) workshop in - ICLR 2020. See `this website `__ - for dataset details. + The `CV4A Kenya Crop Type Competition + `__ + dataset was produced as part of the Crop Type Detection competition at the + Computer Vision for Agriculture (CV4A) Workshop at the ICLR 2020 conference. + The objective of the competition was to create a machine learning model to + classify fields by crop type from images collected during the growing season + by the Sentinel-2 satellites. + + See the `dataset documentation + `__ + for details. Consists of 4 tiles of Sentinel 2 imagery from 13 different points in time. @@ -54,30 +60,13 @@ class CV4AKenyaCropType(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. """ - collection_ids = [ - 'ref_african_crops_kenya_02_labels', - 'ref_african_crops_kenya_02_source', - ] - image_meta = { - 'filename': 'ref_african_crops_kenya_02_source.tar.gz', - 'md5': '9c2004782f6dc83abb1bf45ba4d0da46', - } - target_meta = { - 'filename': 'ref_african_crops_kenya_02_labels.tar.gz', - 'md5': '93949abd0ae82ba564f5a933cefd8215', - } - - tile_names = [ - 'ref_african_crops_kenya_02_tile_00', - 'ref_african_crops_kenya_02_tile_01', - 'ref_african_crops_kenya_02_tile_02', - 'ref_african_crops_kenya_02_tile_03', - ] - dates = [ + url = 'https://radiantearth.blob.core.windows.net/mlhub/kenya-crop-challenge' + tiles = tuple(map(str, range(4))) + dates = ( '20190606', '20190701', '20190706', @@ -91,8 +80,8 @@ class CV4AKenyaCropType(NonGeoDataset): '20190924', '20191004', '20191103', - ] - band_names = ( + ) + all_bands = ( 'B01', 'B02', 'B03', @@ -107,8 +96,7 @@ class CV4AKenyaCropType(NonGeoDataset): 'B12', 'CLD', ) - - rgb_bands = ['B04', 'B03', 'B02'] + rgb_bands = ('B04', 'B03', 'B02') # Same for all tiles tile_height = 3035 @@ -116,15 +104,12 @@ class CV4AKenyaCropType(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', chip_size: int = 256, stride: int = 128, - bands: tuple[str, ...] = band_names, + bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, - verbose: bool = False, ) -> None: """Initialize a new CV4A Kenya Crop Type Dataset instance. @@ -137,37 +122,32 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - verbose: if True, print messages when new tiles are loaded Raises: + AssertionError: If *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) + assert set(bands) <= set(self.all_bands) self.root = root self.chip_size = chip_size self.stride = stride self.bands = bands self.transforms = transforms - self.checksum = checksum - self.verbose = verbose + self.download = download - if download: - self._download(api_key) - - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() # Calculate the indices that we will use over all tiles self.chips_metadata = [] - for tile_index in range(len(self.tile_names)): - for y in list(range(0, self.tile_height - self.chip_size, stride)) + [ - self.tile_height - self.chip_size + for tile_index in range(len(self.tiles)): + for y in [ + *list(range(0, self.tile_height - self.chip_size, stride)), + self.tile_height - self.chip_size, ]: - for x in list(range(0, self.tile_width - self.chip_size, stride)) + [ - self.tile_width - self.chip_size + for x in [ + *list(range(0, self.tile_width - self.chip_size, stride)), + self.tile_width - self.chip_size, ]: self.chips_metadata.append((tile_index, y, x)) @@ -181,10 +161,10 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data, labels, field ids, and metadata at that index """ tile_index, y, x = self.chips_metadata[index] - tile_name = self.tile_names[tile_index] + tile = self.tiles[tile_index] - img = self._load_all_image_tiles(tile_name, self.bands) - labels, field_ids = self._load_label_tile(tile_name) + img = self._load_all_image_tiles(tile) + labels, field_ids = self._load_label_tile(tile) img = img[:, :, y : y + self.chip_size, x : x + self.chip_size] labels = labels[y : y + self.chip_size, x : x + self.chip_size] @@ -213,193 +193,94 @@ def __len__(self) -> int: return len(self.chips_metadata) @lru_cache(maxsize=128) - def _load_label_tile(self, tile_name: str) -> tuple[Tensor, Tensor]: + def _load_label_tile(self, tile: str) -> tuple[Tensor, Tensor]: """Load a single _tile_ of labels and field_ids. Args: - tile_name: name of tile to load + tile: name of tile to load Returns: tuple of labels and field ids - - Raises: - AssertionError: if ``tile_name`` is invalid """ - assert tile_name in self.tile_names + directory = os.path.join(self.root, 'data', tile) - if self.verbose: - print(f'Loading labels/field_ids for {tile_name}') - - directory = os.path.join( - self.root, 'ref_african_crops_kenya_02_labels', tile_name + '_label' - ) - - with Image.open(os.path.join(directory, 'labels.tif')) as img: + with Image.open(os.path.join(directory, f'{tile}_label.tif')) as img: array: np.typing.NDArray[np.int_] = np.array(img) labels = torch.from_numpy(array) - with Image.open(os.path.join(directory, 'field_ids.tif')) as img: + with Image.open(os.path.join(directory, f'{tile}_field_id.tif')) as img: array = np.array(img) field_ids = torch.from_numpy(array) - return (labels, field_ids) - - def _validate_bands(self, bands: tuple[str, ...]) -> None: - """Validate list of bands. - - Args: - bands: user-provided tuple of bands to load - - Raises: - AssertionError: if ``bands`` is not a tuple - ValueError: if an invalid band name is provided - """ - assert isinstance(bands, tuple), 'The list of bands must be a tuple' - for band in bands: - if band not in self.band_names: - raise ValueError(f"'{band}' is an invalid band name.") + return labels, field_ids @lru_cache(maxsize=128) - def _load_all_image_tiles( - self, tile_name: str, bands: tuple[str, ...] = band_names - ) -> Tensor: + def _load_all_image_tiles(self, tile: str) -> Tensor: """Load all the imagery (across time) for a single _tile_. Optionally allows for subsetting of the bands that are loaded. Args: - tile_name: name of tile to load - bands: tuple of bands to load + tile: name of tile to load Returns: imagery of shape (13, number of bands, 3035, 2016) where 13 is the number of - points in time, 3035 is the tile height, and 2016 is the tile width - - Raises: - AssertionError: if ``tile_name`` is invalid + points in time, 3035 is the tile height, and 2016 is the tile width """ - assert tile_name in self.tile_names - - if self.verbose: - print(f'Loading all imagery for {tile_name}') - img = torch.zeros( len(self.dates), - len(bands), + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32, ) for date_index, date in enumerate(self.dates): - img[date_index] = self._load_single_image_tile(tile_name, date, self.bands) + img[date_index] = self._load_single_image_tile(tile, date) return img @lru_cache(maxsize=128) - def _load_single_image_tile( - self, tile_name: str, date: str, bands: tuple[str, ...] - ) -> Tensor: + def _load_single_image_tile(self, tile: str, date: str) -> Tensor: """Load the imagery for a single tile for a single date. - Optionally allows for subsetting of the bands that are loaded. - Args: - tile_name: name of tile to load + tile: name of tile to load date: date of tile to load - bands: bands to load Returns: array containing a single image tile - - Raises: - AssertionError: if ``tile_name`` or ``date`` is invalid """ - assert tile_name in self.tile_names - assert date in self.dates - - if self.verbose: - print(f'Loading imagery for {tile_name} at {date}') - + directory = os.path.join(self.root, 'data', tile, date) img = torch.zeros( - len(bands), self.tile_height, self.tile_width, dtype=torch.float32 + len(self.bands), self.tile_height, self.tile_width, dtype=torch.float32 ) for band_index, band_name in enumerate(self.bands): - filepath = os.path.join( - self.root, - 'ref_african_crops_kenya_02_source', - f'{tile_name}_{date}', - f'{band_name}.tif', - ) + filepath = os.path.join(directory, f'{tile}_{band_name}_{date}.tif') with Image.open(filepath) as band_img: array: np.typing.NDArray[np.int_] = np.array(band_img) img[band_index] = torch.from_numpy(array) return img - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - images: bool = check_integrity( - os.path.join(self.root, self.image_meta['filename']), - self.image_meta['md5'] if self.checksum else None, - ) - - targets: bool = check_integrity( - os.path.join(self.root, self.target_meta['filename']), - self.target_meta['md5'] if self.checksum else None, - ) - - return images and targets - - def get_splits(self) -> tuple[list[int], list[int]]: - """Get the field_ids for the train/test splits from the dataset directory. - - Returns: - list of training field_ids and list of testing field_ids - """ - train_field_ids = [] - test_field_ids = [] - splits_fn = os.path.join( - self.root, - 'ref_african_crops_kenya_02_labels', - '_common', - 'field_train_test_ids.csv', - ) - - with open(splits_fn, newline='') as f: - reader = csv.reader(f) - - # Skip header row - next(reader) - - for row in reader: - train_field_ids.append(int(row[0])) - if row[1]: - test_field_ids.append(int(row[1])) - - return train_field_ids, test_field_ids - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if os.path.exists(os.path.join(self.root, 'FieldIds.csv')): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() - image_archive_path = os.path.join(self.root, self.image_meta['filename']) - target_archive_path = os.path.join(self.root, self.target_meta['filename']) - for fn in [image_archive_path, target_archive_path]: - extract_archive(fn, self.root) + def _download(self) -> None: + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, @@ -439,13 +320,7 @@ def plot( image, mask = sample['image'], sample['mask'] - assert time_step <= image.shape[0] - 1, ( - 'The specified time step' - f' does not exist, image only contains {image.shape[0]} time' - ' instances.' - ) - - image = image[time_step, rgb_indices, :, :] + image = image[time_step, rgb_indices] fig, axs = plt.subplots(nrows=1, ncols=n_cols, figsize=(10, n_cols * 5)) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index eccca9d7314..2a21832703a 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -3,7 +3,6 @@ """Tropical Cyclone Wind Estimation Competition dataset.""" -import json import os from collections.abc import Callable from functools import lru_cache @@ -11,6 +10,7 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd import torch from matplotlib.figure import Figure from PIL import Image @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import Path, which class TropicalCyclone(NonGeoDataset): @@ -26,10 +26,9 @@ class TropicalCyclone(NonGeoDataset): A collection of tropical storms in the Atlantic and East Pacific Oceans from 2000 to 2019 with corresponding maximum sustained surface wind speed. This dataset is split - into training and test categories for the purpose of a competition. - - See https://www.drivendata.org/competitions/72/predict-wind-speeds/ for more - information about the competition. + into training and test categories for the purpose of a competition. Read more about + the competition here: + https://www.drivendata.org/competitions/72/predict-wind-speeds/. If you use this dataset in your research, please cite the following paper: @@ -39,43 +38,27 @@ class TropicalCyclone(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionchanged:: 0.4 Class name changed from TropicalCycloneWindEstimation to TropicalCyclone to be consistent with TropicalCycloneDataModule. """ - collection_id = 'nasa_tropical_storm_competition' - collection_ids = [ - 'nasa_tropical_storm_competition_train_source', - 'nasa_tropical_storm_competition_test_source', - 'nasa_tropical_storm_competition_train_labels', - 'nasa_tropical_storm_competition_test_labels', - ] - md5s = { - 'train': { - 'source': '97e913667a398704ea8d28196d91dad6', - 'labels': '97d02608b74c82ffe7496a9404a30413', - }, - 'test': { - 'source': '8d88099e4b310feb7781d776a6e1dcef', - 'labels': 'd910c430f90153c1f78a99cbc08e7bd0', - }, - } + url = ( + 'https://radiantearth.blob.core.windows.net/mlhub/nasa-tropical-storm-challenge' + ) size = 366 def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: - """Initialize a new Tropical Cyclone Wind Estimation Competition Dataset. + """Initialize a new TropicalCyclone instance. Args: root: root directory where dataset can be found @@ -83,30 +66,26 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in self.md5s + assert split in {'train', 'test'} self.root = root self.split = split self.transforms = transforms - self.checksum = checksum + self.download = download - if download: - self._download(api_key) + self.filename = f'{split}_set' + if split == 'train': + self.filename = f'{split}ing_set' - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() - output_dir = '_'.join([self.collection_id, split, 'source']) - filename = os.path.join(root, output_dir, 'collection.json') - with open(filename) as f: - self.collection = json.load(f)['links'] + self.features = pd.read_csv(os.path.join(root, f'{self.filename}_features.csv')) + self.labels = pd.read_csv(os.path.join(root, f'{self.filename}_labels.csv')) def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. @@ -117,15 +96,14 @@ def __getitem__(self, index: int) -> dict[str, Any]: Returns: data, labels, field ids, and metadata at that index """ - source_id = os.path.split(self.collection[index]['href'])[0] - directory = os.path.join( - self.root, - '_'.join([self.collection_id, self.split, '{0}']), - source_id.replace('source', '{0}'), - ) + sample = { + 'relative_time': torch.tensor(self.features.iat[index, 2]), + 'ocean': torch.tensor(self.features.iat[index, 3]), + 'label': torch.tensor(self.labels.iat[index, 1]), + } - sample: dict[str, Any] = {'image': self._load_image(directory)} - sample.update(self._load_features(directory)) + image_id = self.labels.iat[index, 0] + sample['image'] = self._load_image(image_id) if self.transforms is not None: sample = self.transforms(sample) @@ -138,19 +116,19 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.collection) + return len(self.labels) @lru_cache - def _load_image(self, directory: str) -> Tensor: + def _load_image(self, image_id: str) -> Tensor: """Load a single image. Args: - directory: directory containing image + image_id: Filename of the image. Returns: the image """ - filename = os.path.join(directory.format('source'), 'image.jpg') + filename = os.path.join(self.root, self.split, f'{image_id}.jpg') with Image.open(filename) as img: if img.height != self.size or img.width != self.size: # Moved in PIL 9.1.0 @@ -164,61 +142,30 @@ def _load_image(self, directory: str) -> Tensor: tensor = tensor.permute((2, 0, 1)).float() return tensor - def _load_features(self, directory: str) -> dict[str, Any]: - """Load features for a single image. - - Args: - directory: directory containing image - - Returns: - the features - """ - filename = os.path.join(directory.format('source'), 'features.json') - with open(filename) as f: - features: dict[str, Any] = json.load(f) - - filename = os.path.join(directory.format('labels'), 'labels.json') - with open(filename) as f: - features.update(json.load(f)) - - features['relative_time'] = int(features['relative_time']) - features['ocean'] = int(features['ocean']) - features['label'] = torch.tensor(int(features['wind_speed'])).float() - - return features - - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - for split, resources in self.md5s.items(): - for resource_type, md5 in resources.items(): - filename = '_'.join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename + '.tar.gz') - if not check_integrity(filename, md5 if self.checksum else None): - return False - return True - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv'] + exists = [os.path.exists(os.path.join(self.root, file)) for file in files] + if all(exists): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) - for split, resources in self.md5s.items(): - for resource_type in resources: - filename = '_'.join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename) + '.tar.gz' - extract_archive(filename, self.root) + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + directory = os.path.join(self.root, self.split) + os.makedirs(directory, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', f'{self.url}/{self.split}', directory, '--recursive=true') + files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv'] + for file in files: + azcopy('copy', f'{self.url}/{file}', self.root) def plot( self, diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index a986e43d308..47ef6f15ce3 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -16,6 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -60,7 +61,7 @@ class DeepGlobeLandCover(NonGeoDataset): If you use this dataset in your research, please cite the following paper: - * https://arxiv.org/pdf/1805.06561.pdf + * https://arxiv.org/pdf/1805.06561 .. note:: @@ -73,13 +74,13 @@ class DeepGlobeLandCover(NonGeoDataset): $ unzip deepglobe2018-landcover-segmentation-traindataset.zip .. versionadded:: 0.3 - """ # noqa: E501 + """ filename = 'data.zip' data_root = 'data' md5 = 'f32684b0b2bf6f8d604cd359a399c061' - splits = ['train', 'test'] - classes = [ + splits = ('train', 'test') + classes = ( 'Urban land', 'Agriculture land', 'Rangeland', @@ -87,8 +88,8 @@ class DeepGlobeLandCover(NonGeoDataset): 'Water', 'Barren land', 'Unknown', - ] - colormap = [ + ) + colormap = ( (0, 255, 255), (255, 255, 0), (255, 0, 255), @@ -96,11 +97,11 @@ class DeepGlobeLandCover(NonGeoDataset): (0, 0, 255), (255, 255, 255), (0, 0, 0), - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -245,12 +246,15 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample['image'], sample['mask'], alpha=alpha, colors=self.colormap + sample['image'], sample['mask'], alpha=alpha, colors=list(self.colormap) ) if 'prediction' in sample: ncols += 1 image2 = draw_semantic_segmentation_masks( - sample['image'], sample['prediction'], alpha=alpha, colors=self.colormap + sample['image'], + sample['prediction'], + alpha=alpha, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/dfc2022.py b/torchgeo/datasets/dfc2022.py index 697ddeb0fb5..6886ea6b1d3 100644 --- a/torchgeo/datasets/dfc2022.py +++ b/torchgeo/datasets/dfc2022.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -18,7 +19,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive, percentile_normalization +from .utils import Path, check_integrity, extract_archive, percentile_normalization class DFC2022(NonGeoDataset): @@ -42,7 +43,7 @@ class DFC2022(NonGeoDataset): * DEMs collected from the `IGN RGE ALTI database `_ * Labels collected from the - `UrbanAtlas 2012 database `_ + `UrbanAtlas 2012 database `_ * Data collected from 19 regions in France Dataset format: @@ -75,9 +76,9 @@ class DFC2022(NonGeoDataset): * https://doi.org/10.1007/s10994-020-05943-y .. versionadded:: 0.3 - """ # noqa: E501 + """ - classes = [ + classes = ( 'No information', 'Urban fabric', 'Industrial, commercial, public, military, private and transport units', @@ -94,8 +95,8 @@ class DFC2022(NonGeoDataset): 'Wetlands', 'Water', 'Clouds and Shadows', - ] - colormap = [ + ) + colormap = ( '#231F20', '#DB5F57', '#DB9757', @@ -112,8 +113,8 @@ class DFC2022(NonGeoDataset): '#579BDB', '#0062FF', '#231F20', - ] - metadata = { + ) + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'filename': 'labeled_train.zip', 'md5': '2e87d6a218e466dd0566797d7298c7a9', @@ -137,7 +138,7 @@ class DFC2022(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -224,7 +225,7 @@ def _load_files(self) -> list[dict[str, str]]: return files - def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor: + def _load_image(self, path: Path, shape: Sequence[int] | None = None) -> Tensor: """Load a single image. Args: @@ -235,13 +236,13 @@ def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor: the image """ with rasterio.open(path) as f: - array: np.typing.NDArray[np.float_] = f.read( + array: np.typing.NDArray[np.float64] = f.read( out_shape=shape, out_dtype='float32', resampling=Resampling.bilinear ) tensor = torch.from_numpy(array) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: @@ -306,7 +307,7 @@ def plot( ncols = 2 image = sample['image'][:3] image = image.to(torch.uint8) - image = image.permute(1, 2, 0).numpy() + image_arr = image.permute(1, 2, 0).numpy() dem = sample['image'][-1].numpy() dem = percentile_normalization(dem, lower=0, upper=100, axis=(0, 1)) @@ -325,7 +326,7 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 10)) - axs[0].imshow(image) + axs[0].imshow(image_arr) axs[0].axis('off') axs[1].imshow(dem) axs[1].axis('off') diff --git a/torchgeo/datasets/digital_typhoon.py b/torchgeo/datasets/digital_typhoon.py new file mode 100644 index 00000000000..42bb4caa1bd --- /dev/null +++ b/torchgeo/datasets/digital_typhoon.py @@ -0,0 +1,458 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Digital Typhoon dataset.""" + +import glob +import os +import tarfile +from collections.abc import Callable, Sequence +from typing import Any, ClassVar, TypedDict + +import matplotlib.pyplot as plt +import pandas as pd +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_url, lazy_import, percentile_normalization + + +class _SampleSequenceDict(TypedDict): + """Sample sequence dictionary.""" + + id: str + seq_id: list[int] + + +class DigitalTyphoon(NonGeoDataset): + """Digital Typhoon Dataset for Analysis Task. + + This dataset contains typhoon-centered images, derived from hourly infrared channel + images captured by meteorological satellites. It incorporates data from multiple + generations of the Himawari weather satellite, dating back to 1978. These images + have been transformed into brightness temperatures and adjusted for varying + satellite sensor readings, yielding a consistent spatio-temporal dataset that + covers over four decades. + + See `the Digital Typhoon website + `_ + for more information about the dataset. + + Dataset features: + + * infrared channel images from the Himawari weather satellite (512x512 px) + at 5km spatial resolution + * auxiliary features such as wind speed, pressure, and more that can be used + for regression or classification tasks + * 1,099 typhoons and 189,364 images + + Dataset format: + + * hdf5 files containing the infrared channel images + * .csv files containing the metadata for each image + + If you use this dataset in your research, please cite the following papers: + + * https://doi.org/10.20783/DIAS.664 + + .. versionadded:: 0.6 + """ + + valid_tasks = ('classification', 'regression') + aux_file_name = 'aux_data.csv' + + valid_features = ( + 'year', + 'month', + 'day', + 'hour', + 'grade', + 'lat', + 'lng', + 'pressure', + 'wind', + 'dir50', + 'long50', + 'short50', + 'dir30', + 'long30', + 'short30', + 'landfall', + 'intp', + ) + + url = 'https://hf.co/datasets/torchgeo/digital_typhoon/resolve/cf2f9ef89168d31cb09e42993d35b068688fe0df/WP.tar.gz{0}' + + md5sums: ClassVar[dict[str, str]] = { + 'aa': '3af98052aed17e0ddb1e94caca2582e2', + 'ab': '2c5d25455ac8aef1de33fe6456ab2c8d', + } + + min_input_clamp = 170.0 + max_input_clamp = 300.0 + + data_root = 'WP' + + def __init__( + self, + root: Path = 'data', + task: str = 'regression', + features: Sequence[str] = ['wind'], + targets: Sequence[str] = ['wind'], + sequence_length: int = 3, + min_feature_value: dict[str, float] | None = None, + max_feature_value: dict[str, float] | None = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Digital Typhoon dataset instance. + + Args: + root: root directory where dataset can be found + task: whether to load 'regression' or 'classification' labels + features: which auxiliary features to return + targets: which auxiliary features to use as targets + sequence_length: length of the sequence to return + min_feature_value: minimum value for each feature + max_feature_value: maximum value for each feature + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If any arguments are invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + DependencyNotFoundError: If h5py is not installed. + """ + lazy_import('h5py') + self.root = root + self.transforms = transforms + self.download = download + self.checksum = checksum + self.sequence_length = sequence_length + + self.min_feature_value = min_feature_value + self.max_feature_value = max_feature_value + + assert ( + task in self.valid_tasks + ), f'Please choose one of {self.valid_tasks}, you provided {task}.' + self.task = task + + assert set(features).issubset(set(self.valid_features)) + self.features = features + + assert set(targets).issubset(set(self.valid_features)) + self.targets = targets + + self._verify() + + self.aux_df = pd.read_csv( + os.path.join(root, self.data_root, self.aux_file_name) + ) + self.aux_df['datetime'] = pd.to_datetime( + self.aux_df[['year', 'month', 'day', 'hour']] + ) + + self.aux_df = self.aux_df.sort_values(['year', 'month', 'day', 'hour']) + self.aux_df['seq_id'] = self.aux_df.groupby(['id']).cumcount() + + self.aux_df.columns = [str(col) for col in self.aux_df.columns] + + # Compute the hour difference between consecutive images per typhoon id + self.aux_df['hour_diff_consecutive'] = ( + self.aux_df.sort_values(['id', 'datetime']) + .groupby('id')['datetime'] + .diff() + .dt.total_seconds() + / 3600 + ) + + # Compute the hour difference between the first and second entry + self.aux_df['hour_diff_to_next'] = ( + self.aux_df.groupby('id')['datetime'] + .shift(-1) + .sub(self.aux_df['datetime']) + .abs() + .dt.total_seconds() + / 3600 + ) + + self.aux_df['hour_diff'] = self.aux_df['hour_diff_consecutive'].combine_first( + self.aux_df['hour_diff_to_next'] + ) + self.aux_df.drop( + ['hour_diff_consecutive', 'hour_diff_to_next'], axis=1, inplace=True + ) + + # 0 hour difference is for the last time step of each typhoon sequence and want + # to keep only images that have max 1 hour difference + self.aux_df = self.aux_df[self.aux_df['hour_diff'] <= 1] + # Filter out all ids that only have less than sequence_length entries + self.aux_df = self.aux_df.groupby('id').filter( + lambda x: len(x) >= self.sequence_length + ) + + # Filter aux_df according to min_target_value + if self.min_feature_value is not None: + for feature, min_value in self.min_feature_value.items(): + self.aux_df = self.aux_df[self.aux_df[feature] >= min_value] + + # Filter aux_df according to max_target_value + if self.max_feature_value is not None: + for feature, max_value in self.max_feature_value.items(): + self.aux_df = self.aux_df[self.aux_df[feature] <= max_value] + + # collect target mean and std for each target + self.target_mean: dict[str, float] = self.aux_df[self.targets].mean().to_dict() + self.target_std: dict[str, float] = self.aux_df[self.targets].std().to_dict() + + def _get_subsequences(df: pd.DataFrame, k: int) -> list[dict[str, list[int]]]: + """Generate all possible subsequences of length k for a given group. + + Args: + df: grouped dataframe of a single typhoon + k: length of the subsequences to generate + + Returns: + list of all possible subsequences of length k for a given typhoon id + """ + min_seq_id = df['seq_id'].min() + max_seq_id = df['seq_id'].max() + + # generate possible subsquences of length k for group + subsequences = [ + {'id': df['id'].iloc[0], 'seq_id': list(range(i, i + k))} + for i in range(min_seq_id, max_seq_id - k + 2) + ] + return [ + subseq + for subseq in subsequences + if set(subseq['seq_id']).issubset(df['seq_id']) + ] + + self.sample_sequences: list[_SampleSequenceDict] = [ + item + for sublist in self.aux_df.groupby('id')[['seq_id', 'id']] + .apply(_get_subsequences, k=self.sequence_length) + .tolist() + for item in sublist + ] + + def __getitem__(self, index: int) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data, labels, and metadata at that index + """ + sample_entry = self.sample_sequences[index] + sample_df = self.aux_df[ + (self.aux_df['id'] == sample_entry['id']) + & (self.aux_df['seq_id'].isin(sample_entry['seq_id'])) + ] + + sample = {'image': self._load_image(sample_df)} + # load features of the last image in the sequence + sample.update( + self._load_features( + os.path.join( + self.root, + self.data_root, + 'metadata', + str(sample_df.iloc[-1]['id']) + '.csv', + ), + sample_df.iloc[-1]['image_path'], + ) + ) + + # torchgeo expects a single label + sample['label'] = torch.Tensor([sample[target] for target in self.targets]) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.sample_sequences) + + def _load_image(self, sample_df: pd.DataFrame) -> Tensor: + """Load a single image. + + Args: + sample_df: df holding all information necessary to load the + consecutive images in the sequence + + Returns: + concatenation of all images in the sequence over channel dimension + """ + + def load_image_tensor(id: str, filepath: str) -> Tensor: + """Load a single image tensor from a h5 file. + + Args: + id: typhoon id + filepath: path to the h5 file + + Returns: + image tensor + """ + h5py = lazy_import('h5py') + + full_path = os.path.join(self.root, self.data_root, 'image', id, filepath) + with h5py.File(full_path, 'r') as h5f: + # tensor with added channel dimension + tensor = torch.from_numpy(h5f['Infrared'][:]).unsqueeze(0) + + # follow normalization procedure + # https://github.com/kitamoto-lab/benchmarks/blob/1bdbefd7c570cb1bdbdf9e09f9b63f7c22bbdb27/analysis/regression/FrameDatamodule.py#L94 + tensor = torch.clamp(tensor, self.min_input_clamp, self.max_input_clamp) + tensor = (tensor - self.min_input_clamp) / ( + self.max_input_clamp - self.min_input_clamp + ) + return tensor + + # tensor of shape [sequence_length, height, width] + tensor = torch.cat( + [ + load_image_tensor(str(id), filepath) + for id, filepath in zip(sample_df['id'], sample_df['image_path']) + ] + ).float() + return tensor + + def _load_features(self, filepath: str, image_path: str) -> dict[str, Any]: + """Load features for the corresponding image. + + Args: + filepath: path of the feature file to load + image_path: image path for the unique image for which to retrieve features + + Returns: + features for image + """ + feature_df = pd.read_csv(filepath) + feature_df = feature_df[feature_df['file_1'] == image_path] + feature_dict = { + name: torch.tensor(feature_df[name].item()).float() + for name in self.features + } + # normalize the targets for regression + if self.task == 'regression': + for feature, mean in self.target_mean.items(): + feature_dict[feature] = ( + feature_dict[feature] - mean + ) / self.target_std[feature] + return feature_dict + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + exists = [] + path = os.path.join(self.root, self.data_root, 'image', '*', '*.h5') + if glob.glob(path): + exists.append(True) + else: + exists.append(False) + + # check if aux.csv file exists + exists.append( + os.path.exists(os.path.join(self.root, self.data_root, self.aux_file_name)) + ) + if all(exists): + return + + # Check if the tar.gz files have already been downloaded + exists = [] + for suffix in self.md5sums.keys(): + path = os.path.join(self.root, f'{self.data_root}.tar.gz{suffix}') + exists.append(os.path.exists(path)) + + if all(exists): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download amd extract the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + for suffix, md5 in self.md5sums.items(): + download_url( + self.url.format(suffix), self.root, md5=md5 if self.checksum else None + ) + + def _extract(self) -> None: + """Extract the dataset.""" + # Extract tarball + for suffix in self.md5sums.keys(): + with tarfile.open( + os.path.join(self.root, f'{self.data_root}.tar.gz{suffix}') + ) as tar: + tar.extractall(path=self.root) + + def plot( + self, + sample: dict[str, Any], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + image, label = sample['image'], sample['label'] + + image = percentile_normalization(image) + + showing_predictions = 'prediction' in sample + if showing_predictions: + prediction = sample['prediction'] + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + + ax.imshow(image.permute(1, 2, 0)) + ax.axis('off') + + if show_titles: + title_dict = { + label_name: label[idx].item() + for idx, label_name in enumerate(self.targets) + } + title = f'Label: {title_dict}' + if showing_predictions: + title_dict = { + label_name: prediction[idx].item() + for idx, label_name in enumerate(self.targets) + } + title += f'\nPrediction: {title_dict}' + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py index f30b75bcbeb..d3a046993a1 100644 --- a/torchgeo/datasets/eddmaps.py +++ b/torchgeo/datasets/eddmaps.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, disambiguate_timestamp +from .utils import BoundingBox, Path, disambiguate_timestamp class EDDMapS(GeoDataset): @@ -42,7 +42,7 @@ class EDDMapS(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = 'data') -> None: + def __init__(self, root: Path = 'data') -> None: """Initialize a new Dataset instance. Args: @@ -100,6 +100,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {'crs': self.crs, 'bbox': bboxes} + sample = {'crs': self.crs, 'bounds': bboxes} return sample diff --git a/torchgeo/datasets/enviroatlas.py b/torchgeo/datasets/enviroatlas.py index 0ca0a9bafe3..b8af4aef70e 100644 --- a/torchgeo/datasets/enviroatlas.py +++ b/torchgeo/datasets/enviroatlas.py @@ -6,7 +6,7 @@ import os import sys from collections.abc import Callable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import matplotlib.pyplot as plt @@ -23,7 +23,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class EnviroAtlas(GeoDataset): @@ -47,16 +47,16 @@ class EnviroAtlas(GeoDataset): .. versionadded:: 0.3 """ - url = 'https://zenodo.org/record/5778193/files/enviroatlas_lotp.zip?download=1' + url = 'https://zenodo.org/records/5778193/files/enviroatlas_lotp.zip?download=1' filename = 'enviroatlas_lotp.zip' md5 = 'bfe601be21c7c001315fc6154be8ef14' crs = CRS.from_epsg(3857) res = 1 - valid_prior_layers = ['prior', 'prior_no_osm_no_buildings'] + valid_prior_layers = ('prior', 'prior_no_osm_no_buildings') - valid_layers = [ + valid_layers = ( 'naip', 'nlcd', 'roads', @@ -65,14 +65,15 @@ class EnviroAtlas(GeoDataset): 'waterbodies', 'buildings', 'lc', - ] + valid_prior_layers + *valid_prior_layers, + ) - cities = [ + cities = ( 'pittsburgh_pa-2010_1m', 'durham_nc-2012_1m', 'austin_tx-2012_1m', 'phoenix_az-2010_1m', - ] + ) splits = ( [f'{state}-train' for state in cities[:1]] + [f'{state}-val' for state in cities[:1]] @@ -81,7 +82,7 @@ class EnviroAtlas(GeoDataset): ) # these are used to check the integrity of the dataset - _files = [ + _files = ( 'austin_tx-2012_1m-test_tiles-debuffered', 'austin_tx-2012_1m-val5_tiles-debuffered', 'durham_nc-2012_1m-test_tiles-debuffered', @@ -100,13 +101,13 @@ class EnviroAtlas(GeoDataset): 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_d_water.tif', 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_e_buildings.tif', 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_h_highres_labels.tif', - 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', # noqa: E501 - 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', # noqa: E501 + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31.tif', + 'austin_tx-2012_1m-test_tiles-debuffered/3009726_sw_prior_from_cooccurrences_101_31_no_osm_no_buildings.tif', 'spatial_index.geojson', - ] + ) p_src_crs = pyproj.CRS('epsg:3857') - p_transformers = { + p_transformers: ClassVar[dict[str, CRS]] = { 'epsg:26917': pyproj.Transformer.from_crs( p_src_crs, pyproj.CRS('epsg:26917'), always_xy=True ).transform, @@ -222,7 +223,7 @@ class EnviroAtlas(GeoDataset): dtype=np.uint8, ) - highres_classes = [ + highres_classes = ( 'Unclassified', 'Water', 'Impervious Surface', @@ -234,7 +235,7 @@ class EnviroAtlas(GeoDataset): 'Orchards', 'Woody Wetlands', 'Emergent Wetlands', - ] + ) highres_cmap = ListedColormap( [ [1.00000000, 1.00000000, 1.00000000], @@ -253,7 +254,7 @@ class EnviroAtlas(GeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', splits: Sequence[str] = ['pittsburgh_pa-2010_1m-train'], layers: Sequence[str] = ['naip', 'prior'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -347,7 +348,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[dict[str, str]], [hit.object for hit in hits]) - sample = {'image': [], 'mask': [], 'crs': self.crs, 'bbox': query} + sample = {'image': [], 'mask': [], 'crs': self.crs, 'bounds': query} if len(filepaths) == 0: raise IndexError( @@ -414,7 +415,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" - def exists(filename: str) -> bool: + def exists(filename: Path) -> bool: return os.path.exists(os.path.join(self.root, 'enviroatlas_lotp', filename)) # Check if the extracted files already exist diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py index 04157d28bab..197272d6b48 100644 --- a/torchgeo/datasets/esri2020.py +++ b/torchgeo/datasets/esri2020.py @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class Esri2020(RasterDataset): @@ -41,7 +41,7 @@ class Esri2020(RasterDataset): 9. Snow/Ice 10. Clouds - A more detailed explanation of the invidual classes can be found + A more detailed explanation of the individual classes can be found `here `_. If you use this dataset please cite the following paper: @@ -69,7 +69,7 @@ class Esri2020(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -112,7 +112,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile) if glob.glob(pathname): self._extract() @@ -132,7 +132,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) extract_archive(os.path.join(self.paths, self.zipfile)) def plot( diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 44ab7007f9f..7855c8bb3cf 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -16,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import Path, download_and_extract_archive class ETCI2021(NonGeoDataset): @@ -56,9 +57,9 @@ class ETCI2021(NonGeoDataset): the ETCI competition. """ - bands = ['VV', 'VH'] - masks = ['flood', 'water_body'] - metadata = { + bands = ('VV', 'VH') + masks = ('flood', 'water_body') + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'filename': 'train.zip', 'md5': '1e95792fe0f6e3c9000abdeab2a8ab0f', @@ -81,7 +82,7 @@ class ETCI2021(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -152,7 +153,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -193,7 +194,7 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -210,7 +211,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: diff --git a/torchgeo/datasets/eudem.py b/torchgeo/datasets/eudem.py index 9dc431ec1f6..5a9af7f6fa3 100644 --- a/torchgeo/datasets/eudem.py +++ b/torchgeo/datasets/eudem.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -14,19 +14,14 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import check_integrity, extract_archive +from .utils import Path, check_integrity, extract_archive class EUDEM(RasterDataset): """European Digital Elevation Model (EU-DEM) Dataset. - The `EU-DEM - `__ - dataset is a Digital Elevation Model of reference for the entire European region. - The dataset can be downloaded from this `website - `_ - after making an account. A dataset factsheet is available - `here `__. + `EU-DEM `__ + is a Digital Elevation Model of reference for the entire European region. Dataset features: @@ -40,10 +35,6 @@ class EUDEM(RasterDataset): * DEMs are single-channel tif files - If you use this dataset in your research, please give credit to: - - * `Copernicus `_ - .. versionadded:: 0.3 """ @@ -52,7 +43,7 @@ class EUDEM(RasterDataset): zipfile_glob = 'eu_dem_v11_*[A-Z0-9].zip' filename_regex = '(?P[eudem_v11]{10})_(?P[A-Z0-9]{6})' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'eu_dem_v11_E00N20.zip': '96edc7e11bc299b994e848050d6be591', 'eu_dem_v11_E10N00.zip': 'e14be147ac83eddf655f4833d55c1571', 'eu_dem_v11_E10N10.zip': '2eb5187e4d827245b33768404529c709', @@ -84,7 +75,7 @@ class EUDEM(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -125,7 +116,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile_glob) if glob.glob(pathname): for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/eurocrops.py b/torchgeo/datasets/eurocrops.py index daa1987e3c8..5f438143c87 100644 --- a/torchgeo/datasets/eurocrops.py +++ b/torchgeo/datasets/eurocrops.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import check_integrity, download_and_extract_archive, download_url +from .utils import Path, check_integrity, download_and_extract_archive, download_url class EuroCrops(VectorDataset): @@ -30,8 +30,9 @@ class EuroCrops(VectorDataset): is tagged with a "EC_hcat_n" attribute indicating the harmonized crop name grown within the polygon in the year associated with the shapefile. - If you use this dataset in your research, please follow the citation guidelines at - https://github.com/maja601/EuroCrops#reference. + If you use this dataset in your research, please follow the citation guidelines at: + + * https://github.com/maja601/EuroCrops#reference. .. versionadded:: 0.6 """ @@ -60,7 +61,7 @@ class EuroCrops(VectorDataset): date_format = '%Y' # Filename and md5 of files in this dataset on zenodo. - zenodo_files = [ + zenodo_files: tuple[tuple[str, str], ...] = ( ('AT_2021.zip', '490241df2e3d62812e572049fc0c36c5'), ('BE_VLG_2021.zip', 'ac4b9e12ad39b1cba47fdff1a786c2d7'), ('DE_LS_2021.zip', '6d94e663a3ff7988b32cb36ea24a724f'), @@ -80,11 +81,11 @@ class EuroCrops(VectorDataset): # Year is unknown for Romania portion (ny = no year). # We skip since it is inconsistent with the rest of the data. # ("RO_ny.zip", "648e1504097765b4b7f825decc838882"), - ] + ) def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS = CRS.from_epsg(4326), res: float = 0.00001, classes: list[str] | None = None, @@ -138,7 +139,7 @@ def _check_integrity(self) -> bool: if self.files and not self.checksum: return True - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) filepath = os.path.join(self.paths, self.hcat_fname) if not check_integrity(filepath, self.hcat_md5 if self.checksum else None): @@ -155,7 +156,7 @@ def _download(self) -> None: if self._check_integrity(): print('Files already downloaded and verified') return - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) download_url( self.base_url + self.hcat_fname, self.paths, @@ -177,7 +178,7 @@ def _load_class_map(self, classes: list[str] | None) -> None: (defaults to all classes) """ if not classes: - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) classes = [] filepath = os.path.join(self.paths, self.hcat_fname) with open(filepath) as f: @@ -243,10 +244,12 @@ def plot( fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(4, 4)) - def apply_cmap(arr: 'np.typing.NDArray[Any]') -> 'np.typing.NDArray[np.float_]': + def apply_cmap( + arr: 'np.typing.NDArray[Any]', + ) -> 'np.typing.NDArray[np.float64]': # Color 0 as black, while applying default color map for the class indices. cmap = plt.get_cmap('viridis') - im: np.typing.NDArray[np.float_] = cmap(arr / len(self.class_map)) + im: np.typing.NDArray[np.float64] = cmap(arr / len(self.class_map)) im[arr == 0] = 0 return im diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 9e6bc5a8909..5caef0ee4ee 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable, Sequence -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoClassificationDataset -from .utils import check_integrity, download_url, extract_archive, rasterio_loader +from .utils import Path, check_integrity, download_url, extract_archive, rasterio_loader class EuroSAT(NonGeoClassificationDataset): @@ -41,7 +41,7 @@ class EuroSAT(NonGeoClassificationDataset): * Permanent Crop * Residential Buildings * River - * SeaLake + * Sea & Lake This dataset uses the train/val/test splits defined in the "In-domain representation learning for remote sensing" paper: @@ -54,7 +54,7 @@ class EuroSAT(NonGeoClassificationDataset): * https://ieeexplore.ieee.org/document/8519248 """ - url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSATallBands.zip' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/1ce6f1bfb56db63fd91b6ecc466ea67f2509774c/' filename = 'EuroSATallBands.zip' md5 = '5ac12b3b2557aa56e1826e981e8e200e' @@ -63,13 +63,13 @@ class EuroSAT(NonGeoClassificationDataset): 'ds', 'images', 'remote_sensing', 'otherDatasets', 'sentinel_2', 'tif' ) - splits = ['train', 'val', 'test'] - split_urls = { - 'train': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-train.txt', # noqa: E501 - 'val': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-val.txt', # noqa: E501 - 'test': 'https://storage.googleapis.com/remote_sensing_representations/eurosat-test.txt', # noqa: E501 + splits = ('train', 'val', 'test') + split_filenames: ClassVar[dict[str, str]] = { + 'train': 'eurosat-train.txt', + 'val': 'eurosat-val.txt', + 'test': 'eurosat-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': '908f142e73d6acdf3f482c5e80d851b1', 'val': '95de90f2aa998f70a3b2416bfe0687b4', 'test': '7ae5ab94471417b6e315763121e67c5f', @@ -84,20 +84,23 @@ class EuroSAT(NonGeoClassificationDataset): 'B06', 'B07', 'B08', - 'B8A', 'B09', 'B10', 'B11', 'B12', + 'B8A', ) rgb_bands = ('B04', 'B03', 'B02') - BAND_SETS = {'all': all_band_names, 'rgb': rgb_bands} + BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = { + 'all': all_band_names, + 'rgb': rgb_bands, + } def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -138,11 +141,11 @@ def __init__( self._verify() valid_fns = set() - with open(os.path.join(self.root, f'eurosat-{split}.txt')) as f: + with open(os.path.join(self.root, self.split_filenames[split])) as f: for fn in f: valid_fns.add(fn.strip().replace('.jpg', '.tif')) - def is_in_split(x: str) -> bool: + def is_in_split(x: Path) -> bool: return os.path.basename(x) in valid_fns super().__init__( @@ -204,16 +207,12 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" download_url( - self.url, - self.root, - filename=self.filename, - md5=self.md5 if self.checksum else None, + self.url + self.filename, self.root, md5=self.md5 if self.checksum else None ) for split in self.splits: download_url( - self.split_urls[split], + self.url + self.split_filenames[split], self.root, - filename=f'eurosat-{split}.txt', md5=self.split_md5s[split] if self.checksum else None, ) @@ -302,12 +301,12 @@ class EuroSATSpatial(EuroSAT): .. versionadded:: 0.6 """ - split_urls = { - 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-train.txt', - 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-val.txt', - 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/1c11c73a87b40b0485d103231a97829991b8e22f/eurosat-spatial-test.txt', + split_filenames: ClassVar[dict[str, str]] = { + 'train': 'eurosat-spatial-train.txt', + 'val': 'eurosat-spatial-val.txt', + 'test': 'eurosat-spatial-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': '7be3254be39f23ce4d4d144290c93292', 'val': 'acf392290050bb3df790dc8fc0ebf193', 'test': '5ec1733f9c16116bf0aa2d921fc613ef', @@ -325,16 +324,15 @@ class EuroSAT100(EuroSAT): .. versionadded:: 0.5 """ - url = 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/EuroSAT100.zip' # noqa: E501 filename = 'EuroSAT100.zip' md5 = 'c21c649ba747e86eda813407ef17d596' - split_urls = { - 'train': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-train.txt', # noqa: E501 - 'val': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-val.txt', # noqa: E501 - 'test': 'https://hf.co/datasets/torchgeo/eurosat/resolve/06fd1b090bceecc0ce724cd21578ba7a6664fe8d/eurosat-test.txt', # noqa: E501 + split_filenames: ClassVar[dict[str, str]] = { + 'train': 'eurosat-100-train.txt', + 'val': 'eurosat-100-val.txt', + 'test': 'eurosat-100-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': '033d0c23e3a75e3fa79618b0e35fe1c7', 'val': '3e3f8b3c344182b8d126c4cc88f3f215', 'test': 'f908f151b950f270ad18e61153579794', diff --git a/torchgeo/datasets/fair1m.py b/torchgeo/datasets/fair1m.py index e3476c97128..d58968eaa19 100644 --- a/torchgeo/datasets/fair1m.py +++ b/torchgeo/datasets/fair1m.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, cast +from typing import Any, ClassVar, cast from xml.etree.ElementTree import Element, parse import matplotlib.patches as patches @@ -19,10 +19,10 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive -def parse_pascal_voc(path: str) -> dict[str, Any]: +def parse_pascal_voc(path: Path) -> dict[str, Any]: """Read a PASCAL VOC annotation file. Args: @@ -119,7 +119,7 @@ class FAIR1M(NonGeoDataset): .. versionadded:: 0.2 """ - classes = { + classes: ClassVar[dict[str, dict[str, Any]]] = { 'Passenger Ship': {'id': 0, 'category': 'Ship'}, 'Motorboat': {'id': 1, 'category': 'Ship'}, 'Fishing Boat': {'id': 2, 'category': 'Ship'}, @@ -159,12 +159,12 @@ class FAIR1M(NonGeoDataset): 'Bridge': {'id': 36, 'category': 'Road'}, } - filename_glob = { + filename_glob: ClassVar[dict[str, str]] = { 'train': os.path.join('train', '**', 'images', '*.tif'), 'val': os.path.join('validation', 'images', '*.tif'), 'test': os.path.join('test', 'images', '*.tif'), } - directories = { + directories: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( os.path.join('train', 'part1', 'images'), os.path.join('train', 'part1', 'labelXml'), @@ -175,9 +175,9 @@ class FAIR1M(NonGeoDataset): os.path.join('validation', 'images'), os.path.join('validation', 'labelXml'), ), - 'test': (os.path.join('test', 'images')), + 'test': (os.path.join('test', 'images'),), } - paths = { + paths: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( os.path.join('train', 'part1', 'images.zip'), os.path.join('train', 'part1', 'labelXml.zip'), @@ -194,7 +194,7 @@ class FAIR1M(NonGeoDataset): os.path.join('test', 'images2.zip'), ), } - urls = { + urls: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( 'https://drive.google.com/file/d/1LWT_ybL-s88Lzg9A9wHpj0h2rJHrqrVf', 'https://drive.google.com/file/d/1CnOuS8oX6T9JMqQnfFsbmf7U38G6Vc8u', @@ -211,7 +211,7 @@ class FAIR1M(NonGeoDataset): 'https://drive.google.com/file/d/1oUc25FVf8Zcp4pzJ31A1j1sOLNHu63P0', ), } - md5s = { + md5s: ClassVar[dict[str, tuple[str, ...]]] = { 'train': ( 'a460fe6b1b5b276bf856ce9ac72d6568', '80f833ff355f91445c92a0c0c1fa7414', @@ -230,7 +230,7 @@ class FAIR1M(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -279,7 +279,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: sample = {'image': image} if self.split != 'test': - label_path = path.replace(self.image_root, self.label_root) + label_path = str(path).replace(self.image_root, self.label_root) label_path = label_path.replace('.tif', '.xml') voc = parse_pascal_voc(label_path) boxes, labels = self._load_target(voc['points'], voc['labels']) @@ -298,7 +298,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: diff --git a/torchgeo/datasets/fire_risk.py b/torchgeo/datasets/fire_risk.py index 9a5033c8aab..9370488f503 100644 --- a/torchgeo/datasets/fire_risk.py +++ b/torchgeo/datasets/fire_risk.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class FireRisk(NonGeoClassificationDataset): @@ -55,8 +55,8 @@ class FireRisk(NonGeoClassificationDataset): md5 = 'a77b9a100d51167992ae8c51d26198a6' filename = 'FireRisk.zip' directory = 'FireRisk' - splits = ['train', 'val'] - classes = [ + splits = ('train', 'val') + classes = ( 'High', 'Low', 'Moderate', @@ -64,11 +64,11 @@ class FireRisk(NonGeoClassificationDataset): 'Very_High', 'Very_Low', 'Water', - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/forestdamage.py b/torchgeo/datasets/forestdamage.py index 1cbae17f961..9c3de28a2b5 100644 --- a/torchgeo/datasets/forestdamage.py +++ b/torchgeo/datasets/forestdamage.py @@ -19,10 +19,10 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import Path, check_integrity, download_and_extract_archive, extract_archive -def parse_pascal_voc(path: str) -> dict[str, Any]: +def parse_pascal_voc(path: Path) -> dict[str, Any]: """Read a PASCAL VOC annotation file. Args: @@ -96,17 +96,14 @@ class ForestDamage(NonGeoDataset): .. versionadded:: 0.3 """ - classes = ['other', 'H', 'LD', 'HD'] - url = ( - 'https://lilablobssc.blob.core.windows.net/larch-casebearer/' - 'Data_Set_Larch_Casebearer.zip' - ) + classes = ('other', 'H', 'LD', 'HD') + url = 'https://lilablobssc.blob.core.windows.net/larch-casebearer/Data_Set_Larch_Casebearer.zip' data_dir = 'Data_Set_Larch_Casebearer' md5 = '907815bcc739bff89496fac8f8ce63d7' def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -164,7 +161,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -187,7 +184,7 @@ def _load_files(self, root: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -209,7 +206,7 @@ def _load_target( """Load the target mask for a single image. Args: - bboxes: list of bbox coordinats [xmin, ymin, xmax, ymax] + bboxes: list of bbox coordinates [xmin, ymin, xmax, ymax] labels_list: list of class labels Returns: @@ -220,11 +217,7 @@ def _load_target( return boxes, labels def _verify(self) -> None: - """Checks the integrity of the dataset structure. - - Returns: - True if the dataset directories are found, else False - """ + """Verify the integrity of the dataset.""" filepath = os.path.join(self.root, self.data_dir) if os.path.isdir(filepath): return diff --git a/torchgeo/datasets/ftw.py b/torchgeo/datasets/ftw.py new file mode 100644 index 00000000000..7d4d92273d8 --- /dev/null +++ b/torchgeo/datasets/ftw.py @@ -0,0 +1,362 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Fields Of The World dataset.""" + +import os +from collections.abc import Callable, Sequence +from typing import ClassVar + +import einops +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import rasterio +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, array_to_tensor, download_and_extract_archive, extract_archive + + +class FieldsOfTheWorld(NonGeoDataset): + """Fields Of The World dataset. + + The `Fields Of The World `__ + datataset is a semantic and instance segmentation dataset for delineating field + boundaries. + + Dataset features: + + * 70462 patches across 24 countries + * Each country has a train, val, and test split + * Semantic segmentations masks with and without the field boundary class + * Instance segmentation masks + + Dataset format: + + * images are four-channel GeoTIFFs with dimension 256x256 + * segmentation masks (both two and three class) are single-channel GeoTIFFs + * instance masks are single-channel GeoTIFFs + + Dataset classes: + + 1. background + 2. field + 3. field-boundary (three-class only) + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.48550/arXiv.2409.16252 + + .. versionadded:: 0.7 + """ + + splits = ('train', 'val', 'test') + targets = ('2-class', '3-class', 'instance') + + valid_countries = ( + 'austria', + 'belgium', + 'brazil', + 'cambodia', + 'corsica', + 'croatia', + 'denmark', + 'estonia', + 'finland', + 'france', + 'germany', + 'india', + 'kenya', + 'latvia', + 'lithuania', + 'luxembourg', + 'netherlands', + 'portugal', + 'rwanda', + 'slovakia', + 'slovenia', + 'south_africa', + 'spain', + 'sweden', + 'vietnam', + ) + + base_url = 'https://data.source.coop/kerner-lab/fields-of-the-world-archive/' + + country_to_md5: ClassVar[dict[str, str]] = { + 'austria': '35604e3e3e78b4469e443bc756e19d26', + 'belgium': '111a9048e15391c947bc778e576e99b4', + 'brazil': '2ba96f9f01f37ead1435406c3f2b7c63', + 'cambodia': '581e9b8dae9713e4d03459bcec3c0bd0', + 'corsica': '0b38846063a98a31747fdeaf1ba03980', + 'croatia': 'dc5d33e19ae9e587c97f8f4c9852c87e', + 'denmark': 'ec817210b06351668cacdbd1a8fb9471', + 'estonia': 'b9c89e559e3c7d53a724e7f32ccf88ea', + 'finland': '23f853d6cbaea5a3596d1d38cc27fd65', + 'france': 'f05314f148642ff72d8bea903c01802d', + 'germany': 'd57a7ed203b9cf89c709aab29d687cee', + 'india': '361a688507e2e5cc7ca7138be01a5b80', + 'kenya': '80ca0335b25440379f99b7011dfbdfa2', + 'latvia': '6eeaaa57cdf18f25497f84e854a86d42', + 'lithuania': '0a2f4ab3309633e2de121d936e0763ba', + 'luxembourg': '5a8357eae364cca836b87827b3c6a3d3', + 'netherlands': '3afc61d184aab5c4fd6beaecf2b6c0a9', + 'portugal': '10485b747e1d8c082d33c73d032a7e05', + 'rwanda': '087ce56bbf06b32571ef27ff67bac43b', + 'slovakia': 'f66a0294491086d4c49dc4a804446e50', + 'slovenia': '6fa3ae3920bcc2c890a0d74435d9d29b', + 'south_africa': 'b7f1412d69922e8551cf91081401ec8d', + 'spain': '908bbf29597077c2c6954c439fe8265f', + 'sweden': '4b07726c421981bb2019e8900023393e', + 'vietnam': '32e1cacebcb2da656d40ab8522eb6737', + } + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + target: str = '2-class', + countries: str | Sequence[str] = ['austria'], + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Fields Of The World dataset instance. + + Args: + root: root directory where dataset can be found + split: one of "train", "val", or "test" + target: one of "2-class", "3-class", or "instance" specifying which kind of + target mask to load + countries: which set of countries to load data from + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If any arguments are invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert split in self.splits + assert target in self.targets + if isinstance(countries, str): + countries = [countries] + assert set(countries) <= set(self.valid_countries) + + self.root = root + self.split = split + self.target = target + self.countries = countries + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + self.files = self._load_files() + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + image and mask at that index with image of dimension 3x1024x1024 + and mask of dimension 1024x1024 + """ + win_a_fn = self.files[index]['win_a'] + win_b_fn = self.files[index]['win_b'] + mask_fn = self.files[index]['mask'] + + win_a = self._load_image(win_a_fn) + win_b = self._load_image(win_b_fn) + mask = self._load_target(mask_fn) + + image = torch.cat((win_a, win_b), dim=0) + sample = {'image': image, 'mask': mask} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of datapoints in the dataset. + + Returns: + length of dataset + """ + return len(self.files) + + def _load_files(self) -> list[dict[str, str]]: + """Return the paths of the files in the dataset. + + Returns: + a dictionary with "win_a", "win_b", and "mask" keys containing lists of + file paths + """ + files = [] + for country in self.countries: + df = pd.read_parquet( + os.path.join(self.root, country, f'chips_{country}.parquet') + ) + aois = df[df['split'] == self.split]['aoi_id'].values + + for aoi in aois: + if self.target == 'instance': + subdir = 'instance' + elif self.target == '2-class': + subdir = 'semantic_2class' + elif self.target == '3-class': + subdir = 'semantic_3class' + + win_a_fn = os.path.join( + self.root, country, 's2_images', 'window_a', f'{aoi}.tif' + ) + win_b_fn = os.path.join( + self.root, country, 's2_images', 'window_b', f'{aoi}.tif' + ) + + # there are 333 AOIs that are missing imagery across the dataset + if not (os.path.exists(win_a_fn) and os.path.exists(win_b_fn)): + continue + + sample = { + 'win_a': win_a_fn, + 'win_b': win_b_fn, + 'mask': os.path.join( + self.root, country, 'label_masks', subdir, f'{aoi}.tif' + ), + } + files.append(sample) + + return files + + def _load_image(self, path: Path) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the loaded image + """ + filename = os.path.join(path) + with rasterio.open(filename) as f: + array: np.typing.NDArray[np.int_] = f.read() + tensor = array_to_tensor(array).float() + return tensor + + def _load_target(self, path: Path) -> Tensor: + """Load a single mask corresponding to image. + + Args: + path: path to the mask + + Returns: + the mask of the image + """ + filename = os.path.join(path) + with rasterio.open(filename) as f: + array: np.typing.NDArray[np.int_] = f.read(1) + tensor = torch.from_numpy(array).long() + return tensor + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + for country in self.countries: + if self._verify_data(country): + continue + + filename = f'{country}.zip' + pathname = os.path.join(self.root, filename) + if os.path.exists(pathname): + extract_archive(pathname, os.path.join(self.root, country)) + continue + + if not self.download: + raise DatasetNotFoundError(self) + + download_and_extract_archive( + self.base_url + filename, + os.path.join(self.root, country), + filename=filename, + md5=self.country_to_md5[country] if self.checksum else None, + ) + + def _verify_data(self, country: str) -> bool: + """Verify that data for a country is extracted. + + Args: + country: the country to check + + Returns: + True if the dataset directories and split files are found, else False + """ + for entry in ['label_masks', 's2_images', f'chips_{country}.parquet']: + if not os.path.exists(os.path.join(self.root, country, entry)): + return False + + return True + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample return by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(15, 5)) + + win_a = einops.rearrange(sample['image'][0:3], 'c h w -> h w c') + win_b = einops.rearrange(sample['image'][4:7], 'c h w -> h w c') + mask = sample['mask'] + + win_a = torch.clip(win_a / 3000, 0, 1) + win_b = torch.clip(win_b / 3000, 0, 1) + + axs[0].imshow(win_a) + axs[0].set_title('Window A') + axs[1].imshow(win_b) + axs[1].set_title('Window B') + if self.target == 'instance': + unique_vals = sorted(np.unique(mask)) + for i, val in enumerate(unique_vals): + mask[mask == val] = i + bg_mask = mask == 0 + mask = (mask % 9) + 1 + mask[bg_mask] = 0 + axs[2].imshow(mask, vmin=0, vmax=10, cmap='tab10', interpolation='none') + axs[2].set_title('Instance mask') + elif self.target == '2-class': + axs[2].imshow(mask, vmin=0, vmax=2, cmap='gray', interpolation='none') + axs[2].set_title('2-class mask') + elif self.target == '3-class': + axs[2].imshow(mask, vmin=0, vmax=2, cmap='gray', interpolation='none') + axs[2].set_title('3-class mask') + for ax in axs: + ax.axis('off') + + if not show_titles: + for ax in axs: + ax.set_title('') + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py index 259abe481ad..3e8cfb6c883 100644 --- a/torchgeo/datasets/gbif.py +++ b/torchgeo/datasets/gbif.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox +from .utils import BoundingBox, Path def _disambiguate_timestamps( @@ -80,7 +80,7 @@ class GBIF(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = 'data') -> None: + def __init__(self, root: Path = 'data') -> None: """Initialize a new Dataset instance. Args: @@ -137,6 +137,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {'crs': self.crs, 'bbox': bboxes} + sample = {'crs': self.crs, 'bounds': bboxes} return sample diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index d44242d8130..26a035d427d 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -4,6 +4,7 @@ """Base classes for all :mod:`torchgeo` datasets.""" import abc +import fnmatch import functools import glob import os @@ -11,7 +12,7 @@ import sys import warnings from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import fiona.transform @@ -34,6 +35,7 @@ from .errors import DatasetNotFoundError from .utils import ( BoundingBox, + Path, array_to_tensor, concat_samples, disambiguate_timestamp, @@ -84,7 +86,7 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): dataset = landsat7 | landsat8 """ - paths: str | Iterable[str] + paths: Path | Iterable[Path] _crs = CRS.from_epsg(4326) _res = 0.0 @@ -205,7 +207,7 @@ def __setstate__( self, state: tuple[ dict[Any, Any], - list[tuple[int, tuple[float, float, float, float, float, float], str]], + list[tuple[int, tuple[float, float, float, float, float, float], Path]], ], ) -> None: """Define how to unpickle an instance. @@ -297,8 +299,8 @@ def files(self) -> list[str]: .. versionadded:: 0.5 """ # Make iterable - if isinstance(self.paths, str): - paths: Iterable[str] = [self.paths] + if isinstance(self.paths, str | os.PathLike): + paths: Iterable[Path] = [self.paths] else: paths = self.paths @@ -308,8 +310,10 @@ def files(self) -> list[str]: if os.path.isdir(path): pathname = os.path.join(path, '**', self.filename_glob) files |= set(glob.iglob(pathname, recursive=True)) - elif os.path.isfile(path) or path_is_vsi(path): - files.add(path) + elif (os.path.isfile(path) or path_is_vsi(path)) and fnmatch.fnmatch( + str(path), f'*{self.filename_glob}' + ): + files.add(str(path)) elif not hasattr(self, 'download'): warnings.warn( f"Could not find any relevant files for provided path '{path}'. " @@ -348,7 +352,7 @@ class RasterDataset(GeoDataset): #: Minimum timestamp if not in filename mint: float = 0 - #: Maximum timestmap if not in filename + #: Maximum timestamp if not in filename maxt: float = sys.maxsize #: True if the dataset only contains model inputs (such as images). False if the @@ -357,21 +361,21 @@ class RasterDataset(GeoDataset): #: The sample returned by the dataset/data loader will use the "image" key if #: *is_image* is True, otherwise it will use the "mask" key. #: - #: For datasets with both model inputs and outputs, a custom - #: :func:`~RasterDataset.__getitem__` method must be implemented. + #: For datasets with both model inputs and outputs, the recommended approach is + #: to use 2 `RasterDataset` instances and combine them using an `IntersectionDataset`. is_image = True #: True if data is stored in a separate file for each band, else False. separate_files = False #: Names of all available bands in the dataset - all_bands: list[str] = [] + all_bands: tuple[str, ...] = () #: Names of RGB bands in the dataset, used for plotting - rgb_bands: list[str] = [] + rgb_bands: tuple[str, ...] = () #: Color map for the dataset, used for plotting - cmap: dict[int, tuple[int, int, int, int]] = {} + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {} @property def dtype(self) -> torch.dtype: @@ -410,7 +414,7 @@ def resampling(self) -> Resampling: def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, @@ -453,17 +457,17 @@ def __init__( # See if file has a color map if len(self.cmap) == 0: try: - self.cmap = src.colormap(1) + self.cmap = src.colormap(1) # type: ignore[misc] except ValueError: pass if crs is None: crs = src.crs - if res is None: - res = src.res[0] with WarpedVRT(src, crs=crs) as vrt: minx, miny, maxx, maxy = vrt.bounds + if res is None: + res = vrt.res[0] except rasterio.errors.RasterioIOError: # Skip files that rasterio is unable to read continue @@ -544,13 +548,13 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: else: data = self._merge_files(filepaths, query, self.band_indexes) - sample = {'crs': self.crs, 'bbox': query} + sample = {'crs': self.crs, 'bounds': query} data = data.to(self.dtype) if self.is_image: sample['image'] = data else: - sample['mask'] = data + sample['mask'] = data.squeeze(0) if self.transforms is not None: sample = self.transforms(sample) @@ -587,7 +591,7 @@ def _merge_files( return tensor @functools.lru_cache(maxsize=128) - def _cached_load_warp_file(self, filepath: str) -> DatasetReader: + def _cached_load_warp_file(self, filepath: Path) -> DatasetReader: """Cached version of :meth:`_load_warp_file`. Args: @@ -598,7 +602,7 @@ def _cached_load_warp_file(self, filepath: str) -> DatasetReader: """ return self._load_warp_file(filepath) - def _load_warp_file(self, filepath: str) -> DatasetReader: + def _load_warp_file(self, filepath: Path) -> DatasetReader: """Load and warp a file to the correct CRS and resolution. Args: @@ -649,7 +653,7 @@ def dtype(self) -> torch.dtype: def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.0001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -774,7 +778,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: masks = array_to_tensor(masks) masks = masks.to(self.dtype) - sample = {'mask': masks, 'crs': self.crs, 'bbox': query} + sample = {'mask': masks, 'crs': self.crs, 'bounds': query} if self.transforms is not None: sample = self.transforms(sample) @@ -846,10 +850,10 @@ class NonGeoClassificationDataset(NonGeoDataset, ImageFolder): # type: ignore[m def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, - loader: Callable[[str], Any] | None = pil_loader, - is_valid_file: Callable[[str], bool] | None = None, + loader: Callable[[Path], Any] | None = pil_loader, + is_valid_file: Callable[[Path], bool] | None = None, ) -> None: """Initialize a new NonGeoClassificationDataset instance. diff --git a/torchgeo/datasets/geonrw.py b/torchgeo/datasets/geonrw.py new file mode 100644 index 00000000000..50e05bad0fb --- /dev/null +++ b/torchgeo/datasets/geonrw.py @@ -0,0 +1,346 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""GeoNRW dataset.""" + +import os +from collections.abc import Callable +from glob import glob +from typing import ClassVar + +import matplotlib +import matplotlib.cm +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor +from torchvision import transforms + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_and_extract_archive, extract_archive + + +class GeoNRW(NonGeoDataset): + """GeoNRW dataset. + + This datasets contains RGB, DEM and segmentation label data from North Rhine-Westphalia, Germany. + + Dataset features: + + * 7298 training and 485 test samples + * RGB images, 1000x1000px normalized to [0, 1] + * DEM images, unnormalized + * segmentation labels + + Dataset format: + + * RGB images are three-channel jp2 + * DEM images are single-channel tif + * segmentation labels are single-channel tif + + Dataset classes: + + 0. background + 1. forest + 2. water + 3. agricultural + 4. residential,commercial,industrial + 5. grassland,swamp,shrubbery + 6. railway,trainstation + 7. highway,squares + 8. airport,shipyard + 9. roads + 10. buildings + + Additional information about the dataset can be found `on this site `__. + + If you use this dataset in your research, please cite the following paper: + + * https://ieeexplore.ieee.org/document/9406194 + + + .. versionadded:: 0.6 + """ + + # Splits taken from https://github.com/gbaier/geonrw/blob/ecfcdbca8cfaaeb490a9c6916980f385b9f3941a/pytorch/nrw.py#L48 + + splits = ('train', 'test') + + train_list: tuple[str, ...] = ( + 'aachen', + 'bergisch', + 'bielefeld', + 'bochum', + 'bonn', + 'borken', + 'bottrop', + 'coesfeld', + 'dortmund', + 'dueren', + 'duisburg', + 'ennepetal', + 'erftstadt', + 'essen', + 'euskirchen', + 'gelsenkirchen', + 'guetersloh', + 'hagen', + 'hamm', + 'heinsberg', + 'herford', + 'hoexter', + 'kleve', + 'koeln', + 'krefeld', + 'leverkusen', + 'lippetal', + 'lippstadt', + 'lotte', + 'moenchengladbach', + 'moers', + 'muelheim', + 'muenster', + 'oberhausen', + 'paderborn', + 'recklinghausen', + 'remscheid', + 'siegen', + 'solingen', + 'wuppertal', + ) + + test_list: tuple[str, ...] = ('duesseldorf', 'herne', 'neuss') + + classes = ( + 'background', + 'forest', + 'water', + 'agricultural', + 'residential,commercial,industrial', + 'grassland,swamp,shrubbery', + 'railway,trainstation', + 'highway,squares', + 'airport,shipyard', + 'roads', + 'buildings', + ) + + colormap = mcolors.ListedColormap( + [ + '#000000', # matplotlib black for background + '#2ca02c', # matplotlib green for forest + '#1f77b4', # matplotlib blue for water + '#8c564b', # matplotlib brown for agricultural + '#7f7f7f', # matplotlib gray residential_commercial_industrial + '#bcbd22', # matplotlib olive for grassland_swamp_shrubbery + '#ff7f0e', # matplotlib orange for railway_trainstation + '#9467bd', # matplotlib purple for highway_squares + '#17becf', # matplotlib cyan for airport_shipyard + '#d62728', # matplotlib red for roads + '#e377c2', # matplotlib pink for buildings + ] + ) + + readers: ClassVar[dict[str, Callable[[str], Image.Image]]] = { + 'rgb': lambda path: Image.open(path).convert('RGB'), + 'dem': lambda path: Image.open(path).copy(), + 'seg': lambda path: Image.open(path).convert('I;16'), + } + + modality_filenames: ClassVar[dict[str, Callable[[list[str]], str]]] = { + 'rgb': lambda utm_coords: '{}_{}_rgb.jp2'.format(*utm_coords), + 'dem': lambda utm_coords: '{}_{}_dem.tif'.format(*utm_coords), + 'seg': lambda utm_coords: '{}_{}_seg.tif'.format(*utm_coords), + } + + modalities: tuple[str, ...] = ('rgb', 'dem', 'seg') + + url = 'https://hf.co/datasets/torchgeo/geonrw/resolve/3cb6bdf2a615b9e526c7dcff85fd1f20728081b7/{}' + + filename = 'nrw_dataset.tar.gz' + md5 = 'd56ab50098d5452c33d08ff4e99ce281' + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize the GeoNRW dataset. + + Args: + root: root directory where dataset can be found + split: one of "train", or "test" + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if ``split`` argument is invalid + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert split in self.splits, f'split must be one of {self.splits}' + + self.root = root + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + + self.city_names = self.test_list if split == 'test' else self.train_list + + self._verify() + + self.file_list = self._get_file_list() + + def _get_file_list(self) -> list[str]: + """Get a list of files for cities in the dataset split. + + Returns: + list of filenames in the dataset split + """ + file_list: list[str] = [] + for cn in self.city_names: + pattern = os.path.join(self.root, cn, '*rgb.jp2') + file_list.extend(glob(pattern)) + return sorted(file_list) + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.file_list) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + to_tensor = transforms.ToTensor() + + path: str = self.file_list[index] + utm_coords = os.path.basename(path).split('_')[:2] + base_dir = os.path.dirname(path) + + sample: dict[str, Tensor] = {} + for modality in self.modalities: + modality_path = os.path.join( + base_dir, self.modality_filenames[modality](utm_coords) + ) + sample[modality] = to_tensor(self.readers[modality](modality_path)) + + # rename to torchgeo standard keys + sample['image'] = sample.pop('rgb').float() + sample['mask'] = sample.pop('seg').long().squeeze(0) + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # check if city names directories exist + all_exist = all( + os.path.exists(os.path.join(self.root, cn)) for cn in self.city_names + ) + if all_exist: + return + + # Check if the tar file has been downloaded + if os.path.exists(os.path.join(self.root, self.filename)): + extract_archive(os.path.join(self.root, self.filename), self.root) + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + download_and_extract_archive( + self.url.format(self.filename), + download_root=self.root, + md5=self.md5 if self.checksum else None, + ) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional suptitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + showing_predictions = 'prediction' in sample + ncols = 3 + if showing_predictions: + prediction = sample['prediction'].long() + ncols += 1 + + fig, axs = plt.subplots( + nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True + ) + + axs[0].imshow(sample['image'].permute(1, 2, 0)) + axs[0].axis('off') + axs[1].imshow(sample['dem'].squeeze(0), cmap='gray') + axs[1].axis('off') + axs[2].imshow( + sample['mask'].squeeze(0), + self.colormap, + vmin=0, + vmax=10, + interpolation='none', + ) + axs[2].axis('off') + + if showing_predictions: + axs[3].imshow( + prediction.squeeze(0), + self.colormap, + vmin=0, + vmax=10, + interpolation='none', + ) + + # show classes in legend + if show_titles: + patches = [matplotlib.patches.Patch(color=c) for c in self.colormap.colors] # type: ignore + axs[2].legend( + patches, self.classes, loc='center left', bbox_to_anchor=(1, 0.5) + ) + + if show_titles: + axs[0].set_title('RGB Image') + axs[1].set_title('DEM') + axs[2].set_title('Labels') + + if suptitle is not None: + fig.suptitle(suptitle, y=0.8) + + fig.tight_layout() + + return fig diff --git a/torchgeo/datasets/gid15.py b/torchgeo/datasets/gid15.py index 329d488e94d..b42e6e58df6 100644 --- a/torchgeo/datasets/gid15.py +++ b/torchgeo/datasets/gid15.py @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import Path, download_and_extract_archive class GID15(NonGeoDataset): @@ -66,8 +66,8 @@ class GID15(NonGeoDataset): md5 = '615682bf659c3ed981826c6122c10c83' filename = 'gid-15.zip' directory = 'GID' - splits = ['train', 'val', 'test'] - classes = [ + splits = ('train', 'val', 'test') + classes = ( 'background', 'industrial_land', 'urban_residential', @@ -84,11 +84,11 @@ class GID15(NonGeoDataset): 'river', 'lake', 'pond', - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -154,7 +154,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -178,7 +178,7 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -195,7 +195,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)).float() return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: diff --git a/torchgeo/datasets/globbiomass.py b/torchgeo/datasets/globbiomass.py index 17091b6cc3d..c214fbba205 100644 --- a/torchgeo/datasets/globbiomass.py +++ b/torchgeo/datasets/globbiomass.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Iterable -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -15,7 +15,13 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, check_integrity, disambiguate_timestamp, extract_archive +from .utils import ( + BoundingBox, + Path, + check_integrity, + disambiguate_timestamp, + extract_archive, +) class GlobBiomass(RasterDataset): @@ -66,9 +72,9 @@ class GlobBiomass(RasterDataset): is_image = False dtype = torch.float32 # pixelwise regression - measurements = ['agb', 'gsv'] + measurements = ('agb', 'gsv') - md5s = { + md5s: ClassVar[dict[str, str]] = { 'N00E020_agb.zip': 'bd83a3a4c143885d1962bde549413be6', 'N00E020_gsv.zip': 'da5ddb88e369df2d781a0c6be008ae79', 'N00E060_agb.zip': '85eaca95b939086cc528e396b75bd097', @@ -131,7 +137,7 @@ class GlobBiomass(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, measurement: str = 'agb', @@ -195,12 +201,12 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask = self._merge_files(filepaths, query) - std_error_paths = [f.replace('.tif', '_err.tif') for f in filepaths] + std_error_paths = [str(f).replace('.tif', '_err.tif') for f in filepaths] std_err_mask = self._merge_files(std_error_paths, query) mask = torch.cat((mask, std_err_mask), dim=0) - sample = {'mask': mask, 'crs': self.crs, 'bbox': query} + sample = {'mask': mask, 'crs': self.crs, 'bounds': query} if self.transforms is not None: sample = self.transforms(sample) @@ -214,7 +220,7 @@ def _verify(self) -> None: return # Check if the zip files have already been downloaded - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, f'*_{self.measurement}.zip') if glob.glob(pathname): for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/hyspecnet.py b/torchgeo/datasets/hyspecnet.py new file mode 100644 index 00000000000..412ea504b24 --- /dev/null +++ b/torchgeo/datasets/hyspecnet.py @@ -0,0 +1,229 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""HySpecNet dataset.""" + +import os +from collections.abc import Callable, Sequence +from typing import ClassVar + +import rasterio as rio +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError, RGBBandsMissingError +from .geo import NonGeoDataset +from .utils import Path, download_url, extract_archive, percentile_normalization + +# https://git.tu-berlin.de/rsim/hyspecnet-tools/-/blob/main/tif_to_npy.ipynb +invalid_channels = [ + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 160, + 161, + 162, + 163, + 164, + 165, + 166, +] +valid_channels_ids = [c + 1 for c in range(224) if c not in invalid_channels] + + +class HySpecNet11k(NonGeoDataset): + """HySpecNet-11k dataset. + + `HySpecNet-11k `__ is a large-scale + benchmark dataset for hyperspectral image compression and self-supervised learning. + It is made up of 11,483 nonoverlapping image patches acquired by the + `EnMAP satellite `_. Each patch is a portion of 128 x 128 + pixels with 224 spectral bands and with a ground sample distance of 30 m. + + To construct HySpecNet-11k, a total of 250 EnMAP tiles acquired during the routine + operation phase between 2 November 2022 and 9 November 2022 were considered. The + considered tiles are associated with less than 10% cloud and snow cover. The tiles + were radiometrically, geometrically and atmospherically corrected (L2A water & land + product). Then, the tiles were divided into nonoverlapping image patches. The + cropped patches at the borders of the tiles were eliminated. As a result, more than + 45 patches per tile are obtained, resulting in 11,483 patches for the full dataset. + + We provide predefined splits obtained by randomly dividing HySpecNet into: + + #. a training set that includes 70% of the patches, + #. a validation set that includes 20% of the patches, and + #. a test set that includes 10% of the patches. + + Depending on the way that we used for splitting the dataset, we define two + different splits: + + #. an easy split, where patches from the same tile can be present in different sets + (patchwise splitting); and + #. a hard split, where all patches from one tile belong to the same set + (tilewise splitting). + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2306.00385 + + .. versionadded:: 0.7 + """ + + url = 'https://hf.co/datasets/torchgeo/hyspecnet/resolve/13e110422a6925cbac0f11edff610219b9399227/' + md5s: ClassVar[dict[str, str]] = { + 'hyspecnet-11k-01.tar.gz': '974aae9197006727b42ec81796049efe', + 'hyspecnet-11k-02.tar.gz': 'f80574485f835b8a263b6c64076c0c62', + 'hyspecnet-11k-03.tar.gz': '6bc1de573f97fa4a75b79719b9270cb3', + 'hyspecnet-11k-04.tar.gz': '2463dc10653cb8be10d44951307c5e7d', + 'hyspecnet-11k-05.tar.gz': '16c1bd9e684673e741c0849bd015c988', + 'hyspecnet-11k-06.tar.gz': '8eef16b67d71af6eb4bc836d294fe3c4', + 'hyspecnet-11k-07.tar.gz': 'f61f0e7d6b05c861e69026b09130a5d6', + 'hyspecnet-11k-08.tar.gz': '19d390bc9e61b85e7d765f3077984976', + 'hyspecnet-11k-09.tar.gz': '197ff47befe5b9de88be5e1321c5ce5d', + 'hyspecnet-11k-10.tar.gz': '9e674cca126a9d139d6584be148d4bac', + 'hyspecnet-11k-splits.tar.gz': '94fad9e3c979c612c29a045406247d6c', + } + + all_bands = valid_channels_ids + rgb_bands = (43, 28, 10) + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + strategy: str = 'easy', + bands: Sequence[int] = all_bands, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new HySpecNet11k instance. + + Args: + root: Root directory where dataset can be found. + split: One of 'train', 'val', or 'test'. + strategy: Either 'easy' for patchwise splitting or 'hard' for tilewise + splitting. + bands: Bands to return. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + self.root = root + self.split = split + self.strategy = strategy + self.bands = bands + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + path = os.path.join(root, 'hyspecnet-11k', 'splits', strategy, f'{split}.csv') + with open(path) as f: + self.files = f.read().strip().split('\n') + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + Length of the dataset. + """ + return len(self.files) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + Data and label at that index. + """ + file = self.files[index].replace('DATA.npy', 'SPECTRAL_IMAGE.TIF') + with rio.open(os.path.join(self.root, 'hyspecnet-11k', 'patches', file)) as src: + sample = {'image': torch.tensor(src.read(self.bands).astype('float32'))} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + exists = [] + for directory in ['patches', 'splits']: + path = os.path.join(self.root, 'hyspecnet-11k', directory) + exists.append(os.path.isdir(path)) + + if all(exists): + return + + for file, md5 in self.md5s.items(): + # Check if the file has already been downloaded + path = os.path.join(self.root, file) + if os.path.isfile(path): + extract_archive(path) + continue + + # Check if the user requested to download the dataset + if self.download: + url = self.url + file + download_url(url, self.root, md5=md5 if self.checksum else None) + extract_archive(path) + continue + + raise DatasetNotFoundError(self) + + def plot(self, sample: dict[str, Tensor], suptitle: str | None = None) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by :meth:`__getitem__`. + suptitle: optional string to use as a suptitle + + Returns: + A matplotlib Figure with the rendered sample. + + Raises: + RGBBandsMissingError: If *bands* does not include all RGB bands. + """ + rgb_indices = [] + for band in self.rgb_bands: + if band in self.bands: + rgb_indices.append(self.bands.index(band)) + else: + raise RGBBandsMissingError() + + image = sample['image'][rgb_indices].cpu().numpy() + image = rearrange(image, 'c h w -> h w c') + image = percentile_normalization(image) + + fig, ax = plt.subplots() + ax.imshow(image) + ax.axis('off') + + if suptitle: + fig.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/idtrees.py b/torchgeo/datasets/idtrees.py index 4dd067244df..28e890dc69f 100644 --- a/torchgeo/datasets/idtrees.py +++ b/torchgeo/datasets/idtrees.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, cast, overload +from typing import Any, ClassVar, cast, overload import fiona import matplotlib.pyplot as plt @@ -22,7 +22,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, lazy_import +from .utils import Path, download_url, extract_archive, lazy_import class IDTReeS(NonGeoDataset): @@ -100,7 +100,7 @@ class IDTReeS(NonGeoDataset): .. versionadded:: 0.2 """ - classes = { + classes: ClassVar[dict[str, str]] = { 'ACPE': 'Acer pensylvanicum L.', 'ACRU': 'Acer rubrum L.', 'ACSA3': 'Acer saccharum Marshall', @@ -135,24 +135,27 @@ class IDTReeS(NonGeoDataset): 'ROPS': 'Robinia pseudoacacia L.', 'TSCA': 'Tsuga canadensis (L.) Carriere', } - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { - 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_train_v2.zip?download=1', # noqa: E501 + 'url': 'https://zenodo.org/records/3934932/files/IDTREES_competition_train_v2.zip?download=1', 'md5': '5ddfa76240b4bb6b4a7861d1d31c299c', 'filename': 'IDTREES_competition_train_v2.zip', }, 'test': { - 'url': 'https://zenodo.org/record/3934932/files/IDTREES_competition_test_v2.zip?download=1', # noqa: E501 + 'url': 'https://zenodo.org/records/3934932/files/IDTREES_competition_test_v2.zip?download=1', 'md5': 'b108931c84a70f2a38a8234290131c9b', 'filename': 'IDTREES_competition_test_v2.zip', }, } - directories = {'train': ['train'], 'test': ['task1', 'task2']} + directories: ClassVar[dict[str, list[str]]] = { + 'train': ['train'], + 'test': ['task1', 'task2'], + } image_size = (200, 200) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', task: str = 'task1', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -240,7 +243,7 @@ def __len__(self) -> int: """ return len(self.images) - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a tiff file. Args: @@ -254,7 +257,7 @@ def _load_image(self, path: str) -> Tensor: tensor = torch.from_numpy(array) return tensor - def _load_las(self, path: str) -> Tensor: + def _load_las(self, path: Path) -> Tensor: """Load a single point cloud. Args: @@ -269,7 +272,7 @@ def _load_las(self, path: str) -> Tensor: tensor = torch.from_numpy(array) return tensor - def _load_boxes(self, path: str) -> Tensor: + def _load_boxes(self, path: Path) -> Tensor: """Load object bounding boxes. Args: @@ -313,7 +316,7 @@ def _load_boxes(self, path: str) -> Tensor: tensor = torch.tensor(boxes) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load target label for a single sample. Args: @@ -333,7 +336,7 @@ def _load_target(self, path: str) -> Tensor: return tensor def _load( - self, root: str + self, root: Path ) -> tuple[list[str], dict[int, dict[str, Any]] | None, Any]: """Load files, geometries, and labels. @@ -360,7 +363,7 @@ def _load( return images, geoms, labels - def _load_labels(self, directory: str) -> Any: + def _load_labels(self, directory: Path) -> Any: """Load the csv files containing the labels. Args: @@ -380,7 +383,7 @@ def _load_labels(self, directory: str) -> Any: df.reset_index() return df - def _load_geometries(self, directory: str) -> dict[int, dict[str, Any]]: + def _load_geometries(self, directory: Path) -> dict[int, dict[str, Any]]: """Load the shape files containing the geometries. Args: diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py index 478b60a1c10..bb5cfe3c8df 100644 --- a/torchgeo/datasets/inaturalist.py +++ b/torchgeo/datasets/inaturalist.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import GeoDataset -from .utils import BoundingBox, disambiguate_timestamp +from .utils import BoundingBox, Path, disambiguate_timestamp class INaturalist(GeoDataset): @@ -26,7 +26,7 @@ class INaturalist(GeoDataset): If you use an iNaturalist dataset in your research, please cite it according to: - * https://www.inaturalist.org/pages/help#cite + * https://help.inaturalist.org/en/support/solutions/articles/151000170344-how-should-i-cite-inaturalist- .. versionadded:: 0.3 """ @@ -34,7 +34,7 @@ class INaturalist(GeoDataset): res = 0 _crs = CRS.from_epsg(4326) # Lat/Lon - def __init__(self, root: str = 'data') -> None: + def __init__(self, root: Path = 'data') -> None: """Initialize a new Dataset instance. Args: @@ -107,6 +107,6 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: f'query: {query} not found in index with bounds: {self.bounds}' ) - sample = {'crs': self.crs, 'bbox': bboxes} + sample = {'crs': self.crs, 'bounds': bboxes} return sample diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 5b3db228499..3b2a4348a96 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive, percentile_normalization +from .utils import Path, check_integrity, extract_archive, percentile_normalization class InriaAerialImageLabeling(NonGeoDataset): @@ -59,7 +59,7 @@ class InriaAerialImageLabeling(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, checksum: bool = False, @@ -86,7 +86,7 @@ def __init__( self._verify() self.files = self._load_files(root) - def _load_files(self, root: str) -> list[dict[str, str]]: + def _load_files(self, root: Path) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -121,7 +121,7 @@ def _load_files(self, root: str) -> list[dict[str, str]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -135,7 +135,7 @@ def _load_image(self, path: str) -> Tensor: tensor = torch.from_numpy(array).float() return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Loads the target mask. Args: diff --git a/torchgeo/datasets/iobench.py b/torchgeo/datasets/iobench.py index a0ee246065a..608a9ccc17a 100644 --- a/torchgeo/datasets/iobench.py +++ b/torchgeo/datasets/iobench.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset from .landsat import Landsat9 -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class IOBench(IntersectionDataset): @@ -40,9 +40,9 @@ class IOBench(IntersectionDataset): .. versionadded:: 0.6 """ - url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/io/resolve/c9d9d268cf0b61335941bdc2b6963bf16fc3a6cf/{}.tar.gz' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'original': 'e3a908a0fd1c05c1af2f4c65724d59b3', 'raw': 'e9603990441007ce7bba73bb8ba7d217', 'preprocessed': '9801f1240b238cb17525c865e413d1fd', @@ -50,11 +50,11 @@ class IOBench(IntersectionDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'preprocessed', crs: CRS | None = None, res: float | None = None, - bands: Sequence[str] | None = Landsat9.default_bands + ['SR_QA_AEROSOL'], + bands: Sequence[str] | None = [*Landsat9.default_bands, 'SR_QA_AEROSOL'], classes: list[int] = [0], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, diff --git a/torchgeo/datasets/l7irish.py b/torchgeo/datasets/l7irish.py index 7153738b391..d39f225ed75 100644 --- a/torchgeo/datasets/l7irish.py +++ b/torchgeo/datasets/l7irish.py @@ -7,7 +7,7 @@ import os import re from collections.abc import Callable, Iterable, Sequence -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -18,7 +18,13 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset, RasterDataset -from .utils import BoundingBox, disambiguate_timestamp, download_url, extract_archive +from .utils import ( + BoundingBox, + Path, + disambiguate_timestamp, + download_url, + extract_archive, +) class L7IrishImage(RasterDataset): @@ -36,8 +42,8 @@ class L7IrishImage(RasterDataset): """ date_format = '%Y%m%d' is_image = True - rgb_bands = ['B30', 'B20', 'B10'] - all_bands = ['B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80'] + rgb_bands = ('B30', 'B20', 'B10') + all_bands = ('B10', 'B20', 'B30', 'B40', 'B50', 'B61', 'B62', 'B70', 'B80') class L7IrishMask(RasterDataset): @@ -52,7 +58,7 @@ class L7IrishMask(RasterDataset): _newmask2015\.TIF$ """ is_image = False - classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] + classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud') ordinal_map = torch.zeros(256, dtype=torch.long) ordinal_map[64] = 1 ordinal_map[128] = 2 @@ -61,7 +67,7 @@ class L7IrishMask(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, @@ -151,11 +157,11 @@ class L7Irish(IntersectionDataset): * https://www.sciencebase.gov/catalog/item/573ccf18e4b0dae0d5e4b109 .. versionadded:: 0.5 - """ # noqa: E501 + """ - url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/l7irish/resolve/6807e0b22eca7f9a8a3903ea673b31a115837464/{}.tar.gz' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'austral': '0a34770b992a62abeb88819feb192436', 'boreal': 'b7cfdd689a3c2fd2a8d572e1c10ed082', 'mid_latitude_north': 'c40abe5ad2487f8ab021cfb954982faa', @@ -169,7 +175,7 @@ class L7Irish(IntersectionDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, bands: Sequence[str] = L7IrishImage.all_bands, @@ -222,7 +228,7 @@ def _merge_dataset_indices(self) -> None: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if not isinstance(self.paths, str): + if not isinstance(self.paths, str | os.PathLike): return for classname in [L7IrishImage, L7IrishMask]: @@ -255,7 +261,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/l8biome.py b/torchgeo/datasets/l8biome.py index c200b5c63bc..e53c403b713 100644 --- a/torchgeo/datasets/l8biome.py +++ b/torchgeo/datasets/l8biome.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable, Iterable, Sequence -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import IntersectionDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url, extract_archive class L8BiomeImage(RasterDataset): @@ -35,8 +35,8 @@ class L8BiomeImage(RasterDataset): """ date_format = '%Y%j' is_image = True - rgb_bands = ['B4', 'B3', 'B2'] - all_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11'] + rgb_bands = ('B4', 'B3', 'B2') + all_bands = ('B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B9', 'B10', 'B11') class L8BiomeMask(RasterDataset): @@ -56,7 +56,7 @@ class L8BiomeMask(RasterDataset): """ date_format = '%Y%j' is_image = False - classes = ['Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud'] + classes = ('Fill', 'Cloud Shadow', 'Clear', 'Thin Cloud', 'Cloud') ordinal_map = torch.zeros(256, dtype=torch.long) ordinal_map[64] = 1 ordinal_map[128] = 2 @@ -115,11 +115,11 @@ class L8Biome(IntersectionDataset): * https://doi.org/10.1016/j.rse.2017.03.026 .. versionadded:: 0.5 - """ # noqa: E501 + """ - url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/l8biome/resolve/f76df19accce34d2acc1878d88b9491bc81f94c8/{}.tar.gz' - md5s = { + md5s: ClassVar[dict[str, str]] = { 'barren': '0eb691822d03dabd4f5ea8aadd0b41c3', 'forest': '4a5645596f6bb8cea44677f746ec676e', 'grass_crops': 'a69ed5d6cb227c5783f026b9303cdd3c', @@ -132,7 +132,7 @@ class L8Biome(IntersectionDataset): def __init__( self, - paths: str | Iterable[str], + paths: Path | Iterable[Path], crs: CRS | None = CRS.from_epsg(3857), res: float | None = None, bands: Sequence[str] = L8BiomeImage.all_bands, @@ -173,7 +173,7 @@ def __init__( def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist - if not isinstance(self.paths, str): + if not isinstance(self.paths, str | os.PathLike): return for classname in [L8BiomeImage, L8BiomeMask]: @@ -206,7 +206,7 @@ def _download(self) -> None: def _extract(self) -> None: """Extract the dataset.""" - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, '*.tar.gz') for tarfile in glob.iglob(pathname): extract_archive(tarfile) diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index 970a45eb1cd..d2c0acf88e7 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -9,7 +9,7 @@ import os from collections.abc import Callable from functools import lru_cache -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -23,7 +23,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset, RasterDataset -from .utils import BoundingBox, download_url, extract_archive, working_dir +from .utils import BoundingBox, Path, download_url, extract_archive, working_dir class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): @@ -64,8 +64,8 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): url = 'https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip' filename = 'landcover.ai.v1.zip' md5 = '3268c89070e8734b4e91d531c0617e03' - classes = ['Background', 'Building', 'Woodland', 'Water', 'Road'] - cmap = { + classes = ('Background', 'Building', 'Woodland', 'Water', 'Road') + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 0), 1: (97, 74, 74, 255), 2: (38, 115, 0, 255), @@ -74,7 +74,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC): } def __init__( - self, root: str = 'data', download: bool = False, checksum: bool = False + self, root: Path = 'data', download: bool = False, checksum: bool = False ) -> None: """Initialize a new LandCover.ai dataset instance. @@ -95,8 +95,7 @@ def __init__( lc_colors = np.zeros((max(self.cmap.keys()) + 1, 4)) lc_colors[list(self.cmap.keys())] = list(self.cmap.values()) - lc_colors = lc_colors[:, :3] / 255 - self._lc_cmap = ListedColormap(lc_colors) + self._lc_cmap = ListedColormap(lc_colors[:, :3] / 255) self._verify() @@ -205,7 +204,7 @@ class LandCoverAIGeo(LandCoverAIBase, RasterDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', crs: CRS | None = None, res: float | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -255,7 +254,9 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """ hits = self.index.intersection(tuple(query), objects=True) img_filepaths = cast(list[str], [hit.object for hit in hits]) - mask_filepaths = [path.replace('images', 'masks') for path in img_filepaths] + mask_filepaths = [ + str(path).replace('images', 'masks') for path in img_filepaths + ] if not img_filepaths: raise IndexError( @@ -266,7 +267,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: mask = self._merge_files(mask_filepaths, query, self.band_indexes) sample = { 'crs': self.crs, - 'bbox': query, + 'bounds': query, 'image': img.float(), 'mask': mask.long(), } @@ -294,7 +295,7 @@ class LandCoverAI(LandCoverAIBase, NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -399,10 +400,26 @@ def _extract(self) -> None: super()._extract() # Generate train/val/test splits - # Always check the sha256 of this file before executing - # to avoid malicious code injection - with working_dir(self.root): - with open('split.py') as f: - split = f.read().encode('utf-8') - assert hashlib.sha256(split).hexdigest() == self.sha256 - exec(split) + # Always check the sha256 of this file before executing to avoid malicious code injection + # The LandCoverAI100 dataset doesn't contain split.py, so only run if split.py exists + if os.path.exists(os.path.join(self.root, 'split.py')): + with working_dir(self.root): + with open('split.py') as f: + split = f.read().encode('utf-8') + assert hashlib.sha256(split).hexdigest() == self.sha256 + exec(split) + + +class LandCoverAI100(LandCoverAI): + """Subset of LandCoverAI containing only 100 images. + + Intended for tutorials and demonstrations, not for benchmarking. + + Maintains the same file structure, classes, and train-val-test split. + + .. versionadded:: 0.7 + """ + + url = 'https://huggingface.co/datasets/torchgeo/landcoverai/resolve/5cdf9299bd6c1232506cf79373df01f6e6596b50/landcoverai100.zip' + filename = 'landcoverai100.zip' + md5 = '66eb33b5a0cabb631836ce0a4eafb7cd' diff --git a/torchgeo/datasets/landsat.py b/torchgeo/datasets/landsat.py index aee28c1224d..8fb33b7c9cc 100644 --- a/torchgeo/datasets/landsat.py +++ b/torchgeo/datasets/landsat.py @@ -13,6 +13,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset +from .utils import Path class Landsat(RasterDataset, abc.ABC): @@ -32,7 +33,7 @@ class Landsat(RasterDataset, abc.ABC): * `Surface Temperature `_ * `Surface Reflectance `_ * `U.S. Analysis Ready Data `_ - """ # noqa: E501 + """ # https://www.usgs.gov/landsat-missions/landsat-collection-2 filename_regex = r""" @@ -54,12 +55,12 @@ class Landsat(RasterDataset, abc.ABC): @property @abc.abstractmethod - def default_bands(self) -> list[str]: + def default_bands(self) -> tuple[str, ...]: """Bands to load by default.""" def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, bands: Sequence[str] | None = None, @@ -144,8 +145,8 @@ class Landsat1(Landsat): filename_glob = 'LM01_*_{}.*' - default_bands = ['B4', 'B5', 'B6', 'B7'] - rgb_bands = ['B6', 'B5', 'B4'] + default_bands = ('B4', 'B5', 'B6', 'B7') + rgb_bands = ('B6', 'B5', 'B4') class Landsat2(Landsat1): @@ -165,8 +166,8 @@ class Landsat4MSS(Landsat): filename_glob = 'LM04_*_{}.*' - default_bands = ['B1', 'B2', 'B3', 'B4'] - rgb_bands = ['B3', 'B2', 'B1'] + default_bands = ('B1', 'B2', 'B3', 'B4') + rgb_bands = ('B3', 'B2', 'B1') class Landsat4TM(Landsat): @@ -174,8 +175,8 @@ class Landsat4TM(Landsat): filename_glob = 'LT04_*_{}.*' - default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] - rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] + default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7') + rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1') class Landsat5MSS(Landsat4MSS): @@ -195,8 +196,8 @@ class Landsat7(Landsat): filename_glob = 'LE07_*_{}.*' - default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] - rgb_bands = ['SR_B3', 'SR_B2', 'SR_B1'] + default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7') + rgb_bands = ('SR_B3', 'SR_B2', 'SR_B1') class Landsat8(Landsat): @@ -204,11 +205,11 @@ class Landsat8(Landsat): filename_glob = 'LC08_*_{}.*' - default_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] - rgb_bands = ['SR_B4', 'SR_B3', 'SR_B2'] + default_bands = ('SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7') + rgb_bands = ('SR_B4', 'SR_B3', 'SR_B2') class Landsat9(Landsat8): - """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" # noqa: E501 + """Landsat 9 Operational Land Imager (OLI-2) and Thermal Infrared Sensor (TIRS-2).""" filename_glob = 'LC09_*_{}.*' diff --git a/torchgeo/datasets/levircd.py b/torchgeo/datasets/levircd.py index 67209f603a8..fdff569dc19 100644 --- a/torchgeo/datasets/levircd.py +++ b/torchgeo/datasets/levircd.py @@ -7,6 +7,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -17,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive, percentile_normalization +from .utils import Path, download_and_extract_archive, percentile_normalization class LEVIRCDBase(NonGeoDataset, abc.ABC): @@ -26,12 +27,12 @@ class LEVIRCDBase(NonGeoDataset, abc.ABC): .. versionadded:: 0.6 """ - splits: list[str] | dict[str, dict[str, str]] - directories = ['A', 'B', 'label'] + splits: ClassVar[tuple[str, ...] | dict[str, dict[str, str]]] + directories = ('A', 'B', 'label') def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -94,7 +95,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -111,7 +112,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: @@ -183,7 +184,7 @@ def plot( return fig @abc.abstractmethod - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -237,7 +238,7 @@ class LEVIRCD(LEVIRCDBase): .. versionadded:: 0.6 """ - splits = { + splits: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'url': 'https://drive.google.com/file/d/18GuoCuBn48oZKAlEo-LrNwABrFhVALU-', 'filename': 'train.zip', @@ -255,7 +256,7 @@ class LEVIRCD(LEVIRCDBase): }, } - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -328,17 +329,15 @@ class LEVIRCDPlus(LEVIRCDBase): If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/2107.09244 - - .. versionchanged:: 0.6 """ url = 'https://drive.google.com/file/d/1JamSsxiytXdzAIk6VDVWfc-OsX-81U81' md5 = '1adf156f628aa32fb2e8fe6cada16c04' filename = 'LEVIR-CD+.zip' directory = 'LEVIR-CD+' - splits = ['train', 'test'] + splits = ('train', 'test') - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 58f3876a09b..93b6b18e455 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -5,7 +5,8 @@ import glob import os -from collections.abc import Callable +from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -16,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_and_extract_archive +from .utils import Path, download_and_extract_archive class LoveDA(NonGeoDataset): @@ -57,28 +58,28 @@ class LoveDA(NonGeoDataset): .. versionadded:: 0.2 """ - scenes = ['urban', 'rural'] - splits = ['train', 'val', 'test'] + scenes = ('urban', 'rural') + splits = ('train', 'val', 'test') - info_dict = { + info_dict: ClassVar[dict[str, dict[str, str]]] = { 'train': { - 'url': 'https://zenodo.org/record/5706578/files/Train.zip?download=1', + 'url': 'https://zenodo.org/records/5706578/files/Train.zip?download=1', 'filename': 'Train.zip', 'md5': 'de2b196043ed9b4af1690b3f9a7d558f', }, 'val': { - 'url': 'https://zenodo.org/record/5706578/files/Val.zip?download=1', + 'url': 'https://zenodo.org/records/5706578/files/Val.zip?download=1', 'filename': 'Val.zip', 'md5': '84cae2577468ff0b5386758bb386d31d', }, 'test': { - 'url': 'https://zenodo.org/record/5706578/files/Test.zip?download=1', + 'url': 'https://zenodo.org/records/5706578/files/Test.zip?download=1', 'filename': 'Test.zip', 'md5': 'a489be0090465e01fb067795d24e6b47', }, } - classes = [ + classes = ( 'background', 'building', 'road', @@ -87,13 +88,13 @@ class LoveDA(NonGeoDataset): 'forest', 'agriculture', 'no-data', - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', - scene: list[str] = ['urban', 'rural'], + scene: Sequence[str] = ['urban', 'rural'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -197,7 +198,7 @@ def _load_files(self, scene_paths: list[str], split: str) -> list[dict[str, str] return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -214,7 +215,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load a single mask corresponding to image. Args: diff --git a/torchgeo/datasets/mapinwild.py b/torchgeo/datasets/mapinwild.py index 882ec260fef..cd294014318 100644 --- a/torchgeo/datasets/mapinwild.py +++ b/torchgeo/datasets/mapinwild.py @@ -7,6 +7,7 @@ import shutil from collections import defaultdict from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -19,6 +20,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, download_url, extract_archive, @@ -35,7 +37,7 @@ class MapInWild(NonGeoDataset): different RS sensors over 1018 locations: dual-pol Sentinel-1, four-season Sentinel-2 with 10 bands, ESA WorldCover map, and Visible Infrared Imaging Radiometer Suite NightTime Day/Night band. The dataset consists of 8144 - images with the shape of 1920 × 1920 pixels. The images are weakly annotated + images with the shape of 1920 x 1920 pixels. The images are weakly annotated from the World Database of Protected Areas (WDPA). Dataset features: @@ -53,9 +55,9 @@ class MapInWild(NonGeoDataset): .. versionadded:: 0.5 """ - url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' # noqa: E501 + url = 'https://hf.co/datasets/burakekim/mapinwild/resolve/d963778e31e7e0ed2329c0f4cbe493be532f0e71/' - modality_urls = { + modality_urls: ClassVar[dict[str, set[str]]] = { 'esa_wc': {'esa_wc/ESA_WC.zip'}, 'viirs': {'viirs/VIIRS.zip'}, 'mask': {'mask/mask.zip'}, @@ -71,7 +73,7 @@ class MapInWild(NonGeoDataset): 'split_IDs': {'split_IDs/split_IDs.csv'}, } - md5s = { + md5s: ClassVar[dict[str, str]] = { 'ESA_WC.zip': '72b2ee578fe10f0df85bdb7f19311c92', 'VIIRS.zip': '4eff014bae127fe536f8a5f17d89ecb4', 'mask.zip': '87c83a23a73998ad60d448d240b66225', @@ -90,9 +92,12 @@ class MapInWild(NonGeoDataset): 'split_IDs.csv': 'cb5c6c073702acee23544e1e6fe5856f', } - mask_cmap = {1: (0, 153, 0), 0: (255, 255, 255)} + mask_cmap: ClassVar[dict[int, tuple[int, int, int]]] = { + 1: (0, 153, 0), + 0: (255, 255, 255), + } - wc_cmap = { + wc_cmap: ClassVar[dict[int, tuple[int, int, int]]] = { 10: (0, 160, 0), 20: (150, 100, 0), 30: (255, 180, 0), @@ -108,7 +113,7 @@ class MapInWild(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', modality: list[str] = ['mask', 'esa_wc', 'viirs', 's2_summer'], split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -155,7 +160,7 @@ def __init__( ): self._merge_parts(mode) - # Masks will be loaded seperately in the :meth:`__getitem__` + # Masks will be loaded separately in the :meth:`__getitem__` if 'mask' in self.modality: self.modality.remove('mask') @@ -205,7 +210,7 @@ def __len__(self) -> int: """ return len(self.ids) - def _load_raster(self, filename: int, source: str) -> Tensor: + def _load_raster(self, filename: int, source: Path) -> Tensor: """Load a single raster image or target. Args: @@ -272,7 +277,7 @@ def _download(self, url: str, md5: str | None) -> None: md5=md5 if self.checksum else None, ) - def _extract(self, path: str) -> None: + def _extract(self, path: Path) -> None: """Extracts a modality. Args: diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 46eabbe19e9..b7111da1962 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -17,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, extract_archive +from .utils import Path, check_integrity, extract_archive class MillionAID(NonGeoDataset): @@ -48,7 +48,7 @@ class MillionAID(NonGeoDataset): .. versionadded:: 0.3 """ - multi_label_categories = [ + multi_label_categories = ( 'agriculture_land', 'airport_area', 'apartment', @@ -122,9 +122,9 @@ class MillionAID(NonGeoDataset): 'wind_turbine', 'woodland', 'works', - ] + ) - multi_class_categories = [ + multi_class_categories = ( 'apartment', 'apron', 'bare_land', @@ -176,21 +176,21 @@ class MillionAID(NonGeoDataset): 'wastewater_plant', 'wind_turbine', 'works', - ] + ) - md5s = { + md5s: ClassVar[dict[str, str]] = { 'train': '1b40503cafa9b0601653ca36cd788852', 'test': '51a63ee3eeb1351889eacff349a983d8', } - filenames = {'train': 'train.zip', 'test': 'test.zip'} + filenames: ClassVar[dict[str, str]] = {'train': 'train.zip', 'test': 'test.zip'} - tasks = ['multi-class', 'multi-label'] - splits = ['train', 'test'] + tasks = ('multi-class', 'multi-label') + splits = ('train', 'test') def __init__( self, - root: str = 'data', + root: Path = 'data', task: str = 'multi-class', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -252,7 +252,7 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: return sample - def _load_files(self, root: str) -> list[dict[str, Any]]: + def _load_files(self, root: Path) -> list[dict[str, Any]]: """Return the paths of the files in the dataset. Args: @@ -295,7 +295,7 @@ def _load_files(self, root: str) -> list[dict[str, Any]]: return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -312,11 +312,7 @@ def _load_image(self, path: str) -> Tensor: return tensor def _verify(self) -> None: - """Checks the integrity of the dataset structure. - - Returns: - True if the dataset directories are found, else False - """ + """Verify the integrity of the dataset.""" filepath = os.path.join(self.root, self.split) if os.path.isdir(filepath): return diff --git a/torchgeo/datasets/mmearth.py b/torchgeo/datasets/mmearth.py new file mode 100644 index 00000000000..f363276c40a --- /dev/null +++ b/torchgeo/datasets/mmearth.py @@ -0,0 +1,620 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""MMEarth Dataset.""" + +import json +import os +from collections.abc import Callable, Sequence +from datetime import datetime, timedelta +from typing import Any, ClassVar, cast + +import numpy as np +import torch +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, lazy_import + + +class MMEarth(NonGeoDataset): + """MMEarth dataset. + + There are three different versions of the dataset, that vary in image size + and the number of tiles: + + * MMEarth: 128x128 px, 1.2M tiles, 579 GB + * MMEarth64: 64x64 px, 1.2M tiles, 162 GB + * MMEarth100k: 128x128 px, 100K tiles, 48 GB + + The dataset consists of 12 modalities: + + * Aster: elevation and slope + * Biome: 14 terrestrial ecosystem categories + * ETH Canopy Height: Canopy height and standard deviation + * Dynamic World: 9 landcover categories + * Ecoregion: 846 ecoregion categories + * ERA5: Climate reanalysis data for temperature mean, min, and max of [year, month, previous month] + and precipitation total of [year, month, previous month] (counted as separate modalities) + * ESA World Cover: 11 landcover categories + * Sentinel-1: VV, VH, HV, HH for ascending/descending orbit + * Sentinel-2: multi-spectral B1-B12 for L1C/L2A products + * Geolocation: cyclic encoding of latitude and longitude + * Date: cyclic encoding of month + + Additionally, there are three masks available as modalities: + + * Sentinel-2 Cloudmask: Sentinel-2 cloud mask + * Sentinel-2 Cloud probability: Sentinel-2 cloud probability + * Sentinel-2 SCL: Sentinel-2 scene classification + + that are synchronized across tiles. + + Dataset format: + + * Dataset in single HDF5 file + * JSON files for band statistics, splits, and tile information + + For additional information, as well as bash scripts to + download the data, please refer to the + `official repository `_. + + If you use this dataset in your research, please cite the following paper: + + * https://arxiv.org/abs/2405.02771 + + .. note:: + + This dataset requires the following additional library to be installed: + + * `h5py `_ to load the dataset + + .. versionadded:: 0.7 + """ + + subsets = ('MMEarth', 'MMEarth64', 'MMEarth100k') + + filenames: ClassVar[dict[str, str]] = { + 'MMEarth': 'data_1M_v001', + 'MMEarth64': 'data_1M_v001_64', + 'MMEarth100k': 'data_100k_v001', + } + + all_modalities = ( + 'aster', + 'biome', + 'canopy_height_eth', + 'dynamic_world', + 'eco_region', + 'era5', + 'esa_worldcover', + 'sentinel1_asc', + 'sentinel1_desc', + 'sentinel2', + 'sentinel2_cloudmask', + 'sentinel2_cloudprod', + 'sentinel2_scl', + ) + + # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/MODALITIES.py#L108C1-L160C2 + all_modality_bands: ClassVar[dict[str, list[str]]] = { + 'sentinel2': [ + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8A', + 'B8', + 'B9', + 'B10', + 'B11', + 'B12', + ], + 'sentinel2_cloudmask': ['QA60'], + 'sentinel2_cloudprod': ['MSK_CLDPRB'], + 'sentinel2_scl': ['SCL'], + 'sentinel1_asc': ['VV', 'VH', 'HH', 'HV'], + 'sentinel1_desc': ['VV', 'VH', 'HH', 'HV'], + 'aster': ['b1', 'slope'], # elevation and slope + 'era5': [ + 'prev_temperature_2m', # previous month avg temp + 'prev_temperature_2m_min', # previous month min temp + 'prev_temperature_2m_max', # previous month max temp + 'prev_total_precipitation_sum', # previous month total precip + 'curr_temperature_2m', # current month avg temp + 'curr_temperature_2m_min', # current month min temp + 'curr_temperature_2m_max', # current month max temp + 'curr_total_precipitation_sum', # current month total precip + '0_temperature_2m_mean', # year avg temp + '1_temperature_2m_min_min', # year min temp + '2_temperature_2m_max_max', # year max temp + '3_total_precipitation_sum_sum', # year total precip + ], + 'dynamic_world': ['label'], + 'canopy_height_eth': ['height', 'std'], + 'lat': ['sin', 'cos'], + 'lon': ['sin', 'cos'], + 'biome': ['biome'], + 'eco_region': ['eco_region'], + 'month': ['sin_month', 'cos_month'], + 'esa_worldcover': ['Map'], + } + + # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/MODALITIES.py#L36 + no_data_vals: ClassVar[dict[str, int | float]] = { + 'sentinel2': 0, + 'sentinel2_cloudmask': 65535, + 'sentinel2_cloudprod': 65535, + 'sentinel2_scl': 255, + 'sentinel1_asc': float('-inf'), + 'sentinel1_desc': float('-inf'), + 'aster': float('-inf'), + 'canopy_height_eth': 255, + 'dynamic_world': 0, + 'esa_worldcover': 255, + 'lat': float('-inf'), + 'lon': float('-inf'), + 'month': float('-inf'), + 'era5': float('inf'), + 'biome': 255, + 'eco_region': 65535, + } + + norm_modes = ('z-score', 'min-max') + + modality_category_name: ClassVar[dict[str, str]] = { + 'sentinel1_asc': 'image_', + 'sentinel1_desc': 'image_', + 'sentinel2': 'image_', + 'sentinel2_cloudmask': 'mask_', + 'sentinel2_cloudprod': 'mask_', + 'sentinel2_scl': 'mask_', + 'aster': 'image_', + 'era5': '', + 'canopy_height_eth': 'image_', + 'dynamic_world': 'mask_', + 'esa_worldcover': 'mask_', + } + + def __init__( + self, + root: Path = 'data', + subset: str = 'MMEarth', + modalities: Sequence[str] = all_modalities, + modality_bands: dict[str, list[str]] | None = None, + normalization_mode: str = 'z-score', + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + ) -> None: + """Initialize the MMEarth dataset. + + Args: + root: root directory where dataset can be found + subset: one of "MMEarth", "MMEarth64", or "MMEarth100k" + modalities: list of modalities to load + modality_bands: dictionary of modality bands, see + normalization_mode: one of "z-score" or "min-max" + transforms: a function/transform that takes input sample dictionary + and returns a transformed version + + Raises: + AssertionError: if *normalization_mode* or *subset* + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + lazy_import('h5py') + + assert ( + normalization_mode in self.norm_modes + ), f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}' + assert ( + subset in self.subsets + ), f'Invalid dataset version: {subset}, please choose from {self.subsets}' + + self._validate_modalities(modalities) + self.modalities = modalities + if modality_bands is None: + modality_bands = { + modality: self.all_modality_bands[modality] for modality in modalities + } + self._validate_modality_bands(modality_bands) + self.modality_bands = modality_bands + + self.root = root + self.subset = subset + self.normalization_mode = normalization_mode + self.split = 'train' + self.transforms = transforms + + self.dataset_filename = f'{self.filenames[subset]}.h5' + self.band_stats_filename = f'{self.filenames[subset]}_band_stats.json' + self.splits_filename = f'{self.filenames[subset]}_splits.json' + self.tile_info_filename = f'{self.filenames[subset]}_tile_info.json' + + self._verify() + + self.indices = self._load_indices() + self.band_stats = self._load_normalization_stats() + self.tile_info = self._load_tile_info() + + def _verify(self) -> None: + """Verify the dataset.""" + data_dir = os.path.join(self.root, self.filenames[self.subset]) + + exists = [ + os.path.exists(os.path.join(data_dir, f)) + for f in [ + self.dataset_filename, + self.band_stats_filename, + self.splits_filename, + self.tile_info_filename, + ] + ] + if not all(exists): + raise DatasetNotFoundError(self) + + def _load_indices(self) -> list[int]: + """Load the indices for the dataset split. + + Returns: + list of indices + """ + with open( + os.path.join(self.root, self.filenames[self.subset], self.splits_filename) + ) as f: + split_indices: dict[str, list[int]] = json.load(f) + + return split_indices[self.split] + + def _load_normalization_stats(self) -> dict[str, dict[str, float]]: + """Load normalization statistics for each band. + + Returns: + dictionary containing the normalization statistics + """ + with open( + os.path.join( + self.root, self.filenames[self.subset], self.band_stats_filename + ) + ) as f: + band_stats = json.load(f) + + return cast(dict[str, dict[str, float]], band_stats) + + def _load_tile_info(self) -> dict[str, dict[str, str]]: + """Load tile information. + + Returns: + dictionary containing tile information + """ + with open( + os.path.join( + self.root, self.filenames[self.subset], self.tile_info_filename + ) + ) as f: + tile_info = json.load(f) + + return cast(dict[str, dict[str, str]], tile_info) + + def _validate_modalities(self, modalities: Sequence[str]) -> None: + """Validate list of modalities. + + Args: + modalities: user-provided sequence of modalities to load + + Raises: + AssertionError: if ``modalities`` is not a sequence or an + invalid modality name is provided + """ + # validate modalities + assert isinstance(modalities, Sequence), "'modalities' must be a sequence" + if not set(modalities) <= set(self.all_modalities): + raise ValueError( + f'{set(modalities) - set(self.all_modalities)} is an invalid modality.' + ) + + def _validate_modality_bands(self, modality_bands: dict[str, list[str]]) -> None: + """Validate modality bands. + + Args: + modality_bands: user-provided dictionary of modality bands + + Raises: + AssertionError: if ``modality_bands`` is not a dictionary + ValueError: if an invalid modality name is provided + ValueError: if modality bands are invalid + """ + assert isinstance(modality_bands, dict), "'modality_bands' must be a dictionary" + # validate modality bands + for key, vals in modality_bands.items(): + # check that the modality name is also specified in modalities + if key not in self.modalities: + raise ValueError(f"'{key}' is an invalid modality name.") + for val in vals: + if val not in self.all_modality_bands[key]: + raise ValueError( + f"'{val}' is an invalid band name for modality '{key}'." + ) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Return a sample from the dataset. + + Normalization is applied to the data with chosen ``normalization_mode``. + In addition to the modalities, the sample contains the following raw metadata: + + * lat: latitude + * lon: longitude + * date: date + * crs: coordinate reference system + * tile_id: tile identifier + + Args: + index: index to return + + Returns: + dictionary containing the modalities and metadata + of the sample + """ + ds_index = self.indices[index] + + # expose sample retrieval to separate function to allow for different index sampling strategies + # in subclasses + sample = self._retrieve_sample(ds_index) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def get_sample_specific_band_names( + self, tile_info: dict[str, Any] + ) -> dict[str, list[str]]: + """Retrieve the sample specific band names. + + Args: + tile_info: tile information for a sample + + Returns: + dictionary containing the specific band names for each modality + """ + date_str = tile_info['S2_DATE'] + date_obj = datetime.strptime(date_str, '%Y-%m-%d') + curr_month_str = date_obj.strftime('%Y%m') + # set to first day of month and subtract one day to get previous month + prev_month_obj = date_obj.replace(day=1) - timedelta(days=1) + prev_month_str = prev_month_obj.strftime('%Y%m') + + specific_modality_bands = {} + for modality, bands in self.modality_bands.items(): + if modality == 'era5': + # replace date with the 'prev' and 'curr' strings for generality + bands = [band.replace(prev_month_str, 'prev') for band in bands] + bands = [band.replace(curr_month_str, 'curr') for band in bands] + specific_modality_bands[modality] = bands + + return specific_modality_bands + + def get_intersection_dict(self, tile_info: dict[str, Any]) -> dict[str, list[str]]: + """Get intersection of requested and available bands. + + Args: + tile_info: tile information for a sample + + Returns: + Dictionary with intersected keys and lists. + """ + sample_specific_band_names = self.get_sample_specific_band_names(tile_info) + # used the chosen modality bands to get the intersection with available bands + intersection_dict = {} + for modality in self.all_modalities: + if modality in sample_specific_band_names: + intersected_list = [ + band + for band in self.all_modality_bands[modality] + if band in sample_specific_band_names[modality] + ] + if intersected_list: + intersection_dict[modality] = intersected_list + + return intersection_dict + + def _retrieve_sample(self, ds_index: int) -> dict[str, Any]: + """Retrieve a sample from the dataset. + + Args: + ds_index: index inside the hdf5 dataset file + + Returns: + dictionary containing the modalities and metadata + of the sample + """ + h5py = lazy_import('h5py') + sample: dict[str, Any] = {} + with h5py.File( + os.path.join(self.root, self.filenames[self.subset], self.dataset_filename), + 'r', + ) as f: + name = f['metadata'][ds_index][0].decode('utf-8') + tile_info: dict[str, Any] = self.tile_info[name] + # need to find the intersection of requested and available bands + intersection_dict = self.get_intersection_dict(tile_info) + for modality, bands in intersection_dict.items(): + if 'sentinel1' in modality: + data = f['sentinel1'][ds_index][:] + else: + data = f[modality][ds_index][:] + + tensor = self._preprocess_modality(data, modality, tile_info, bands) + modality_name = self.modality_category_name.get(modality, '') + modality + sample[modality_name] = tensor + + # add the sensor and bands actually available + sample['avail_bands'] = intersection_dict + + # add additional metadata to the sample + sample['lat'] = tile_info['lat'] + sample['lon'] = tile_info['lon'] + sample['date'] = tile_info['S2_DATE'] + sample['crs'] = tile_info['CRS'] + sample['tile_id'] = name + + return sample + + def _select_indices_for_modality( + self, modality: str, bands: list[str] + ) -> list[int]: + """Select bands for a modality. + + Args: + modality: modality name + bands: bands aviailable for the modality + + Returns: + list of band indices + """ + # need to handle sentinel1 descending separately, because ascending + # and descending are stored under the same modality + if modality == 'sentinel1_desc': + indices = [ + self.all_modality_bands['sentinel1_desc'].index(band) + 4 + for band in bands + ] + # the modality is called sentinel2 but has different bands stats for l1c and l2a + # but common indices + elif modality in ['sentinel2_l1c', 'sentinel2_l2a']: + indices = [ + self.all_modality_bands['sentinel2'].index(band) for band in bands + ] + else: + indices = [self.all_modality_bands[modality].index(band) for band in bands] + return indices + + def _preprocess_modality( + self, + data: 'np.typing.NDArray[Any]', + modality: str, + tile_info: dict[str, Any], + bands: list[str], + ) -> Tensor: + """Preprocess a single modality. + + Args: + data: data to process + modality: modality name + tile_info: tile information + bands: available bands for the modality + + Returns: + processed data + """ + # band selection for modality + indices = self._select_indices_for_modality(modality, bands) + data = data[indices, ...] + + # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/mmearth_dataset.py#L69 + if modality == 'dynamic_world': + # first replace 0 with nan then assign new labels to have 0-index classes + data = np.where(data == self.no_data_vals[modality], np.nan, data) + old_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, np.nan] + new_values = [0, 1, 2, 3, 4, 5, 6, 7, 8, np.nan] + for old, new in zip(old_values, new_values): + data = np.where(data == old, new, data) + + # need to replace nan with a no-data value and get long tensor + # maybe also 255 like esa_worldcover + tensor = torch.from_numpy(data) + + elif modality == 'esa_worldcover': + old_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 100, 255] + new_values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 255] + for old, new in zip(old_values, new_values): + data = np.where(data == old, new, data) + + # currently no-data value is still 255 + tensor = torch.from_numpy(data).long() + + elif modality in [ + 'aster', + 'canopy_height_eth', + 'sentinel1_asc', + 'sentinel1_desc', + 'sentinel2', + 'era5', + 'lat', + 'lon', + 'month', + ]: + data = data.astype(np.float32) + # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/mmearth_dataset.py#L88 + # the modality is called sentinel2 but has different bands stats for l1c and l2a + if modality == 'sentinel2': + modality_ = ( + 'sentinel2_l2a' + if tile_info['S2_type'] == 'l2a' + else 'sentinel2_l1c' + ) + else: + modality_ = modality + data = self._normalize_modality(data, modality_, bands) + data = np.where(data == self.no_data_vals[modality], np.nan, data) + tensor = torch.from_numpy(data).float() + elif modality in ['biome', 'eco_region']: + data = data.astype(np.int32) + # no data value also 255 for biome and 65535 for eco_region + tensor = torch.from_numpy(data).long() + elif modality in [ + 'sentinel2_cloudmask', + 'sentinel2_cloudprod', + 'sentinel2_scl', + ]: + tensor = torch.from_numpy(data.astype(np.int32)).long() + + # TODO: tensor might still contain nans, how to handle this? + return tensor + + def _normalize_modality( + self, data: 'np.typing.NDArray[Any]', modality: str, bands: list[str] + ) -> 'np.typing.NDArray[np.float64]': + """Normalize a single modality. + + Args: + data: data to normalize + modality: modality name + bands: available bands for the modality + + Returns: + normalized data + """ + indices = self._select_indices_for_modality(modality, bands) + + if 'sentinel1' in modality: + modality = 'sentinel1' + + if self.normalization_mode == 'z-score': + mean = np.array(self.band_stats[modality]['mean'])[indices, ...] + std = np.array(self.band_stats[modality]['std'])[indices, ...] + if data.ndim == 3: + data = (data - mean[:, None, None]) / std[:, None, None] + else: + data = (data - mean) / std + elif self.normalization_mode == 'min-max': + min_val = np.array(self.band_stats[modality]['min'])[indices, ...] + max_val = np.array(self.band_stats[modality]['max'])[indices, ...] + if data.ndim == 3: + data = (data - min_val[:, None, None]) / ( + max_val[:, None, None] - min_val[:, None, None] + ) + else: + data = (data - min_val) / (max_val - min_val) + + return data + + def __len__(self) -> int: + """Return the length of the dataset. + + Returns: + length of the dataset + """ + return len(self.indices) diff --git a/torchgeo/datasets/naip.py b/torchgeo/datasets/naip.py index d8185782367..326dccd6d72 100644 --- a/torchgeo/datasets/naip.py +++ b/torchgeo/datasets/naip.py @@ -45,8 +45,8 @@ class NAIP(RasterDataset): """ # Plotting - all_bands = ['R', 'G', 'B', 'NIR'] - rgb_bands = ['R', 'G', 'B'] + all_bands = ('R', 'G', 'B', 'NIR') + rgb_bands = ('R', 'G', 'B') def plot( self, diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 6d7a2dcdaa5..9c57e290406 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -3,6 +3,7 @@ """NASA Marine Debris dataset.""" +import glob import os from collections.abc import Callable @@ -16,13 +17,13 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import Path, which class NASAMarineDebris(NonGeoDataset): """NASA Marine Debris dataset. - The `NASA Marine Debris `__ + The `NASA Marine Debris `__ dataset is a dataset for detection of floating marine debris in satellite imagery. Dataset features: @@ -47,26 +48,19 @@ class NASAMarineDebris(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.2 """ - collection_ids = ['nasa_marine_debris_source', 'nasa_marine_debris_labels'] - directories = ['nasa_marine_debris_source', 'nasa_marine_debris_labels'] - filenames = ['nasa_marine_debris_source.tar.gz', 'nasa_marine_debris_labels.tar.gz'] - md5s = ['fe8698d1e68b3f24f0b86b04419a797d', 'd8084f5a72778349e07ac90ec1e1d990'] - class_label = 'marine_debris' + url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa-marine-debris' def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, - verbose: bool = False, ) -> None: """Initialize a new NASA Marine Debris Dataset instance. @@ -75,9 +69,6 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - verbose: if True, print messages when new tiles are loaded Raises: DatasetNotFoundError: If dataset is not found and *download* is False. @@ -85,11 +76,11 @@ def __init__( self.root = root self.transforms = transforms self.download = download - self.api_key = api_key - self.checksum = checksum - self.verbose = verbose + self._verify() - self.files = self._load_files() + + self.source = sorted(glob.glob(os.path.join(self.root, 'source', '*.tif'))) + self.labels = sorted(glob.glob(os.path.join(self.root, 'labels', '*.npy'))) def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -100,15 +91,21 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and labels at that index """ - image = self._load_image(self.files[index]['image']) - boxes = self._load_target(self.files[index]['target']) - sample = {'image': image, 'boxes': boxes} + with rasterio.open(self.source[index]) as source: + image = torch.from_numpy(source.read()).float() + + labels = np.load(self.labels[index]) + + # Boxes contain unnecessary value of 1 after xyxy coords + boxes = torch.from_numpy(labels[:, :4]) # Filter invalid boxes - w_check = (sample['boxes'][:, 2] - sample['boxes'][:, 0]) > 0 - h_check = (sample['boxes'][:, 3] - sample['boxes'][:, 1]) > 0 + w_check = (boxes[:, 2] - boxes[:, 0]) > 0 + h_check = (boxes[:, 3] - boxes[:, 1]) > 0 indices = w_check & h_check - sample['boxes'] = sample['boxes'][indices] + boxes = boxes[indices] + + sample = {'image': image, 'boxes': boxes} if self.transforms is not None: sample = self.transforms(sample) @@ -121,85 +118,13 @@ def __len__(self) -> int: Returns: length of the dataset """ - return len(self.files) - - def _load_image(self, path: str) -> Tensor: - """Load a single image. - - Args: - path: path to the image - - Returns: - the image - """ - with rasterio.open(path) as f: - array = f.read() - tensor = torch.from_numpy(array).float() - return tensor - - def _load_target(self, path: str) -> Tensor: - """Load the target bounding boxes for a single image. - - Args: - path: path to the labels - - Returns: - the target boxes - """ - array = np.load(path) - # boxes contain unecessary value of 1 after xyxy coords - array = array[:, :4] - tensor = torch.from_numpy(array) - return tensor - - def _load_files(self) -> list[dict[str, str]]: - """Load a image and label files. - - Returns: - list of dicts containing image and label files - """ - image_root = os.path.join(self.root, self.directories[0]) - target_root = os.path.join(self.root, self.directories[1]) - image_folders = sorted( - f for f in os.listdir(image_root) if not f.endswith('json') - ) - - files = [] - for folder in image_folders: - files.append( - { - 'image': os.path.join(image_root, folder, 'image_geotiff.tif'), - 'target': os.path.join( - target_root, - folder.replace('source', 'labels'), - 'pixel_bounds.npy', - ), - } - ) - return files + return len(self.source) def _verify(self) -> None: """Verify the integrity of the dataset.""" - # Check if the files already exist - exists = [ - os.path.exists(os.path.join(self.root, directory)) - for directory in self.directories - ] - if all(exists): - return - - # Check if zip file already exists (if so then extract) - exists = [] - for filename, md5 in zip(self.filenames, self.md5s): - filepath = os.path.join(self.root, filename) - if os.path.exists(filepath): - if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError('Dataset checksum mismatch.') - exists.append(True) - extract_archive(filepath) - else: - exists.append(False) - + # Check if the directories already exist + dirs = ['source', 'labels'] + exists = [os.path.exists(os.path.join(self.root, d)) for d in dirs] if all(exists): return @@ -207,14 +132,14 @@ def _verify(self) -> None: if not self.download: raise DatasetNotFoundError(self) - # Download and extract the dataset - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, self.api_key) - for filename, md5 in zip(self.filenames, self.md5s): - filepath = os.path.join(self.root, filename) - if self.checksum and not check_integrity(filepath, md5): - raise RuntimeError('Dataset checksum mismatch.') - extract_archive(filepath) + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, @@ -238,25 +163,25 @@ def plot( image = sample['image'] if 'boxes' in sample and len(sample['boxes']): image = draw_bounding_boxes(image=sample['image'], boxes=sample['boxes']) - image = image.permute((1, 2, 0)).numpy() + image_arr = image.permute((1, 2, 0)).numpy() if 'prediction_boxes' in sample and len(sample['prediction_boxes']): ncols += 1 preds = draw_bounding_boxes( image=sample['image'], boxes=sample['prediction_boxes'] ) - preds = preds.permute((1, 2, 0)).numpy() + preds_arr = preds.permute((1, 2, 0)).numpy() fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) if ncols < 2: - axs.imshow(image) + axs.imshow(image_arr) axs.axis('off') if show_titles: axs.set_title('Ground Truth') else: - axs[0].imshow(image) + axs[0].imshow(image_arr) axs[0].axis('off') - axs[1].imshow(preds) + axs[1].imshow(preds_arr) axs[1].axis('off') if show_titles: diff --git a/torchgeo/datasets/nccm.py b/torchgeo/datasets/nccm.py index 68e0566e28a..96633b2e35b 100644 --- a/torchgeo/datasets/nccm.py +++ b/torchgeo/datasets/nccm.py @@ -4,7 +4,7 @@ """Northeastern China Crop Map Dataset.""" from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, download_url +from .utils import BoundingBox, Path, download_url class NCCM(RasterDataset): @@ -57,23 +57,23 @@ class NCCM(RasterDataset): date_format = '%Y' is_image = False - urls = { + urls: ClassVar[dict[int, str]] = { 2019: 'https://figshare.com/ndownloader/files/25070540', 2018: 'https://figshare.com/ndownloader/files/25070624', 2017: 'https://figshare.com/ndownloader/files/25070582', } - md5s = { + md5s: ClassVar[dict[int, str]] = { 2019: '0d062bbd42e483fdc8239d22dba7020f', 2018: 'b3bb4894478d10786aa798fb11693ec1', 2017: 'd047fbe4a85341fa6248fd7e0badab6c', } - fnames = { + fnames: ClassVar[dict[int, str]] = { 2019: 'CDL2019_clip.tif', 2018: 'CDL2018_clip1.tif', 2017: 'CDL2017_clip.tif', } - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 255, 0, 255), 1: (255, 0, 0, 255), 2: (255, 255, 0, 255), @@ -83,7 +83,7 @@ class NCCM(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2019], diff --git a/torchgeo/datasets/nlcd.py b/torchgeo/datasets/nlcd.py index 13c6883801d..501fd6db8f9 100644 --- a/torchgeo/datasets/nlcd.py +++ b/torchgeo/datasets/nlcd.py @@ -3,10 +3,9 @@ """NLCD dataset.""" -import glob import os from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import torch @@ -15,20 +14,18 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import BoundingBox, download_url, extract_archive +from .utils import BoundingBox, Path, download_url class NLCD(RasterDataset): - """National Land Cover Database (NLCD) dataset. + """Annual National Land Cover Database (NLCD) dataset. - The `NLCD dataset - `_ - is a land cover product that covers the United States and Puerto Rico. The current - implementation supports maps for the continental United States only. The product is - a joint effort between the United States Geological Survey + The `Annual NLCD products + `_ + is an annual land cover product for the conterminous U.S. initially covering the period + from 1985 to 2023. The product is a joint effort between the United States Geological Survey (`USGS `_) and the Multi-Resolution Land Characteristics - Consortium (`MRLC `_) which released the first product - in 2001 with new updates every five years since then. + Consortium (`MRLC `_). The dataset contains the following 17 classes: @@ -57,36 +54,63 @@ class NLCD(RasterDataset): * single channel .img file with integer class labels - If you use this dataset in your research, please use the corresponding citation: + If you use this dataset in your research, please cite the following paper: - * 2001: https://doi.org/10.5066/P9MZGHLF - * 2006: https://doi.org/10.5066/P9HBR9V3 - * 2011: https://doi.org/10.5066/P97S2IID - * 2016: https://doi.org/10.5066/P96HHBIE - * 2019: https://doi.org/10.5066/P9KZCM54 + * https://doi.org/10.5066/P94UXNTS .. versionadded:: 0.5 - """ # noqa: E501 + """ - filename_glob = 'nlcd_*_land_cover_l48_*.img' - filename_regex = ( - r'nlcd_(?P\d{4})_land_cover_l48_(?P\d{8})\.img' - ) - zipfile_glob = 'nlcd_*_land_cover_l48_*.zip' + filename_glob = 'Annual_NLCD_LndCov_*_CU_C1V0.tif' + filename_regex = r'Annual_NLCD_LndCov_(?P\d{4})_CU_C1V0\.tif' date_format = '%Y' is_image = False - url = 'https://s3-us-west-2.amazonaws.com/mrlc/nlcd_{}_land_cover_l48_20210604.zip' - - md5s = { - 2001: '538166a4d783204764e3df3b221fc4cd', - 2006: '67454e7874a00294adb9442374d0c309', - 2011: 'ea524c835d173658eeb6fa3c8e6b917b', - 2016: '452726f6e3bd3f70d8ca2476723d238a', - 2019: '82851c3f8105763b01c83b4a9e6f3961', + url = 'https://s3-us-west-2.amazonaws.com/mrlc/Annual_NLCD_LndCov_{}_CU_C1V0.tif' + + md5s: ClassVar[dict[int, str]] = { + 1985: 'a2e1c5f0b34e9b15a63a9dc10e8d3ec2', + 1986: 'da1d08ca51ac43abc14711c8d6139f1d', + 1987: '2cb85e8f077c227605cd7bac62a72a75', + 1988: 'b20fb987cc30926d2d125d045e02626d', + 1989: 'dbe851cbea34d0a57c2a94eb745a1267', + 1990: '1927e0e040b9ff513ff039749b64919b', + 1991: 'eca73474843d6c58693eba62d70e507c', + 1992: '8beda41ba79000f55a8e9358ba3fa5a4', + 1993: '1a023552967cdac1111e9968ea62c879', + 1994: 'acc30ce4f6cdd78af5f7887d17ac4de3', + 1995: 'f728e8fc231b2e8e74a14201f500543a', + 1996: 'd2580904244f89b20d6258150fbf4161', + 1997: 'fec4e08032e162f2cc7dbe019d042609', + 1998: '87ea19434de96ea99cd5d7991042816c', + 1999: 'd4133737f20e75f3bd3a5baa32a668da', + 2000: 'e20b61bb2e7f4034a33c9fd536798a01', + 2001: 'b1f46ace9aedd17a89efab489cb67bc3', + 2002: '57bf60d7cd473096af3bb125391bde63', + 2003: '5e346854da9abf739152e85fee4c7aff', + 2004: '13136f271f53a454358eb7ec12bda686', + 2005: 'f00b66b57a23eb49a077e88704964a91', + 2006: '074ba90de5e62a37a5f001b7572f6baa', + 2007: 'cdef29a191cf165baaae80857ce5a980', + 2008: 'da907c76a1f12739333148504fd111c9', + 2009: '47890b306b875e681990b3db0c709da3', + 2010: '9a81f405f9e2f45d581078afd53c2d4b', + 2011: '13f4ef40b204aa1108dc0599d9546701', + 2012: '66b33146f9a9d9491be10c59c51e3e33', + 2013: 'f8d230f7dea493c47fbc74984ff856cc', + 2014: '68eb07ce86c1f7c2546ec43c2f9f7029', + 2015: 'f5a1b59fe54a70752f544c06cb965be4', + 2016: 'f0c2e74824fc281a57821e28e2c7fe6e', + 2017: 'a0aa8be0ed7d637f0f88f26d3742b20e', + 2018: 'a01f31547837ff1dfec1aba07b89bbec', + 2019: 'fa738201cddc1393dac4383b6ce2561a', + 2020: 'aa8f51690c7b01f3b3b413be9a7c36d6', + 2021: '47fc1794a64704a918b6ad586df4267c', + 2022: '11359748229e138cde971947864104a4', + 2023: '498ff8a512d32fe905720796fdb7fd52', } - cmap = { + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 0), 11: (70, 107, 159, 255), 12: (209, 222, 248, 255), @@ -108,10 +132,10 @@ class NLCD(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, - years: list[int] = [2019], + years: list[int] = [2023], classes: list[int] = list(cmap.keys()), transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, @@ -183,19 +207,14 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" - # Check if the extracted files already exist - if self.files: - return - - # Check if the zip files have already been downloaded + # Check if the TIFF files for the specified years have already been downloaded exists = [] for year in self.years: - zipfile_year = self.zipfile_glob.replace('*', str(year), 1) - assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, '**', zipfile_year) - if glob.glob(pathname, recursive=True): + filename_year = self.filename_glob.replace('*', str(year), 1) + assert isinstance(self.paths, str | os.PathLike) + pathname = os.path.join(self.paths, filename_year) + if os.path.exists(pathname): exists.append(True) - self._extract() else: exists.append(False) @@ -208,7 +227,6 @@ def _verify(self) -> None: # Download the dataset self._download() - self._extract() def _download(self) -> None: """Download the dataset.""" @@ -219,14 +237,6 @@ def _download(self) -> None: md5=self.md5s[year] if self.checksum else None, ) - def _extract(self) -> None: - """Extract the dataset.""" - for year in self.years: - zipfile_name = self.zipfile_glob.replace('*', str(year), 1) - assert isinstance(self.paths, str) - pathname = os.path.join(self.paths, '**', zipfile_name) - extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) - def plot( self, sample: dict[str, Any], diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index ec1650ed88f..292dc274c32 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -8,7 +8,7 @@ import os import sys from collections.abc import Callable, Iterable -from typing import Any, cast +from typing import Any, ClassVar, cast import fiona import fiona.transform @@ -24,14 +24,14 @@ from .errors import DatasetNotFoundError from .geo import VectorDataset -from .utils import BoundingBox, check_integrity +from .utils import BoundingBox, Path, check_integrity class OpenBuildings(VectorDataset): r"""Open Buildings dataset. The `Open Buildings - `__ dataset + `__ dataset consists of computer generated building detections across the African continent. Dataset features: @@ -47,10 +47,10 @@ class OpenBuildings(VectorDataset): * meta data geojson file The data can be downloaded from `here - `__. Additionally, the - `meta data geometry file - `_ also needs to be - placed in `root` as `tiles.geojson`. + `__. + Additionally, the `meta data geometry file + `_ + also needs to be placed in `root` as `tiles.geojson`. If you use this dataset in your research, please cite the following technical report: @@ -60,7 +60,7 @@ class OpenBuildings(VectorDataset): .. versionadded:: 0.3 """ - md5s = { + md5s: ClassVar[dict[str, str]] = { '025_buildings.csv.gz': '41db2572bfd08628d01475a2ee1a2f17', '04f_buildings.csv.gz': '3232c1c6d45c1543260b77e5689fc8b1', '05b_buildings.csv.gz': '4fc57c63bbbf9a21a3902da7adc3a670', @@ -207,7 +207,7 @@ class OpenBuildings(VectorDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 0.0001, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -241,7 +241,7 @@ def __init__( # Create an R-tree to index the dataset using the polygon centroid as bounds self.index = Index(interleaved=False, properties=Property(dimension=3)) - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) with open(os.path.join(self.paths, 'tiles.geojson')) as f: data = json.load(f) @@ -327,7 +327,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: else: masks = torch.zeros(size=(1, round(height), round(width))) - sample = {'mask': masks, 'crs': self.crs, 'bbox': query} + sample = {'mask': masks, 'crs': self.crs, 'bounds': query} if self.transforms is not None: sample = self.transforms(sample) @@ -397,7 +397,7 @@ def _wkt_fiona_geom_transform(self, x: str) -> dict[str, Any]: def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the zip files have already been downloaded and checksum - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) pathname = os.path.join(self.paths, self.zipfile_glob) i = 0 for zipfile in glob.iglob(pathname): diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index b2ad8aef275..28f7714a7c6 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -17,6 +18,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset from .utils import ( + Path, download_url, draw_semantic_segmentation_masks, extract_archive, @@ -49,7 +51,7 @@ class OSCD(NonGeoDataset): .. versionadded:: 0.2 """ - urls = { + urls: ClassVar[dict[str, str]] = { 'Onera Satellite Change Detection dataset - Images.zip': ( 'https://partage.imt.fr/index.php/s/gKRaWgRnLMfwMGo/download' ), @@ -60,7 +62,7 @@ class OSCD(NonGeoDataset): 'https://partage.imt.fr/index.php/s/gpStKn4Mpgfnr63/download' ), } - md5s = { + md5s: ClassVar[dict[str, str]] = { 'Onera Satellite Change Detection dataset - Images.zip': ( 'c50d4a2941da64e03a47ac4dec63d915' ), @@ -74,9 +76,9 @@ class OSCD(NonGeoDataset): zipfile_glob = '*Onera*.zip' filename_glob = '*Onera*' - splits = ['train', 'test'] + splits = ('train', 'test') - colormap = ['blue'] + colormap = ('blue',) all_bands = ( 'B01', @@ -98,7 +100,7 @@ class OSCD(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -207,7 +209,7 @@ def get_image_paths(ind: int) -> list[str]: return regions - def _load_image(self, paths: Sequence[str]) -> Tensor: + def _load_image(self, paths: Sequence[Path]) -> Tensor: """Load a single image. Args: @@ -224,7 +226,7 @@ def _load_image(self, paths: Sequence[str]) -> Tensor: tensor = torch.from_numpy(array).float() return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: @@ -318,7 +320,7 @@ def get_masked(img: Tensor) -> 'np.typing.NDArray[np.uint8]': torch.from_numpy(rgb_img), sample['mask'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) return array diff --git a/torchgeo/datasets/pastis.py b/torchgeo/datasets/pastis.py index 0b7629bcec5..06f716a9ffb 100644 --- a/torchgeo/datasets/pastis.py +++ b/torchgeo/datasets/pastis.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable, Sequence +from typing import ClassVar import fiona import matplotlib.pyplot as plt @@ -16,7 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive class PASTIS(NonGeoDataset): @@ -70,7 +71,7 @@ class PASTIS(NonGeoDataset): .. versionadded:: 0.5 """ - classes = [ + classes = ( 'background', # all non-agricultural land 'meadow', 'soft_winter_wheat', @@ -91,8 +92,8 @@ class PASTIS(NonGeoDataset): 'mixed_cereal', 'sorghum', 'void_label', # for parcels mostly outside their patch - ] - cmap = { + ) + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (174, 199, 232, 255), 2: (255, 127, 14, 255), @@ -116,9 +117,9 @@ class PASTIS(NonGeoDataset): } directory = 'PASTIS-R' filename = 'PASTIS-R.zip' - url = 'https://zenodo.org/record/5735646/files/PASTIS-R.zip?download=1' + url = 'https://zenodo.org/records/5735646/files/PASTIS-R.zip?download=1' md5 = '4887513d6c2d2b07fa935d325bd53e09' - prefix = { + prefix: ClassVar[dict[str, str]] = { 's2': os.path.join('DATA_S2', 'S2_'), 's1a': os.path.join('DATA_S1A', 'S1A_'), 's1d': os.path.join('DATA_S1D', 'S1D_'), @@ -128,7 +129,7 @@ class PASTIS(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', folds: Sequence[int] = (1, 2, 3, 4, 5), bands: str = 's2', mode: str = 'semantic', @@ -232,7 +233,7 @@ def _load_semantic_targets(self, index: int) -> Tensor: Returns: the target mask """ - # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # noqa: E501 + # See https://github.com/VSainteuf/pastis-benchmark/blob/main/code/dataloader.py#L201 # even though the mask file is 3 bands, we just select the first band array = np.load(self.files[index]['semantic'])[0].astype(np.uint8) tensor = torch.from_numpy(array).long() diff --git a/torchgeo/datasets/patternnet.py b/torchgeo/datasets/patternnet.py index f34afa443df..a9385049872 100644 --- a/torchgeo/datasets/patternnet.py +++ b/torchgeo/datasets/patternnet.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class PatternNet(NonGeoClassificationDataset): @@ -78,14 +78,14 @@ class PatternNet(NonGeoClassificationDataset): * https://doi.org/10.1016/j.isprsjprs.2018.01.004 """ - url = 'https://drive.google.com/file/d/127lxXYqzO6Bd0yZhvEbgIfz95HaEnr9K' + url = 'https://hf.co/datasets/torchgeo/PatternNet/resolve/2dbd901b00e301967a5c5146b25454f5d3455ad0/PatternNet.zip' md5 = '96d54b3224c5350a98d55d5a7e6984ad' filename = 'PatternNet.zip' directory = os.path.join('PatternNet', 'images') def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, diff --git a/torchgeo/datasets/potsdam.py b/torchgeo/datasets/potsdam.py index 943a489217a..51f1ebd0441 100644 --- a/torchgeo/datasets/potsdam.py +++ b/torchgeo/datasets/potsdam.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -17,6 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -53,12 +55,12 @@ class Potsdam2D(NonGeoDataset): * https://doi.org/10.5194/isprsannals-I-3-293-2012 .. versionadded:: 0.2 - """ # noqa: E501 + """ - filenames = ['4_Ortho_RGBIR.zip', '5_Labels_all.zip'] - md5s = ['c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db'] + filenames = ('4_Ortho_RGBIR.zip', '5_Labels_all.zip') + md5s = ('c4a8f7d8c7196dd4eba4addd0aae10c1', 'cf7403c1a97c0d279414db') image_root = '4_Ortho_RGBIR' - splits = { + splits: ClassVar[dict[str, list[str]]] = { 'train': [ 'top_potsdam_2_10', 'top_potsdam_2_11', @@ -102,26 +104,26 @@ class Potsdam2D(NonGeoDataset): 'top_potsdam_7_13', ], } - classes = [ + classes = ( 'Clutter/background', 'Impervious surfaces', 'Building', 'Low Vegetation', 'Tree', 'Car', - ] - colormap = [ + ) + colormap = ( (255, 0, 0), (255, 255, 255), (0, 0, 255), (0, 255, 255), (0, 255, 0), (255, 255, 0), - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -256,7 +258,7 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap + sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap) ) if 'prediction' in sample: ncols += 1 @@ -264,7 +266,7 @@ def plot( sample['image'][:3], sample['prediction'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index ce5d9a3bd2c..811d79cff08 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any, cast +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, lazy_import, percentile_normalization +from .utils import Path, download_url, lazy_import, percentile_normalization class QuakeSet(NonGeoDataset): @@ -61,12 +61,16 @@ class QuakeSet(NonGeoDataset): filename = 'earthquakes.h5' url = 'https://hf.co/datasets/DarthReca/quakeset/resolve/bead1d25fb9979dbf703f9ede3e8b349f73b29f7/earthquakes.h5' md5 = '76fc7c76b7ca56f4844d852e175e1560' - splits = {'train': 'train', 'val': 'validation', 'test': 'test'} - classes = ['unaffected_area', 'earthquake_affected_area'] + splits: ClassVar[dict[str, str]] = { + 'train': 'train', + 'val': 'validation', + 'test': 'test', + } + classes = ('unaffected_area', 'earthquake_affected_area') def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, diff --git a/torchgeo/datasets/reforestree.py b/torchgeo/datasets/reforestree.py index 28c1d0135f0..1c46c450191 100644 --- a/torchgeo/datasets/reforestree.py +++ b/torchgeo/datasets/reforestree.py @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import Path, check_integrity, download_and_extract_archive, extract_archive class ReforesTree(NonGeoDataset): @@ -56,15 +56,15 @@ class ReforesTree(NonGeoDataset): .. versionadded:: 0.3 """ - classes = ['other', 'banana', 'cacao', 'citrus', 'fruit', 'timber'] - url = 'https://zenodo.org/record/6813783/files/reforesTree.zip?download=1' + classes = ('other', 'banana', 'cacao', 'citrus', 'fruit', 'timber') + url = 'https://zenodo.org/records/6813783/files/reforesTree.zip?download=1' md5 = 'f6a4a1d8207aeaa5fbab7b21b683a302' zipfilename = 'reforesTree.zip' def __init__( self, - root: str = 'data', + root: Path = 'data', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -124,7 +124,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str) -> list[str]: + def _load_files(self, root: Path) -> list[str]: """Return the paths of the files in the dataset. Args: @@ -137,7 +137,7 @@ def _load_files(self, root: str) -> list[str]: return image_paths - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -153,7 +153,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, filepath: str) -> tuple[Tensor, ...]: + def _load_target(self, filepath: Path) -> tuple[Tensor, ...]: """Load boxes and labels for a single image. Args: diff --git a/torchgeo/datasets/resisc45.py b/torchgeo/datasets/resisc45.py index cd5adff76c8..fd33b634fde 100644 --- a/torchgeo/datasets/resisc45.py +++ b/torchgeo/datasets/resisc45.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -14,7 +14,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class RESISC45(NonGeoClassificationDataset): @@ -93,18 +93,18 @@ class RESISC45(NonGeoClassificationDataset): * https://doi.org/10.1109/jproc.2017.2675998 """ - url = 'https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv' - md5 = 'd824acb73957502b00efd559fc6cfbbb' - filename = 'NWPU-RESISC45.rar' + url = 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/NWPU-RESISC45.zip' + md5 = '75206b2e16446591afa88e2628744886' + filename = 'NWPU-RESISC45.zip' directory = 'NWPU-RESISC45' - splits = ['train', 'val', 'test'] - split_urls = { - 'train': 'https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt', # noqa: E501 - 'val': 'https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt', # noqa: E501 - 'test': 'https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt', # noqa: E501 + splits = ('train', 'val', 'test') + split_urls: ClassVar[dict[str, str]] = { + 'train': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-train.txt', + 'val': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-val.txt', + 'test': 'https://hf.co/datasets/torchgeo/resisc45/resolve/a826b44d938a883185f11ebe3d512d38b464312f/resisc45-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': 'b5a4c05a37de15e4ca886696a85c403e', 'val': 'a0770cee4c5ca20b8c32bbd61e114805', 'test': '3dda9e4988b47eb1de9f07993653eb08', @@ -112,7 +112,7 @@ class RESISC45(NonGeoClassificationDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -142,7 +142,7 @@ def __init__( for fn in f: valid_fns.add(fn.strip()) - def is_in_split(x: str) -> bool: + def is_in_split(x: Path) -> bool: return os.path.basename(x) in valid_fns super().__init__( diff --git a/torchgeo/datasets/rwanda_field_boundary.py b/torchgeo/datasets/rwanda_field_boundary.py index 9439e525ab3..48968f007f5 100644 --- a/torchgeo/datasets/rwanda_field_boundary.py +++ b/torchgeo/datasets/rwanda_field_boundary.py @@ -3,24 +3,27 @@ """Rwanda Field Boundary Competition dataset.""" +import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np import rasterio import rasterio.features import torch +from einops import rearrange from matplotlib.figure import Figure from torch import Tensor from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import Path, which class RwandaFieldBoundary(NonGeoDataset): - r"""Rwanda Field Boundary Competition dataset. + """Rwanda Field Boundary Competition dataset. This dataset contains field boundaries for smallholder farms in eastern Rwanda. The Nasa Harvest program funded a team of annotators from TaQadam to label Planet @@ -46,49 +49,27 @@ class RwandaFieldBoundary(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.5 """ - dataset_id = 'nasa_rwanda_field_boundary_competition' - collection_ids = [ - 'nasa_rwanda_field_boundary_competition_source_train', - 'nasa_rwanda_field_boundary_competition_labels_train', - 'nasa_rwanda_field_boundary_competition_source_test', - ] - number_of_patches_per_split = {'train': 57, 'test': 13} - - filenames = { - 'train_images': 'nasa_rwanda_field_boundary_competition_source_train.tar.gz', - 'test_images': 'nasa_rwanda_field_boundary_competition_source_test.tar.gz', - 'train_labels': 'nasa_rwanda_field_boundary_competition_labels_train.tar.gz', - } - md5s = { - 'train_images': '1f9ec08038218e67e11f82a86849b333', - 'test_images': '17bb0e56eedde2e7a43c57aa908dc125', - 'train_labels': '10e4eb761523c57b6d3bdf9394004f5f', - } + url = 'https://radiantearth.blob.core.windows.net/mlhub/nasa_rwanda_field_boundary_competition' + splits: ClassVar[dict[str, int]] = {'train': 57, 'test': 13} dates = ('2021_03', '2021_04', '2021_08', '2021_10', '2021_11', '2021_12') - all_bands = ('B01', 'B02', 'B03', 'B04') rgb_bands = ('B03', 'B02', 'B01') - - classes = ['No field-boundary', 'Field-boundary'] - - splits = ['train', 'test'] + classes = ('No field-boundary', 'Field-boundary') def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = all_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: """Initialize a new RwandaFieldBoundary instance. @@ -99,49 +80,29 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: If *split* or *bands* are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ - self._validate_bands(bands) assert split in self.splits - if download and api_key is None: - raise RuntimeError('Must provide an API key to download the dataset') + assert set(bands) <= set(self.all_bands) + self.root = root + self.split = split self.bands = bands self.transforms = transforms - self.split = split self.download = download - self.api_key = api_key - self.checksum = checksum + self._verify() - self.image_filenames: list[list[list[str]]] = [] - self.mask_filenames: list[str] = [] - for i in range(self.number_of_patches_per_split[split]): - dates = [] - for date in self.dates: - patch = [] - for band in self.bands: - fn = os.path.join( - self.root, - f'nasa_rwanda_field_boundary_competition_source_{split}', - f'nasa_rwanda_field_boundary_competition_source_{split}_{i:02d}_{date}', # noqa: E501 - f'{band}.tif', - ) - patch.append(fn) - dates.append(patch) - self.image_filenames.append(dates) - self.mask_filenames.append( - os.path.join( - self.root, - f'nasa_rwanda_field_boundary_competition_labels_{split}', - f'nasa_rwanda_field_boundary_competition_labels_{split}_{i:02d}', - 'raster_labels.tif', - ) - ) + def __len__(self) -> int: + """Return the number of chips in the dataset. + + Returns: + length of the dataset + """ + return self.splits[self.split] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -150,83 +111,34 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: index: index to return Returns: - a dict containing image, mask, transform, crs, and metadata at index. + a dict containing image and mask at index. """ - img_fns = self.image_filenames[index] - mask_fn = self.mask_filenames[index] - - imgs = [] - for date_fns in img_fns: - bands = [] - for band_fn in date_fns: - with rasterio.open(band_fn) as f: - bands.append(f.read(1).astype(np.int32)) - imgs.append(bands) - img = torch.from_numpy(np.array(imgs)) - - sample = {'image': img} + images = [] + for date in self.dates: + patches = [] + for band in self.bands: + path = os.path.join(self.root, 'source', self.split, date) + with rasterio.open(os.path.join(path, f'{index:02}_{band}.tif')) as src: + patches.append(src.read(1).astype(np.float32)) + images.append(patches) + sample = {'image': torch.from_numpy(np.array(images))} if self.split == 'train': - with rasterio.open(mask_fn) as f: - mask = f.read(1) - mask = torch.from_numpy(mask) - sample['mask'] = mask + path = os.path.join(self.root, 'labels', self.split) + with rasterio.open(os.path.join(path, f'{index:02}.tif')) as src: + sample['mask'] = torch.from_numpy(src.read(1).astype(np.int64)) if self.transforms is not None: sample = self.transforms(sample) return sample - def __len__(self) -> int: - """Return the number of chips in the dataset. - - Returns: - length of the dataset - """ - return len(self.image_filenames) - - def _validate_bands(self, bands: Sequence[str]) -> None: - """Validate list of bands. - - Args: - bands: user-provided sequence of bands to load - - Raises: - ValueError: if an invalid band name is provided - """ - for band in bands: - if band not in self.all_bands: - raise ValueError(f"'{band}' is an invalid band name.") - def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the subdirectories already exist and have the correct number of files - checks = [] - for split, num_patches in self.number_of_patches_per_split.items(): - path = os.path.join( - self.root, f'nasa_rwanda_field_boundary_competition_source_{split}' - ) - if os.path.exists(path): - num_files = len(os.listdir(path)) - # 6 dates + 1 collection.json file - checks.append(num_files == (num_patches * 6) + 1) - else: - checks.append(False) - - if all(checks): - return - - # Check if tar file already exists (if so then extract) - have_all_files = True - for group in ['train_images', 'train_labels', 'test_images']: - filepath = os.path.join(self.root, self.filenames[group]) - if os.path.exists(filepath): - if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError('Dataset found, but corrupted.') - extract_archive(filepath) - else: - have_all_files = False - if have_all_files: + path = os.path.join(self.root, 'source', self.split, '*', '*.tif') + expected = len(self.dates) * self.splits[self.split] * len(self.all_bands) + if len(glob.glob(path)) == expected: return # Check if the user requested to download the dataset @@ -237,15 +149,10 @@ def _verify(self) -> None: self._download() def _download(self) -> None: - """Download the dataset and extract it.""" - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, self.api_key) - - for group in ['train_images', 'train_labels', 'test_images']: - filepath = os.path.join(self.root, self.filenames[group]) - if self.checksum and not check_integrity(filepath, self.md5s[group]): - raise RuntimeError('Dataset not found or corrupted.') - extract_archive(filepath, self.root) + """Download the dataset.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') def plot( self, @@ -275,41 +182,33 @@ def plot( else: raise RGBBandsMissingError() - num_time_points = sample['image'].shape[0] - assert time_step < num_time_points - - image = np.rollaxis(sample['image'][time_step, rgb_indices].numpy(), 0, 3) - image = np.clip(image / 2000, 0, 1) - - if 'mask' in sample: - mask = sample['mask'].numpy() - else: - mask = np.zeros_like(image) - - num_panels = 2 - showing_predictions = 'prediction' in sample - if showing_predictions: - predictions = sample['prediction'].numpy() - num_panels += 1 + ncols = 1 + for key in ('mask', 'prediction'): + if key in sample: + ncols += 1 - fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4)) + fig, axs = plt.subplots(ncols=ncols, squeeze=False) - axs[0].imshow(image) - axs[0].axis('off') + image = torch.clamp(sample['image'][time_step, rgb_indices] / 2000, 0, 1) + image = rearrange(image, 'c h w -> h w c') + axs[0, 0].imshow(image) + axs[0, 0].axis('off') if show_titles: - axs[0].set_title(f't={time_step}') + axs[0, 0].set_title(f't={time_step}') - axs[1].imshow(mask, vmin=0, vmax=1, interpolation='none') - axs[1].axis('off') - if show_titles: - axs[1].set_title('Mask') + if 'mask' in sample: + axs[0, 1].imshow(sample['mask']) + axs[0, 1].axis('off') + if show_titles: + axs[0, 1].set_title('Mask') - if showing_predictions: - axs[2].imshow(predictions, vmin=0, vmax=1, interpolation='none') - axs[2].axis('off') + if 'prediction' in sample: + axs[0, 2].imshow(sample['prediction']) + axs[0, 2].axis('off') if show_titles: - axs[2].set_title('Predictions') + axs[0, 2].set_title('Prediction') if suptitle is not None: - plt.suptitle(suptitle) + fig.suptitle(suptitle) + return fig diff --git a/torchgeo/datasets/satlas.py b/torchgeo/datasets/satlas.py new file mode 100644 index 00000000000..f52a7fef98a --- /dev/null +++ b/torchgeo/datasets/satlas.py @@ -0,0 +1,770 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SatlasPretrain dataset.""" + +import os +from collections.abc import Callable, Iterable +from typing import ClassVar, TypedDict + +import numpy as np +import pandas as pd +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, check_integrity, extract_archive, which + + +class _Task(TypedDict, total=False): + BackgroundInvalid: bool + categories: list[str] + colors: list[list[int]] + type: str + + +# https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py +TASKS: dict[str, _Task] = { + 'polyline_bin_segment': { + 'type': 'bin_segment', + 'categories': [ + 'airport_runway', + 'airport_taxiway', + 'raceway', + 'road', + 'railway', + 'river', + ], + 'colors': [ + [255, 255, 255], # (white) airport_runway + [192, 192, 192], # (light grey) airport_taxiway + [160, 82, 45], # (sienna) raceway + [255, 255, 255], # (white) road + [144, 238, 144], # (light green) railway + [0, 0, 255], # (blue) river + ], + }, + 'bin_segment': { + 'type': 'bin_segment', + 'categories': [ + 'aquafarm', + 'lock', + 'dam', + 'solar_farm', + 'power_plant', + 'gas_station', + 'park', + 'parking_garage', + 'parking_lot', + 'landfill', + 'quarry', + 'stadium', + 'airport', + 'airport_runway', + 'airport_taxiway', + 'airport_apron', + 'airport_hangar', + 'airstrip', + 'airport_terminal', + 'ski_resort', + 'theme_park', + 'storage_tank', + 'silo', + 'track', + 'raceway', + 'wastewater_plant', + 'road', + 'railway', + 'river', + 'water_park', + 'pier', + 'water_tower', + 'street_lamp', + 'traffic_signals', + 'power_tower', + 'power_substation', + 'building', + 'bridge', + 'road_motorway', + 'road_trunk', + 'road_primary', + 'road_secondary', + 'road_tertiary', + 'road_residential', + 'road_service', + 'road_track', + 'road_pedestrian', + ], + 'colors': [ + [32, 178, 170], # (light sea green) aquafarm + [0, 255, 255], # (cyan) lock + [173, 216, 230], # (light blue) dam + [255, 0, 255], # (magenta) solar farm + [255, 165, 0], # (orange) power plant + [128, 128, 0], # (olive) gas station + [0, 255, 0], # (green) park + [47, 79, 79], # (dark slate gray) parking garage + [128, 0, 0], # (maroon) parking lot + [165, 42, 42], # (brown) landfill + [128, 128, 128], # (grey) quarry + [255, 215, 0], # (gold) stadium + [255, 105, 180], # (pink) airport + [255, 255, 255], # (white) airport_runway + [192, 192, 192], # (light grey) airport_taxiway + [128, 0, 128], # (purple) airport_apron + [0, 128, 0], # (dark green) airport_hangar + [248, 248, 255], # (ghost white) airstrip + [240, 230, 140], # (khaki) airport_terminal + [192, 192, 192], # (silver) ski_resort + [0, 96, 0], # (dark green) theme_park + [95, 158, 160], # (cadet blue) storage_tank + [205, 133, 63], # (peru) silo + [154, 205, 50], # (yellow green) track + [160, 82, 45], # (sienna) raceway + [218, 112, 214], # (orchid) wastewater_plant + [255, 255, 255], # (white) road + [144, 238, 144], # (light green) railway + [0, 0, 255], # (blue) river + [255, 240, 245], # (lavender blush) water_park + [65, 105, 225], # (royal blue) pier + [238, 130, 238], # (violet) water_tower + [75, 0, 130], # (indigo) street_lamp + [233, 150, 122], # (dark salmon) traffic_signals + [255, 255, 0], # (yellow) power_tower + [255, 255, 0], # (yellow) power_substation + [255, 0, 0], # (red) building + [64, 64, 64], # (dark grey) bridge + [255, 255, 255], # (white) road_motorway + [255, 255, 255], # (white) road_trunk + [255, 255, 255], # (white) road_primary + [255, 255, 255], # (white) road_secondary + [255, 255, 255], # (white) road_tertiary + [255, 255, 255], # (white) road_residential + [255, 255, 255], # (white) road_service + [255, 255, 255], # (white) road_track + [255, 255, 255], # (white) road_pedestrian + ], + }, + 'land_cover': { + 'type': 'segment', + 'BackgroundInvalid': True, + 'categories': [ + 'background', + 'water', + 'developed', + 'tree', + 'shrub', + 'grass', + 'crop', + 'bare', + 'snow', + 'wetland', + 'mangroves', + 'moss', + ], + 'colors': [ + [0, 0, 0], # unknown + [0, 0, 255], # (blue) water + [255, 0, 0], # (red) developed + [0, 192, 0], # (dark green) tree + [200, 170, 120], # (brown) shrub + [0, 255, 0], # (green) grass + [255, 255, 0], # (yellow) crop + [128, 128, 128], # (grey) bare + [255, 255, 255], # (white) snow + [0, 255, 255], # (cyan) wetland + [255, 0, 255], # (pink) mangroves + [128, 0, 128], # (purple) moss + ], + }, + 'tree_cover': {'type': 'regress', 'BackgroundInvalid': True}, + 'crop_type': { + 'type': 'segment', + 'BackgroundInvalid': True, + 'categories': [ + 'invalid', + 'rice', + 'grape', + 'corn', + 'sugarcane', + 'tea', + 'hop', + 'wheat', + 'soy', + 'barley', + 'oats', + 'rye', + 'cassava', + 'potato', + 'sunflower', + 'asparagus', + 'coffee', + ], + 'colors': [ + [0, 0, 0], # unknown + [0, 0, 255], # (blue) rice + [255, 0, 0], # (red) grape + [255, 255, 0], # (yellow) corn + [0, 255, 0], # (green) sugarcane + [128, 0, 128], # (purple) tea + [255, 0, 255], # (pink) hop + [0, 128, 0], # (dark green) wheat + [255, 255, 255], # (white) soy + [128, 128, 128], # (grey) barley + [165, 42, 42], # (brown) oats + [0, 255, 255], # (cyan) rye + [128, 0, 0], # (maroon) cassava + [173, 216, 230], # (light blue) potato + [128, 128, 0], # (olive) sunflower + [0, 128, 0], # (dark green) asparagus + [92, 64, 51], # (dark brown) coffee + ], + }, + 'point': { + 'type': 'detect', + 'categories': [ + 'background', + 'wind_turbine', + 'lighthouse', + 'mineshaft', + 'aerialway_pylon', + 'helipad', + 'fountain', + 'toll_booth', + 'chimney', + 'communications_tower', + 'flagpole', + 'petroleum_well', + 'water_tower', + 'offshore_wind_turbine', + 'offshore_platform', + 'power_tower', + ], + 'colors': [ + [0, 0, 0], + [0, 255, 255], # (cyan) wind_turbine + [0, 255, 0], # (green) lighthouse + [255, 255, 0], # (yellow) mineshaft + [0, 0, 255], # (blue) pylon + [173, 216, 230], # (light blue) helipad + [128, 0, 128], # (purple) fountain + [255, 255, 255], # (white) toll_booth + [0, 128, 0], # (dark green) chimney + [128, 128, 128], # (grey) communications_tower + [165, 42, 42], # (brown) flagpole + [128, 0, 0], # (maroon) petroleum_well + [255, 165, 0], # (orange) water_tower + [255, 255, 0], # (yellow) offshore_wind_turbine + [255, 0, 0], # (red) offshore_platform + [255, 0, 255], # (magenta) power_tower + ], + }, + 'rooftop_solar_panel': { + 'type': 'detect', + 'categories': ['background', 'rooftop_solar_panel'], + 'colors': [ + [0, 0, 0], + [255, 255, 0], # (yellow) rooftop_solar_panel + ], + }, + 'building': { + 'type': 'instance', + 'categories': ['background', 'ms_building'], + 'colors': [ + [0, 0, 0], + [255, 255, 0], # (yellow) building + ], + }, + 'polygon': { + 'type': 'instance', + 'categories': [ + 'background', + 'aquafarm', + 'lock', + 'dam', + 'solar_farm', + 'power_plant', + 'gas_station', + 'park', + 'parking_garage', + 'parking_lot', + 'landfill', + 'quarry', + 'stadium', + 'airport', + 'airport_apron', + 'airport_hangar', + 'airport_terminal', + 'ski_resort', + 'theme_park', + 'storage_tank', + 'silo', + 'track', + 'wastewater_plant', + 'power_substation', + 'pier', + 'crop', + 'water_park', + ], + 'colors': [ + [0, 0, 0], + [255, 255, 0], # (yellow) aquafarm + [0, 255, 255], # (cyan) lock + [0, 255, 0], # (green) dam + [0, 0, 255], # (blue) solar_farm + [255, 0, 0], # (red) power_plant + [128, 0, 128], # (purple) gas_station + [255, 255, 255], # (white) park + [0, 128, 0], # (dark green) parking_garage + [128, 128, 128], # (grey) parking_lot + [165, 42, 42], # (brown) landfill + [128, 0, 0], # (maroon) quarry + [255, 165, 0], # (orange) stadium + [255, 105, 180], # (pink) airport + [192, 192, 192], # (silver) airport_apron + [173, 216, 230], # (light blue) airport_hangar + [32, 178, 170], # (light sea green) airport_terminal + [255, 0, 255], # (magenta) ski_resort + [128, 128, 0], # (olive) theme_park + [47, 79, 79], # (dark slate gray) storage_tank + [255, 215, 0], # (gold) silo + [192, 192, 192], # (light grey) track + [240, 230, 140], # (khaki) wastewater_plant + [154, 205, 50], # (yellow green) power_substation + [255, 165, 0], # (orange) pier + [0, 192, 0], # (middle green) crop + [0, 192, 0], # (middle green) water_park + ], + }, + 'wildfire': { + 'type': 'bin_segment', + 'categories': ['fire_retardant', 'burned'], + 'colors': [ + [255, 0, 0], # (red) fire retardant + [128, 128, 128], # (grey) burned area + ], + }, + 'smoke': {'type': 'classification', 'categories': ['no', 'partial', 'yes']}, + 'snow': {'type': 'classification', 'categories': ['no', 'partial', 'yes']}, + 'dem': {'type': 'regress', 'BackgroundInvalid': True}, + 'airplane': { + 'type': 'detect', + 'categories': ['background', 'airplane'], + 'colors': [ + [0, 0, 0], # (black) background + [255, 0, 0], # (red) airplane + ], + }, + 'vessel': { + 'type': 'detect', + 'categories': ['background', 'vessel'], + 'colors': [ + [0, 0, 0], # (black) background + [255, 0, 0], # (red) vessel + ], + }, + 'water_event': { + 'type': 'segment', + 'BackgroundInvalid': True, + 'categories': ['invalid', 'background', 'water_event'], + 'colors': [ + [0, 0, 0], # (black) invalid + [0, 255, 0], # (green) background + [0, 0, 255], # (blue) water_event + ], + }, + 'park_sport': { + 'type': 'classification', + 'categories': [ + 'american_football', + 'badminton', + 'baseball', + 'basketball', + 'cricket', + 'rugby', + 'soccer', + 'tennis', + 'volleyball', + ], + }, + 'park_type': { + 'type': 'classification', + 'categories': ['park', 'pitch', 'golf_course', 'cemetery'], + }, + 'power_plant_type': { + 'type': 'classification', + 'categories': ['oil', 'nuclear', 'coal', 'gas'], + }, + 'quarry_resource': { + 'type': 'classification', + 'categories': ['sand', 'gravel', 'clay', 'coal', 'peat'], + }, + 'track_sport': { + 'type': 'classification', + 'categories': ['running', 'cycling', 'horse'], + }, + 'road_type': { + 'type': 'classification', + 'categories': [ + 'motorway', + 'trunk', + 'primary', + 'secondary', + 'tertiary', + 'residential', + 'service', + 'track', + 'pedestrian', + ], + }, + 'cloud': { + 'type': 'bin_segment', + 'categories': ['background', 'cloud', 'shadow'], + 'colors': [ + [0, 255, 0], # (green) not clouds or shadows + [255, 255, 255], # (white) clouds + [128, 128, 128], # (grey) shadows + ], + 'BackgroundInvalid': True, + }, + 'flood': { + 'type': 'bin_segment', + 'categories': ['background', 'water'], + 'colors': [ + [0, 255, 0], # (green) background + [0, 0, 255], # (blue) water + ], + 'BackgroundInvalid': True, + }, +} + + +class SatlasPretrain(NonGeoDataset): + """SatlasPretrain dataset. + + `SatlasPretrain `_ is a large-scale pre-training + dataset for tasks that involve understanding satellite images. Regularly-updated + satellite data is publicly available for much of the Earth through sources such as + Sentinel-2 and NAIP, and can inform numerous applications from tackling illegal + deforestation to monitoring marine infrastructure. However, developing automatic + computer vision systems to parse these images requires a huge amount of manual + labeling of training data. By combining over 30 TB of satellite images with 137 + label categories, SatlasPretrain serves as an effective pre-training dataset that + greatly reduces the effort needed to develop robust models for downstream satellite + image applications. + + Reference implementation: + + * https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.48550/arXiv.2211.15660 + + .. versionadded:: 0.7 + + .. note:: + This dataset requires the following additional library to be installed: + + * `AWS CLI `_: to download the dataset from AWS. + """ + + # https://github.com/allenai/satlas/blob/main/satlaspretrain_urls.txt + url = 's3://ai2-public-datasets/satlas/' + tarballs: ClassVar[dict[str, tuple[str, ...]]] = { + 'landsat': ('satlas-dataset-v1-landsat.tar',), + 'naip': ( + 'satlas-dataset-v1-naip-2011.tar', + 'satlas-dataset-v1-naip-2012.tar', + 'satlas-dataset-v1-naip-2013.tar', + 'satlas-dataset-v1-naip-2014.tar', + 'satlas-dataset-v1-naip-2015.tar', + 'satlas-dataset-v1-naip-2016.tar', + 'satlas-dataset-v1-naip-2017.tar', + 'satlas-dataset-v1-naip-2018.tar', + 'satlas-dataset-v1-naip-2019.tar', + 'satlas-dataset-v1-naip-2020.tar', + ), + 'sentinel1': ('satlas-dataset-v1-sentinel1-new.tar',), + 'sentinel2': ( + 'satlas-dataset-v1-sentinel2-a.tar', + 'satlas-dataset-v1-sentinel2-b.tar', + ), + 'static': ('satlas-dataset-v1-labels-static.tar',), + 'dynamic': ('satlas-dataset-v1-labels-dynamic.tar',), + 'metadata': ('satlas-dataset-v1-metadata.tar',), + } + md5s: ClassVar[dict[str, tuple[str, ...]]] = { + 'landsat': ('89ea5e8974826c071908392827780a06',), + 'naip': ( + '523736842994861054f04b97c4d90bfb', + '636b9a3b08be0e40d098cb7b5e655b57', + '69e2b1052b1d2d465322a24cf7207a16', + '38999aea424d403ad60e1398443636aa', + '97f4855072a8a406a4bfbe94c5f7311c', + '9ba3c626b23e6d26749a323eaedc7c0a', + 'e4aba3d198dedfe1524a9338e85794aa', + '74191a36d841b0b9b5d5cbae9a92ad71', + '55b110cc6f734bf88793306d49f1c415', + '97fc8414334987c59593d574f112a77e', + ), + 'sentinel1': ('3d88a0a10df6ab0aa50db2ba4c475048',), + 'sentinel2': ( + '7e1c6a1e322807fb11df8c0c062545ca', + '6636b8ecf2fff1d6723ecfef55a4876d', + ), + 'static': ('4e38c2573bc78cf1f0d7267e432cb42c',), + 'dynamic': ('4503ae687948e7d2cb7ade0083f77a8a',), + 'metadata': ('6b9ac5a4f9a1ee88a271d28f12854607',), + } + + # NOTE: 'tci' is RGB (b04-b02), not BGR (b02-b04) + bands: ClassVar[dict[str, tuple[str, ...]]] = { + 'landsat': tuple(f'b{i}' for i in range(1, 12)), + 'naip': ('tci', 'ir'), + 'sentinel1': ('vh', 'vv'), + 'sentinel2': ('tci', 'b05', 'b06', 'b07', 'b08', 'b11', 'b12'), + } + + chip_size = 512 + + def __init__( + self, + root: Path = 'data', + split: str = 'train_lowres', + good_images: str = 'good_images_lowres_all', + image_times: str = 'image_times', + images: Iterable[str] = ('sentinel1', 'sentinel2', 'landsat'), + labels: Iterable[str] = ('land_cover',), + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new SatlasPretrain instance. + + Args: + root: Root directory where dataset can be found. + split: Metadata split to load. + good_images: Metadata mapping between col/row and directory. + image_times: Metadata mapping between directory and ISO time. + images: List of image products. + labels: List of label products. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + AssertionError: If *images* is invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert set(images) <= set(self.bands.keys()) + + self.root = root + self.images = images + self.labels = labels + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + # Read metadata files + self.split = pd.read_json( + os.path.join(root, 'metadata', f'{split}.json'), typ='frame' + ) + self.good_images = pd.read_json( + os.path.join(root, 'metadata', f'{good_images}.json'), typ='frame' + ) + self.image_times = pd.read_json( + os.path.join(root, 'metadata', f'{image_times}.json'), typ='series' + ) + + self.split.columns = ['col', 'row'] + self.good_images.columns = ['col', 'row', 'directory'] + self.good_images = self.good_images.groupby(['col', 'row']) + + def __len__(self) -> int: + """Return the number of locations in the dataset. + + Returns: + Length of the dataset + """ + return len(self.split) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + Data and label at that index. + """ + col, row = self.split.iloc[index] + directories = self.good_images.get_group((col, row))['directory'] + + sample: dict[str, Tensor] = {} + + for image in self.images: + self._load_image(sample, image, col, row, directories) + + for label in self.labels: + self._load_label(sample, label, col, row) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_image( + self, + sample: dict[str, Tensor], + image: str, + col: int, + row: int, + directories: pd.Series, + ) -> None: + """Load a single image. + + Args: + sample: Dataset sample to populate. + image: Image product. + col: Web Mercator column. + row: Web Mercator row. + directories: Directories that may contain the image. + """ + # Moved in PIL 9.1.0 + try: + resample = Image.Resampling.BILINEAR + except AttributeError: + resample = Image.BILINEAR # type: ignore[attr-defined] + + # Find directories that match image product + good_directories: list[str] = [] + for directory in directories: + path = os.path.join(self.root, image, directory) + if os.path.isdir(path): + good_directories.append(directory) + + # Choose a random timestamp + idx = torch.randint(len(good_directories), (1,)) + directory = good_directories[idx] + time = self.image_times[directory].timestamp() + sample[f'time_{image}'] = torch.tensor(time) + + # Load all bands + channels = [] + for band in self.bands[image]: + path = os.path.join(self.root, image, directory, band, f'{col}_{row}.png') + with Image.open(path) as img: + img = img.resize((self.chip_size, self.chip_size), resample=resample) + array = np.atleast_3d(np.array(img, dtype=np.float32)) + channels.append(torch.tensor(array)) + raster = rearrange(torch.cat(channels, dim=-1), 'h w c -> c h w') + sample[f'image_{image}'] = raster + + def _load_label( + self, sample: dict[str, Tensor], label: str, col: int, row: int + ) -> None: + """Load a single label. + + Args: + sample: Dataset sample to populate. + label: Label product. + col: Web Mercator column. + row: Web Mercator row. + """ + path = os.path.join(self.root, 'static', f'{col}_{row}', f'{label}.png') + if os.path.isfile(path): + with Image.open(path) as img: + raster = torch.tensor(np.array(img, dtype=np.int64)) + else: + raster = torch.zeros(self.chip_size, self.chip_size, dtype=torch.long) + sample[f'mask_{label}'] = raster + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + products = [*self.images, 'metadata'] + if self.labels: + products.append('static') + + for product in products: + # Check if the extracted directory already exists + if os.path.isdir(os.path.join(self.root, product)): + continue + + tarballs = self.tarballs[product] + md5s = self.md5s[product] + for tarball, md5 in zip(tarballs, md5s): + path = os.path.join(self.root, tarball) + + # Check if the tarball has already been downloaded + if os.path.isfile(path): + extract_archive(path) + continue + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download and extract the tarball + aws = which('aws') + aws('s3', 'cp', self.url + tarball, self.root) + check_integrity(path, md5 if self.checksum else None) + extract_archive(path) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by :meth:`__getitem__`. + show_titles: Flag indicating whether to show titles above each panel. + suptitle: Optional string to use as a suptitle. + + Returns: + A matplotlib Figure with the rendered sample. + """ + images = [] + titles = [] + for key, value in sample.items(): + match key.split('_', 1): + case ['image', 'landsat']: + images.append(rearrange(value[[3, 2, 1]], 'c h w -> h w c') / 255) + titles.append('Landsat 8/9') + case ['image', 'naip']: + images.append(rearrange(value[:3], 'c h w -> h w c') / 255) + titles.append('NAIP') + case ['image', 'sentinel1']: + images.extend([value[0] / 255, value[1] / 255]) + titles.extend(['Sentinel-1 VH', 'Sentinel-1 VV']) + case ['image', 'sentinel2']: + images.append(rearrange(value[:3], 'c h w -> h w c') / 255) + titles.append('Sentinel-2') + case ['mask' | 'prediction', label]: + cmap = torch.tensor(TASKS[label]['colors']) + images.append(cmap[value]) + titles.append(label.replace('_', ' ').capitalize()) + + fig, ax = plt.subplots(ncols=len(images), squeeze=False) + for i, (image, title) in enumerate(zip(images, titles)): + ax[0, i].imshow(image) + ax[0, i].axis('off') + + if show_titles: + ax[0, i].set_title(title) + + if suptitle is not None: + fig.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/seasonet.py b/torchgeo/datasets/seasonet.py index 1bd59487af3..3e47a8ec491 100644 --- a/torchgeo/datasets/seasonet.py +++ b/torchgeo/datasets/seasonet.py @@ -6,6 +6,7 @@ import os import random from collections.abc import Callable, Collection, Iterable +from typing import ClassVar import matplotlib.patches as mpatches import matplotlib.pyplot as plt @@ -20,7 +21,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, percentile_normalization +from .utils import Path, download_url, extract_archive, percentile_normalization class SeasoNet(NonGeoDataset): @@ -38,7 +39,7 @@ class SeasoNet(NonGeoDataset): Dataset format: - * images are 16-bit GeoTiffs, split into seperate files based on resolution + * images are 16-bit GeoTiffs, split into separate files based on resolution * images include 12 spectral bands with 10, 20 and 60 m per pixel resolutions * masks are single-channel 8-bit GeoTiffs @@ -85,51 +86,51 @@ class SeasoNet(NonGeoDataset): .. versionadded:: 0.5 """ - metadata = [ + metadata = ( { 'name': 'spring', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/spring.zip', 'md5': 'de4cdba7b6196aff624073991b187561', }, { 'name': 'summer', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/summer.zip', 'md5': '6a54d4e134d27ae4eb03f180ee100550', }, { 'name': 'fall', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/fall.zip', 'md5': '5f94920fe41a63c6bfbab7295f7d6b95', }, { 'name': 'winter', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/winter.zip', 'md5': 'dc5e3e09e52ab5c72421b1e3186c9a48', }, { 'name': 'snow', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/snow.zip', 'md5': 'e1b300994143f99ebb03f51d6ab1cbe6', }, { 'name': 'splits', 'ext': '.zip', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/splits.zip', 'md5': 'e4ec4a18bc4efc828f0944a7cf4d5fed', }, { 'name': 'meta.csv', 'ext': '', - 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', # noqa: E501 + 'url': 'https://zenodo.org/api/files/e2288446-9ee8-4b2e-ae76-cd80366a40e1/meta.csv', 'md5': '43ea07974936a6bf47d989c32e16afe7', }, - ] - classes = [ + ) + classes = ( 'Continuous urban fabric', 'Discontinuous urban fabric', 'Industrial or commercial units', @@ -163,12 +164,17 @@ class SeasoNet(NonGeoDataset): 'Coastal lagoons', 'Estuaries', 'Sea and ocean', - ] - all_seasons = {'Spring', 'Summer', 'Fall', 'Winter', 'Snow'} + ) + all_seasons = frozenset({'Spring', 'Summer', 'Fall', 'Winter', 'Snow'}) all_bands = ('10m_RGB', '10m_IR', '20m', '60m') - band_nums = {'10m_RGB': 3, '10m_IR': 1, '20m': 6, '60m': 2} - splits = ['train', 'val', 'test'] - cmap = { + band_nums: ClassVar[dict[str, int]] = { + '10m_RGB': 3, + '10m_IR': 1, + '20m': 6, + '60m': 2, + } + splits = ('train', 'val', 'test') + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (230, 000, 77, 255), 1: (255, 000, 000, 255), 2: (204, 77, 242, 255), @@ -207,7 +213,7 @@ class SeasoNet(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', seasons: Collection[str] = all_seasons, bands: Iterable[str] = all_bands, @@ -331,7 +337,7 @@ def _load_image(self, index: int) -> Tensor: for band in self.bands: with rasterio.open(f'{path}_{band}.tif') as f: array = f.read( - out_shape=[f.count] + list(self.image_size), + out_shape=[f.count, *list(self.image_size)], out_dtype='int32', resampling=Resampling.bilinear, ) diff --git a/torchgeo/datasets/seco.py b/torchgeo/datasets/seco.py index 2aa23f3ba75..c67fecb9c8e 100644 --- a/torchgeo/datasets/seco.py +++ b/torchgeo/datasets/seco.py @@ -5,7 +5,8 @@ import os import random -from collections.abc import Callable +from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -17,7 +18,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, percentile_normalization +from .utils import Path, download_url, extract_archive, percentile_normalization class SeasonalContrastS2(NonGeoDataset): @@ -34,10 +35,10 @@ class SeasonalContrastS2(NonGeoDataset): If you use this dataset in your research, please cite the following paper: - * https://arxiv.org/pdf/2103.16607.pdf + * https://arxiv.org/pdf/2103.16607 """ - all_bands = [ + all_bands = ( 'B1', 'B2', 'B3', @@ -50,18 +51,18 @@ class SeasonalContrastS2(NonGeoDataset): 'B9', 'B11', 'B12', - ] - rgb_bands = ['B4', 'B3', 'B2'] + ) + rgb_bands = ('B4', 'B3', 'B2') - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { '100k': { - 'url': 'https://zenodo.org/record/4728033/files/seco_100k.zip?download=1', + 'url': 'https://zenodo.org/records/4728033/files/seco_100k.zip?download=1', 'md5': 'ebf2d5e03adc6e657f9a69a20ad863e0', 'filename': 'seco_100k.zip', 'directory': 'seasonal_contrast_100k', }, '1m': { - 'url': 'https://zenodo.org/record/4728033/files/seco_1m.zip?download=1', + 'url': 'https://zenodo.org/records/4728033/files/seco_1m.zip?download=1', 'md5': '187963d852d4d3ce6637743ec3a4bd9e', 'filename': 'seco_1m.zip', 'directory': 'seasonal_contrast_1m', @@ -70,10 +71,10 @@ class SeasonalContrastS2(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', version: str = '100k', seasons: int = 1, - bands: list[str] = rgb_bands, + bands: Sequence[str] = rgb_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, @@ -147,7 +148,7 @@ def __len__(self) -> int: """ return (10**5 if self.version == '100k' else 10**6) // 5 - def _load_patch(self, root: str, subdir: str) -> Tensor: + def _load_patch(self, root: Path, subdir: Path) -> Tensor: """Load a single image patch. Args: @@ -169,7 +170,7 @@ def _load_patch(self, root: str, subdir: str) -> Tensor: # what could be sped up throughout later. There is also a potential # slowdown here from converting to/from a PIL Image just to resize. # https://gist.github.com/calebrob6/748045ac8d844154067b2eefa47de92f - pil_image = Image.fromarray(band_data) # type: ignore[no-untyped-call] + pil_image = Image.fromarray(band_data) # Moved in PIL 9.1.0 try: resample = Image.Resampling.BILINEAR diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 07f49f964c1..20183f06421 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -15,7 +16,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, percentile_normalization +from .utils import Path, check_integrity, percentile_normalization class SEN12MS(NonGeoDataset): @@ -63,9 +64,9 @@ class SEN12MS(NonGeoDataset): or manually downloaded from https://dataserv.ub.tum.de/s/m1474000 and https://github.com/schmitt-muc/SEN12MS/tree/master/splits. This download will likely take several hours. - """ # noqa: E501 + """ - BAND_SETS: dict[str, tuple[str, ...]] = { + BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = { 'all': ( 'VV', 'VH', @@ -120,9 +121,9 @@ class SEN12MS(NonGeoDataset): 'B12', ) - rgb_bands = ['B04', 'B03', 'B02'] + rgb_bands = ('B04', 'B03', 'B02') - filenames = [ + filenames = ( 'ROIs1158_spring_lc.tar.gz', 'ROIs1158_spring_s1.tar.gz', 'ROIs1158_spring_s2.tar.gz', @@ -137,16 +138,16 @@ class SEN12MS(NonGeoDataset): 'ROIs2017_winter_s2.tar.gz', 'train_list.txt', 'test_list.txt', - ] - light_filenames = [ + ) + light_filenames = ( 'ROIs1158_spring', 'ROIs1868_summer', 'ROIs1970_fall', 'ROIs2017_winter', 'train_list.txt', 'test_list.txt', - ] - md5s = [ + ) + md5s = ( '6e2e8fa8b8cba77ddab49fd20ff5c37b', 'fba019bb27a08c1db96b31f718c34d79', 'd58af2c15a16f376eb3308dc9b685af2', @@ -161,11 +162,11 @@ class SEN12MS(NonGeoDataset): '3807545661288dcca312c9c538537b63', '0a68d4e1eb24f128fccdb930000b2546', 'c7faad064001e646445c4c634169484d', - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 2d4dedb50fb..79637931adb 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -13,6 +13,7 @@ from .errors import RGBBandsMissingError from .geo import RasterDataset +from .utils import Path class Sentinel(RasterDataset): @@ -24,7 +25,7 @@ class Sentinel(RasterDataset): If you use this dataset in your research, please cite it using the following format: - * https://asf.alaska.edu/data-sets/sar-data-sets/sentinel-1/sentinel-1-how-to-cite/ + * https://asf.alaska.edu/datasets/daac/sentinel-1/ """ @@ -32,7 +33,7 @@ class Sentinel1(Sentinel): r"""Sentinel-1 dataset. The `Sentinel-1 mission - `_ comprises a + `_ comprises a constellation of two polar-orbiting satellites, operating day and night performing C-band synthetic aperture radar imaging, enabling them to acquire imagery regardless of the weather. @@ -49,16 +50,16 @@ class Sentinel1(Sentinel): Product Types: * `Level-0 - `_: + `_: Raw (RAW) * `Level-1 - `_: + `_: Single Look Complex (SLC) * `Level-1 - `_: + `_: Ground Range Detected (GRD) * `Level-2 - `_: + `_: Ocean (OCN) Polarizations: @@ -71,13 +72,13 @@ class Sentinel1(Sentinel): Acquisition Modes: * `Stripmap (SM) - `_ + `_ * `Interferometric Wide (IW) swath - `_ + `_ * `Extra Wide (EW) swatch - `_ + `_ * `Wave (WV) - `_ + `_ .. note:: At the moment, this dataset only supports the GRD product type. Data must be @@ -136,12 +137,12 @@ class Sentinel1(Sentinel): \. """ date_format = '%Y%m%dT%H%M%S' - all_bands = ['HH', 'HV', 'VV', 'VH'] + all_bands = ('HH', 'HV', 'VV', 'VH') separate_files = True def __init__( self, - paths: str | list[str] = 'data', + paths: Path | list[Path] = 'data', crs: CRS | None = None, res: float = 10, bands: Sequence[str] = ['VV', 'VH'], @@ -254,7 +255,7 @@ class Sentinel2(Sentinel): """Sentinel-2 dataset. The `Copernicus Sentinel-2 mission - `_ comprises a + `_ comprises a constellation of two polar-orbiting satellites placed in the same sun-synchronous orbit, phased at 180° to each other. It aims at monitoring variability in land surface conditions, and its wide swath width (290 km) and high revisit time (10 days @@ -276,7 +277,7 @@ class Sentinel2(Sentinel): date_format = '%Y%m%dT%H%M%S' # https://gisgeography.com/sentinel-2-bands-combinations/ - all_bands = [ + all_bands: tuple[str, ...] = ( 'B01', 'B02', 'B03', @@ -290,14 +291,14 @@ class Sentinel2(Sentinel): 'B10', 'B11', 'B12', - ] - rgb_bands = ['B04', 'B03', 'B02'] + ) + rgb_bands = ('B04', 'B03', 'B02') separate_files = True def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float = 10, bands: Sequence[str] | None = None, diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 0d111ae15b9..3accf32d2af 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import numpy as np @@ -16,7 +16,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive, lazy_import +from .utils import Path, download_url, extract_archive, lazy_import class SKIPPD(NonGeoDataset): @@ -62,8 +62,8 @@ class SKIPPD(NonGeoDataset): .. versionadded:: 0.5 """ - url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' # noqa: E501 - md5 = { + url = 'https://hf.co/datasets/torchgeo/skippd/resolve/a16c7e200b4618cd93be3143cdb973e3f21498fa/{}' + md5: ClassVar[dict[str, str]] = { 'forecast': 'f4f3509ddcc83a55c433be9db2e51077', 'nowcast': '0000761d403e45bb5f86c21d3c69aa80', } @@ -71,15 +71,15 @@ class SKIPPD(NonGeoDataset): data_file_name = '2017_2019_images_pv_processed_{}.hdf5' zipfile_name = '2017_2019_images_pv_processed_{}.zip' - valid_splits = ['trainval', 'test'] + valid_splits = ('trainval', 'test') - valid_tasks = ['nowcast', 'forecast'] + valid_tasks = ('nowcast', 'forecast') dateformat = '%m/%d/%Y, %H:%M:%S' def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'trainval', task: str = 'nowcast', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, @@ -91,7 +91,7 @@ def __init__( Args: root: root directory where dataset can be found split: one of "trainval", or "test" - task: one fo "nowcast", or "forecast" + task: one of "nowcast", or "forecast" transforms: a function/transform that takes an input sample and returns a transformed version download: if True, download dataset and store it in the root directory diff --git a/torchgeo/datasets/skyscript.py b/torchgeo/datasets/skyscript.py new file mode 100644 index 00000000000..2c5c4e0fa98 --- /dev/null +++ b/torchgeo/datasets/skyscript.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""SkyScript dataset.""" + +import os +from collections.abc import Callable +from typing import Any, ClassVar + +import numpy as np +import pandas as pd +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_and_extract_archive, download_url, extract_archive + + +class SkyScript(NonGeoDataset): + """SkyScript dataset. + + `SkyScript `__ is a large and + semantically diverse image-text dataset for remote sensing images. It contains + 5.2 million remote sensing image-text pairs in total, covering more than 29K + distinct semantic tags. + + If you use this dataset in your research, please cite it using the following format: + + * https://arxiv.org/abs/2312.12856 + + .. versionadded:: 0.6 + """ + + url = 'https://opendatasharing.s3.us-west-2.amazonaws.com/SkyScript/{}' + + image_dirs = tuple(f'images{i}' for i in range(2, 8)) + image_md5s = ( + 'fbfb5f7aa1731f4106fc3ffbd608100a', + 'ad4fd9fdb9622d1ea360210cb222f2bd', + 'aeeb41e830304c74b14b5ffc1fc8e8c3', + '02ee7e0e59f9ac1c87b678a155e1f1df', + '350475f1e7fb996152fa16db891b4142', + '5e2fbf3e9262b36e30b458ec9a1df625', + ) + + #: Can be modified in subclasses to change train/val/test split + caption_files: ClassVar[dict[str, str]] = { + 'train': 'SkyScript_train_top30pct_filtered_by_CLIP_openai.csv', + 'val': 'SkyScript_val_5K_filtered_by_CLIP_openai.csv', + 'test': 'SkyScript_test_30K_filtered_by_CLIP_openai.csv', + } + caption_md5s: ClassVar[dict[str, str]] = { + 'train': '05b362e43a852667b5374c9a5ae53f8e', + 'val': 'c8d278fd29b754361989d5e7a6608f69', + 'test': '0135d9b49ce6751360912a4353e809dc', + } + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new SkyScript instance. + + Args: + root: Root directory where dataset can be found. + split: One of 'train', 'val', 'test'. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + AssertionError: If *split* is invalid. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert split in self.caption_files + + self.root = root + self.split = split + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + self.captions = pd.read_csv(os.path.join(self.root, self.caption_files[split])) + + def __len__(self) -> int: + """Return the number of images in the dataset. + + Returns: + Length of the dataset. + """ + return len(self.captions) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + A dict containing image and caption at index. + """ + filepath, title = self.captions.iloc[index][:2] + + with Image.open(os.path.join(self.root, filepath)) as img: + array = np.array(img, dtype=np.float32) + array = rearrange(array, 'h w c -> c h w') + image = torch.from_numpy(array) + + sample = {'image': image, 'caption': title} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + md5: str | None + for directory, md5 in zip(self.image_dirs, self.image_md5s): + # Check if the extracted files already exist + if os.path.isdir(os.path.join(self.root, directory)): + continue + + # Check if the zip files have already been downloaded + if os.path.isfile(os.path.join(self.root, f'{directory}.zip')): + extract_archive(os.path.join(self.root, f'{directory}.zip')) + continue + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + url = self.url.format(f'{directory}.zip') + md5 = md5 if self.checksum else None + download_and_extract_archive(url, self.root, md5=md5) + + # Download the caption file + if self.download: + url = self.url.format(self.caption_files[self.split]) + md5 = self.caption_md5s[self.split] if self.checksum else None + download_url(url, self.root, md5=md5) + + def plot( + self, + sample: dict[str, Any], + show_titles: bool = True, + suptitle: str | None = None, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`RasterDataset.__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + fig, ax = plt.subplots() + + image = rearrange(sample['image'], 'c h w -> h w c') / 255 + ax.imshow(image) + ax.axis('off') + + if show_titles: + title = sample['caption'] + if 'prediction' in sample: + title += '\n' + sample['prediction'] + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 3003031399b..4840a48e468 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable, Sequence -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import check_integrity, lazy_import, percentile_normalization +from .utils import Path, check_integrity, lazy_import, percentile_normalization class So2Sat(NonGeoDataset): @@ -103,10 +103,10 @@ class So2Sat(NonGeoDataset): This dataset requires the following additional library to be installed: * ``_ to load the dataset - """ # noqa: E501 + """ - versions = ['2', '3_random', '3_block', '3_culture_10'] - filenames_by_version = { + versions = ('2', '3_random', '3_block', '3_culture_10') + filenames_by_version: ClassVar[dict[str, dict[str, str]]] = { '2': { 'train': 'training.h5', 'validation': 'validation.h5', @@ -119,7 +119,7 @@ class So2Sat(NonGeoDataset): 'test': 'culture_10/testing.h5', }, } - md5s_by_version = { + md5s_by_version: ClassVar[dict[str, dict[str, str]]] = { '2': { 'train': '702bc6a9368ebff4542d791e53469244', 'validation': '71cfa6795de3e22207229d06d6f8775d', @@ -139,7 +139,7 @@ class So2Sat(NonGeoDataset): }, } - classes = [ + classes = ( 'Compact high rise', 'Compact mid rise', 'Compact low rise', @@ -157,7 +157,7 @@ class So2Sat(NonGeoDataset): 'Bare rock or paved', 'Bare soil or sand', 'Water', - ] + ) all_s1_band_names = ( 'S1_B1', @@ -183,9 +183,9 @@ class So2Sat(NonGeoDataset): ) all_band_names = all_s1_band_names + all_s2_band_names - rgb_bands = ['S2_B04', 'S2_B03', 'S2_B02'] + rgb_bands = ('S2_B04', 'S2_B03', 'S2_B02') - BAND_SETS = { + BAND_SETS: ClassVar[dict[str, tuple[str, ...]]] = { 'all': all_band_names, 's1': all_s1_band_names, 's2': all_s2_band_names, @@ -194,7 +194,7 @@ class So2Sat(NonGeoDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', version: str = '2', split: str = 'train', bands: Sequence[str] = BAND_SETS['all'], diff --git a/torchgeo/datasets/south_africa_crop_type.py b/torchgeo/datasets/south_africa_crop_type.py index 48dcbd8529a..a8643873c5b 100644 --- a/torchgeo/datasets/south_africa_crop_type.py +++ b/torchgeo/datasets/south_africa_crop_type.py @@ -5,8 +5,8 @@ import os import re -from collections.abc import Callable, Iterable -from typing import Any, cast +from collections.abc import Callable, Iterable, Sequence +from typing import Any, ClassVar, cast import matplotlib.pyplot as plt import torch @@ -14,9 +14,9 @@ from rasterio.crs import CRS from torch import Tensor -from .errors import RGBBandsMissingError +from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import RasterDataset -from .utils import BoundingBox +from .utils import BoundingBox, Path, which class SouthAfricaCropType(RasterDataset): @@ -59,9 +59,17 @@ class SouthAfricaCropType(RasterDataset): "Crop Type Classification Dataset for Western Cape, South Africa", Version 1.0, Radiant MLHub, https://doi.org/10.34911/rdnt.j0co8q + .. note:: + This dataset requires the following additional library to be installed: + + * `azcopy `_: to download the + dataset from Source Cooperative. + .. versionadded:: 0.6 """ + url = 'https://radiantearth.blob.core.windows.net/mlhub/ref-south-africa-crops-competition-v1' + filename_glob = '*_07_*_{}_10m.*' filename_regex = r""" ^(?P\d+) @@ -70,9 +78,9 @@ class SouthAfricaCropType(RasterDataset): _10m """ date_format = '%Y_%m_%d' - rgb_bands = ['B04', 'B03', 'B02'] - s1_bands = ['VH', 'VV'] - s2_bands = [ + rgb_bands = ('B04', 'B03', 'B02') + s1_bands = ('VH', 'VV') + s2_bands = ( 'B01', 'B02', 'B03', @@ -85,9 +93,9 @@ class SouthAfricaCropType(RasterDataset): 'B09', 'B11', 'B12', - ] - all_bands: list[str] = s1_bands + s2_bands - cmap = { + ) + all_bands = s1_bands + s2_bands + cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = { 0: (0, 0, 0, 255), 1: (255, 211, 0, 255), 2: (255, 37, 37, 255), @@ -102,11 +110,12 @@ class SouthAfricaCropType(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, - classes: list[int] = list(cmap.keys()), - bands: list[str] = s2_bands, + classes: Sequence[int] = list(cmap.keys()), + bands: Sequence[str] = s2_bands, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, ) -> None: """Initialize a new South Africa Crop Type dataset instance. @@ -117,6 +126,7 @@ def __init__( bands: the subset of bands to load transforms: a function/transform that takes input sample and its target as entry and returns a transformed version + download: if True, download dataset and store it in the root directory Raises: DatasetNotFoundError: If dataset is not found and *download* is False. @@ -127,15 +137,17 @@ def __init__( assert 0 in classes, 'Classes must include the background class: 0' self.paths = paths - self.classes = classes - self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) - self.ordinal_cmap = torch.zeros((len(self.classes), 4), dtype=torch.uint8) + self.download = download self.filename_glob = self.filename_glob.format(bands[0]) + self._verify() + super().__init__(paths=paths, crs=crs, bands=bands, transforms=transforms) # Map chosen classes to ordinal numbers, all others mapped to background class - for v, k in enumerate(self.classes): + self.ordinal_map = torch.zeros(max(self.cmap.keys()) + 1, dtype=self.dtype) + self.ordinal_cmap = torch.zeros((len(classes), 4), dtype=torch.uint8) + for v, k in enumerate(classes): self.ordinal_map[k] = v self.ordinal_cmap[v] = torch.tensor(self.cmap[k]) @@ -148,7 +160,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: Returns: data and labels at that index """ - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) # Get all files matching the given query hits = self.index.intersection(tuple(query), objects=True) @@ -211,11 +223,11 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: ) mask_filepaths.append(file_path) - mask = self._merge_files(mask_filepaths, query) + mask = self._merge_files(mask_filepaths, query).squeeze(0) sample = { 'crs': self.crs, - 'bbox': query, + 'bounds': query, 'image': image.float(), 'mask': mask.long(), } @@ -225,6 +237,26 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: return sample + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + if self.files: + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + assert isinstance(self.paths, str | os.PathLike) + os.makedirs(self.paths, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', f'{self.url}', self.paths, '--recursive=true') + def plot( self, sample: dict[str, Tensor], diff --git a/torchgeo/datasets/south_america_soybean.py b/torchgeo/datasets/south_america_soybean.py index fc28c229370..adbde74d6cb 100644 --- a/torchgeo/datasets/south_america_soybean.py +++ b/torchgeo/datasets/south_america_soybean.py @@ -3,8 +3,9 @@ """South America Soybean Dataset.""" +import os from collections.abc import Callable, Iterable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -12,7 +13,7 @@ from .errors import DatasetNotFoundError from .geo import RasterDataset -from .utils import download_url +from .utils import Path, download_url class SouthAmericaSoybean(RasterDataset): @@ -46,7 +47,7 @@ class SouthAmericaSoybean(RasterDataset): is_image = False url = 'https://glad.umd.edu/projects/AnnualClassMapsV1/SouthAmerica_Soybean_{}.tif' - md5s = { + md5s: ClassVar[dict[int, str]] = { 2021: 'edff3ada13a1a9910d1fe844d28ae4f', 2020: '0709dec807f576c9707c8c7e183db31', 2019: '441836493bbcd5e123cff579a58f5a4f', @@ -72,7 +73,7 @@ class SouthAmericaSoybean(RasterDataset): def __init__( self, - paths: str | Iterable[str] = 'data', + paths: Path | Iterable[Path] = 'data', crs: CRS | None = None, res: float | None = None, years: list[int] = [2021], @@ -112,7 +113,7 @@ def _verify(self) -> None: # Check if the extracted files already exist if self.files: return - assert isinstance(self.paths, str) + assert isinstance(self.paths, str | os.PathLike) # Check if the user requested to download the dataset if not self.download: diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 2a5283dfb5b..21a31657f58 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -3,24 +3,23 @@ """SpaceNet datasets.""" -import abc -import copy import glob -import math import os import re +from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any +from typing import Any, ClassVar import fiona import matplotlib.pyplot as plt import numpy as np import rasterio as rio import torch -from fiona.errors import FionaValueError +from fiona.errors import FionaError, FionaValueError from fiona.transform import transform_geom from matplotlib.figure import Figure from rasterio.crs import CRS +from rasterio.enums import Resampling from rasterio.features import rasterize from rasterio.transform import Affine from torch import Tensor @@ -28,15 +27,15 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, - download_radiant_mlhub_collection, - download_radiant_mlhub_dataset, extract_archive, percentile_normalization, + which, ) -class SpaceNet(NonGeoDataset, abc.ABC): +class SpaceNet(NonGeoDataset, ABC): """Abstract base class for the SpaceNet datasets. The `SpaceNet `__ datasets are a set of @@ -48,103 +47,116 @@ class SpaceNet(NonGeoDataset, abc.ABC): The SpaceNet datasets require the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `AWS CLI `_: to download the dataset from AWS. """ + url = 's3://spacenet-dataset/spacenet/{dataset_id}/tarballs/{tarball}' + directory_glob = os.path.join('**', 'AOI_{aoi}_*', '{product}') + image_glob = '*.tif' + mask_glob = '*.geojson' + file_regex = r'_img(\d+)\.' + chip_size: ClassVar[dict[str, tuple[int, int]]] = {} + + cities: ClassVar[dict[int, str]] = { + 1: 'Rio', + 2: 'Vegas', + 3: 'Paris', + 4: 'Shanghai', + 5: 'Khartoum', + 6: 'Atlanta', + 7: 'Moscow', + 8: 'Mumbai', + 9: 'San Juan', + 10: 'Dar Es Salaam', + 11: 'Rotterdam', + } + @property - @abc.abstractmethod + @abstractmethod def dataset_id(self) -> str: """Dataset ID.""" @property - @abc.abstractmethod - def imagery(self) -> dict[str, str]: - """Mapping of image identifier and filename.""" + @abstractmethod + def tarballs(self) -> dict[str, dict[int, list[str]]]: + """Mapping of tarballs[split][aoi] = [tarballs].""" @property - @abc.abstractmethod - def label_glob(self) -> str: - """Label filename.""" + @abstractmethod + def md5s(self) -> dict[str, dict[int, list[str]]]: + """Mapping of md5s[split][aoi] = [md5s].""" @property - @abc.abstractmethod - def collection_md5_dict(self) -> dict[str, str]: - """Mapping of collection id and md5 checksum.""" + @abstractmethod + def valid_aois(self) -> dict[str, list[int]]: + """Mapping of valid_aois[split] = [aois].""" @property - @abc.abstractmethod - def chip_size(self) -> dict[str, tuple[int, int]]: - """Mapping of images and their chip size.""" + @abstractmethod + def valid_images(self) -> dict[str, list[str]]: + """Mapping of valid_images[split] = [images].""" + + @property + @abstractmethod + def valid_masks(self) -> tuple[str, ...]: + """List of valid masks.""" def __init__( self, - root: str, - image: str, - collections: list[str] = [], + root: Path = 'data', + split: str = 'train', + aois: list[int] = [], + image: str | None = None, + mask: str | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: str | None = None, checksum: bool = False, ) -> None: """Initialize a new SpaceNet Dataset instance. Args: root: root directory where dataset can be found + split: 'train' or 'test' split + aois: areas of interest image: image selection - collections: collection selection + mask: mask selection transforms: a function/transform that takes input sample and its target as entry and returns a transformed version. download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: + AssertionError: If any invalid arguments are passed. DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root - self.image = image # For testing - - if collections: - for collection in collections: - assert collection in self.collection_md5_dict - - self.collections = collections or list(self.collection_md5_dict.keys()) - self.filename = self.imagery[image] + self.split = split + self.aois = aois or self.valid_aois[split] + self.image = image or self.valid_images[split][0] + self.mask = mask or self.valid_masks[0] self.transforms = transforms + self.download = download self.checksum = checksum - to_be_downloaded = self._check_integrity() + assert self.split in {'train', 'test'} + assert set(self.aois) <= set(self.valid_aois[split]) + assert self.image in self.valid_images[split] + assert self.mask in self.valid_masks - if to_be_downloaded: - if not download: - raise DatasetNotFoundError(self) - else: - self._download(to_be_downloaded, api_key) + self._verify() - self.files = self._load_files(root) - - def _load_files(self, root: str) -> list[dict[str, str]]: - """Return the paths of the files in the dataset. + if self.split == 'train': + assert len(self.images) == len(self.masks) - Args: - root: root dir of dataset + def __len__(self) -> int: + """Return the number of samples in the dataset. Returns: - list of dicts containing paths for each pair of image and label + length of the dataset """ - files = [] - for collection in self.collections: - images = glob.glob(os.path.join(root, collection, '*', self.filename)) - images = sorted(images) - for imgpath in images: - lbl_path = os.path.join( - f'{os.path.dirname(imgpath)}-labels', self.label_glob - ) - files.append({'image_path': imgpath, 'label_path': lbl_path}) - return files + return len(self.images) - def _load_image(self, path: str) -> tuple[Tensor, Affine, CRS]: + def _load_image(self, path: Path) -> tuple[Tensor, Affine, CRS]: """Load a single image. Args: @@ -153,14 +165,16 @@ def _load_image(self, path: str) -> tuple[Tensor, Affine, CRS]: Returns: the image """ - filename = os.path.join(path) - with rio.open(filename) as img: - array = img.read().astype(np.int32) - tensor = torch.from_numpy(array).float() + with rio.open(path) as img: + out_shape = (img.count, img.height, img.width) + if self.image in self.chip_size: + out_shape = (img.count, *self.chip_size[self.image]) + array = img.read(out_shape=out_shape, resampling=Resampling.bilinear) + tensor = torch.from_numpy(array.astype(np.float32)) return tensor, img.transform, img.crs def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] + self, path: Path, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] ) -> Tensor: """Rasterizes the dataset's labels (in geojson format). @@ -176,43 +190,32 @@ def _load_mask( try: with fiona.open(path) as src: vector_crs = CRS(src.crs) - if raster_crs == vector_crs: - labels = [feature['geometry'] for feature in src] - else: - labels = [ - transform_geom( - vector_crs.to_string(), - raster_crs.to_string(), - feature['geometry'], - ) - for feature in src - ] - except FionaValueError: + labels = [ + transform_geom( + vector_crs.to_string(), + raster_crs.to_string(), + feature['geometry'], + ) + for feature in src + if feature['geometry'] + ] + except (FionaError, FionaValueError): + # Empty geojson files, geometries that cannot be transformed (SN7) labels = [] - if not labels: - mask_data = np.zeros(shape=shape) - else: - mask_data = rasterize( + if labels: + mask = rasterize( labels, out_shape=shape, fill=0, # nodata value transform=tfm, all_touched=False, - dtype=np.uint8, + dtype=np.int64, ) + else: + mask = np.zeros(shape=shape, dtype=np.int64) - mask = torch.from_numpy(mask_data).long() - - return mask - - def __len__(self) -> int: - """Return the number of samples in the dataset. - - Returns: - length of the dataset - """ - return len(self.files) + return torch.from_numpy(mask) def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. @@ -223,79 +226,117 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: Returns: data and label at that index """ - files = self.files[index] - img, tfm, raster_crs = self._load_image(files['image_path']) + image_path = self.images[index] + img, tfm, raster_crs = self._load_image(image_path) h, w = img.shape[1:] - mask = self._load_mask(files['label_path'], tfm, raster_crs, (h, w)) + sample = {'image': img} - ch, cw = self.chip_size[self.image] - sample = {'image': img[:, :ch, :cw], 'mask': mask[:ch, :cw]} + if self.split == 'train': + mask_path = self.masks[index] + mask = self._load_mask(mask_path, tfm, raster_crs, (h, w)) + sample['mask'] = mask if self.transforms is not None: sample = self.transforms(sample) return sample - def _check_integrity(self) -> list[str]: - """Checks the integrity of the dataset structure. + def _image_id(self, path: str) -> list[Any]: + """Return the image ID. + + Args: + path: An image or mask filepath. Returns: - List of collections to be downloaded + A list of integers. """ - # Check if collections exist - missing_collections = [] - for collection in self.collections: - stacpath = os.path.join(self.root, collection, 'collection.json') - - if not os.path.exists(stacpath): - missing_collections.append(collection) - - if not missing_collections: - return [] - - to_be_downloaded = [] - for collection in missing_collections: - archive_path = os.path.join(self.root, f'{collection}.tar.gz') - if os.path.exists(archive_path): - print(f'Found {collection} archive') - if ( - self.checksum - and check_integrity( - archive_path, self.collection_md5_dict[collection] - ) - or not self.checksum - ): - print('Extracting...') - extract_archive(archive_path) - else: - print(f'Collection {collection} is corrupted') - to_be_downloaded.append(collection) - else: - print(f'{collection} not found') - to_be_downloaded.append(collection) - - return to_be_downloaded - - def _download(self, collections: list[str], api_key: str | None = None) -> None: - """Download the dataset and extract it. + keys: list[Any] = [] + if match := re.search(self.file_regex, path): + for key in match.group(1).split('_'): + try: + keys.append(int(key)) + except ValueError: + keys.append(key) + + return keys + + def _list_files(self, aoi: int) -> tuple[list[str], list[str]]: + """List all files in a particular AOI. Args: - collections: Collections to be downloaded - api_key: a RadiantEarth MLHub API key to use for downloading the dataset + aoi: Area of interest. + + Returns: + Lists of image and mask files. """ - for collection in collections: - download_radiant_mlhub_collection(collection, self.root, api_key) - archive_path = os.path.join(self.root, f'{collection}.tar.gz') - if ( - not self.checksum - or not check_integrity( - archive_path, self.collection_md5_dict[collection] - ) - ) and self.checksum: - raise RuntimeError(f'Collection {collection} corrupted') + # Produce a list of files + kwargs = {} + if '{aoi}' in self.directory_glob: + kwargs['aoi'] = aoi - print('Extracting...') - extract_archive(archive_path) + product_glob = os.path.join( + self.root, self.dataset_id, self.split, self.directory_glob + ) + image_glob = product_glob.format(product=self.image, **kwargs) + mask_glob = product_glob.format(product=self.mask, **kwargs) + images = glob.glob(os.path.join(image_glob, self.image_glob), recursive=True) + masks = glob.glob(os.path.join(mask_glob, self.mask_glob), recursive=True) + + # Sort files based on image ID + images.sort(key=self._image_id) + masks.sort(key=self._image_id) + + # Remove images missing masks (SN3) or duplicate images (SN8) + if self.split == 'train': + images_iter = iter(images) + images = [] + for mask in masks: + mask_id = self._image_id(mask) + for image in images_iter: + image_id = self._image_id(image) + if image_id == mask_id: + images.append(image) + break + + return images, masks + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + self.images = [] + self.masks = [] + root = os.path.join(self.root, self.dataset_id, self.split) + os.makedirs(root, exist_ok=True) + for aoi in self.aois: + # Check if the extracted files already exist + images, masks = self._list_files(aoi) + if images: + self.images.extend(images) + self.masks.extend(masks) + continue + + # Check if the tarball has already been downloaded + for tarball, md5 in zip( + self.tarballs[self.split][aoi], self.md5s[self.split][aoi] + ): + if os.path.exists(os.path.join(root, tarball)): + extract_archive(os.path.join(root, tarball), root) + continue + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download the dataset + url = self.url.format(dataset_id=self.dataset_id, tarball=tarball) + aws = which('aws') + aws('s3', 'cp', url, root) + check_integrity( + os.path.join(root, tarball), md5 if self.checksum else None + ) + extract_archive(os.path.join(root, tarball), root) + images, masks = self._list_files(aoi) + self.images.extend(images) + self.masks.extend(masks) def plot( self, @@ -315,11 +356,7 @@ def plot( .. versionadded:: 0.2 """ - # image can be 1 channel or >3 channels - if sample['image'].shape[0] == 1: - image = np.rollaxis(sample['image'].numpy(), 0, 3) - else: - image = np.rollaxis(sample['image'][:3].numpy(), 0, 3) + image = np.rollaxis(sample['image'][:3].numpy(), 0, 3) image = percentile_normalization(image, axis=(0, 1)) ncols = 1 @@ -334,25 +371,23 @@ def plot( prediction = sample['prediction'].numpy() ncols += 1 - fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) - if not isinstance(axs, np.ndarray): - axs = [axs] - axs[0].imshow(image) - axs[0].axis('off') + fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 8, 8)) + axs[0, 0].imshow(image) + axs[0, 0].axis('off') if show_titles: - axs[0].set_title('Image') + axs[0, 0].set_title('Image') if show_mask: - axs[1].imshow(mask, interpolation='none') - axs[1].axis('off') + axs[0, 1].imshow(mask, interpolation='none') + axs[0, 1].axis('off') if show_titles: - axs[1].set_title('Label') + axs[0, 1].set_title('Label') if show_predictions: - axs[2].imshow(prediction, interpolation='none') - axs[2].axis('off') + axs[0, 2].imshow(prediction, interpolation='none') + axs[0, 2].axis('off') if show_titles: - axs[2].set_title('Prediction') + axs[0, 2].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) @@ -387,43 +422,47 @@ class SpaceNet1(SpaceNet): If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/1807.01232 - """ - dataset_id = 'spacenet1' - imagery = {'rgb': 'RGB.tif', '8band': '8Band.tif'} - chip_size = {'rgb': (406, 438), '8band': (101, 110)} - label_glob = 'labels.geojson' - collection_md5_dict = {'sn1_AOI_1_RIO': 'e6ea35331636fa0c036c04b3d1cbf226'} - - def __init__( - self, - root: str = 'data', - image: str = 'rgb', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, - download: bool = False, - api_key: str | None = None, - checksum: bool = False, - ) -> None: - """Initialize a new SpaceNet 1 Dataset instance. - - Args: - root: root directory where dataset can be found - image: image selection which must be "rgb" or "8band" - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version. - download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - - Raises: - DatasetNotFoundError: If dataset is not found and *download* is False. - """ - collections = ['sn1_AOI_1_RIO'] - assert image in {'rgb', '8band'} - super().__init__( - root, image, collections, transforms, download, api_key, checksum - ) + directory_glob = '{product}' + dataset_id = 'SN1_buildings' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 1: [ + 'SN1_buildings_train_AOI_1_Rio_3band.tar.gz', + 'SN1_buildings_train_AOI_1_Rio_8band.tar.gz', + 'SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz', + ] + }, + 'test': { + 1: [ + 'SN1_buildings_test_AOI_1_Rio_3band.tar.gz', + 'SN1_buildings_test_AOI_1_Rio_8band.tar.gz', + ] + }, + } + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 1: [ + '279e334a2120ecac70439ea246174516', + '6440a9eedbd7c4fe9741875135362c8c', + 'b6e02fbd727f252ea038abe4f77a77b3', + ] + }, + 'test': { + 1: ['18283d78b21c239bc1831f3bf1d2c996', '732b3a40603b76e80aac84e002e2b3e8'] + }, + } + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [1], 'test': [1]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['3band', '8band'], + 'test': ['3band', '8band'], + } + valid_masks = ('geojson',) + chip_size: ClassVar[dict[str, tuple[int, int]]] = { + '3band': (406, 439), + '8band': (102, 110), + } class SpaceNet2(SpaceNet): @@ -486,63 +525,47 @@ class SpaceNet2(SpaceNet): If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/1807.01232 - """ - dataset_id = 'spacenet2' - collection_md5_dict = { - 'sn2_AOI_2_Vegas': 'a5a8de355290783b88ac4d69c7ef0694', - 'sn2_AOI_3_Paris': '8299186b7bbfb9a256d515bad1b7f146', - 'sn2_AOI_4_Shanghai': '4e3e80f2f437faca10ca2e6e6df0ef99', - 'sn2_AOI_5_Khartoum': '8070ff9050f94cd9f0efe9417205d7c3', + dataset_id = 'SN2_buildings' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 2: ['SN2_buildings_train_AOI_2_Vegas.tar.gz'], + 3: ['SN2_buildings_train_AOI_3_Paris.tar.gz'], + 4: ['SN2_buildings_train_AOI_4_Shanghai.tar.gz'], + 5: ['SN2_buildings_train_AOI_5_Khartoum.tar.gz'], + }, + 'test': { + 2: ['AOI_2_Vegas_Test_public.tar.gz'], + 3: ['AOI_3_Paris_Test_public.tar.gz'], + 4: ['AOI_4_Shanghai_Test_public.tar.gz'], + 5: ['AOI_5_Khartoum_Test_public.tar.gz'], + }, } - - imagery = { - 'MS': 'MS.tif', - 'PAN': 'PAN.tif', - 'PS-MS': 'PS-MS.tif', - 'PS-RGB': 'PS-RGB.tif', + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 2: ['307da318bc43aaf9481828f92eda9126'], + 3: ['4db469e3e4e7bf025368ad730aec0888'], + 4: ['986129eecd3e842ebc2063d43b407adb'], + 5: ['462b4bf0466c945d708befabd4d9115b'], + }, + 'test': { + 2: ['d45405afd6629e693e2f9168b1291ea3'], + 3: ['2eaee95303e88479246e4ee2f2279b7f'], + 4: ['f51dc51fa484dc7fb89b3697bd15a950'], + 5: ['037d7be10530f0dd1c43d4ef79f3236e'], + }, } - chip_size = { - 'MS': (162, 162), - 'PAN': (650, 650), - 'PS-MS': (650, 650), - 'PS-RGB': (650, 650), + valid_aois: ClassVar[dict[str, list[int]]] = { + 'train': [2, 3, 4, 5], + 'test': [2, 3, 4, 5], } - label_glob = 'label.geojson' - - def __init__( - self, - root: str = 'data', - image: str = 'PS-RGB', - collections: list[str] = [], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, - download: bool = False, - api_key: str | None = None, - checksum: bool = False, - ) -> None: - """Initialize a new SpaceNet 2 Dataset instance. - - Args: - root: root directory where dataset can be found - image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"] - collections: collection selection which must be a subset of: - [sn2_AOI_2_Vegas, sn2_AOI_3_Paris, sn2_AOI_4_Shanghai, - sn2_AOI_5_Khartoum]. If unspecified, all collections will be - used. - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - - Raises: - DatasetNotFoundError: If dataset is not found and *download* is False. - """ - assert image in {'MS', 'PAN', 'PS-MS', 'PS-RGB'} - super().__init__( - root, image, collections, transforms, download, api_key, checksum - ) + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], + 'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], + } + valid_masks = (os.path.join('geojson', 'buildings'),) + chip_size: ClassVar[dict[str, tuple[int, int]]] = {'MUL': (163, 163)} class SpaceNet3(SpaceNet): @@ -609,200 +632,56 @@ class SpaceNet3(SpaceNet): .. versionadded:: 0.3 """ - dataset_id = 'spacenet3' - collection_md5_dict = { - 'sn3_AOI_2_Vegas': '8ce7e6abffb8849eb88885035f061ee8', - 'sn3_AOI_3_Paris': '90b9ebd64cd83dc8d3d4773f45050d8f', - 'sn3_AOI_4_Shanghai': '3ea291df34548962dfba8b5ed37d700c', - 'sn3_AOI_5_Khartoum': 'b8d549ac9a6d7456c0f7a8e6de23d9f9', + dataset_id = 'SN3_roads' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 2: [ + 'SN3_roads_train_AOI_2_Vegas.tar.gz', + 'SN3_roads_train_AOI_2_Vegas_geojson_roads_speed.tar.gz', + ], + 3: [ + 'SN3_roads_train_AOI_3_Paris.tar.gz', + 'SN3_roads_train_AOI_3_Paris_geojson_roads_speed.tar.gz', + ], + 4: [ + 'SN3_roads_train_AOI_4_Shanghai.tar.gz', + 'SN3_roads_train_AOI_4_Shanghai_geojson_roads_speed.tar.gz', + ], + 5: [ + 'SN3_roads_train_AOI_5_Khartoum.tar.gz', + 'SN3_roads_train_AOI_5_Khartoum_geojson_roads_speed.tar.gz', + ], + }, + 'test': { + 2: ['SN3_roads_test_public_AOI_2_Vegas.tar.gz'], + 3: ['SN3_roads_test_public_AOI_3_Paris.tar.gz'], + 4: ['SN3_roads_test_public_AOI_4_Shanghai.tar.gz'], + 5: ['SN3_roads_test_public_AOI_5_Khartoum.tar.gz'], + }, } - - imagery = { - 'MS': 'MS.tif', - 'PAN': 'PAN.tif', - 'PS-MS': 'PS-MS.tif', - 'PS-RGB': 'PS-RGB.tif', + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 2: ['06317255b5e0c6df2643efd8a50f22ae', '4acf7846ed8121db1319345cfe9fdca9'], + 3: ['c13baf88ee10fe47870c303223cabf82', 'abc8199d4c522d3a14328f4f514702ad'], + 4: ['ef3de027c3da734411d4333bee9c273b', 'f1db36bd17b2be2281f5f7d369e9e25d'], + 5: ['46f327b550076f87babb5f7b43f27c68', 'd969693760d59401a84bd9215375a636'], + }, + 'test': { + 2: ['e9eb2220888ba38cab175fc6db6799a2'], + 3: ['21098cfe471dba6208c92b37b8203ae9'], + 4: ['2e7438b870ffd33d4453366db1c5b317'], + 5: ['f367c79fa0fc1d38e63a0fdd065ed957'], + }, } - chip_size = { - 'MS': (325, 325), - 'PAN': (1300, 1300), - 'PS-MS': (1300, 1300), - 'PS-RGB': (1300, 1300), + valid_aois: ClassVar[dict[str, list[int]]] = { + 'train': [2, 3, 4, 5], + 'test': [2, 3, 4, 5], } - label_glob = 'labels.geojson' - - def __init__( - self, - root: str = 'data', - image: str = 'PS-RGB', - speed_mask: bool | None = False, - collections: list[str] = [], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, - download: bool = False, - api_key: str | None = None, - checksum: bool = False, - ) -> None: - """Initialize a new SpaceNet 3 Dataset instance. - - Args: - root: root directory where dataset can be found - image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"] - speed_mask: use multi-class speed mask (created by binning roads at - 10 mph increments) as label if true, else use binary mask - collections: collection selection which must be a subset of: - [sn3_AOI_2_Vegas, sn3_AOI_3_Paris, sn3_AOI_4_Shanghai, - sn3_AOI_5_Khartoum]. If unspecified, all collections will be - used. - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - - Raises: - DatasetNotFoundError: If dataset is not found and *download* is False. - """ - assert image in {'MS', 'PAN', 'PS-MS', 'PS-RGB'} - self.speed_mask = speed_mask - super().__init__( - root, image, collections, transforms, download, api_key, checksum - ) - - def _load_mask( - self, path: str, tfm: Affine, raster_crs: CRS, shape: tuple[int, int] - ) -> Tensor: - """Rasterizes the dataset's labels (in geojson format). - - Args: - path: path to the label - tfm: transform of corresponding image - raster_crs: CRS of raster file - shape: shape of corresponding image - - Returns: - Tensor: label tensor - """ - min_speed_bin = 1 - max_speed_bin = 65 - speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1) - bin_size_mph = 10.0 - speed_cls_arr: np.typing.NDArray[np.int_] = np.array( - [math.ceil(s / bin_size_mph) for s in speed_arr_bin] - ) - - try: - with fiona.open(path) as src: - vector_crs = CRS(src.crs) - labels = [] - - for feature in src: - if raster_crs != vector_crs: - geom = transform_geom( - vector_crs.to_string(), - raster_crs.to_string(), - feature['geometry'], - ) - else: - geom = feature['geometry'] - - if self.speed_mask: - val = speed_cls_arr[ - int(feature['properties']['inferred_speed_mph']) - 1 - ] - else: - val = 1 - - labels.append((geom, val)) - - except FionaValueError: - labels = [] - - if not labels: - mask_data = np.zeros(shape=shape) - else: - mask_data = rasterize( - labels, - out_shape=shape, - fill=0, # nodata value - transform=tfm, - all_touched=False, - dtype=np.uint8, - ) - - mask = torch.from_numpy(mask_data).long() - return mask - - def plot( - self, - sample: dict[str, Tensor], - show_titles: bool = True, - suptitle: str | None = None, - ) -> Figure: - """Plot a sample from the dataset. - - Args: - sample: a sample returned by :meth:`SpaceNet.__getitem__` - show_titles: flag indicating whether to show titles above each panel - suptitle: optional string to use as a suptitle - - Returns: - a matplotlib Figure with the rendered sample - - """ - # image can be 1 channel or >3 channels - if sample['image'].shape[0] == 1: - image = np.rollaxis(sample['image'].numpy(), 0, 3) - else: - image = np.rollaxis(sample['image'][:3].numpy(), 0, 3) - image = percentile_normalization(image, axis=(0, 1)) - - ncols = 1 - show_mask = 'mask' in sample - show_predictions = 'prediction' in sample - - if show_mask: - mask = sample['mask'].numpy() - ncols += 1 - - if show_predictions: - prediction = sample['prediction'].numpy() - ncols += 1 - - fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 8, 8)) - if not isinstance(axs, np.ndarray): - axs = [axs] - axs[0].imshow(image) - axs[0].axis('off') - if show_titles: - axs[0].set_title('Image') - - if show_mask: - if self.speed_mask: - cmap = copy.copy(plt.get_cmap('autumn_r')) - cmap.set_under(color='black') - axs[1].imshow(mask, vmin=0.1, vmax=7, cmap=cmap, interpolation='none') - else: - axs[1].imshow(mask, cmap='Greys_r', interpolation='none') - axs[1].axis('off') - if show_titles: - axs[1].set_title('Label') - - if show_predictions: - if self.speed_mask: - cmap = copy.copy(plt.get_cmap('autumn_r')) - cmap.set_under(color='black') - axs[2].imshow( - prediction, vmin=0.1, vmax=7, cmap=cmap, interpolation='none' - ) - else: - axs[2].imshow(prediction, cmap='Greys_r', interpolation='none') - axs[2].axis('off') - if show_titles: - axs[2].set_title('Prediction') - - if suptitle is not None: - plt.suptitle(suptitle) - return fig + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['MS', 'PS-MS', 'PAN', 'PS-RGB'], + 'test': ['MUL', 'MUL-PanSharpen', 'PAN', 'RGB-PanSharpen'], + } + valid_masks: tuple[str, ...] = ('geojson_roads', 'geojson_roads_speed') class SpaceNet4(SpaceNet): @@ -836,136 +715,87 @@ class SpaceNet4(SpaceNet): If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/1903.12239 - """ - dataset_id = 'spacenet4' - collection_md5_dict = {'sn4_AOI_6_Atlanta': 'c597d639cba5257927a97e3eff07b753'} - - imagery = {'MS': 'MS.tif', 'PAN': 'PAN.tif', 'PS-RGBNIR': 'PS-RGBNIR.tif'} - chip_size = {'MS': (225, 225), 'PAN': (900, 900), 'PS-RGBNIR': (900, 900)} - label_glob = 'labels.geojson' - - angle_catalog_map = { - 'nadir': [ - '1030010003D22F00', - '10300100023BC100', - '1030010003993E00', - '1030010003CAF100', - '1030010002B7D800', - '10300100039AB000', - '1030010002649200', - '1030010003C92000', - '1030010003127500', - '103001000352C200', - '103001000307D800', - ], - 'off-nadir': [ - '1030010003472200', - '1030010003315300', - '10300100036D5200', - '103001000392F600', - '1030010003697400', - '1030010003895500', - '1030010003832800', - ], - 'very-off-nadir': [ - '10300100035D1B00', - '1030010003CCD700', - '1030010003713C00', - '10300100033C5200', - '1030010003492700', - '10300100039E6200', - '1030010003BDDC00', - '1030010003CD4300', - '1030010003193D00', - ], + directory_glob = os.path.join('**', '{product}') + file_regex = r'_(\d+_\d+)\.' + dataset_id = 'SN4_buildings' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 6: [ + 'Atlanta_nadir7_catid_1030010003D22F00.tar.gz', + 'Atlanta_nadir8_catid_10300100023BC100.tar.gz', + 'Atlanta_nadir10_catid_1030010003993E00.tar.gz', + 'Atlanta_nadir10_catid_1030010003CAF100.tar.gz', + 'Atlanta_nadir13_catid_1030010002B7D800.tar.gz', + 'Atlanta_nadir14_catid_10300100039AB000.tar.gz', + 'Atlanta_nadir16_catid_1030010002649200.tar.gz', + 'Atlanta_nadir19_catid_1030010003C92000.tar.gz', + 'Atlanta_nadir21_catid_1030010003127500.tar.gz', + 'Atlanta_nadir23_catid_103001000352C200.tar.gz', + 'Atlanta_nadir25_catid_103001000307D800.tar.gz', + 'Atlanta_nadir27_catid_1030010003472200.tar.gz', + 'Atlanta_nadir29_catid_1030010003315300.tar.gz', + 'Atlanta_nadir30_catid_10300100036D5200.tar.gz', + 'Atlanta_nadir32_catid_103001000392F600.tar.gz', + 'Atlanta_nadir34_catid_1030010003697400.tar.gz', + 'Atlanta_nadir36_catid_1030010003895500.tar.gz', + 'Atlanta_nadir39_catid_1030010003832800.tar.gz', + 'Atlanta_nadir42_catid_10300100035D1B00.tar.gz', + 'Atlanta_nadir44_catid_1030010003CCD700.tar.gz', + 'Atlanta_nadir46_catid_1030010003713C00.tar.gz', + 'Atlanta_nadir47_catid_10300100033C5200.tar.gz', + 'Atlanta_nadir49_catid_1030010003492700.tar.gz', + 'Atlanta_nadir50_catid_10300100039E6200.tar.gz', + 'Atlanta_nadir52_catid_1030010003BDDC00.tar.gz', + 'Atlanta_nadir53_catid_1030010003193D00.tar.gz', + 'Atlanta_nadir53_catid_1030010003CD4300.tar.gz', + 'geojson.tar.gz', + ] + }, + 'test': {6: ['SN4_buildings_AOI_6_Atlanta_test_public.tar.gz']}, } - - def __init__( - self, - root: str = 'data', - image: str = 'PS-RGBNIR', - angles: list[str] = [], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, - download: bool = False, - api_key: str | None = None, - checksum: bool = False, - ) -> None: - """Initialize a new SpaceNet 4 Dataset instance. - - Args: - root: root directory where dataset can be found - image: image selection which must be in ["MS", "PAN", "PS-RGBNIR"] - angles: angle selection which must be in ["nadir", "off-nadir", - "very-off-nadir"] - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - - Raises: - DatasetNotFoundError: If dataset is not found and *download* is False. - """ - collections = ['sn4_AOI_6_Atlanta'] - assert image in {'MS', 'PAN', 'PS-RGBNIR'} - self.angles = angles - if self.angles: - for angle in self.angles: - assert angle in self.angle_catalog_map.keys() - super().__init__( - root, image, collections, transforms, download, api_key, checksum - ) - - def _load_files(self, root: str) -> list[dict[str, str]]: - """Return the paths of the files in the dataset. - - Args: - root: root dir of dataset - - Returns: - list of dicts containing paths for each pair of image and label - """ - files = [] - nadir = [] - offnadir = [] - veryoffnadir = [] - images = glob.glob(os.path.join(root, self.collections[0], '*', self.filename)) - images = sorted(images) - - catalog_id_pattern = re.compile(r'(_[A-Z0-9])\w+$') - for imgpath in images: - imgdir = os.path.basename(os.path.dirname(imgpath)) - match = catalog_id_pattern.search(imgdir) - assert match is not None, 'Invalid image directory' - catalog_id = match.group()[1:] - - lbl_dir = os.path.dirname(imgpath).split('-nadir')[0] - - lbl_path = os.path.join(f'{lbl_dir}-labels', self.label_glob) - assert os.path.exists(lbl_path) - - _file = {'image_path': imgpath, 'label_path': lbl_path} - if catalog_id in self.angle_catalog_map['very-off-nadir']: - veryoffnadir.append(_file) - elif catalog_id in self.angle_catalog_map['off-nadir']: - offnadir.append(_file) - elif catalog_id in self.angle_catalog_map['nadir']: - nadir.append(_file) - - angle_file_map = { - 'nadir': nadir, - 'off-nadir': offnadir, - 'very-off-nadir': veryoffnadir, - } - - if not self.angles: - files.extend(nadir + offnadir + veryoffnadir) - else: - for angle in self.angles: - files.extend(angle_file_map[angle]) - return files + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 6: [ + 'd41ab6ec087b07e1e046c55d1fa5754b', + '72f04a7c0c34dd4595c181ee1ae6cb4c', + '89559f42ac11a8de570cef9802a577ad', + '5489ac756249c336ea506ef0acb3c09d', + 'bd9ed231cedd8631683ea51ea0602de1', + 'c497a8a448ed7ccdf63e7706507c0603', + '45d54eeecefdc60aa38320be6f29a17c', + '611528c0188bbc7e9cdf98609c6b0c49', + '532fbf1ca73d3d2e8b03c585f61b7316', + '538f48429b0968b6cfad97eb61fa8de1', + '3c48e94bc6d9e66e27c3a9bc8d35d65d', + 'b78cdf951e7bf4fedbe9259abd1e047a', + 'f307ce3c623d12d5a2fd5acb1e0607e0', + '9a17574332cd5513d68a0bcc9c607bdd', + 'fe905ca809f7bd2ceef75bde23c326f3', + 'd9f2e4a5c8462f6f9f7d5c573d9a1dc6', + 'f9425ff38dc82bf0e8f25a6287ff1ad1', + '7a6005d6fd972d5ce04caf9b42b36897', + '7c5aa16bb64cacf766cf88f89b3093bd', + '8f7e959eb0156ad2dfb0b966a1de06a9', + '62c4babcbe70034b7deb7c14d5ff61c2', + '8001d75f67534edf6932242324b8c1a7', + 'bc299cb5de432b5f5a1ce65a3bdb0abc', + 'd7640eda7c4efaf825665e853037bec9', + 'd4e1931551e9d3c6fd9bf1d8adfd07a0', + 'b313e23ead8fe6e2c8671a49f2c9de37', + '3bd8f07ad57bff841d0cf91c91c6f5ed', + '2556339e26a09e57559452eb240ef29c', + ] + }, + 'test': {6: ['0ec3874bfc19aed63b33ac47b039aace']}, + } + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [6], 'test': [6]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['MS', 'PAN', 'Pan-Sharpen'], + 'test': ['MS', 'PAN', 'Pan-Sharpen'], + } + valid_masks = (os.path.join('geojson', 'spacenet-buildings'),) class SpaceNet5(SpaceNet3): @@ -1030,66 +860,28 @@ class SpaceNet5(SpaceNet3): .. versionadded:: 0.2 """ - dataset_id = 'spacenet5' - collection_md5_dict = { - 'sn5_AOI_7_Moscow': 'b18107f878152fe7e75444373c320cba', - 'sn5_AOI_8_Mumbai': '1f1e2b3c26fbd15bfbcdbb6b02ae051c', + file_regex = r'_chip(\d+)\.' + dataset_id = 'SN5_roads' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 7: ['SN5_roads_train_AOI_7_Moscow.tar.gz'], + 8: ['SN5_roads_train_AOI_8_Mumbai.tar.gz'], + }, + 'test': {9: ['SN5_roads_test_public_AOI_9_San_Juan.tar.gz']}, } - - imagery = { - 'MS': 'MS.tif', - 'PAN': 'PAN.tif', - 'PS-MS': 'PS-MS.tif', - 'PS-RGB': 'PS-RGB.tif', + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 7: ['03082d01081a6d8df2bc5a9645148d2a'], + 8: ['1ee20ba781da6cb7696eef9a95a5bdcc'], + }, + 'test': {9: ['fc45afef219dfd3a20f2d4fc597f6882']}, } - chip_size = { - 'MS': (325, 325), - 'PAN': (1300, 1300), - 'PS-MS': (1300, 1300), - 'PS-RGB': (1300, 1300), + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [7, 8], 'test': [9]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['MS', 'PAN', 'PS-MS', 'PS-RGB'], + 'test': ['MS', 'PAN', 'PS-MS', 'PS-RGB'], } - label_glob = 'labels.geojson' - - def __init__( - self, - root: str = 'data', - image: str = 'PS-RGB', - speed_mask: bool | None = False, - collections: list[str] = [], - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, - download: bool = False, - api_key: str | None = None, - checksum: bool = False, - ) -> None: - """Initialize a new SpaceNet 5 Dataset instance. - - Args: - root: root directory where dataset can be found - image: image selection which must be in ["MS", "PAN", "PS-MS", "PS-RGB"] - speed_mask: use multi-class speed mask (created by binning roads at - 10 mph increments) as label if true, else use binary mask - collections: collection selection which must be a subset of: - [sn5_AOI_7_Moscow, sn5_AOI_8_Mumbai]. If unspecified, all - collections will be used. - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - - Raises: - DatasetNotFoundError: If dataset is not found and *download* is False. - """ - super().__init__( - root, - image, - speed_mask, - collections, - transforms, - download, - api_key, - checksum, - ) + valid_masks = ('geojson_roads_speed',) class SpaceNet6(SpaceNet): @@ -1152,84 +944,25 @@ class SpaceNet6(SpaceNet): * https://arxiv.org/abs/2004.06500 - .. note:: - - This dataset requires the following additional library to be installed: - - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub - .. versionadded:: 0.4 """ - dataset_id = 'spacenet6' - collections = ['sn6_AOI_11_Rotterdam'] - # This is actually the metadata hash - collection_md5_dict = {'sn6_AOI_11_Rotterdam': '66f7312218fec67a1e0b3b02b22c95cc'} - imagery = { - 'PAN': 'PAN.tif', - 'RGBNIR': 'RGBNIR.tif', - 'PS-RGB': 'PS-RGB.tif', - 'PS-RGBNIR': 'PS-RGBNIR.tif', - 'SAR-Intensity': 'SAR-Intensity.tif', + file_regex = r'_tile_(\d+)\.' + dataset_id = 'SN6_buildings' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': {11: ['SN6_buildings_AOI_11_Rotterdam_train.tar.gz']}, + 'test': {11: ['SN6_buildings_AOI_11_Rotterdam_test_public.tar.gz']}, } - chip_size = { - 'PAN': (900, 900), - 'RGBNIR': (450, 450), - 'PS-RGB': (900, 900), - 'PS-RGBNIR': (900, 900), - 'SAR-Intensity': (900, 900), + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': {11: ['10ca26d2287716e3b6ef0cf0ad9f946e']}, + 'test': {11: ['a07823a5e536feeb8bb6b6f0cb43cf05']}, } - label_glob = 'labels.geojson' - - def __init__( - self, - root: str = 'data', - image: str = 'PS-RGB', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, - download: bool = False, - api_key: str | None = None, - ) -> None: - """Initialize a new SpaceNet 6 Dataset instance. - - Args: - root: root directory where dataset can be found - image: image selection which must be in ["PAN", "RGBNIR", - "PS-RGB", "PS-RGBNIR", "SAR-Intensity"] - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - - Raises: - DatasetNotFoundError: If dataset is not found and *download* is False. - """ - self.root = root - self.image = image # For testing - - self.filename = self.imagery[image] - self.transforms = transforms - - if download: - self.__download(api_key) - - self.files = self._load_files(os.path.join(root, self.dataset_id)) - - def __download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if os.path.exists( - os.path.join( - self.root, self.dataset_id, self.collections[0], 'collection.json' - ) - ): - print('Files already downloaded and verified') - return - - download_radiant_mlhub_dataset(self.dataset_id, self.root, api_key) + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [11], 'test': [11]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['PAN', 'PS-RGB', 'PS-RGBNIR', 'RGBNIR', 'SAR-Intensity'], + 'test': ['SAR-Intensity'], + } + valid_masks = ('geojson_buildings',) class SpaceNet7(SpaceNet): @@ -1237,7 +970,7 @@ class SpaceNet7(SpaceNet): `SpaceNet 7 `_ is a dataset which consist of medium resolution (4.0m) satellite imagery mosaics acquired from - Planet Labs’ Dove constellation between 2017 and 2020. It includes ≈ 24 + Planet Labs' Dove constellation between 2017 and 2020. It includes ≈ 24 images (one per month) covering > 100 unique geographies, and comprises > 40,000 km2 of imagery and exhaustive polygon labels of building footprints therein, totaling over 11M individual annotations. @@ -1268,111 +1001,69 @@ class SpaceNet7(SpaceNet): .. versionadded:: 0.2 """ - dataset_id = 'spacenet7' - collection_md5_dict = { - 'sn7_train_source': '9f8cc109d744537d087bd6ff33132340', - 'sn7_train_labels': '16f873e3f0f914d95a916fb39b5111b5', - 'sn7_test_source': 'e97914f58e962bba3e898f08a14f83b2', + directory_glob = os.path.join('**', '{product}') + mask_glob = '*_Buildings.geojson' + file_regex = r'global_monthly_(\d+.*\d+)' + dataset_id = 'SN7_buildings' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': {0: ['SN7_buildings_train.tar.gz']}, + 'test': {0: ['SN7_buildings_test_public.tar.gz']}, + } + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': {0: ['6eda13b9c28f6f5cdf00a7e8e218c1b1']}, + 'test': {0: ['b3bde95a0f8f32f3bfeba49464b9bc97']}, + } + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['images', 'images_masked'], + 'test': ['images_masked'], + } + valid_masks = ('labels', 'labels_match', 'labels_match_pix') + chip_size: ClassVar[dict[str, tuple[int, int]]] = { + 'images': (1024, 1024), + 'images_masked': (1024, 1024), } - imagery = {'img': 'mosaic.tif'} - chip_size = {'img': (1023, 1023)} - - label_glob = 'labels.geojson' - - def __init__( - self, - root: str = 'data', - split: str = 'train', - transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, - download: bool = False, - api_key: str | None = None, - checksum: bool = False, - ) -> None: - """Initialize a new SpaceNet 7 Dataset instance. - - Args: - root: root directory where dataset can be found - split: split selection which must be in ["train", "test"] - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version - download: if True, download dataset and store it in the root directory. - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) - - Raises: - DatasetNotFoundError: If dataset is not found and *download* is False. - """ - self.root = root - self.split = split - self.filename = self.imagery['img'] - self.transforms = transforms - self.checksum = checksum - - assert split in {'train', 'test'}, 'Invalid split' - - if split == 'test': - self.collections = ['sn7_test_source'] - else: - self.collections = ['sn7_train_source', 'sn7_train_labels'] - - to_be_downloaded = self._check_integrity() - - if to_be_downloaded: - if not download: - raise DatasetNotFoundError(self) - else: - self._download(to_be_downloaded, api_key) - - self.files = self._load_files(root) - - def _load_files(self, root: str) -> list[dict[str, str]]: - """Return the paths of the files in the dataset. - Args: - root: root dir of dataset +class SpaceNet8(SpaceNet): + r"""SpaceNet8: Flood Detection Challenge Using Multiclass Segmentation. - Returns: - list of dicts containing paths for images and labels (if train split) - """ - files = [] - if self.split == 'train': - imgs = sorted( - glob.glob(os.path.join(root, 'sn7_train_source', '*', self.filename)) - ) - lbls = sorted( - glob.glob(os.path.join(root, 'sn7_train_labels', '*', self.label_glob)) - ) - for img, lbl in zip(imgs, lbls): - files.append({'image_path': img, 'label_path': lbl}) - else: - imgs = sorted( - glob.glob(os.path.join(root, 'sn7_test_source', '*', self.filename)) - ) - for img in imgs: - files.append({'image_path': img}) - return files + `SpaceNet 8 `_ is a dataset focusing on + infrastructure and flood mapping related to hurricanes and heavy rains that cause + route obstructions and significant damage. - def __getitem__(self, index: int) -> dict[str, Tensor]: - """Return an index within the dataset. - - Args: - index: index to return + If you use this dataset in your research, please cite the following paper: - Returns: - data at that index - """ - files = self.files[index] - img, tfm, raster_crs = self._load_image(files['image_path']) - h, w = img.shape[1:] + * https://openaccess.thecvf.com/content/CVPR2022W/EarthVision/html/Hansch_SpaceNet_8\_-_The_Detection_of_Flooded_Roads_and_Buildings_CVPRW_2022_paper.html - ch, cw = self.chip_size['img'] - sample = {'image': img[:, :ch, :cw]} - if self.split == 'train': - mask = self._load_mask(files['label_path'], tfm, raster_crs, (h, w)) - sample['mask'] = mask[:ch, :cw] - - if self.transforms is not None: - sample = self.transforms(sample) + .. versionadded:: 0.6 + """ - return sample + directory_glob = '{product}' + file_regex = r'(\d+_\d+_\d+)\.' + dataset_id = 'SN8_floods' + tarballs: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 0: [ + 'Germany_Training_Public.tar.gz', + 'Louisiana-East_Training_Public.tar.gz', + ] + }, + 'test': {0: ['Louisiana-West_Test_Public.tar.gz']}, + } + md5s: ClassVar[dict[str, dict[int, list[str]]]] = { + 'train': { + 0: ['81383a9050b93e8f70c8557d4568e8a2', 'fa40ae3cf6ac212c90073bf93d70bd95'] + }, + 'test': {0: ['d41d8cd98f00b204e9800998ecf8427e']}, + } + valid_aois: ClassVar[dict[str, list[int]]] = {'train': [0], 'test': [0]} + valid_images: ClassVar[dict[str, list[str]]] = { + 'train': ['PRE-event', 'POST-event'], + 'test': ['PRE-event', 'POST-event'], + } + valid_masks = ('annotations',) + chip_size: ClassVar[dict[str, tuple[int, int]]] = { + 'PRE-event': (1300, 1300), + 'POST-event': (1300, 1300), + } diff --git a/torchgeo/datasets/ssl4eo.py b/torchgeo/datasets/ssl4eo.py index ce6558f52c4..b2840afb865 100644 --- a/torchgeo/datasets/ssl4eo.py +++ b/torchgeo/datasets/ssl4eo.py @@ -7,7 +7,7 @@ import os import random from collections.abc import Callable -from typing import TypedDict +from typing import ClassVar, TypedDict import matplotlib.pyplot as plt import numpy as np @@ -18,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive class SSL4EO(NonGeoDataset): @@ -93,13 +93,13 @@ class SSL4EOL(NonGeoDataset): * https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html .. versionadded:: 0.5 - """ # noqa: E501 + """ class _Metadata(TypedDict): num_bands: int rgb_bands: list[int] - metadata: dict[str, _Metadata] = { + metadata: ClassVar[dict[str, _Metadata]] = { 'tm_toa': {'num_bands': 7, 'rgb_bands': [2, 1, 0]}, 'etm_toa': {'num_bands': 9, 'rgb_bands': [2, 1, 0]}, 'etm_sr': {'num_bands': 6, 'rgb_bands': [2, 1, 0]}, @@ -107,8 +107,8 @@ class _Metadata(TypedDict): 'oli_sr': {'num_bands': 7, 'rgb_bands': [3, 2, 1]}, } - url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' # noqa: E501 - checksums = { + url = 'https://hf.co/datasets/torchgeo/ssl4eo_l/resolve/e2467887e6a6bcd7547d9d5999f8d9bc3323dc31/{0}/ssl4eo_l_{0}.tar.gz{1}' + checksums: ClassVar[dict[str, dict[str, str]]] = { 'tm_toa': { 'aa': '553795b8d73aa253445b1e67c5b81f11', 'ab': 'e9e0739b5171b37d16086cb89ab370e8', @@ -162,7 +162,7 @@ class _Metadata(TypedDict): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'oli_sr', seasons: int = 1, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -357,7 +357,7 @@ class _Metadata(TypedDict): md5: str bands: list[str] - metadata: dict[str, _Metadata] = { + metadata: ClassVar[dict[str, _Metadata]] = { 's1': { 'filename': 's1.tar.gz', 'md5': '51ee23b33eb0a2f920bda25225072f3a', @@ -404,7 +404,7 @@ class _Metadata(TypedDict): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 's2c', seasons: int = 1, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, diff --git a/torchgeo/datasets/ssl4eo_benchmark.py b/torchgeo/datasets/ssl4eo_benchmark.py index 7d9edcaecb4..13c5a8474c4 100644 --- a/torchgeo/datasets/ssl4eo_benchmark.py +++ b/torchgeo/datasets/ssl4eo_benchmark.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -18,7 +19,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .nlcd import NLCD -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class SSL4EOLBenchmark(NonGeoDataset): @@ -46,16 +47,16 @@ class SSL4EOLBenchmark(NonGeoDataset): * https://proceedings.neurips.cc/paper_files/paper/2023/hash/bbf7ee04e2aefec136ecf60e346c2e61-Abstract-Datasets_and_Benchmarks.html .. versionadded:: 0.5 - """ # noqa: E501 + """ - url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/ssl4eo-l-benchmark/resolve/da96ae2b04cb509710b72fce9131c2a3d5c211c2/{}.tar.gz' - valid_sensors = ['tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr'] - valid_products = ['cdl', 'nlcd'] - valid_splits = ['train', 'val', 'test'] + valid_sensors = ('tm_toa', 'etm_toa', 'etm_sr', 'oli_tirs_toa', 'oli_sr') + valid_products = ('cdl', 'nlcd') + valid_splits = ('train', 'val', 'test') image_root = 'ssl4eo_l_{}_benchmark' - img_md5s = { + img_md5s: ClassVar[dict[str, str]] = { 'tm_toa': '8e3c5bcd56d3780a442f1332013b8d15', 'etm_toa': '1b051c7fe4d61c581b341370c9e76f1f', 'etm_sr': '34a24fa89a801654f8d01e054662c8cd', @@ -63,14 +64,14 @@ class SSL4EOLBenchmark(NonGeoDataset): 'oli_sr': '0700cd15cc2366fe68c2f8c02fa09a15', } - mask_dir_dict = { + mask_dir_dict: ClassVar[dict[str, str]] = { 'tm_toa': 'ssl4eo_l_tm_{}', 'etm_toa': 'ssl4eo_l_etm_{}', 'etm_sr': 'ssl4eo_l_etm_{}', 'oli_tirs_toa': 'ssl4eo_l_oli_{}', 'oli_sr': 'ssl4eo_l_oli_{}', } - mask_md5s = { + mask_md5s: ClassVar[dict[str, dict[str, str]]] = { 'tm': { 'cdl': '3d676770ffb56c7e222a7192a652a846', 'nlcd': '261149d7614fcfdcb3be368eefa825c7', @@ -85,7 +86,7 @@ class SSL4EOLBenchmark(NonGeoDataset): }, } - year_dict = { + year_dict: ClassVar[dict[str, int]] = { 'tm_toa': 2011, 'etm_toa': 2019, 'etm_sr': 2019, @@ -93,7 +94,7 @@ class SSL4EOLBenchmark(NonGeoDataset): 'oli_sr': 2019, } - rgb_indices = { + rgb_indices: ClassVar[dict[str, list[int]]] = { 'tm_toa': [2, 1, 0], 'etm_toa': [2, 1, 0], 'etm_sr': [2, 1, 0], @@ -101,13 +102,16 @@ class SSL4EOLBenchmark(NonGeoDataset): 'oli_sr': [3, 2, 1], } - split_percentages = [0.7, 0.15, 0.15] + split_percentages = (0.7, 0.15, 0.15) - cmaps = {'nlcd': NLCD.cmap, 'cdl': CDL.cmap} + cmaps: ClassVar[dict[str, dict[int, tuple[int, int, int, int]]]] = { + 'nlcd': NLCD.cmap, + 'cdl': CDL.cmap, + } def __init__( self, - root: str = 'data', + root: Path = 'data', sensor: str = 'oli_sr', product: str = 'cdl', split: str = 'train', @@ -297,7 +301,7 @@ def retrieve_sample_collection(self) -> list[tuple[str, str]]: sample_collection.append((img_path, mask_path)) return sample_collection - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load the input image. Args: @@ -310,17 +314,17 @@ def _load_image(self, path: str) -> Tensor: image = torch.from_numpy(src.read()).float() return image - def _load_mask(self, path: str) -> Tensor: + def _load_mask(self, path: Path) -> Tensor: """Load the mask. Args: path: path to mask - Retuns: + Returns: mask """ with rasterio.open(path) as src: - mask = torch.from_numpy(src.read()).long() + mask = torch.from_numpy(src.read(1)).long() mask = self.ordinal_map[mask] return mask diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 8eb410297e9..eec9be57ab3 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class SustainBenchCropYield(NonGeoDataset): @@ -45,21 +45,21 @@ class SustainBenchCropYield(NonGeoDataset): * https://doi.org/10.1609/aaai.v31i1.11172 .. versionadded:: 0.5 - """ # noqa: E501 + """ - valid_countries = ['usa', 'brazil', 'argentina'] + valid_countries = ('usa', 'brazil', 'argentina') md5 = '362bad07b51a1264172b8376b39d1fc9' - url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' # noqa: E501 + url = 'https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link' dir = 'soybeans' - valid_splits = ['train', 'dev', 'test'] + valid_splits = ('train', 'dev', 'test') def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', countries: list[str] = ['usa'], transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, diff --git a/torchgeo/datasets/treesatai.py b/torchgeo/datasets/treesatai.py new file mode 100644 index 00000000000..5f55f158361 --- /dev/null +++ b/torchgeo/datasets/treesatai.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""TreeSatAI datasets.""" + +import json +import os +from collections.abc import Callable, Sequence +from typing import ClassVar + +import rasterio as rio +import torch +from einops import rearrange +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import Path, download_url, extract_archive, percentile_normalization + + +class TreeSatAI(NonGeoDataset): + """TreeSatAI Benchmark Archive. + + `TreeSatAI Benchmark Archive `_ is a + multi-sensor, multi-label dataset for tree species classification in remote + sensing. It was created by combining labels from the federal forest inventory of + Lower Saxony, Germany with 20 cm Color-Infrared (CIR) and 10 m Sentinel imagery. + + The TreeSatAI Benchmark Archive contains: + + * 50,381 image triplets (aerial, Sentinel-1, Sentinel-2) + * synchronized time steps and locations + * all original spectral bands/polarizations from the sensors + * 20 species classes (single labels) + * 12 age classes (single labels) + * 15 genus classes (multi labels) + * 60 m and 200 m patches + * fixed split for train (90%) and test (10%) data + * additional single labels such as English species name, genus, + forest stand type, foliage type, land cover + + If you use this dataset in your research, please cite the following paper: + + * https://doi.org/10.5194/essd-15-681-2023 + + .. versionadded:: 0.7 + """ + + url = 'https://zenodo.org/records/6780578/files/' + md5s: ClassVar[dict[str, str]] = { + 'aerial_60m_abies_alba.zip': '4298b1c9fbf6d0d85f7aa208ff5fe0c9', + 'aerial_60m_acer_pseudoplatanus.zip': '7c31d7ddea841f6509deece8f984a79e', + 'aerial_60m_alnus_spec.zip': '34ea107f43c6172c6d2652dbf26306af', + 'aerial_60m_betula_spec.zip': '69de9373739a027692a823846434fa0c', + 'aerial_60m_cleared.zip': '8dffbb2f6aad17ef83721cffa5b52d96', + 'aerial_60m_fagus_sylvatica.zip': '77b277e69e90bfbd3c5fd15a73d228fe', + 'aerial_60m_fraxinus_excelsior.zip': '9a88a8e6821f8a54ded950de9238831f', + 'aerial_60m_larix_decidua.zip': 'aa0bc5b091b099018a078536ef429031', + 'aerial_60m_larix_kaempferi.zip': '429df073f69f8bbf60aef765e1c925ba', + 'aerial_60m_picea_abies.zip': 'edb9b1bc9a5a7b405f4cbb0d71cedf54', + 'aerial_60m_pinus_nigra.zip': '96bf1798ef82f712ea46c2963ddb7083', + 'aerial_60m_pinus_strobus.zip': '0ff818c6d31f59b8488880e49b300c7a', + 'aerial_60m_pinus_sylvestris.zip': '298cbaac4d9f07a204e1e74e8446798d', + 'aerial_60m_populus_spec.zip': '46fcff76b119cc24f3caf938a0bb433a', + 'aerial_60m_prunus_spec.zip': 'fb1c570d3ea925a049630224ccb354bc', + 'aerial_60m_pseudotsuga_menziesii.zip': '2d05511ceabf4037b869eca928f3c04e', + 'aerial_60m_quercus_petraea.zip': '31f573fb0419b2b453ed7da1c4d2a298', + 'aerial_60m_quercus_robur.zip': 'bcd90506509de26692c043f4c8d73af0', + 'aerial_60m_quercus_rubra.zip': '71d8495725ed1b4f27d9e382409fcc5e', + 'aerial_60m_tilia_spec.zip': 'f81558c9c7189ac8a257d041ee43c1c9', + 'geojson.zip': 'aa749718f3cb76c1dfc9cddc2ed201db', + 'labels.zip': '656f1b68ec9ab70afd02bb127b75bb24', + 's1.zip': 'bed4fc8cb65da46a24ec1bc6cea2763c', + 's2.zip': '453ba69056aa33a3c6b97afb7b6afadb', + 'test_filenames.lst': '2166903d947f0025f61e342da466f917', + 'train_filenames.lst': 'a1a0148e8120b0268f76d2e98a68436f', + } + + # Genus-level classes (species-level labels also exist) + classes = ( + 'Abies', # fir + 'Acer', # maple + 'Alnus', # alder + 'Betula', # birch + 'Cleared', # none + 'Fagus', # beech + 'Fraxinus', # ash + 'Larix', # larch + 'Picea', # spruce + 'Pinus', # pine + 'Populus', # poplar + 'Prunus', # cherry + 'Pseudotsuga', # Douglas fir + 'Quercus', # oak + 'Tilia', # linden + ) + + # https://zenodo.org/records/6780578/files/220629_doc_TreeSatAI_benchmark_archive.pdf + all_sensors = ('aerial', 's1', 's2') + all_bands: ClassVar[dict[str, list[str]]] = { + 'aerial': ['IR', 'G', 'B', 'R'], + 's1': ['VV', 'VH', 'VV/VH'], + 's2': [ + 'B02', + 'B03', + 'B04', + 'B08', + 'B05', + 'B06', + 'B07', + 'B8A', + 'B11', + 'B12', + 'B01', + 'B09', + ], + } + rgb_bands: ClassVar[dict[str, list[str]]] = { + 'aerial': ['R', 'G', 'B'], + 's1': ['VV', 'VH', 'VV/VH'], + 's2': ['B04', 'B03', 'B02'], + } + + def __init__( + self, + root: Path = 'data', + split: str = 'train', + sensors: Sequence[str] = all_sensors, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new TreeSatAI instance. + + Args: + root: Root directory where dataset can be found. + split: Either 'train' or 'test'. + sensors: One or more of 'aerial', 's1', and/or 's2'. + transforms: A function/transform that takes input sample and its target as + entry and returns a transformed version. + download: If True, download dataset and store it in the root directory. + checksum: If True, check the MD5 of the downloaded files (may be slow). + + Raises: + AssertionError: If invalid *sensors* are chosen. + DatasetNotFoundError: If dataset is not found and *download* is False. + """ + assert set(sensors) <= set(self.all_sensors) + + self.root = root + self.split = split + self.sensors = sensors + self.transforms = transforms + self.download = download + self.checksum = checksum + + self._verify() + + path = os.path.join(self.root, f'{split}_filenames.lst') + with open(path) as f: + self.files = f.read().strip().split('\n') + + path = os.path.join(self.root, 'labels', 'TreeSatBA_v9_60m_multi_labels.json') + with open(path) as f: + self.labels = json.load(f) + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + Length of the dataset. + """ + return len(self.files) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: Index to return. + + Returns: + Data and label at that index. + """ + file = self.files[index] + label = torch.zeros(len(self.classes)) + for genus, _ in self.labels[file]: + i = self.classes.index(genus) + label[i] = 1 + + sample = {'label': label} + for directory in self.sensors: + with rio.open(os.path.join(self.root, directory, '60m', file)) as f: + sample[f'image_{directory}'] = torch.tensor(f.read().astype('float32')) + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the extracted files already exist + exists = [] + for directory in self.sensors: + exists.append(os.path.isdir(os.path.join(self.root, directory))) + + if all(exists): + return + + for file, md5 in self.md5s.items(): + # Check if the file has already been downloaded + if os.path.isfile(os.path.join(self.root, file)): + self._extract(file) + continue + + # Check if the user requested to download the dataset + if self.download: + url = self.url + file + download_url(url, self.root, md5=md5 if self.checksum else None) + self._extract(file) + continue + + raise DatasetNotFoundError(self) + + def _extract(self, file: str) -> None: + """Extract file. + + Args: + file: The file to extract. + """ + if not file.endswith('.zip'): + return + + to_path = self.root + if file.startswith('aerial'): + to_path = os.path.join(self.root, 'aerial', '60m') + + extract_archive(os.path.join(self.root, file), to_path) + + def plot(self, sample: dict[str, Tensor], show_titles: bool = True) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: A sample returned by :meth:`__getitem__`. + show_titles: Flag indicating whether to show titles above each panel. + + Returns: + A matplotlib Figure with the rendered sample. + """ + fig, ax = plt.subplots(ncols=len(self.sensors), squeeze=False) + + for i, sensor in enumerate(self.sensors): + image = sample[f'image_{sensor}'].cpu().numpy() + bands = [self.all_bands[sensor].index(b) for b in self.rgb_bands[sensor]] + image = rearrange(image[bands], 'c h w -> h w c') + image = percentile_normalization(image) + ax[0, i].imshow(image) + ax[0, i].axis('off') + + if show_titles: + ax[0, i].set_title(sensor) + + if show_titles: + label = self._multilabel_to_string(sample['label']) + suptitle = f'Label: ({label})' + + if 'prediction' in sample: + prediction = self._multilabel_to_string(sample['prediction']) + suptitle += f'\nPrediction: ({prediction})' + + fig.suptitle(suptitle) + + fig.tight_layout() + return fig + + def _multilabel_to_string(self, multilabel: Tensor) -> str: + """Convert a tensor of multilabel class probabilities to human readable format. + + Args: + multilabel: A tensor of multilabel class probabilities. + + Returns: + Class names and percentages sorted by percentage. + """ + labels: list[tuple[str, float]] = [] + for i, pct in enumerate(multilabel.cpu().numpy()): + if pct > 0.001: + labels.append((self.classes[i], pct)) + + labels.sort(key=lambda label: label[1], reverse=True) + return ', '.join([f'{genus}: {pct:.1%}' for genus, pct in labels]) diff --git a/torchgeo/datasets/ucmerced.py b/torchgeo/datasets/ucmerced.py index 686fb3b96f8..e7e830ddb1a 100644 --- a/torchgeo/datasets/ucmerced.py +++ b/torchgeo/datasets/ucmerced.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import cast +from typing import ClassVar, cast import matplotlib.pyplot as plt import numpy as np @@ -15,7 +15,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoClassificationDataset -from .utils import check_integrity, download_url, extract_archive +from .utils import Path, check_integrity, download_url, extract_archive class UCMerced(NonGeoClassificationDataset): @@ -66,19 +66,19 @@ class UCMerced(NonGeoClassificationDataset): * https://dl.acm.org/doi/10.1145/1869790.1869829 """ - url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/d0af6e2eeea2322af86078068bd83337148a2149/UCMerced_LandUse.zip' # noqa: E501 + url = 'https://hf.co/datasets/torchgeo/ucmerced/resolve/7c5ef3454d9b1cccfa7ccde0c01fc8f00a45909a/' filename = 'UCMerced_LandUse.zip' md5 = '5b7ec56793786b6dc8a908e8854ac0e4' base_dir = os.path.join('UCMerced_LandUse', 'Images') - splits = ['train', 'val', 'test'] - split_urls = { - 'train': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-train.txt', # noqa: E501 - 'val': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-val.txt', # noqa: E501 - 'test': 'https://storage.googleapis.com/remote_sensing_representations/uc_merced-test.txt', # noqa: E501 + splits = ('train', 'val', 'test') + split_filenames: ClassVar[dict[str, str]] = { + 'train': 'uc_merced-train.txt', + 'val': 'uc_merced-val.txt', + 'test': 'uc_merced-test.txt', } - split_md5s = { + split_md5s: ClassVar[dict[str, str]] = { 'train': 'f2fb12eb2210cfb53f93f063a35ff374', 'val': '11ecabfc52782e5ea6a9c7c0d263aca0', 'test': '046aff88472d8fc07c4678d03749e28d', @@ -86,7 +86,7 @@ class UCMerced(NonGeoClassificationDataset): def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, @@ -113,11 +113,11 @@ def __init__( self._verify() valid_fns = set() - with open(os.path.join(self.root, f'uc_merced-{split}.txt')) as f: + with open(os.path.join(self.root, self.split_filenames[split])) as f: for fn in f: valid_fns.add(fn.strip()) - def is_in_split(x: str) -> bool: + def is_in_split(x: Path) -> bool: return os.path.basename(x) in valid_fns super().__init__( @@ -173,16 +173,12 @@ def _verify(self) -> None: def _download(self) -> None: """Download the dataset.""" download_url( - self.url, - self.root, - filename=self.filename, - md5=self.md5 if self.checksum else None, + self.url + self.filename, self.root, md5=self.md5 if self.checksum else None ) for split in self.splits: download_url( - self.split_urls[split], + self.url + self.split_filenames[split], self.root, - filename=f'uc_merced-{split}.txt', md5=self.split_md5s[split] if self.checksum else None, ) diff --git a/torchgeo/datasets/usavars.py b/torchgeo/datasets/usavars.py index b955e8ded68..db13059bfa7 100644 --- a/torchgeo/datasets/usavars.py +++ b/torchgeo/datasets/usavars.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable, Sequence +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -17,7 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_url, extract_archive +from .utils import Path, download_url, extract_archive class USAVars(NonGeoDataset): @@ -49,12 +50,12 @@ class USAVars(NonGeoDataset): .. versionadded:: 0.3 """ - data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' # noqa: E501 + data_url = 'https://hf.co/datasets/torchgeo/usavars/resolve/01377abfaf50c0cc8548aaafb79533666bbf288f/{}' dirname = 'uar' md5 = '677e89fd20e5dd0fe4d29b61827c2456' - label_urls = { + label_urls: ClassVar[dict[str, str]] = { 'housing': data_url.format('housing.csv'), 'income': data_url.format('income.csv'), 'roads': data_url.format('roads.csv'), @@ -64,7 +65,7 @@ class USAVars(NonGeoDataset): 'treecover': data_url.format('treecover.csv'), } - split_metadata = { + split_metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'url': data_url.format('train_split.txt'), 'filename': 'train_split.txt', @@ -82,11 +83,11 @@ class USAVars(NonGeoDataset): }, } - ALL_LABELS = ['treecover', 'elevation', 'population'] + ALL_LABELS = ('treecover', 'elevation', 'population') def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', labels: Sequence[str] = ALL_LABELS, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, @@ -170,7 +171,7 @@ def _load_files(self) -> list[str]: files = f.read().splitlines() return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: diff --git a/torchgeo/datasets/utils.py b/torchgeo/datasets/utils.py index 12ed2aad40b..7ddbe08e597 100644 --- a/torchgeo/datasets/utils.py +++ b/torchgeo/datasets/utils.py @@ -6,183 +6,42 @@ # https://github.com/sphinx-doc/sphinx/issues/11327 from __future__ import annotations -import bz2 import collections import contextlib -import gzip import importlib -import lzma import os import shutil import subprocess import sys -import tarfile -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, cast, overload +from typing import Any, TypeAlias, cast, overload import numpy as np import rasterio import torch from torch import Tensor -from torchvision.datasets.utils import check_integrity, download_url +from torchvision.datasets.utils import ( + check_integrity, + download_and_extract_archive, + download_url, + extract_archive, +) from torchvision.utils import draw_segmentation_masks from .errors import DependencyNotFoundError # Only include import redirects -__all__ = ('check_integrity', 'download_url') +__all__ = ( + 'check_integrity', + 'download_and_extract_archive', + 'download_url', + 'extract_archive', +) -class _rarfile: - class RarFile: - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.args = args - self.kwargs = kwargs - - def __enter__(self) -> Any: - rarfile = lazy_import('rarfile') - # TODO: catch exception for when rarfile is installed but not - # unrar/unar/bsdtar - return rarfile.RarFile(*self.args, **self.kwargs) - - def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - pass - - -class _zipfile: - class ZipFile: - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.args = args - self.kwargs = kwargs - - def __enter__(self) -> Any: - try: - # Supports normal zip files, proprietary deflate64 compression algorithm - import zipfile_deflate64 as zipfile - except ImportError: - # Only supports normal zip files - # https://github.com/python/mypy/issues/1153 - import zipfile - - return zipfile.ZipFile(*self.args, **self.kwargs) - - def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - pass - - -def extract_archive(src: str, dst: str | None = None) -> None: - """Extract an archive. - - Args: - src: file to be extracted - dst: directory to extract to (defaults to dirname of ``src``) - - Raises: - RuntimeError: if src file has unknown archival/compression scheme - """ - if dst is None: - dst = os.path.dirname(src) - - suffix_and_extractor: list[tuple[str | tuple[str, ...], Any]] = [ - ('.rar', _rarfile.RarFile), - ( - ('.tar', '.tar.gz', '.tar.bz2', '.tar.xz', '.tgz', '.tbz2', '.tbz', '.txz'), - tarfile.open, - ), - ('.zip', _zipfile.ZipFile), - ] - - for suffix, extractor in suffix_and_extractor: - if src.endswith(suffix): - with extractor(src, 'r') as f: - f.extractall(dst) - return - - suffix_and_decompressor: list[tuple[str, Any]] = [ - ('.bz2', bz2.open), - ('.gz', gzip.open), - ('.xz', lzma.open), - ] - - for suffix, decompressor in suffix_and_decompressor: - if src.endswith(suffix): - dst = os.path.join(dst, os.path.basename(src).replace(suffix, '')) - with decompressor(src, 'rb') as sf, open(dst, 'wb') as df: - df.write(sf.read()) - return - - raise RuntimeError('src file has unknown archival/compression scheme') - - -def download_and_extract_archive( - url: str, - download_root: str, - extract_root: str | None = None, - filename: str | None = None, - md5: str | None = None, -) -> None: - """Download and extract an archive. - - Args: - url: URL to download - download_root: directory to download to - extract_root: directory to extract to (defaults to ``download_root``) - filename: download filename (defaults to basename of ``url``) - md5: checksum for download verification - """ - download_root = os.path.expanduser(download_root) - if extract_root is None: - extract_root = download_root - if not filename: - filename = os.path.basename(url) - - download_url(url, download_root, filename, md5) - - archive = os.path.join(download_root, filename) - print(f'Extracting {archive} to {extract_root}') - extract_archive(archive, extract_root) - - -def download_radiant_mlhub_dataset( - dataset_id: str, download_root: str, api_key: str | None = None -) -> None: - """Download a dataset from Radiant Earth. - - Args: - dataset_id: the ID of the dataset to fetch - download_root: directory to download to - api_key: the API key to use for all requests from the session. Can also be - passed in via the ``MLHUB_API_KEY`` environment variable, or configured in - ``~/.mlhub/profiles``. - - Raises: - DependencyNotFoundError: If radiant_mlhub is not installed. - """ - radiant_mlhub = lazy_import('radiant_mlhub') - dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key) - dataset.download(output_dir=download_root, api_key=api_key) - - -def download_radiant_mlhub_collection( - collection_id: str, download_root: str, api_key: str | None = None -) -> None: - """Download a collection from Radiant Earth. - - Args: - collection_id: the ID of the collection to fetch - download_root: directory to download to - api_key: the API key to use for all requests from the session. Can also be - passed in via the ``MLHUB_API_KEY`` environment variable, or configured in - ``~/.mlhub/profiles``. - - Raises: - DependencyNotFoundError: If radiant_mlhub is not installed. - """ - radiant_mlhub = lazy_import('radiant_mlhub') - collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key) - collection.download(output_dir=download_root, api_key=api_key) +Path: TypeAlias = str | os.PathLike[str] @dataclass(frozen=True) @@ -224,13 +83,12 @@ def __post_init__(self) -> None: f"Bounding box is invalid: 'mint={self.mint}' > 'maxt={self.maxt}'" ) - # https://github.com/PyCQA/pydocstyle/issues/525 @overload - def __getitem__(self, key: int) -> float: # noqa: D105 + def __getitem__(self, key: int) -> float: pass @overload - def __getitem__(self, key: slice) -> list[float]: # noqa: D105 + def __getitem__(self, key: slice) -> list[float]: pass def __getitem__(self, key: int | slice) -> float | list[float]: @@ -410,7 +268,7 @@ class Executable: .. versionadded:: 0.6 """ - def __init__(self, name: str) -> None: + def __init__(self, name: Path) -> None: """Initialize a new Executable instance. Args: @@ -429,7 +287,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[byt The completed process. """ kwargs['check'] = True - return subprocess.run((self.name,) + args, **kwargs) + return subprocess.run((self.name, *args), **kwargs) def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: @@ -449,8 +307,8 @@ def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: (mint, maxt) tuple for indexing """ mint = datetime.strptime(date_str, format) + format = format.replace('%%', '') - # TODO: This doesn't correctly handle literal `%%` characters in format # TODO: May have issues with time zones, UTC vs. local time, and DST # TODO: This is really tedious, is there a better way to do this? @@ -488,7 +346,7 @@ def disambiguate_timestamp(date_str: str, format: str) -> tuple[float, float]: @contextlib.contextmanager -def working_dir(dirname: str, create: bool = False) -> Iterator[None]: +def working_dir(dirname: Path, create: bool = False) -> Iterator[None]: """Context manager for changing directories. Args: @@ -507,7 +365,9 @@ def working_dir(dirname: str, create: bool = False) -> Iterator[None]: os.chdir(cwd) -def _list_dict_to_dict_list(samples: Iterable[dict[Any, Any]]) -> dict[Any, list[Any]]: +def _list_dict_to_dict_list( + samples: Iterable[Mapping[Any, Any]], +) -> dict[Any, list[Any]]: """Convert a list of dictionaries to a dictionary of lists. Args: @@ -518,14 +378,18 @@ def _list_dict_to_dict_list(samples: Iterable[dict[Any, Any]]) -> dict[Any, list .. versionadded:: 0.2 """ - collated = collections.defaultdict(list) + collated: dict[Any, list[Any]] = dict() for sample in samples: for key, value in sample.items(): + if key not in collated: + collated[key] = [] collated[key].append(value) return collated -def _dict_list_to_list_dict(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: +def _dict_list_to_list_dict( + sample: Mapping[Any, Sequence[Any]], +) -> list[dict[Any, Any]]: """Convert a dictionary of lists to a list of dictionaries. Args: @@ -545,7 +409,7 @@ def _dict_list_to_list_dict(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, return uncollated -def stack_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: +def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: """Stack a list of samples along a new axis. Useful for forming a mini-batch of samples to pass to @@ -566,7 +430,7 @@ def stack_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: return collated -def concat_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: +def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: """Concatenate a list of samples along an existing axis. Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`. @@ -588,7 +452,7 @@ def concat_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: return collated -def merge_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: +def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]: """Merge a list of samples. Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`. @@ -613,7 +477,7 @@ def merge_samples(samples: Iterable[dict[Any, Any]]) -> dict[Any, Any]: return collated -def unbind_samples(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: +def unbind_samples(sample: MutableMapping[Any, Any]) -> list[dict[Any, Any]]: """Reverse of :func:`stack_samples`. Useful for turning a mini-batch of samples into a list of samples. These individual @@ -633,7 +497,7 @@ def unbind_samples(sample: dict[Any, Sequence[Any]]) -> list[dict[Any, Any]]: return _dict_list_to_list_dict(sample) -def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]: +def rasterio_loader(path: Path) -> np.typing.NDArray[np.int_]: """Load an image file using rasterio. Args: @@ -649,7 +513,7 @@ def rasterio_loader(path: str) -> np.typing.NDArray[np.int_]: return array -def sort_sentinel2_bands(x: str) -> str: +def sort_sentinel2_bands(x: Path) -> str: """Sort Sentinel-2 band files in the correct order.""" x = os.path.basename(x).split('_')[-1] x = os.path.splitext(x)[0] @@ -674,7 +538,7 @@ def draw_semantic_segmentation_masks( colors: list of RGB int tuples, or color strings e.g. red, #FF00FF Returns: - a version of ``image`` overlayed with the colors given by ``mask`` and + a version of ``image`` overlaid with the colors given by ``mask`` and ``colors`` """ classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8)) @@ -687,7 +551,7 @@ def draw_semantic_segmentation_masks( def rgb_to_mask( - rgb: np.typing.NDArray[np.uint8], colors: list[tuple[int, int, int]] + rgb: np.typing.NDArray[np.uint8], colors: Sequence[tuple[int, int, int]] ) -> np.typing.NDArray[np.uint8]: """Converts an RGB colormap mask to a integer mask. @@ -744,7 +608,7 @@ def percentile_normalization( return img_normalized -def path_is_vsi(path: str) -> bool: +def path_is_vsi(path: Path) -> bool: """Checks if the given path is pointing to a Virtual File System. .. note:: @@ -758,14 +622,14 @@ def path_is_vsi(path: str) -> bool: * https://rasterio.readthedocs.io/en/latest/topics/datasets.html Args: - path: string representing a directory or file + path: a directory or file Returns: True if path is on a virtual file system, else False .. versionadded:: 0.6 """ - return '://' in path or path.startswith('/vsi') + return '://' in str(path) or str(path).startswith('/vsi') def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor: @@ -831,7 +695,7 @@ def lazy_import(name: str) -> Any: raise DependencyNotFoundError(msg) from None -def which(name: str) -> Executable: +def which(name: Path) -> Executable: """Search for executable *name*. Args: @@ -845,8 +709,8 @@ def which(name: str) -> Executable: .. versionadded:: 0.6 """ - if shutil.which(name): - return Executable(name) + if cmd := shutil.which(name): + return Executable(cmd) else: msg = f'{name} is not installed and is required to use this dataset.' raise DependencyNotFoundError(msg) from None diff --git a/torchgeo/datasets/vaihingen.py b/torchgeo/datasets/vaihingen.py index 6276dcb87cf..2c671ca27ac 100644 --- a/torchgeo/datasets/vaihingen.py +++ b/torchgeo/datasets/vaihingen.py @@ -5,6 +5,7 @@ import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -16,6 +17,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, draw_semantic_segmentation_masks, extract_archive, @@ -54,15 +56,15 @@ class Vaihingen2D(NonGeoDataset): * https://doi.org/10.5194/isprsannals-I-3-293-2012 .. versionadded:: 0.2 - """ # noqa: E501 + """ - filenames = [ + filenames = ( 'ISPRS_semantic_labeling_Vaihingen.zip', 'ISPRS_semantic_labeling_Vaihingen_ground_truth_COMPLETE.zip', - ] - md5s = ['462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277'] + ) + md5s = ('462b8dca7b6fa9eaf729840f0cdfc7f3', '4802dd6326e2727a352fb735be450277') image_root = 'top' - splits = { + splits: ClassVar[dict[str, list[str]]] = { 'train': [ 'top_mosaic_09cm_area1.tif', 'top_mosaic_09cm_area11.tif', @@ -101,26 +103,26 @@ class Vaihingen2D(NonGeoDataset): 'top_mosaic_09cm_area29.tif', ], } - classes = [ + classes = ( 'Clutter/background', 'Impervious surfaces', 'Building', 'Low Vegetation', 'Tree', 'Car', - ] - colormap = [ + ) + colormap = ( (255, 0, 0), (255, 255, 255), (0, 0, 255), (0, 255, 255), (0, 255, 0), (255, 255, 0), - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -257,7 +259,7 @@ def plot( """ ncols = 1 image1 = draw_semantic_segmentation_masks( - sample['image'][:3], sample['mask'], alpha=alpha, colors=self.colormap + sample['image'][:3], sample['mask'], alpha=alpha, colors=list(self.colormap) ) if 'prediction' in sample: ncols += 1 @@ -265,7 +267,7 @@ def plot( sample['image'][:3], sample['prediction'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/vhr10.py b/torchgeo/datasets/vhr10.py index b1aae5d2a30..9adc2f44e9e 100644 --- a/torchgeo/datasets/vhr10.py +++ b/torchgeo/datasets/vhr10.py @@ -5,7 +5,7 @@ import os from collections.abc import Callable -from typing import Any +from typing import Any, ClassVar import matplotlib.pyplot as plt import numpy as np @@ -18,6 +18,7 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import ( + Path, check_integrity, download_and_extract_archive, download_url, @@ -151,26 +152,24 @@ class VHR10(NonGeoDataset): .. note:: - This dataset requires the following additional libraries to be installed: + This dataset requires the following additional library to be installed: * `pycocotools `_ to load the ``annotations.json`` file for the "positive" image set - * `rarfile `_ to extract the dataset, - which is stored in a RAR file """ - image_meta = { - 'url': 'https://drive.google.com/file/d/1--foZ3dV5OCsqXQXT84UeKtrAqc5CkAE', - 'filename': 'NWPU VHR-10 dataset.rar', - 'md5': 'd30a7ff99d92123ebb0b3a14d9102081', + image_meta: ClassVar[dict[str, str]] = { + 'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/NWPU%20VHR-10%20dataset.zip', + 'filename': 'NWPU VHR-10 dataset.zip', + 'md5': '6add6751469c12dd8c8d6223064c6c4d', } - target_meta = { - 'url': 'https://raw.githubusercontent.com/chaozhong2010/VHR-10_dataset_coco/ce0ba0f5f6a0737031f1cbe05e785ddd5ef05bd7/NWPU%20VHR-10_dataset_coco/annotations.json', # noqa: E501 + target_meta: ClassVar[dict[str, str]] = { + 'url': 'https://hf.co/datasets/torchgeo/vhr10/resolve/7e7968ad265dadc4494e0ca4a079e0b63dc6f3f8/annotations.json', 'filename': 'annotations.json', 'md5': '7c76ec50c17a61bb0514050d20f22c08', } - categories = [ + categories = ( 'background', 'airplane', 'ships', @@ -182,11 +181,11 @@ class VHR10(NonGeoDataset): 'harbor', 'bridge', 'vehicle', - ] + ) def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'positive', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, @@ -196,7 +195,7 @@ def __init__( Args: root: root directory where dataset can be found - split: one of "postive" or "negative" + split: one of "positive" or "negative" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index 10602efc256..fe51f6ade8f 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -6,16 +6,15 @@ import glob import json import os -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any import pandas as pd import torch -from torch import Tensor from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import download_radiant_mlhub_collection, extract_archive +from .utils import Path, which class WesternUSALiveFuelMoisture(NonGeoDataset): @@ -25,7 +24,7 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): (mass of water in vegetation) and remotely sensed variables in the western United States. It contains 2615 datapoints and 138 variables. For more details see the - `dataset page `_. + `dataset page `_. Dataset Format: @@ -44,19 +43,17 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionadded:: 0.5 """ - collection_id = 'su_sar_moisture_content' - - md5 = 'a6c0721f06a3a0110b7d1243b18614f0' + url = 'https://radiantearth.blob.core.windows.net/mlhub/su-sar-moisture-content' label_name = 'percent(t)' - all_variable_names = [ + all_variable_names = ( # "date", 'slope(t)', 'elevation(t)', @@ -196,16 +193,14 @@ class WesternUSALiveFuelMoisture(NonGeoDataset): 'vh_vv(t-3)', 'lat', 'lon', - ] + ) def __init__( self, - root: str = 'data', - input_features: list[str] = all_variable_names, + root: Path = 'data', + input_features: Iterable[str] = all_variable_names, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: """Initialize a new Western USA Live Fuel Moisture Dataset. @@ -215,42 +210,22 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``input_features`` contains invalid variable names DatasetNotFoundError: If dataset is not found and *download* is False. """ - super().__init__() + assert set(input_features) <= set(self.all_variable_names) self.root = root + self.input_features = input_features self.transforms = transforms - self.checksum = checksum self.download = download - self.api_key = api_key self._verify() - assert all( - input in self.all_variable_names for input in input_features - ), 'Invalid input variable name.' - self.input_features = input_features - - self.collection = self._retrieve_collection() - self.dataframe = self._load_data() - def _retrieve_collection(self) -> list[str]: - """Retrieve dataset collection that maps samples to paths. - - Returns: - list of sample paths - """ - return glob.glob( - os.path.join(self.root, self.collection_id, '**', 'labels.geojson') - ) - def __len__(self) -> int: """Return the number of data points in the dataset. @@ -270,7 +245,7 @@ def __getitem__(self, index: int) -> dict[str, Any]: """ data = self.dataframe.iloc[index, :] - sample: dict[str, Tensor] = { + sample = { 'input': torch.tensor( data.drop([self.label_name]).values, dtype=torch.float32 ), @@ -289,7 +264,7 @@ def _load_data(self) -> pd.DataFrame: the features and label """ data_rows = [] - for path in self.collection: + for path in sorted(self.files): with open(path) as f: content = json.load(f) data_dict = content['properties'] @@ -297,21 +272,16 @@ def _load_data(self) -> pd.DataFrame: data_dict['lat'] = content['geometry']['coordinates'][1] data_rows.append(data_dict) - df: pd.DataFrame = pd.DataFrame(data_rows) - df = df[self.input_features + [self.label_name]] + df = pd.DataFrame(data_rows) + df = df[[*self.input_features, self.label_name]] return df def _verify(self) -> None: """Verify the integrity of the dataset.""" - # Check if the extracted files already exist - pathname = os.path.join(self.root, self.collection_id) - if os.path.exists(pathname): - return - - # Check if the zip files have already been downloaded - pathname = os.path.join(self.root, self.collection_id) + '.tar.gz' - if os.path.exists(pathname): - self._extract() + # Check if the files already exist + file_glob = os.path.join(self.root, '**', 'feature_*.geojson') + self.files = glob.glob(file_glob, recursive=True) + if self.files: return # Check if the user requested to download the dataset @@ -320,19 +290,10 @@ def _verify(self) -> None: # Download the dataset self._download() - self._extract() - - def _extract(self) -> None: - """Extract the dataset.""" - pathname = os.path.join(self.root, self.collection_id) + '.tar.gz' - extract_archive(pathname, self.root) + self.files = glob.glob(file_glob, recursive=True) - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - download_radiant_mlhub_collection(self.collection_id, self.root, api_key) - filename = os.path.join(self.root, self.collection_id) + '.tar.gz' - extract_archive(filename, self.root) + def _download(self) -> None: + """Download the dataset and extract it.""" + os.makedirs(self.root, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', self.url, self.root, '--recursive=true') diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 9854d18458f..9f7e9c4be72 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -6,6 +6,7 @@ import glob import os from collections.abc import Callable +from typing import ClassVar import matplotlib.pyplot as plt import numpy as np @@ -16,7 +17,12 @@ from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive +from .utils import ( + Path, + check_integrity, + draw_semantic_segmentation_masks, + extract_archive, +) class XView2(NonGeoDataset): @@ -49,7 +55,7 @@ class XView2(NonGeoDataset): .. versionadded:: 0.2 """ - metadata = { + metadata: ClassVar[dict[str, dict[str, str]]] = { 'train': { 'filename': 'train_images_labels_targets.tar.gz', 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', @@ -61,12 +67,12 @@ class XView2(NonGeoDataset): 'directory': 'test', }, } - classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] - colormap = ['green', 'blue', 'orange', 'red'] + classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed') + colormap = ('green', 'blue', 'orange', 'red') def __init__( self, - root: str = 'data', + root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, @@ -127,7 +133,7 @@ def __len__(self) -> int: """ return len(self.files) - def _load_files(self, root: str, split: str) -> list[dict[str, str]]: + def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: """Return the paths of the files in the dataset. Args: @@ -152,7 +158,7 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files - def _load_image(self, path: str) -> Tensor: + def _load_image(self, path: Path) -> Tensor: """Load a single image. Args: @@ -169,7 +175,7 @@ def _load_image(self, path: str) -> Tensor: tensor = tensor.permute((2, 0, 1)) return tensor - def _load_target(self, path: str) -> Tensor: + def _load_target(self, path: Path) -> Tensor: """Load the target mask for a single image. Args: @@ -237,10 +243,16 @@ def plot( """ ncols = 2 image1 = draw_semantic_segmentation_masks( - sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap + sample['image'][0], + sample['mask'][0], + alpha=alpha, + colors=list(self.colormap), ) image2 = draw_semantic_segmentation_masks( - sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap + sample['image'][1], + sample['mask'][1], + alpha=alpha, + colors=list(self.colormap), ) if 'prediction' in sample: # NOTE: this assumes predictions are made for post ncols += 1 @@ -248,7 +260,7 @@ def plot( sample['image'][1], sample['prediction'], alpha=alpha, - colors=self.colormap, + colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) diff --git a/torchgeo/datasets/zuericrop.py b/torchgeo/datasets/zuericrop.py index e1a5f4d2870..2928dc58a70 100644 --- a/torchgeo/datasets/zuericrop.py +++ b/torchgeo/datasets/zuericrop.py @@ -13,7 +13,7 @@ from .errors import DatasetNotFoundError, RGBBandsMissingError from .geo import NonGeoDataset -from .utils import download_url, lazy_import, percentile_normalization +from .utils import Path, download_url, lazy_import, percentile_normalization class ZueriCrop(NonGeoDataset): @@ -52,19 +52,19 @@ class ZueriCrop(NonGeoDataset): * `h5py `_ to load the dataset """ - urls = [ + urls = ( 'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download', - 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', # noqa: E501 - ] - md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b'] - filenames = ['ZueriCrop.hdf5', 'labels.csv'] + 'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv', + ) + md5s = ('1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b') + filenames = ('ZueriCrop.hdf5', 'labels.csv') band_names = ('NIR', 'B03', 'B02', 'B04', 'B05', 'B06', 'B07', 'B11', 'B12') - rgb_bands = ['B04', 'B03', 'B02'] + rgb_bands = ('B04', 'B03', 'B02') def __init__( self, - root: str = 'data', + root: Path = 'data', bands: Sequence[str] = band_names, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, diff --git a/torchgeo/main.py b/torchgeo/main.py index b403d4fa50c..48e84b6e8cf 100644 --- a/torchgeo/main.py +++ b/torchgeo/main.py @@ -8,7 +8,7 @@ from lightning.pytorch.cli import ArgsType, LightningCLI # Allows classes to be referenced using only the class name -import torchgeo.datamodules # noqa: F401 +import torchgeo.datamodules import torchgeo.trainers # noqa: F401 from torchgeo.datamodules import BaseDataModule from torchgeo.trainers import BaseTask diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index 327a343bde3..539be67180a 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -5,6 +5,7 @@ from .api import get_model, get_model_weights, get_weight, list_models from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg +from .croma import CROMA, CROMABase_Weights, CROMALarge_Weights, croma_base, croma_large from .dofa import ( DOFA, DOFABase16_Weights, @@ -18,39 +19,56 @@ from .fcn import FCN from .fcsiam import FCSiamConc, FCSiamDiff from .rcf import RCF -from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 -from .swin import Swin_V2_B_Weights, swin_v2_b +from .resnet import ( + ResNet18_Weights, + ResNet50_Weights, + ResNet152_Weights, + resnet18, + resnet50, + resnet152, +) +from .scale_mae import ScaleMAE, ScaleMAELarge16_Weights, scalemae_large_patch16 +from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t from .vit import ViTSmall16_Weights, vit_small_patch16_224 __all__ = ( - # models - 'ChangeMixin', - 'ChangeStar', - 'ChangeStarFarSeg', + 'CROMA', 'DOFA', - 'dofa_small_patch16_224', - 'dofa_base_patch16_224', - 'dofa_large_patch16_224', - 'dofa_huge_patch16_224', - 'FarSeg', 'FCN', - 'FCSiamConc', - 'FCSiamDiff', 'RCF', - 'resnet18', - 'resnet50', - 'swin_v2_b', - 'vit_small_patch16_224', - # weights + 'CROMABase_Weights', + 'CROMALarge_Weights', + 'ChangeMixin', + 'ChangeStar', + 'ChangeStarFarSeg', 'DOFABase16_Weights', 'DOFALarge16_Weights', - 'ResNet50_Weights', + 'FCSiamConc', + 'FCSiamDiff', + 'FarSeg', 'ResNet18_Weights', + 'ResNet50_Weights', + 'ResNet152_Weights', + 'ScaleMAE', + 'ScaleMAELarge16_Weights', 'Swin_V2_B_Weights', + 'Swin_V2_T_Weights', 'ViTSmall16_Weights', - # utilities + 'croma_base', + 'croma_large', + 'dofa_base_patch16_224', + 'dofa_huge_patch16_224', + 'dofa_large_patch16_224', + 'dofa_small_patch16_224', 'get_model', 'get_model_weights', 'get_weight', 'list_models', + 'resnet18', + 'resnet50', + 'resnet152', + 'scalemae_large_patch16', + 'swin_v2_b', + 'swin_v2_t', + 'vit_small_patch16_224', ) diff --git a/torchgeo/models/api.py b/torchgeo/models/api.py index 9e214d9d04e..b5b058726b2 100644 --- a/torchgeo/models/api.py +++ b/torchgeo/models/api.py @@ -8,7 +8,7 @@ * https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/ * https://pytorch.org/vision/stable/models.html * https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py -""" # noqa: E501 +""" from collections.abc import Callable from typing import Any @@ -22,8 +22,16 @@ dofa_base_patch16_224, dofa_large_patch16_224, ) -from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50 -from .swin import Swin_V2_B_Weights, swin_v2_b +from .resnet import ( + ResNet18_Weights, + ResNet50_Weights, + ResNet152_Weights, + resnet18, + resnet50, + resnet152, +) +from .scale_mae import ScaleMAELarge16_Weights, scalemae_large_patch16 +from .swin import Swin_V2_B_Weights, Swin_V2_T_Weights, swin_v2_b, swin_v2_t from .vit import ViTSmall16_Weights, vit_small_patch16_224 _model = { @@ -31,21 +39,30 @@ 'dofa_large_patch16_224': dofa_large_patch16_224, 'resnet18': resnet18, 'resnet50': resnet50, + 'resnet152': resnet152, + 'scalemae_large_patch16': scalemae_large_patch16, + 'swin_v2_t': swin_v2_t, 'swin_v2_b': swin_v2_b, 'vit_small_patch16_224': vit_small_patch16_224, } -_model_weights = { +_model_weights: dict[str | Callable[..., nn.Module], WeightsEnum] = { dofa_base_patch16_224: DOFABase16_Weights, dofa_large_patch16_224: DOFALarge16_Weights, resnet18: ResNet18_Weights, resnet50: ResNet50_Weights, + resnet152: ResNet152_Weights, + scalemae_large_patch16: ScaleMAELarge16_Weights, + swin_v2_t: Swin_V2_T_Weights, swin_v2_b: Swin_V2_B_Weights, vit_small_patch16_224: ViTSmall16_Weights, 'dofa_base_patch16_224': DOFABase16_Weights, 'dofa_large_patch16_224': DOFALarge16_Weights, 'resnet18': ResNet18_Weights, 'resnet50': ResNet50_Weights, + 'resnet152': ResNet152_Weights, + 'scalemae_large_patch16': ScaleMAELarge16_Weights, + 'swin_v2_t': Swin_V2_T_Weights, 'swin_v2_b': Swin_V2_B_Weights, 'vit_small_patch16_224': ViTSmall16_Weights, } @@ -92,8 +109,17 @@ def get_weight(name: str) -> WeightsEnum: Returns: The requested weight enum. + + Raises: + ValueError: If *name* is not a valid WeightsEnum. """ - return eval(name) + for weight_name, weight_enum in _model_weights.items(): + if isinstance(weight_name, str): + for sub_weight_enum in weight_enum: + if name == str(sub_weight_enum): + return sub_weight_enum + + raise ValueError(f'{name} is not a valid WeightsEnum') def list_models() -> list[str]: diff --git a/torchgeo/models/changestar.py b/torchgeo/models/changestar.py index 9d8da16e793..f2177a1fddb 100644 --- a/torchgeo/models/changestar.py +++ b/torchgeo/models/changestar.py @@ -29,7 +29,7 @@ def __init__( inner_channels: int = 16, num_convs: int = 4, scale_factor: float = 4.0, - ): + ) -> None: """Initializes a new ChangeMixin module. Args: diff --git a/torchgeo/models/croma.py b/torchgeo/models/croma.py new file mode 100644 index 00000000000..475c32fd3a9 --- /dev/null +++ b/torchgeo/models/croma.py @@ -0,0 +1,644 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +# Code based on https://github.com/antofuller/CROMA under MIT License + +"""CROMA model.""" + +import itertools +import math +from collections.abc import Sequence +from typing import Any + +import torch +from einops import rearrange +from torch import Tensor, einsum, nn +from torchvision.models._api import Weights, WeightsEnum + + +class CROMA(nn.Module): + """Pretrained CROMA model. + + Corresponds to the pretrained CROMA model found in the CROMA repository: + + * https://github.com/antofuller/CROMA/blob/main/pretrain_croma.py + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2311.00566 + """ + + valid_modalities = ('sar', 'optical') + + def __init__( + self, + modalities: Sequence[str] = ['sar', 'optical'], + encoder_dim: int = 768, + encoder_depth: int = 12, + num_heads: int = 16, + patch_size: int = 8, + image_size: int = 120, + ) -> None: + """Initialize the CROMA model. + + Args: + modalities: List of modalities used during forward pass, list can contain + 'sar', 'optical', or both. + encoder_dim: Dimension of the encoder. + encoder_depth: Depth of the encoder. + num_heads: Number of heads for the multi-head attention, should be power of 2. + patch_size: Size of the patches. + image_size: Size of the input images, CROMA was trained on 120x120 images, + must be a multiple of 8. + + Raises: + AssertionError: If any arguments are not valid. + """ + super().__init__() + for modality in modalities: + assert ( + modality in self.valid_modalities + ), f'{modality} is not a valid modality' + + assert image_size % 8 == 0, 'image_size must be a multiple of 8' + assert num_heads % 2 == 0, 'num_heads must be a power of 2' + + self.modalities = modalities + self.encoder_dim = encoder_dim + self.encoder_depth = encoder_depth + self.num_heads = num_heads + self.patch_size = patch_size + self.image_size = image_size + + self.num_patches = int((image_size / 8) ** 2) + self.s1_channels = 2 # fixed at 2 SAR backscatter channels + self.s2_channels = 12 # fixed at 12 multispectral optical channels + + self.attn_bias = get_2dalibi( + num_heads=self.num_heads, num_patches=self.num_patches + ) + + def initialize_encoder( + encoder_dim: int, encoder_depth: int, in_channels: int + ) -> tuple[nn.Module, nn.Module]: + """Initialize the encoder and GAP-FFN for a given modality. + + Args: + encoder_dim: Dimension of the encoder. + encoder_depth: Depth of the encoder. + in_channels: Number of input channels. + + Returns: + Tuple containing the encoder and GAP-FFN. + """ + encoder = ViT(dim=encoder_dim, depth=encoder_depth, in_channels=in_channels) + gap_ffn = nn.Sequential( + nn.LayerNorm(encoder_dim), + nn.Linear(encoder_dim, int(4 * encoder_dim)), + nn.GELU(), + nn.Linear(int(4 * encoder_dim), encoder_dim), + ) + return encoder, gap_ffn + + if 'sar' in modalities: + self.s1_encoder, self.s1_GAP_FFN = initialize_encoder( + encoder_dim, int(encoder_depth / 2), self.s1_channels + ) + if 'optical' in modalities: + self.s2_encoder, self.s2_GAP_FFN = initialize_encoder( + encoder_dim, encoder_depth, self.s2_channels + ) + if set(self.modalities) == {'sar', 'optical'}: + self.joint_encoder = BaseTransformerCrossAttn( + dim=encoder_dim, depth=int(encoder_depth / 2), num_heads=num_heads + ) + + def forward( + self, x_sar: Tensor | None = None, x_optical: Tensor | None = None + ) -> dict[str, Tensor]: + """Forward pass of the CROMA model. + + Args: + x_sar: Input mini-batch of SAR images [B, 2, H, W]. + x_optical: Input mini-batch of optical images [B, 12, H, W]. + """ + return_dict: dict[str, Tensor] = {} + + if 'sar' in self.modalities and x_sar is not None: + sar_encodings = self.s1_encoder(imgs=x_sar, attn_bias=self.attn_bias) + sar_GAP = self.s1_GAP_FFN(sar_encodings.mean(dim=1)) + return_dict['sar_encodings'] = sar_encodings + return_dict['sar_GAP'] = sar_GAP + + if 'optical' in self.modalities and x_optical is not None: + optical_encodings = self.s2_encoder( + imgs=x_optical, attn_bias=self.attn_bias + ) + optical_GAP = self.s2_GAP_FFN(optical_encodings.mean(dim=1)) + return_dict['optical_encodings'] = optical_encodings + return_dict['optical_GAP'] = optical_GAP + + if set(self.modalities) == {'sar', 'optical'}: + joint_encodings = self.joint_encoder( + x=sar_encodings, + context=optical_encodings, + relative_position_bias=self.attn_bias, + ) + joint_GAP = joint_encodings.mean(dim=1) + return_dict['joint_encodings'] = joint_encodings + return_dict['joint_GAP'] = joint_GAP + + return return_dict + + +def get_2dalibi(num_heads: int, num_patches: int) -> Tensor: + """Get 2D relative position bias for the attention layer. + + Args: + num_heads: Number of heads for the multi-head attention. + num_patches: Number of patches. + + Returns: + 2D relative position bias tensor. + """ + # inspired by: https://github.com/ofirpress/attention_with_linear_biases + points = list( + itertools.product( + range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches))) + ) + ) + + def get_slopes(n: int) -> list[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + slopes = torch.Tensor(get_slopes(num_heads)).unsqueeze(1) + idxs = [] + for p1 in points: + for p2 in points: + dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) + idxs.append(dist * slopes * -1) + all_bias = torch.cat(idxs, dim=1) + return all_bias.view(1, num_heads, num_patches, num_patches) + + +class FFN(nn.Module): + """Feed-forward network for the transformer.""" + + def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0) -> None: + """Initialize the feed-forward network. + + Args: + dim: Dimension of the input. + mult: Multiplier for the inner dimension of the feed-forward network. + dropout: Dropout probability + """ + super().__init__() + inner_dim = int(dim * mult) + + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ) + self.input_norm = nn.LayerNorm(dim) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass of the feed-forward network. + + Args: + x: Input tensor. + + Returns: + Output tensor. + """ + x = self.input_norm(x) + x = self.net(x) + return x + + +class Attention(nn.Module): + """Multi-head attention layer for the transformer.""" + + def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.0) -> None: + """Initialize the multi-head attention layer. + + Args: + dim: Dimension of the input. + num_heads: Number of heads for the multi-head attention. + dropout: Dropout probability. + """ + super().__init__() + self.num_heads = num_heads + assert dim % num_heads == 0, 'dim must be evenly divisible by num_heads' + dim_head = int(dim / num_heads) + self.scale = dim_head**-0.5 + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + self.to_out = nn.Linear(dim, dim) + self.input_norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor, relative_position_bias: Tensor) -> Tensor: + """Forward pass of the multi-head attention layer. + + Args: + x: Input tensor. + relative_position_bias: Relative position bias tensor. + + Returns: + Output tensor. + """ + x = self.input_norm(x) + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map( + lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q, k, v) + ) + + attention_scores = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + attention_scores = attention_scores + relative_position_bias + + attn = attention_scores.softmax(dim=-1) + attn = self.dropout(attn) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + + +class CrossAttention(nn.Module): + """Cross-attention layer for the transformer.""" + + def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.0) -> None: + """Initialize the cross-attention layer. + + Args: + dim: Dimension of the input. + num_heads: Number of heads for the multi-head attention. + dropout: Dropout probability. + + Raises: + AssertionError: If the dimension is not evenly divisible by the number of heads. + """ + super().__init__() + self.num_heads = num_heads + assert dim % num_heads == 0, 'dim must be evenly divisible by num_heads' + dim_head = int(dim / num_heads) + self.scale = dim_head**-0.5 + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_k = nn.Linear(dim, dim, bias=False) + self.to_v = nn.Linear(dim, dim, bias=False) + + self.to_out = nn.Linear(dim, dim) + self.input_norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) + + def forward( + self, x: Tensor, context: Tensor, relative_position_bias: Tensor + ) -> Tensor: + """Forward pass of the cross-attention layer. + + Args: + x: Input tensor. + context: Context tensor. + relative_position_bias: Relative position bias tensor. + + Returns: + Output tensor. + """ + x = self.input_norm(x) + context = self.input_norm(context) + + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map( + lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q, k, v) + ) + + attention_scores = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + attention_scores = attention_scores + relative_position_bias + + attn = attention_scores.softmax(dim=-1) + attn = self.dropout(attn) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + out = self.to_out(out) + return out + + +class BaseTransformer(nn.Module): + """Base transformer model.""" + + def __init__( + self, + dim: int, + depth: int, + num_heads: int = 8, + attn_dropout: float = 0.0, + ff_dropout: float = 0.0, + ff_mult: int = 4, + final_norm: bool = True, + ) -> None: + """Initialize the base transformer model. + + Args: + dim: Dimension of the input. + depth: Depth of the transformer. + num_heads: Number of heads for the multi-head attention. + attn_dropout: Dropout probability for the attention layer. + ff_dropout: Dropout probability for the feed-forward network. + ff_mult: Multiplier for the inner dimension of the feed-forward network. + final_norm: Whether to apply a final layer normalization. + """ + super().__init__() + self.final_norm = final_norm + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), + FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), + ] + ) + ) + + if self.final_norm: + self.norm_out = nn.LayerNorm(dim) + + def forward(self, x: Tensor, relative_position_bias: Tensor) -> Tensor: + """Forward pass of the base transformer model. + + Args: + x: Input tensor. + relative_position_bias: whether to use relative position bias. + """ + for self_attn, ffn in self.layers: + x = self_attn(x, relative_position_bias) + x + x = ffn(x) + x + + x = self.norm_out(x) if self.final_norm else x + return x + + +class BaseTransformerCrossAttn(nn.Module): + """Base transformer model with cross-attention.""" + + def __init__( + self, + dim: int, + depth: int, + num_heads: int = 8, + attn_dropout: float = 0.0, + ff_dropout: float = 0.0, + ff_mult: int = 4, + ) -> None: + """Initialize the base transformer model with cross-attention. + + Args: + dim: Dimension of the input. + depth: Depth of the transformer. + num_heads: Number of heads for the multi-head attention. + attn_dropout: Dropout probability for the attention layer. + ff_dropout: Dropout probability for the feed-forward network. + ff_mult: Multiplier for the inner dimension of the feed-forward network. + """ + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), + CrossAttention( + dim=dim, num_heads=num_heads, dropout=attn_dropout + ), + FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), + ] + ) + ) + + self.norm_out = nn.LayerNorm(dim) + + def forward( + self, x: Tensor, context: Tensor, relative_position_bias: Tensor + ) -> Tensor: + """Forward pass of the base transformer model with cross-attention. + + Args: + x: Input tensor. + context: Context tensor. + relative_position_bias: Relative position bias tensor. + + Returns: + Output tensor. + """ + for self_attn, cross_attn, ffn in self.layers: + x = self_attn(x, relative_position_bias) + x + x = cross_attn(x, context, relative_position_bias) + x + x = ffn(x) + x + + x = self.norm_out(x) + return x + + +class ViT(nn.Module): + """Vision Transformer model.""" + + def __init__(self, dim: int, depth: int, in_channels: int) -> None: + """Initialize the vision transformer model. + + Args: + dim: Dimension of the input. + depth: Depth of the transformer. + in_channels: Number of input channels. + """ + super().__init__() + self.depth = depth + self.in_channels = in_channels + self.dim = dim + self.num_heads = 16 # always 16, for base and large models + self.patch_size = 8 # always 8, for base and large models + + pixels_per_patch = int(self.patch_size * self.patch_size * in_channels) + self.linear_input = nn.Linear(pixels_per_patch, self.dim) + self.transformer = BaseTransformer( + dim=self.dim, depth=self.depth, num_heads=self.num_heads + ) + + def forward(self, imgs: Tensor, attn_bias: Tensor) -> Tensor: + """Forward pass of the vision transformer model. + + Args: + imgs: Input tensor. + attn_bias: Relative position bias tensor. + + Returns: + Output tensor. + """ + x = rearrange( + imgs, + 'b c (h i) (w j) -> b (h w) (c i j)', + i=self.patch_size, + j=self.patch_size, + ) + # x is shape -> (bsz, num_patches, self.channels*self.patch_size*self.patch_size) + + x = self.linear_input(x) + x = self.transformer(x, relative_position_bias=attn_bias) + return x + + +class CROMABase_Weights(WeightsEnum): # type: ignore[misc] + """CROMA base model weights. + + .. versionadded:: 0.7 + """ + + CROMA_VIT = Weights( + url='https://hf.co/torchgeo/croma/resolve/387883f08af79d777167519c57cd826eda89a16f/CROMA_base-0238d814.pt', + transforms=None, + meta={ + 'dataset': 'SSL4EO', + 'model': 'vit', + 'publication': 'https://arxiv.org/abs/2311.00566', + 'repo': 'https://github.com/antofuller/CROMA', + 'ssl_method': 'croma', + }, + ) + + +class CROMALarge_Weights(WeightsEnum): # type: ignore[misc] + """CROMA large model weights. + + .. versionadded:: 0.7 + """ + + CROMA_VIT = Weights( + url='https://huggingface.co/torchgeo/croma/resolve/92cb1a0f4e34c6c01558baf070197c01255382f6/CROMA_large-921e69ad.pt', + transforms=None, + meta={ + 'dataset': 'SSL4EO', + 'model': 'vit', + 'publication': 'https://arxiv.org/abs/2311.00566', + 'repo': 'https://github.com/antofuller/CROMA', + 'ssl_method': 'croma', + }, + ) + + +def load_weights(model: CROMA, weights: WeightsEnum) -> None: + """Load weights from a WeightsEnum object. + + Args: + model: Model to load the weights into. + weights: Weights to load. + + Raises: + AssertionError: If there are missing or unexpected keys. + """ + state_dict = weights.get_state_dict(progress=True) + missing_keys, unexpected_keys = [], [] + + if 'sar' in model.modalities: + miss_key, unexp_key = model.s1_encoder.load_state_dict( + state_dict['s1_encoder'], strict=False + ) + missing_keys.extend(miss_key) + unexpected_keys.extend(unexp_key) + miss_key, unexp_key = model.s1_GAP_FFN.load_state_dict( + state_dict['s1_GAP_FFN'], strict=False + ) + missing_keys.extend(miss_key) + unexpected_keys.extend(unexp_key) + + if 'optical' in model.modalities: + miss_key, unexp_key = model.s2_encoder.load_state_dict( + state_dict['s2_encoder'], strict=False + ) + missing_keys.extend(miss_key) + unexpected_keys.extend(unexp_key) + miss_key, unexp_key = model.s2_GAP_FFN.load_state_dict( + state_dict['s2_GAP_FFN'], strict=False + ) + missing_keys.extend(miss_key) + unexpected_keys.extend(unexp_key) + + if set(model.modalities) == {'sar', 'optical'}: + miss_key, unexp_key = model.joint_encoder.load_state_dict( + state_dict['joint_encoder'], strict=False + ) + missing_keys.extend(miss_key) + unexpected_keys.extend(unexp_key) + + assert not missing_keys, f'Missing keys: {missing_keys}' + assert not unexpected_keys, f'Unexpected keys: {unexpected_keys}' + + +def croma_base( + weights: CROMABase_Weights | None = None, *args: Any, **kwargs: Any +) -> CROMA: + """CROMA base model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2311.00566 + + .. versionadded:: 0.7 + + Args: + weights: Pretrained weights to load. + *args: Additional arguments to pass to :class:CROMA.` + **kwargs: Additional keyword arguments to pass to :class:CROMA.` + + Returns: + CROMA base model. + """ + kwargs |= { + 'encoder_dim': 768, + 'encoder_depth': 12, + 'num_heads': 16, + 'patch_size': 8, + } + model = CROMA(*args, **kwargs) + if weights: + load_weights(model, weights) + return model + + +def croma_large( + weights: CROMALarge_Weights | None = None, *args: Any, **kwargs: Any +) -> CROMA: + """CROMA large model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2311.00566 + + .. versionadded:: 0.7 + + Args: + weights: Pretrained weights to load. + *args: Additional arguments to pass to :class:CROMA.` + **kwargs: Additional keyword arguments to pass to :class:CROMA.` + + Returns: + CROMA large model. + """ + kwargs |= { + 'encoder_dim': 1024, + 'encoder_depth': 24, + 'num_heads': 16, + 'patch_size': 8, + } + model = CROMA(*args, **kwargs) + if weights: + load_weights(model, weights) + return model diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index 32f0be01a61..7184429aff7 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -65,6 +65,13 @@ def __init__( num_layers: Number of layers. """ super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_layers = num_layers + encoder_layer = nn.TransformerEncoderLayer( d_model=input_dim, nhead=num_heads, @@ -203,8 +210,10 @@ def forward(self, x: Tensor, wavelengths: Tensor) -> tuple[Tensor, Tensor]: weight, bias = self.weight_generator(waves) # 3x3x3 dynamic_weight = weight.view( - self.embed_dim, inplanes, self.kernel_size, self.kernel_size - ) # 3xoutdx16x16 + inplanes, self.kernel_size, self.kernel_size, self.embed_dim + ) + dynamic_weight = dynamic_weight.permute([3, 0, 1, 2]) + if bias is not None: bias = bias.view([self.embed_dim]) * self.scaler @@ -265,8 +274,17 @@ def __init__( """ super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.drop_rate = drop_rate + self.embed_dim = embed_dim + self.depth = depth + self.num_heads = num_heads self.dynamic_embed_dim = dynamic_embed_dim + self.num_classes = num_classes self.global_pool = global_pool + self.mlp_ratio = mlp_ratio + if self.global_pool: norm_layer = norm_layer embed_dim = embed_dim @@ -384,7 +402,7 @@ class DOFABase16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_base_patch16_224-7cc0f413.pth', # noqa: E501 + url='https://hf.co/torchgeo/dofa/resolve/b8db318b64a90b9e085ec04ba8851233c5893666/dofa_base_patch16_224-a0275954.pth', transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', @@ -403,7 +421,7 @@ class DOFALarge16_Weights(WeightsEnum): # type: ignore[misc] """ DOFA_MAE = Weights( - url='https://hf.co/torchgeo/dofa/resolve/ade8745c5ec6eddfe15d8c03421e8cb8f21e66ff/dofa_large_patch16_224-fbd47fa9.pth', # noqa: E501 + url='https://hf.co/torchgeo/dofa/resolve/b8db318b64a90b9e085ec04ba8851233c5893666/dofa_large_patch16_224-0ff904d3.pth', transforms=_dofa_transforms, meta={ 'dataset': 'SatlasPretrain, Five-Billion-Pixels, HySpecNet-11k', @@ -426,7 +444,7 @@ def dofa_small_patch16_224(*args: Any, **kwargs: Any) -> DOFA: Args: *args: Additional arguments to pass to :class:`DOFA`. - **kwargs: Additional keywork arguments to pass to :class:`DOFA`. + **kwargs: Additional keyword arguments to pass to :class:`DOFA`. Returns: A DOFA small 16 model. @@ -450,7 +468,7 @@ def dofa_base_patch16_224( Args: weights: Pre-trained model weights to use. *args: Additional arguments to pass to :class:`DOFA`. - **kwargs: Additional keywork arguments to pass to :class:`DOFA`. + **kwargs: Additional keyword arguments to pass to :class:`DOFA`. Returns: A DOFA base 16 model. @@ -488,7 +506,7 @@ def dofa_large_patch16_224( Args: weights: Pre-trained model weights to use. *args: Additional arguments to pass to :class:`DOFA`. - **kwargs: Additional keywork arguments to pass to :class:`DOFA`. + **kwargs: Additional keyword arguments to pass to :class:`DOFA`. Returns: A DOFA large 16 model. @@ -523,7 +541,7 @@ def dofa_huge_patch16_224(*args: Any, **kwargs: Any) -> DOFA: Args: *args: Additional arguments to pass to :class:`DOFA`. - **kwargs: Additional keywork arguments to pass to :class:`DOFA`. + **kwargs: Additional keyword arguments to pass to :class:`DOFA`. Returns: A DOFA huge 16 model. diff --git a/torchgeo/models/farseg.py b/torchgeo/models/farseg.py index f57f59cde5f..dc25858abf6 100644 --- a/torchgeo/models/farseg.py +++ b/torchgeo/models/farseg.py @@ -36,7 +36,7 @@ class FarSeg(Module): If you use this model in your research, please cite the following paper: - * https://arxiv.org/pdf/2011.09766.pdf + * https://arxiv.org/pdf/2011.09766 """ def __init__( diff --git a/torchgeo/models/fcsiam.py b/torchgeo/models/fcsiam.py index e1fa9a3a69e..8114cccab2b 100644 --- a/torchgeo/models/fcsiam.py +++ b/torchgeo/models/fcsiam.py @@ -32,7 +32,7 @@ def __init__( in_channels: int = 3, classes: int = 1, activation: str | Callable[[Tensor], Tensor] | None = None, - ): + ) -> None: """Initialize a new FCSiamConc model. Args: diff --git a/torchgeo/models/rcf.py b/torchgeo/models/rcf.py index 82545afa339..cc330a0192c 100644 --- a/torchgeo/models/rcf.py +++ b/torchgeo/models/rcf.py @@ -78,7 +78,7 @@ def __init__( # We register the weight and bias tensors as "buffers". This does two things: # makes them behave correctly when we call .to(...) on the module, and makes - # them explicitely _not_ Parameters of the model (which might get updated) if + # them explicitly _not_ Parameters of the model (which might get updated) if # a user tries to train with this model. self.register_buffer( 'weights', @@ -140,7 +140,7 @@ def _normalize( a numpy array of size (N, C, H, W) containing the normalized patches .. versionadded:: 0.5 - """ # noqa: E501 + """ n_patches = patches.shape[0] orig_shape = patches.shape patches = patches.reshape(patches.shape[0], -1) diff --git a/torchgeo/models/resnet.py b/torchgeo/models/resnet.py index 05d95ceb8f1..7bbeab22dc8 100644 --- a/torchgeo/models/resnet.py +++ b/torchgeo/models/resnet.py @@ -11,18 +11,137 @@ from timm.models import ResNet from torchvision.models._api import Weights, WeightsEnum -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 -# Normalization either by 10K or channel-wise with band statistics -_zhu_xlab_transforms = K.AugmentationSequential( +from .swin import ( + _satlas_bands, + _satlas_sentinel2_bands, + _satlas_sentinel2_transforms, + _satlas_transforms, +) + +# https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LT05_C02_T1_TOA +_landsat_tm_toa_bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7'] + +# https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LE07_C02_T1_TOA +_landsat_etm_toa_bands = [ + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6_VCID_1', + 'B6_VCID_2', + 'B7', + 'B8', +] + +# https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LE07_C02_T1_L2 +_landsat_etm_sr_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7'] + +# https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_TOA +_landsat_oli_tirs_toa_bands = [ + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8', + 'B9', + 'B10', + 'B11', +] + +# https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_L2 +_landsat_oli_sr_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7'] + +# https://github.com/zhu-xlab/SSL4EO-S12/blob/main/src/download_data/convert_rgb.py +_sentinel2_toa_bands = [ + 'B1', + 'B2', + 'B3', + 'B4', + 'B5', + 'B6', + 'B7', + 'B8', + 'B8a', + 'B9', + 'B10', + 'B11', + 'B12', +] + +# https://github.com/zhu-xlab/SSL4EO-S12/blob/main/src/download_data/convert_rgb.py +_sentinel2_rgb_bands = ['B4', 'B3', 'B2'] + +# https://github.com/zhu-xlab/SSL4EO-S12/blob/main/src/download_data/convert_rgb.py +_sentinel1_bands = ['VV', 'VH'] + +# https://github.com/zhu-xlab/DeCUR/blob/f190e9a3895ef645c005c8c2fce287ffa5a937e3/src/transfer_classification_BE/linear_BE_resnet.py#L286 +# Normalization by channel-wise band statistics +_mean_s1 = torch.tensor([-12.59, -20.26]) +_std_s1 = torch.tensor([5.26, 5.91]) +_ssl4eo_s12_transforms_s1 = K.AugmentationSequential( + K.Resize(256), + K.CenterCrop(224), + K.Normalize(mean=_mean_s1, std=_std_s1), + data_keys=None, +) + +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 +# Normalization either by 10K (for S2 uint16 input) or channel-wise with band statistics +_ssl4eo_s12_transforms_s2_10k = K.AugmentationSequential( K.Resize(256), K.CenterCrop(224), K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)), data_keys=None, ) +_mean_s2 = torch.tensor( + [ + 1612.9, + 1397.6, + 1322.3, + 1373.1, + 1561.0, + 2108.4, + 2390.7, + 2318.7, + 2581.0, + 837.7, + 22.0, + 2195.2, + 1537.4, + ] +) +_std_s2 = torch.tensor( + [ + 791.0, + 854.3, + 878.7, + 1144.9, + 1127.5, + 1164.2, + 1276.0, + 1249.5, + 1345.9, + 577.5, + 47.5, + 1340.0, + 1142.9, + ] +) +_ssl4eo_s12_transforms_s2_stats = K.AugmentationSequential( + K.Resize(256), + K.CenterCrop(224), + K.Normalize(mean=_mean_s2, std=_std_s2), + data_keys=None, +) + # Normalization only available for RGB dataset, defined here: -# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py # noqa: E501 +# https://github.com/ServiceNow/seasonal-contrast/blob/8285173ec205b64bc3e53b880344dd6c3f79fa7a/datasets/seco_dataset.py _min = torch.tensor([3, 2, 0]) _max = torch.tensor([88, 103, 129]) _mean = torch.tensor([0.485, 0.456, 0.406]) @@ -37,7 +156,7 @@ ) # Normalization only available for RGB dataset, defined here: -# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 # noqa: E501 +# https://github.com/sustainlab-group/geography-aware-ssl/blob/main/moco_fmow/main_moco_geo%2Btp.py#L287 _mean = torch.tensor([0.485, 0.456, 0.406]) _std = torch.tensor([0.229, 0.224, 0.225]) _gassl_transforms = K.AugmentationSequential( @@ -47,7 +166,7 @@ data_keys=None, ) -# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 +# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 _ssl4eo_l_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.CenterCrop((224, 224)), @@ -61,16 +180,16 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] - """ResNet18 weights. + """ResNet-18 weights. - For `timm `_ + For `timm `_ *resnet18* implementation. .. versionadded:: 0.4 """ LANDSAT_TM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_moco-1c691b4f.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -79,11 +198,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_tm_toa_bands, }, ) LANDSAT_TM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -92,11 +212,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_tm_toa_bands, }, ) LANDSAT_ETM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_moco-bb88689c.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -105,11 +226,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_etm_toa_bands, }, ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_toa_simclr-4d813f79.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -118,11 +240,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_etm_toa_bands, }, ) LANDSAT_ETM_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_moco-4f078acd.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -131,11 +254,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_etm_sr_bands, }, ) LANDSAT_ETM_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_etm_sr_simclr-8e8543b4.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -144,11 +268,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_etm_sr_bands, }, ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_moco-a3002f51.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -157,11 +282,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_oli_tirs_toa_bands, }, ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_tirs_toa_simclr-b0635cc6.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -170,11 +296,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_oli_tirs_toa_bands, }, ) LANDSAT_OLI_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_moco-660e82ed.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -183,11 +310,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_oli_sr_bands, }, ) LANDSAT_OLI_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_oli_sr_simclr-7bced5be.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -196,12 +324,13 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_oli_sr_bands, }, ) SENTINEL2_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', # noqa: E501 - transforms=_zhu_xlab_transforms, + url='https://hf.co/torchgeo/resnet18_sentinel2_all_moco/resolve/5b8cddc9a14f3844350b7f40b85bcd32aed75918/resnet18_sentinel2_all_moco-59bfdff9.pth', + transforms=_ssl4eo_s12_transforms_s2_10k, meta={ 'dataset': 'SSL4EO-S12', 'in_chans': 13, @@ -209,12 +338,13 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco', + 'bands': _sentinel2_toa_bands, }, ) SENTINEL2_RGB_MOCO = Weights( - url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', # noqa: E501 - transforms=_zhu_xlab_transforms, + url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_moco/resolve/e1c032e7785fd0625224cdb6699aa138bb304eec/resnet18_sentinel2_rgb_moco-e3a335e3.pth', + transforms=_ssl4eo_s12_transforms_s2_10k, meta={ 'dataset': 'SSL4EO-S12', 'in_chans': 3, @@ -222,11 +352,12 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco', + 'bands': _sentinel2_rgb_bands, }, ) SENTINEL2_RGB_SECO = Weights( - url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet18_sentinel2_rgb_seco/resolve/f8dcee692cf7142163b55a5c197d981fe0e717a0/resnet18_sentinel2_rgb_seco-cefca942.pth', transforms=_seco_transforms, meta={ 'dataset': 'SeCo Dataset', @@ -235,21 +366,22 @@ class ResNet18_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2103.16607', 'repo': 'https://github.com/ServiceNow/seasonal-contrast', 'ssl_method': 'seco', + 'bands': _sentinel2_rgb_bands, }, ) class ResNet50_Weights(WeightsEnum): # type: ignore[misc] - """ResNet50 weights. + """ResNet-50 weights. - For `timm `_ + For `timm `_ *resnet50* implementation. .. versionadded:: 0.4 """ FMOW_RGB_GASSL = Weights( - url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_fmow_rgb_gassl/resolve/fe8a91026cf9104f1e884316b8e8772d7af9052c/resnet50_fmow_rgb_gassl-da43d987.pth', transforms=_gassl_transforms, meta={ 'dataset': 'fMoW Dataset', @@ -262,7 +394,7 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] ) LANDSAT_TM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_moco-ba1ce753.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -271,11 +403,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_tm_toa_bands, }, ) LANDSAT_TM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_tm_toa_simclr-a1c93432.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -284,11 +417,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_tm_toa_bands, }, ) LANDSAT_ETM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_moco-e9a84d5a.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -297,11 +431,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_etm_toa_bands, }, ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_toa_simclr-70b5575f.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -310,11 +445,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_etm_toa_bands, }, ) LANDSAT_ETM_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_moco-1266cde3.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -323,11 +459,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_etm_sr_bands, }, ) LANDSAT_ETM_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_etm_sr_simclr-e5d185d7.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -336,11 +473,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_etm_sr_bands, }, ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_moco-de7f5e0f.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -349,11 +487,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_oli_tirs_toa_bands, }, ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_tirs_toa_simclr-030cebfe.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -362,11 +501,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_oli_tirs_toa_bands, }, ) LANDSAT_OLI_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_moco-ff580dad.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -375,11 +515,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_oli_sr_bands, }, ) LANDSAT_OLI_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet50_landsat_oli_sr_simclr-94f78913.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -388,12 +529,27 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_oli_sr_bands, + }, + ) + + SENTINEL1_ALL_DECUR = Weights( + url='https://huggingface.co/torchgeo/decur/resolve/9328eeb90c686a88b30f8526ed757b4bc0f12027/rn50_ssl4eo-s12_sar_decur_ep100-f0e69ba2.pth', + transforms=_ssl4eo_s12_transforms_s1, + meta={ + 'dataset': 'SSL4EO-S12', + 'in_chans': 2, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2309.05300', + 'repo': 'https://github.com/zhu-xlab/DeCUR', + 'ssl_method': 'decur', + 'bands': _sentinel1_bands, }, ) SENTINEL1_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', # noqa: E501 - transforms=_zhu_xlab_transforms, + url='https://hf.co/torchgeo/resnet50_sentinel1_all_moco/resolve/e79862c667853c10a709bdd77ea8ffbad0e0f1cf/resnet50_sentinel1_all_moco-906e4356.pth', + transforms=_ssl4eo_s12_transforms_s1, meta={ 'dataset': 'SSL4EO-S12', 'in_chans': 2, @@ -401,12 +557,27 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco', + 'bands': _sentinel1_bands, + }, + ) + + SENTINEL2_ALL_DECUR = Weights( + url='https://huggingface.co/torchgeo/decur/resolve/eba7ae5945d482a4319be046d34b552db5dd9950/rn50_ssl4eo-s12_ms_decur_ep100-fc6b09ff.pth', + transforms=_ssl4eo_s12_transforms_s2_10k, + meta={ + 'dataset': 'SSL4EO-S12', + 'in_chans': 13, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2309.05300', + 'repo': 'https://github.com/zhu-xlab/DeCUR', + 'ssl_method': 'decur', + 'bands': _sentinel2_toa_bands, }, ) SENTINEL2_ALL_DINO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', # noqa: E501 - transforms=_zhu_xlab_transforms, + url='https://hf.co/torchgeo/resnet50_sentinel2_all_dino/resolve/d7f14bf5530d70ac69d763e58e77e44dbecfec7c/resnet50_sentinel2_all_dino-d6c330e9.pth', + transforms=_ssl4eo_s12_transforms_s2_10k, meta={ 'dataset': 'SSL4EO-S12', 'in_chans': 13, @@ -414,12 +585,13 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'dino', + 'bands': _sentinel2_toa_bands, }, ) SENTINEL2_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', # noqa: E501 - transforms=_zhu_xlab_transforms, + url='https://hf.co/torchgeo/resnet50_sentinel2_all_moco/resolve/da4f3c9dbe09272eb902f3b37f46635fa4726879/resnet50_sentinel2_all_moco-df8b932e.pth', + transforms=_ssl4eo_s12_transforms_s2_10k, meta={ 'dataset': 'SSL4EO-S12', 'in_chans': 13, @@ -427,12 +599,39 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco', + 'bands': _sentinel2_toa_bands, + }, + ) + + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_ms-da5413d2.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_rgb-e79bb7fe.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, }, ) SENTINEL2_RGB_MOCO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', # noqa: E501 - transforms=_zhu_xlab_transforms, + url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_moco/resolve/efd9723b59a88e9dc1420dc1e96afb25b0630a3c/resnet50_sentinel2_rgb_moco-2b57ba8b.pth', + transforms=_ssl4eo_s12_transforms_s2_10k, meta={ 'dataset': 'SSL4EO-S12', 'in_chans': 3, @@ -440,11 +639,12 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco', + 'bands': _sentinel2_rgb_bands, }, ) SENTINEL2_RGB_SECO = Weights( - url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', # noqa: E501 + url='https://hf.co/torchgeo/resnet50_sentinel2_rgb_seco/resolve/fbd07b02a8edb8fc1035f7957160deed4321c145/resnet50_sentinel2_rgb_seco-018bf397.pth', transforms=_seco_transforms, meta={ 'dataset': 'SeCo Dataset', @@ -453,6 +653,95 @@ class ResNet50_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2103.16607', 'repo': 'https://github.com/ServiceNow/seasonal-contrast', 'ssl_method': 'seco', + 'bands': _sentinel2_rgb_bands, + }, + ) + + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_ms-1f454cc6.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_si_rgb-45fc6972.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, + }, + ) + + +class ResNet152_Weights(WeightsEnum): # type: ignore[misc] + """ResNet-152 weights. + + For `timm `_ + *resnet152* implementation. + + .. versionadded:: 0.6 + """ + + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_ms-fd35b4bb.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_mi_rgb-67563ac5.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, + }, + ) + + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_ms-4500c6cb.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet152_si_rgb-f4d24c3c.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'resnet50', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlaspretrain_models', + 'bands': _satlas_bands, }, ) @@ -464,14 +753,14 @@ def resnet18( If you use this model in your research, please cite the following paper: - * https://arxiv.org/pdf/1512.03385.pdf + * https://arxiv.org/pdf/1512.03385 .. versionadded:: 0.4 Args: weights: Pre-trained model weights to use. *args: Additional arguments to pass to :func:`timm.create_model` - **kwargs: Additional keywork arguments to pass to :func:`timm.create_model` + **kwargs: Additional keyword arguments to pass to :func:`timm.create_model` Returns: A ResNet-18 model. @@ -498,7 +787,7 @@ def resnet50( If you use this model in your research, please cite the following paper: - * https://arxiv.org/pdf/1512.03385.pdf + * https://arxiv.org/pdf/1512.03385 .. versionchanged:: 0.4 Switched to multi-weight support API. @@ -506,7 +795,7 @@ def resnet50( Args: weights: Pre-trained model weights to use. *args: Additional arguments to pass to :func:`timm.create_model`. - **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keyword arguments to pass to :func:`timm.create_model`. Returns: A ResNet-50 model. @@ -524,3 +813,37 @@ def resnet50( assert not unexpected_keys return model + + +def resnet152( + weights: ResNet152_Weights | None = None, *args: Any, **kwargs: Any +) -> ResNet: + """ResNet-152 model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/pdf/1512.03385 + + .. versionadded:: 0.6 + + Args: + weights: Pre-trained model weights to use. + *args: Additional arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keyword arguments to pass to :func:`timm.create_model`. + + Returns: + A ResNet-152 model. + """ + if weights: + kwargs['in_chans'] = weights.meta['in_chans'] + + model: ResNet = timm.create_model('resnet152', *args, **kwargs) + + if weights: + missing_keys, unexpected_keys = model.load_state_dict( + weights.get_state_dict(progress=True), strict=False + ) + assert set(missing_keys) <= {'fc.weight', 'fc.bias'} + assert not unexpected_keys + + return model diff --git a/torchgeo/models/scale_mae.py b/torchgeo/models/scale_mae.py new file mode 100644 index 00000000000..7dd689e0f1e --- /dev/null +++ b/torchgeo/models/scale_mae.py @@ -0,0 +1,251 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Pre-trained Scale-MAE models.""" + +from collections import OrderedDict +from functools import partial +from typing import Any + +import kornia.augmentation as K +import torch +import torch.nn as nn +from timm.models.vision_transformer import VisionTransformer +from torch import Tensor +from torchvision.models._api import Weights, WeightsEnum + +_mean = torch.tensor([0.485, 0.456, 0.406]) +_std = torch.tensor([0.229, 0.224, 0.225]) +_scale_mae_transforms = K.AugmentationSequential( + K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), + K.Normalize(mean=_mean, std=_std), + data_keys=None, +) + + +def get_2d_sincos_pos_embed_with_resolution( + embed_dim: int, grid_size: int, res: Tensor, cls_token: bool = False +) -> Tensor: + """Generate spatial resolution specific 2D positional embeddings. + + Args: + embed_dim: Dimension of the positional embeddings. + grid_size: Height (ph) and width (pw) of the image patches. + res: Spatial resolution tensor of shape (N,) of the image. + cls_token: Increase positional embedding size by 1 for class token. + + Returns: + pos_embed: Spatial resolution aware positional embeddings (Ph * Pw, D). + """ + device, dtype = res.device, res.dtype + grid_h = torch.arange(grid_size, dtype=dtype, device=device) + grid_w = torch.arange(grid_size, dtype=dtype, device=device) + grid: Tensor = torch.stack(torch.meshgrid(grid_w, grid_h, indexing='xy'), dim=0) + grid = torch.einsum('chw,n->cnhw', grid, res) + _, n, h, w = grid.shape + pos_embed = get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid) + pos_embed = pos_embed.reshape(n, h * w, embed_dim) + if cls_token: + pos_embed = torch.cat( + [torch.zeros([n, 1, embed_dim], dtype=dtype, device=device), pos_embed], + dim=1, + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid_torch(embed_dim: int, grid: Tensor) -> Tensor: + """Generate 2D sin-cos positional embedding from grid. + + Args: + embed_dim: Dimension of the positional embeddings. + grid: Tensor representing the image patch grid (C, N, Ph, Pw) + + Returns: + emb: 2D sin-cos positional embeddings (Ph * Pw, D). + """ + assert embed_dim % 2 == 0 + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[0]) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid[1]) + emb = torch.cat([emb_h, emb_w], dim=1) + return emb + + +def get_1d_sincos_pos_embed_from_grid_torch(embed_dim: int, pos: Tensor) -> Tensor: + """Generate 1D sin-cos positional embedding from grid dimension. + + Args: + embed_dim: Dimension of the positional embeddings. + pos: Tensor of positions to be encoded (M,). + + Returns: + emb: 1D sin-cos positional embeddings (M, D). + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=pos.dtype, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + pos = pos.reshape(-1) + out = torch.einsum('m,d->md', pos, omega) + emb_sin = torch.sin(out) + emb_cos = torch.cos(out) + emb = torch.cat([emb_sin, emb_cos], dim=1) + return emb + + +class ScaleMAE(VisionTransformer): # type: ignore[misc] + """Custom Vision Transformer for Scale-MAE with GSD positional embeddings. + + This is a ViT encoder only model of the Scale-MAE architecture with GSD positional embeddings. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2212.14532 + """ + + def __init__(self, res: float = 1.0, *args: Any, **kwargs: Any) -> None: + """Initialize a new ScaleMAE model. + + Args: + res: Spatial resolution of the image in meters. + *args: Additional arguments to + pass to :class:`timm.models.vision_transformer.VisionTransformer`. + **kwargs: Additional keyword arguments to + pass to :class:`timm.models.vision_transformer.VisionTransformer`. + """ + super().__init__(*args, **kwargs) + + self.res = res + + # Scale MAE uses resolution specific positional embeddings + self.pos_embed.requires_grad = False + + def _pos_embed(self, x: Tensor) -> Tensor: + """Apply GSD positional embeddings to the input tensor.""" + res = torch.tensor(self.res, dtype=x.dtype, device=x.device) + res = res.repeat(x.shape[0]) + pos_embed = ( + get_2d_sincos_pos_embed_with_resolution( + self.embed_dim, + int(self.patch_embed.num_patches**0.5), + res, + cls_token=True, + ) + .to(x.dtype) + .to(x.device) + ) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + pos_embed + x = self.pos_drop(x) + return x + + +def interpolate_pos_embed( + model: ScaleMAE, state_dict: OrderedDict[str, Tensor] +) -> OrderedDict[str, Tensor]: + """Interpolate the positional embeddings if image size is different than pretrained image size. + + Args: + model: ScaleMAE model. + state_dict: Pretrained model state dict. + + Returns: + state_dict: State dict with interpolated positional embeddings. + """ + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + f'Interpolating positional embeddings from {orig_size}x{orig_size} to {new_size}x{new_size}' + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + return state_dict + + +# https://github.com/pytorch/vision/pull/6883 +# https://github.com/pytorch/vision/pull/7107 +# Can be removed once torchvision>=0.15 is required +Weights.__deepcopy__ = lambda *args, **kwargs: args[0] + + +class ScaleMAELarge16_Weights(WeightsEnum): # type: ignore[misc] + """Scale-MAE Large patch size 16 weights. + + .. versionadded:: 0.6 + """ + + FMOW_RGB = Weights( + url='https://hf.co/torchgeo/vit_large_patch16_224_fmow_rgb_scalemae/resolve/9dc7f569424baeb780698352cf6e87638c882123/vit_large_patch16_224_fmow_rgb_scalemae-98ed9821.pth', + transforms=_scale_mae_transforms, + meta={ + 'dataset': 'fMoW', + 'in_chans': 3, + 'img_size': 224, + 'model': 'vit_large_patch16', + 'publication': 'https://arxiv.org/abs/2212.14532', + 'repo': 'https://github.com/bair-climate-initiative/scale-mae', + }, + ) + + +def scalemae_large_patch16( + weights: ScaleMAELarge16_Weights | None = None, *args: Any, **kwargs: Any +) -> ScaleMAE: + """Scale-MAE Large model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2212.14532 + + .. versionadded:: 0.6 + + Args: + weights: Pre-trained model weights to use. + *args: Additional arguments to + pass to :class:`ScaleMAE`. + **kwargs: Additional keyword arguments to + pass to :class:`ScaleMAE`. + + Returns: + A Scale-MAE Large patch16 model. + """ + model = ScaleMAE( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + *args, + **kwargs, + ) + + if weights: + state_dict = weights.get_state_dict(progress=True) + + if 'img_size' in kwargs and weights.meta['img_size'] != kwargs['img_size']: + state_dict = interpolate_pos_embed(model, state_dict) + + model.load_state_dict(state_dict, strict=False) + + return model diff --git a/torchgeo/models/swin.py b/torchgeo/models/swin.py index 21df8dedd91..dacf3a6f0c7 100644 --- a/torchgeo/models/swin.py +++ b/torchgeo/models/swin.py @@ -8,35 +8,38 @@ import kornia.augmentation as K import torch import torchvision -from kornia.contrib import Lambda from torchvision.models import SwinTransformer from torchvision.models._api import Weights, WeightsEnum -# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501 -# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). -# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 -# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501 +import torchgeo.transforms.transforms as T + +# All Satlas transforms include: +# https://github.com/allenai/satlas/blob/main/satlas/cmd/model/train.py#L49 +# +# Information about sensor-specific normalization can be found at: +# https://github.com/allenai/satlas/blob/main/Normalization.md + +_satlas_bands = ('B04', 'B03', 'B02') _satlas_transforms = K.AugmentationSequential( - K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=None + K.CenterCrop(256), + K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), + data_keys=None, ) -# Satlas uses the TCI product for Sentinel-2 RGB, which is in the range (0, 255). -# See details: https://github.com/allenai/satlas/blob/main/Normalization.md#sentinel-2-images. # noqa: E501 -# Satlas Sentinel-2 multispectral imagery has first 3 bands divided by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501 -_std = torch.tensor( - [255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0] -) # noqa: E501 -_mean = torch.zeros_like(_std) -_sentinel2_ms_satlas_transforms = K.AugmentationSequential( - K.Normalize(mean=_mean, std=_std), - K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), +_satlas_sentinel2_bands = (*_satlas_bands, 'B05', 'B06', 'B07', 'B08', 'B11', 'B12') +_std = torch.tensor([255, 255, 255, 8160, 8160, 8160, 8160, 8160, 8160]) +_satlas_sentinel2_transforms = K.AugmentationSequential( + K.CenterCrop(256), + K.Normalize(mean=torch.tensor(0), std=_std), + T._Clamp(p=1, min=0, max=1), data_keys=None, ) -# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501 -_landsat_satlas_transforms = K.AugmentationSequential( +_satlas_landsat_bands = tuple(f'B{i:02}' for i in range(1, 12)) +_satlas_landsat_transforms = K.AugmentationSequential( + K.CenterCrop(256), K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)), - K.ImageSequential(Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0))), + T._Clamp(p=1, min=0, max=1), data_keys=None, ) @@ -46,6 +49,68 @@ Weights.__deepcopy__ = lambda *args, **kwargs: args[0] +class Swin_V2_T_Weights(WeightsEnum): # type: ignore[misc] + """Swin Transformer v2 Tiny weights. + + For `torchvision `_ + *swin_v2_t* implementation. + + .. versionadded:: 0.6 + """ + + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_ms-d8c659e3.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_mi_rgb-424d91f4.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_bands, + }, + ) + + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_ms-bc68e396.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swint_si_rgb-0c1a96e0.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_t', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_bands, + }, + ) + + class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] """Swin Transformer v2 Base weights. @@ -55,82 +120,176 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc] .. versionadded:: 0.6 """ - NAIP_RGB_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/aerial_swinb_si.pth', # noqa: E501 + NAIP_RGB_MI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_mi-326d69e1.pth', transforms=_satlas_transforms, meta={ - 'dataset': 'Satlas', + 'dataset': 'SatlasPretrain', 'in_chans': 3, 'model': 'swin_v2_b', 'publication': 'https://arxiv.org/abs/2211.15660', 'repo': 'https://github.com/allenai/satlas', + 'bands': ('R', 'G', 'B'), }, ) - SENTINEL2_RGB_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_rgb.pth', # noqa: E501 + NAIP_RGB_SI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/aerial_swinb_si-e4169eb1.pth', transforms=_satlas_transforms, meta={ - 'dataset': 'Satlas', + 'dataset': 'SatlasPretrain', 'in_chans': 3, 'model': 'swin_v2_b', 'publication': 'https://arxiv.org/abs/2211.15660', 'repo': 'https://github.com/allenai/satlas', + 'bands': ('R', 'G', 'B'), }, ) - SENTINEL2_MS_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel2_swinb_si_ms.pth', # noqa: E501 - transforms=_sentinel2_ms_satlas_transforms, + LANDSAT_MI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_mi-6b4a1cda.pth', + transforms=_satlas_landsat_transforms, meta={ - 'dataset': 'Satlas', - 'in_chans': 9, + 'dataset': 'SatlasPretrain', + 'in_chans': 11, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_landsat_bands, + }, + ) + + LANDSAT_SI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/landsat_swinb_si-4af978f6.pth', + transforms=_satlas_landsat_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 11, 'model': 'swin_v2_b', 'publication': 'https://arxiv.org/abs/2211.15660', 'repo': 'https://github.com/allenai/satlas', - 'bands': ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B11', 'B12'], + 'bands': _satlas_landsat_bands, + }, + ) + + SENTINEL1_MI_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_mi-f6c43d97.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 2, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': ('VH', 'VV'), }, ) SENTINEL1_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/sentinel1_swinb_si.pth', # noqa: E501 + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel1_swinb_si-3981c153.pth', transforms=_satlas_transforms, meta={ - 'dataset': 'Satlas', + 'dataset': 'SatlasPretrain', 'in_chans': 2, 'model': 'swin_v2_b', 'publication': 'https://arxiv.org/abs/2211.15660', 'repo': 'https://github.com/allenai/satlas', - 'bands': ['VH', 'VV'], + 'bands': ('VH', 'VV'), }, ) - LANDSAT_SI_SATLAS = Weights( - url='https://hf.co/allenai/satlas-pretrain/resolve/daa578a4be36573d9791bf51dcd0420b8dc75732/landsat_swinb_si.pth', # noqa: E501 - transforms=_landsat_satlas_transforms, + SENTINEL2_MI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_ms-39c86721.pth', + transforms=_satlas_sentinel2_transforms, meta={ - 'dataset': 'Satlas', - 'in_chans': 11, + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_MI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_mi_rgb-4efa210c.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_bands, + }, + ) + + SENTINEL2_SI_MS_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_ms-fe22a12c.pth', + transforms=_satlas_sentinel2_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 9, + 'model': 'swin_v2_b', + 'publication': 'https://arxiv.org/abs/2211.15660', + 'repo': 'https://github.com/allenai/satlas', + 'bands': _satlas_sentinel2_bands, + }, + ) + + SENTINEL2_SI_RGB_SATLAS = Weights( + url='https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_swinb_si_rgb-156a98d5.pth', + transforms=_satlas_transforms, + meta={ + 'dataset': 'SatlasPretrain', + 'in_chans': 3, 'model': 'swin_v2_b', 'publication': 'https://arxiv.org/abs/2211.15660', 'repo': 'https://github.com/allenai/satlas', - 'bands': [ - 'B01', - 'B02', - 'B03', - 'B04', - 'B05', - 'B06', - 'B07', - 'B08', - 'B09', - 'B10', - 'B11', - ], # noqa: E501 + 'bands': _satlas_bands, }, ) +def swin_v2_t( + weights: Swin_V2_T_Weights | None = None, *args: Any, **kwargs: Any +) -> SwinTransformer: + """Swin Transformer v2 tiny model. + + If you use this model in your research, please cite the following paper: + + * https://arxiv.org/abs/2111.09883 + + .. versionadded:: 0.6 + + Args: + weights: Pre-trained model weights to use. + *args: Additional arguments to + pass to :class:`torchvision.models.swin_transformer.SwinTransformer`. + **kwargs: Additional keyword arguments to + pass to :class:`torchvision.models.swin_transformer.SwinTransformer`. + + Returns: + A Swin Transformer Tiny model. + """ + model: SwinTransformer = torchvision.models.swin_v2_t(weights=None, *args, **kwargs) + + if weights: + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + # https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27 + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) + missing_keys, unexpected_keys = model.load_state_dict( + weights.get_state_dict(progress=True), strict=False + ) + assert set(missing_keys) <= set() + assert not unexpected_keys + + return model + + def swin_v2_b( weights: Swin_V2_B_Weights | None = None, *args: Any, **kwargs: Any ) -> SwinTransformer: @@ -146,7 +305,7 @@ def swin_v2_b( weights: Pre-trained model weights to use. *args: Additional arguments to pass to :class:`torchvision.models.swin_transformer.SwinTransformer`. - **kwargs: Additional keywork arguments to + **kwargs: Additional keyword arguments to pass to :class:`torchvision.models.swin_transformer.SwinTransformer`. Returns: @@ -155,6 +314,16 @@ def swin_v2_b( model: SwinTransformer = torchvision.models.swin_v2_b(weights=None, *args, **kwargs) if weights: - model.load_state_dict(weights.get_state_dict(progress=True), strict=False) + num_channels = weights.meta['in_chans'] + out_channels = model.features[0][0].out_channels + # https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py#L27 + model.features[0][0] = torch.nn.Conv2d( + num_channels, out_channels, kernel_size=(4, 4), stride=(4, 4) + ) + missing_keys, unexpected_keys = model.load_state_dict( + weights.get_state_dict(progress=True), strict=False + ) + assert set(missing_keys) <= set() + assert not unexpected_keys return model diff --git a/torchgeo/models/vit.py b/torchgeo/models/vit.py index a81ac13d48a..3c876ed3fe7 100644 --- a/torchgeo/models/vit.py +++ b/torchgeo/models/vit.py @@ -11,8 +11,17 @@ from timm.models.vision_transformer import VisionTransformer from torchvision.models._api import Weights, WeightsEnum -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 # noqa: E501 -# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # noqa: E501 +from .resnet import ( + _landsat_etm_sr_bands, + _landsat_etm_toa_bands, + _landsat_oli_sr_bands, + _landsat_oli_tirs_toa_bands, + _landsat_tm_toa_bands, + _sentinel2_toa_bands, +) + +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167 +# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97 # Normalization either by 10K or channel-wise with band statistics _zhu_xlab_transforms = K.AugmentationSequential( K.Resize(256), @@ -21,7 +30,7 @@ data_keys=None, ) -# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 # noqa: E501 +# https://github.com/microsoft/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43 _ssl4eo_l_transforms = K.AugmentationSequential( K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), K.CenterCrop((224, 224)), @@ -37,14 +46,14 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] """Vision Transformer Small Patch Size 16 weights. - For `timm `_ + For `timm `_ *vit_small_patch16_224* implementation. .. versionadded:: 0.4 """ LANDSAT_TM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -53,11 +62,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_tm_toa_bands, }, ) LANDSAT_TM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -66,11 +76,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_tm_toa_bands, }, ) LANDSAT_ETM_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -79,11 +90,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_etm_toa_bands, }, ) LANDSAT_ETM_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -92,11 +104,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_etm_toa_bands, }, ) LANDSAT_ETM_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -105,11 +118,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_etm_sr_bands, }, ) LANDSAT_ETM_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -118,11 +132,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_etm_sr_bands, }, ) LANDSAT_OLI_TIRS_TOA_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -131,11 +146,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_oli_tirs_toa_bands, }, ) LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -144,11 +160,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_oli_tirs_toa_bands, }, ) LANDSAT_OLI_SR_MOCO = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -157,11 +174,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'moco', + 'bands': _landsat_oli_sr_bands, }, ) LANDSAT_OLI_SR_SIMCLR = Weights( - url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', # noqa: E501 + url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth', transforms=_ssl4eo_l_transforms, meta={ 'dataset': 'SSL4EO-L', @@ -170,11 +188,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2306.09424', 'repo': 'https://github.com/microsoft/torchgeo', 'ssl_method': 'simclr', + 'bands': _landsat_oli_sr_bands, }, ) SENTINEL2_ALL_DINO = Weights( - url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', # noqa: E501 + url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -183,11 +202,12 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'dino', + 'bands': _sentinel2_toa_bands, }, ) SENTINEL2_ALL_MOCO = Weights( - url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', # noqa: E501 + url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth', transforms=_zhu_xlab_transforms, meta={ 'dataset': 'SSL4EO-S12', @@ -196,6 +216,7 @@ class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc] 'publication': 'https://arxiv.org/abs/2211.07044', 'repo': 'https://github.com/zhu-xlab/SSL4EO-S12', 'ssl_method': 'moco', + 'bands': _sentinel2_toa_bands, }, ) @@ -214,7 +235,7 @@ def vit_small_patch16_224( Args: weights: Pre-trained model weights to use. *args: Additional arguments to pass to :func:`timm.create_model`. - **kwargs: Additional keywork arguments to pass to :func:`timm.create_model`. + **kwargs: Additional keyword arguments to pass to :func:`timm.create_model`. Returns: A ViT small 16 model. diff --git a/torchgeo/samplers/__init__.py b/torchgeo/samplers/__init__.py index ba995c9c782..e9cad8a8f82 100644 --- a/torchgeo/samplers/__init__.py +++ b/torchgeo/samplers/__init__.py @@ -9,18 +9,13 @@ from .utils import get_random_bounding_box, tile_to_chips __all__ = ( - # Samplers + 'BatchGeoSampler', + 'GeoSampler', 'GridGeoSampler', 'PreChippedGeoSampler', - 'RandomGeoSampler', - # Batch samplers 'RandomBatchGeoSampler', - # Base classes - 'GeoSampler', - 'BatchGeoSampler', - # Utilities + 'RandomGeoSampler', + 'Units', 'get_random_bounding_box', 'tile_to_chips', - # Constants - 'Units', ) diff --git a/torchgeo/samplers/batch.py b/torchgeo/samplers/batch.py index 22726f74b2c..686b458ce24 100644 --- a/torchgeo/samplers/batch.py +++ b/torchgeo/samplers/batch.py @@ -8,6 +8,7 @@ import torch from rtree.index import Index, Property +from torch import Generator from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset @@ -70,6 +71,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -86,6 +88,9 @@ def __init__( .. versionchanged:: 0.4 ``length`` parameter is now optional, a reasonable default will be used + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from size: dimensions of each :term:`patch` @@ -97,9 +102,11 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: pseudo-random number generator (PRNG). """ super().__init__(dataset, roi) self.size = _to_tuple(size) + self.generator = generator if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) @@ -144,7 +151,9 @@ def __iter__(self) -> Iterator[list[BoundingBox]]: # Choose random indices within that tile batch = [] for _ in range(self.batch_size): - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) batch.append(bounding_box) yield batch diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index 094142cb647..6fa4331c4b7 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -5,9 +5,11 @@ import abc from collections.abc import Callable, Iterable, Iterator +from functools import partial import torch from rtree.index import Index, Property +from torch import Generator from torch.utils.data import Sampler from ..datasets import BoundingBox, GeoDataset @@ -72,6 +74,7 @@ def __init__( length: int | None = None, roi: BoundingBox | None = None, units: Units = Units.PIXELS, + generator: Generator | None = None, ) -> None: """Initialize a new Sampler instance. @@ -88,6 +91,9 @@ def __init__( .. versionchanged:: 0.4 ``length`` parameter is now optional, a reasonable default will be used + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from size: dimensions of each :term:`patch` @@ -98,6 +104,7 @@ def __init__( roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) units: defines if ``size`` is in pixel or CRS units + generator: pseudo-random number generator (PRNG). """ super().__init__(dataset, roi) self.size = _to_tuple(size) @@ -105,6 +112,7 @@ def __init__( if units == Units.PIXELS: self.size = (self.size[0] * self.res, self.size[1] * self.res) + self.generator = generator self.length = 0 self.hits = [] areas = [] @@ -142,7 +150,9 @@ def __iter__(self) -> Iterator[BoundingBox]: bounds = BoundingBox(*hit.bounds) # Choose a random index within that tile - bounding_box = get_random_bounding_box(bounds, self.size, self.res) + bounding_box = get_random_bounding_box( + bounds, self.size, self.res, self.generator + ) yield bounding_box @@ -270,20 +280,30 @@ class PreChippedGeoSampler(GeoSampler): """ def __init__( - self, dataset: GeoDataset, roi: BoundingBox | None = None, shuffle: bool = False + self, + dataset: GeoDataset, + roi: BoundingBox | None = None, + shuffle: bool = False, + generator: torch.Generator | None = None, ) -> None: """Initialize a new Sampler instance. .. versionadded:: 0.3 + .. versionadded:: 0.7 + The *generator* parameter. + Args: dataset: dataset to index from roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) shuffle: if True, reshuffle data at every epoch + generator: pseudo-random number generator (PRNG) used in combination with shuffle. + """ super().__init__(dataset, roi) self.shuffle = shuffle + self.generator = generator self.hits = [] for hit in self.index.intersection(tuple(self.roi), objects=True): @@ -297,7 +317,7 @@ def __iter__(self) -> Iterator[BoundingBox]: """ generator: Callable[[int], Iterable[int]] = range if self.shuffle: - generator = torch.randperm + generator = partial(torch.randperm, generator=self.generator) for idx in generator(len(self)): yield BoundingBox(*self.hits[idx].bounds) diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index a1fca673a3a..48ad760f928 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -7,6 +7,7 @@ from typing import overload import torch +from torch import Generator from ..datasets import BoundingBox @@ -35,7 +36,10 @@ def _to_tuple(value: tuple[float, float] | float) -> tuple[float, float]: def get_random_bounding_box( - bounds: BoundingBox, size: tuple[float, float] | float, res: float + bounds: BoundingBox, + size: tuple[float, float] | float, + res: float, + generator: Generator | None = None, ) -> BoundingBox: """Returns a random bounding box within a given bounding box. @@ -46,10 +50,14 @@ def get_random_bounding_box( * a ``tuple`` of two floats - in which case, the first *float* is used for the height dimension, and the second *float* for the width dimension + .. versionadded:: 0.7 + The *generator* parameter. + Args: bounds: the larger bounding box to sample from size: the size of the bounding box to sample res: the resolution of the image + generator: pseudo-random number generator (PRNG). Returns: randomly sampled bounding box from the extent of the input @@ -64,8 +72,8 @@ def get_random_bounding_box( miny = bounds.miny # Use an integer multiple of res to avoid resampling - minx += int(torch.rand(1).item() * width) * res - miny += int(torch.rand(1).item() * height) * res + minx += int(torch.rand(1, generator=generator).item() * width) * res + miny += int(torch.rand(1, generator=generator).item() * height) * res maxx = minx + t_size[1] maxy = miny + t_size[0] diff --git a/torchgeo/trainers/__init__.py b/torchgeo/trainers/__init__.py index be4fb4a03db..ee69bff0021 100644 --- a/torchgeo/trainers/__init__.py +++ b/torchgeo/trainers/__init__.py @@ -14,19 +14,15 @@ from .simclr import SimCLRTask __all__ = ( - # Supervised + 'BYOLTask', + 'BaseTask', 'ClassificationTask', + 'IOBenchTask', + 'MoCoTask', 'MultiLabelClassificationTask', 'ObjectDetectionTask', 'PixelwiseRegressionTask', 'RegressionTask', 'SemanticSegmentationTask', - # Self-supervised - 'BYOLTask', - 'MoCoTask', 'SimCLRTask', - # Base classes - 'BaseTask', - # Other - 'IOBenchTask', ) diff --git a/torchgeo/trainers/base.py b/torchgeo/trainers/base.py index 1f50ad0ab58..02628ba08cb 100644 --- a/torchgeo/trainers/base.py +++ b/torchgeo/trainers/base.py @@ -19,6 +19,9 @@ class BaseTask(LightningModule, ABC): .. versionadded:: 0.5 """ + #: Parameters to ignore when saving hyperparameters. + ignore: Sequence[str] | str | None = 'weights' + #: Model to train. model: Any @@ -28,14 +31,14 @@ class BaseTask(LightningModule, ABC): #: Whether the goal is to minimize or maximize the performance metric to monitor. mode = 'min' - def __init__(self, ignore: Sequence[str] | str | None = None) -> None: + def __init__(self) -> None: """Initialize a new BaseTask instance. Args: ignore: Arguments to skip when saving hyperparameters. """ super().__init__() - self.save_hyperparameters(ignore=ignore) + self.save_hyperparameters(ignore=self.ignore) self.configure_models() self.configure_losses() self.configure_metrics() @@ -52,7 +55,7 @@ def configure_metrics(self) -> None: def configure_optimizers( self, - ) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig': + ) -> 'lightning.pytorch.utilities.types.OptimizerLRScheduler': """Initialize the optimizer and learning rate scheduler. Returns: diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index 35243eaa545..18df10e02f0 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -66,7 +66,7 @@ def __init__(self, image_size: tuple[int, int] = (256, 256)) -> None: ) def forward(self, x: Tensor) -> Tensor: - """Applys SimCLR augmentations to the input tensor. + """Applies SimCLR augmentations to the input tensor. Args: x: a batch of imagery @@ -122,8 +122,8 @@ class BackboneWrapper(nn.Module): * The output of the encoding layer is passed through the projection head * The forward call returns the output of the projection head - .. versionchanged 0.4: Name changed from *EncoderWrapper* to - *BackboneWrapper*. + .. versionchanged:: 0.4 + Name changed from *EncoderWrapper* to *BackboneWrapper*. """ def __init__( @@ -137,7 +137,7 @@ def __init__( Args: model: model to encode - projection_size: size of the ouput layer of the projector MLP + projection_size: size of the output layer of the projector MLP hidden_size: size of hidden layer of the projector MLP layer: layer from model to project """ @@ -286,7 +286,7 @@ class BYOLTask(BaseTask): Reference implementation: - * https://github.com/deepmind/deepmind-research/tree/master/byol + * https://github.com/google-deepmind/deepmind-research/tree/master/byol If you use this trainer in your research, please cite the following paper: @@ -324,7 +324,7 @@ def __init__( renamed to *model*, *lr*, and *patience*. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model.""" diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index cc293099519..2e2766419a5 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -73,7 +73,7 @@ class and used with 'ce' loss. *lr* and *patience*. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model.""" diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 24a7b4e8b90..3d970abdae0 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -53,6 +53,7 @@ class ObjectDetectionTask(BaseTask): .. versionadded:: 0.4 """ + ignore = None monitor = 'val_map' mode = 'max' @@ -82,7 +83,7 @@ def __init__( weights: Initial model weights. True for ImageNet weights, False or None for random weights. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. + num_classes: Number of prediction classes (including the background). trainable_layers: Number of trainable layers. lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. diff --git a/torchgeo/trainers/iobench.py b/torchgeo/trainers/iobench.py index c8826a1dce5..c7be263dac9 100644 --- a/torchgeo/trainers/iobench.py +++ b/torchgeo/trainers/iobench.py @@ -24,7 +24,7 @@ def configure_models(self) -> None: def configure_optimizers( self, - ) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig': + ) -> 'lightning.pytorch.utilities.types.OptimizerLRScheduler': """Initialize the optimizer. Returns: diff --git a/torchgeo/trainers/moco.py b/torchgeo/trainers/moco.py index 73646c3868a..ce35855c12f 100644 --- a/torchgeo/trainers/moco.py +++ b/torchgeo/trainers/moco.py @@ -136,6 +136,7 @@ class MoCoTask(BaseTask): .. versionadded:: 0.5 """ + ignore = ('weights', 'augmentation1', 'augmentation2') monitor = 'train_loss' def __init__( @@ -219,7 +220,7 @@ def __init__( warnings.warn('MoCo v3 does not use a memory bank') self.weights = weights - super().__init__(ignore=['weights', 'augmentation1', 'augmentation2']) + super().__init__() grayscale_weights = grayscale_weights or torch.ones(in_channels) aug1, aug2 = moco_augmentations(version, size, grayscale_weights) @@ -292,7 +293,7 @@ def configure_losses(self) -> None: def configure_optimizers( self, - ) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig': + ) -> 'lightning.pytorch.utilities.types.OptimizerLRScheduler': """Initialize the optimizer and learning rate scheduler. Returns: diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 86c3423c656..0381316050b 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -77,7 +77,7 @@ def __init__( *lr* and *patience*. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model.""" diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index dad26635c64..f8e519fa493 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -54,7 +54,7 @@ def __init__( model does not support pretrained weights. Pretrained ViT weight enums are not supported yet. in_channels: Number of input channels to model. - num_classes: Number of prediction classes. + num_classes: Number of prediction classes (including the background). num_filters: Number of filters. Only applicable when model='fcn'. loss: Name of the loss function, currently supports 'ce', 'jaccard' or 'focal' loss. @@ -77,7 +77,7 @@ class and used with 'ce' loss. were renamed to *model*, *backbone*, and *weights*. .. versionadded:: 0.5 - The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. + The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters. .. versionchanged:: 0.5 The *weights* parameter now supports WeightEnums and checkpoint paths. @@ -85,10 +85,10 @@ class and used with 'ce' loss. *lr* and *patience*. .. versionchanged:: 0.6 - The *ignore_index* parameter now works for jaccard loss. + The *ignore_index* parameter now works for jaccard loss. """ self.weights = weights - super().__init__(ignore='weights') + super().__init__() def configure_models(self) -> None: """Initialize the model. diff --git a/torchgeo/trainers/simclr.py b/torchgeo/trainers/simclr.py index ba9443e9191..a0625f26ebb 100644 --- a/torchgeo/trainers/simclr.py +++ b/torchgeo/trainers/simclr.py @@ -15,8 +15,8 @@ import torch.nn.functional as F from lightly.loss import NTXentLoss from lightly.models.modules import SimCLRProjectionHead +from lightly.utils.lars import LARS from torch import Tensor -from torch.optim import Adam from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from torchvision.models._api import WeightsEnum @@ -68,6 +68,7 @@ class SimCLRTask(BaseTask): .. versionadded:: 0.5 """ + ignore = ('weights', 'augmentations') monitor = 'train_loss' def __init__( @@ -80,6 +81,7 @@ def __init__( hidden_dim: int | None = None, output_dim: int | None = None, lr: float = 4.8, + momentum: float = 0.9, weight_decay: float = 1e-4, temperature: float = 0.07, memory_bank_size: int = 64000, @@ -90,6 +92,9 @@ def __init__( ) -> None: """Initialize a new SimCLRTask instance. + .. versionadded:: 0.6 + The *momentum* parameter. + Args: model: Name of the `timm `__ model to use. @@ -104,6 +109,7 @@ def __init__( output_dim: Number of output dimensions in projection head (defaults to output dimension of model). lr: Learning rate (0.3 x batch_size / 256 is recommended). + momentum: Momentum factor. weight_decay: Weight decay coefficient (1e-6 for v1, 1e-4 for v2). temperature: Temperature used in NT-Xent loss. memory_bank_size: Size of memory bank (0 for v1, 64K for v2). @@ -135,7 +141,7 @@ def __init__( warnings.warn('SimCLR v2 uses a memory bank') self.weights = weights - super().__init__(ignore=['weights', 'augmentations']) + super().__init__() grayscale_weights = grayscale_weights or torch.ones(in_channels) self.augmentations = augmentations or simclr_augmentations( @@ -280,16 +286,19 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> N def configure_optimizers( self, - ) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig': + ) -> 'lightning.pytorch.utilities.types.OptimizerLRScheduler': """Initialize the optimizer and learning rate scheduler. + .. versionchanged:: 0.6 + Changed from Adam to LARS optimizer. + Returns: Optimizer and learning rate scheduler. """ - # Original paper uses LARS optimizer, but this is not defined in PyTorch - optimizer = Adam( + optimizer = LARS( self.parameters(), lr=self.hparams['lr'], + momentum=self.hparams['momentum'], weight_decay=self.hparams['weight_decay'], ) max_epochs = 200 diff --git a/torchgeo/transforms/indices.py b/torchgeo/transforms/indices.py index d04385d6ad1..ac3fee56fff 100644 --- a/torchgeo/transforms/indices.py +++ b/torchgeo/transforms/indices.py @@ -74,7 +74,7 @@ class AppendNBR(AppendNormalizedDifferenceIndex): If you use this index in your research, please cite the following paper: - * https://www.sciencebase.gov/catalog/item/4f4e4b20e4b07f02db6abb36 + * https://www.yumpu.com/en/document/view/24226870/the-normalized-burn-ratio-and-relationships-to-burn-severity-/7 .. versionadded:: 0.2 """ diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 87484e75730..d8f80bdcaac 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,7 +8,7 @@ import kornia.augmentation as K import torch from einops import rearrange -from kornia.contrib import Lambda, extract_tensor_patches +from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices from kornia.geometry.boxes import Boxes from torch import Tensor @@ -25,7 +25,7 @@ class AugmentationSequential(Module): def __init__( self, - *args: K.base._AugmentationBase | K.ImageSequential | Lambda, + *args: K.base._AugmentationBase | K.ImageSequential, data_keys: list[str], **kwargs: Any, ) -> None: @@ -53,9 +53,9 @@ def __init__( else: keys.append(key) - self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501 + self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Any]) -> dict[str, Any]: """Perform augmentations and update data dict. Args: @@ -99,7 +99,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Convert boxes to default [N, 4] if 'boxes' in batch: - batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # type:ignore[assignment] + batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') # Torchmetrics does not support masks with a channel dimension if 'mask' in batch and batch['mask'].shape[1] == 1: @@ -272,3 +272,54 @@ def apply_transform( out = rearrange(out, 'b t c h w -> (b t) c h w') return out + + +class _Clamp(K.IntensityAugmentationBase2D): + """Clamp images to a specific range.""" + + def __init__( + self, + p: float = 0.5, + p_batch: float = 1, + min: float = 0, + max: float = 1, + same_on_batch: bool = False, + keepdim: bool = False, + ) -> None: + """Initialize a new _Clamp instance. + + Args: + p: Probability for applying an augmentation. This param controls the + augmentation probabilities element-wise for a batch. + p_batch: Probability for applying an augmentation to a batch. This param + controls the augmentation probabilities batch-wise. + min: Minimum value to clamp to. + max: Maximum value to clamp to. + same_on_batch: Apply the same transformation across the batch. + keepdim: Whether to keep the output shape the same as input ``True`` + or broadcast it to the batch form ``False``. + """ + super().__init__( + p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim + ) + self.flags = {'min': min, 'max': max} + + def apply_transform( + self, + input: Tensor, + params: dict[str, Tensor], + flags: dict[str, Any], + transform: Tensor | None = None, + ) -> Tensor: + """Apply the transform. + + Args: + input: the input tensor + params: generated parameters + flags: static parameters + transform: the geometric transformation tensor + + Returns: + the augmented input + """ + return torch.clamp(input, self.flags['min'], self.flags['max'])