Skip to content

Commit

Permalink
Add binning to custom label prediction
Browse files Browse the repository at this point in the history
Allow users to group custom label predictions into bins.

Fixes #29
  • Loading branch information
johnbradley committed Oct 9, 2024
1 parent 7f2041b commit a658c25
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 33 deletions.
55 changes: 51 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,29 @@ fish 2.932403668845507e-12
bear 1.0
```

### Predict from a list of classes with binning
```python
from bioclip import CustomLabelsBinningClassifier
classifier = CustomLabelsBinningClassifier(cls_to_bin={
'dog': 'small',
'fish': 'small',
'bear': 'big',
})
predictions = classifier.predict("Ursus-arctos.jpeg")
for prediction in predictions:
print(prediction["classification"], prediction["score"])
```
Output:
```console
big 0.99992835521698
small 7.165559509303421e-05
```

## Command Line Usage
```
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT]
[--rank {kingdom,phylum,class,order,family,genus,species} | --cls CLS | --bins BINS]
[--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...]
Commands:
Expand All @@ -117,9 +137,13 @@ Arguments:
Options:
-h --help
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
--k=K number of top predictions to show [default: 5]
--cls=CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed.
--rank {kingdom,phylum,class,order,family,genus,species}
rank of the classification, default: species (when)
--cls CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the
--rank and --bins arguments are not allowed.
--bins BINS path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls and
--bins arguments are not allowed.
--k K number of top predictions to show, default: 5
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
--output=OUTFILE print output to file OUTFILE [default: stdout]
```
Expand Down Expand Up @@ -195,6 +219,29 @@ Ursus-arctos.jpeg,bird,3.051998476166773e-08
Ursus-arctos.jpeg,bear,0.9999998807907104
```

### Predict from a binning CSV
Create predictions for 3 classes (cat, bird, and bear) with 2 bins (one, two) for image `Ursus-arctos.jpeg`:

Create a CSV file named `bins.csv` with the following contents:
```
cls,bin
cat,one
bird,one
bear,two
```
The names of the columns do not matter. The first column values will be used as the classes. The second column values will be used for bin names.

Run predict command:
```console
bioclip predict --bins bins.csv Ursus-arctos.jpeg
```

Output:
```
Ursus-arctos.jpeg,two,0.9999998807907104
Ursus-arctos.jpeg,one,7.633736487377973e-08
```

### Create embeddings

#### Create embedding for an image
Expand Down
4 changes: 2 additions & 2 deletions src/bioclip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2024-present John Bradley <[email protected]>
#
# SPDX-License-Identifier: MIT
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier

__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier"]
__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier", "CustomLabelsBinningClassifier"]
34 changes: 24 additions & 10 deletions src/bioclip/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier
from .predict import BIOCLIP_MODEL_STR
import open_clip as oc
import os
Expand Down Expand Up @@ -32,17 +32,32 @@ def write_results_to_file(df, format, outfile):
raise ValueError(f"Invalid format: {format}")


def parse_bins_csv(bins_path):
if not os.path.exists(bins_path):
raise FileNotFoundError(f"File not found: {bins_path}")
bin_df = pd.read_csv(bins_path, index_col=0)
if len(bin_df.columns) == 0:
raise ValueError("CSV file must have at least two columns.")
return bin_df[bin_df.columns[0]].to_dict()


def predict(image_file: list[str],
format: str,
output: str,
cls_str: str,
rank: Rank,
bins_path: str,
k: int,
**kwargs):
if cls_str:
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
elif bins_path:
cls_to_bin = parse_bins_csv(bins_path)
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
else:
classifier = TreeOfLifeClassifier(**kwargs)
predictions = classifier.predict(image_paths=image_file, rank=rank, k=k)
Expand Down Expand Up @@ -81,11 +96,13 @@ def create_parser():
predict_parser.add_argument('image_file', nargs='+', help='input image file(s)')
predict_parser.add_argument('--format', choices=['table', 'csv'], default='csv', help='format of the output, default: csv')
predict_parser.add_argument('--output', **output_arg)
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
cls_group.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
help='rank of the classification, default: species (when)')
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank and --bins arguments are not allowed."
cls_group.add_argument('--cls', help=cls_help)
cls_group.add_argument('--bins', help='path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls and --bins arguments are not allowed.')
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed."
predict_parser.add_argument('--cls', help=cls_help)

predict_parser.add_argument('--device', **device_arg)
predict_parser.add_argument('--model', **model_arg)
Expand Down Expand Up @@ -115,11 +132,7 @@ def create_parser():
def parse_args(input_args=None):
args = create_parser().parse_args(input_args)
if args.command == 'predict':
if args.cls:
# custom class list mode
if args.rank:
raise ValueError("Cannot use --cls with --rank")
else:
if not args.cls and not args.bins:
# tree of life class list mode
if args.model or args.pretrained:
raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction")
Expand Down Expand Up @@ -155,6 +168,7 @@ def main():
output=args.output,
cls_str=cls_str,
rank=args.rank,
bins_path=args.bins,
k=args.k,
device=args.device,
model_str=args.model,
Expand All @@ -167,7 +181,7 @@ def main():
for model_str in oc.list_models():
print(f"\t{model_str}")
else:
raise ValueError("Invalid command")
create_parser().print_help()


if __name__ == '__main__':
Expand Down
40 changes: 33 additions & 7 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,39 @@ def predict(self, image_paths: List[str] | str, k: int = None) -> dict[str, floa
img_probs = probs[image_path]
if not k or k > len(self.classes):
k = len(self.classes)
topk = img_probs.topk(k)
for i, prob in zip(topk.indices, topk.values):
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: self.classes[i],
PRED_SCORE_KEY: prob.item()
})
result.extend(self.group_probs(image_path, img_probs, k))
return result

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
topk = img_probs.topk(k)
for i, prob in zip(topk.indices, topk.values):
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: self.classes[i],
PRED_SCORE_KEY: prob.item()
})
return result


