diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index a96bb4f389f..4c61a9d7067 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -42,6 +42,11 @@ EnviroAtlas .. autoclass:: EnviroAtlas +Esri2020 +^^^^^^^^ + +.. autoclass:: Esri2020 + Landsat ^^^^^^^ diff --git a/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip new file mode 100644 index 00000000000..217161b7a15 Binary files /dev/null and b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip differ diff --git a/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01/00A_20200101-20210101.tif b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01/00A_20200101-20210101.tif new file mode 100644 index 00000000000..669e3f47342 Binary files /dev/null and b/tests/data/esri2020/io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01/00A_20200101-20210101.tif differ diff --git a/tests/datasets/test_esri2020.py b/tests/datasets/test_esri2020.py new file mode 100644 index 00000000000..cf9fda2f960 --- /dev/null +++ b/tests/datasets/test_esri2020.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import pytest +import torch +import torch.nn as nn +from _pytest.monkeypatch import MonkeyPatch +from rasterio.crs import CRS + +import torchgeo.datasets.utils +from torchgeo.datasets import BoundingBox, Esri2020, IntersectionDataset, UnionDataset + + +def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: + shutil.copy(url, root) + + +class TestEsri2020: + @pytest.fixture + def dataset( + self, monkeypatch: Generator[MonkeyPatch, None, None], tmp_path: Path + ) -> Esri2020: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.esri2020, "download_url", download_url + ) + zipfile = "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip" + monkeypatch.setattr(Esri2020, "zipfile", zipfile) # type: ignore[attr-defined] + + md5 = "4932855fcd00735a34b74b1f87db3df0" + monkeypatch.setattr(Esri2020, "md5", md5) # type: ignore[attr-defined] + url = os.path.join( + "tests", + "data", + "esri2020", + "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip", + ) + monkeypatch.setattr(Esri2020, "url", url) # type: ignore[attr-defined] + root = str(tmp_path) + transforms = nn.Identity() # type: ignore[attr-defined] + return Esri2020(root, transforms=transforms, download=True, checksum=True) + + def test_already_downloaded(self, tmp_path: Path) -> None: + url = os.path.join( + "tests", + "data", + "esri2020", + "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip", + ) + root = str(tmp_path) + shutil.copy(url, root) + Esri2020(root) + + def test_getitem(self, dataset: Esri2020) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["crs"], CRS) + assert isinstance(x["mask"], torch.Tensor) + + def test_already_extracted(self, dataset: Esri2020) -> None: + Esri2020(root=dataset.root, download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + Esri2020(str(tmp_path), checksum=True) + + def test_and(self, dataset: Esri2020) -> None: + ds = dataset & dataset + assert isinstance(ds, IntersectionDataset) + + def test_or(self, dataset: Esri2020) -> None: + ds = dataset | dataset + assert isinstance(ds, UnionDataset) + + def test_plot(self, dataset: Esri2020) -> None: + query = dataset.bounds + x = dataset[query] + dataset.plot(x["mask"]) + + def test_url(self) -> None: + ds = Esri2020(os.path.join("tests", "data", "esri2020")) + assert "ai4edataeuwest.blob.core.windows.net" in ds.url + + def test_invalid_query(self, dataset: Esri2020) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 9250920c301..e99e857eda4 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -26,6 +26,7 @@ from .cyclone import TropicalCycloneWindEstimation from .dfc2022 import DFC2022 from .enviroatlas import EnviroAtlas +from .esri2020 import Esri2020 from .etci2021 import ETCI2021 from .eurosat import EuroSAT from .fair1m import FAIR1M @@ -96,6 +97,7 @@ "ChesapeakeVA", "ChesapeakeWV", "ChesapeakeCVPR", + "Esri2020", "Landsat", "Landsat1", "Landsat2", diff --git a/torchgeo/datasets/esri2020.py b/torchgeo/datasets/esri2020.py new file mode 100644 index 00000000000..4d8c651a6a0 --- /dev/null +++ b/torchgeo/datasets/esri2020.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Esri 2020 Land Cover Dataset.""" + +import abc +import glob +import os +from typing import Any, Callable, Dict, Optional + +from rasterio.crs import CRS + +from .geo import RasterDataset +from .utils import download_url, extract_archive + + +class Esri2020(RasterDataset, abc.ABC): + """Esri 2020 Land Cover Dataset. + + The `Esri 2020 Land Cover dataset + `_ + consists of a global single band land use/land cover map derived from ESA + Sentinel-2 imagery at 10m resolution with a total of 10 classes. + It was published in July 2021 and used the Universal Transverse Mercator (UTM) + projection. This dataset only contains labels, no raw satellite imagery. + + The 10 classes are: + + 0. No Data + 1. Water + 2. Trees + 3. Grass + 4. Flooded Vegetation + 5. Crops + 6. Scrub/Shrub + 7. Built Area + 8. Bare Ground + 9. Snow/Ice + 10. Clouds + + A more detailed explanation of the invidual classes can be found + `here `_. + + If you use this dataset please cite the following paper: + + * https://ieeexplore.ieee.org/document/9553499 + + .. versionadded:: 0.3 + """ + + is_image = False + filename_glob = "*_20200101-20210101.*" + filename_regex = r"""^ + (?P[0-9][0-9][A-Z]) + _(?P\d{8}) + -(?P\d{8}) + """ + + zipfile = "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip" + md5 = "4932855fcd00735a34b74b1f87db3df0" + + url = ( + "https://ai4edataeuwest.blob.core.windows.net/io-lulc/" + "io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip" + ) + + def __init__( + self, + root: str = "data", + crs: Optional[CRS] = None, + res: Optional[float] = None, + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + cache: bool = True, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new Dataset instance. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + transforms: a function/transform that takes an input sample + and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + FileNotFoundError: if no files are found in ``root`` + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + self.root = root + self.download = download + self.checksum = checksum + + self._verify() + + super().__init__(root, crs, res, transforms, cache) + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + # Check if the extracted file already exists + pathname = os.path.join(self.root, "**", self.filename_glob) + if glob.glob(pathname): + return + + # Check if the zip files have already been downloaded + pathname = os.path.join(self.root, self.zipfile) + if glob.glob(pathname): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automaticaly download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset.""" + download_url(self.url, self.root, filename=self.zipfile, md5=self.md5) + + def _extract(self) -> None: + """Extract the dataset.""" + extract_archive(os.path.join(self.root, self.zipfile))