diff --git a/focoos/local_model.py b/focoos/local_model.py index 4e655ef..71f8753 100644 --- a/focoos/local_model.py +++ b/focoos/local_model.py @@ -39,6 +39,7 @@ from focoos.runtime import BaseRuntime, load_runtime from focoos.utils.logger import get_logger from focoos.utils.vision import ( + get_postprocess_fn, image_preprocess, scale_detections, sv_to_fai_detections, @@ -99,6 +100,7 @@ def __init__( # Load metadata and set model reference self.metadata: ModelMetadata = self._read_metadata() self.model_ref = self.metadata.ref + self.postprocess_fn = get_postprocess_fn(self.metadata.task) # Initialize annotation utilities self.label_annotator = sv.LabelAnnotator(text_padding=10, border_radius=10) @@ -137,6 +139,9 @@ def _annotate(self, im: np.ndarray, detections: sv.Detections) -> np.ndarray: Returns: np.ndarray: The annotated image with bounding boxes or masks. """ + if len(detections.xyxy) == 0: + logger.warning("No detections found, skipping annotation") + return im classes = self.metadata.classes labels = [ f"{classes[int(class_id)] if classes is not None else str(class_id)}: {confid * 100:.0f}%" @@ -189,11 +194,16 @@ def infer( t0 = perf_counter() im1, im0 = image_preprocess(image, resize=resize) t1 = perf_counter() - detections = self.runtime(im1.astype(np.float32), threshold) + detections = self.runtime(im1.astype(np.float32)) + t2 = perf_counter() + + detections = self.postprocess_fn( + out=detections, im0_shape=(im0.shape[1], im0.shape[0]), conf_threshold=threshold + ) + if resize: detections = scale_detections(detections, (resize, resize), (im0.shape[1], im0.shape[0])) - logger.debug(f"Inference time: {t2 - t1:.3f} seconds") out = sv_to_fai_detections(detections, classes=self.metadata.classes) t3 = perf_counter() @@ -205,6 +215,10 @@ def infer( im = None if annotate: im = self._annotate(im0, detections) + + logger.debug( + f"Found {len(detections)} detections. Inference time: {(t2 - t1) * 1000:.0f}ms, preprocess: {(t1 - t0) * 1000:.0f}ms, postprocess: {(t3 - t2) * 1000:.0f}ms" + ) return FocoosDetections(detections=out, latency=latency), im def benchmark(self, iterations: int, size: int) -> LatencyMetrics: diff --git a/focoos/remote_model.py b/focoos/remote_model.py index 7817f56..a25147b 100644 --- a/focoos/remote_model.py +++ b/focoos/remote_model.py @@ -224,6 +224,10 @@ def _annotate(self, im: np.ndarray, detections: sv.Detections) -> np.ndarray: Returns: np.ndarray: The annotated image as a NumPy array. """ + + if len(detections.xyxy) == 0: + logger.warning("No detections found, skipping annotation") + return im classes = self.metadata.classes if classes is not None: labels = [ @@ -291,11 +295,13 @@ def infer( ) t1 = time.time() if res.status_code == 200: - logger.debug(f"Inference time: {t1 - t0:.3f} seconds") detections = FocoosDetections( detections=[FocoosDet.from_json(d) for d in res.json().get("detections", [])], latency=res.json().get("latency", None), ) + logger.debug( + f"Found {len(detections.detections)} detections. Inference Request time: {(t1 - t0) * 1000:.0f}ms" + ) preview = None if annotate: im0 = image_loader(image) diff --git a/focoos/runtime.py b/focoos/runtime.py index 17a00c0..2cf9542 100644 --- a/focoos/runtime.py +++ b/focoos/runtime.py @@ -20,12 +20,10 @@ from abc import abstractmethod from pathlib import Path from time import perf_counter -from typing import Any, List, Tuple +from typing import Any import numpy as np -from focoos.utils.vision import mask_to_xyxy - try: import torch @@ -40,11 +38,9 @@ except ImportError: ORT_AVAILABLE = False -import supervision as sv # from supervision.detection.utils import mask_to_xyxy from focoos.ports import ( - FocoosTask, LatencyMetrics, ModelMetadata, OnnxRuntimeOpts, @@ -59,93 +55,12 @@ logger = get_logger() -def get_postprocess_fn(task: FocoosTask): - if task == FocoosTask.INSTANCE_SEGMENTATION: - return instance_postprocess - elif task == FocoosTask.SEMSEG: - return semseg_postprocess - else: - return det_postprocess - - -def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: - """ - Postprocesses the output of an object detection model and filters detections - based on a confidence threshold. - - Args: - out (List[np.ndarray]): The output of the detection model. - im0_shape (Tuple[int, int]): The original shape of the input image (height, width). - conf_threshold (float): The confidence threshold for filtering detections. - - Returns: - sv.Detections: A sv.Detections object containing the filtered bounding boxes, class ids, and confidences. - """ - cls_ids, boxes, confs = out - boxes[:, 0::2] *= im0_shape[1] - boxes[:, 1::2] *= im0_shape[0] - high_conf_indices = (confs > conf_threshold).nonzero() - - return sv.Detections( - xyxy=boxes[high_conf_indices].astype(int), - class_id=cls_ids[high_conf_indices].astype(int), - confidence=confs[high_conf_indices].astype(float), - ) - - -def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: - """ - Postprocesses the output of a semantic segmentation model and filters based - on a confidence threshold. - - Args: - out (List[np.ndarray]): The output of the semantic segmentation model. - conf_threshold (float): The confidence threshold for filtering detections. - - Returns: - sv.Detections: A sv.Detections object containing the masks, class ids, and confidences. - """ - cls_ids, mask, confs = out[0][0], out[1][0], out[2][0] - masks = np.equal(mask, np.arange(len(cls_ids))[:, None, None]) - high_conf_indices = np.where(confs > conf_threshold)[0] - masks = masks[high_conf_indices].astype(bool) - cls_ids = cls_ids[high_conf_indices].astype(int) - confs = confs[high_conf_indices].astype(float) - xyxy = mask_to_xyxy(masks) - return sv.Detections( - mask=masks, - xyxy=xyxy, - class_id=cls_ids, - confidence=confs, - ) - - -def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: - """ - Postprocesses the output of an instance segmentation model and filters detections - based on a confidence threshold. - """ - cls_ids, mask, confs = out[0][0], out[1][0], out[2][0] - high_conf_indices = np.where(confs > conf_threshold)[0] - - masks = mask[high_conf_indices].astype(bool) - cls_ids = cls_ids[high_conf_indices].astype(int) - confs = confs[high_conf_indices].astype(float) - xyxy = mask_to_xyxy(masks) - return sv.Detections( - mask=masks, - xyxy=xyxy, - class_id=cls_ids, - confidence=confs, - ) - - class BaseRuntime: def __init__(self, model_path: str, opts: Any, model_metadata: ModelMetadata): pass @abstractmethod - def __call__(self, im: np.ndarray, conf_threshold: float) -> sv.Detections: + def __call__(self, im: np.ndarray) -> np.ndarray: pass @abstractmethod @@ -169,8 +84,6 @@ def __init__(self, model_path: str, opts: OnnxRuntimeOpts, model_metadata: Model self.opts = opts self.model_metadata = model_metadata - self.postprocess_fn = get_postprocess_fn(model_metadata.task) - # Setup session options options = ort.SessionOptions() options.log_severity_level = 0 if opts.verbose else 2 @@ -241,12 +154,12 @@ def _warmup(self): self.logger.info("⏱️ [onnxruntime] Warmup done") - def __call__(self, im: np.ndarray, conf_threshold: float) -> sv.Detections: + def __call__(self, im: np.ndarray) -> list[np.ndarray]: """Run inference and return detections.""" input_name = self.ort_sess.get_inputs()[0].name out_name = [output.name for output in self.ort_sess.get_outputs()] out = self.ort_sess.run(out_name, {input_name: im}) - return self.postprocess_fn(out=out, im0_shape=(im.shape[2], im.shape[3]), conf_threshold=conf_threshold) + return out def benchmark(self, iterations=20, size=640) -> LatencyMetrics: """Benchmark model latency.""" @@ -297,7 +210,6 @@ def __init__( self.logger = get_logger(name="TorchscriptEngine") self.logger.info(f"🔧 [torchscript] Device: {self.device}") self.opts = opts - self.postprocess_fn = get_postprocess_fn(model_metadata.task) map_location = None if torch.cuda.is_available() else "cpu" @@ -312,12 +224,12 @@ def __init__( self.model(np_image) self.logger.info("⏱️ [torchscript] WARMUP DONE") - def __call__(self, im: np.ndarray, conf_threshold: float) -> sv.Detections: + def __call__(self, im: np.ndarray) -> list[np.ndarray]: """Run inference and return detections.""" with torch.no_grad(): torch_image = torch.from_numpy(im).to(self.device, dtype=torch.float32) res = self.model(torch_image) - return self.postprocess_fn([r.cpu().numpy() for r in res], (im.shape[2], im.shape[3]), conf_threshold) + return [r.cpu().numpy() for r in res] def benchmark(self, iterations=20, size=640) -> LatencyMetrics: """Benchmark model latency.""" diff --git a/focoos/utils/vision.py b/focoos/utils/vision.py index 88b1081..da39e32 100644 --- a/focoos/utils/vision.py +++ b/focoos/utils/vision.py @@ -9,7 +9,7 @@ from scipy.ndimage import zoom from typing_extensions import Buffer -from focoos.ports import FocoosDet, FocoosDetections +from focoos.ports import FocoosDet, FocoosDetections, FocoosTask def index_to_class(class_ids: list[int], classes: list[str]) -> list[str]: @@ -50,8 +50,6 @@ def image_loader(im: Union[bytes, str, Path, np.ndarray, Image.Image]) -> np.nda elif isinstance(im, Buffer): byte_array = np.frombuffer(im, dtype=np.uint8) cv_image = cv2.imdecode(byte_array, cv2.IMREAD_COLOR) - else: - raise ValueError(f"Unsupported image type: {type(im)}") return cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) @@ -149,7 +147,7 @@ def fai_detections_to_sv(inference_output: FocoosDetections, im0_shape: tuple) - class_id = np.array([d.cls_id for d in inference_output.detections]) confidence = np.array([d.conf for d in inference_output.detections]) if xyxy.shape[0] == 0: - xyxy = np.empty((0, 4)) + xyxy = np.zeros((0, 4)) _masks = [] if len(inference_output.detections) > 0 and inference_output.detections[0].mask: _masks = [np.zeros(im0_shape, dtype=bool) for _ in inference_output.detections] @@ -227,8 +225,14 @@ def sv_to_fai_detections(detections: sv.Detections, classes: Optional[list[str]] res = [] for xyxy, mask, conf, cls_id, _, _ in detections: if mask is not None: - cropped_mask = mask[int(xyxy[1]) : int(xyxy[3]), int(xyxy[0]) : int(xyxy[2])] + x1, y1, x2, y2 = map(int, xyxy) + x1 = max(x1 - 1, 0) + y1 = max(y1 - 1, 0) + x2 = min(x2 + 2, mask.shape[1]) + y2 = min(y2 + 2, mask.shape[0]) + cropped_mask = mask[y1:y2, x1:x2] mask = binary_mask_to_base64(cropped_mask) + det = FocoosDet( cls_id=int(cls_id) if cls_id is not None else None, bbox=[int(x) for x in xyxy], @@ -240,7 +244,7 @@ def sv_to_fai_detections(detections: sv.Detections, classes: Optional[list[str]] return res -def mask_to_xyxy(masks: np.ndarray) -> np.ndarray: +def masks_to_xyxy(masks: np.ndarray) -> np.ndarray: """ Converts a 3D `np.array` of 2D bool masks into a 2D `np.array` of bounding boxes. @@ -267,3 +271,112 @@ def mask_to_xyxy(masks: np.ndarray) -> np.ndarray: xyxy[i, :] = [x_min, y_min, x_max, y_max] return xyxy + + +def get_postprocess_fn(task: FocoosTask): + if task == FocoosTask.INSTANCE_SEGMENTATION: + return instance_postprocess + elif task == FocoosTask.SEMSEG: + return semseg_postprocess + else: + return det_postprocess + + +def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: + """ + Postprocesses the output of an object detection model and filters detections + based on a confidence threshold. + + Args: + out (List[np.ndarray]): The output of the detection model. + im0_shape (Tuple[int, int]): The original shape of the input image (height, width). + conf_threshold (float): The confidence threshold for filtering detections. + + Returns: + sv.Detections: A sv.Detections object containing the filtered bounding boxes, class ids, and confidences. + """ + cls_ids, boxes, confs = out + boxes[:, 0::2] *= im0_shape[1] + boxes[:, 1::2] *= im0_shape[0] + high_conf_indices = (confs > conf_threshold).nonzero() + + return sv.Detections( + xyxy=boxes[high_conf_indices].astype(int), + class_id=cls_ids[high_conf_indices].astype(int), + confidence=confs[high_conf_indices].astype(float), + ) + + +def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: + """ + Postprocesses the output of a semantic segmentation model and filters based + on a confidence threshold, removing empty masks. + + Args: + out (List[np.ndarray]): The output of the semantic segmentation model. + conf_threshold (float): The confidence threshold for filtering detections. + + Returns: + sv.Detections: A sv.Detections object containing the non-empty masks, class ids, and confidences. + """ + cls_ids, mask, confs = out[0][0], out[1][0], out[2][0] + masks = np.equal(mask, np.arange(len(cls_ids))[:, None, None]) + high_conf_indices = confs > conf_threshold + masks = masks[high_conf_indices] + cls_ids = cls_ids[high_conf_indices] + confs = confs[high_conf_indices] + + if len(masks.shape) != 3: + return sv.Detections( + mask=None, + xyxy=np.zeros((0, 4)), + class_id=None, + confidence=None, + ) + # Filter out empty masks + non_empty_mask_indices = np.any(masks, axis=(1, 2)) + masks = masks[non_empty_mask_indices] + cls_ids = cls_ids[non_empty_mask_indices] + confs = confs[non_empty_mask_indices] + xyxy = masks_to_xyxy(masks) + return sv.Detections( + mask=masks, + # xyxy is required from supervision + xyxy=xyxy, + class_id=cls_ids, + confidence=confs, + ) + + +def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: + """ + Postprocesses the output of an instance segmentation model and filters detections + based on a confidence threshold. + """ + cls_ids, mask, confs = out[0][0], out[1][0], out[2][0] + high_conf_indices = np.where(confs > conf_threshold)[0] + masks = mask[high_conf_indices].astype(bool) + cls_ids = cls_ids[high_conf_indices].astype(int) + confs = confs[high_conf_indices].astype(float) + if len(masks.shape) != 3: + return sv.Detections( + mask=None, + xyxy=np.zeros((0, 4)), + class_id=None, + confidence=None, + ) + + # Filter out empty masks + non_empty_mask_indices = np.any(masks, axis=(1, 2)) + masks = masks[non_empty_mask_indices] + cls_ids = cls_ids[non_empty_mask_indices] + confs = confs[non_empty_mask_indices] + xyxy = masks_to_xyxy(masks) + + return sv.Detections( + mask=masks, + # xyxy is required from supervision + xyxy=xyxy, + class_id=cls_ids, + confidence=confs, + ) diff --git a/tests/test_local_model.py b/tests/test_local_model.py index aa54def..fb45853 100644 --- a/tests/test_local_model.py +++ b/tests/test_local_model.py @@ -139,6 +139,11 @@ def mock_sv_detections() -> sv.Detections: ) +@pytest.fixture +def mock_runtime_detections() -> list[np.ndarray]: + return [np.array([[2, 8, 16, 32], [4, 10, 18, 34]]), np.array([0, 1]), np.array([0.8, 0.9])] + + def test_annotate_detection_metadata_classes_none( image_ndarray: np.ndarray, mock_local_model_onnx: LocalModel, mock_sv_detections ): @@ -175,6 +180,7 @@ def mock_infer_setup( mock_local_model: LocalModel, image_ndarray: np.ndarray, mock_sv_detections: sv.Detections, + mock_runtime_detections: list[np.ndarray], mock_focoos_detections: FocoosDetections, annotate: bool, ): @@ -191,6 +197,10 @@ def mock_infer_setup( mock_sv_to_focoos_detections = mocker.patch("focoos.local_model.sv_to_fai_detections") mock_sv_to_focoos_detections.return_value = mock_focoos_detections.detections + # mock postprocess + mock_postprocess = mocker.patch.object(mock_local_model, "postprocess_fn") + mock_postprocess.return_value = mock_sv_detections + # Mock _annotate mock_annotate = mocker.patch.object(mock_local_model, "_annotate", autospec=True) if annotate: @@ -201,9 +211,9 @@ def mock_infer_setup( # Mock runtime class MockRuntime(MagicMock): def __call__(self, *args, **kwargs): - return mock_sv_detections + return mock_runtime_detections - mock_runtime_call = mocker.patch.object(MockRuntime, "__call__", return_value=mock_sv_detections) + mock_runtime_call = mocker.patch.object(MockRuntime, "__call__", return_value=mock_runtime_detections) mock_local_model.runtime = MockRuntime(spec=ONNXRuntime) return ( @@ -222,6 +232,7 @@ def test_infer_onnx( image_ndarray, mock_sv_detections, mock_focoos_detections, + mock_runtime_detections, annotate, ): # Arrange @@ -230,6 +241,7 @@ def test_infer_onnx( mock_local_model_onnx, image_ndarray, mock_sv_detections, + mock_runtime_detections, mock_focoos_detections, annotate, ) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index e7ce8b4..60df09e 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -1,22 +1,16 @@ import pathlib from unittest.mock import MagicMock -import numpy as np import pytest -import supervision as sv from pytest_mock import MockerFixture -from focoos.ports import FocoosTask, ModelMetadata, OnnxRuntimeOpts, RuntimeTypes, TorchscriptRuntimeOpts +from focoos.ports import ModelMetadata, OnnxRuntimeOpts, RuntimeTypes, TorchscriptRuntimeOpts from focoos.runtime import ( ORT_AVAILABLE, TORCH_AVAILABLE, ONNXRuntime, TorchscriptRuntime, - det_postprocess, - get_postprocess_fn, - instance_postprocess, load_runtime, - semseg_postprocess, ) @@ -56,70 +50,6 @@ def test_onnx_import(): assert ort is not None, "ONNX Runtime should be properly imported" -def test_det_post_process(): - cls_ids = np.array([1, 2, 3]) - boxes = np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]) - confs = np.array([0.8, 0.9, 0.7]) - out = [cls_ids, boxes, confs] - - im0_shape = (640, 480) - conf_threshold = 0.75 - sv_detections = det_postprocess(out, im0_shape, conf_threshold) - - np.testing.assert_array_equal(sv_detections.xyxy, np.array([[48, 128, 144, 256], [240, 384, 336, 512]])) - assert sv_detections.class_id is not None - np.testing.assert_array_equal(sv_detections.class_id, np.array([1, 2])) - assert sv_detections.confidence is not None - np.testing.assert_array_equal(sv_detections.confidence, np.array([0.8, 0.9])) - - -def test_semseg_postprocess(): - cls_ids = np.array([1, 2, 3]) - mask = np.array( - [ - [0, 1, 1, 2], - [0, 1, 2, 2], - [0, 0, 1, 2], - ] - ) - confs = np.array([0.7, 0.9, 0.8]) - out = [ - np.expand_dims(cls_ids, axis=0), - np.expand_dims(mask, axis=0), - np.expand_dims(confs, axis=0), - ] - - im0_shape = (3, 4) - conf_threshold = 0.75 - - sv_detections = semseg_postprocess(out, im0_shape, conf_threshold) - - # Expected masks - expected_masks = np.array( - [ - [ - [False, True, True, False], - [False, True, False, False], - [False, False, True, False], - ], # Class 1 - [ - [False, False, False, True], - [False, False, True, True], - [False, False, False, True], - ], # Class 2 - ] - ) - - # Assertions - assert sv_detections.mask is not None - np.testing.assert_array_equal(sv_detections.mask, expected_masks) - assert sv_detections.class_id is not None - np.testing.assert_array_equal(sv_detections.class_id, np.array([2, 3])) - assert sv_detections.confidence is not None - np.testing.assert_array_equal(sv_detections.confidence, np.array([0.9, 0.8])) - assert sv_detections.xyxy.shape == (2, 4) - - @pytest.mark.parametrize( "runtime_type, expected_opts", [ @@ -235,59 +165,3 @@ def test_load_unavailable_runtime(mocker: MockerFixture): load_runtime(RuntimeTypes.TORCHSCRIPT_32, "fake_model_path", MagicMock(spec=ModelMetadata), 2) with pytest.raises(ImportError): load_runtime(RuntimeTypes.ONNX_CUDA32, "fake_model_path", MagicMock(spec=ModelMetadata), 2) - - -def test_get_postprocess_fn(): - """ - Test the get_postprocess_fn function to ensure it returns - the correct postprocessing function for each task. - """ - # Test detection task - det_fn = get_postprocess_fn(FocoosTask.DETECTION) - assert det_fn == det_postprocess, "Detection task should return det_postprocess function" - - # Test instance segmentation task - instance_fn = get_postprocess_fn(FocoosTask.INSTANCE_SEGMENTATION) - assert instance_fn == instance_postprocess, "Instance segmentation task should return instance_postprocess function" - - # Test semantic segmentation task - semseg_fn = get_postprocess_fn(FocoosTask.SEMSEG) - assert semseg_fn == semseg_postprocess, "Semantic segmentation task should return semseg_postprocess function" - - # Test all FocoosTask values to ensure no exceptions - for task in FocoosTask: - fn = get_postprocess_fn(task) - assert callable(fn), f"Postprocess function for {task} should be callable" - - -def test_instance_postprocess(): - """Test instance segmentation postprocessing""" - cls_ids = np.array([0, 1]) - masks = np.zeros((2, 100, 100)) - masks[0, 10:30, 10:30] = 1 - masks[1, 40:60, 40:60] = 1 - confs = np.array([0.95, 0.85]) - out = [[cls_ids], [masks], [confs]] - - result = instance_postprocess(out, (100, 100), 0.8) - - assert isinstance(result, sv.Detections) - assert len(result) == 2 - assert result.mask is not None - assert result.xyxy is not None - assert result.class_id is not None - assert result.confidence is not None - - -def test_confidence_threshold_filtering(): - """Test that confidence threshold filtering works correctly""" - out = [ - np.array([0, 1, 2]), # cls_ids - np.array([[0.1, 0.1, 0.3, 0.3], [0.4, 0.4, 0.6, 0.6], [0.7, 0.7, 0.9, 0.9]]), # boxes - np.array([0.95, 0.55, 0.85]), # confs - ] - - result = det_postprocess(out, (100, 100), conf_threshold=0.8) - - assert len(result) == 2 # Should only keep detections with conf > 0.8 - assert all(conf > 0.8 for conf in result.confidence) diff --git a/tests/utils/test_vision.py b/tests/utils/test_vision.py index 9655bf4..290c223 100644 --- a/tests/utils/test_vision.py +++ b/tests/utils/test_vision.py @@ -4,18 +4,22 @@ import numpy as np import supervision as sv -from focoos.ports import FocoosDet +from focoos.ports import FocoosDet, FocoosTask from focoos.utils.vision import ( base64mask_to_mask, binary_mask_to_base64, class_to_index, + det_postprocess, fai_detections_to_sv, + get_postprocess_fn, image_loader, image_preprocess, index_to_class, - mask_to_xyxy, + instance_postprocess, + masks_to_xyxy, scale_detections, scale_mask, + semseg_postprocess, sv_to_fai_detections, ) @@ -180,28 +184,144 @@ def test_sv_to_focoos_detections(sv_detections: sv.Detections): assert isinstance(result_focoos_detection.mask, str), "Mask should be a string" -def test_mask_to_xyxy(): +def test_masks_to_xyxy(): # Basic case: a single mask with one active pixel mask1 = np.zeros((1, 5, 5), dtype=bool) mask1[0, 2, 3] = True # One active pixel at (2,3) - assert np.array_equal(mask_to_xyxy(mask1), np.array([[3, 2, 3, 2]])) + assert np.array_equal(masks_to_xyxy(mask1), np.array([[3, 2, 3, 2]])) # Case with a rectangle mask2 = np.zeros((1, 5, 5), dtype=bool) mask2[0, 1:4, 2:5] = True # Rectangle between (1,2) and (3,4) - assert np.array_equal(mask_to_xyxy(mask2), np.array([[2, 1, 4, 3]])) + assert np.array_equal(masks_to_xyxy(mask2), np.array([[2, 1, 4, 3]])) # Case with multiple masks masks = np.zeros((2, 5, 5), dtype=bool) masks[0, 1:4, 2:5] = True # First rectangle masks[1, 0:3, 1:4] = True # Second rectangle expected = np.array([[2, 1, 4, 3], [1, 0, 3, 2]]) - assert np.array_equal(mask_to_xyxy(masks), expected) - - # Case with an empty mask - empty_mask = np.zeros((1, 5, 5), dtype=bool) - assert np.array_equal(mask_to_xyxy(empty_mask), np.array([[0, 0, 0, 0]])) + assert np.array_equal(masks_to_xyxy(masks), expected) # Case with a mask covering the entire image full_mask = np.ones((1, 5, 5), dtype=bool) - assert np.array_equal(mask_to_xyxy(full_mask), np.array([[0, 0, 4, 4]])) + assert np.array_equal(masks_to_xyxy(full_mask), np.array([[0, 0, 4, 4]])) + + +def test_det_post_process(): + cls_ids = np.array([1, 2, 3]) + boxes = np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]) + confs = np.array([0.8, 0.9, 0.7]) + out = [cls_ids, boxes, confs] + + im0_shape = (640, 480) + conf_threshold = 0.75 + sv_detections = det_postprocess(out, im0_shape, conf_threshold) + + np.testing.assert_array_equal(sv_detections.xyxy, np.array([[48, 128, 144, 256], [240, 384, 336, 512]])) + assert sv_detections.class_id is not None + np.testing.assert_array_equal(sv_detections.class_id, np.array([1, 2])) + assert sv_detections.confidence is not None + np.testing.assert_array_equal(sv_detections.confidence, np.array([0.8, 0.9])) + + +def test_semseg_postprocess(): + cls_ids = np.array([1, 2, 3]) + mask = np.array( + [ + [0, 1, 1, 2], + [0, 1, 2, 2], + [0, 0, 1, 2], + ] + ) + confs = np.array([0.7, 0.9, 0.8]) + out = [ + np.expand_dims(cls_ids, axis=0), + np.expand_dims(mask, axis=0), + np.expand_dims(confs, axis=0), + ] + + im0_shape = (3, 4) + conf_threshold = 0.75 + + sv_detections = semseg_postprocess(out, im0_shape, conf_threshold) + + # Expected masks + expected_masks = np.array( + [ + [ + [False, True, True, False], + [False, True, False, False], + [False, False, True, False], + ], # Class 1 + [ + [False, False, False, True], + [False, False, True, True], + [False, False, False, True], + ], # Class 2 + ] + ) + + # Assertions + assert sv_detections.mask is not None + np.testing.assert_array_equal(sv_detections.mask, expected_masks) + assert sv_detections.class_id is not None + np.testing.assert_array_equal(sv_detections.class_id, np.array([2, 3])) + assert sv_detections.confidence is not None + np.testing.assert_array_equal(sv_detections.confidence, np.array([0.9, 0.8])) + assert sv_detections.xyxy.shape == (2, 4) + + +def test_get_postprocess_fn(): + """ + Test the get_postprocess_fn function to ensure it returns + the correct postprocessing function for each task. + """ + # Test detection task + det_fn = get_postprocess_fn(FocoosTask.DETECTION) + assert det_fn == det_postprocess, "Detection task should return det_postprocess function" + + # Test instance segmentation task + instance_fn = get_postprocess_fn(FocoosTask.INSTANCE_SEGMENTATION) + assert instance_fn == instance_postprocess, "Instance segmentation task should return instance_postprocess function" + + # Test semantic segmentation task + semseg_fn = get_postprocess_fn(FocoosTask.SEMSEG) + assert semseg_fn == semseg_postprocess, "Semantic segmentation task should return semseg_postprocess function" + + # Test all FocoosTask values to ensure no exceptions + for task in FocoosTask: + fn = get_postprocess_fn(task) + assert callable(fn), f"Postprocess function for {task} should be callable" + + +def test_instance_postprocess(): + """Test instance segmentation postprocessing""" + cls_ids = np.array([0, 1]) + masks = np.zeros((2, 100, 100)) + masks[0, 10:30, 10:30] = 1 + masks[1, 40:60, 40:60] = 1 + confs = np.array([0.95, 0.85]) + out = [[cls_ids], [masks], [confs]] + + result = instance_postprocess(out, (100, 100), 0.8) + + assert isinstance(result, sv.Detections) + assert len(result) == 2 + assert result.mask is not None + assert result.xyxy is not None + assert result.class_id is not None + assert result.confidence is not None + + +def test_confidence_threshold_filtering(): + """Test that confidence threshold filtering works correctly""" + out = [ + np.array([0, 1, 2]), # cls_ids + np.array([[0.1, 0.1, 0.3, 0.3], [0.4, 0.4, 0.6, 0.6], [0.7, 0.7, 0.9, 0.9]]), # boxes + np.array([0.95, 0.55, 0.85]), # confs + ] + + result = det_postprocess(out, (100, 100), conf_threshold=0.8) + + assert len(result) == 2 # Should only keep detections with conf > 0.8 + assert all(conf > 0.8 for conf in result.confidence)