class CustomLabelsBinningClassifier(CustomLabelsClassifier):
def __init__(self, cls_to_bin: dict, **kwargs):
super().__init__(cls_ary=cls_to_bin.keys(), **kwargs)
self.cls_to_bin = cls_to_bin

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
output = collections.defaultdict(float)
for i in range(len(self.classes)):
name = self.cls_to_bin[self.classes[i]]
output[name] += img_probs[i]
topk_names = heapq.nlargest(k, output, key=output.get)
for name in topk_names:
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: name,
PRED_SCORE_KEY: output[name].item()
})
return result


Expand Down
67 changes: 58 additions & 9 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import mock_open, patch
import argparse
from bioclip.__main__ import parse_args, Rank, create_classes_str, main
import pandas as pd
from bioclip.__main__ import parse_args, Rank, create_classes_str, main, parse_bins_csv


class TestParser(unittest.TestCase):
Expand All @@ -15,6 +16,7 @@ def test_parse_args(self):
self.assertEqual(args.rank, Rank.SPECIES)
self.assertEqual(args.k, 5)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cpu')

args = parse_args(['predict', 'image.jpg', 'image2.png'])
Expand All @@ -41,12 +43,29 @@ def test_parse_args(self):
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, 'class1,class2')
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cuda')

# test binning version of predict
args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--bins', 'bins.csv', '--device', 'cuda'])
self.assertEqual(args.command, 'predict')
self.assertEqual(args.image_file, ['image.jpg'])
self.assertEqual(args.format, 'table')
self.assertEqual(args.output, 'output.csv')
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, 'bins.csv')
self.assertEqual(args.device, 'cuda')

# test error when using --cls with --rank
with self.assertRaises(ValueError):
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus'])

# test error when using --cls with --bins
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--bins', 'somefile.csv', 'genus'])

# not an error when using --cls with --k
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
self.assertEqual(args.k, 10)
Expand Down Expand Up @@ -77,10 +96,10 @@ def test_create_classes_str(self):
def test_predict_no_class(self, mock_parse_args, mock_predict):
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls=None, device='cpu',
model=None, pretrained=None)
model=None, pretrained=None, bins=None)
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, k=5,
device='cpu', model_str=None, pretrained_str=None)
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES,
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
Expand All @@ -89,10 +108,10 @@ def test_predict_class_list(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = False
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls='dog,fish,bird',
device='cpu', model=None, pretrained=None)
device='cpu', model=None, pretrained=None, bins=None)
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
k=5, device='cpu', model_str=None, pretrained_str=None)
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
Expand All @@ -101,8 +120,38 @@ def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = True
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls='somefile.txt',
device='cpu', model=None, pretrained=None)
device='cpu', model=None, pretrained=None, bins=None)
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
k=5, device='cpu', model_str=None, pretrained_str=None)
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
@patch('bioclip.__main__.os')
def test_predict_bins(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = True
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=None, k=5, cls=None,
device='cpu', model=None, pretrained=None,
bins='some.csv')
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=None,
bins_path='some.csv', k=5, device='cpu', model_str=None, pretrained_str=None)
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv_file_missing(self, mock_path):
mock_path.exists.return_value = False
with self.assertRaises(FileNotFoundError) as raised_exception:
parse_bins_csv("somefile.csv")
self.assertEqual(str(raised_exception.exception), 'File not found: somefile.csv')

@patch('bioclip.__main__.pd')
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv(self, mock_path, mock_pd):
mock_path.exists.return_value = True
data = {'bin': ['a', 'b']}
mock_pd.read_csv.return_value = pd.DataFrame(data=data, index=['dog', 'cat'])
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
cls_to_bin = parse_bins_csv("somefile.csv")
self.assertEqual(cls_to_bin, {'cat': 'b', 'dog': 'a'})
24 changes: 23 additions & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from bioclip.predict import TreeOfLifeClassifier, Rank
from bioclip.predict import CustomLabelsClassifier
from bioclip.predict import CustomLabelsBinningClassifier
import os
import torch

Expand Down Expand Up @@ -81,13 +82,34 @@ def test_custom_labels_classifier_ary_multiple(self):
{'file_name': EXAMPLE_CAT_IMAGE2, 'classification': 'dog', 'score': unittest.mock.ANY},
])


def test_predict_with_rgba_image(self):
# Ensure that the classifier can handle RGBA images
classifier = TreeOfLifeClassifier()
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2], rank=Rank.SPECIES)
self.assertEqual(len(prediction_ary), 5)

def test_predict_with_bins(self):
classifier = CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': 'two',
'fish': 'two',
})
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
self.assertEqual(len(prediction_ary), 2)
self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2)
names = set([pred['classification'] for pred in prediction_ary])
self.assertEqual(names, set(['one', 'two']))

classifier = CustomLabelsBinningClassifier(cls_to_bin={
'cat': 'one',
'mouse': 'two',
'fish': 'three',
})
prediction_ary = classifier.predict(image_paths=[EXAMPLE_CAT_IMAGE2])
self.assertEqual(len(prediction_ary), 3)
self.assertEqual(prediction_ary[0]['file_name'], EXAMPLE_CAT_IMAGE2)
names = set([pred['classification'] for pred in prediction_ary])
self.assertEqual(names, set(['one', 'two', 'three']))

class TestEmbed(unittest.TestCase):
def test_get_image_features(self):
Expand Down

0 comments on commit a658c25

Please sign in to comment.