diff --git a/focoos/runtime.py b/focoos/runtime.py index 61d4491..d00df15 100644 --- a/focoos/runtime.py +++ b/focoos/runtime.py @@ -56,6 +56,15 @@ logger = get_logger() +def get_postprocess_fn(task: FocoosTask): + if task == FocoosTask.INSTANCE_SEGMENTATION: + return instance_postprocess + elif task == FocoosTask.SEMSEG: + return semseg_postprocess + else: + return det_postprocess + + def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: """ Postprocesses the output of an object detection model and filters detections @@ -108,6 +117,26 @@ def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_t ) +def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections: + """ + Postprocesses the output of an instance segmentation model and filters detections + based on a confidence threshold. + """ + cls_ids, mask, confs = out[0][0], out[1][0], out[2][0] + high_conf_indices = np.where(confs > conf_threshold)[0] + + masks = mask[high_conf_indices].astype(bool) + cls_ids = cls_ids[high_conf_indices].astype(int) + confs = confs[high_conf_indices].astype(float) + return sv.Detections( + mask=masks, + # xyxy is required from supervision + xyxy=np.zeros(shape=(len(high_conf_indices), 4), dtype=np.uint8), + class_id=cls_ids, + confidence=confs, + ) + + class BaseRuntime: def __init__(self, model_path: str, opts: Any, model_metadata: ModelMetadata): pass @@ -136,7 +165,8 @@ def __init__(self, model_path: str, opts: OnnxRuntimeOpts, model_metadata: Model self.name = Path(model_path).stem self.opts = opts self.model_metadata = model_metadata - self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess + + self.postprocess_fn = get_postprocess_fn(model_metadata.task) # Setup session options options = ort.SessionOptions() @@ -264,7 +294,7 @@ def __init__( self.logger = get_logger(name="TorchscriptEngine") self.logger.info(f"🔧 [torchscript] Device: {self.device}") self.opts = opts - self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess + self.postprocess_fn = get_postprocess_fn(model_metadata.task) map_location = None if torch.cuda.is_available() else "cpu" diff --git a/focoos/utils/vision.py b/focoos/utils/vision.py index 56ec437..c778bee 100644 --- a/focoos/utils/vision.py +++ b/focoos/utils/vision.py @@ -109,7 +109,6 @@ def scale_mask(mask: np.ndarray, target_shape: tuple) -> np.ndarray: """ # Calculate scale factors for height and width scale_factors = (target_shape[0] / mask.shape[0], target_shape[1] / mask.shape[1]) - # Resize the mask using zoom with nearest-neighbor interpolation (order=0) scaled_mask = zoom(mask, scale_factors, order=0) > 0.5 @@ -123,6 +122,9 @@ def scale_detections(detections: sv.Detections, in_shape: tuple, out_shape: tupl x_ratio = out_shape[0] / in_shape[0] y_ratio = out_shape[1] / in_shape[1] detections.xyxy = detections.xyxy * np.array([x_ratio, y_ratio, x_ratio, y_ratio]) + + if detections.mask is not None: + detections.mask = np.array([scale_mask(m, out_shape) for m in detections.mask]) return detections