diff --git a/lib/modelscan_api/bailo_modelscan_api/config.py b/lib/modelscan_api/bailo_modelscan_api/config.py index c29d155cb..608da39d9 100644 --- a/lib/modelscan_api/bailo_modelscan_api/config.py +++ b/lib/modelscan_api/bailo_modelscan_api/config.py @@ -25,13 +25,10 @@ class Settings(BaseSettings): You can upload files and view modelscan's result.""" app_version: str = "1.0.0" - download_dir: str = "." + # download_dir is used if it evaluates, otherwise a temporary directory is used. + download_dir: str | None = None modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS block_size: int = 1024 # Load in a dotenv file to set/overwrite any properties with potentially sensitive values model_config = SettingsConfigDict(env_file=".env") - - -logger.info("Instantiating settings.") -settings = Settings() diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 68e76e908..929484a84 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -3,10 +3,12 @@ from __future__ import annotations +from contextlib import nullcontext import logging from functools import lru_cache from http import HTTPStatus from pathlib import Path +from tempfile import TemporaryDirectory from typing import Any import uvicorn @@ -73,33 +75,37 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st """ logger.info("Called the API endpoint to scan an uploaded file") try: - if in_file.filename and str(in_file.filename).strip(): - pathlib_path = Path.joinpath(parse_path(get_settings().download_dir), str(in_file.filename)) - else: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="An error occurred while extracting the uploaded file's name.", - ) - - # Write the streamed in_file to disk. - # This is a bit silly as modelscan will ultimately load this back into memory, but modelscan - # doesn't currently support streaming. - try: - with open(pathlib_path, "wb") as out_file: - while content := in_file.file.read(get_settings().block_size): - out_file.write(content) - except OSError as exception: - logger.exception("Failed writing the file to the disk.") - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"An error occurred while trying to write the uploaded file to the disk: {exception}", - ) from exception - - # Scan the uploaded file. - result = modelscan.scan(pathlib_path) - - # Finally, return the result. - return result + # Use Setting's download_dir if defined else use a temporary directory. + with ( + TemporaryDirectory() if not get_settings().download_dir else nullcontext(get_settings().download_dir) + ) as download_dir: + if in_file.filename and str(in_file.filename).strip(): + pathlib_path = Path.joinpath(parse_path(download_dir), str(in_file.filename)) + else: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="An error occurred while extracting the uploaded file's name.", + ) + + # Write the streamed in_file to disk. + # This is a bit silly as modelscan will ultimately load this back into memory, but modelscan + # doesn't currently support streaming. + try: + with open(pathlib_path, "wb") as out_file: + while content := in_file.file.read(get_settings().block_size): + out_file.write(content) + except OSError as exception: + logger.exception("Failed writing the file to the disk.") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while trying to write the uploaded file to the disk: {exception}", + ) from exception + + # Scan the uploaded file. + result = modelscan.scan(pathlib_path) + + # Finally, return the result. + return result except HTTPException: # Re-raise HTTPExceptions. @@ -116,6 +122,7 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st finally: try: # Clean up the downloaded file as a background task to allow returning sooner. + # If using a temporary dir then this would happen anyway, but if Settings' download_dir evaluates then this is required. logger.info("Cleaning up downloaded file.") background_tasks.add_task(Path.unlink, pathlib_path, missing_ok=True) except UnboundLocalError: diff --git a/lib/modelscan_api/bailo_modelscan_api/test_main.py b/lib/modelscan_api/bailo_modelscan_api/test_main.py index 4843c444f..590e95c36 100644 --- a/lib/modelscan_api/bailo_modelscan_api/test_main.py +++ b/lib/modelscan_api/bailo_modelscan_api/test_main.py @@ -5,12 +5,20 @@ from unittest.mock import Mock, patch from fastapi.testclient import TestClient +from .config import Settings from .dependencies import parse_path from .main import app, get_settings client = TestClient(app) +def get_settings_override(): + return Settings(download_dir=".") + + +app.dependency_overrides[get_settings] = get_settings_override + + def test_health(): response = client.get("/health")