Skip to content

Commit

Permalink
added filter + tracking tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobias Grosse-Puppendahl committed Aug 19, 2020
1 parent ef3d7d7 commit 290d36e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
11 changes: 7 additions & 4 deletions simple_filters/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,17 @@ def update(self, states):

states = np.array(states)

# check if array is 2d, otherwise make it so
# check if the states array is 2d, otherwise make it so
if len(states.shape) == 1:
states = np.array([states])

# set initial properties
number_of_states = states.shape[0]
# check if the states array is empty
if states.size == 0:
number_of_states = 0
else:
number_of_states = states.shape[0]

number_of_tracked_objects = len(self.__tracked_objects)

objects_to_match = [i for i in range(0, number_of_tracked_objects)]
states_to_match = [i for i in range(0, number_of_states)]

Expand Down
42 changes: 33 additions & 9 deletions tests/test_single_object_tracking.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from unittest import TestCase

from ..simple_filters import Tracker, TrackedObject, Filter, PolynomialFilterStrategy
from ..simple_filters import Tracker, TrackedObject, Filter, PolynomialFilterStrategy, DummyFilterStrategy

import pytest
import numpy as np

class TestSingleStepSingleObjectTracking(TestCase):

def setUp(self):
strategy = PolynomialFilterStrategy()
filter_prototype = Filter(strategy, history_size=1)
self.tracker = Tracker(filter_prototype, distance_threshold=1.)

def test_main(self):
def test_new_obj(self):
strategy = DummyFilterStrategy()
filter_prototype = Filter(strategy, history_size=5)
tracker = Tracker(filter_prototype, distance_threshold=1.)

states = [
np.array([1.0, 1.0]),
np.array([1.5, 1.5]),
Expand All @@ -26,8 +25,33 @@ def test_main(self):

for i, (state, expected_tracking_id) in enumerate(zip(states, expected_tracking_ids)):
print("timestep", i)
self.tracker.update(state)
tracked_state = self.tracker.get_tracked_objects()
tracker.update(state)
tracked_state = tracker.get_tracked_objects()

self.assertEqual(1, len(tracked_state))
self.assertEqual(tracked_state[0].id, expected_tracking_id)

def test_interpolate_object_with_ttl(self):
strategy = PolynomialFilterStrategy(poly_degree=1, reject_outliers=False)
filter_prototype = Filter(strategy, history_size=3)
tracker = Tracker(filter_prototype, distance_threshold=1., time_to_live=1)

states = [
np.array([[1.0, 1.0]]),
np.array([[1.5, 1.5]]),
np.array([[2.0, 2.0]]),
np.array([]),
np.array([[3.0, 3.0]]),
np.array([[3.5, 3.5]]),
np.array([[4.0, 4.0]])
]

expected_tracking_ids = [1, 1, 1, 1, 1, 1, 1]

for i, (state, expected_tracking_id) in enumerate(zip(states, expected_tracking_ids)):
print("timestep", i)
tracker.update(state)
tracked_state = tracker.get_tracked_objects()

self.assertEqual(1, len(tracked_state))
self.assertEqual(tracked_state[0].id, expected_tracking_id)

0 comments on commit 290d36e

Please sign in to comment.