From 529e3ebd0bef4ab6c5126e05d22e86b3e0424c8a Mon Sep 17 00:00:00 2001 From: Tobias Grosse-Puppendahl Date: Sun, 23 Aug 2020 11:38:14 +0200 Subject: [PATCH] bug fixes --- simple_filters/tracker.py | 51 ++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/simple_filters/tracker.py b/simple_filters/tracker.py index 00ea1f2..b376c56 100644 --- a/simple_filters/tracker.py +++ b/simple_filters/tracker.py @@ -79,8 +79,8 @@ def update(self, states): number_of_states = states.shape[0] number_of_tracked_objects = len(self.__tracked_objects) - objects_to_match = [i for i in range(0, number_of_tracked_objects)] - states_to_match = [i for i in range(0, number_of_states)] + objects_to_match = list([i for i in range(0, number_of_tracked_objects)]) + states_to_match = list([i for i in range(0, number_of_states)]) ## Build the distance matrix and match objects # We build a matrix that contains the distances of the tracked objects @@ -101,40 +101,47 @@ def update(self, states): # Clearly, this is not the optimal solution (but fast), since it doesn't # optimize for the cumulative distance of pairs # Whenever a minimum was found, we invalidate this part in the distance matrix - min_distances = np.min(distance_matrix, axis=1) - min_distances = np.sort(min_distances) - - for d in min_distances: + while True: + min_distance = np.min(distance_matrix) + index = np.argwhere(distance_matrix == min_distance)[0] + # when the distance threshold is exceeded, it means we are seeing new objects, # or existing ones shall be removed - if d > self.distance_threshold: + if min_distance > self.distance_threshold: break - index = np.argwhere(distance_matrix == d)[0] t = index[0] # just to avoid confusion, this is the object s = index[1] # and this the state - distance_matrix[t, :] = np.inf # invalidate this part of the distance matrix - - objects_to_match.remove(t) - states_to_match.remove(s) - self.__tracked_objects[t].update(states[s]) - - ## Add objects - # now go through all unmatched objects and create new objects - for i in states_to_match: - self.object_counter += 1 - added_object = TrackedObject(self.object_counter, deepcopy(self.__filter_prototype)) - added_object.update(states[i]) - self.__tracked_objects.append(added_object) + distance_matrix[t, s] = np.inf # invalidate this part of the distance matrix + + # if the state has not previously been associated to a tracked object + if t in objects_to_match and s in states_to_match: + objects_to_match.remove(t) + states_to_match.remove(s) + self.__tracked_objects[t].update(states[s]) ## Delete objects # Remove an object that has not been seen when its time-to-live is exceeded + # We are using two steps for deletion, because in the first step we are addressing by index and we + # don't want to mess up the list indexing + removals = [] for i in objects_to_match: tracked_object = self.__tracked_objects[i] tracked_object.time_to_live += 1 if tracked_object.time_to_live > self.time_to_live: - self.__tracked_objects.remove(tracked_object) + removals.append(tracked_object) else: # update the object with the next predicted state tracked_object.update(tracked_object.eval(time=1)) + + for tracked_object in removals: + self.__tracked_objects.remove(tracked_object) + + ## Add objects + # now go through all unmatched objects and create new objects + for i in states_to_match: + self.object_counter += 1 + added_object = TrackedObject(self.object_counter, deepcopy(self.__filter_prototype)) + added_object.update(states[i]) + self.__tracked_objects.append(added_object)