diff --git a/docs/index.rst b/docs/index.rst index ced959493a8..22c203ce9df 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,6 +30,7 @@ torchgeo tutorials/getting_started tutorials/custom_raster_dataset + tutorials/sits_dataset tutorials/transforms tutorials/indices tutorials/trainers diff --git a/docs/tutorials/sits_dataset.ipynb b/docs/tutorials/sits_dataset.ipynb new file mode 100644 index 00000000000..874e9bf7567 --- /dev/null +++ b/docs/tutorials/sits_dataset.ipynb @@ -0,0 +1,422 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to use a SITS dataset\n", + "\n", + "This notebook shows how to create a Satellite Image Time Series (SITS) dataset. The idea is that with this type of sampling, the sample image that is being returned has the shape `[batch, dates, channels, height, width]`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Additional requirements" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install plotly planetary_computer pystac_client tqdm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare our data\n", + "\n", + "For this example we are using the RGB bands of the same Sentinel2 tile for 5 different dates. Note that right now we are selecting a specific orbit to ensure that our data covers the same spatial extent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import planetary_computer\n", + "import pystac_client\n", + "\n", + "catalog = pystac_client.Client.open(\n", + " 'https://planetarycomputer.microsoft.com/api/stac/v1',\n", + " modifier=planetary_computer.sign_inplace,\n", + ")\n", + "area_of_interest = {\n", + " 'type': 'Polygon',\n", + " 'coordinates': [\n", + " [\n", + " [-148.56536865234375, 60.80072385643073],\n", + " [-147.44338989257812, 60.80072385643073],\n", + " [-147.44338989257812, 61.18363894915102],\n", + " [-148.56536865234375, 61.18363894915102],\n", + " [-148.56536865234375, 60.80072385643073],\n", + " ]\n", + " ],\n", + "}\n", + "time_of_interest = '2019-06-01/2019-10-01'\n", + "search = catalog.search(\n", + " collections=['sentinel-2-l2a'],\n", + " intersects=area_of_interest,\n", + " datetime=time_of_interest,\n", + " query={'eo:cloud_cover': {'lt': 13}, 'sat:relative_orbit': {'eq': 143}},\n", + ")\n", + "\n", + "# Check how many items were returned\n", + "items = search.item_collection()\n", + "items" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "from urllib.parse import urlparse\n", + "\n", + "import planetary_computer\n", + "import pystac\n", + "\n", + "from torchgeo.datasets.utils import download_url\n", + "\n", + "root = os.path.join(tempfile.gettempdir(), 'sentinel')\n", + "item_urls = [item.links[3].href for item in items]\n", + "\n", + "for item_url in item_urls:\n", + " item = pystac.Item.from_file(item_url)\n", + " signed_item = planetary_computer.sign(item)\n", + " for band in ['B02', 'B03', 'B04']:\n", + " asset_href = signed_item.assets[band].href\n", + " filename = urlparse(asset_href).path.split('/')[-1]\n", + " download_url(asset_href, root, filename)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Normal Raster Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dataset\n", + "We define a custom dataset which is almost idential to the dataset used in `custom_raster_dataset.ipynb`. The main difference is the glob pattern, which matches any file band instead of a single band, since we are populating the index with all possible files. Lets first create a dataset like we are used to, with `return_as_ts=False`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "import plotly.express as px\n", + "import torch\n", + "\n", + "from torchgeo.datasets import RasterDataset\n", + "\n", + "\n", + "class Sentinel2(RasterDataset):\n", + " filename_glob = 'T*.tif'\n", + " filename_regex = r'.*(?P\\d{8}T\\d{6})_(?PB0[\\d])'\n", + " date_format = '%Y%m%dT%H%M%S'\n", + " is_image = True\n", + " separate_files = True\n", + " return_as_ts = False\n", + " all_bands = ('B02', 'B03', 'B04', 'B08')\n", + " rgb_bands = ('B04', 'B03', 'B02')\n", + "\n", + " def plot(self, sample, show=True):\n", + " # Find the correct band index order\n", + " rgb_indices = []\n", + " for band in self.rgb_bands:\n", + " rgb_indices.append(self.all_bands.index(band))\n", + "\n", + " # Reorder and rescale the image\n", + " image = sample['image'][rgb_indices].permute(1, 2, 0)\n", + " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n", + "\n", + " # Plot the image\n", + " fig = px.imshow(image)\n", + " if show:\n", + " fig.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "single_image_dataset = Sentinel2(root)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiate sampler\n", + "We are instantiating a random sampler. This means that we are sampling randomly both spatially, as temporally. " + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "from torchgeo.datasets.utils import stack_samples\n", + "from torchgeo.samplers import RandomGeoSampler\n", + "\n", + "sampler = RandomGeoSampler(single_image_dataset, size=(2000, 2000), length=2)\n", + "\n", + "dataloader = DataLoader(\n", + " single_image_dataset,\n", + " sampler=sampler,\n", + " batch_size=1,\n", + " collate_fn=stack_samples,\n", + " num_workers=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torchgeo.datasets.utils import unbind_samples\n", + "\n", + "for batch in dataloader:\n", + " for sample in unbind_samples(batch):\n", + " single_image_dataset.plot(sample, show=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create the TS dataset\n", + "Now we create the SITS dataset. The only difference is that we add the `return_as_ts=True` and add a custom plotting function that allows to visualize timeseries data." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "class Sentinel2SITS(RasterDataset):\n", + " filename_glob = 'T*.tif'\n", + " filename_regex = r'.*(?P\\d{8}T\\d{6})_(?PB0[\\d])'\n", + " date_format = '%Y%m%dT%H%M%S'\n", + " is_image = True\n", + " separate_files = True\n", + " all_bands = ('B02', 'B03', 'B04')\n", + " rgb_bands = ('B04', 'B03', 'B02')\n", + " return_as_ts = True\n", + "\n", + " def plot(self, sample, indices_to_plot=None, show=False, **kwargs):\n", + " \"\"\"\n", + " Plots the image data from the given sample.\n", + "\n", + " Args:\n", + " sample (dict): A dictionary containing the image data returned by self.__get_item__.\n", + " indices_to_plot (list, optional): A list of indices to plot. If not provided, the method will use the RGB bands defined in `self.rgb_bands`.\n", + " show (bool, optional): Whether to display the plot. Defaults to False.\n", + " **kwargs: Additional keyword arguments to be passed to the plot function.\n", + "\n", + " Returns:\n", + " fig: The plotly figure object.\n", + "\n", + " Raises:\n", + " None\n", + "\n", + " \"\"\"\n", + "\n", + " if indices_to_plot:\n", + " indices = indices_to_plot\n", + " else:\n", + " if self.bands == self.all_bands:\n", + " # Find the correct band index order\n", + " indices = []\n", + " for band in self.rgb_bands:\n", + " indices.append(self.all_bands.index(band))\n", + " else:\n", + " logging.info('No indices to plot provided, using first band by default')\n", + " indices = [0]\n", + "\n", + " logging.info(f'Plotting bands: {[self.bands[i] for i in indices]}')\n", + "\n", + " image = sample['image']\n", + "\n", + " # Reorder and rescale the image\n", + " if self.return_as_ts:\n", + " # Shape of image = [d, c, h, w]\n", + " image = image[:, indices, :, :].permute(0, 2, 3, 1)\n", + " if image.shape[-1] == 1:\n", + " image = image.squeeze(-1)\n", + " image = torch.clamp(image / 5000, min=0, max=1).numpy()\n", + "\n", + " fig = px.imshow(\n", + " image, animation_frame=0, labels={'animation_frame': 'Date'}, **kwargs\n", + " )\n", + " date_labels = [\n", + " date.strftime('%m/%d/%Y, %H:%M:%S') for date in sample['dates']\n", + " ]\n", + " for i, label in enumerate(date_labels):\n", + " fig.layout.sliders[0].steps[i].label = label\n", + "\n", + " else:\n", + " image = image[indices].permute(1, 2, 0)\n", + " image = torch.clamp(image / 5000, min=0, max=1).numpy()\n", + "\n", + " # Plot the image\n", + " fig = px.imshow(image)\n", + "\n", + " fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)\n", + " if show:\n", + " fig.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "sits_dataset = Sentinel2SITS(root)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Instantiate sampler\n", + "We are again using a random sampler. However, since we are sampling SITS, the random sampler returns for every random spatial location the full SITS series." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "sits_sampler = RandomGeoSampler(sits_dataset, size=(2000, 2000), length=2)\n", + "\n", + "sits_dataloader = DataLoader(\n", + " sits_dataset,\n", + " sampler=sits_sampler,\n", + " batch_size=1,\n", + " collate_fn=stack_samples,\n", + " num_workers=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for batch in sits_dataloader:\n", + " for sample in unbind_samples(batch):\n", + " sits_dataset.plot(sample, show=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Accessing the dates of each sample" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sample.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sample['image'].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sample['dates']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sample['crs']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "biovers_bmi", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 8233480443a..53219516e6c 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -12,11 +12,13 @@ import sys import warnings from collections.abc import Callable, Iterable, Sequence +from datetime import datetime from typing import Any, ClassVar, cast import fiona import fiona.transform import numpy as np +import pandas as pd import pyproj import rasterio import rasterio.merge @@ -97,6 +99,9 @@ class GeoDataset(Dataset[dict[str, Any]], abc.ABC): #: a different file format than what it was originally downloaded as. filename_glob = '*' + # Whether to return the dataset as a Timeseries, this will add another dimension to the dataset + return_as_ts = False + # NOTE: according to the Python docs: # # * https://docs.python.org/3/library/exceptions.html#NotImplementedError @@ -368,6 +373,9 @@ class RasterDataset(GeoDataset): #: True if data is stored in a separate file for each band, else False. separate_files = False + # Whether to return imagery as SITS or not + return_as_ts: bool = False + #: Names of all available bands in the dataset all_bands: tuple[str, ...] = () @@ -482,6 +490,10 @@ def __init__( stop = match.group('stop') mint, _ = disambiguate_timestamp(start, self.date_format) _, maxt = disambiguate_timestamp(stop, self.date_format) + elif self.return_as_ts: + warnings.warn( + 'return_as_ts = True, but no date could be found from filename_regex' + ) coords = (minx, maxx, miny, maxy, mint, maxt) self.index.insert(i, coords, filepath) @@ -507,6 +519,26 @@ def __init__( self._crs = cast(CRS, crs) self._res = cast(float, res) + def _get_regex_groups_as_df(self, filepaths: list[str]) -> pd.DataFrame: + """Extracts the regex metadata from a list of filepaths. + + Args: + filepaths (list): A list of filepaths. + + Returns: + pandas.DataFrame: A DataFrame containing the extracted file metadata. + """ + filename_regex = re.compile(self.filename_regex, re.VERBOSE) + file_metadata = [] + for filepath in filepaths: + match = re.match(filename_regex, filepath) + if match: + meta = match.groupdict() + meta.update({'filepath': filepath}) + file_metadata.append(meta) + + return pd.DataFrame(file_metadata) + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve image/mask and metadata indexed by query. @@ -527,28 +559,56 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: f'query: {query} not found in index with bounds: {self.bounds}' ) + self.file_df = self._get_regex_groups_as_df(filepaths) + + res = [] if self.separate_files: - data_list: list[Tensor] = [] - filename_regex = re.compile(self.filename_regex, re.VERBOSE) - for band in self.bands: - band_filepaths = [] - for filepath in filepaths: - filename = os.path.basename(filepath) - directory = os.path.dirname(filepath) - match = re.match(filename_regex, filename) - if match: - if 'band' in match.groupdict(): - start = match.start('band') - end = match.end('band') - filename = filename[:start] + band + filename[end:] - filepath = os.path.join(directory, filename) - band_filepaths.append(filepath) - data_list.append(self._merge_files(band_filepaths, query)) - data = torch.cat(data_list) + if self.return_as_ts: + grouped = self.file_df.groupby(['date', 'band']).agg(list) + all_dates = np.sort(self.file_df.date.unique()) + for date in all_dates: + res_per_date = [ + self._merge_files(filepaths, query) + for band, filepaths in grouped.sort_values('band') + .loc[date, 'filepath'] + .items() + ] + res.append( + torch.cat(res_per_date).unsqueeze(0) + if len(res_per_date) > 0 + else torch.tensor([]) + ) + data = torch.cat(res) if len(res) > 0 else torch.tensor([]) + else: + res_per_band = [ + self._merge_files(filepaths, query) + for band, filepaths in self.file_df.groupby(['band']) + .agg(list) + .loc[:, 'filepath'] + .items() + ] + data = ( + torch.cat(res_per_band) + if len(res_per_band) > 0 + else torch.tensor([]) + ) else: - data = self._merge_files(filepaths, query, self.band_indexes) + if self.return_as_ts: + all_dates = self.file_df.date.unique() + grouped = self.file_df.groupby(['date']).agg(list) + res = [ + self._merge_files(grouped.loc[date, 'filepath'], query).unsqueeze(0) + for date in all_dates + ] + data = torch.cat(res) if len(res) > 0 else torch.tensor([]) + else: + data = self._merge_files(filepaths, query, self.band_indexes) sample = {'crs': self.crs, 'bounds': query} + if self.return_as_ts: + sample['dates'] = [ + datetime.strptime(date, self.date_format) for date in all_dates + ] data = data.to(self.dtype) if self.is_image: @@ -982,6 +1042,7 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('IntersectionDataset only supports GeoDatasets') + self.return_as_ts = dataset1.return_as_ts or dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res @@ -1142,6 +1203,7 @@ def __init__( if not isinstance(ds, GeoDataset): raise ValueError('UnionDataset only supports GeoDatasets') + self.return_as_ts = dataset1.return_as_ts and dataset2.return_as_ts self.crs = dataset1.crs self.res = dataset1.res diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index ea943db3d53..2f62ea5581e 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -33,6 +33,7 @@ def __init__(self, dataset: GeoDataset, roi: BoundingBox | None = None) -> None: roi: region of interest to sample from (minx, maxx, miny, maxy, mint, maxt) (defaults to the bounds of ``dataset.index``) """ + self.dataset = dataset if roi is None: self.index = dataset.index roi = BoundingBox(*self.index.bounds) @@ -144,7 +145,15 @@ def __iter__(self) -> Iterator[BoundingBox]: # Choose a random tile, weighted by area idx = torch.multinomial(self.areas, 1) hit = self.hits[idx] - bounds = BoundingBox(*hit.bounds) + + bounds = hit.bounds + if self.dataset.return_as_ts: + mint = self.index.bounds.mint + maxt = self.index.bounds.maxt + bounds[-2] = mint + bounds[-1] = maxt + + bounds = BoundingBox(*bounds) # Choose a random index within that tile bounding_box = get_random_bounding_box( @@ -238,8 +247,13 @@ def __iter__(self) -> Iterator[BoundingBox]: for hit in self.hits: bounds = BoundingBox(*hit.bounds) rows, cols = tile_to_chips(bounds, self.size, self.stride) - mint = bounds.mint - maxt = bounds.maxt + + if self.dataset.return_as_ts: + mint = self.index.bounds.mint + maxt = self.index.bounds.maxt + else: + mint = bounds.mint + maxt = bounds.maxt # For each row... for i in range(rows): @@ -314,7 +328,18 @@ def __iter__(self) -> Iterator[BoundingBox]: generator = partial(torch.randperm, generator=self.generator) for idx in generator(len(self)): - yield BoundingBox(*self.hits[idx].bounds) + bounding_box = self.hits[idx].bounds + + if self.dataset.return_as_ts: + mint = self.index.bounds.mint + maxt = self.index.bounds.maxt + else: + mint = bounding_box.mint + maxt = bounding_box.maxt + + bounding_box[-2] = mint + bounding_box[-1] = maxt + yield BoundingBox(*bounding_box) def __len__(self) -> int: """Return the number of samples over the ROI.