From f3f6e266d713ff455c73e44b990370aa8952e211 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Thu, 7 Nov 2024 16:10:26 +0000 Subject: [PATCH 1/7] BAI-1500 create initial modelscan Rest API --- lib/modelscan_api/.gitignore | 10 ++ lib/modelscan_api/README.md | 35 ++++++ .../bailo_modelscan_api/__init__.py | 0 .../bailo_modelscan_api/config.py | 22 ++++ .../bailo_modelscan_api/dependencies.py | 16 +++ lib/modelscan_api/bailo_modelscan_api/main.py | 115 ++++++++++++++++++ lib/modelscan_api/requirements.txt | 5 + 7 files changed, 203 insertions(+) create mode 100644 lib/modelscan_api/.gitignore create mode 100644 lib/modelscan_api/README.md create mode 100644 lib/modelscan_api/bailo_modelscan_api/__init__.py create mode 100644 lib/modelscan_api/bailo_modelscan_api/config.py create mode 100644 lib/modelscan_api/bailo_modelscan_api/dependencies.py create mode 100644 lib/modelscan_api/bailo_modelscan_api/main.py create mode 100644 lib/modelscan_api/requirements.txt diff --git a/lib/modelscan_api/.gitignore b/lib/modelscan_api/.gitignore new file mode 100644 index 000000000..7e8e8134e --- /dev/null +++ b/lib/modelscan_api/.gitignore @@ -0,0 +1,10 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Environments +.env +*env/ +*ENV/ +*env.bak/ diff --git a/lib/modelscan_api/README.md b/lib/modelscan_api/README.md new file mode 100644 index 000000000..a226da7e1 --- /dev/null +++ b/lib/modelscan_api/README.md @@ -0,0 +1,35 @@ +# ModelScan + +This directory provides all of the necessary functionality to interact with +[modelscan](https://github.com/protectai/modelscan/tree/main) as an API. + +> ModelScan is an open source project from +> [Protect AI](https://protectai.com/?utm_campaign=Homepage&utm_source=ModelScan%20GitHub%20Page&utm_medium=cta&utm_content=Open%20Source) +> that scans models to determine if they contain unsafe code. It is the first model scanning tool to support multiple +> model formats. ModelScan currently supports: H5, Pickle, and SavedModel formats. This protects you when using PyTorch, +> TensorFlow, Keras, Sklearn, XGBoost, with more on the way. + +## Setup + +Create and activate a virtual environment + +```bash +python3 -m venv modelscan-venv +source modelscan-venv/bin/activate +``` + +Install the required pip packages + +```bash +pip install -r requirements.txt +``` + +## Usage + +Create and populate a `.env` file to override and set any variables, including sensitive properties. + +Run: + +```bash +fastapi dev bailo_modelscan_api/main.py +``` diff --git a/lib/modelscan_api/bailo_modelscan_api/__init__.py b/lib/modelscan_api/bailo_modelscan_api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/modelscan_api/bailo_modelscan_api/config.py b/lib/modelscan_api/bailo_modelscan_api/config.py new file mode 100644 index 000000000..38d5b640a --- /dev/null +++ b/lib/modelscan_api/bailo_modelscan_api/config.py @@ -0,0 +1,22 @@ +from typing import Any + +from modelscan.settings import DEFAULT_SETTINGS +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Basic settings object for the FastAPI app. + + :param BaseSettings: Default template object. + """ + + app_name: str = "Bailo ModelScan API" + download_dir: str = "." # TODO: use this + modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS + block_size: int = 1024 + bailo_client_url: str = "http://localhost:8080/" + + model_config = SettingsConfigDict(env_file=".env") + + +settings = Settings() diff --git a/lib/modelscan_api/bailo_modelscan_api/dependencies.py b/lib/modelscan_api/bailo_modelscan_api/dependencies.py new file mode 100644 index 000000000..c5fa42ed4 --- /dev/null +++ b/lib/modelscan_api/bailo_modelscan_api/dependencies.py @@ -0,0 +1,16 @@ +from pathlib import Path +from typing import Union +from requests import Response + + +class ResponsePath: + + def __init__(self, response: Response, path: Path) -> None: + self.response = response + self.path = path + + +def parse_path(path: Union[str, Path, None]) -> Path: + if path is None: + path = "." + return Path().cwd() if path == "." else Path(path).absolute() diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py new file mode 100644 index 000000000..8407710d4 --- /dev/null +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -0,0 +1,115 @@ +from email.message import Message +from functools import lru_cache +from pathlib import Path + +from bailo import Client +from fastapi import BackgroundTasks, FastAPI +from modelscan.modelscan import ModelScan +from requests import Response +import uvicorn + +from bailo_modelscan_api.dependencies import ResponsePath, parse_path +from bailo_modelscan_api.config import Settings + + +# Instantiate FastAPI app with various dependencies. +app = FastAPI() + + +@lru_cache +def get_settings() -> Settings: + """Fast way to only load settings from dotenv once. + + :return: Evaluated Settings from config file. + """ + return Settings() + + +# Instantiating the PkiAgent(), if using. +# agent = PkiAgent(cert='', key='', auth='') + +# Instantiating the Bailo client +bailo_client = Client(get_settings().bailo_client_url) + +# Instantiating ModelScan +modelscan = ModelScan(settings=get_settings().modelscan_settings) + + +def get_file(model_id: str, file_id: str) -> Response: + """Get a specific file by its id. + + :param model_id: Unique model ID + :param file_id: Unique file ID + :return: The unique file ID + """ + return bailo_client.get_download_file(model_id, file_id) + + +def download_file(model_id: str, file_id: str, path: str | None = None) -> ResponsePath: + """Get and download a specific file by its id. + + :param model_id: Unique model ID + :param file_id: Unique file ID + :param path: _description_ + :return: The unique file ID + """ + pathlib_path = parse_path(path) + + # TODO: try/catch as bailo_client may be bad (e.g. auth errors) + res = get_file(model_id, file_id) + if not res.ok: + # TODO: properly error + raise Exception + + if (content_disposition := res.headers.get("Content-Disposition")) is not None: + # parse to get filename + msg = Message() + msg["content-disposition"] = content_disposition + if (filename := msg.get_filename()) is not None: + pathlib_path = Path.joinpath(pathlib_path, str(filename)) + # TODO: else fail + + with open(pathlib_path, "wb") as f: + for data in res.iter_content(get_settings().block_size): + f.write(data) + + return ResponsePath(res, pathlib_path) + + +# TODO: don't keep this, but it is useful for testing things work +@app.get("/") +async def read_root(): + return {"message": "Hello world!"} + + +# TODO: define return schema +@app.get("/scan/{model_id}/{file_id}") +def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): + """Scan the specific file for a given model. + + :param model_id: Unique model ID + :param file_id: Unique file ID + :param background_tasks: FastAPI object to perform background tasks once the function has already returned. + :return: The model_id, file_id, and results object from modelscan. + """ + # Make sure that we have the file that is being checked + # Ideally we would just get this abd pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects + file_response = download_file(model_id, file_id) + if not file_response.response.ok: + # TODO: error properly + return file_response + + # Scan the downloaded file. + try: + result = modelscan.scan(file_response.path) + # TODO: catch and handle errors + finally: + # Clean up the downloaded file as a background task to allow returning sooner. + background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True) + # Finally, return the result. + return {"model_id": model_id, "file_id": file_id, "result": result} + + +if __name__ == "__main__": + # Start the app programmatically. + uvicorn.run(app) diff --git a/lib/modelscan_api/requirements.txt b/lib/modelscan_api/requirements.txt new file mode 100644 index 000000000..9fbfb70af --- /dev/null +++ b/lib/modelscan_api/requirements.txt @@ -0,0 +1,5 @@ +bailo==2.5.0 +fastapi[standard]==0.115.4 +modelscan==0.8.1 +pydantic_settings==2.6.1 +uvicorn==0.32.0 From 07a88dc1423ebefb6cf97e1f5894e539c0c51347 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Fri, 8 Nov 2024 10:51:23 +0000 Subject: [PATCH 2/7] BAI-1500 add error checking to modelscan API --- .../bailo_modelscan_api/config.py | 5 +- .../bailo_modelscan_api/dependencies.py | 8 ++ lib/modelscan_api/bailo_modelscan_api/main.py | 86 +++++++++++++------ 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/lib/modelscan_api/bailo_modelscan_api/config.py b/lib/modelscan_api/bailo_modelscan_api/config.py index 38d5b640a..5b75aaeca 100644 --- a/lib/modelscan_api/bailo_modelscan_api/config.py +++ b/lib/modelscan_api/bailo_modelscan_api/config.py @@ -1,3 +1,6 @@ +"""Configuration settings for FastAPI app. +""" + from typing import Any from modelscan.settings import DEFAULT_SETTINGS @@ -11,7 +14,7 @@ class Settings(BaseSettings): """ app_name: str = "Bailo ModelScan API" - download_dir: str = "." # TODO: use this + download_dir: str = "." modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS block_size: int = 1024 bailo_client_url: str = "http://localhost:8080/" diff --git a/lib/modelscan_api/bailo_modelscan_api/dependencies.py b/lib/modelscan_api/bailo_modelscan_api/dependencies.py index c5fa42ed4..0d3864ccb 100644 --- a/lib/modelscan_api/bailo_modelscan_api/dependencies.py +++ b/lib/modelscan_api/bailo_modelscan_api/dependencies.py @@ -1,3 +1,6 @@ +"""Common utilities used by the FastAPI app. +""" + from pathlib import Path from typing import Union from requests import Response @@ -11,6 +14,11 @@ def __init__(self, response: Response, path: Path) -> None: def parse_path(path: Union[str, Path, None]) -> Path: + """Ensure that a path is consistently represented as a Path. + + :param path: System path to parse. Defaults to the file's current working directory if unspecified. + :return: An absolute Path representation of the path parameter. + """ if path is None: path = "." return Path().cwd() if path == "." else Path(path).absolute() diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 8407710d4..477ac07ff 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -1,9 +1,14 @@ +"""FastAPI app. +""" + from email.message import Message from functools import lru_cache +from http import HTTPStatus from pathlib import Path from bailo import Client -from fastapi import BackgroundTasks, FastAPI +from bailo.core.exceptions import BailoException +from fastapi import BackgroundTasks, FastAPI, HTTPException from modelscan.modelscan import ModelScan from requests import Response import uvicorn @@ -42,7 +47,13 @@ def get_file(model_id: str, file_id: str) -> Response: :param file_id: Unique file ID :return: The unique file ID """ - return bailo_client.get_download_file(model_id, file_id) + try: + return bailo_client.get_download_file(model_id, file_id) + except BailoException as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while trying to connect to the Bailo client: {exception}", + ) def download_file(model_id: str, file_id: str, path: str | None = None) -> ResponsePath: @@ -50,28 +61,42 @@ def download_file(model_id: str, file_id: str, path: str | None = None) -> Respo :param model_id: Unique model ID :param file_id: Unique file ID - :param path: _description_ + :param path: The directory to write the downloaded file to :return: The unique file ID """ pathlib_path = parse_path(path) - # TODO: try/catch as bailo_client may be bad (e.g. auth errors) res = get_file(model_id, file_id) if not res.ok: - # TODO: properly error - raise Exception + raise HTTPException(status_code=res.status_code, detail=res.text) - if (content_disposition := res.headers.get("Content-Disposition")) is not None: - # parse to get filename + try: + # Parse to get the filename (we mainly care about the file's extension as modelscan uses that). + content_disposition = res.headers["Content-Disposition"] msg = Message() msg["content-disposition"] = content_disposition - if (filename := msg.get_filename()) is not None: + # None and empty strings both evaluate to false. + if filename := msg.get_filename(): pathlib_path = Path.joinpath(pathlib_path, str(filename)) - # TODO: else fail + else: + raise ValueError("Cannot have an empty filename") + except (ValueError, KeyError) as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while extracting the downloaded file's name.", + ) - with open(pathlib_path, "wb") as f: - for data in res.iter_content(get_settings().block_size): - f.write(data) + try: + # Write the streamed response to disk. + # This is a bit silly as modelscan will ultimately load this back into memory, but modelscan doesn't currently support streaming. + with open(pathlib_path, "wb") as f: + for data in res.iter_content(get_settings().block_size): + f.write(data) + except OSError as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while trying to write the downloaded file to the disk: {exception}\n{type(exception)}", + ) return ResponsePath(res, pathlib_path) @@ -92,22 +117,31 @@ def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): :param background_tasks: FastAPI object to perform background tasks once the function has already returned. :return: The model_id, file_id, and results object from modelscan. """ - # Make sure that we have the file that is being checked - # Ideally we would just get this abd pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects - file_response = download_file(model_id, file_id) - if not file_response.response.ok: - # TODO: error properly - return file_response - - # Scan the downloaded file. try: + # Ideally we would just get this and pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects. + file_response = download_file(model_id, file_id, get_settings().download_dir) + # No need to check the responses's status_code as download_file already does this. + + # Scan the downloaded file. result = modelscan.scan(file_response.path) - # TODO: catch and handle errors + + # Finally, return the result. + return {"model_id": model_id, "file_id": file_id, "result": result} + except HTTPException: + # Re-raise HTTPExceptions. + raise + except Exception as exception: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred: {exception}", + ) finally: - # Clean up the downloaded file as a background task to allow returning sooner. - background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True) - # Finally, return the result. - return {"model_id": model_id, "file_id": file_id, "result": result} + try: + # Clean up the downloaded file as a background task to allow returning sooner. + background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True) + except: + # file_response may not be defined if download_file failed. + pass if __name__ == "__main__": From 49f7111570ae77b1317e658cc326ea77f27c0114 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Mon, 11 Nov 2024 09:59:22 +0000 Subject: [PATCH 3/7] BAI-1500 add logging and pre-commit for development --- lib/modelscan_api/.pre-commit-config.yaml | 60 +++++++++++++++++++ lib/modelscan_api/README.md | 15 +++++ .../bailo_modelscan_api/__init__.py | 5 ++ .../bailo_modelscan_api/config.py | 6 ++ .../bailo_modelscan_api/dependencies.py | 11 +++- lib/modelscan_api/bailo_modelscan_api/main.py | 37 +++++++----- 6 files changed, 117 insertions(+), 17 deletions(-) create mode 100644 lib/modelscan_api/.pre-commit-config.yaml diff --git a/lib/modelscan_api/.pre-commit-config.yaml b/lib/modelscan_api/.pre-commit-config.yaml new file mode 100644 index 000000000..d957f567f --- /dev/null +++ b/lib/modelscan_api/.pre-commit-config.yaml @@ -0,0 +1,60 @@ +ci: + autoupdate_commit_msg: 'chore: update pre-commit hooks' + autofix_commit_msg: 'style: pre-commit fixes' + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: requirements-txt-fixer + - id: trailing-whitespace + + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ['-a', 'from __future__ import annotations'] + + - repo: https://github.com/asottile/pyupgrade + rev: v3.15.0 + hooks: + - id: pyupgrade + args: [--py37-plus] + + - repo: https://github.com/hadialqattan/pycln + rev: v2.4.0 + hooks: + - id: pycln + args: [--config=pyproject.toml] + stages: [manual] + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + + - repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 + hooks: + - id: python-check-blanket-noqa + - id: python-check-blanket-type-ignore + - id: python-no-log-warn + - id: python-no-eval + - id: python-use-type-annotations + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal + + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.11.0 + hooks: + - id: black + args: [--line-length=120] diff --git a/lib/modelscan_api/README.md b/lib/modelscan_api/README.md index a226da7e1..d500f3636 100644 --- a/lib/modelscan_api/README.md +++ b/lib/modelscan_api/README.md @@ -33,3 +33,18 @@ Run: ```bash fastapi dev bailo_modelscan_api/main.py ``` + +Connect via the local endpoint (development only): `http://127.0.0.1:8000` + +View the swagger docs: `http://127.0.0.1:8000/docs` + +## Development + +### Install and add pre-commit + +If already working on Bailo you may be prompted to overwrite Husky. Follow the instructions given by Git CLI. + +```bash +pip install pre-commit +pre-commit install +``` diff --git a/lib/modelscan_api/bailo_modelscan_api/__init__.py b/lib/modelscan_api/bailo_modelscan_api/__init__.py index e69de29bb..77453f6de 100644 --- a/lib/modelscan_api/bailo_modelscan_api/__init__.py +++ b/lib/modelscan_api/bailo_modelscan_api/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +import logging + +logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/lib/modelscan_api/bailo_modelscan_api/config.py b/lib/modelscan_api/bailo_modelscan_api/config.py index 5b75aaeca..0e487e23e 100644 --- a/lib/modelscan_api/bailo_modelscan_api/config.py +++ b/lib/modelscan_api/bailo_modelscan_api/config.py @@ -1,11 +1,16 @@ """Configuration settings for FastAPI app. """ +from __future__ import annotations + +import logging from typing import Any from modelscan.settings import DEFAULT_SETTINGS from pydantic_settings import BaseSettings, SettingsConfigDict +logger = logging.getLogger(__name__) + class Settings(BaseSettings): """Basic settings object for the FastAPI app. @@ -22,4 +27,5 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_file=".env") +logger.info("Instantiating settings.") settings = Settings() diff --git a/lib/modelscan_api/bailo_modelscan_api/dependencies.py b/lib/modelscan_api/bailo_modelscan_api/dependencies.py index 0d3864ccb..a7875d87d 100644 --- a/lib/modelscan_api/bailo_modelscan_api/dependencies.py +++ b/lib/modelscan_api/bailo_modelscan_api/dependencies.py @@ -1,24 +1,29 @@ """Common utilities used by the FastAPI app. """ +from __future__ import annotations + +import logging from pathlib import Path -from typing import Union + from requests import Response +logger = logging.getLogger(__name__) -class ResponsePath: +class ResponsePath: def __init__(self, response: Response, path: Path) -> None: self.response = response self.path = path -def parse_path(path: Union[str, Path, None]) -> Path: +def parse_path(path: str | Path | None) -> Path: """Ensure that a path is consistently represented as a Path. :param path: System path to parse. Defaults to the file's current working directory if unspecified. :return: An absolute Path representation of the path parameter. """ + logger.info("Parsing path.") if path is None: path = "." return Path().cwd() if path == "." else Path(path).absolute() diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 477ac07ff..0396be316 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -1,21 +1,24 @@ """FastAPI app. """ +from __future__ import annotations + +import logging from email.message import Message from functools import lru_cache from http import HTTPStatus from pathlib import Path +import uvicorn from bailo import Client from bailo.core.exceptions import BailoException +from bailo_modelscan_api.config import Settings +from bailo_modelscan_api.dependencies import ResponsePath, parse_path from fastapi import BackgroundTasks, FastAPI, HTTPException from modelscan.modelscan import ModelScan from requests import Response -import uvicorn - -from bailo_modelscan_api.dependencies import ResponsePath, parse_path -from bailo_modelscan_api.config import Settings +logger = logging.getLogger(__name__) # Instantiate FastAPI app with various dependencies. app = FastAPI() @@ -47,9 +50,11 @@ def get_file(model_id: str, file_id: str) -> Response: :param file_id: Unique file ID :return: The unique file ID """ + logger.info("Fetching specified file from the bailo client.") try: return bailo_client.get_download_file(model_id, file_id) except BailoException as exception: + logger.exception("Failed to get the specified file from the bailo client.") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"An error occurred while trying to connect to the Bailo client: {exception}", @@ -64,10 +69,12 @@ def download_file(model_id: str, file_id: str, path: str | None = None) -> Respo :param path: The directory to write the downloaded file to :return: The unique file ID """ + logger.info("Downloading file from bailo client.") pathlib_path = parse_path(path) res = get_file(model_id, file_id) if not res.ok: + logger.exception('The bailo client did not return an "ok" response.') raise HTTPException(status_code=res.status_code, detail=res.text) try: @@ -81,6 +88,7 @@ def download_file(model_id: str, file_id: str, path: str | None = None) -> Respo else: raise ValueError("Cannot have an empty filename") except (ValueError, KeyError) as exception: + logger.exception("Failed to extract key information.") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"An error occurred while extracting the downloaded file's name.", @@ -88,11 +96,13 @@ def download_file(model_id: str, file_id: str, path: str | None = None) -> Respo try: # Write the streamed response to disk. - # This is a bit silly as modelscan will ultimately load this back into memory, but modelscan doesn't currently support streaming. + # This is a bit silly as modelscan will ultimately load this back into memory, but modelscan + # doesn't currently support streaming. with open(pathlib_path, "wb") as f: for data in res.iter_content(get_settings().block_size): f.write(data) except OSError as exception: + logger.exception("Failed writing the file to the disk.") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"An error occurred while trying to write the downloaded file to the disk: {exception}\n{type(exception)}", @@ -101,12 +111,6 @@ def download_file(model_id: str, file_id: str, path: str | None = None) -> Respo return ResponsePath(res, pathlib_path) -# TODO: don't keep this, but it is useful for testing things work -@app.get("/") -async def read_root(): - return {"message": "Hello world!"} - - # TODO: define return schema @app.get("/scan/{model_id}/{file_id}") def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): @@ -117,8 +121,10 @@ def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): :param background_tasks: FastAPI object to perform background tasks once the function has already returned. :return: The model_id, file_id, and results object from modelscan. """ + logger.info("Called the API endpoint to scan a specific file") try: - # Ideally we would just get this and pass the streamed response to modelscan, but currently modelscan only reads from files rather than in-memory objects. + # Ideally we would just get this and pass the streamed response to modelscan, but currently modelscan + # only reads from files rather than in-memory objects. file_response = download_file(model_id, file_id, get_settings().download_dir) # No need to check the responses's status_code as download_file already does this. @@ -129,8 +135,10 @@ def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): return {"model_id": model_id, "file_id": file_id, "result": result} except HTTPException: # Re-raise HTTPExceptions. + logger.exception("Re-raising HTTPException.") raise except Exception as exception: + logger.exception("An unexpected error occurred.") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"An error occurred: {exception}", @@ -138,12 +146,13 @@ def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): finally: try: # Clean up the downloaded file as a background task to allow returning sooner. + logger.info("Cleaning up downloaded file.") background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True) - except: + except Exception: # file_response may not be defined if download_file failed. pass if __name__ == "__main__": - # Start the app programmatically. + logger.info("Starting the application programmatically.") uvicorn.run(app) From 0cf1f1eb61f5874fca19750f6f45c566ea09af03 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:05:30 +0000 Subject: [PATCH 4/7] BAI-1500 rewrite endpoint to accept an uploaded file rather than do an additional API call --- .../bailo_modelscan_api/config.py | 8 +- .../bailo_modelscan_api/dependencies.py | 8 - lib/modelscan_api/bailo_modelscan_api/main.py | 151 ++++++------------ lib/modelscan_api/requirements.txt | 3 +- 4 files changed, 60 insertions(+), 110 deletions(-) diff --git a/lib/modelscan_api/bailo_modelscan_api/config.py b/lib/modelscan_api/bailo_modelscan_api/config.py index 0e487e23e..c29d155cb 100644 --- a/lib/modelscan_api/bailo_modelscan_api/config.py +++ b/lib/modelscan_api/bailo_modelscan_api/config.py @@ -19,11 +19,17 @@ class Settings(BaseSettings): """ app_name: str = "Bailo ModelScan API" + app_summary: str = "REST API wrapper for ModelScan package for use with Bailo." + app_description: str = """ + Bailo ModelScan API allows for easy programmatic interfacing with ProtectAI's ModelScan package to scan and detect potential threats within files stored in Bailo. + + You can upload files and view modelscan's result.""" + app_version: str = "1.0.0" download_dir: str = "." modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS block_size: int = 1024 - bailo_client_url: str = "http://localhost:8080/" + # Load in a dotenv file to set/overwrite any properties with potentially sensitive values model_config = SettingsConfigDict(env_file=".env") diff --git a/lib/modelscan_api/bailo_modelscan_api/dependencies.py b/lib/modelscan_api/bailo_modelscan_api/dependencies.py index a7875d87d..b8ee898ff 100644 --- a/lib/modelscan_api/bailo_modelscan_api/dependencies.py +++ b/lib/modelscan_api/bailo_modelscan_api/dependencies.py @@ -6,17 +6,9 @@ import logging from pathlib import Path -from requests import Response - logger = logging.getLogger(__name__) -class ResponsePath: - def __init__(self, response: Response, path: Path) -> None: - self.response = response - self.path = path - - def parse_path(path: str | Path | None) -> Path: """Ensure that a path is consistently represented as a Path. diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 0396be316..9610b25b5 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -4,24 +4,19 @@ from __future__ import annotations import logging -from email.message import Message from functools import lru_cache from http import HTTPStatus from pathlib import Path +from typing import Any import uvicorn -from bailo import Client -from bailo.core.exceptions import BailoException from bailo_modelscan_api.config import Settings -from bailo_modelscan_api.dependencies import ResponsePath, parse_path -from fastapi import BackgroundTasks, FastAPI, HTTPException +from bailo_modelscan_api.dependencies import parse_path +from fastapi import BackgroundTasks, FastAPI, HTTPException, UploadFile from modelscan.modelscan import ModelScan -from requests import Response -logger = logging.getLogger(__name__) -# Instantiate FastAPI app with various dependencies. -app = FastAPI() +logger = logging.getLogger(__name__) @lru_cache @@ -33,124 +28,82 @@ def get_settings() -> Settings: return Settings() -# Instantiating the PkiAgent(), if using. -# agent = PkiAgent(cert='', key='', auth='') - -# Instantiating the Bailo client -bailo_client = Client(get_settings().bailo_client_url) +# Instantiate FastAPI app with various dependencies. +app = FastAPI( + title=get_settings().app_name, + summary=get_settings().app_summary, + description=get_settings().app_description, + version=get_settings().app_version, +) # Instantiating ModelScan modelscan = ModelScan(settings=get_settings().modelscan_settings) -def get_file(model_id: str, file_id: str) -> Response: - """Get a specific file by its id. - - :param model_id: Unique model ID - :param file_id: Unique file ID - :return: The unique file ID - """ - logger.info("Fetching specified file from the bailo client.") - try: - return bailo_client.get_download_file(model_id, file_id) - except BailoException as exception: - logger.exception("Failed to get the specified file from the bailo client.") - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"An error occurred while trying to connect to the Bailo client: {exception}", - ) - - -def download_file(model_id: str, file_id: str, path: str | None = None) -> ResponsePath: - """Get and download a specific file by its id. +@app.post( + "/scan/file", + summary="Upload and scan a file", + description="Upload a file which is scanned by ModelScan and return the result", + status_code=HTTPStatus.OK, + response_description="The result from ModelScan", +) +def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[str, Any]: + """API endpoint to upload and scan a file using modelscan. - :param model_id: Unique model ID - :param file_id: Unique file ID - :param path: The directory to write the downloaded file to - :return: The unique file ID + :param in_file: uploaded file to be scanned + :param background_tasks: FastAPI object to perform background tasks once the function has already returned. + :raises HTTPException: failure to process the uploaded file in any way + :return: `modelscan.scan` results """ - logger.info("Downloading file from bailo client.") - pathlib_path = parse_path(path) - - res = get_file(model_id, file_id) - if not res.ok: - logger.exception('The bailo client did not return an "ok" response.') - raise HTTPException(status_code=res.status_code, detail=res.text) - + logger.info("Called the API endpoint to scan an uploaded file") try: - # Parse to get the filename (we mainly care about the file's extension as modelscan uses that). - content_disposition = res.headers["Content-Disposition"] - msg = Message() - msg["content-disposition"] = content_disposition - # None and empty strings both evaluate to false. - if filename := msg.get_filename(): - pathlib_path = Path.joinpath(pathlib_path, str(filename)) + if in_file.filename: + pathlib_path = Path.joinpath(parse_path(get_settings().download_dir), str(in_file.filename)) else: - raise ValueError("Cannot have an empty filename") - except (ValueError, KeyError) as exception: - logger.exception("Failed to extract key information.") - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"An error occurred while extracting the downloaded file's name.", - ) + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail="An error occurred while extracting the uploaded file's name.", + ) - try: - # Write the streamed response to disk. + # Write the streamed in_file to disk. # This is a bit silly as modelscan will ultimately load this back into memory, but modelscan # doesn't currently support streaming. - with open(pathlib_path, "wb") as f: - for data in res.iter_content(get_settings().block_size): - f.write(data) - except OSError as exception: - logger.exception("Failed writing the file to the disk.") - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail=f"An error occurred while trying to write the downloaded file to the disk: {exception}\n{type(exception)}", - ) - - return ResponsePath(res, pathlib_path) - - -# TODO: define return schema -@app.get("/scan/{model_id}/{file_id}") -def scan(model_id: str, file_id: str, background_tasks: BackgroundTasks): - """Scan the specific file for a given model. - - :param model_id: Unique model ID - :param file_id: Unique file ID - :param background_tasks: FastAPI object to perform background tasks once the function has already returned. - :return: The model_id, file_id, and results object from modelscan. - """ - logger.info("Called the API endpoint to scan a specific file") - try: - # Ideally we would just get this and pass the streamed response to modelscan, but currently modelscan - # only reads from files rather than in-memory objects. - file_response = download_file(model_id, file_id, get_settings().download_dir) - # No need to check the responses's status_code as download_file already does this. - - # Scan the downloaded file. - result = modelscan.scan(file_response.path) + try: + with open(pathlib_path, "wb") as out_file: + while content := in_file.file.read(get_settings().block_size): + out_file.write(content) + except OSError as exception: + logger.exception("Failed writing the file to the disk.") + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + detail=f"An error occurred while trying to write the uploaded file to the disk: {exception}", + ) from exception + + # Scan the uploaded file. + result = modelscan.scan(pathlib_path) # Finally, return the result. - return {"model_id": model_id, "file_id": file_id, "result": result} + return result + except HTTPException: # Re-raise HTTPExceptions. logger.exception("Re-raising HTTPException.") raise + except Exception as exception: logger.exception("An unexpected error occurred.") raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"An error occurred: {exception}", - ) + ) from exception + finally: try: # Clean up the downloaded file as a background task to allow returning sooner. logger.info("Cleaning up downloaded file.") - background_tasks.add_task(Path.unlink, file_response.path, missing_ok=True) + background_tasks.add_task(Path.unlink, pathlib_path, missing_ok=True) except Exception: - # file_response may not be defined if download_file failed. - pass + logger.exception("An error occurred while trying to cleanup the downloaded file.") if __name__ == "__main__": diff --git a/lib/modelscan_api/requirements.txt b/lib/modelscan_api/requirements.txt index 9fbfb70af..583d6efcb 100644 --- a/lib/modelscan_api/requirements.txt +++ b/lib/modelscan_api/requirements.txt @@ -1,5 +1,4 @@ -bailo==2.5.0 fastapi[standard]==0.115.4 -modelscan==0.8.1 +modelscan[tensorflow,h5py]==0.8.1 pydantic_settings==2.6.1 uvicorn==0.32.0 From b5087fde126733bd6904411c721f07a817ac7f8f Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:13:52 +0000 Subject: [PATCH 5/7] BAI-1500 improve docs --- lib/modelscan_api/README.md | 10 ++++++++-- lib/modelscan_api/bailo_modelscan_api/main.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/modelscan_api/README.md b/lib/modelscan_api/README.md index d500f3636..62c196287 100644 --- a/lib/modelscan_api/README.md +++ b/lib/modelscan_api/README.md @@ -31,10 +31,10 @@ Create and populate a `.env` file to override and set any variables, including s Run: ```bash -fastapi dev bailo_modelscan_api/main.py +fastapi run bailo_modelscan_api/main.py ``` -Connect via the local endpoint (development only): `http://127.0.0.1:8000` +Connect via the local endpoint: `http://127.0.0.1:8000` View the swagger docs: `http://127.0.0.1:8000/docs` @@ -48,3 +48,9 @@ If already working on Bailo you may be prompted to overwrite Husky. Follow the i pip install pre-commit pre-commit install ``` + +To run in dev mode: + +```bash +fastapi dev bailo_modelscan_api/main.py +``` diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 9610b25b5..be8d1a1d4 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -43,7 +43,7 @@ def get_settings() -> Settings: @app.post( "/scan/file", summary="Upload and scan a file", - description="Upload a file which is scanned by ModelScan and return the result", + description="Upload a file which is scanned by ModelScan and return the result of the scan", status_code=HTTPStatus.OK, response_description="The result from ModelScan", ) From 9829d63738d168193caa7f4be75989fcaa6e4642 Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Tue, 12 Nov 2024 09:29:53 +0000 Subject: [PATCH 6/7] BAI-1500 add pytest tests for main application --- lib/modelscan_api/README.md | 18 ++++++-- lib/modelscan_api/bailo_modelscan_api/main.py | 5 +- .../bailo_modelscan_api/test_main.py | 46 +++++++++++++++++++ lib/modelscan_api/requirements-dev.txt | 3 ++ 4 files changed, 66 insertions(+), 6 deletions(-) create mode 100644 lib/modelscan_api/bailo_modelscan_api/test_main.py create mode 100644 lib/modelscan_api/requirements-dev.txt diff --git a/lib/modelscan_api/README.md b/lib/modelscan_api/README.md index 62c196287..73bfd23f6 100644 --- a/lib/modelscan_api/README.md +++ b/lib/modelscan_api/README.md @@ -1,7 +1,7 @@ # ModelScan This directory provides all of the necessary functionality to interact with -[modelscan](https://github.com/protectai/modelscan/tree/main) as an API. +[modelscan](https://github.com/protectai/modelscan/tree/main) as a REST API. > ModelScan is an open source project from > [Protect AI](https://protectai.com/?utm_campaign=Homepage&utm_source=ModelScan%20GitHub%20Page&utm_medium=cta&utm_content=Open%20Source) @@ -40,16 +40,26 @@ View the swagger docs: `http://127.0.0.1:8000/docs` ## Development -### Install and add pre-commit +### Install dev packages If already working on Bailo you may be prompted to overwrite Husky. Follow the instructions given by Git CLI. ```bash -pip install pre-commit +pip install -r requirements-dev.txt pre-commit install ``` -To run in dev mode: +### Tests + +To run the tests: + +```bash +pytest +``` + +### Running + +To run in [dev mode](https://fastapi.tiangolo.com/fastapi-cli/#fastapi-dev): ```bash fastapi dev bailo_modelscan_api/main.py diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index be8d1a1d4..288049e15 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -57,7 +57,7 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st """ logger.info("Called the API endpoint to scan an uploaded file") try: - if in_file.filename: + if in_file.filename and str(in_file.filename).strip(): pathlib_path = Path.joinpath(parse_path(get_settings().download_dir), str(in_file.filename)) else: raise HTTPException( @@ -102,7 +102,8 @@ def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[st # Clean up the downloaded file as a background task to allow returning sooner. logger.info("Cleaning up downloaded file.") background_tasks.add_task(Path.unlink, pathlib_path, missing_ok=True) - except Exception: + except UnboundLocalError: + # pathlib_path may not exist. logger.exception("An error occurred while trying to cleanup the downloaded file.") diff --git a/lib/modelscan_api/bailo_modelscan_api/test_main.py b/lib/modelscan_api/bailo_modelscan_api/test_main.py new file mode 100644 index 000000000..b96bb0fd0 --- /dev/null +++ b/lib/modelscan_api/bailo_modelscan_api/test_main.py @@ -0,0 +1,46 @@ +"""Test for the main.py file. +""" + +from pathlib import Path +from unittest.mock import Mock, patch +from fastapi.testclient import TestClient + +from .dependencies import parse_path +from .main import app, get_settings + +client = TestClient(app) + + +@patch("modelscan.modelscan.ModelScan.scan") +def test_scan_file(mock_scan: Mock): + mock_scan.return_value = {} + files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")} + + response = client.post("/scan/file", files=files) + + assert response.status_code == 200 + mock_scan.assert_called_once() + + +@patch("modelscan.modelscan.ModelScan.scan") +def test_scan_file_exception(mock_scan: Mock): + mock_scan.side_effect = Exception("Mocked error!") + files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")} + + response = client.post("/scan/file", files=files) + + assert response.status_code == 500 + assert response.json() == {"detail": "An error occurred: Mocked error!"} + mock_scan.assert_called_once() + + # Manually cleanup as FastAPI won't trigger background_tasks on Exception due to using TestClient. + Path.unlink(Path.joinpath(parse_path(get_settings().download_dir), "foo.h5"), missing_ok=True) + + +def test_scan_file_filename_missing(): + files = {"in_file": (" ", rb"", "application/x-hdf5")} + + response = client.post("/scan/file", files=files) + + assert response.status_code == 500 + assert response.json() == {"detail": "An error occurred while extracting the uploaded file's name."} diff --git a/lib/modelscan_api/requirements-dev.txt b/lib/modelscan_api/requirements-dev.txt new file mode 100644 index 000000000..02942b08d --- /dev/null +++ b/lib/modelscan_api/requirements-dev.txt @@ -0,0 +1,3 @@ +-r requirements.txt +pre_commit==4.0.1 +pytest==8.3.3 From aa0a2231f96f22410ed6bb72f582d958886e04ed Mon Sep 17 00:00:00 2001 From: PE39806 <185931318+PE39806@users.noreply.github.com> Date: Tue, 12 Nov 2024 09:33:51 +0000 Subject: [PATCH 7/7] BAI-1500 add missing get_settings Depends --- lib/modelscan_api/bailo_modelscan_api/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/modelscan_api/bailo_modelscan_api/main.py b/lib/modelscan_api/bailo_modelscan_api/main.py index 288049e15..e91049358 100644 --- a/lib/modelscan_api/bailo_modelscan_api/main.py +++ b/lib/modelscan_api/bailo_modelscan_api/main.py @@ -12,7 +12,7 @@ import uvicorn from bailo_modelscan_api.config import Settings from bailo_modelscan_api.dependencies import parse_path -from fastapi import BackgroundTasks, FastAPI, HTTPException, UploadFile +from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, UploadFile from modelscan.modelscan import ModelScan @@ -34,6 +34,7 @@ def get_settings() -> Settings: summary=get_settings().app_summary, description=get_settings().app_description, version=get_settings().app_version, + dependencies=[Depends(get_settings)], ) # Instantiating ModelScan