diff --git a/openeo/metadata.py b/openeo/metadata.py index 31d97513f..af15a913d 100644 --- a/openeo/metadata.py +++ b/openeo/metadata.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import pystac import warnings from typing import Any, Callable, List, NamedTuple, Optional, Tuple, Union @@ -522,3 +523,63 @@ def _repr_html_(self): def __str__(self) -> str: bands = self.band_names if self.has_band_dimension() else "no bands dimension" return f"CollectionMetadata({self.extent} - {bands} - {self.dimension_names()})" + + +def metadata_from_stac(url: str) -> CubeMetadata: + """ + Reads the band metadata a static STAC catalog or a STAC API Collection and returns it as a :py:class:`CubeMetadata` + + :param url: The URL to a static STAC catalog (STAC Item, STAC Collection, or STAC Catalog) or a specific STAC API Collection + :return: A :py:class:`CubeMetadata` containing the DataCube band metadata from the url. + """ + + def get_band_metadata(eo_bands_location: dict) -> List[Band]: + return [ + Band(name=band["name"], common_name=band.get("common_name"), wavelength_um=band.get("center_wavelength")) + for band in eo_bands_location.get("eo:bands", []) + ] + + def get_band_names(bands: List[Band]) -> List[str]: + return [band.name for band in bands] + + def is_band_asset(asset: pystac.Asset) -> bool: + return "eo:bands" in asset.extra_fields + + stac_object = pystac.read_file(href=url) + + bands = [] + collection = None + + if isinstance(stac_object, pystac.Item): + item = stac_object + if "eo:bands" in item.properties: + eo_bands_location = item.properties + elif item.get_collection() is not None: + collection = item.get_collection() + eo_bands_location = item.get_collection().summaries.lists + else: + eo_bands_location = {} + bands = get_band_metadata(eo_bands_location) + + elif isinstance(stac_object, pystac.Collection): + collection = stac_object + bands = get_band_metadata(collection.summaries.lists) + + # Summaries is not a required field in a STAC collection, so also check the assets + for itm in collection.get_items(): + band_assets = {asset_id: asset for asset_id, asset in itm.get_assets().items() if is_band_asset(asset)} + + for asset in band_assets.values(): + asset_bands = get_band_metadata(asset.extra_fields) + for asset_band in asset_bands: + if asset_band.name not in get_band_names(bands): + bands.append(asset_band) + + else: + assert isinstance(stac_object, pystac.Catalog) + catalog = stac_object + bands = get_band_metadata(catalog.extra_fields.get("summaries", {})) + + band_dimension = BandDimension(name="bands", bands=bands) + metadata = CubeMetadata(dimensions=[band_dimension]) + return metadata diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 8737855df..7840e1f66 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -27,7 +27,14 @@ from openeo.internal.jupyter import VisualDict, VisualList from openeo.internal.processes.builder import ProcessBuilderBase from openeo.internal.warnings import deprecated, legacy_alias -from openeo.metadata import Band, BandDimension, CollectionMetadata, SpatialDimension, TemporalDimension +from openeo.metadata import ( + Band, + BandDimension, + CollectionMetadata, + SpatialDimension, + TemporalDimension, + metadata_from_stac, +) from openeo.rest import ( CapabilitiesException, OpenEoApiError, @@ -1361,6 +1368,10 @@ def load_stac( prop: build_child_callback(pred, parent_parameters=["value"]) for prop, pred in properties.items() } cube = self.datacube_from_process(process_id="load_stac", **arguments) + try: + cube.metadata = metadata_from_stac(url) + except Exception: + _log.warning(f"Failed to extract cube metadata from STAC URL {url}", exc_info=True) return cube def load_ml_model(self, id: Union[str, BatchJob]) -> MlModel: diff --git a/setup.py b/setup.py index 73f7ef8c8..b1810709f 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ "numpy>=1.17.0", "xarray>=0.12.3", "pandas>0.20.0", + "pystac", "deprecated>=1.2.12", 'oschmod>=0.3.12; sys_platform == "win32"', "importlib_resources; python_version<'3.9'", diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 28ac81dc7..9a23fd9f0 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -2,6 +2,7 @@ from typing import List +import json import pytest from openeo.metadata import ( @@ -14,6 +15,7 @@ MetadataException, SpatialDimension, TemporalDimension, + metadata_from_stac, ) @@ -782,3 +784,57 @@ def filter_bbox(self, bbox): assert isinstance(new, MyCubeMetadata) assert orig.bbox is None assert new.bbox == (1, 2, 3, 4) + + +@pytest.mark.parametrize( + "test_stac, expected", + [ + ( + { + "type": "Collection", + "id": "test-collection", + "stac_version": "1.0.0", + "description": "Test collection", + "links": [], + "title": "Test Collection", + "extent": { + "spatial": {"bbox": [[-180.0, -90.0, 180.0, 90.0]]}, + "temporal": {"interval": [["2020-01-01T00:00:00Z", "2020-01-10T00:00:00Z"]]}, + }, + "license": "proprietary", + "summaries": {"eo:bands": [{"name": "B01"}, {"name": "B02"}]}, + }, + ["B01", "B02"], + ), + ( + { + "type": "Catalog", + "id": "test-catalog", + "stac_version": "1.0.0", + "description": "Test Catalog", + "links": [], + }, + [], + ), + ( + { + "type": "Feature", + "stac_version": "1.0.0", + "id": "test-item", + "properties": {"datetime": "2020-05-22T00:00:00Z", "eo:bands": [{"name": "SCL"}, {"name": "B08"}]}, + "geometry": {"coordinates": [[[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]], "type": "Polygon"}, + "links": [], + "assets": {}, + "bbox": [0, 1, 0, 1], + "stac_extensions": [], + }, + ["SCL", "B08"], + ), + ], +) +def test_metadata_from_stac(tmp_path, test_stac, expected): + + path = tmp_path / "stac.json" + path.write_text(json.dumps(test_stac)) + metadata = metadata_from_stac(path) + assert metadata.band_names == expected