Skip to content

Commit

Permalink
Merge pull request #53 from Imageomics/29-custom-bins
Browse files Browse the repository at this point in the history
Add binning to custom label prediction
  • Loading branch information
johnbradley authored Oct 16, 2024
2 parents 44818f0 + ae33abf commit d8035ef
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 33 deletions.
55 changes: 51 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,29 @@ fish 2.932403668845507e-12
bear 1.0
```

### Predict from a list of classes with binning
```python
from bioclip import CustomLabelsBinningClassifier
classifier = CustomLabelsBinningClassifier(cls_to_bin={
'dog': 'small',
'fish': 'small',
'bear': 'big',
})
predictions = classifier.predict("Ursus-arctos.jpeg")
for prediction in predictions:
print(prediction["classification"], prediction["score"])
```
Output:
```console
big 0.99992835521698
small 7.165559509303421e-05
```

## Command Line Usage
```
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT] [--rank {kingdom,phylum,class,order,family,genus,species}] [--k K] [--cls CLS] [--device DEVICE] image_file [image_file ...]
bioclip predict [-h] [--format {table,csv}] [--output OUTPUT]
[--rank {kingdom,phylum,class,order,family,genus,species} | --cls CLS | --bins BINS]
[--k K] [--device DEVICE] image_file [image_file ...]
bioclip embed [-h] [--device=DEVICE] [--output=OUTPUT] [IMAGE_FILE...]
Commands:
Expand All @@ -117,9 +137,13 @@ Arguments:
Options:
-h --help
--format=FORMAT format of the output (table or csv) for predict mode [default: csv]
--rank=RANK rank of the classification (kingdom, phylum, class, order, family, genus, species) [default: species]
--k=K number of top predictions to show [default: 5]
--cls=CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed.
--rank {kingdom,phylum,class,order,family,genus,species}
rank of the classification, default: species (when)
--cls CLS classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the
--rank and --bins arguments are not allowed.
--bins BINS path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls
argument is not allowed.
--k K number of top predictions to show, default: 5
--device=DEVICE device to use matrix math (cpu or cuda or mps) [default: cpu]
--output=OUTFILE print output to file OUTFILE [default: stdout]
```
Expand Down Expand Up @@ -195,6 +219,29 @@ Ursus-arctos.jpeg,bird,3.051998476166773e-08
Ursus-arctos.jpeg,bear,0.9999998807907104
```

### Predict from a binning CSV
Create predictions for 3 classes (cat, bird, and bear) with 2 bins (one, two) for image `Ursus-arctos.jpeg`:

Create a CSV file named `bins.csv` with the following contents:
```
cls,bin
cat,one
bird,one
bear,two
```
The names of the columns do not matter. The first column values will be used as the classes. The second column values will be used for bin names.

Run predict command:
```console
bioclip predict --bins bins.csv Ursus-arctos.jpeg
```

Output:
```
Ursus-arctos.jpeg,two,0.9999998807907104
Ursus-arctos.jpeg,one,7.633736487377973e-08
```

### Create embeddings

#### Create embedding for an image
Expand Down
4 changes: 2 additions & 2 deletions src/bioclip/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2024-present John Bradley <[email protected]>
#
# SPDX-License-Identifier: MIT
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
from bioclip.predict import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier

__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier"]
__all__ = ["TreeOfLifeClassifier", "Rank", "CustomLabelsClassifier", "CustomLabelsBinningClassifier"]
34 changes: 24 additions & 10 deletions src/bioclip/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier
from bioclip import TreeOfLifeClassifier, Rank, CustomLabelsClassifier, CustomLabelsBinningClassifier
from .predict import BIOCLIP_MODEL_STR
import open_clip as oc
import os
Expand Down Expand Up @@ -32,17 +32,32 @@ def write_results_to_file(df, format, outfile):
raise ValueError(f"Invalid format: {format}")


def parse_bins_csv(bins_path):
if not os.path.exists(bins_path):
raise FileNotFoundError(f"File not found: {bins_path}")
bin_df = pd.read_csv(bins_path, index_col=0)
if len(bin_df.columns) == 0:
raise ValueError("CSV file must have at least two columns.")
return bin_df[bin_df.columns[0]].to_dict()


