From 07a88dc1423ebefb6cf97e1f5894e539c0c51347 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Fri, 8 Nov 2024 10:51:23 +0000 Subject: [PATCH] BAI-1500 add error checking to modelscan API --- .../bailo_modelscan_api/config.py | 5 +- .../bailo_modelscan_api/dependencies.py | 8 ++ lib/modelscan_api/bailo_modelscan_api/main.py | 86 +++++++++++++------ 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/lib/modelscan_api/bailo_modelscan_api/config.py b/lib/modelscan_api/bailo_modelscan_api/config.py index 38d5b640a..5b75aaeca 100644 --- a/lib/modelscan_api/bailo_modelscan_api/config.py +++ b/lib/modelscan_api/bailo_modelscan_api/config.py @@ -1,3 +1,6 @@ +"""Configuration settings for FastAPI app. +""" + from typing import Any from modelscan.settings import DEFAULT_SETTINGS @@ -11,7 +14,7 @@ class Settings(BaseSettings): """ app_name: str = "Bailo ModelScan API" - download_dir: str = "." # TODO: use this + download_dir: str = "." modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS block_size: int = 1024 bailo_client_url: str = "http://localhost:8080/" diff --git a/lib/modelscan_api/bailo_modelscan_api/dependencies.py b/lib/modelscan_api/bailo_modelscan_api/dependencies.py index c5fa42ed4..0d3864ccb 100644 --- a/lib/modelscan_api/bailo_modelscan_api/dependencies.py +++ b/lib/modelscan_api/bailo_modelscan_api/dependencies.py @@ -1,3 +1,6 @@ +"""Common utilities used by the FastAPI app. +""" + from pathlib import Path from typing import Union from requests import Response @@ -11,6 +14,11 @@ def __init__(self, response: Response, path: Path) -> None: def parse_path(path: Union[str, Path, None]) -> Path: + """Ensure that a path is consistently represented as a Path. + + :param path: System path to parse. Defaults to the file's current working directory if unspecified. + :return: An absolute Path representation of the path parameter. + """ if path is None: path = "." return Path().cwd() if path == "." else Path(path).absolute() diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 8407710d4..477ac07ff 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -1,9 +1,14 @@ +"""FastAPI app. +""" + from email.message import Message from functools import lru_cache +from http import HTTPStatus from pathlib import Path from bailo import Client -from fastapi import BackgroundTasks, FastAPI +from bailo.core.exceptions import BailoException +from fastapi import BackgroundTasks, FastAPI, HTTPException from modelscan.modelscan import ModelScan from requests import Response import uvicorn @@ -42,7 +47,13 @@ def get_file(model_id: str, file_id: str) -> Response: :param file_id: Unique file ID :return: The unique file ID """ - return bailo_client.get_download_file(model_id, file_id) + try: + return bailo_client.get_download_file(model_id, file_id) + except BailoException as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while trying to connect to the Bailo client: {exception}", + ) def download_file(model_id: str, file_id: str, path: str | None = None) -> ResponsePath: @@ -50,28 +61,42 @@ def download_file(model_id: str, file_id: str, path: str | None = None) -> Respo :param model_id: Unique model ID :param file_id: Unique file ID - :param path: _description_ + :param path: The directory to write the downloaded file to :return: The unique file ID """ pathlib_path = parse_path(path) - # TODO: try/catch as bailo_client may be bad (e.g. auth errors) res = get_file(model_id, file_id) if not res.ok: - # TODO: properly error - raise Exception + raise HTTPException(status_code=res.status_code, detail=res.text) - if (content_disposition := res.headers.get("Content-Disposition")) is not None: - # parse to get filename + try: + # Parse to get the filename (we mainly care about the file's extension as modelscan uses that). + content_disposition = res.headers["Content-Disposition"] msg = Message() msg["content-disposition"] = content_disposition - if (filename := msg.get_filename()) is not None: + # None and empty strings both evaluate to false. + if filename := msg.get_filename(): pathlib_path = Path.joinpath(pathlib_path, str(filename)) - # TODO: else fail + else: + raise ValueError("Cannot have an empty filename") + except (ValueError, KeyError) as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while extracting the downloaded file's name.", + ) - with open(pathlib_path, "wb") as f: - for data in res.iter_content(get_settings().block_size): - f.write(data) + try: + # Write the streamed response to disk. + # This is a bit silly as modelscan will ultimately load this back into memory, but modelscan doesn't currently support streaming. + with open(pathlib_path, "wb") as f: + for data in res.iter_content(get_settings().block_size): + f.write(data) + except OSError as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while trying to write the downloaded file to the disk: {exception}\n{type(exception)}", + ) return ResponsePath(res, pathlib_path) @@ -92,22 +117,31 @@ def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): :param background_tasks: FastAPI object to perform background tasks once the function has already returned. :return: The model_id, file_id, and results object from modelscan. """ - # Make sure that we have the file that is being checked - # Ideally we would just get this abd pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects - file_response = download_file(model_id, file_id) - if not file_response.response.ok: - # TODO: error properly - return file_response - - # Scan the downloaded file. try: + # Ideally we would just get this and pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects. + file_response = download_file(model_id, file_id, get_settings().download_dir) + # No need to check the responses's status_code as download_file already does this. + + # Scan the downloaded file. result = modelscan.scan(file_response.path) - # TODO: catch and handle errors + + # Finally, return the result. + return {"model_id": model_id, "file_id": file_id, "result": result} + except HTTPException: + # Re-raise HTTPExceptions. + raise + except Exception as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred: {exception}", + ) finally: - # Clean up the downloaded file as a background task to allow returning sooner. - background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True) - # Finally, return the result. - return {"model_id": model_id, "file_id": file_id, "result": result} + try: + # Clean up the downloaded file as a background task to allow returning sooner. + background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True) + except: + # file_response may not be defined if download_file failed. + pass if __name__ == "__main__":