Skip to content

Commit

Permalink
Merge pull request #1612 from gchq/feature/BAI-1500-create-a-model-sc…
Browse files Browse the repository at this point in the history
…an-rest-application

BAI-1500 ModelScan REST API
  • Loading branch information
PE39806 authored Nov 13, 2024
2 parents 795c6c6 + aa0a223 commit 9df2d26
Show file tree
Hide file tree
Showing 10 changed files with 365 additions and 0 deletions.
10 changes: 10 additions & 0 deletions lib/modelscan_api/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Environments
.env
*env/
*ENV/
*env.bak/
60 changes: 60 additions & 0 deletions lib/modelscan_api/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
ci:
autoupdate_commit_msg: 'chore: update pre-commit hooks'
autofix_commit_msg: 'style: pre-commit fixes'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: mixed-line-ending
- id: requirements-txt-fixer
- id: trailing-whitespace

- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
args: ['-a', 'from __future__ import annotations']

- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py37-plus]

- repo: https://github.com/hadialqattan/pycln
rev: v2.4.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
stages: [manual]

- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
hooks:
- id: codespell

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: python-check-blanket-noqa
- id: python-check-blanket-type-ignore
- id: python-no-log-warn
- id: python-no-eval
- id: python-use-type-annotations
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.11.0
hooks:
- id: black
args: [--line-length=120]
66 changes: 66 additions & 0 deletions lib/modelscan_api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# ModelScan

This directory provides all of the necessary functionality to interact with
[modelscan](https://github.com/protectai/modelscan/tree/main) as a REST API.

> ModelScan is an open source project from
> [Protect AI](https://protectai.com/?utm_campaign=Homepage&utm_source=ModelScan%20GitHub%20Page&utm_medium=cta&utm_content=Open%20Source)
> that scans models to determine if they contain unsafe code. It is the first model scanning tool to support multiple
> model formats. ModelScan currently supports: H5, Pickle, and SavedModel formats. This protects you when using PyTorch,
> TensorFlow, Keras, Sklearn, XGBoost, with more on the way.
## Setup

Create and activate a virtual environment

```bash
python3 -m venv modelscan-venv
source modelscan-venv/bin/activate
```

Install the required pip packages

```bash
pip install -r requirements.txt
```

## Usage

Create and populate a `.env` file to override and set any variables, including sensitive properties.

Run:

```bash
fastapi run bailo_modelscan_api/main.py
```

Connect via the local endpoint: `http://127.0.0.1:8000`

View the swagger docs: `http://127.0.0.1:8000/docs`

## Development

### Install dev packages

If already working on Bailo you may be prompted to overwrite Husky. Follow the instructions given by Git CLI.

```bash
pip install -r requirements-dev.txt
pre-commit install
```

### Tests

To run the tests:

```bash
pytest
```

### Running

To run in [dev mode](https://fastapi.tiangolo.com/fastapi-cli/#fastapi-dev):

```bash
fastapi dev bailo_modelscan_api/main.py
```
5 changes: 5 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

import logging

logging.getLogger(__name__).addHandler(logging.NullHandler())
37 changes: 37 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Configuration settings for FastAPI app.
"""

from __future__ import annotations

import logging
from typing import Any

from modelscan.settings import DEFAULT_SETTINGS
from pydantic_settings import BaseSettings, SettingsConfigDict

logger = logging.getLogger(__name__)


class Settings(BaseSettings):
"""Basic settings object for the FastAPI app.
:param BaseSettings: Default template object.
"""

app_name: str = "Bailo ModelScan API"
app_summary: str = "REST API wrapper for ModelScan package for use with Bailo."
app_description: str = """
Bailo ModelScan API allows for easy programmatic interfacing with ProtectAI's ModelScan package to scan and detect potential threats within files stored in Bailo.
You can upload files and view modelscan's result."""
app_version: str = "1.0.0"
download_dir: str = "."
modelscan_settings: dict[str, Any] = DEFAULT_SETTINGS
block_size: int = 1024

# Load in a dotenv file to set/overwrite any properties with potentially sensitive values
model_config = SettingsConfigDict(env_file=".env")


logger.info("Instantiating settings.")
settings = Settings()
21 changes: 21 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Common utilities used by the FastAPI app.
"""

from __future__ import annotations

import logging
from pathlib import Path

logger = logging.getLogger(__name__)


def parse_path(path: str | Path | None) -> Path:
"""Ensure that a path is consistently represented as a Path.
:param path: System path to parse. Defaults to the file's current working directory if unspecified.
:return: An absolute Path representation of the path parameter.
"""
logger.info("Parsing path.")
if path is None:
path = "."
return Path().cwd() if path == "." else Path(path).absolute()
113 changes: 113 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""FastAPI app.
"""

from __future__ import annotations

import logging
from functools import lru_cache
from http import HTTPStatus
from pathlib import Path
from typing import Any

import uvicorn
from bailo_modelscan_api.config import Settings
from bailo_modelscan_api.dependencies import parse_path
from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, UploadFile
from modelscan.modelscan import ModelScan


logger = logging.getLogger(__name__)


@lru_cache
def get_settings() -> Settings:
"""Fast way to only load settings from dotenv once.
:return: Evaluated Settings from config file.
"""
return Settings()


# Instantiate FastAPI app with various dependencies.
app = FastAPI(
title=get_settings().app_name,
summary=get_settings().app_summary,
description=get_settings().app_description,
version=get_settings().app_version,
dependencies=[Depends(get_settings)],
)

# Instantiating ModelScan
modelscan = ModelScan(settings=get_settings().modelscan_settings)


@app.post(
"/scan/file",
summary="Upload and scan a file",
description="Upload a file which is scanned by ModelScan and return the result of the scan",
status_code=HTTPStatus.OK,
response_description="The result from ModelScan",
)
def scan_file(in_file: UploadFile, background_tasks: BackgroundTasks) -> dict[str, Any]:
"""API endpoint to upload and scan a file using modelscan.
:param in_file: uploaded file to be scanned
:param background_tasks: FastAPI object to perform background tasks once the function has already returned.
:raises HTTPException: failure to process the uploaded file in any way
:return: `modelscan.scan` results
"""
logger.info("Called the API endpoint to scan an uploaded file")
try:
if in_file.filename and str(in_file.filename).strip():
pathlib_path = Path.joinpath(parse_path(get_settings().download_dir), str(in_file.filename))
else:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="An error occurred while extracting the uploaded file's name.",
)

# Write the streamed in_file to disk.
# This is a bit silly as modelscan will ultimately load this back into memory, but modelscan
# doesn't currently support streaming.
try:
with open(pathlib_path, "wb") as out_file:
while content := in_file.file.read(get_settings().block_size):
out_file.write(content)
except OSError as exception:
logger.exception("Failed writing the file to the disk.")
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred while trying to write the uploaded file to the disk: {exception}",
) from exception

# Scan the uploaded file.
result = modelscan.scan(pathlib_path)

# Finally, return the result.
return result

except HTTPException:
# Re-raise HTTPExceptions.
logger.exception("Re-raising HTTPException.")
raise

except Exception as exception:
logger.exception("An unexpected error occurred.")
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"An error occurred: {exception}",
) from exception

