Skip to content

Commit

Permalink
BAI-1502 change default modelscan api download_dir to be a tempdir fo…
Browse files Browse the repository at this point in the history
…r security
  • Loading branch information
PE39806 committed Nov 15, 2024
1 parent 27b43ca commit 6ff92e5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 32 deletions.
7 changes: 2 additions & 5 deletions lib/modelscan_api/bailo_modelscan_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@ class Settings(BaseSettings):
You can upload files and view modelscan's result."""
app_version: str = "1.0.0"
download_dir: str = "."
# download_dir is used if it evaluates, otherwise a temporary directory is used.
download_dir: str | None = None
modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS
block_size: int = 1024

# Load in a dotenv file to set/overwrite any properties with potentially sensitive values
model_config = SettingsConfigDict(env_file=".env")


logger.info("Instantiating settings.")
settings = Settings()
61 changes: 34 additions & 27 deletions lib/modelscan_api/bailo_modelscan_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

from __future__ import annotations

from contextlib import nullcontext
import logging
from functools import lru_cache
from http import HTTPStatus
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any

import uvicorn
Expand Down Expand Up @@ -73,33 +75,37 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st
"""
logger.info("Called the API endpoint to scan an uploaded file")
try:
if in_file.filename and str(in_file.filename).strip():
pathlib_path = Path.joinpath(parse_path(get_settings().download_dir), str(in_file.filename))
else:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="An error occurred while extracting the uploaded file's name.",
)

# Write the streamed in_file to disk.
# This is a bit silly as modelscan will ultimately load this back into memory, but modelscan
# doesn't currently support streaming.
try:
with open(pathlib_path, "wb") as out_file:
while content := in_file.file.read(get_settings().block_size):
out_file.write(content)
except OSError as exception:
logger.exception("Failed writing the file to the disk.")
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred while trying to write the uploaded file to the disk: {exception}",
) from exception

# Scan the uploaded file.
result = modelscan.scan(pathlib_path)

# Finally, return the result.
return result
# Use Setting's download_dir if defined else use a temporary directory.
with (
TemporaryDirectory() if not get_settings().download_dir else nullcontext(get_settings().download_dir)
) as download_dir:
if in_file.filename and str(in_file.filename).strip():
pathlib_path = Path.joinpath(parse_path(download_dir), str(in_file.filename))
else:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="An error occurred while extracting the uploaded file's name.",
)

# Write the streamed in_file to disk.
# This is a bit silly as modelscan will ultimately load this back into memory, but modelscan
# doesn't currently support streaming.
try:
with open(pathlib_path, "wb") as out_file:
while content := in_file.file.read(get_settings().block_size):
out_file.write(content)
except OSError as exception:
logger.exception("Failed writing the file to the disk.")
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred while trying to write the uploaded file to the disk: {exception}",
) from exception

# Scan the uploaded file.
result = modelscan.scan(pathlib_path)

# Finally, return the result.
return result

except HTTPException:
# Re-raise HTTPExceptions.
Expand All @@ -116,6 +122,7 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st
finally:
try:
# Clean up the downloaded file as a background task to allow returning sooner.
# If using a temporary dir then this would happen anyway, but if Settings' download_dir evaluates then this is required.
logger.info("Cleaning up downloaded file.")
background_tasks.add_task(Path.unlink, pathlib_path, missing_ok=True)
except UnboundLocalError:
Expand Down
8 changes: 8 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient

from .config import Settings
from .dependencies import parse_path
from .main import app, get_settings

client = TestClient(app)


def get_settings_override():
return Settings(download_dir=".")


app.dependency_overrides[get_settings] = get_settings_override


def test_health():
response = client.get("/health")

Expand Down

0 comments on commit 6ff92e5

Please sign in to comment.