Skip to content

Commit

Permalink
implemented IOU for BaseBx and added unittests (#1)
Browse files Browse the repository at this point in the history
# Main commits
* implemented intersection-over-union (IOU) for `BaseBx`
* added unittests for all modules
* Implemented classmethod and `bbx()` for `BaseBx` class to convert all types to `BaseBx`
* `ops` now handles all type conversions (json-array, list-array)
* bug fixes, best caught: 
   * `BaseBx` method `xywh()` flipped `w` and `h`
   * read keys in order of `voc_keys` for json annotations)
* updated README.md and nbs/

## Squashed commit messages: 
* move voc_keys to anchor.py

* add new corrected annots_rand.json, annots_iou.json

* Implemented iou and fixed bbx.xywh() bug, inserted classmethod for bbx

* function to make json, list -> array

function to calculate the intersecting box dimension.

* adding testing scripts

* adding testing scripts

* comment fixed

* add tests for all modules

* update README.md
  • Loading branch information
thatgeeman authored Jan 18, 2022
1 parent ccb18ca commit bf37cb7
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 158 deletions.
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
[![PyPI version](https://badge.fury.io/py/pybx.svg)](https://badge.fury.io/py/pybx)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thatgeeman/pybx/blob/master/nbs/pybx_walkthrough.ipynb)

*WIP*

A simple python package to generate anchor
(aka default/prior) boxes for object detection
tasks. Calculated anchor boxes are in `pascal_voc` format by default.
Expand Down Expand Up @@ -67,13 +65,12 @@ to [Visualising anchor boxes](data/README.md).
- [x] Companion notebook
- [x] Update with new Class methods
- [x] Integrate MultiBx into anchor.bx()
- [ ] IOU check (return best overlap boxes)
- [ ] Return masks
- [x] IOU calcultaion
- [x] Unit tests
- [x] Specific tests
- [x] `feature_sz` of different aspect ratios
- [x] `image_sz` of different aspect ratios
- [ ] Move to setup.py
- [ ] Generate docs
- [ ] Generate docs `sphinx`
- [ ] clean docstrings


23 changes: 23 additions & 0 deletions data/annots_iou.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[
{
"x_min": 20.0,
"y_min": 10.0,
"x_max": 70.0,
"y_max": 80.0,
"label": "b1"
},
{
"x_min": 50.0,
"y_min": 60.0,
"x_max": 120.0,
"y_max": 150.0,
"label": "b2"
},
{
"x_min": 50.0,
"y_min": 60.0,
"x_max": 70.0,
"y_max": 80.0,
"label": "int"
}
]
4 changes: 2 additions & 2 deletions data/annots_rand.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
{
"x_min": 50.0,
"y_min": 70.0,
"y_max": 100.0,
"x_max": 120.0,
"y_max": 100.0,
"label": "rand1"
},
{
"x_min": 150.0,
"y_min": 200.0,
"y_max": 240.0,
"x_max": 250.0,
"y_max": 240.0,
"label": "rand2"
}
]
2 changes: 0 additions & 2 deletions src/pybx/anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from .ops import __ops__, get_op, named_idx

voc_keys = ['x_min', 'y_min', 'x_max', 'y_max', 'label']


def get_edges(image_sz: tuple, feature_sz: tuple, op='noop'):
"""
Expand Down
114 changes: 40 additions & 74 deletions src/pybx/basics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import numpy as np
from fastcore.basics import concat, store_attr

from .anchor import voc_keys
from .ops import mul, sub
from .ops import mul, sub, intersection_box, make_array, NoIntersection, voc_keys

__all__ = ['mbx', 'MultiBx', 'BaseBx', 'JsonBx', 'ListBx']
__all__ = ['bbx', 'mbx', 'MultiBx', 'BaseBx', 'JsonBx', 'ListBx']


def bbx(coords=None, labels=None):
"""
interface to the BaseBx class and all of its attributes
MultiBx wraps the coordinates and labels exposing many validation methods
:param coords: coordinates in list/array/json format
:param labels: labels in list format or keep intentionally None (also None for json)
:return: BaseBx object
"""
return BaseBx.basebx(coords, labels)


def mbx(coords=None, labels=None):
Expand All @@ -21,19 +31,29 @@ def mbx(coords=None, labels=None):
class BaseBx:
def __init__(self, coords, label=''):
store_attr('coords, label')
coords_ = coords[::-1] # reverse
self.w = sub(*coords_[::2])
self.h = sub(*coords_[1::2])
self.w = sub(*coords[::2][::-1])
self.h = sub(*coords[1::2][::-1])

def area(self):
return abs(mul(self.w, self.h))

def iou(self, other):
if not isinstance(other, BaseBx):
other = BaseBx.basebx(other)
if self.valid():
try:
int_box = BaseBx.basebx(intersection_box(self.coords, other.coords))
except NoIntersection:
return 0.0
int_area = int_box.area()
union_area = other.area() + self.area() - int_area
return int_area / union_area
return 0.0

def valid(self):
# TODO: more validations here
v_area = bool(self.area()) # False if 0
# TODO: v_ratio
v_all = [v_area]
return False if False in v_all else True
v_all = np.array([v_area])
return True if v_all.all() else False

def values(self):
return [*self.coords, self.label]
Expand All @@ -49,6 +69,15 @@ def make_2d(self):
labels = [self.label]
return coords, labels

@classmethod
def basebx(cls, coords, label: list = None):
if not isinstance(coords, np.ndarray):
try:
coords, label = make_array(coords)
except ValueError:
coords = make_array(coords)
return cls(coords, label)


class MultiBx:
def __init__(self, coords, label: list = None):
Expand Down Expand Up @@ -129,72 +158,9 @@ def jsonbx(cls, coords, label=None):
r = []
for i, c in enumerate(coords):
assert isinstance(c, dict), f'expected b of type dict, got {type(c)}'
c_ = list(c.values())
c_ = [c[k] for k in voc_keys] # read in order
l_ = c_[-1] if len(c_) > 4 else '' if label is None else label[i]
l.append(l_)
r.append(c_[:-1] if len(c_) > 4 else c_)
coords = np.array(r)
return cls(coords, label=l)


# deprecated
class BxIter:
def __init__(self, coords: np.ndarray, x_max=-1.0, y_max=-1.0, clip_only=False):
"""
returns an iterator that validates the coordinates calculated.
:param coords: ndarray of box coordinates
:param x_max: max dimension along x
:param y_max: max dimension along y
:param clip_only: whether to apply only np.clip with validate
clip_only cuts boxes that bleed outside limits
and forgo other validation ops
"""
if not isinstance(coords, np.ndarray):
coords = np.array(coords)
self.coords = coords.clip(0, max(x_max, y_max))
# clip_only cuts boxes that bleed outside limits
store_attr('x_max, y_max, clip_only')

def __iter__(self):
self.index = 0
return self

def __next__(self):
try:
c = self.coords[self.index]
if not self.clip_only:
self.validate_edge(c)
except IndexError:
raise StopIteration
self.index += 1
return c

def validate_edge(self, c):
"""
return next only if the x_min and y_min
# TODO: more tests, check if asp ratio changed as an indicator
does not flow outside the image, but:
- while might keep point (1,1,1,1) or line (0,0,1,0) | (0,0,0,1) boxes!
either maybe undesirable.
:param c: pass a box
:return: call for next iterator if conditions not met
"""
x1, y1 = c[:2]
if (x1 >= self.x_max) or (y1 >= self.y_max):
self.index += 1
return self.__next__()

def to_array(self, cast_fn=np.asarray):
"""
return all validated coords as np.ndarray
:return: array of coordinates, specify get_as torch.tensor for Tensor
"""
# TODO: fix UserWarning directly casting a numpy to tensor is too slow (torch>10)
return cast_fn([c for c in self.coords])

def to_records(self, cast_fn=list):
"""
return all validated coords as records (list of dicts)
:return: array of coordinates, specify get_as dict for json
"""
return cast_fn(dict(zip(voc_keys, [*c, f'a{i}'])) for i, c in enumerate(self.coords))
41 changes: 40 additions & 1 deletion src/pybx/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
from fastcore.foundation import L

__ops__ = ['add', 'sub', 'noop']
__ops__ = ['add', 'sub', 'mul', 'noop']
voc_keys = ['x_min', 'y_min', 'x_max', 'y_max', 'label']


def add(x, y):
Expand Down Expand Up @@ -30,6 +31,24 @@ def get_op(op: str):
return eval(op, globals())


def make_array(x):
if isinstance(x, dict):
try:
x = [x[k] for k in voc_keys]
except TypeError:
x = [x[k] for k in voc_keys[:-1]]
# now dict made into a list too
if isinstance(x, list):
if len(x) > 4:
return np.asarray(x[:4]), x[-1]
else:
return np.asarray(x)
elif isinstance(x, np.ndarray):
return x
else:
raise NotImplementedError


def named_idx(x: np.ndarray, sfx: str):
"""
return a list of string indices matching the array
Expand All @@ -40,3 +59,23 @@ def named_idx(x: np.ndarray, sfx: str):
"""
idx = np.arange(0, x.shape[0]).tolist()
return L([sfx + i.__str__() for i in idx])


def intersection_box(b1: np.ndarray, b2: np.ndarray):
"""
return the intersection box given two boxes
:param b1:
:param b2:
:return:
"""
if not isinstance(b1, np.ndarray):
raise TypeError('expected ndarrays')
top_edge = np.max(np.vstack([b1, b2]), axis=0)[:2]
bot_edge = np.min(np.vstack([b1, b2]), axis=0)[2:]
if (bot_edge > top_edge).all():
return np.hstack([top_edge, bot_edge])
raise NoIntersection


class NoIntersection(Exception):
pass
43 changes: 43 additions & 0 deletions tests/test_anchor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import unittest

import numpy as np

from pybx import anchor

np.random.seed(1)

params = {
"feature_szs": [(2, 2), (3, 3), (4, 4)],
"asp_ratios": [1 / 2., 1., 2.],
"feature_sz": (2, 2),
"asp_ratio": 1 / 2.,
"image_sz": (10, 10, 3),
"data_dir": '../data',
}

results = {
"bx_b": 236.8933982822018,
"bx_l": 'a_2x2_0.5_8',
"bxs_b": 3703.086279536432,
"bxs_l": 'a_4x4_2.0_24',
"scaled_ans": (9.0, 6.0),
}


class AnchorTestCase(unittest.TestCase):
def test_bx(self):
b, l_ = anchor.bx(params["image_sz"], params["feature_sz"], params["asp_ratio"])
self.assertIn(results["bx_l"], l_, 'label not matching')
self.assertEqual(len(b), len(l_))
self.assertEqual(b.sum(), results["bx_b"], 'sum not matching') # add assertion here

def test_bxs(self):
b, l_ = anchor.bxs(params["image_sz"], params["feature_szs"], params["asp_ratios"])
self.assertIn(results["bxs_l"], l_, 'label not matching')
self.assertEqual(len(b), len(l_))
self.assertEqual(b.sum(), results["bxs_b"], 'sum not matching') # add assertion here


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit bf37cb7

Please sign in to comment.