diff --git a/src/khiva/matrix.cpp b/src/khiva/matrix.cpp index 16a88f1..715bbba 100644 --- a/src/khiva/matrix.cpp +++ b/src/khiva/matrix.cpp @@ -552,7 +552,7 @@ void findBestNOccurrences(af::array q, af::array t, long n, af::array &distances af::sort(sortedDistances, sortedIndexes, distancesGlobal); - indexes = sortedIndexes(af::seq(n), af::span, af::span).as(t.type()); + indexes = sortedIndexes(af::seq(n), af::span, af::span); distances = sortedDistances(af::seq(n), af::span, af::span).as(t.type()); } diff --git a/test/matrixTest.cpp b/test/matrixTest.cpp index 559ab74..04b3600 100644 --- a/test/matrixTest.cpp +++ b/test/matrixTest.cpp @@ -192,13 +192,10 @@ void findBestNOccurrences() { ASSERT_EQ(distance.dims(), af::dim4(1, 1, 2, 1)); ASSERT_EQ(index.dims(), af::dim4(1, 1, 2, 1)); - distance = distance.as(f32); - index = index.as(s64); - ASSERT_NEAR(distance(0, 0, 0, 0).scalar(), expectedDistance, 1e-2); ASSERT_NEAR(distance(0, 0, 1, 0).scalar(), expectedDistance, 1e-2); - ASSERT_EQ(index(0, 0, 0, 0).scalar(), expectedIndex); - ASSERT_EQ(index(0, 0, 0, 0).scalar(), expectedIndex); + ASSERT_EQ(index(0, 0, 0, 0).scalar(), expectedIndex); + ASSERT_EQ(index(0, 0, 1, 0).scalar(), expectedIndex); } void findBestNOccurrencesMultipleQueries() {