From 2807606201ada9fb7df23e16e4c9f763cd61a54e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20B=C3=BCchner?= Date: Tue, 9 Apr 2024 08:47:44 +0200 Subject: [PATCH] Refactor edge cases in EAR computation - better input handling - do not raise exceptions but log warnings - handle outputs interally better - update tests accordinlyg - depricate test for negative landmarks (which are now valid) --- .../facial_features/features/ear_feature.py | 45 +++++++++++++------ tests/test_earfeature.py | 24 +++++----- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/src/jefapato/facial_features/features/ear_feature.py b/src/jefapato/facial_features/features/ear_feature.py index 8308bac..99d0c1c 100644 --- a/src/jefapato/facial_features/features/ear_feature.py +++ b/src/jefapato/facial_features/features/ear_feature.py @@ -12,7 +12,7 @@ logger = structlog.get_logger() -def ear_score(eye: np.ndarray) -> float: +def ear_score(eye: np.ndarray) -> tuple[float, bool]: """ Compute the EAR Score for eye landmarks @@ -35,6 +35,7 @@ def ear_score(eye: np.ndarray) -> float: Returns: float: The computed EAR score, which should be between 0 and 1 + bool: A flag indicating if the EAR score is valid """ if eye is None: raise ValueError("eye must not be None") @@ -43,25 +44,31 @@ def ear_score(eye: np.ndarray) -> float: raise TypeError(f"eye must be a numpy array, but got {type(eye)}") if eye.shape != (6, 2) and eye.shape != (6, 3): # allow for 3D landmarks - raise ValueError(f"eye must be a 6x2 array, but got {eye.shape}") + raise ValueError(f"eye must be a 6x2 or 6x3 array, but got {eye.shape}") # check that no value is negative if np.any(eye < 0): - # raise ValueError(f"eye must not contain negative values, but got {eye}") - logger.warning(f"eye must not contain negative values, but got {eye}") + # This can be the case if parts of the face are not inside the image + # but the predictor tries to estimate the rough location. + # Thus we just log a warning and continue + logger.warning(f"Eye landmarks must not contain negative values, but got {eye}") - # dont forget the 0-index A = distance.euclidean(eye[1], eye[5]) B = distance.euclidean(eye[2], eye[4]) C = distance.euclidean(eye[0], eye[3]) ratio = (A + B) / (2.0 * C) - if ratio > 1.002: # allow for some rounding errors - # raise ValueError(f"EAR score must be between 0 and 1, but got {ratio}, check your landmarks order") - logger.warning("EAR score must be between 0 and 1, but got {ratio}, check your landmarks order") - ratio = 1.0 - return ratio + compute_valid = True + if ratio > 1.0: + logger.warning(f"EAR score must be between 0 and 1, but got {ratio}") + ratio = 1.0 + compute_valid = False + if ratio < 0.0: + logger.warning(f"EAR score must be between 0 and 1, but got {ratio}") + ratio = 0.0 + compute_valid = False + return ratio, compute_valid @dataclasses.dataclass @@ -161,8 +168,13 @@ def compute(self, features: np.ndarray, valid: bool) -> EAR_Data: return EAR_Data(1.0, 1.0, False, lm_l, lm_r) ear_valid = not (np.allclose(np.zeros_like(lm_l), lm_l) and np.allclose(np.zeros_like(lm_r), lm_r)) - ear_l = ear_score(lm_l) if ear_valid else 1.0 - ear_r = ear_score(lm_r) if ear_valid else 1.0 + ear_l, ear_l_c = ear_score(lm_l) + ear_r, ear_r_c = ear_score(lm_r) + + if not ear_l_c or not ear_r_c: + logger.warning(f"EAR score computation is not valid for left: {ear_l} and right: {ear_r}") + + ear_valid = ear_valid and ear_l_c and ear_r_c return EAR_Data(ear_l, ear_r, ear_valid, lm_l, lm_r) @@ -195,6 +207,11 @@ def compute(self, features: np.ndarray, valid: bool) -> EAR_Data: return EAR_Data(1.0, 1.0, False, lm_l, lm_r) ear_valid = not (np.allclose(np.zeros_like(lm_l), lm_l) and np.allclose(np.zeros_like(lm_r), lm_r)) - ear_l = ear_score(lm_l) if ear_valid else 1.0 - ear_r = ear_score(lm_r) if ear_valid else 1.0 + ear_l, ear_l_c = ear_score(lm_l) + ear_r, ear_r_c = ear_score(lm_r) + + if not ear_l_c or not ear_r_c: + logger.warning(f"EAR score computation is not valid for left: {ear_l} and right: {ear_r}") + + ear_valid = ear_valid and ear_l_c and ear_r_c return EAR_Data(ear_l, ear_r, ear_valid, lm_l, lm_r) \ No newline at end of file diff --git a/tests/test_earfeature.py b/tests/test_earfeature.py index b31e519..07b0a74 100644 --- a/tests/test_earfeature.py +++ b/tests/test_earfeature.py @@ -28,28 +28,30 @@ def test_ear_score_inputs(): with pytest.raises(ValueError): features.ear_score(input_data) - # Test case 5: Invalid input range - input_data = np.array([[-2, 2], [1,4], [4,4], [5,2], [4, 0], [1, 0]]) - with pytest.raises(ValueError): - features.ear_score(input_data) + # TEST DEPRIECATED + # Reason: negative values are possible (but rare) and should not raise an error + # # Test case 5: Invalid input range + # input_data = np.array([[-2, 2], [1,4], [4,4], [5,2], [4, 0], [1, 0]]) + # with pytest.raises(ValueError): + # features.ear_score(input_data) - # Test case 6: Out of range + # Test case 6: Out of range, thus not valid output input_data = np.array([[0, 2], [1,10], [4,10], [5,2], [4, 0], [1, 0]]) - with pytest.raises(ValueError): - features.ear_score(input_data) - + score, valid = features.ear_score(input_data) + assert not valid + def test_ear_score(): # Test case 1: Valid input of sphere input_data = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) expected_output = 1.0 - assert features.ear_score(input_data) == approx(expected_output) + assert features.ear_score(input_data)[0] == approx(expected_output) # Test case 2: Valid input input_data = np.array([[0, 2], [1,4], [4,4], [5,2], [4, 0], [1, 0]]) expected_output = (4 + 4) / 10 - assert features.ear_score(input_data) == approx(expected_output) + assert features.ear_score(input_data)[0] == approx(expected_output) # Test case 3: Valid input input_data = np.array([[0, 2], [1,3], [4,3], [5,2], [4, 1], [1, 1]]) expected_output = (2 + 2) / 10 - assert features.ear_score(input_data) == approx(expected_output) + assert features.ear_score(input_data)[0] == approx(expected_output)