Skip to content

Commit

Permalink
test(runtime): add unit test for det_postprocess function to validate…
Browse files Browse the repository at this point in the history
… output structure and values
  • Loading branch information
giuseppeambrosio97 committed Jan 7, 2025
1 parent ddb3fd4 commit dc755ba
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
import pathlib
from unittest.mock import MagicMock

import numpy as np
import pytest
from pytest_mock import MockerFixture

from focoos.ports import ModelMetadata, OnnxEngineOpts, RuntimeTypes
from focoos.runtime import ONNXRuntime, get_runtime
from focoos.runtime import ONNXRuntime, det_postprocess, get_runtime


def test_det_post_process():
cls_ids = np.array([1, 2, 3])
boxes = np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]])
confs = np.array([0.8, 0.9, 0.7])
out = [cls_ids, boxes, confs]

im0_shape = (640, 480)
conf_threshold = 0.75
sv_detections = det_postprocess(out, im0_shape, conf_threshold)

np.testing.assert_array_equal(
sv_detections.xyxy, np.array([[48, 128, 144, 256], [240, 384, 336, 512]])
)
assert sv_detections.class_id is not None
np.testing.assert_array_equal(sv_detections.class_id, np.array([1, 2]))
assert sv_detections.confidence is not None
np.testing.assert_array_equal(sv_detections.confidence, np.array([0.8, 0.9]))


@pytest.mark.parametrize(
Expand Down

0 comments on commit dc755ba

Please sign in to comment.