diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 00000000000..86276ecaaea --- /dev/null +++ b/test.ipynb @@ -0,0 +1,1195 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Populating index\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 16/16 [00:00<00:00, 76.25it/s]\n" + ] + } + ], + "source": [ + "from torchgeo.datasets import Sentinel2\n", + "\n", + "data_dir = r\"tests\\data\\sentinel2\"\n", + "\n", + "ds = Sentinel2(data_dir, bands=[\"B02\", \"B03\", \"B04\", \"B08\"], cache=False, res=10)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BoundingBox(minx=399960.0, maxx=401240.0, miny=4498720.0, maxy=4500000.0, mint=1555079321.0, maxt=1649927271.999999)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.bounds" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 4, 13, 13])\n", + "[[datetime.datetime(2022, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 14, 11, 7, 51), datetime.datetime(2022, 4, 14, 11, 7, 51)]]\n" + ] + } + ], + "source": [ + "from torchgeo.datasets.utils import BoundingBox\n", + "full_t_query = BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1555079321.0, maxt=1649927271.999999)\n", + "sample = ds[[full_t_query]]\n", + "print(sample[\"image\"].shape)\n", + "print(sample[\"dates\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 4, 13, 13])\n", + "[[datetime.datetime(2019, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 14, 11, 7, 51)], [datetime.datetime(2022, 4, 12, 16, 28, 41), datetime.datetime(2022, 4, 14, 11, 7, 51)]]\n" + ] + } + ], + "source": [ + "multi_t_query = [BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1555079321.0, maxt=1605264929),\n", + " BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1605264929.0, maxt=1649927272),\n", + " ]\n", + "sample = ds[multi_t_query]\n", + "print(sample[\"image\"].shape)\n", + "print(sample[\"dates\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import plotly.express as px\n", + "\n", + "def plot(\n", + " sample: dict,\n", + " indices_to_plot,\n", + " show = False,\n", + " **kwargs,\n", + "):\n", + " \"\"\"Plots the image data from the given sample.\n", + "\n", + " Args:\n", + " sample (dict): A dictionary containing the image data returned by self.__get_item__. Should contain the key \"image\".\n", + " indices_to_plot (list, optional): A list of indices to plot. If not provided, the method will use the RGB bands defined in `self.rgb_bands`.\n", + " show (bool, optional): Whether to display the plot. Defaults to False.\n", + " **kwargs (dict): Additional keyword arguments to be passed to `px.imshow`.\n", + "\n", + " Returns:\n", + " fig: The plotly figure object.\n", + " \"\"\"\n", + " image = sample[\"image\"]\n", + "\n", + " # Reorder and rescale the image\n", + " if (sample[\"image\"].ndim == 4) and (sample[\"image\"].shape[0] > 1):\n", + " # Shape of image = [d, c, h, w]\n", + " image = image[:, indices_to_plot, :, :].permute(0, 2, 3, 1)\n", + " if image.shape[-1] == 1:\n", + " image = image.squeeze(-1)\n", + " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n", + "\n", + " fig = px.imshow(\n", + " image, animation_frame=0, labels={\"animation_frame\": \"Date\"}, **kwargs\n", + " )\n", + " # Todo, currently taking the first date, need to handle multiple dates\n", + " date_labels = [\n", + " dates[0].strftime(\"%m/%d/%Y, %H:%M:%S\") for dates in sample[\"dates\"]\n", + " ]\n", + " for i, label in enumerate(date_labels):\n", + " fig.layout.sliders[0].steps[i].label = label\n", + "\n", + " else:\n", + " image = image[indices_to_plot].permute(1, 2, 0)\n", + " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n", + "\n", + " # Plot the image\n", + " fig = px.imshow(image, **kwargs)\n", + "\n", + " fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)\n", + " if show:\n", + " fig.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "x: %{x}
y: %{y}", + "name": "0", + "source": "", + "type": "image", + "xaxis": "x", + "yaxis": "y" + } + ], + "frames": [ + { + "data": [ + { + "name": "0", + "source": "", + "type": "image" + } + ], + "layout": { + "margin": { + "t": 60 + } + }, + "name": "0" + }, + { + "data": [ + { + "name": "1", + "source": "", + "type": "image" + } + ], + "layout": { + "margin": { + "t": 60 + } + }, + "name": "1" + } + ], + "layout": { + "margin": { + "t": 60 + }, + "sliders": [ + { + "active": 0, + "currentvalue": { + "prefix": "Date=" + }, + "len": 0.9, + "pad": { + "b": 10, + "t": 60 + }, + "steps": [ + { + "args": [ + [ + "0" + ], + { + "frame": { + "duration": 0, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 0, + "easing": "linear" + } + } + ], + "label": "04/12/2019, 16:28:41", + "method": "animate" + }, + { + "args": [ + [ + "1" + ], + { + "frame": { + "duration": 0, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 0, + "easing": "linear" + } + } + ], + "label": "04/12/2022, 16:28:41", + "method": "animate" + } + ], + "x": 0.1, + "xanchor": "left", + "y": 0, + "yanchor": "top" + } + ], + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "updatemenus": [ + { + "buttons": [ + { + "args": [ + null, + { + "frame": { + "duration": 500, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 500, + "easing": "linear" + } + } + ], + "label": "▶", + "method": "animate" + }, + { + "args": [ + [ + null + ], + { + "frame": { + "duration": 0, + "redraw": true + }, + "fromcurrent": true, + "mode": "immediate", + "transition": { + "duration": 0, + "easing": "linear" + } + } + ], + "label": "◼", + "method": "animate" + } + ], + "direction": "left", + "pad": { + "r": 10, + "t": 70 + }, + "showactive": false, + "type": "buttons", + "x": 0.1, + "xanchor": "right", + "y": 0, + "yanchor": "top" + } + ], + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "showticklabels": false + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "showticklabels": false + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot(sample, show=False, indices_to_plot=[2, 1, 0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cca", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 26a035d427d..9ea6de07746 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -12,11 +12,15 @@ import sys import warnings from collections.abc import Callable, Iterable, Sequence +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from re import Pattern from typing import Any, ClassVar, cast import fiona import fiona.transform import numpy as np +import pandas as pd import pyproj import rasterio import rasterio.merge @@ -31,6 +35,7 @@ from torch.utils.data import Dataset from torchvision.datasets import ImageFolder from torchvision.datasets.folder import default_loader as pil_loader +from tqdm import tqdm from .errors import DatasetNotFoundError from .utils import ( @@ -125,7 +130,7 @@ def __init__( self.index = Index(interleaved=False, properties=Property(dimension=3)) @abc.abstractmethod - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: + def __getitem__(self, query: Iterable[BoundingBox]) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. Args: @@ -377,6 +382,9 @@ class RasterDataset(GeoDataset): #: Color map for the dataset, used for plotting cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {} + #: Nodata value for the dataset + nodata_value: int | None = None + @property def dtype(self) -> torch.dtype: """The dtype of the dataset (overrides the dtype of the data file via a cast). @@ -420,6 +428,7 @@ def __init__( bands: Sequence[str] | None = None, transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, cache: bool = True, + drop_nodata: bool = True, ) -> None: """Initialize a new RasterDataset instance. @@ -433,6 +442,7 @@ def __init__( transforms: a function/transform that takes an input sample and returns a transformed version cache: if True, cache file handle to speed up repeated sampling + drop_nodata: Drop the sample if any pixel contains nodata value. Raises: DatasetNotFoundError: If dataset is not found. @@ -445,50 +455,10 @@ def __init__( self.paths = paths self.bands = bands or self.all_bands self.cache = cache + self.drop_nodata = drop_nodata - # Populate the dataset index - i = 0 - filename_regex = re.compile(self.filename_regex, re.VERBOSE) - for filepath in self.files: - match = re.match(filename_regex, os.path.basename(filepath)) - if match is not None: - try: - with rasterio.open(filepath) as src: - # See if file has a color map - if len(self.cmap) == 0: - try: - self.cmap = src.colormap(1) # type: ignore[misc] - except ValueError: - pass - - if crs is None: - crs = src.crs - - with WarpedVRT(src, crs=crs) as vrt: - minx, miny, maxx, maxy = vrt.bounds - if res is None: - res = vrt.res[0] - except rasterio.errors.RasterioIOError: - # Skip files that rasterio is unable to read - continue - else: - mint = self.mint - maxt = self.maxt - if 'date' in match.groupdict(): - date = match.group('date') - mint, maxt = disambiguate_timestamp(date, self.date_format) - elif 'start' in match.groupdict() and 'stop' in match.groupdict(): - start = match.group('start') - stop = match.group('stop') - mint, _ = disambiguate_timestamp(start, self.date_format) - _, maxt = disambiguate_timestamp(stop, self.date_format) - - coords = (minx, maxx, miny, maxy, mint, maxt) - self.index.insert(i, coords, filepath) - i += 1 - - if i == 0: - raise DatasetNotFoundError(self) + crs, res = self.try_set_metadata(crs, res) + self._populate_index(crs) if not self.separate_files: self.band_indexes = None @@ -505,50 +475,255 @@ def __init__( raise AssertionError(msg) self._crs = cast(CRS, crs) - self._res = cast(float, res) + self._res = res - def __getitem__(self, query: BoundingBox) -> dict[str, Any]: - """Retrieve image/mask and metadata indexed by query. + def try_set_metadata(self, crs: CRS, res: float | None) -> tuple[CRS, float]: + """Try to set the CRS, resolution and cmap from the first file in the dataset. Args: - query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + crs: The coordinate reference system (CRS) to use. + res: The resolution of the dataset in units of CRS. Returns: - sample of image/mask and metadata at that index + tuple: The coordinate reference system (CRS) and resolution of the dataset. + """ + with rasterio.open(self.files[0]) as src: + # See if file has a color map + if len(self.cmap) == 0: + try: + self.cmap = src.colormap(1) # type: ignore[misc] + except ValueError: + pass + + if crs is None: + crs = src.crs + if self.nodata_value is None: + src_nodata_value = src.nodata + if src_nodata_value is not None: + self.nodata_value = src_nodata_value + elif self.drop_nodata: + raise ValueError( + 'drop_nodata is True but nodata is not set in the dataset and could not be read from the file.' + ) - Raises: - IndexError: if query is not found in the index + with WarpedVRT(src, crs=crs) as vrt: + if res is None: + res = vrt.res[0] + return crs, res + + def _get_bounds(self, filepath: str, crs: CRS) -> tuple[tuple[float], str]: + """Retrieves the bounds for a given file path and coordinate reference system (CRS). + + Args: + filepath (str): The path to the file. + crs (str): The coordinate reference system (CRS) to use. + + Returns: + tuple[tuple[float], str]: A tuple containing the bbox coordinates and the filepath. + The bbox coordinates are represented as a tuple of floats in the following order: + (minx, maxx, miny, maxy, mint, maxt). + """ + filename_regex = re.compile(self.filename_regex, re.VERBOSE) + + try: + with rasterio.open(filepath) as src: + with WarpedVRT(src, crs=crs) as vrt: + minx, miny, maxx, maxy = vrt.bounds + match = re.match(filename_regex, os.path.basename(filepath)) + if not match: + raise ValueError(f'No match found for {os.path.basename(filepath)}') + except rasterio.errors.RasterioIOError as e: + raise FileNotFoundError(f'Error reading {filepath}') from e + else: + mint = self.mint + maxt = self.maxt + if 'date' in match.groupdict(): + date = match.group('date') + mint, maxt = disambiguate_timestamp(date, self.date_format) + elif 'start' in match.groupdict() and 'stop' in match.groupdict(): + start = match.group('start') + stop = match.group('stop') + mint, _ = disambiguate_timestamp(start, self.date_format) + _, maxt = disambiguate_timestamp(stop, self.date_format) + else: + # TODO: Optionally, revert to no_date option if date is not found + pass + + bbox = ( + float(minx), + float(maxx), + float(miny), + float(maxy), + float(mint), + float(maxt), + ) + return bbox, filepath + + def _compile_and_check_filename_regex(self) -> Pattern: + """Compiles and checks the filename whether a valid regex pattern is supplied. + + Returns: + Pattern: The compiled regex pattern. + + """ + if 'band' not in self.filename_regex and self.separate_files: + raise ValueError( + 'The term `band` is not in the filename_regex, but separate_files=True. At least provide a regex pattern to distinguish bands.' + ) + return re.compile(self.filename_regex, re.VERBOSE) + + def _populate_index(self, crs: CRS) -> None: + """Populates the dataset index by retrieving index parameters for each filepath in the dataset paths. + + This method uses a ThreadPoolExecutor to concurrently retrieve index parameters for each file that matches the + filename regex. The retrieved parameters are then inserted into the dataset index. + + Args: + crs (str): The coordinate reference system used for warping while opening the file. + + Returns: + None + """ + print('Populating index') + filename_regex = self._compile_and_check_filename_regex() + + # Populate the dataset index + def has_match(filepath: str) -> bool: + """Check if the given filepath matches the specified filename regex and its band is included in `self.bands`. + + Args: + filepath (str): The path to the file to be checked. + + Returns: + bool: True if the filepath matches the filename regex and conditions, False otherwise. + """ + match = re.match(filename_regex, os.path.basename(filepath)) + if match is not None: + if self.separate_files: + return match.group('band') in self.bands + else: + return True + else: + return False + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [ + executor.submit(self._get_bounds, filepath, crs) + for filepath in self.files + if has_match(filepath) + ] + i = 0 + for f in tqdm(as_completed(futures), total=len(futures)): + bbox, filepath = f.result() + self.index.insert(i, bbox, filepath) + i += 1 + + # TODO: Sequential version: choose which to use. + # i = 0 + # for filepath in self.files: + # if has_match(filepath): + # bbox, filepath = self._get_bounds(filepath, crs) + # self.index.insert(i, bbox, filepath) + # i += 1 + + if i == 0: + raise DatasetNotFoundError(self) + + def _get_regex_groups_as_df(self, filepaths: list[str]) -> pd.DataFrame: + """Extracts the regex metadata from a list of filepaths. + + Args: + filepaths (list): A list of filepaths. + + Returns: + pandas.DataFrame: A DataFrame containing the extracted file metadata. + """ + filename_regex = re.compile(self.filename_regex, re.VERBOSE) + file_metadata = [] + for filepath in filepaths: + match = re.match(filename_regex, os.path.basename(filepath)) + if match: + meta = match.groupdict() + meta.update({'filepath': filepath}) + file_metadata.append(meta) + + return pd.DataFrame(file_metadata) + + def __merge_single_bbox( + self, query: BoundingBox + ) -> tuple[torch.Tensor | None, list[str]]: + """Merge all files that intersect with a single bounding box. + + Args: + query: (BoundingBox) Bounds of the query + + Returns: + tuple[torch.Tensor, list[str]]: A tuple containing the merged tensor and the list of dates that produced that tensor. """ hits = self.index.intersection(tuple(query), objects=True) filepaths = cast(list[str], [hit.object for hit in hits]) - if not filepaths: raise IndexError( f'query: {query} not found in index with bounds: {self.bounds}' ) + file_df = self._get_regex_groups_as_df(filepaths) + if self.separate_files: - data_list: list[Tensor] = [] - filename_regex = re.compile(self.filename_regex, re.VERBOSE) - for band in self.bands: - band_filepaths = [] - for filepath in filepaths: - filename = os.path.basename(filepath) - directory = os.path.dirname(filepath) - match = re.match(filename_regex, filename) - if match: - if 'band' in match.groupdict(): - start = match.start('band') - end = match.end('band') - filename = filename[:start] + band + filename[end:] - filepath = os.path.join(directory, filename) - band_filepaths.append(filepath) - data_list.append(self._merge_files(band_filepaths, query)) - data = torch.cat(data_list) + grouped = file_df.groupby(['band']).agg(list) + res_for_bbox = [] + for band, filepaths in grouped.sort_values('band')['filepath'].items(): + single_bbox_single_band = self._merge_files(filepaths, query) + res_for_bbox.append(single_bbox_single_band) + + res_single_bbox = ( + torch.cat(res_for_bbox).unsqueeze(0) + if len(res_for_bbox) == len(self.bands) + else None + ) else: - data = self._merge_files(filepaths, query, self.band_indexes) + res_single_bbox = self._merge_files(filepaths, query, self.band_indexes) + # TODO ideally, we want feedback from rasterio.merge.merge to know which dates were merged + dates = file_df['date'].unique().tolist() # TODO what if no dates? + dates = [datetime.strptime(date, self.date_format) for date in dates] + + if res_single_bbox is not None: + # Check if res_single_date contains nodata values and only append if it doesn't + if not self.drop_nodata or not torch.any( + res_single_bbox == self.nodata_value + ): + return res_single_bbox, dates + return None, [] + + def __merge_query( + self, query: Iterable[BoundingBox] + ) -> tuple[torch.Tensor, list[str]]: + res = [] + valid_dates = [] + for bbox in query: + res_single_bbox, dates = self.__merge_single_bbox(bbox) + if res_single_bbox is not None: + res.append(res_single_bbox) + valid_dates.append(dates) + return torch.cat(res), valid_dates + + def __getitem__(self, query: BoundingBox | Iterable[BoundingBox]) -> dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image/mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + if isinstance(query, BoundingBox): + query = [query] + data, valid_dates = self.__merge_query(query) - sample = {'crs': self.crs, 'bounds': query} + sample = {'crs': self.crs, 'bounds': query, 'dates': valid_dates} data = data.to(self.dtype) if self.is_image: diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py index 79637931adb..b7d09b891a5 100644 --- a/torchgeo/datasets/sentinel.py +++ b/torchgeo/datasets/sentinel.py @@ -266,7 +266,7 @@ class Sentinel2(Sentinel): # https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-2-msi/naming-convention # https://sentinel.esa.int/documents/247904/685211/Sentinel-2-MSI-L2A-Product-Format-Specifications.pdf - filename_glob = 'T*_*_{}*.*' + filename_glob = 'T*_*_*.*' filename_regex = r""" ^T(?P\d{{2}}[A-Z]{{3}}) _(?P\d{{8}}T\d{{6}}) @@ -295,6 +295,7 @@ class Sentinel2(Sentinel): rgb_bands = ('B04', 'B03', 'B02') separate_files = True + nodata_value = 0 def __init__( self, @@ -325,7 +326,7 @@ def __init__( *root* was renamed to *paths* """ bands = bands or self.all_bands - self.filename_glob = self.filename_glob.format(bands[0]) + # self.filename_glob = self.filename_glob.format(bands[0]) self.filename_regex = self.filename_regex.format(res) super().__init__(paths, crs, res, bands, transforms, cache)