finally:
try:
# Clean up the downloaded file as a background task to allow returning sooner.
logger.info("Cleaning up downloaded file.")
background_tasks.add_task(Path.unlink, pathlib_path, missing_ok=True)
except UnboundLocalError:
# pathlib_path may not exist.
logger.exception("An error occurred while trying to cleanup the downloaded file.")


if __name__ == "__main__":
logger.info("Starting the application programmatically.")
uvicorn.run(app)
46 changes: 46 additions & 0 deletions lib/modelscan_api/bailo_modelscan_api/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Test for the main.py file.
"""

from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient

from .dependencies import parse_path
from .main import app, get_settings

client = TestClient(app)


@patch("modelscan.modelscan.ModelScan.scan")
def test_scan_file(mock_scan: Mock):
mock_scan.return_value = {}
files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")}

response = client.post("/scan/file", files=files)

assert response.status_code == 200
mock_scan.assert_called_once()


@patch("modelscan.modelscan.ModelScan.scan")
def test_scan_file_exception(mock_scan: Mock):
mock_scan.side_effect = Exception("Mocked error!")
files = {"in_file": ("foo.h5", rb"", "application/x-hdf5")}

response = client.post("/scan/file", files=files)

assert response.status_code == 500
assert response.json() == {"detail": "An error occurred: Mocked error!"}
mock_scan.assert_called_once()

# Manually cleanup as FastAPI won't trigger background_tasks on Exception due to using TestClient.
Path.unlink(Path.joinpath(parse_path(get_settings().download_dir), "foo.h5"), missing_ok=True)


def test_scan_file_filename_missing():
files = {"in_file": (" ", rb"", "application/x-hdf5")}

response = client.post("/scan/file", files=files)

assert response.status_code == 500
assert response.json() == {"detail": "An error occurred while extracting the uploaded file's name."}
3 changes: 3 additions & 0 deletions lib/modelscan_api/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
pre_commit==4.0.1
pytest==8.3.3
4 changes: 4 additions & 0 deletions lib/modelscan_api/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
fastapi[standard]==0.115.4
modelscan[tensorflow,h5py]==0.8.1
pydantic_settings==2.6.1
uvicorn==0.32.0

0 comments on commit 9df2d26

Please sign in to comment.