def predict(image_file: list[str],
format: str,
output: str,
cls_str: str,
rank: Rank,
bins_path: str,
k: int,
**kwargs):
if cls_str:
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
elif bins_path:
cls_to_bin = parse_bins_csv(bins_path)
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
predictions = classifier.predict(image_paths=image_file, k=k)
write_results(predictions, format, output)
else:
classifier = TreeOfLifeClassifier(**kwargs)
predictions = classifier.predict(image_paths=image_file, rank=rank, k=k)
Expand Down Expand Up @@ -81,11 +96,13 @@ def create_parser():
predict_parser.add_argument('image_file', nargs='+', help='input image file(s)')
predict_parser.add_argument('--format', choices=['table', 'csv'], default='csv', help='format of the output, default: csv')
predict_parser.add_argument('--output', **output_arg)
predict_parser.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
cls_group = predict_parser.add_mutually_exclusive_group(required=False)
cls_group.add_argument('--rank', choices=['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'],
help='rank of the classification, default: species (when)')
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank and --bins arguments are not allowed."
cls_group.add_argument('--cls', help=cls_help)
cls_group.add_argument('--bins', help='path to CSV file with two columns with the first being classes and second being bin names, when specified the --cls argument is not allowed.')
predict_parser.add_argument('--k', type=int, help='number of top predictions to show, default: 5')
cls_help = "classes to predict: either a comma separated list or a path to a text file of classes (one per line), when specified the --rank argument is not allowed."
predict_parser.add_argument('--cls', help=cls_help)

predict_parser.add_argument('--device', **device_arg)
predict_parser.add_argument('--model', **model_arg)
Expand Down Expand Up @@ -115,11 +132,7 @@ def create_parser():
def parse_args(input_args=None):
args = create_parser().parse_args(input_args)
if args.command == 'predict':
if args.cls:
# custom class list mode
if args.rank:
raise ValueError("Cannot use --cls with --rank")
else:
if not args.cls and not args.bins:
# tree of life class list mode
if args.model or args.pretrained:
raise ValueError("Custom model or checkpoints currently not supported for Tree-of-Life prediction")
Expand Down Expand Up @@ -155,6 +168,7 @@ def main():
output=args.output,
cls_str=cls_str,
rank=args.rank,
bins_path=args.bins,
k=args.k,
device=args.device,
model_str=args.model,
Expand All @@ -167,7 +181,7 @@ def main():
for model_str in oc.list_models():
print(f"\t{model_str}")
else:
raise ValueError("Invalid command")
create_parser().print_help()


if __name__ == '__main__':
Expand Down
43 changes: 36 additions & 7 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import open_clip as oc
import torch.nn.functional as F
import numpy as np
import pandas as pd
import collections
import heapq
import PIL.Image
Expand Down Expand Up @@ -253,13 +254,41 @@ def predict(self, image_paths: List[str] | str, k: int = None) -> dict[str, floa
img_probs = probs[image_path]
if not k or k > len(self.classes):
k = len(self.classes)
topk = img_probs.topk(k)
for i, prob in zip(topk.indices, topk.values):
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: self.classes[i],
PRED_SCORE_KEY: prob.item()
})
result.extend(self.group_probs(image_path, img_probs, k))
return result

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
topk = img_probs.topk(k)
for i, prob in zip(topk.indices, topk.values):
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: self.classes[i],
PRED_SCORE_KEY: prob.item()
})
return result


class CustomLabelsBinningClassifier(CustomLabelsClassifier):
def __init__(self, cls_to_bin: dict, **kwargs):
super().__init__(cls_ary=cls_to_bin.keys(), **kwargs)
self.cls_to_bin = cls_to_bin
if any([pd.isna(x) or not x for x in cls_to_bin.values()]):
raise ValueError("Empty, null, or nan are not allowed for bin values.")

def group_probs(self, image_path: str, img_probs: torch.Tensor, k: int = None) -> List[dict[str, float]]:
result = []
output = collections.defaultdict(float)
for i in range(len(self.classes)):
name = self.cls_to_bin[self.classes[i]]
output[name] += img_probs[i]
topk_names = heapq.nlargest(k, output, key=output.get)
for name in topk_names:
result.append({
PRED_FILENAME_KEY: image_path,
PRED_CLASSICATION_KEY: name,
PRED_SCORE_KEY: output[name].item()
})
return result


Expand Down
67 changes: 58 additions & 9 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import mock_open, patch
import argparse
from bioclip.__main__ import parse_args, Rank, create_classes_str, main
import pandas as pd
from bioclip.__main__ import parse_args, Rank, create_classes_str, main, parse_bins_csv


