Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: separate postprocessing from runtime inference logic and handle null masks results #79

Merged
merged 2 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions focoos/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from focoos.runtime import BaseRuntime, load_runtime
from focoos.utils.logger import get_logger
from focoos.utils.vision import (
get_postprocess_fn,
image_preprocess,
scale_detections,
sv_to_fai_detections,
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
# Load metadata and set model reference
self.metadata: ModelMetadata = self._read_metadata()
self.model_ref = self.metadata.ref
self.postprocess_fn = get_postprocess_fn(self.metadata.task)

# Initialize annotation utilities
self.label_annotator = sv.LabelAnnotator(text_padding=10, border_radius=10)
Expand Down Expand Up @@ -137,6 +139,9 @@ def _annotate(self, im: np.ndarray, detections: sv.Detections) -> np.ndarray:
Returns:
np.ndarray: The annotated image with bounding boxes or masks.
"""
if len(detections.xyxy) == 0:
logger.warning("No detections found, skipping annotation")
return im
classes = self.metadata.classes
labels = [
f"{classes[int(class_id)] if classes is not None else str(class_id)}: {confid * 100:.0f}%"
Expand Down Expand Up @@ -189,11 +194,16 @@ def infer(
t0 = perf_counter()
im1, im0 = image_preprocess(image, resize=resize)
t1 = perf_counter()
detections = self.runtime(im1.astype(np.float32), threshold)
detections = self.runtime(im1.astype(np.float32))

t2 = perf_counter()

detections = self.postprocess_fn(
out=detections, im0_shape=(im0.shape[1], im0.shape[0]), conf_threshold=threshold
)

if resize:
detections = scale_detections(detections, (resize, resize), (im0.shape[1], im0.shape[0]))
logger.debug(f"Inference time: {t2 - t1:.3f} seconds")

out = sv_to_fai_detections(detections, classes=self.metadata.classes)
t3 = perf_counter()
Expand All @@ -205,6 +215,10 @@ def infer(
im = None
if annotate:
im = self._annotate(im0, detections)

logger.debug(
f"Found {len(detections)} detections. Inference time: {(t2 - t1) * 1000:.0f}ms, preprocess: {(t1 - t0) * 1000:.0f}ms, postprocess: {(t3 - t2) * 1000:.0f}ms"
)
return FocoosDetections(detections=out, latency=latency), im

def benchmark(self, iterations: int, size: int) -> LatencyMetrics:
Expand Down
4 changes: 3 additions & 1 deletion focoos/remote_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,13 @@ def infer(
)
t1 = time.time()
if res.status_code == 200:
logger.debug(f"Inference time: {t1 - t0:.3f} seconds")
detections = FocoosDetections(
detections=[FocoosDet.from_json(d) for d in res.json().get("detections", [])],
latency=res.json().get("latency", None),
)
logger.debug(
f"Found {len(detections.detections)} detections. Inference Request time: {(t1 - t0) * 1000:.0f}ms"
)
preview = None
if annotate:
im0 = image_loader(image)
Expand Down
100 changes: 6 additions & 94 deletions focoos/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
from abc import abstractmethod
from pathlib import Path
from time import perf_counter
from typing import Any, List, Tuple
from typing import Any

import numpy as np

from focoos.utils.vision import mask_to_xyxy

try:
import torch

Expand All @@ -40,11 +38,9 @@
except ImportError:
ORT_AVAILABLE = False

import supervision as sv

# from supervision.detection.utils import mask_to_xyxy
from focoos.ports import (
FocoosTask,
LatencyMetrics,
ModelMetadata,
OnnxRuntimeOpts,
Expand All @@ -59,93 +55,12 @@
logger = get_logger()


def get_postprocess_fn(task: FocoosTask):
if task == FocoosTask.INSTANCE_SEGMENTATION:
return instance_postprocess
elif task == FocoosTask.SEMSEG:
return semseg_postprocess
else:
return det_postprocess


def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections:
"""
Postprocesses the output of an object detection model and filters detections
based on a confidence threshold.

Args:
out (List[np.ndarray]): The output of the detection model.
im0_shape (Tuple[int, int]): The original shape of the input image (height, width).
conf_threshold (float): The confidence threshold for filtering detections.

Returns:
sv.Detections: A sv.Detections object containing the filtered bounding boxes, class ids, and confidences.
"""
cls_ids, boxes, confs = out
boxes[:, 0::2] *= im0_shape[1]
boxes[:, 1::2] *= im0_shape[0]
high_conf_indices = (confs > conf_threshold).nonzero()

return sv.Detections(
xyxy=boxes[high_conf_indices].astype(int),
class_id=cls_ids[high_conf_indices].astype(int),
confidence=confs[high_conf_indices].astype(float),
)


def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections:
"""
Postprocesses the output of a semantic segmentation model and filters based
on a confidence threshold.

Args:
out (List[np.ndarray]): The output of the semantic segmentation model.
conf_threshold (float): The confidence threshold for filtering detections.

Returns:
sv.Detections: A sv.Detections object containing the masks, class ids, and confidences.
"""
cls_ids, mask, confs = out[0][0], out[1][0], out[2][0]
masks = np.equal(mask, np.arange(len(cls_ids))[:, None, None])
high_conf_indices = np.where(confs > conf_threshold)[0]
masks = masks[high_conf_indices].astype(bool)
cls_ids = cls_ids[high_conf_indices].astype(int)
confs = confs[high_conf_indices].astype(float)
xyxy = mask_to_xyxy(masks)
return sv.Detections(
mask=masks,
xyxy=xyxy,
class_id=cls_ids,
confidence=confs,
)


def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections:
"""
Postprocesses the output of an instance segmentation model and filters detections
based on a confidence threshold.
"""
cls_ids, mask, confs = out[0][0], out[1][0], out[2][0]
high_conf_indices = np.where(confs > conf_threshold)[0]

masks = mask[high_conf_indices].astype(bool)
cls_ids = cls_ids[high_conf_indices].astype(int)
confs = confs[high_conf_indices].astype(float)
xyxy = mask_to_xyxy(masks)
return sv.Detections(
mask=masks,
xyxy=xyxy,
class_id=cls_ids,
confidence=confs,
)


class BaseRuntime:
def __init__(self, model_path: str, opts: Any, model_metadata: ModelMetadata):
pass

@abstractmethod
def __call__(self, im: np.ndarray, conf_threshold: float) -> sv.Detections:
def __call__(self, im: np.ndarray) -> np.ndarray:
pass

@abstractmethod
Expand All @@ -169,8 +84,6 @@ def __init__(self, model_path: str, opts: OnnxRuntimeOpts, model_metadata: Model
self.opts = opts
self.model_metadata = model_metadata

self.postprocess_fn = get_postprocess_fn(model_metadata.task)

# Setup session options
options = ort.SessionOptions()
options.log_severity_level = 0 if opts.verbose else 2
Expand Down Expand Up @@ -241,12 +154,12 @@ def _warmup(self):

self.logger.info("⏱️ [onnxruntime] Warmup done")

def __call__(self, im: np.ndarray, conf_threshold: float) -> sv.Detections:
def __call__(self, im: np.ndarray) -> list[np.ndarray]:
"""Run inference and return detections."""
input_name = self.ort_sess.get_inputs()[0].name
out_name = [output.name for output in self.ort_sess.get_outputs()]
out = self.ort_sess.run(out_name, {input_name: im})
return self.postprocess_fn(out=out, im0_shape=(im.shape[2], im.shape[3]), conf_threshold=conf_threshold)
return out

def benchmark(self, iterations=20, size=640) -> LatencyMetrics:
"""Benchmark model latency."""
Expand Down Expand Up @@ -297,7 +210,6 @@ def __init__(
self.logger = get_logger(name="TorchscriptEngine")
self.logger.info(f"🔧 [torchscript] Device: {self.device}")
self.opts = opts
self.postprocess_fn = get_postprocess_fn(model_metadata.task)

map_location = None if torch.cuda.is_available() else "cpu"

Expand All @@ -312,12 +224,12 @@ def __init__(
self.model(np_image)
self.logger.info("⏱️ [torchscript] WARMUP DONE")

def __call__(self, im: np.ndarray, conf_threshold: float) -> sv.Detections:
def __call__(self, im: np.ndarray) -> list[np.ndarray]:
"""Run inference and return detections."""
with torch.no_grad():
torch_image = torch.from_numpy(im).to(self.device, dtype=torch.float32)
res = self.model(torch_image)
return self.postprocess_fn([r.cpu().numpy() for r in res], (im.shape[2], im.shape[3]), conf_threshold)
return [r.cpu().numpy() for r in res]

def benchmark(self, iterations=20, size=640) -> LatencyMetrics:
"""Benchmark model latency."""
Expand Down
Loading