Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add instance segmentation postprocessing support #75

Merged
merged 1 commit into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions focoos/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
logger = get_logger()


def get_postprocess_fn(task: FocoosTask):
if task == FocoosTask.INSTANCE_SEGMENTATION:
return instance_postprocess
elif task == FocoosTask.SEMSEG:
return semseg_postprocess
else:
return det_postprocess


def det_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections:
"""
Postprocesses the output of an object detection model and filters detections
Expand Down Expand Up @@ -108,6 +117,26 @@ def semseg_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_t
)


def instance_postprocess(out: List[np.ndarray], im0_shape: Tuple[int, int], conf_threshold: float) -> sv.Detections:
"""
Postprocesses the output of an instance segmentation model and filters detections
based on a confidence threshold.
"""
cls_ids, mask, confs = out[0][0], out[1][0], out[2][0]
high_conf_indices = np.where(confs > conf_threshold)[0]

masks = mask[high_conf_indices].astype(bool)
cls_ids = cls_ids[high_conf_indices].astype(int)
confs = confs[high_conf_indices].astype(float)
return sv.Detections(
mask=masks,
# xyxy is required from supervision
xyxy=np.zeros(shape=(len(high_conf_indices), 4), dtype=np.uint8),
class_id=cls_ids,
confidence=confs,
)


class BaseRuntime:
def __init__(self, model_path: str, opts: Any, model_metadata: ModelMetadata):
pass
Expand Down Expand Up @@ -136,7 +165,8 @@ def __init__(self, model_path: str, opts: OnnxRuntimeOpts, model_metadata: Model
self.name = Path(model_path).stem
self.opts = opts
self.model_metadata = model_metadata
self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess

self.postprocess_fn = get_postprocess_fn(model_metadata.task)

# Setup session options
options = ort.SessionOptions()
Expand Down Expand Up @@ -264,7 +294,7 @@ def __init__(
self.logger = get_logger(name="TorchscriptEngine")
self.logger.info(f"🔧 [torchscript] Device: {self.device}")
self.opts = opts
self.postprocess_fn = det_postprocess if model_metadata.task == FocoosTask.DETECTION else semseg_postprocess
self.postprocess_fn = get_postprocess_fn(model_metadata.task)

map_location = None if torch.cuda.is_available() else "cpu"

Expand Down
4 changes: 3 additions & 1 deletion focoos/utils/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def scale_mask(mask: np.ndarray, target_shape: tuple) -> np.ndarray:
"""
# Calculate scale factors for height and width
scale_factors = (target_shape[0] / mask.shape[0], target_shape[1] / mask.shape[1])

# Resize the mask using zoom with nearest-neighbor interpolation (order=0)
scaled_mask = zoom(mask, scale_factors, order=0) > 0.5

Expand All @@ -123,6 +122,9 @@ def scale_detections(detections: sv.Detections, in_shape: tuple, out_shape: tupl
x_ratio = out_shape[0] / in_shape[0]
y_ratio = out_shape[1] / in_shape[1]
detections.xyxy = detections.xyxy * np.array([x_ratio, y_ratio, x_ratio, y_ratio])

if detections.mask is not None:
detections.mask = np.array([scale_mask(m, out_shape) for m in detections.mask])
return detections


Expand Down