class TestParser(unittest.TestCase):
Expand All @@ -15,6 +16,7 @@ def test_parse_args(self):
self.assertEqual(args.rank, Rank.SPECIES)
self.assertEqual(args.k, 5)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cpu')

args = parse_args(['predict', 'image.jpg', 'image2.png'])
Expand All @@ -41,12 +43,29 @@ def test_parse_args(self):
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, 'class1,class2')
self.assertEqual(args.bins, None)
self.assertEqual(args.device, 'cuda')

# test binning version of predict
args = parse_args(['predict', 'image.jpg', '--format', 'table', '--output', 'output.csv', '--bins', 'bins.csv', '--device', 'cuda'])
self.assertEqual(args.command, 'predict')
self.assertEqual(args.image_file, ['image.jpg'])
self.assertEqual(args.format, 'table')
self.assertEqual(args.output, 'output.csv')
self.assertEqual(args.rank, None) # default ignored for the --cls variation
self.assertEqual(args.k, None)
self.assertEqual(args.cls, None)
self.assertEqual(args.bins, 'bins.csv')
self.assertEqual(args.device, 'cuda')

# test error when using --cls with --rank
with self.assertRaises(ValueError):
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--rank', 'genus'])

# test error when using --cls with --bins
with self.assertRaises(SystemExit):
parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--bins', 'somefile.csv'])

# not an error when using --cls with --k
args = parse_args(['predict', 'image.jpg', '--cls', 'class1,class2', '--k', '10'])
self.assertEqual(args.k, 10)
Expand Down Expand Up @@ -77,10 +96,10 @@ def test_create_classes_str(self):
def test_predict_no_class(self, mock_parse_args, mock_predict):
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls=None, device='cpu',
model=None, pretrained=None)
model=None, pretrained=None, bins=None)
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES, k=5,
device='cpu', model_str=None, pretrained_str=None)
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=Rank.SPECIES,
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
Expand All @@ -89,10 +108,10 @@ def test_predict_class_list(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = False
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls='dog,fish,bird',
device='cpu', model=None, pretrained=None)
device='cpu', model=None, pretrained=None, bins=None)
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
k=5, device='cpu', model_str=None, pretrained_str=None)
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
Expand All @@ -101,8 +120,38 @@ def test_predict_class_file(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = True
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=Rank.SPECIES, k=5, cls='somefile.txt',
device='cpu', model=None, pretrained=None)
device='cpu', model=None, pretrained=None, bins=None)
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str='dog,fish,bird', rank=Rank.SPECIES,
k=5, device='cpu', model_str=None, pretrained_str=None)
bins_path=None, k=5, device='cpu', model_str=None, pretrained_str=None)

@patch('bioclip.__main__.predict')
@patch('bioclip.__main__.parse_args')
@patch('bioclip.__main__.os')
def test_predict_bins(self, mock_os, mock_parse_args, mock_predict):
mock_os.path.exists.return_value = True
mock_parse_args.return_value = argparse.Namespace(command='predict', image_file='image.jpg', format='csv',
output='stdout', rank=None, k=5, cls=None,
device='cpu', model=None, pretrained=None,
bins='some.csv')
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
main()
mock_predict.assert_called_with('image.jpg', format='csv', output='stdout', cls_str=None, rank=None,
bins_path='some.csv', k=5, device='cpu', model_str=None, pretrained_str=None)
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv_file_missing(self, mock_path):
mock_path.exists.return_value = False
with self.assertRaises(FileNotFoundError) as raised_exception:
parse_bins_csv("somefile.csv")
self.assertEqual(str(raised_exception.exception), 'File not found: somefile.csv')

@patch('bioclip.__main__.pd')
@patch('bioclip.__main__.os.path')
def test_parse_bins_csv(self, mock_path, mock_pd):
mock_path.exists.return_value = True
data = {'bin': ['a', 'b']}
mock_pd.read_csv.return_value = pd.DataFrame(data=data, index=['dog', 'cat'])
with patch("builtins.open", mock_open(read_data='dog\nfish\nbird')) as mock_file:
cls_to_bin = parse_bins_csv("somefile.csv")
self.assertEqual(cls_to_bin, {'cat': 'b', 'dog': 'a'})
Loading

0 comments on commit d8035ef

Please sign in to comment.