diff --git a/api/api/routers/datasets/download_dataset.py b/api/api/routers/datasets/download_dataset.py index bc4ebe1d..353c783e 100755 --- a/api/api/routers/datasets/download_dataset.py +++ b/api/api/routers/datasets/download_dataset.py @@ -5,7 +5,10 @@ from ..auth import get_current_user from ...src.models import User -from ...src.usecases.datasets import download_dataset_file, download_stac_catalog, generate_presigned_url +from ...src.usecases.datasets import ( + download_dataset_file, + download_stac_catalog, +) from .responses import download_dataset_responses as responses router = APIRouter() @@ -57,23 +60,3 @@ async def download_stac_Catalog( except Exception as e: logger.exception("datasets:download") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) - - -# @router.get( -# "/{dataset_id}/url/{filename:path}", -# summary="Retrieve a presigend get url", -# responses=responses, -# ) -# async def generate_a_presigned_url( -# dataset_id: str = Path(..., description="ID of the dataset to download"), -# filename: str = Path( -# ..., description="Filename or path to the file to download from the dataset" -# ), # podría ser un path... a/b/c/file.txt -# version: int = Query(None, description="Version of the dataset to download"), -# user: User = Depends(get_current_user), -# ): -# try: -# return generate_presigned_url(dataset_id, filename, version) -# except Exception as e: -# logger.exception("datasets:download") -# raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) \ No newline at end of file diff --git a/api/api/routers/models/create_model.py b/api/api/routers/models/create_model.py index 0764ed4c..45f75545 100755 --- a/api/api/routers/models/create_model.py +++ b/api/api/routers/models/create_model.py @@ -6,7 +6,7 @@ from ..auth import get_current_user from ...src.models import User -from ...src.usecases.models import create_model, create_model_version +from ...src.usecases.models import create_model, create_model_version, create_stac_model from .responses import create_model_responses, version_model_responses router = APIRouter() @@ -42,9 +42,15 @@ def create( raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) -@router.post("/version/{model_id}", summary="Get the version of a model", responses=version_model_responses) -def version_model(model_id: str = Path(..., description="The ID of the model"), - user: User = Depends(get_current_user)): +@router.post( + "/version/{model_id}", + summary="Get the version of a model", + responses=version_model_responses, +) +def version_model( + model_id: str = Path(..., description="The ID of the model"), + user: User = Depends(get_current_user), +): """ Get the version of a model. """ @@ -54,3 +60,20 @@ def version_model(model_id: str = Path(..., description="The ID of the model"), except Exception as e: logger.exception("models:version") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + + +class CreateSTACModelBody(BaseModel): + name: str + + +@router.post("/stac", summary="Create a new stac model") +def create_stac( + body: CreateSTACModelBody, + user: User = Depends(get_current_user), +): + try: + model_id = create_stac_model(user, body.name) + return {"model_id": model_id} + except Exception as e: + logger.exception("datasets:ingest") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) diff --git a/api/api/routers/models/download_model.py b/api/api/routers/models/download_model.py index c3d32a5f..5edd5666 100755 --- a/api/api/routers/models/download_model.py +++ b/api/api/routers/models/download_model.py @@ -5,49 +5,55 @@ from ..auth import get_current_user from ...src.models import User -from ...src.usecases.models import download_model_file +from ...src.usecases.models import download_model_file, download_stac_catalog from .responses import download_model_responses router = APIRouter() logger = logging.getLogger(__name__) -@router.get("/{model_id}/download/{filename:path}", summary="Download a model", responses=download_model_responses) +@router.get( + "/{model_id}/download/{filename:path}", + summary="Download a model", + responses=download_model_responses, +) async def download_model( model_id: str = Path(..., description="ID of the model to download"), - filename: str = Path(..., description="Filename or path to the file to download from the model"), # podría ser un path... a/b/c/file.txt + filename: str = Path( + ..., description="Filename or path to the file to download from the model" + ), # podría ser un path... a/b/c/file.txt version: int = Query(None, description="Version of the model to download"), user: User = Depends(get_current_user), ): """ Download an entire model or a specific model file from the EOTDL. """ + # try: + data_stream, object_info, _filename = download_model_file( + model_id, filename, user, version + ) + response_headers = { + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Type": object_info.content_type, + "Content-Length": str(object_info.size), + } + return StreamingResponse( + data_stream(model_id, _filename), + headers=response_headers, + media_type=object_info.content_type, + ) + # except Exception as e: + # logger.exception("models:download") + # raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + + +@router.get("/{model_id}/download") +async def download_stac_Catalog( + model_id: str, + user: User = Depends(get_current_user), +): try: - data_stream, object_info, _filename = download_model_file( - model_id, filename, user, version - ) - response_headers = { - "Content-Disposition": f'attachment; filename="{filename}"', - "Content-Type": object_info.content_type, - "Content-Length": str(object_info.size), - } - return StreamingResponse( - data_stream(model_id, _filename), - headers=response_headers, - media_type=object_info.content_type, - ) + return download_stac_catalog(model_id, user) except Exception as e: logger.exception("models:download") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) - - -# @router.get("/{model_id}/download") -# async def download_stac_Catalog( -# model_id: str, -# user: User = Depends(get_current_user), -# ): -# try: -# return download_stac_catalog(model_id, user) -# except Exception as e: -# logger.exception("models:download") -# raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) diff --git a/api/api/routers/models/ingest_model.py b/api/api/routers/models/ingest_model.py index 55067015..622517d8 100755 --- a/api/api/routers/models/ingest_model.py +++ b/api/api/routers/models/ingest_model.py @@ -2,12 +2,15 @@ from fastapi import APIRouter, status, Depends, File, Form, UploadFile, Path, Query import logging from typing import List +from pydantic import BaseModel from ..auth import get_current_user from ...src.models import User from ...src.usecases.models import ( ingest_model_files_batch, add_files_batch_to_model_version, + ingest_stac, + ingest_model_file, ) from .responses import ingest_files_responses @@ -15,6 +18,39 @@ logger = logging.getLogger(__name__) +@router.post( + "/{model_id}", + summary="Ingest file to a model", + responses=ingest_files_responses, +) +async def ingest_files( + model_id: str = Path(..., description="ID of the model"), + version: int = Query(None, description="Version of the dataset"), + file: UploadFile = File(..., description="file to ingest"), + checksum: str = Form( + ..., + description="checksum of the file to ingest, calculated with SHA-1", + ), + user: User = Depends(get_current_user), +): + """ + Batch ingest of files to an existing dataset. The batch file must be a compressed file (.zip). + The checksums are calculated using the SHA-1 checksums algorithm. + """ + try: + model_id, model_name, filename = await ingest_model_file( + file, model_id, checksum, user, version + ) + return { + "model_id": model_id, + "model_name": model_name, + "filename": filename, + } + except Exception as e: + logger.exception("datasets:ingest") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + + @router.post( "/{model_id}/batch", summary="Batch ingest files to a model", @@ -36,18 +72,18 @@ async def ingest_files_batch( Batch ingest of files to an existing model. The batch file must be a compressed file (.zip). The checksums are calculated using the SHA-1 checksums algorithm. """ - # try: - model_id, model_name, filenames = await ingest_model_files_batch( - batch, model_id, checksums, user, version - ) - return { - "model_id": model_id, - "model_name": model_name, - "filenames": filenames, - } - # except Exception as e: - # logger.exception("models:ingest") - # raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + try: + model_id, model_name, filenames = await ingest_model_files_batch( + batch, model_id, checksums, user, version + ) + return { + "model_id": model_id, + "model_name": model_name, + "filenames": filenames, + } + except Exception as e: + logger.exception("models:ingest") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) @router.post( @@ -81,46 +117,18 @@ def ingest_existing_file( raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) -# except Exception as e: -# logger.exception("models:ingest") -# raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) +class IngestSTACBody(BaseModel): + stac: dict # json as string -# class IngestSTACBody(BaseModel): -# stac: dict # json as string - - -# @router.put("/stac/{model_id}") -# async def ingest_stac_catalog( -# model_id: str, -# body: IngestSTACBody, -# user: User = Depends(get_current_user), -# ): -# try: -# return ingest_stac(body.stac, model_id, user) -# except Exception as e: -# logger.exception("models:ingest_url") -# raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) - -# class IngestURLBody(BaseModel): -# url: str - - -# @router.post("/{model_id}/url") -# async def ingest_url( -# model_id: str, -# body: IngestURLBody, -# user: User = Depends(get_current_user), -# ): -# # try: -# model_id, model_name, file_name = await ingest_file_url( -# body.url, model_id, user -# ) -# return { -# "model_id": model_id, -# "model_name": model_name, -# "file_name": file_name, -# } -# # except Exception as e: -# # logger.exception("models:ingest") -# # raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) +@router.put("/stac/{model_id}") +def ingest_stac_catalog( + model_id: str, + body: IngestSTACBody, + user: User = Depends(get_current_user), +): + try: + return ingest_stac(body.stac, model_id, user) + except Exception as e: + logger.exception("datasets:ingest_url") + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) diff --git a/api/api/src/models/__init__.py b/api/api/src/models/__init__.py index 17ddad2c..33c9fe2b 100755 --- a/api/api/src/models/__init__.py +++ b/api/api/src/models/__init__.py @@ -3,5 +3,5 @@ from .files import File, Files, UploadingFile, Folder, UploadingFile from .tag import Tag from .usage import Usage, Limits -from .model import Model +from .model import Model, STACModel from .verison import Version diff --git a/api/api/src/models/model.py b/api/api/src/models/model.py index 855532fb..6eb0e201 100755 --- a/api/api/src/models/model.py +++ b/api/api/src/models/model.py @@ -36,3 +36,21 @@ def check_source_is_url(cls, source): if not source.startswith("http") and not source.startswith("https"): raise ValueError("source must be a valid url") return source + + +class STACModel(BaseModel): + uid: str + id: str + name: str + description: str = "" + tags: List[str] = [] + createdAt: datetime = Field(default_factory=datetime.now) + updatedAt: datetime = Field(default_factory=datetime.now) + likes: int = 0 + downloads: int = 0 + quality: int = 1 + size: int = 0 + catalog: dict = {} + items: dict = {} + versions: List[Version] = [] + files: str diff --git a/api/api/src/usecases/datasets/__init__.py b/api/api/src/usecases/datasets/__init__.py index a604f8c5..3a895fbe 100755 --- a/api/api/src/usecases/datasets/__init__.py +++ b/api/api/src/usecases/datasets/__init__.py @@ -12,7 +12,11 @@ add_files_batch_to_dataset_version, ingest_stac, ) # ingest_file_url -from .download_dataset import download_dataset_file, download_stac_catalog, generate_presigned_url +from .download_dataset import ( + download_dataset_file, + download_stac_catalog, + generate_presigned_url, +) from .update_dataset import toggle_like_dataset, update_dataset from .delete_dataset import delete_dataset diff --git a/api/api/src/usecases/datasets/ingest_file.py b/api/api/src/usecases/datasets/ingest_file.py index 65abadfd..08620f01 100755 --- a/api/api/src/usecases/datasets/ingest_file.py +++ b/api/api/src/usecases/datasets/ingest_file.py @@ -99,7 +99,7 @@ def ingest_stac(stac, dataset_id, user): values = gpd.GeoDataFrame.from_features(stac["features"], crs="4326") # ??? # values.to_csv("/tmp/iepa.csv") catalog = values[values["type"] == "Catalog"] - items = values.drop_duplicates(subset='geometry') + items = values.drop_duplicates(subset="geometry") items = items[items["type"] == "Feature"] # convert to geojson items = json.loads(items.to_json()) @@ -140,29 +140,3 @@ def ingest_stac(stac, dataset_id, user): dataset.updatedAt = datetime.now() repo.update_dataset(dataset.id, dataset.model_dump()) return dataset - - -# async def ingest_dataset_file(file, dataset_id, version, parent, checksum, user): -# dataset = retrieve_owned_dataset(dataset_id, user.uid) -# versions = [v.version_id for v in dataset.versions] -# if not version in versions: -# raise DatasetVersionDoesNotExistError() -# filename, file_size = await ingest_file( -# file, version, parent, dataset_id, checksum, dataset.quality, dataset.files -# ) -# version = [v for v in dataset.versions if v.version_id == version][0] -# version.size += file_size # for Q0+ will add, so put to 0 before if necessary -# dataset.updatedAt = datetime.now() -# dataset_db_repo = DatasetsDBRepo() -# dataset_db_repo.update_dataset(dataset.id, dataset.dict()) -# return dataset.id, dataset.name, filename - - -def ingest_file_url(): - # TODO - return - # def get_file_name(self, file): - # return file.split("/")[-1] - - # def persist_file(self, file, dataset_id, filename): - # return os_repo.persist_file_url(file, dataset_id, filename) diff --git a/api/api/src/usecases/models/__init__.py b/api/api/src/usecases/models/__init__.py index d8871e99..3142605a 100755 --- a/api/api/src/usecases/models/__init__.py +++ b/api/api/src/usecases/models/__init__.py @@ -4,9 +4,9 @@ retrieve_models_leaderboard, retrieve_popular_models, ) -from .create_model import create_model +from .create_model import create_model, create_stac_model from .create_model_version import create_model_version -from .download_model import download_model_file +from .download_model import download_model_file, download_stac_catalog from .upload_large_file import ( generate_upload_id, ingest_model_chunk, @@ -16,6 +16,8 @@ from .ingest_file import ( ingest_model_files_batch, add_files_batch_to_model_version, -) # , ingest_stac, ingest_file_url + ingest_stac, + ingest_model_file, +) from .delete_model import delete_model diff --git a/api/api/src/usecases/models/create_model.py b/api/api/src/usecases/models/create_model.py index b7e8b57e..d81283e6 100755 --- a/api/api/src/usecases/models/create_model.py +++ b/api/api/src/usecases/models/create_model.py @@ -1,12 +1,12 @@ -from ...models import Model, Files +from ...models import Model, Files, STACModel from ...errors import ( ModelAlreadyExistsError, ModelDoesNotExistError, ) -from ...repos import ModelsDBRepo +from ...repos import ModelsDBRepo, GeoDBRepo from .retrieve_model import retrieve_model_by_name -from ..user import check_user_can_create_model # , retrieve_user_credentials +from ..user import check_user_can_create_model, retrieve_user_credentials def create_model(user, name, authors, source, license): @@ -33,22 +33,24 @@ def create_model(user, name, authors, source, license): return model.id -# def create_stac_dataset(user, name): -# repo = ModelsDBRepo() -# credentials = retrieve_user_credentials(user) -# geodb_repo = GeoDBRepo(credentials) # validate credentials -# try: -# retrieve_dataset_by_name(name) -# raise ModelAlreadyExistsError() -# except ModelDoesNotExistError: -# check_user_can_create_dataset(user) -# id = repo.generate_id() -# # do we manage files as well or delegate to geodb? -# dataset = STACDataset( -# uid=user.uid, -# id=id, -# name=name, -# ) -# repo.persist_dataset(dataset.model_dump(), dataset.id) -# repo.increase_user_dataset_count(user.uid) -# return dataset.id +def create_stac_model(user, name): + repo = ModelsDBRepo() + credentials = retrieve_user_credentials(user) + geodb_repo = GeoDBRepo(credentials) # validate credentials + try: + retrieve_model_by_name(name) + raise ModelAlreadyExistsError() + except ModelDoesNotExistError: + check_user_can_create_model(user) + id, files_id = repo.generate_id(), repo.generate_id() + files = Files(id=files_id, dataset=id) + model = STACModel( + uid=user.uid, + id=id, + name=name, + files=files_id, + ) + repo.persist_files(files.model_dump(), files.id) + repo.persist_model(model.model_dump(), model.id) + repo.increase_user_model_count(user.uid) + return model.id diff --git a/api/api/src/usecases/models/download_model.py b/api/api/src/usecases/models/download_model.py index f7911015..676dfe73 100755 --- a/api/api/src/usecases/models/download_model.py +++ b/api/api/src/usecases/models/download_model.py @@ -1,33 +1,44 @@ -from ...repos import OSRepo +import json +import prometheus_client + +from ...repos import OSRepo, GeoDBRepo, FilesDBRepo from .retrieve_model import retrieve_model +from ..user import retrieve_user_credentials +from ..datasets.download_dataset import eotdl_api_downloaded_bytes def download_model_file(model_id, filename, user, version=None): os_repo = OSRepo() - retrieve_model(model_id) - # check_user_can_download_model(user) - # TODO: if no version is provided, download most recent file ? - data_stream = os_repo.data_stream + if version is None: # retrieve latest version + version = retrieve_latest_file_version(model_id, filename) + + async def track_download_volume(*args, **kwargs): + async for data in os_repo.data_stream(*args, **kwargs): + eotdl_api_downloaded_bytes.labels(user.email).inc(len(data)) + yield data + filename = f"{filename}_{version}" object_info = os_repo.object_info(model_id, filename) - return data_stream, object_info, filename - - -def download_stac_catalog(): - # TODO - return - # def __call__(self, inputs: Inputs) -> Outputs: - # # check if model exists and user is owner - # data = self.db_repo.retrieve("models", inputs.model_id) - # if not data: - # raise modelDoesNotExistError() - # model = STACmodel(**data) - # if model.uid != inputs.user.uid: - # raise modelDoesNotExistError() - # # retrieve from geodb - # credentials = self.retrieve_user_credentials(inputs.user) - # self.geodb_repo = self.geodb_repo(credentials) - # gdf = self.geodb_repo.retrieve(inputs.model_id) - # # report usage - # self.db_repo.increase_counter("models", "id", model.id, "downloads") - # return self.Outputs(stac=json.loads(gdf.to_json())) + return track_download_volume, object_info, filename + + +def download_stac_catalog(model_id, user): + # check if dataset exists + dataset = retrieve_model(model_id) + # retrieve from geodb + credentials = retrieve_user_credentials(user) + geodb_repo = GeoDBRepo(credentials) + gdf = geodb_repo.retrieve(model_id) + # TODO: report usage + return json.loads(gdf.to_json()) + + +def retrieve_latest_file_version(model_id, filename): + files_repo = FilesDBRepo() + model = retrieve_model(model_id) + files = files_repo.retrieve_file(model.files, filename) + if not files or "files" not in files: + raise Exception("File does not exist") + file = sorted(files["files"], key=lambda x: x["version"])[-1] + version = file["version"] + return version diff --git a/api/api/src/usecases/models/ingest_file.py b/api/api/src/usecases/models/ingest_file.py index e65a520b..73012d5e 100755 --- a/api/api/src/usecases/models/ingest_file.py +++ b/api/api/src/usecases/models/ingest_file.py @@ -2,11 +2,36 @@ import zipfile import io import os +import geopandas as gpd +import json from .retrieve_model import retrieve_owned_model from ...errors import ModelVersionDoesNotExistError -from ...repos import ModelsDBRepo +from ...repos import ModelsDBRepo, GeoDBRepo from ..files import ingest_file, ingest_existing_file +from ..utils.stac import STACDataFrame +from ..user import retrieve_user_credentials + + +async def ingest_model_file(file, model_id, checksum, user, version): + model = retrieve_owned_model(model_id, user.uid) + versions = [v.version_id for v in model.versions] + if not version in versions: + raise ModelVersionDoesNotExistError() + file_size = await ingest_file( + file.filename, + file.file, + version, + model_id, + checksum, + model.files, + ) + version = [v for v in model.versions if v.version_id == version][0] + version.size += file_size + model.updatedAt = datetime.now() + model_db_repo = ModelsDBRepo() + model_db_repo.update_model(model.id, model.dict()) + return model.id, model.name, file.filename async def ingest_model_files_batch(batch, model_id, checksums, user, version): @@ -64,55 +89,37 @@ def add_files_batch_to_model_version(filenames, checksums, model_id, version, us return model.id, model.name, filenames -# async def ingest_model_file(file, model_id, version, parent, checksum, user): -# model = retrieve_owned_model(model_id, user.uid) -# versions = [v.version_id for v in model.versions] -# if not version in versions: -# raise modelVersionDoesNotExistError() -# filename, file_size = await ingest_file( -# file, version, parent, model_id, checksum, model.quality, model.files -# ) -# version = [v for v in model.versions if v.version_id == version][0] -# version.size += file_size # for Q0+ will add, so put to 0 before if necessary -# model.updatedAt = datetime.now() -# model_db_repo = modelsDBRepo() -# model_db_repo.update_model(model.id, model.dict()) -# return model.id, model.name, filename - - -def ingest_file_url(): - # TODO - return - # def get_file_name(self, file): - # return file.split("/")[-1] - - # def persist_file(self, file, model_id, filename): - # return os_repo.persist_file_url(file, model_id, filename) - - -def ingest_stac(): - # TODO - return - # # check if model exists - # data = db_repo.retrieve("models", model) - # if not data: - # raise modelDoesNotExistError() - # model = STACmodel(**data) - # # check user owns model - # if model.uid != user.uid: - # raise modelDoesNotExistError() - # # TODO: check all assets exist in os - # # ingest to geodb - # credentials = retrieve_user_credentials(user) - # geodb_repo = geodb_repo(credentials) - # catalog = geodb_repo.insert(model, stac) - # # the catalog should contain all the info we want to show in the UI - # model.catalog = catalog - # keys = list(catalog.keys()) - # if "ml-model:name" in keys: - # model.quality = 2 - # # TODO: compute and report automatic qa metrics - # # TODO: validate Q2 model, not only check name - # # TODO: validate Q1 model with required fields/extensions (author, license) - # db_repo.update("models", model, model.model_dump()) - # return Outputs(model=model) +def ingest_stac(stac, model_id, user): + # check if model exists + model = retrieve_owned_model(model_id, user.uid) + # TODO: check all assets exist in os + # generate catalog + values = gpd.GeoDataFrame.from_features(stac["features"], crs="4326") # ??? + # values.to_csv("/tmp/iepa.csv") + print(values) + catalog = values[values["type"] == "Catalog"] + items = values.drop_duplicates(subset="geometry") + items = items[items["type"] == "Feature"] + assert len(catalog) == 1, "STAC catalog must have exactly one root catalog" + assert len(items) == 1, "Only one item is allowed" + catalog = json.loads(catalog.to_json())["features"][0]["properties"] + item = json.loads(items.to_json())["features"][0]["properties"] + print(item) + model_quality = 1 + # TODO: validate Q2 model, not only check name + if "mlm:name" in item["properties"]: + model_quality = 2 + # compute metrics like we do for Q2 models ? + print("quality", model_quality) + # ingest to geodb + credentials = retrieve_user_credentials(user) + geodb_repo = GeoDBRepo(credentials) + geodb_repo.insert(model.id, values) + # the catalog should contain all the info we want to show in the UI + model.catalog = catalog # OJO ! this is not the same as the model catalog + model.items = item + model.quality = model_quality + repo = ModelsDBRepo() + model.updatedAt = datetime.now() + repo.update_model(model.id, model.model_dump()) + return model diff --git a/api/api/src/usecases/models/retrieve_model.py b/api/api/src/usecases/models/retrieve_model.py index 0f07db03..c5c7dd8c 100755 --- a/api/api/src/usecases/models/retrieve_model.py +++ b/api/api/src/usecases/models/retrieve_model.py @@ -1,4 +1,4 @@ -from ...models import Model +from ...models import Model, STACModel from ...errors import ModelDoesNotExistError, UserUnauthorizedError from ...repos import ModelsDBRepo from ..files import retrieve_files @@ -7,7 +7,7 @@ def retrieve(data): if data is None: raise ModelDoesNotExistError() - return Model(**data) + return Model(**data) if data["quality"] == 0 else STACModel(**data) def retrieve_model(model_id): diff --git a/api/api/src/usecases/models/retrieve_models.py b/api/api/src/usecases/models/retrieve_models.py index abdc9275..0a185a91 100755 --- a/api/api/src/usecases/models/retrieve_models.py +++ b/api/api/src/usecases/models/retrieve_models.py @@ -1,5 +1,5 @@ from ...repos import ModelsDBRepo -from ...models import Model +from ...models import Model, STACModel def retrieve_models(match=None, limit=None): @@ -7,7 +7,10 @@ def retrieve_models(match=None, limit=None): data = repo.retrieve_models(match, limit) models = [] for d in data: - models.append(Model(**d)) + if d["quality"] == 0: + models.append(Model(**d)) + else: + models.append(STACModel(**d)) return models @@ -25,5 +28,8 @@ def retrieve_popular_models(limit): data = repo.retrieve_popular_models(limit) models = [] for d in data: - models.append(Model(**d)) + if d["quality"] == 0: + models.append(Model(**d)) + else: + models.append(STACModel(**d)) return models diff --git a/eotdl/eotdl/models/download.py b/eotdl/eotdl/models/download.py index 1cf7c69c..975bee79 100755 --- a/eotdl/eotdl/models/download.py +++ b/eotdl/eotdl/models/download.py @@ -5,8 +5,9 @@ from ..auth import with_auth from .retrieve import retrieve_model, retrieve_model_files from ..shared import calculate_checksum -from ..repos import FilesAPIRepo +from ..repos import FilesAPIRepo, ModelsAPIRepo from .metadata import generate_metadata +from ..curation.stac import STACDataFrame @with_auth @@ -46,20 +47,6 @@ def download_model( if model["quality"] == 0: if file: raise NotImplementedError("Downloading a specific file is not implemented") - # files = [f for f in model["files"] if f["name"] == file] - # if not files: - # raise Exception(f"File {file} not found") - # if len(files) > 1: - # raise Exception(f"Multiple files with name {file} found") - # dst_path = download( - # model, - # model["id"], - # file, - # files[0]["checksum"], - # download_path, - # user, - # ) - # return Outputs(dst_path=dst_path) model_files = retrieve_model_files(model["id"], version) repo = FilesAPIRepo() for file in tqdm(model_files, disable=verbose, unit="file"): @@ -74,41 +61,38 @@ def download_model( file_version, endpoint="models", ) - # if calculate_checksum(dst_path) != checksum: - # logger(f"Checksum for {file} does not match") + if verbose: + logger("Generating README.md ...") + generate_metadata(download_path, model) else: - raise NotImplementedError("Downloading a STAC model is not implemented") - # logger("Downloading STAC metadata...") - # gdf, error = repo.download_stac( - # model["id"], - # user["id_token"], - # ) - # if error: - # raise Exception(error) - # df = STACDataFrame(gdf) - # # df.geometry = df.geometry.apply(lambda x: Polygon() if x is None else x) - # path = path - # if path is None: - # path = download_base_path + "/" + model["name"] - # df.to_stac(path) - # # download assets - # if assets: - # logger("Downloading assets...") - # df = df.dropna(subset=["assets"]) - # for row in tqdm(df.iterrows(), total=len(df)): - # id = row[1]["stac_id"] - # # print(row[1]["links"]) - # for k, v in row[1]["assets"].items(): - # href = v["href"] - # repo.download_file_url( - # href, f"{path}/assets/{id}", user["id_token"] - # ) - # else: - # logger("To download assets, set assets=True or -a in the CLI.") - # return Outputs(dst_path=path) - if verbose: - logger("Generating README.md ...") - generate_metadata(download_path, model) + if verbose: + logger("Downloading STAC metadata...") + repo = ModelsAPIRepo() + gdf, error = repo.download_stac( + model["id"], + user, + ) + if error: + raise Exception(error) + df = STACDataFrame(gdf) + # df.geometry = df.geometry.apply(lambda x: Polygon() if x is None else x) + df.to_stac(download_path) + # download assets + if assets: + if verbose: + logger("Downloading assets...") + repo = FilesAPIRepo() + df = df.dropna(subset=["assets"]) + for row in tqdm(df.iterrows(), total=len(df)): + for k, v in row[1]["assets"].items(): + href = v["href"] + _, filename = href.split("/download/") + # will overwrite assets with same name :( + repo.download_file_url( + href, filename, f"{download_path}/assets", user + ) + else: + logger("To download assets, set assets=True or -a in the CLI.") if verbose: logger("Done") return download_path diff --git a/eotdl/eotdl/models/ingest.py b/eotdl/eotdl/models/ingest.py index ef6d1ca0..883ad0b1 100755 --- a/eotdl/eotdl/models/ingest.py +++ b/eotdl/eotdl/models/ingest.py @@ -2,13 +2,16 @@ import yaml import frontmatter import markdown +from tqdm import tqdm +import json from ..auth import with_auth from .metadata import Metadata, generate_metadata -from ..repos import ModelsAPIRepo +from ..repos import ModelsAPIRepo, FilesAPIRepo from ..shared import calculate_checksum -from ..files import ingest_files +from ..files import ingest_files, create_new_version from .update import update_model +from ..curation.stac import STACDataFrame def ingest_model( @@ -17,8 +20,8 @@ def ingest_model( path = Path(path) if not path.is_dir(): raise Exception("Path must be a folder") - # if "catalog.json" in [f.name for f in path.iterdir()]: - # return ingest_stac(path / "catalog.json", logger) + if "catalog.json" in [f.name for f in path.iterdir()]: + return ingest_stac(path / "catalog.json", logger) return ingest_folder(path, verbose, logger, force_metadata_update, sync_metadata) @@ -101,3 +104,64 @@ def check_metadata( generate_metadata(str(folder), dataset) return False return False + + +def retrieve_stac_model(model_name, user): + repo = ModelsAPIRepo() + data, error = repo.retrieve_model(model_name) + # print(data, error) + if data and data["uid"] != user["uid"]: + raise Exception("Model already exists.") + if error and error == "Model doesn't exist": + # create model + data, error = repo.create_stac_model(model_name, user) + # print(data, error) + if error: + raise Exception(error) + data["id"] = data["model_id"] + return data["id"] + + +@with_auth +def ingest_stac(stac_catalog, logger=None, user=None): + repo, files_repo = ModelsAPIRepo(), FilesAPIRepo() + # load catalog + logger("Loading STAC catalog...") + df = STACDataFrame.from_stac_file(stac_catalog) + catalog = df[df["type"] == "Catalog"] + assert len(catalog) == 1, "STAC catalog must have exactly one root catalog" + dataset_name = catalog.id.iloc[0] + # retrieve dataset (create if doesn't exist) + model_id = retrieve_stac_model(dataset_name, user) + # create new version + version = create_new_version(repo, model_id, user) + logger("New version created, version: " + str(version)) + df2 = df.dropna(subset=["assets"]) + for row in tqdm(df2.iterrows(), total=len(df2)): + try: + for k, v in row[1]["assets"].items(): + data, error = files_repo.ingest_file( + v["href"], + model_id, + user, + calculate_checksum(v["href"]), # is always absolute? + "models", + version, + ) + if error: + raise Exception(error) + file_url = ( + f"{repo.url}models/{data['model_id']}/download/{data['filename']}" + ) + df.loc[row[0], "assets"][k]["href"] = file_url + except Exception as e: + logger(f"Error uploading asset {row[0]}: {e}") + break + # ingest the STAC catalog into geodb + logger("Ingesting STAC catalog...") + data, error = repo.ingest_stac(json.loads(df.to_json()), model_id, user) + if error: + # TODO: delete all assets that were uploaded + raise Exception(error) + logger("Done") + return diff --git a/eotdl/eotdl/repos/ModelsAPIRepo.py b/eotdl/eotdl/repos/ModelsAPIRepo.py index f5214678..be30a59e 100755 --- a/eotdl/eotdl/repos/ModelsAPIRepo.py +++ b/eotdl/eotdl/repos/ModelsAPIRepo.py @@ -1,4 +1,5 @@ import requests +import geopandas as gpd from ..repos import APIRepo @@ -53,3 +54,27 @@ def update_model( headers=self.generate_headers(user), ) return self.format_response(response) + + def create_stac_model(self, name, user): + response = requests.post( + self.url + "models/stac", + json={"name": name}, + headers=self.generate_headers(user), + ) + return self.format_response(response) + + def ingest_stac(self, stac_json, model_id, user): + response = requests.put( + self.url + f"models/stac/{model_id}", + json={"stac": stac_json}, + headers=self.generate_headers(user), + ) + return self.format_response(response) + + def download_stac(self, model_id, user): + url = self.url + "models/" + model_id + "/download" + headers = self.generate_headers(user) + response = requests.get(url, headers=headers) + if response.status_code != 200: + return None, response.json()["detail"] + return gpd.GeoDataFrame.from_features(response.json()["features"]), None diff --git a/tutorials/notebooks/05_q2_model.ipynb b/tutorials/notebooks/05_q2_model.ipynb new file mode 100644 index 00000000..bb3608b5 --- /dev/null +++ b/tutorials/notebooks/05_q2_model.ipynb @@ -0,0 +1,463 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Q2 ML Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Like training datasets, ML Models in EOTDL are categorized into different [quality levels](https://eotdl.com/docs/datasets/quality), which in turn will impact the range of functionality that will be available for each model.\n", + "\n", + "In this tutorial you will learn about Q2 models, models with STAC metadata and the ML-Model extension (models with STAC metadata but not ML-Model extension will be qualified as Q1). " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## STAC Spec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For Q2 ML Models we rely on the [ML-Model](https://github.com/crim-ca/mlm-extension) STAC extension. Here we develop the required metadata for the [RoadSegmentation](https://www.eotdl.com/models/RoadSegmentation) Q0 model on EOTDL." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2024.05.02'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import eotdl\n", + "\n", + "eotdl.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model `RoadSegmentation v2` already exists at data/RoadSegmentation/v2. To force download, use force=True or -f in the CLI.\n" + ] + }, + { + "data": { + "text/plain": [ + "'data/RoadSegmentation/v2'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from eotdl.models import download_model\n", + "\n", + "try:\n", + "\tpath = download_model('RoadSegmentation', path=\"data\", version=2)\n", + "except Exception as e:\n", + "\tprint(e)\n", + "\tpath = 'data/RoadSegmentation/v2'\n", + "\n", + "path" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['README.md', 'model.onnx']" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os \n", + "\n", + "os.listdir(path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our goal is to provide STAC metadata to run `model.onnx` on any inference processor that implements the ML-Model STAC extension. From the official repo:\n", + "\n", + "> The STAC Machine Learning Model (MLM) Extension provides a standard set of fields to describe machine learning models trained on overhead imagery and enable running model inference.\n", + ">\n", + "> The main objectives of the extension are:\n", + ">\n", + "> 1. to enable building model collections that can be searched alongside associated STAC datasets\n", + "> 2. record all necessary bands, parameters, modeling artifact locations, and high-level processing steps to deploy an inference service.\n", + ">\n", + ">Specifically, this extension records the following information to make ML models searchable and reusable:\n", + ">\n", + "> 1. Sensor band specifications\n", + "> 2. Model input transforms including resize and normalization\n", + "> 3. Model output shape, data type, and its semantic interpretation\n", + "> 4. An optional, flexible description of the runtime environment to be able to run the model\n", + "> 5. Scientific references\n", + "\n", + "Let's start with a generic `catalog` for our model." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import pystac\n", + "\n", + "# current directory + 'data/RoadSegmentation/STAC'\n", + "root_href = os.path.join(os.getcwd(), 'data/RoadSegmentation/STAC')\n", + "\n", + "catalog = pystac.Catalog(id='RoadSegmentationQ2', description='Catalog for the Road Segmentation Q2 ML Model')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's create a `collection` for our model." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + "
" + ], + "text/plain": [ + ">" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pystac\n", + "from datetime import datetime\n", + "\n", + "# Create a new Collection\n", + "collection = pystac.Collection(\n", + " id='model',\n", + " description='Collection for the Road Segmentation Q2 ML Model',\n", + " extent=pystac.Extent(\n", + " spatial=pystac.SpatialExtent([[-180, -90, 180, 90]]), # dummy extent\n", + " temporal=pystac.TemporalExtent([[datetime(2020, 1, 1), None]]) # dummy extent\n", + " ),\n", + "\t# extra_fields={\n", + " # 'stac_extensions': ['https://crim-ca.github.io/mlm-extension/v1.2.0/schema.json']\n", + " # }\n", + ")\n", + "\n", + "# Add the Collection to the Catalog\n", + "catalog.add_child(collection)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And finally, an `item` to describe the model itself with the extension." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "
\n", + "
\n", + " \n", + "
\n", + "
" + ], + "text/plain": [ + ">" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create a new Item\n", + "item = pystac.Item(\n", + " id='model',\n", + " geometry={ # dummy geometry\n", + " \"type\": \"Point\",\n", + " \"coordinates\": [125.6, 10.1]\n", + " },\n", + " bbox=[125.6, 10.1, 125.6, 10.1], # dummy bbox\n", + " datetime=datetime.utcnow(), # dummy datetime\n", + " properties={ \n", + "\t\t\"mlm:name\": \"model.onnx\", # name of the asset ? otherwise, how can we know which asset to use ?\n", + "\t\t\"mlm:framework\": \"ONNX\", # only framework support for now\n", + "\t\t\"mlm:architecture\": \"U-Net\",\n", + "\t\t\"mlm:tasks\": [\"segmentation\"], # https://github.com/crim-ca/mlm-extension?tab=readme-ov-file#task-enum\n", + "\t\t\"mlm:input\": { # https://github.com/crim-ca/mlm-extension?tab=readme-ov-file#model-input-object\n", + "\t\t\t\"name\": \"RGB statellite image (HR)\",\n", + "\t\t\t\"bands\": [\n", + "\t\t\t\t\"red\",\n", + "\t\t\t\t\"green\",\n", + "\t\t\t\t\"blue\"\n", + "\t\t\t],\n", + "\t\t\t\"input\": { # https://github.com/crim-ca/mlm-extension?tab=readme-ov-file#input-structure-object\n", + "\t\t\t\t\"shape\": [\n", + "\t\t\t\t\t-1,\n", + "\t\t\t\t\t3,\n", + "\t\t\t\t\t-1, # should be divisble by 16\n", + "\t\t\t\t\t-1 # should be divisble by 16\n", + "\t\t\t\t],\n", + "\t\t\t\t\"dim_order\": [\n", + "\t\t\t\t\t\"batch\",\n", + "\t\t\t\t\t\"channel\",\n", + "\t\t\t\t\t\"height\",\n", + "\t\t\t\t\t\"width\"\n", + "\t\t\t\t],\n", + "\t\t\t\t\"data_type\": \"float32\",\n", + " # we should add here the resize to nearest divisible by 16\n", + "\t\t\t\t# \"pre_processing_function\": { # https://github.com/crim-ca/mlm-extension?tab=readme-ov-file#processing-expression\n", + "\t\t\t\t# \t\"format\": \n", + "\t\t\t\t# \t\"expression\": \n", + "\t\t\t\t# }\n", + "\t\t\t\t\"description\": \"Model trained with 1024x1024 RGB HR images, but can work with other dimensions as long as they are divisible by 16\"\n", + "\t\t\t}\n", + "\t\t},\n", + "\t\t\"mlm:output\": {\n", + "\t\t\t\"name\": \"road binary mask\",\n", + "\t\t\t\"tasks\": [\"segmentation\"], # redundant ?\n", + "\t\t\t\"result\": { # https://github.com/crim-ca/mlm-extension?tab=readme-ov-file#result-structure-object\n", + "\t\t\t\t\"shape\": [-1, -1, -1],\n", + "\t\t\t\t\"dim_order\": [\n", + "\t\t\t\t\t\"batch\",\n", + "\t\t\t\t\t\"height\",\n", + "\t\t\t\t\t\"width\"\n", + "\t\t\t\t],\n", + "\t\t\t\t\"data_type\": \"uint8\",\n", + "\t\t\t\t\"description\": \"Binary mask of the road segmentation. 1 for road, 0 for background\",\n", + "\t\t\t\t# \"post_processing_function\": { # https://github.com/crim-ca/mlm-extension?tab=readme-ov-file#processing-expression\n", + "\t\t\t\t# }\n", + "\t\t\t},\n", + "\t\t},\n", + "\t}, \n", + " stac_extensions=['https://crim-ca.github.io/mlm-extension/v1.2.0/schema.json']\n", + ")\n", + "\n", + "# Add the Item to the Collection\n", + "collection.add_item(item)\n", + "\n", + "# Save the Catalog to a file\n", + "# catalog.normalize_and_save(root_href=root_href, catalog_type=pystac.CatalogType.SELF_CONTAINED)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model weights are added as an asset to the item" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# Create an Asset\n", + "model_asset = pystac.Asset(\n", + " href=os.path.abspath('data/RoadSegmentation/v2/model.onnx'), \n", + ")\n", + "\n", + "# Add the Asset to the Item\n", + "item.add_asset('model', model_asset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we validate and save the metadata" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# Validate the Catalog\n", + "\n", + "# catalog.validate_all()\n", + "\n", + "catalog.normalize_and_save(root_href=root_href, catalog_type=pystac.CatalogType.SELF_CONTAINED)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "eotdl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}