diff --git a/tests/helpers/test_filters.py b/tests/helpers/test_filters.py index a8ca605d..3821b000 100644 --- a/tests/helpers/test_filters.py +++ b/tests/helpers/test_filters.py @@ -190,3 +190,90 @@ def test_filter_bbox_predictions(): "sheep", ], } + + +def test_filter_valid_bboxes(): + predictions = { + "bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]], + "labels": ["sheep", "sheep"], + } + image_size = (100, 100) + + results = filter_bbox_predictions(predictions, image_size) + assert results == predictions + + +def test_filter_wrong_order(): + predictions = { + "bboxes": [[326.0, 362.6, 225.8, 262.9], [50, 60, 70, 80]], + "labels": ["sheep", "sheep"], + } + image_size = (835, 453) + + results = filter_bbox_predictions(predictions, image_size) + assert results == {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]} + + +def test_filter_invalid_bboxes_negative_coords(): + predictions = { + "bboxes": [[-10, 20, 30, 40], [50, 60, 70, 80]], + "labels": ["sheep", "sheep"], + } + image_size = (100, 100) + + results = filter_bbox_predictions(predictions, image_size) + assert results == {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]} + + +def test_filter_invalid_bboxes_out_of_bounds(): + predictions = { + "bboxes": [[10, 20, 110, 40], [50, 60, 70, 80]], + "labels": ["sheep", "sheep"], + } + image_size = (100, 100) + + results = filter_bbox_predictions(predictions, image_size) + assert results == {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]} + + +def test_filter_invalid_bboxes_mixed_valid_invalid(): + predictions = { + "bboxes": [ + [10, 20, 30, 40], + [-10, 20, 30, 40], + [50, 60, 70, 80], + [110, 20, 120, 40], + ], + "labels": ["sheep", "sheep"], + } + image_size = (100, 100) + + results = filter_bbox_predictions(predictions, image_size) + assert results == { + "bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]], + "labels": ["sheep"], + } + + +def test_filter_invalid_bboxes(): + predictions_list = [ + {"bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]}, + {"bboxes": [[-10, 20, 30, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]}, + {"bboxes": [[10, 20, 110, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]}, + {"bboxes": [[10, 20, 30, 40], [50, 60, 40, 70]], "labels": ["sheep", "sheep"]}, + {"bboxes": [[10, 20, 30, 40], [50, 60, 70, 60]], "labels": ["sheep", "sheep"]}, + ] + + image_size = (100, 100) + + expected_results = [ + {"bboxes": [[10, 20, 30, 40], [50, 60, 70, 80]], "labels": ["sheep", "sheep"]}, + {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]}, + {"bboxes": [[50, 60, 70, 80]], "labels": ["sheep"]}, + {"bboxes": [[10, 20, 30, 40]], "labels": ["sheep"]}, + {"bboxes": [[10, 20, 30, 40]], "labels": ["sheep"]}, + ] + + for idx, prediction in enumerate(predictions_list): + results = filter_bbox_predictions(prediction, image_size) + assert results == expected_results[idx] diff --git a/tests/models/test_owlv2.py b/tests/models/test_owlv2.py index 17907cb2..c20d3c43 100644 --- a/tests/models/test_owlv2.py +++ b/tests/models/test_owlv2.py @@ -43,8 +43,8 @@ def test_owlv2_removing_extra_bbox(shared_model): assert len(response) == 1 item = response[0] - assert len(item["bboxes"]) == 42 - assert len([label == "egg" for label in item["labels"]]) == 42 + assert len(item["bboxes"]) == 40 + assert len([label == "egg" for label in item["labels"]]) == 40 def test_owlv2_image_with_nms(shared_model): diff --git a/vision_agent_tools/helpers/filters.py b/vision_agent_tools/helpers/filters.py index db350b2c..c0c35f2d 100644 --- a/vision_agent_tools/helpers/filters.py +++ b/vision_agent_tools/helpers/filters.py @@ -17,10 +17,19 @@ def filter_bbox_predictions( ) -> dict[str, Any]: new_preds = {} - # Remove the whole image bounding box if it is predicted - bboxes_to_remove = _remove_whole_image_bbox(predictions, image_size, bboxes_key) + # Remove invalid bboxes, other filters rely on well formed bboxes + bboxes_to_remove = _filter_invalid_bboxes( + predictions=predictions, + image_size=image_size, + bboxes_key=bboxes_key, + ) + new_preds = _remove_bboxes(predictions, bboxes_to_remove) + # Remove the whole image bounding box if it is predicted + bboxes_to_remove = _remove_whole_image_bbox(new_preds, image_size, bboxes_key) + new_preds = _remove_bboxes(new_preds, bboxes_to_remove) + # Apply a dummy agnostic Non-Maximum Suppression (NMS) to get rid of any # overlapping predictions on the same object bboxes_to_remove = _dummy_agnostic_nms(new_preds, nms_threshold, bboxes_key) @@ -183,3 +192,32 @@ def _contains(box_a, box_b): and x_max_a >= x_max_b and y_max_a >= y_max_b ) + + +def _filter_invalid_bboxes( + predictions: dict[str, Any], + image_size: tuple[int, int], + bboxes_key: str = "bboxes", +) -> list[int]: + """Filters out invalid bounding boxes from the given predictions and + returns a list of indices of invalid boxes. + + Args: + predictions: A dictionary containing 'bboxes' and 'labels' keys. + image_size: A tuple representing the image width and height. + bboxes_key: The key for bounding boxes in the predictions dictionary. + + Returns: + A list of indices of invalid bounding boxes. + """ + width, height = image_size + + invalid_indices = [] + + for idx, bbox in enumerate(predictions[bboxes_key]): + x1, y1, x2, y2 = bbox + if not (0 <= x1 < x2 <= width and 0 <= y1 < y2 <= height): + invalid_indices.append(idx) + _LOGGER.warning(f"Removing invalid bbox {bbox}") + + return invalid_indices