-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from groundlight/dev
Added upload and evaluation scripts with simple instructions
- Loading branch information
Showing
8 changed files
with
1,745 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# pdm | ||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
#pdm.lock | ||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
# in version control. | ||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control | ||
.pdm.toml | ||
.pdm-python | ||
.pdm-build/ | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# misc | ||
.DS_Store | ||
*.pem | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,91 @@ | ||
# model-evaluation-tool | ||
A simple tool for evaluating the performance of your Groundlight ML model | ||
# Model Evaluation Tool | ||
A simple tool for evaluating the performance of your Groundlight Binary ML model. | ||
|
||
This script provides a simple way for users to do an independent evaluation of the ML's performance. Note that this is not the recommended way of using our service, as this only evaluates ML performance and not the combined performance of our ML + escalation system. However, the balanced accuracy results from `evaluate.py` should fall within the bounds of Projected ML Accuracy shown on our website, if the train and evaluation dataset that the user provided are well randomized. | ||
|
||
## Installation | ||
|
||
The dependencies for this script can be installed by either using poetry (recommended) or `requirements.txt`. | ||
|
||
Using poetry | ||
|
||
```bash | ||
poetry install | ||
``` | ||
|
||
Using `requirements.txt` | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Usage | ||
|
||
### Setting Up Your Account | ||
|
||
To train a ML model, make sure to create a binary detector on the [Online Dashboard](https://dashboard.groundlight.ai/). | ||
|
||
You will also need to create an API Token to start uploading images to the account. You can go [here](https://dashboard.groundlight.ai/reef/my-account/api-tokens) to create one. | ||
|
||
After you have created your API token, add the token to your terminal as an variable: | ||
|
||
```bash | ||
export GROUNDLIIGHT_API_TOKEN="YOUR_API_TOKEN" | ||
``` | ||
|
||
### Formatting Dataset | ||
|
||
This script assumes your custom image dataset is structured in the following format: | ||
|
||
```bash | ||
└── dataset | ||
├── dataset.csv | ||
└── images | ||
├── 1.jpg | ||
├── 10.jpg | ||
├── 11.jpg | ||
├── 12.jpg | ||
├── 13.jpg | ||
├── 14.jpg | ||
``` | ||
|
||
The `dataset.csv` file should have two columns: image_name and label (YES/NO), for example: | ||
|
||
```bash | ||
1.jpg,YES | ||
11.jpg,NO | ||
12.jpg,YES | ||
13.jpg,YES | ||
14.jpg,NO | ||
``` | ||
|
||
The corresponding image file should be placed inside the `images` folder. | ||
|
||
### Training the Detector | ||
|
||
To train the ML model for a detector, simply run the script `train.py` with the following arguments: | ||
|
||
```bash | ||
poetry run python train.py --detector-name NAME_OF_THE_DETECTOR --detector-query QUERY_OF_THE_DETECTOR --dataset PATH_TO_DATASET_TRAIN_FOLDER | ||
``` | ||
|
||
Optionally, set the `--delay` argument to prevent going over the throttling limit of your account. | ||
|
||
### Evaluate the Detector | ||
|
||
To evaluate the ML model performance for a detector, simply run the script `evaluate.py` with the following arguments: | ||
|
||
```bash | ||
poetry run python evaluate.py --detector-id YOUR_DETECTOR_ID --dataset PATH_TO_DATASET_TEST_FOLDER | ||
``` | ||
|
||
Optionally, set the `--delay` argument to prevent going over the throttling limit of your account. | ||
|
||
The evaluation script will output the following information: | ||
|
||
``` | ||
Number of Correct ML Predictions | ||
Average Confidence | ||
Balanced Accuracy | ||
Precision | ||
Recall | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
A script to evaluate the accuracy of a detector on a given dataset. | ||
It will upload the images to the detector and compare the predicted labels with the ground truth labels. | ||
You can specify the delay between uploads. | ||
""" | ||
|
||
import argparse | ||
import os | ||
import PIL | ||
import time | ||
import PIL.Image | ||
import pandas as pd | ||
import logging | ||
|
||
from groundlight import Groundlight, Detector, BinaryClassificationResult | ||
from tqdm.auto import tqdm | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def upload_image(gl: Groundlight, detector: Detector, image: PIL) -> BinaryClassificationResult: | ||
""" | ||
Upload a image with a label to a detector. | ||
Args: | ||
gl: The Groundlight object. | ||
detector: The detector to upload to. | ||
image: The image to upload. | ||
Returns: | ||
The predicted label (YES/NO). | ||
""" | ||
|
||
# Convert image to jpg if not already | ||
if image.format != "JPEG": | ||
image = image.convert("RGB") | ||
|
||
# Use ask_ml to upload the image and then return the result | ||
iq = gl.ask_ml(detector=detector, image=image) | ||
return iq.result | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Evaluate the accuracy of a detector on a given dataset.") | ||
parser.add_argument("--detector-id", type=str, required=True, help="The ID of the detector to evaluate.") | ||
parser.add_argument("--dataset", type=str, required=True, help="The folder containing the dataset.csv and images folder") | ||
parser.add_argument("--delay", type=float, required=False, default=0.1, help="The delay between uploads.") | ||
args = parser.parse_args() | ||
|
||
gl = Groundlight() | ||
detector = gl.get_detector(args.detector_id) | ||
|
||
# Load the dataset from the CSV file and images from the images folder | ||
# The CSV file should have two columns: image_name and label (YES/NO) | ||
|
||
dataset = pd.read_csv(os.path.join(args.dataset, "dataset.csv")) | ||
images = os.listdir(os.path.join(args.dataset, "images")) | ||
|
||
logger.info(f"Evaluating {len(dataset)} images on detector {detector.name} with delay {args.delay}.") | ||
|
||
# Record the number of correct predictions | ||
# Also record the number of TP, TN, FP, FN for calculating balanced accuracy, precision, and recall | ||
true_positives = 0 | ||
true_negatives = 0 | ||
false_positives = 0 | ||
false_negatives = 0 | ||
total_processed = 0 | ||
average_confidence = 0 | ||
|
||
for image_name, label in tqdm(dataset.values): | ||
if image_name not in images: | ||
logger.warning(f"Image {image_name} not found in images folder.") | ||
continue | ||
|
||
if label not in ["YES", "NO"]: | ||
logger.warning(f"Invalid label {label} for image {image_name}. Skipping.") | ||
continue | ||
|
||
image = PIL.Image.open(os.path.join(args.dataset, "images", image_name)) | ||
result = upload_image(gl=gl, detector=detector, image=image) | ||
|
||
if result.label == "YES" and label == "YES": | ||
true_positives += 1 | ||
elif result.label == "NO" and label == "NO": | ||
true_negatives += 1 | ||
elif result.label == "YES" and label == "NO": | ||
false_positives += 1 | ||
elif result.label == "NO" and label == "YES": | ||
false_negatives += 1 | ||
|
||
average_confidence += result.confidence | ||
total_processed += 1 | ||
|
||
time.sleep(args.delay) | ||
|
||
# Calculate the accuracy, precision, and recall | ||
balanced_accuracy = (true_positives / (true_positives + false_negatives) + true_negatives / (true_negatives + false_positives)) / 2 | ||
precision = true_positives / (true_positives + false_positives) | ||
recall = true_positives / (true_positives + false_negatives) | ||
|
||
logger.info(f"Processed {total_processed} images.") | ||
logger.info(f"Average Confidence: {average_confidence / total_processed:.2f}") | ||
logger.info(f"Balanced Accuracy: {balanced_accuracy:.2f}") | ||
logger.info(f"Precision: {precision:.2f}") | ||
logger.info(f"Recall: {recall:.2f}") |
Oops, something went wrong.