Skip to content

Commit

Permalink
Refactor edge cases in EAR computation
Browse files Browse the repository at this point in the history
- better input handling
- do not raise exceptions but log warnings
- handle outputs interally better
- update tests accordinlyg
- depricate test for negative landmarks (which are now valid)
  • Loading branch information
Timozen committed Apr 9, 2024
1 parent d743527 commit 2807606
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
45 changes: 31 additions & 14 deletions src/jefapato/facial_features/features/ear_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

logger = structlog.get_logger()

def ear_score(eye: np.ndarray) -> float:
def ear_score(eye: np.ndarray) -> tuple[float, bool]:
"""
Compute the EAR Score for eye landmarks
Expand All @@ -35,6 +35,7 @@ def ear_score(eye: np.ndarray) -> float:
Returns:
float: The computed EAR score, which should be between 0 and 1
bool: A flag indicating if the EAR score is valid
"""
if eye is None:
raise ValueError("eye must not be None")
Expand All @@ -43,25 +44,31 @@ def ear_score(eye: np.ndarray) -> float:
raise TypeError(f"eye must be a numpy array, but got {type(eye)}")

if eye.shape != (6, 2) and eye.shape != (6, 3): # allow for 3D landmarks
raise ValueError(f"eye must be a 6x2 array, but got {eye.shape}")
raise ValueError(f"eye must be a 6x2 or 6x3 array, but got {eye.shape}")

# check that no value is negative
if np.any(eye < 0):
# raise ValueError(f"eye must not contain negative values, but got {eye}")
logger.warning(f"eye must not contain negative values, but got {eye}")
# This can be the case if parts of the face are not inside the image
# but the predictor tries to estimate the rough location.
# Thus we just log a warning and continue
logger.warning(f"Eye landmarks must not contain negative values, but got {eye}")


# dont forget the 0-index
A = distance.euclidean(eye[1], eye[5])
B = distance.euclidean(eye[2], eye[4])
C = distance.euclidean(eye[0], eye[3])

ratio = (A + B) / (2.0 * C)
if ratio > 1.002: # allow for some rounding errors
# raise ValueError(f"EAR score must be between 0 and 1, but got {ratio}, check your landmarks order")
logger.warning("EAR score must be between 0 and 1, but got {ratio}, check your landmarks order")
ratio = 1.0
return ratio
compute_valid = True
if ratio > 1.0:
logger.warning(f"EAR score must be between 0 and 1, but got {ratio}")
ratio = 1.0
compute_valid = False
if ratio < 0.0:
logger.warning(f"EAR score must be between 0 and 1, but got {ratio}")
ratio = 0.0
compute_valid = False
return ratio, compute_valid


@dataclasses.dataclass
Expand Down Expand Up @@ -161,8 +168,13 @@ def compute(self, features: np.ndarray, valid: bool) -> EAR_Data:
return EAR_Data(1.0, 1.0, False, lm_l, lm_r)

ear_valid = not (np.allclose(np.zeros_like(lm_l), lm_l) and np.allclose(np.zeros_like(lm_r), lm_r))
ear_l = ear_score(lm_l) if ear_valid else 1.0
ear_r = ear_score(lm_r) if ear_valid else 1.0
ear_l, ear_l_c = ear_score(lm_l)
ear_r, ear_r_c = ear_score(lm_r)

if not ear_l_c or not ear_r_c:
logger.warning(f"EAR score computation is not valid for left: {ear_l} and right: {ear_r}")

ear_valid = ear_valid and ear_l_c and ear_r_c
return EAR_Data(ear_l, ear_r, ear_valid, lm_l, lm_r)


Expand Down Expand Up @@ -195,6 +207,11 @@ def compute(self, features: np.ndarray, valid: bool) -> EAR_Data:
return EAR_Data(1.0, 1.0, False, lm_l, lm_r)

ear_valid = not (np.allclose(np.zeros_like(lm_l), lm_l) and np.allclose(np.zeros_like(lm_r), lm_r))
ear_l = ear_score(lm_l) if ear_valid else 1.0
ear_r = ear_score(lm_r) if ear_valid else 1.0
ear_l, ear_l_c = ear_score(lm_l)
ear_r, ear_r_c = ear_score(lm_r)

if not ear_l_c or not ear_r_c:
logger.warning(f"EAR score computation is not valid for left: {ear_l} and right: {ear_r}")

ear_valid = ear_valid and ear_l_c and ear_r_c
return EAR_Data(ear_l, ear_r, ear_valid, lm_l, lm_r)
24 changes: 13 additions & 11 deletions tests/test_earfeature.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,30 @@ def test_ear_score_inputs():
with pytest.raises(ValueError):
features.ear_score(input_data)

# Test case 5: Invalid input range
input_data = np.array([[-2, 2], [1,4], [4,4], [5,2], [4, 0], [1, 0]])
with pytest.raises(ValueError):
features.ear_score(input_data)
# TEST DEPRIECATED
# Reason: negative values are possible (but rare) and should not raise an error
# # Test case 5: Invalid input range
# input_data = np.array([[-2, 2], [1,4], [4,4], [5,2], [4, 0], [1, 0]])
# with pytest.raises(ValueError):
# features.ear_score(input_data)

# Test case 6: Out of range
# Test case 6: Out of range, thus not valid output
input_data = np.array([[0, 2], [1,10], [4,10], [5,2], [4, 0], [1, 0]])
with pytest.raises(ValueError):
features.ear_score(input_data)
score, valid = features.ear_score(input_data)
assert not valid

def test_ear_score():
# Test case 1: Valid input of sphere
input_data = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
expected_output = 1.0
assert features.ear_score(input_data) == approx(expected_output)
assert features.ear_score(input_data)[0] == approx(expected_output)

# Test case 2: Valid input
input_data = np.array([[0, 2], [1,4], [4,4], [5,2], [4, 0], [1, 0]])
expected_output = (4 + 4) / 10
assert features.ear_score(input_data) == approx(expected_output)
assert features.ear_score(input_data)[0] == approx(expected_output)

# Test case 3: Valid input
input_data = np.array([[0, 2], [1,3], [4,3], [5,2], [4, 1], [1, 1]])
expected_output = (2 + 2) / 10
assert features.ear_score(input_data) == approx(expected_output)
assert features.ear_score(input_data)[0] == approx(expected_output)

0 comments on commit 2807606

Please sign in to comment.