Skip to content

Commit

Permalink
feat: filter invalid bounding boxes (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
hrnn authored Nov 19, 2024
1 parent 535e7c7 commit 7967a00
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 4 deletions.
87 changes: 87 additions & 0 deletions tests/helpers/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,90 @@ def test_filter_bbox_predictions():
"sheep",
],
}


def test_filter_valid_bboxes():
predictions = {
"bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]],
"labels": ["sheep", "sheep"],
}
image_size = (100, 100)

results = filter_bbox_predictions(predictions, image_size)
assert results == predictions


def test_filter_wrong_order():
predictions = {
"bboxes": [[326.0, 362.6, 225.8, 262.9], [50, 60, 70, 80]],
"labels": ["sheep", "sheep"],
}
image_size = (835, 453)

results = filter_bbox_predictions(predictions, image_size)
assert results == {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]}


def test_filter_invalid_bboxes_negative_coords():
predictions = {
"bboxes": [[-10, 20, 30, 40], [50, 60, 70, 80]],
"labels": ["sheep", "sheep"],
}
image_size = (100, 100)

results = filter_bbox_predictions(predictions, image_size)
assert results == {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]}


def test_filter_invalid_bboxes_out_of_bounds():
predictions = {
"bboxes": [[10, 20, 110, 40], [50, 60, 70, 80]],
"labels": ["sheep", "sheep"],
}
image_size = (100, 100)

results = filter_bbox_predictions(predictions, image_size)
assert results == {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]}


def test_filter_invalid_bboxes_mixed_valid_invalid():
predictions = {
"bboxes": [
[10, 20, 30, 40],
[-10, 20, 30, 40],
[50, 60, 70, 80],
[110, 20, 120, 40],
],
"labels": ["sheep", "sheep"],
}
image_size = (100, 100)

results = filter_bbox_predictions(predictions, image_size)
assert results == {
"bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]],
"labels": ["sheep"],
}


def test_filter_invalid_bboxes():
predictions_list = [
{"bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]},
{"bboxes": [[-10, 20, 30, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]},
{"bboxes": [[10, 20, 110, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]},
{"bboxes": [[10, 20, 30, 40], [50, 60, 40, 70]], "labels": ["sheep", "sheep"]},
{"bboxes": [[10, 20, 30, 40], [50, 60, 70, 60]], "labels": ["sheep", "sheep"]},
]

image_size = (100, 100)

expected_results = [
{"bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]},
{"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]},
{"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]},
{"bboxes": [[10, 20, 30, 40]], "labels": ["sheep"]},
{"bboxes": [[10, 20, 30, 40]], "labels": ["sheep"]},
]

for idx, prediction in enumerate(predictions_list):
results = filter_bbox_predictions(prediction, image_size)
assert results == expected_results[idx]
4 changes: 2 additions & 2 deletions tests/models/test_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def test_owlv2_removing_extra_bbox(shared_model):

assert len(response) == 1
item = response[0]
assert len(item["bboxes"]) == 42
assert len([label == "egg" for label in item["labels"]]) == 42
assert len(item["bboxes"]) == 40
assert len([label == "egg" for label in item["labels"]]) == 40


def test_owlv2_image_with_nms(shared_model):
Expand Down
42 changes: 40 additions & 2 deletions vision_agent_tools/helpers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ def filter_bbox_predictions(
) -> dict[str, Any]:
new_preds = {}

# Remove the whole image bounding box if it is predicted
bboxes_to_remove = _remove_whole_image_bbox(predictions, image_size, bboxes_key)
# Remove invalid bboxes, other filters rely on well formed bboxes
bboxes_to_remove = _filter_invalid_bboxes(
predictions=predictions,
image_size=image_size,
bboxes_key=bboxes_key,
)

new_preds = _remove_bboxes(predictions, bboxes_to_remove)

# Remove the whole image bounding box if it is predicted
bboxes_to_remove = _remove_whole_image_bbox(new_preds, image_size, bboxes_key)
new_preds = _remove_bboxes(new_preds, bboxes_to_remove)

# Apply a dummy agnostic Non-Maximum Suppression (NMS) to get rid of any
# overlapping predictions on the same object
bboxes_to_remove = _dummy_agnostic_nms(new_preds, nms_threshold, bboxes_key)
Expand Down Expand Up @@ -183,3 +192,32 @@ def _contains(box_a, box_b):
and x_max_a >= x_max_b
and y_max_a >= y_max_b
)


def _filter_invalid_bboxes(
predictions: dict[str, Any],
image_size: tuple[int, int],
bboxes_key: str = "bboxes",
) -> list[int]:
"""Filters out invalid bounding boxes from the given predictions and
returns a list of indices of invalid boxes.
Args:
predictions: A dictionary containing 'bboxes' and 'labels' keys.
image_size: A tuple representing the image width and height.
bboxes_key: The key for bounding boxes in the predictions dictionary.
Returns:
A list of indices of invalid bounding boxes.
"""
width, height = image_size

invalid_indices = []

for idx, bbox in enumerate(predictions[bboxes_key]):
x1, y1, x2, y2 = bbox
if not (0 <= x1 < x2 <= width and 0 <= y1 < y2 <= height):
invalid_indices.append(idx)
_LOGGER.warning(f"Removing invalid bbox {bbox}")

return invalid_indices

0 comments on commit 7967a00

Please sign in to comment.