Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
ejhusom committed Jan 24, 2025
1 parent e220f67 commit 0e96bb2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
17 changes: 10 additions & 7 deletions src/cluster_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ def find_segments(labels):

segments = []

if len(labels) == 0:
return np.array(segments)

current_label = labels[0]
current_length = 1
start_idx = 0
Expand All @@ -393,13 +396,8 @@ def find_segments(labels):
for i in range(1, len(labels)):
if labels[i] == current_label:
current_length += 1
if i == len(labels) - 1:
end_idx = i
segments.append(
[segment_idx, current_label, current_length, start_idx, end_idx]
)
else:
end_idx = i
end_idx = i - 1
segments.append(
[segment_idx, current_label, current_length, start_idx, end_idx]
)
Expand All @@ -408,8 +406,13 @@ def find_segments(labels):
current_length = 1
start_idx = i

return np.array(segments)
# Append the last segment
end_idx = len(labels) - 1
segments.append(
[segment_idx, current_label, current_length, start_idx, end_idx]
)

return np.array(segments)

def create_event_log(labels, identifier="",
feature_vector_timestamps=None):
Expand Down
39 changes: 36 additions & 3 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import yaml

sys.path.append("src/")
import cluster
import train
import cluster_utils


class TestUDAVA(unittest.TestCase):
Expand All @@ -31,8 +32,8 @@ class TestUDAVA(unittest.TestCase):
def test_find_segments(self):
"""Test whether find_segments() returns expected results."""

labels = [0, 0, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2]
segments = cluster.find_segments(labels)
labels = np.array([0, 0, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2])
segments = cluster_utils.find_segments(labels)

expected_segments = np.array(
[[0, 0, 2, 0, 1], [1, 1, 3, 2, 4], [2, 0, 4, 5, 8], [3, 2, 3, 9, 11]]
Expand All @@ -43,6 +44,38 @@ def test_find_segments(self):

np.testing.assert_array_equal(segments, expected_segments)

def test_find_segments_single_label(self):
"""Test find_segments() with a single label."""

labels = np.array([1, 1, 1, 1, 1])
segments = cluster_utils.find_segments(labels)

expected_segments = np.array([[0, 1, 5, 0, 4]])

np.testing.assert_array_equal(segments, expected_segments)

def test_find_segments_alternating_labels(self):
"""Test find_segments() with alternating labels."""

labels = np.array([0, 1, 0, 1, 0])
segments = cluster_utils.find_segments(labels)

expected_segments = np.array(
[[0, 0, 1, 0, 0], [1, 1, 1, 1, 1], [2, 0, 1, 2, 2], [3, 1, 1, 3, 3], [4, 0, 1, 4, 4]]
)

np.testing.assert_array_equal(segments, expected_segments)

def test_find_segments_empty(self):
"""Test find_segments() with an empty array."""

labels = np.array([])
segments = cluster_utils.find_segments(labels)

expected_segments = np.array([])

np.testing.assert_array_equal(segments, expected_segments)


if __name__ == "__main__":

Expand Down

0 comments on commit 0e96bb2

Please sign in to comment.