Skip to content

Commit

Permalink
BAI-1500 add error checking to modelscan API
Browse files Browse the repository at this point in the history
  • Loading branch information
PE39806 committed Nov 8, 2024
1 parent f3f6e26 commit 07a88dc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 27 deletions.
5 changes: 4 additions & 1 deletion lib/modelscan_api/bailo_modelscan_api/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Configuration settings for FastAPI app.
"""

from typing import Any

from modelscan.settings import DEFAULT_SETTINGS
Expand All @@ -11,7 +14,7 @@ class Settings(BaseSettings):
"""

app_name: str = "Bailo ModelScan API"
download_dir: str = "." # TODO: use this
download_dir: str = "."
modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS
block_size: int = 1024
bailo_client_url: str = "http://localhost:8080/"
Expand Down
8 changes: 8 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""Common utilities used by the FastAPI app.
"""

from pathlib import Path
from typing import Union
from requests import Response
Expand All @@ -11,6 +14,11 @@ def __init__(self, response: Response, path: Path) -> None:


def parse_path(path: Union[str, Path, None]) -> Path:
"""Ensure that a path is consistently represented as a Path.
:param path: System path to parse. Defaults to the file's current working directory if unspecified.
:return: An absolute Path representation of the path parameter.
"""
if path is None:
path = "."
return Path().cwd() if path == "." else Path(path).absolute()
86 changes: 60 additions & 26 deletions lib/modelscan_api/bailo_modelscan_api/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""FastAPI app.
"""

from email.message import Message
from functools import lru_cache
from http import HTTPStatus
from pathlib import Path

from bailo import Client
from fastapi import BackgroundTasks, FastAPI
from bailo.core.exceptions import BailoException
from fastapi import BackgroundTasks, FastAPI, HTTPException
from modelscan.modelscan import ModelScan
from requests import Response
import uvicorn
Expand Down Expand Up @@ -42,36 +47,56 @@ def get_file(model_id: str, file_id: str) -> Response:
:param file_id: Unique file ID
:return: The unique file ID
"""
return bailo_client.get_download_file(model_id, file_id)
try:
return bailo_client.get_download_file(model_id, file_id)
except BailoException as exception:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred while trying to connect to the Bailo client: {exception}",
)


def download_file(model_id: str, file_id: str, path: str | None = None) -> ResponsePath:
"""Get and download a specific file by its id.
:param model_id: Unique model ID
:param file_id: Unique file ID
:param path: _description_
:param path: The directory to write the downloaded file to
:return: The unique file ID
"""
pathlib_path = parse_path(path)

# TODO: try/catch as bailo_client may be bad (e.g. auth errors)
res = get_file(model_id, file_id)
if not res.ok:
# TODO: properly error
raise Exception
raise HTTPException(status_code=res.status_code, detail=res.text)

if (content_disposition := res.headers.get("Content-Disposition")) is not None:
# parse to get filename
try:
# Parse to get the filename (we mainly care about the file's extension as modelscan uses that).
content_disposition = res.headers["Content-Disposition"]
msg = Message()
msg["content-disposition"] = content_disposition
if (filename := msg.get_filename()) is not None:
# None and empty strings both evaluate to false.
if filename := msg.get_filename():
pathlib_path = Path.joinpath(pathlib_path, str(filename))
# TODO: else fail
else:
raise ValueError("Cannot have an empty filename")
except (ValueError, KeyError) as exception:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred while extracting the downloaded file's name.",
)

with open(pathlib_path, "wb") as f:
for data in res.iter_content(get_settings().block_size):
f.write(data)
try:
# Write the streamed response to disk.
# This is a bit silly as modelscan will ultimately load this back into memory, but modelscan doesn't currently support streaming.
with open(pathlib_path, "wb") as f:
for data in res.iter_content(get_settings().block_size):
f.write(data)
except OSError as exception:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred while trying to write the downloaded file to the disk: {exception}\n{type(exception)}",
)

return ResponsePath(res, pathlib_path)

Expand All @@ -92,22 +117,31 @@ def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks):
:param background_tasks: FastAPI object to perform background tasks once the function has already returned.
:return: The model_id, file_id, and results object from modelscan.
"""
# Make sure that we have the file that is being checked
# Ideally we would just get this abd pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects
file_response = download_file(model_id, file_id)
if not file_response.response.ok:
# TODO: error properly
return file_response

# Scan the downloaded file.
try:
# Ideally we would just get this and pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects.
file_response = download_file(model_id, file_id, get_settings().download_dir)
# No need to check the responses's status_code as download_file already does this.

# Scan the downloaded file.
result = modelscan.scan(file_response.path)
# TODO: catch and handle errors

# Finally, return the result.
return {"model_id": model_id, "file_id": file_id, "result": result}
except HTTPException:
# Re-raise HTTPExceptions.
raise
except Exception as exception:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred: {exception}",
)
finally:
# Clean up the downloaded file as a background task to allow returning sooner.
background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True)
# Finally, return the result.
return {"model_id": model_id, "file_id": file_id, "result": result}
try:
# Clean up the downloaded file as a background task to allow returning sooner.
background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True)
except:
# file_response may not be defined if download_file failed.
pass


if __name__ == "__main__":
Expand Down

0 comments on commit 07a88dc

Please sign in to comment.