From 95d9bcbff44e221d07d4fc080e36d421bbc3d8fe Mon Sep 17 00:00:00 2001 From: Wannes Meert Date: Tue, 4 Jun 2019 11:55:51 +0200 Subject: [PATCH] fix --- dtaidistance/dtw.py | 4 ++-- tests/test_dtw.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dtaidistance/dtw.py b/dtaidistance/dtw.py index 069eb210..c5550e2e 100644 --- a/dtaidistance/dtw.py +++ b/dtaidistance/dtw.py @@ -467,13 +467,13 @@ def distances_array_to_matrix(dists, nb_series, block=None): def distance_array_index(a, b, nb_series): if a == b: - return 0 + raise ValueError("Distance between the same series is not available.") if a > b: a, b = b, a idx = 0 for r in range(a): idx += nb_series - r - 1 - idx += b + idx += b - a - 1 return idx diff --git a/tests/test_dtw.py b/tests/test_dtw.py index 071dd1c6..7ff294e6 100644 --- a/tests/test_dtw.py +++ b/tests/test_dtw.py @@ -40,9 +40,9 @@ def test_condensed_index1(): +-----------------------------+ """ - assert dtw.distance_array_index(3, 2, 5) == 9 - assert dtw.distance_array_index(2, 3, 5) == 9 - assert dtw.distance_array_index(1, 5, 5) == 8 + assert dtw.distance_array_index(3, 2, 6) == 9 + assert dtw.distance_array_index(2, 3, 6) == 9 + assert dtw.distance_array_index(1, 5, 6) == 8 def test_distance1_a(): @@ -156,6 +156,7 @@ def test_distance_matrix_block(): if __name__ == "__main__": # test_distance1_a() - test_distance_matrix2_e() + # test_distance_matrix2_e() # run_distance_matrix_block(parallel=True, use_c=True, use_nogil=False) # test_expected_length1() + test_condensed_index1()