Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend LineZone to filter out miscounts #1540

Merged
merged 9 commits into from
Nov 6, 2024
211 changes: 143 additions & 68 deletions supervision/detection/line_zone.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import math
import warnings
from collections import Counter
from collections import Counter, defaultdict, deque
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
from typing import Any, Deque, Dict, Iterable, List, Literal, Optional, Tuple

import cv2
import numpy as np
import numpy.typing as npt

from supervision.config import CLASS_NAME_DATA_FIELD
from supervision.detection.core import Detections
Expand Down Expand Up @@ -74,6 +75,7 @@ def __init__(
Position.BOTTOM_LEFT,
Position.BOTTOM_RIGHT,
),
crossing_acceptance_threshold: int = 1,
):
"""
Args:
Expand All @@ -84,10 +86,18 @@ def __init__(
to consider when deciding on whether the detection
has passed the line counter or not. By default, this
contains the four corners of the detection's bounding box
crossing_acceptance_threshold (int): Detection needs to be seen
on the other side of the line for this many frames to be
considered as having crossed the line. This is useful when
dealing with unstable bounding boxes or when detections
may linger on the line.
"""
self.vector = Vector(start=start, end=end)
self.limits = self.calculate_region_of_interest_limits(vector=self.vector)
self.tracker_state: Dict[str, bool] = {}
self.limits = self._calculate_region_of_interest_limits(vector=self.vector)
self.crossing_history_length = max(2, crossing_acceptance_threshold + 1)
self.crossing_state_history: Dict[int, Deque[bool]] = defaultdict(
lambda: deque(maxlen=self.crossing_history_length)
)
self._in_count_per_class: Counter = Counter()
self._out_count_per_class: Counter = Counter()
self.triggering_anchors = triggering_anchors
Expand Down Expand Up @@ -127,8 +137,82 @@ def out_count_per_class(self) -> Dict[int, int]:
"""
return dict(self._out_count_per_class)

def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
"""
Update the `in_count` and `out_count` based on the objects that cross the line.

Args:
detections (Detections): A list of detections for which to update the
counts.

Returns:
A tuple of two boolean NumPy arrays. The first array indicates which
detections have crossed the line from outside to inside. The second
array indicates which detections have crossed the line from inside to
outside.
"""
crossed_in = np.full(len(detections), False)
crossed_out = np.full(len(detections), False)

if len(detections) == 0:
return crossed_in, crossed_out

if detections.tracker_id is None:
warnings.warn(
"Line zone counting skipped. LineZone requires tracker_id. Refer to "
"https://supervision.roboflow.com/latest/trackers for more "
"information.",
category=SupervisionWarnings,
)
return crossed_in, crossed_out

self._update_class_id_to_name(detections)

in_limits, has_any_left_trigger, has_any_right_trigger = (
self._compute_anchor_sides(detections)
)

class_ids: List[Optional[int]] = (
list(detections.class_id)
if detections.class_id is not None
else [None] * len(detections)
)

for i, (class_id, tracker_id) in enumerate(
zip(class_ids, detections.tracker_id)
):
if not in_limits[i]:
continue

if has_any_left_trigger[i] and has_any_right_trigger[i]:
continue

tracker_state: bool = has_any_left_trigger[i]
crossing_history = self.crossing_state_history[tracker_id]
crossing_history.append(tracker_state)

if len(crossing_history) < self.crossing_history_length:
continue

# TODO: Account for incorrect class_id.
# Most likely this would involve indexing self.crossing_state_history
# with (tracker_id, class_id).

oldest_state = crossing_history[0]
if crossing_history.count(oldest_state) > 1:
continue

if tracker_state:
self._in_count_per_class[class_id] += 1
crossed_in[i] = True
else:
self._out_count_per_class[class_id] += 1
crossed_out[i] = True

return crossed_in, crossed_out

@staticmethod
def calculate_region_of_interest_limits(vector: Vector) -> Tuple[Vector, Vector]:
def _calculate_region_of_interest_limits(vector: Vector) -> Tuple[Vector, Vector]:
magnitude = vector.magnitude

if magnitude == 0:
Expand Down Expand Up @@ -159,40 +243,45 @@ def calculate_region_of_interest_limits(vector: Vector) -> Tuple[Vector, Vector]
)
return start_region_limit, end_region_limit

@staticmethod
def is_point_in_limits(point: Point, limits: Tuple[Vector, Vector]) -> bool:
cross_product_1 = limits[0].cross_product(point)
cross_product_2 = limits[1].cross_product(point)
return (cross_product_1 > 0) == (cross_product_2 > 0)

def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
def _compute_anchor_sides(
self, detections: Detections
) -> Tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""
Update the `in_count` and `out_count` based on the objects that cross the line.
Find if detections' anchors are within the limit of the line
zone and which anchors are on its left and right side.

Assumes:
* At least 1 detection is provided
* Detections have `tracker_id`

The limit is defined as the region between the two lines,
perpendicular to the line zone, and passing through its start
and end points, as shown below:

Limits:
```
| IN ↑
| |
OUT o---LINE---o OUT
| |
↓ IN |
```

Args:
detections (Detections): A list of detections for which to update the
counts.
detections (Detections): The detections to check.

Returns:
A tuple of two boolean NumPy arrays. The first array indicates which
detections have crossed the line from outside to inside. The second
array indicates which detections have crossed the line from inside to
outside.
result (Tuple[np.ndarray, np.ndarray, np.ndarray]):
All 3 arrays are boolean arrays of shape (N, ) where N is the
number of detections. The first array, `in_limits`, indicates
if the detection's anchor is within the line zone limits.
The second array, `has_any_left_trigger`, indicates if the
detection's anchor is on the left side of the line zone.
The third array, `has_any_right_trigger`, indicates if the
detection's anchor is on the right side of the line zone.
"""
crossed_in = np.full(len(detections), False)
crossed_out = np.full(len(detections), False)

