-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1612 from gchq/feature/BAI-1500-create-a-model-sc…
…an-rest-application BAI-1500 ModelScan REST API
- Loading branch information
Showing
10 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# Environments | ||
.env | ||
*env/ | ||
*ENV/ | ||
*env.bak/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
ci: | ||
autoupdate_commit_msg: 'chore: update pre-commit hooks' | ||
autofix_commit_msg: 'style: pre-commit fixes' | ||
|
||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.5.0 | ||
hooks: | ||
- id: check-added-large-files | ||
- id: check-case-conflict | ||
- id: check-merge-conflict | ||
- id: check-symlinks | ||
- id: check-yaml | ||
- id: debug-statements | ||
- id: end-of-file-fixer | ||
- id: mixed-line-ending | ||
- id: requirements-txt-fixer | ||
- id: trailing-whitespace | ||
|
||
- repo: https://github.com/PyCQA/isort | ||
rev: 5.12.0 | ||
hooks: | ||
- id: isort | ||
args: ['-a', 'from __future__ import annotations'] | ||
|
||
- repo: https://github.com/asottile/pyupgrade | ||
rev: v3.15.0 | ||
hooks: | ||
- id: pyupgrade | ||
args: [--py37-plus] | ||
|
||
- repo: https://github.com/hadialqattan/pycln | ||
rev: v2.4.0 | ||
hooks: | ||
- id: pycln | ||
args: [--config=pyproject.toml] | ||
stages: [manual] | ||
|
||
- repo: https://github.com/codespell-project/codespell | ||
rev: v2.2.6 | ||
hooks: | ||
- id: codespell | ||
|
||
- repo: https://github.com/pre-commit/pygrep-hooks | ||
rev: v1.10.0 | ||
hooks: | ||
- id: python-check-blanket-noqa | ||
- id: python-check-blanket-type-ignore | ||
- id: python-no-log-warn | ||
- id: python-no-eval | ||
- id: python-use-type-annotations | ||
- id: rst-backticks | ||
- id: rst-directive-colons | ||
- id: rst-inline-touching-normal | ||
|
||
- repo: https://github.com/psf/black-pre-commit-mirror | ||
rev: 23.11.0 | ||
hooks: | ||
- id: black | ||
args: [--line-length=120] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# ModelScan | ||
|
||
This directory provides all of the necessary functionality to interact with | ||
[modelscan](https://github.com/protectai/modelscan/tree/main) as a REST API. | ||
|
||
> ModelScan is an open source project from | ||
> [Protect AI](https://protectai.com/?utm_campaign=Homepage&utm_source=ModelScan%20GitHub%20Page&utm_medium=cta&utm_content=Open%20Source) | ||
> that scans models to determine if they contain unsafe code. It is the first model scanning tool to support multiple | ||
> model formats. ModelScan currently supports: H5, Pickle, and SavedModel formats. This protects you when using PyTorch, | ||
> TensorFlow, Keras, Sklearn, XGBoost, with more on the way. | ||
## Setup | ||
|
||
Create and activate a virtual environment | ||
|
||
```bash | ||
python3 -m venv modelscan-venv | ||
source modelscan-venv/bin/activate | ||
``` | ||
|
||
Install the required pip packages | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Usage | ||
|
||
Create and populate a `.env` file to override and set any variables, including sensitive properties. | ||
|
||
Run: | ||
|
||
```bash | ||
fastapi run bailo_modelscan_api/main.py | ||
``` | ||
|
||
Connect via the local endpoint: `http://127.0.0.1:8000` | ||
|
||
View the swagger docs: `http://127.0.0.1:8000/docs` | ||
|
||
## Development | ||
|
||
### Install dev packages | ||
|
||
If already working on Bailo you may be prompted to overwrite Husky. Follow the instructions given by Git CLI. | ||
|
||
```bash | ||
pip install -r requirements-dev.txt | ||
pre-commit install | ||
``` | ||
|
||
### Tests | ||
|
||
To run the tests: | ||
|
||
```bash | ||
pytest | ||
``` | ||
|
||
### Running | ||
|
||
To run in [dev mode](https://fastapi.tiangolo.com/fastapi-cli/#fastapi-dev): | ||
|
||
```bash | ||
fastapi dev bailo_modelscan_api/main.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
logging.getLogger(__name__).addHandler(logging.NullHandler()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Configuration settings for FastAPI app. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Any | ||
|
||
from modelscan.settings import DEFAULT_SETTINGS | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Settings(BaseSettings): | ||
"""Basic settings object for the FastAPI app. | ||
:param BaseSettings: Default template object. | ||
""" | ||
|
||
app_name: str = "Bailo ModelScan API" | ||
app_summary: str = "REST API wrapper for ModelScan package for use with Bailo." | ||
app_description: str = """ | ||
Bailo ModelScan API allows for easy programmatic interfacing with ProtectAI's ModelScan package to scan and detect potential threats within files stored in Bailo. | ||
You can upload files and view modelscan's result.""" | ||
app_version: str = "1.0.0" | ||
download_dir: str = "." | ||
modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS | ||
block_size: int = 1024 | ||
|
||
# Load in a dotenv file to set/overwrite any properties with potentially sensitive values | ||
model_config = SettingsConfigDict(env_file=".env") | ||
|
||
|
||
logger.info("Instantiating settings.") | ||
settings = Settings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Common utilities used by the FastAPI app. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from pathlib import Path | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def parse_path(path: str | Path | None) -> Path: | ||
"""Ensure that a path is consistently represented as a Path. | ||
:param path: System path to parse. Defaults to the file's current working directory if unspecified. | ||
:return: An absolute Path representation of the path parameter. | ||
""" | ||
logger.info("Parsing path.") | ||
if path is None: | ||
path = "." | ||
return Path().cwd() if path == "." else Path(path).absolute() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
"""FastAPI app. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from functools import lru_cache | ||
from http import HTTPStatus | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import uvicorn | ||
from bailo_modelscan_api.config import Settings | ||
from bailo_modelscan_api.dependencies import parse_path | ||
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, UploadFile | ||
from modelscan.modelscan import ModelScan | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@lru_cache | ||
def get_settings() -> Settings: | ||
"""Fast way to only load settings from dotenv once. | ||
:return: Evaluated Settings from config file. | ||
""" | ||
return Settings() | ||
|
||
|
||
# Instantiate FastAPI app with various dependencies. | ||
app = FastAPI( | ||
title=get_settings().app_name, | ||
summary=get_settings().app_summary, | ||
description=get_settings().app_description, | ||
version=get_settings().app_version, | ||
dependencies=[Depends(get_settings)], | ||
) | ||
|
||
# Instantiating ModelScan | ||
modelscan = ModelScan(settings=get_settings().modelscan_settings) | ||
|
||
|
||
@app.post( | ||
"/scan/file", | ||
summary="Upload and scan a file", | ||
description="Upload a file which is scanned by ModelScan and return the result of the scan", | ||
status_code=HTTPStatus.OK, | ||
response_description="The result from ModelScan", | ||
) | ||
def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[str, Any]: | ||
"""API endpoint to upload and scan a file using modelscan. | ||
:param in_file: uploaded file to be scanned | ||
:param background_tasks: FastAPI object to perform background tasks once the function has already returned. | ||
:raises HTTPException: failure to process the uploaded file in any way | ||
:return: `modelscan.scan` results | ||
""" | ||
logger.info("Called the API endpoint to scan an uploaded file") | ||
try: | ||
if in_file.filename and str(in_file.filename).strip(): | ||
pathlib_path = Path.joinpath(parse_path(get_settings().download_dir), str(in_file.filename)) | ||
else: | ||
raise HTTPException( | ||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, | ||
detail="An error occurred while extracting the uploaded file's name.", | ||
) | ||
|
||
# Write the streamed in_file to disk. | ||
# This is a bit silly as modelscan will ultimately load this back into memory, but modelscan | ||
# doesn't currently support streaming. | ||
try: | ||
with open(pathlib_path, "wb") as out_file: | ||
while content := in_file.file.read(get_settings().block_size): | ||
out_file.write(content) | ||
except OSError as exception: | ||
logger.exception("Failed writing the file to the disk.") | ||
raise HTTPException( | ||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, | ||
detail=f"An error occurred while trying to write the uploaded file to the disk: {exception}", | ||
) from exception | ||
|
||
# Scan the uploaded file. | ||
result = modelscan.scan(pathlib_path) | ||
|
||
# Finally, return the result. | ||
return result | ||
|
||
except HTTPException: | ||
# Re-raise HTTPExceptions. | ||
logger.exception("Re-raising HTTPException.") | ||
raise | ||
|
||
except Exception as exception: | ||
logger.exception("An unexpected error occurred.") | ||
raise HTTPException( | ||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, | ||
detail=f"An error occurred: {exception}", | ||
) from exception | ||
|
||
finally: | ||
try: | ||
# Clean up the downloaded file as a background task to allow returning sooner. | ||
logger.info("Cleaning up downloaded file.") | ||
background_tasks.add_task(Path.unlink, pathlib_path, missing_ok=True) | ||
except UnboundLocalError: | ||
# pathlib_path may not exist. | ||
logger.exception("An error occurred while trying to cleanup the downloaded file.") | ||
|
||
|
||
if __name__ == "__main__": | ||
logger.info("Starting the application programmatically.") | ||
uvicorn.run(app) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
"""Test for the main.py file. | ||
""" | ||
|
||
from pathlib import Path | ||
from unittest.mock import Mock, patch | ||
from fastapi.testclient import TestClient | ||
|
||
from .dependencies import parse_path | ||
from .main import app, get_settings | ||
|
||
client = TestClient(app) | ||
|
||
|
||
@patch("modelscan.modelscan.ModelScan.scan") | ||
def test_scan_file(mock_scan: Mock): | ||
mock_scan.return_value = {} | ||
files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")} | ||
|
||
response = client.post("/scan/file", files=files) | ||
|
||
assert response.status_code == 200 | ||
mock_scan.assert_called_once() | ||
|
||
|
||
@patch("modelscan.modelscan.ModelScan.scan") | ||
def test_scan_file_exception(mock_scan: Mock): | ||
mock_scan.side_effect = Exception("Mocked error!") | ||
files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")} | ||
|
||
response = client.post("/scan/file", files=files) | ||
|
||
assert response.status_code == 500 | ||
assert response.json() == {"detail": "An error occurred: Mocked error!"} | ||
mock_scan.assert_called_once() | ||
|
||
# Manually cleanup as FastAPI won't trigger background_tasks on Exception due to using TestClient. | ||
Path.unlink(Path.joinpath(parse_path(get_settings().download_dir), "foo.h5"), missing_ok=True) | ||
|
||
|
||
def test_scan_file_filename_missing(): | ||
files = {"in_file": (" ", rb"", "application/x-hdf5")} | ||
|
||
response = client.post("/scan/file", files=files) | ||
|
||
assert response.status_code == 500 | ||
assert response.json() == {"detail": "An error occurred while extracting the uploaded file's name."} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
-r requirements.txt | ||
pre_commit==4.0.1 | ||
pytest==8.3.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
fastapi[standard]==0.115.4 | ||
modelscan[tensorflow,h5py]==0.8.1 | ||
pydantic_settings==2.6.1 | ||
uvicorn==0.32.0 |