Skip to content

Commit

Permalink
test: enhance runtime and postprocessing test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousDolphin committed Feb 21, 2025
1 parent 4a44fd1 commit c6ee176
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/test_local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __call__(self, *args, **kwargs):


@pytest.mark.parametrize("annotate", [(False, None)])
def test_infer_(
def test_infer_onnx(
mocker,
mock_local_model_onnx,
image_ndarray,
Expand Down
292 changes: 289 additions & 3 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,68 @@
import pathlib
from unittest.mock import MagicMock
from datetime import datetime
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
import supervision as sv
from pytest_mock import MockerFixture

from focoos.ports import ModelMetadata, OnnxRuntimeOpts, RuntimeTypes, TorchscriptRuntimeOpts
from focoos.runtime import ONNXRuntime, TorchscriptRuntime, det_postprocess, load_runtime, semseg_postprocess
from focoos.ports import (
FocoosTask,
LatencyMetrics,
ModelMetadata,
ModelStatus,
OnnxRuntimeOpts,
RuntimeTypes,
TorchscriptRuntimeOpts,
)
from focoos.runtime import (
ORT_AVAILABLE,
TORCH_AVAILABLE,
ONNXRuntime,
TorchscriptRuntime,
det_postprocess,
get_postprocess_fn,
instance_postprocess,
load_runtime,
semseg_postprocess,
)


def test_runtime_availability():
"""
Test the runtime availability flags.
These flags should be boolean values indicating whether
PyTorch and ONNX Runtime are available in the environment.
"""
# Check that the flags are boolean
assert isinstance(TORCH_AVAILABLE, bool), "TORCH_AVAILABLE should be a boolean"
assert isinstance(ORT_AVAILABLE, bool), "ORT_AVAILABLE should be a boolean"

# At least one runtime should be available for the library to work
assert TORCH_AVAILABLE or ORT_AVAILABLE, "At least one runtime (PyTorch or ONNX Runtime) must be available"


@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not available")
def test_torch_import():
"""
Test PyTorch import when available.
This test is skipped if PyTorch is not installed.
"""
import torch

assert torch is not None, "PyTorch should be properly imported"


@pytest.mark.skipif(not ORT_AVAILABLE, reason="ONNX Runtime not available")
def test_onnx_import():
"""
Test ONNX Runtime import when available.
This test is skipped if ONNX Runtime is not installed.
"""
import onnxruntime as ort

assert ort is not None, "ONNX Runtime should be properly imported"


def test_det_post_process():
Expand Down Expand Up @@ -188,3 +244,233 @@ def test_load_unavailable_runtime(mocker: MockerFixture):
load_runtime(RuntimeTypes.TORCHSCRIPT_32, "fake_model_path", MagicMock(spec=ModelMetadata), 2)
with pytest.raises(ImportError):
load_runtime(RuntimeTypes.ONNX_CUDA32, "fake_model_path", MagicMock(spec=ModelMetadata), 2)


def test_get_postprocess_fn():
"""
Test the get_postprocess_fn function to ensure it returns
the correct postprocessing function for each task.
"""
# Test detection task
det_fn = get_postprocess_fn(FocoosTask.DETECTION)
assert det_fn == det_postprocess, "Detection task should return det_postprocess function"

# Test instance segmentation task
instance_fn = get_postprocess_fn(FocoosTask.INSTANCE_SEGMENTATION)
assert instance_fn == instance_postprocess, "Instance segmentation task should return instance_postprocess function"

# Test semantic segmentation task
semseg_fn = get_postprocess_fn(FocoosTask.SEMSEG)
assert semseg_fn == semseg_postprocess, "Semantic segmentation task should return semseg_postprocess function"

# Test all FocoosTask values to ensure no exceptions
for task in FocoosTask:
fn = get_postprocess_fn(task)
assert callable(fn), f"Postprocess function for {task} should be callable"


@pytest.fixture
def detection_output():
"""Fixture for detection model output"""
cls_ids = np.array([0, 1, 2])
boxes = np.array([[0.1, 0.1, 0.3, 0.3], [0.4, 0.4, 0.6, 0.6], [0.7, 0.7, 0.9, 0.9]])
confs = np.array([0.95, 0.85, 0.75])
return [cls_ids, boxes, confs]


@pytest.fixture
def segmentation_output():
"""Fixture for segmentation model output"""
cls_ids = np.array([0, 1, 2])
mask = np.zeros((1, 100, 100))
mask[0, 10:30, 10:30] = 1 # Class 1 mask
mask[0, 40:60, 40:60] = 2 # Class 2 mask
confs = np.array([0.95, 0.85, 0.75])
return [[cls_ids], [mask], [confs]]


def test_det_postprocess2():
"""Test detection postprocessing"""
out = [
np.array([0, 1]), # cls_ids
np.array([[0.1, 0.1, 0.3, 0.3], [0.4, 0.4, 0.6, 0.6]]), # boxes
np.array([0.95, 0.85]), # confs
]
im0_shape = (100, 100)
conf_threshold = 0.8

result = det_postprocess(out, im0_shape, conf_threshold)

assert isinstance(result, sv.Detections)
assert len(result) == 2 # Should keep both detections above threshold
assert result.class_id.dtype == int
assert result.confidence.dtype == float
assert result.xyxy.dtype == int


def test_semseg_postprocess2():
"""Test semantic segmentation postprocessing"""
cls_ids = np.array([0, 1])
mask = np.zeros((1, 100, 100))
mask[0, 10:30, 10:30] = 1
confs = np.array([0.95, 0.85])
out = [[cls_ids], [mask], [confs]]

result = semseg_postprocess(out, (100, 100), 0.8)

assert isinstance(result, sv.Detections)
assert len(result) == 2
assert result.mask is not None
assert result.xyxy is not None
assert result.class_id is not None
assert result.confidence is not None


def test_instance_postprocess():
"""Test instance segmentation postprocessing"""
cls_ids = np.array([0, 1])
masks = np.zeros((2, 100, 100))
masks[0, 10:30, 10:30] = 1
masks[1, 40:60, 40:60] = 1
confs = np.array([0.95, 0.85])
out = [[cls_ids], [masks], [confs]]

result = instance_postprocess(out, (100, 100), 0.8)

assert isinstance(result, sv.Detections)
assert len(result) == 2
assert result.mask is not None
assert result.xyxy is not None
assert result.class_id is not None
assert result.confidence is not None


def test_confidence_threshold_filtering():
"""Test that confidence threshold filtering works correctly"""
out = [
np.array([0, 1, 2]), # cls_ids
np.array([[0.1, 0.1, 0.3, 0.3], [0.4, 0.4, 0.6, 0.6], [0.7, 0.7, 0.9, 0.9]]), # boxes
np.array([0.95, 0.55, 0.85]), # confs
]

result = det_postprocess(out, (100, 100), conf_threshold=0.8)

assert len(result) == 2 # Should only keep detections with conf > 0.8
assert all(conf > 0.8 for conf in result.confidence)


@pytest.fixture
def mock_torch():
"""Mock torch and its required components"""
with patch("focoos.runtime.torch") as mock_torch:
# Mock device
mock_device = MagicMock()
mock_torch.device.return_value = mock_device
mock_torch.cuda.is_available.return_value = True

# Mock model
mock_model = MagicMock()
mock_torch.jit.load.return_value = mock_model
mock_model.to.return_value = mock_model

# Mock tensor operations
mock_torch.from_numpy.return_value = MagicMock()
mock_torch.rand.return_value = MagicMock()

yield mock_torch


@pytest.fixture
def runtime_opts():
"""Fixture for TorchscriptRuntime options"""
return TorchscriptRuntimeOpts(warmup_iter=2)


@pytest.fixture
def model_metadata():
"""Fixture for model metadata"""
return ModelMetadata(
task=FocoosTask.DETECTION,
classes=["class1", "class2"],
ref="test_ref",
name="test_name",
owner_ref="test_owner_ref",
focoos_model="test_focoos_model",
created_at=datetime.now(),
updated_at=datetime.now(),
status=ModelStatus.TRAINING_COMPLETED,
)


def test_torchscript_runtime_init(mock_torch, runtime_opts, model_metadata, tmp_path):
"""Test TorchscriptRuntime initialization"""
model_path = tmp_path / "model.pt"
model_path.write_bytes(b"dummy model")

TorchscriptRuntime(str(model_path), runtime_opts, model_metadata)

# Check if torch.jit.load was called with correct arguments
mock_torch.jit.load.assert_called_once_with(str(model_path), map_location=None)

# Check if model was moved to correct device
mock_model = mock_torch.jit.load.return_value
mock_model.to.assert_called_once()

# Check warmup iterations
assert mock_torch.rand.call_count == 1 # One call for warmup input
assert mock_model.call_count == 2 # Two warmup iterations


def test_torchscript_runtime_inference(mock_torch, runtime_opts, model_metadata, tmp_path):
"""Test TorchscriptRuntime inference"""
model_path = tmp_path / "model.pt"
model_path.write_bytes(b"dummy model")

# Mock model output
mock_output = [
MagicMock(cpu=lambda: MagicMock(numpy=lambda: np.array([0, 1, 2]))),
MagicMock(cpu=lambda: MagicMock(numpy=lambda: np.array([[0.1, 0.1, 0.2, 0.2]]))),
MagicMock(cpu=lambda: MagicMock(numpy=lambda: np.array([0.9]))),
]
mock_torch.jit.load.return_value.return_value = mock_output

runtime = TorchscriptRuntime(str(model_path), runtime_opts, model_metadata)

# Create dummy input
input_tensor = np.random.rand(1, 3, 640, 640).astype(np.float32)
result = runtime(input_tensor, conf_threshold=0.5)

# Check if model was called with correct input
mock_torch.from_numpy.assert_called_once()
assert isinstance(result, sv.Detections)


def test_torchscript_runtime_benchmark(mock_torch, runtime_opts, model_metadata, tmp_path):
"""Test TorchscriptRuntime benchmark"""
model_path = tmp_path / "model.pt"
model_path.write_bytes(b"dummy model")

runtime = TorchscriptRuntime(str(model_path), runtime_opts, model_metadata)
metrics = runtime.benchmark(iterations=3, size=320)

assert isinstance(metrics, LatencyMetrics)
assert metrics.engine == "torchscript"
assert metrics.im_size == 320
assert isinstance(metrics.fps, int)
assert isinstance(metrics.mean, float)
assert isinstance(metrics.max, float)
assert isinstance(metrics.min, float)
assert isinstance(metrics.std, float)


@pytest.mark.parametrize("cuda_available", [True, False])
def test_torchscript_runtime_device_selection(mock_torch, runtime_opts, model_metadata, tmp_path, cuda_available):
"""Test device selection based on CUDA availability"""
mock_torch.cuda.is_available.return_value = cuda_available
model_path = tmp_path / "model.pt"
model_path.write_bytes(b"dummy model")

TorchscriptRuntime(str(model_path), runtime_opts, model_metadata)

expected_device = "cuda" if cuda_available else "cpu"
mock_torch.device.assert_called_once_with(expected_device)

0 comments on commit c6ee176

Please sign in to comment.