if len(detections) == 0:
return crossed_in, crossed_out

if detections.tracker_id is None:
warnings.warn(
"Line zone counting skipped. LineZone requires tracker_id. Refer to "
"https://supervision.roboflow.com/latest/trackers for more "
"information.",
category=SupervisionWarnings,
)
return crossed_in, crossed_out
assert len(detections) > 0
assert detections.tracker_id is not None

all_anchors = np.array(
[
Expand All @@ -203,52 +292,38 @@ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:

cross_products_1 = cross_product(all_anchors, self.limits[0])
cross_products_2 = cross_product(all_anchors, self.limits[1])

# Works because limit vectors are pointing in opposite directions
in_limits = (cross_products_1 > 0) == (cross_products_2 > 0)
in_limits = np.all(in_limits, axis=0)

triggers = cross_product(all_anchors, self.vector) < 0
has_any_left_trigger = np.any(triggers, axis=0)
has_any_right_trigger = np.any(~triggers, axis=0)
is_uniformly_triggered = ~(has_any_left_trigger & has_any_right_trigger)

class_ids = (
list(detections.class_id)
if detections.class_id is not None
else [None] * len(detections)
)
tracker_ids = list(detections.tracker_id)

if CLASS_NAME_DATA_FIELD in detections.data:
class_names = detections.data[CLASS_NAME_DATA_FIELD]
for class_id, class_name in zip(class_ids, class_names):
if class_id is None:
class_name = "No class"
self.class_id_to_name[class_id] = class_name

for i, (class_ids, tracker_id) in enumerate(zip(class_ids, tracker_ids)):
if not in_limits[i]:
continue

if not is_uniformly_triggered[i]:
continue
return in_limits, has_any_left_trigger, has_any_right_trigger

tracker_state = has_any_left_trigger[i]
if tracker_id not in self.tracker_state:
self.tracker_state[tracker_id] = tracker_state
continue
def _update_class_id_to_name(self, detections: Detections) -> None:
"""
Update the attribute keeping track of which class
IDs correspond to which class names.

if self.tracker_state.get(tracker_id) == tracker_state:
continue
Assumes that class_names are only provided when class_ids are.
"""
class_names = detections.data.get(CLASS_NAME_DATA_FIELD)
assert class_names is None or detections.class_id is not None

self.tracker_state[tracker_id] = tracker_state
if tracker_state:
self._in_count_per_class[class_ids] += 1
crossed_in[i] = True
else:
self._out_count_per_class[class_ids] += 1
crossed_out[i] = True
if detections.class_id is None:
return

return crossed_in, crossed_out
if class_names is None:
new_names = {class_id: str(class_id) for class_id in detections.class_id}
else:
new_names = {
class_id: class_name
for class_id, class_name in zip(detections.class_id, class_names)
}
self.class_id_to_name.update(new_names)


class LineZoneAnnotator:
Expand Down
Loading