diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8431b218040..d2a1c2f06ec 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -206,6 +206,11 @@ COWC .. autoclass:: COWCCounting .. autoclass:: COWCDetection +CropHarvest +^^^^^^^^^^^ + +.. autoclass:: CropHarvest + Kenya Crop Type ^^^^^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 146b81bcb9e..a34b918b5ec 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -6,6 +6,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI `Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI `COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB +`CropHarvest`_,"C","Sentinel-1/2, SRTM, ERA5","CC-BY-SA-4.0","70,213",351,1x1,10,"SAR, MSI, SRTM" `Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI `DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB `DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB @@ -49,4 +50,4 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `VHR-10`_,I,"Google Earth, Vaihingen","MIT",800,10,"358--1,728",0.08--2,RGB `Western USA Live Fuel Moisture`_,R,"Landsat8, Sentinel-1","CC-BY-NC-ND-4.0",2615,-,-,-,- `xView2`_,CD,Maxar,"CC-BY-NC-SA-4.0","3,732",4,"1,024x1,024",0.8,RGB -`ZueriCrop`_,"I, T",Sentinel-2,-,116K,48,24x24,10,MSI \ No newline at end of file +`ZueriCrop`_,"I, T",Sentinel-2,-,116K,48,24x24,10,MSI diff --git a/tests/data/cropharvest/data.py b/tests/data/cropharvest/data.py new file mode 100755 index 00000000000..4b882717c83 --- /dev/null +++ b/tests/data/cropharvest/data.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import json +import os +import shutil + +import h5py +import numpy as np + +SIZE = 32 + +np.random.seed(0) + +PATHS = [ + os.path.join("cropharvest", "features", "arrays", "0_TestDataset1.h5"), + os.path.join("cropharvest", "features", "arrays", "1_TestDataset1.h5"), + os.path.join("cropharvest", "features", "arrays", "2_TestDataset1.h5"), + os.path.join("cropharvest", "features", "arrays", "0_TestDataset2.h5"), + os.path.join("cropharvest", "features", "arrays", "1_TestDataset2.h5"), +] + + +def create_geojson(): + geojson = { + "type": "FeatureCollection", + "crs": {}, + "features": [ + { + "type": "Feature", + "properties": { + "dataset": "TestDataset1", + "index": 0, + "is_crop": 1, + "label": "soybean", + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ], + }, + }, + { + "type": "Feature", + "properties": { + "dataset": "TestDataset1", + "index": 0, + "is_crop": 1, + "label": "alfalfa", + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ], + }, + }, + { + "type": "Feature", + "properties": { + "dataset": "TestDataset1", + "index": 1, + "is_crop": 1, + "label": None, + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ], + }, + }, + { + "type": "Feature", + "properties": { + "dataset": "TestDataset2", + "index": 2, + "is_crop": 1, + "label": "maize", + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ], + }, + }, + { + "type": "Feature", + "properties": { + "dataset": "TestDataset2", + "index": 1, + "is_crop": 0, + "label": None, + }, + "geometry": { + "type": "Polygon", + "coordinates": [ + [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]] + ], + }, + }, + ], + } + return geojson + + +def create_file(path: str) -> None: + Z = np.random.randint(4000, size=(12, 18), dtype=np.int64) + with h5py.File(path, "w") as f: + f.create_dataset("array", data=Z) + + +if __name__ == "__main__": + directory = "cropharvest" + + # remove old data + to_remove = [ + os.path.join(directory, "features"), + os.path.join(directory, "features.tar.gz"), + os.path.join(directory, "labels.geojson"), + ] + for path in to_remove: + if os.path.isdir(path): + shutil.rmtree(path) + + label_path = os.path.join(directory, "labels.geojson") + geojson = create_geojson() + os.makedirs(os.path.dirname(label_path), exist_ok=True) + + with open(label_path, "w") as f: + json.dump(geojson, f) + + for path in PATHS: + os.makedirs(os.path.dirname(path), exist_ok=True) + create_file(path) + + # compress data + source_dir = os.path.join(directory, "features") + shutil.make_archive(source_dir, "gztar", directory, "features") + + # compute checksum + with open(label_path, "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{label_path}: {md5}") + + with open(os.path.join(directory, "features.tar.gz"), "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"zipped features: {md5}") diff --git a/tests/data/cropharvest/features.tar.gz b/tests/data/cropharvest/features.tar.gz new file mode 100644 index 00000000000..8f9e4a3aa8c Binary files /dev/null and b/tests/data/cropharvest/features.tar.gz differ diff --git a/tests/data/cropharvest/features/arrays/0_TestDataset1.h5 b/tests/data/cropharvest/features/arrays/0_TestDataset1.h5 new file mode 100644 index 00000000000..a0fd2a58d9e Binary files /dev/null and b/tests/data/cropharvest/features/arrays/0_TestDataset1.h5 differ diff --git a/tests/data/cropharvest/features/arrays/0_TestDataset2.h5 b/tests/data/cropharvest/features/arrays/0_TestDataset2.h5 new file mode 100644 index 00000000000..cf48cb56679 Binary files /dev/null and b/tests/data/cropharvest/features/arrays/0_TestDataset2.h5 differ diff --git a/tests/data/cropharvest/features/arrays/1_TestDataset1.h5 b/tests/data/cropharvest/features/arrays/1_TestDataset1.h5 new file mode 100644 index 00000000000..e134234af78 Binary files /dev/null and b/tests/data/cropharvest/features/arrays/1_TestDataset1.h5 differ diff --git a/tests/data/cropharvest/features/arrays/1_TestDataset2.h5 b/tests/data/cropharvest/features/arrays/1_TestDataset2.h5 new file mode 100644 index 00000000000..fd57acc26b5 Binary files /dev/null and b/tests/data/cropharvest/features/arrays/1_TestDataset2.h5 differ diff --git a/tests/data/cropharvest/features/arrays/2_TestDataset1.h5 b/tests/data/cropharvest/features/arrays/2_TestDataset1.h5 new file mode 100644 index 00000000000..de087375033 Binary files /dev/null and b/tests/data/cropharvest/features/arrays/2_TestDataset1.h5 differ diff --git a/tests/data/cropharvest/labels.geojson b/tests/data/cropharvest/labels.geojson new file mode 100644 index 00000000000..5536f27a90c --- /dev/null +++ b/tests/data/cropharvest/labels.geojson @@ -0,0 +1 @@ +{"type": "FeatureCollection", "crs": {}, "features": [{"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 0, "is_crop": 1, "label": "soybean"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 0, "is_crop": 1, "label": "alfalfa"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 1, "is_crop": 1, "label": null}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset2", "index": 2, "is_crop": 1, "label": "maize"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset2", "index": 1, "is_crop": 0, "label": null}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}]} \ No newline at end of file diff --git a/tests/datasets/test_cropharvest.py b/tests/datasets/test_cropharvest.py new file mode 100644 index 00000000000..ccefd4325b5 --- /dev/null +++ b/tests/datasets/test_cropharvest.py @@ -0,0 +1,100 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import builtins +import os +import shutil +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from pytest import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import CropHarvest, DatasetNotFoundError + +pytest.importorskip("h5py", minversion="3") + + +def download_url(url: str, root: str, filename: str, md5: str) -> None: + shutil.copy(url, os.path.join(root, filename)) + + +class TestCropHarvest: + @pytest.fixture + def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None: + import_orig = builtins.__import__ + + def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "h5py": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest: + monkeypatch.setattr(torchgeo.datasets.cropharvest, "download_url", download_url) + monkeypatch.setitem( + CropHarvest.file_dict["features"], "md5", "ef6f4f00c0b3b50ed8380b0044928572" + ) + monkeypatch.setitem( + CropHarvest.file_dict["labels"], "md5", "1d93b6bfcec7b6797b75acbd9d284b92" + ) + monkeypatch.setitem( + CropHarvest.file_dict["features"], + "url", + os.path.join("tests", "data", "cropharvest", "features.tar.gz"), + ) + monkeypatch.setitem( + CropHarvest.file_dict["labels"], + "url", + os.path.join("tests", "data", "cropharvest", "labels.geojson"), + ) + + root = str(tmp_path) + transforms = nn.Identity() + + dataset = CropHarvest(root, transforms, download=True, checksum=True) + return dataset + + def test_getitem(self, dataset: CropHarvest) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["array"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert x["array"].shape == (12, 18) + y = dataset[2] + assert y["label"] == 1 + + def test_len(self, dataset: CropHarvest) -> None: + assert len(dataset) == 5 + + def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None: + CropHarvest(root=str(tmp_path), download=False) + + def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None: + feature_path = os.path.join(tmp_path, "features") + shutil.rmtree(feature_path) + CropHarvest(root=str(tmp_path), download=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match="Dataset not found"): + CropHarvest(str(tmp_path)) + + def test_plot(self, dataset: CropHarvest) -> None: + x = dataset[0].copy() + dataset.plot(x, subtitle="Test") + plt.close() + + def test_mock_missing_module( + self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None + ) -> None: + with pytest.raises( + ImportError, + match="h5py is not installed and is required to use this dataset", + ): + CropHarvest(root=str(tmp_path), download=True)[0] diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index effca5b5d91..0dcaf3a53c2 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -28,6 +28,7 @@ from .cloud_cover import CloudCoverDetection from .cms_mangrove_canopy import CMSGlobalMangroveCanopy from .cowc import COWC, COWCCounting, COWCDetection +from .cropharvest import CropHarvest from .cv4a_kenya_crop_type import CV4AKenyaCropType from .cyclone import TropicalCyclone from .deepglobelandcover import DeepGlobeLandCover @@ -150,6 +151,7 @@ "ChesapeakeWV", "ChesapeakeCVPR", "CMSGlobalMangroveCanopy", + "CropHarvest", "EDDMapS", "Esri2020", "EUDEM", diff --git a/torchgeo/datasets/cropharvest.py b/torchgeo/datasets/cropharvest.py new file mode 100644 index 00000000000..08d69f2fc82 --- /dev/null +++ b/torchgeo/datasets/cropharvest.py @@ -0,0 +1,318 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""CropHarvest datasets.""" + +import glob +import json +import os +from typing import Callable, Optional + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from matplotlib.figure import Figure +from torch import Tensor + +from .geo import NonGeoDataset +from .utils import DatasetNotFoundError, download_url, extract_archive + + +class CropHarvest(NonGeoDataset): + """CropHarvest dataset. + + `CropHarvest `__ is a + crop classification dataset. + + Dataset features: + + * single pixel time series with crop-type labels + * 18 bands per image over 12 months + + Dataset format: + + * arrays are 12x18 with 18 bands over 12 months + + Dataset properties: + + 1. is_crop - whether or not a single pixel contains cropland + 2. classification_label - optional field identifying a specific crop type + 3. dataset - source dataset for the imagery + 4. lat - latitude + 5. lon - longitude + + If you use this dataset in your research, please cite the following paper: + + * https://openreview.net/forum?id=JtjzUXPEaCu + + This dataset requires the following additional library to be installed: + + * `h5py `_ to load the dataset + + .. versionadded:: 0.6 + """ + + # https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py + all_bands = [ + "VV", + "VH", + "B2", + "B3", + "B4", + "B5", + "B6", + "B7", + "B8", + "B8A", + "B9", + "B11", + "B12", + "temperature_2m", + "total_precipitation", + "elevation", + "slope", + "NDVI", + ] + rgb_bands = ["B4", "B3", "B2"] + + features_url = "https://zenodo.org/records/7257688/files/features.tar.gz?download=1" + labels_url = "https://zenodo.org/records/7257688/files/labels.geojson?download=1" + file_dict = { + "features": { + "url": features_url, + "filename": "features.tar.gz", + "extracted_filename": os.path.join("features", "arrays"), + "md5": "cad4df655c75caac805a80435e46ee3e", + }, + "labels": { + "url": labels_url, + "filename": "labels.geojson", + "extracted_filename": "labels.geojson", + "md5": "bf7bae6812fc7213481aff6a2e34517d", + }, + } + + def __init__( + self, + root: str = "data", + transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new CropHarvest dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + DatasetNotFoundError: If dataset is not found and *download* is False. + ImportError: If h5py is not installed + """ + try: + import h5py # noqa: F401 + except ImportError: + raise ImportError( + "h5py is not installed and is required to use this dataset" + ) + + self.root = root + self.transforms = transforms + self.checksum = checksum + self.download = download + + self._verify() + + self.files = self._load_features(self.root) + self.labels = self._load_labels(self.root) + self.classes = self.labels["properties.label"].unique() + self.classes = self.classes[self.classes != np.array(None)] + self.classes = np.insert(self.classes, 0, ["None", "Other"]) + + def __getitem__(self, index: int) -> dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + single pixel time-series array and label at that index + """ + files = self.files[index] + data = self._load_array(files["chip"]) + + label = self._load_label(files["index"], files["dataset"]) + sample = {"array": data, "label": label} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def _load_features(self, root: str) -> list[dict[str, str]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing path for each of hd5 single pixel time series and + its key for associated data + """ + files = [] + chips = glob.glob( + os.path.join(root, self.file_dict["features"]["extracted_filename"], "*.h5") + ) + chips = sorted(os.path.basename(chip) for chip in chips) + for chip in chips: + chip_path = os.path.join( + root, self.file_dict["features"]["extracted_filename"], chip + ) + index = chip.split("_")[0] + dataset = chip.split("_")[1][:-3] + files.append(dict(chip=chip_path, index=index, dataset=dataset)) + return files + + def _load_labels(self, root: str) -> pd.DataFrame: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + pandas dataframe containing label data for each feature + """ + filename = self.file_dict["labels"]["extracted_filename"] + with open(os.path.join(root, filename), encoding="utf8") as f: + data = json.load(f) + df = pd.json_normalize(data["features"]) + return df + + def _load_array(self, path: str) -> Tensor: + """Load an individual single pixel time series. + + Args: + path: path to the image + + Returns: + the image + """ + import h5py + + filename = os.path.join(path) + with h5py.File(filename, "r") as f: + array = f.get("array")[()] + tensor = torch.from_numpy(array) + return tensor + + def _load_label(self, idx: str, dataset: str) -> Tensor: + """Load the crop-type label for a single pixel time series. + + Args: + idx: sample index in labels.geojson + dataset: dataset name to query labels.geojson + + Returns: + the crop-type label + """ + index = int(idx) + row = self.labels[ + (self.labels["properties.index"] == index) + & (self.labels["properties.dataset"] == dataset) + ] + row = row.to_dict(orient="records")[0] + label = "None" + if row["properties.label"]: + label = row["properties.label"] + elif row["properties.is_crop"] == 1: + label = "Other" + + return torch.tensor(np.where(self.classes == label)[0][0]) + + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if feature files already exist + feature_path = os.path.join( + self.root, self.file_dict["features"]["extracted_filename"] + ) + feature_path_zip = os.path.join( + self.root, self.file_dict["features"]["filename"] + ) + label_path = os.path.join( + self.root, self.file_dict["labels"]["extracted_filename"] + ) + # Check if labels exist + if os.path.exists(label_path): + # Check if features exist + if os.path.exists(feature_path): + return + # Check if features are downloaded in zip format + if os.path.exists(feature_path_zip): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) + + # Download and extract the dataset + self._download() + self._extract() + + def _download(self) -> None: + """Download the dataset and extract it.""" + features_path = os.path.join(self.file_dict["features"]["filename"]) + download_url( + self.file_dict["features"]["url"], + self.root, + filename=features_path, + md5=self.file_dict["features"]["md5"] if self.checksum else None, + ) + + download_url( + self.file_dict["labels"]["url"], + self.root, + filename=os.path.join(self.file_dict["labels"]["filename"]), + md5=self.file_dict["labels"]["md5"] if self.checksum else None, + ) + + def _extract(self) -> None: + """Extract the dataset.""" + features_path = os.path.join(self.root, self.file_dict["features"]["filename"]) + extract_archive(features_path) + + def plot(self, sample: dict[str, Tensor], subtitle: Optional[str] = None) -> Figure: + """Plot a sample from the dataset using bands for Agriculture RGB composite. + + Args: + sample: a sample returned by :meth:`__getitem__` + suptitle: optional subtitle to use for figure + + Returns: + a matplotlib Figure with the rendered sample + """ + fig, axs = plt.subplots() + bands = [self.all_bands.index(band) for band in self.rgb_bands] + rgb = np.array(sample["array"])[:, bands] / 3000 + axs.imshow(rgb[None, ...]) + axs.set_title(f'Crop type: {self.classes[sample["label"]]}') + axs.set_xticks(np.arange(12)) + axs.set_xticklabels(np.arange(12) + 1) + axs.set_yticks([]) + axs.set_xlabel("Month") + if subtitle is not None: + plt.suptitle(subtitle) + + return fig