Skip to content

Commit

Permalink
Test symmetric with complex types and hermitian and conjtrans with re…
Browse files Browse the repository at this point in the history
…al types
  • Loading branch information
Rbiessy committed Jul 9, 2024
1 parent 9b9548a commit 82566e5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 47 deletions.
7 changes: 2 additions & 5 deletions tests/unit_tests/sparse_blas/include/test_spmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ void test_helper_with_format_with_transpose(

/**
* Helper function to test combination of transpose vals.
* Only test \p conjtrans if \p fpType is complex.
*
* @tparam fpType Complex or scalar, single or double precision type
* @tparam testFunctorI32 Test functor for fpType and int32
Expand All @@ -223,10 +222,8 @@ void test_helper_with_format(
const std::vector<oneapi::mkl::sparse::spmm_alg> &non_default_algorithms, int &num_passed,
int &num_skipped) {
std::vector<oneapi::mkl::transpose> transpose_vals{ oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::trans };
if (complex_info<fpType>::is_complex) {
transpose_vals.push_back(oneapi::mkl::transpose::conjtrans);
}
oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::conjtrans };
for (auto transpose_A : transpose_vals) {
for (auto transpose_B : transpose_vals) {
test_helper_with_format_with_transpose<fpType>(
Expand Down
61 changes: 51 additions & 10 deletions tests/unit_tests/sparse_blas/include/test_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
* The test functions will use different sizes if the configuration implies a symmetric matrix.
*/
template <typename fpType, typename testFunctorI32, typename testFunctorI64>
void test_helper_with_format(
void test_helper_with_format_with_transpose(
testFunctorI32 test_functor_i32, testFunctorI64 test_functor_i64, sycl::device *dev,
sparse_matrix_format_t format,
const std::vector<oneapi::mkl::sparse::spmv_alg> &non_default_algorithms,
Expand Down Expand Up @@ -153,22 +153,37 @@ void test_helper_with_format(
no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
if (transpose_val != oneapi::mkl::transpose::conjtrans) {
// Lower symmetric or hermitian
// Do not test conjtrans with symmetric or hermitian views as no backend supports it.
// Lower symmetric
oneapi::mkl::sparse::matrix_view symmetric_view(
complex_info<fpType>::is_complex ? oneapi::mkl::sparse::matrix_descr::hermitian
: oneapi::mkl::sparse::matrix_descr::symmetric);
oneapi::mkl::sparse::matrix_descr::symmetric);
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero,
transpose_val, fp_one, fp_zero, default_alg, symmetric_view,
no_properties, no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Upper symmetric or hermitian
// Upper symmetric
symmetric_view.uplo_view = oneapi::mkl::uplo::upper;
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero,
transpose_val, fp_one, fp_zero, default_alg, symmetric_view,
no_properties, no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Lower hermitian
oneapi::mkl::sparse::matrix_view hermitian_view(
oneapi::mkl::sparse::matrix_descr::hermitian);
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero,
transpose_val, fp_one, fp_zero, default_alg, hermitian_view,
no_properties, no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
// Upper hermitian
hermitian_view.uplo_view = oneapi::mkl::uplo::upper;
EXPECT_TRUE_OR_FUTURE_SKIP(
test_functor_i32(dev, format, nrows_A, ncols_A, density_A_matrix, index_zero,
transpose_val, fp_one, fp_zero, default_alg, hermitian_view,
no_properties, no_reset_data, no_scalars_on_device),
num_passed, num_skipped);
}
// Test other algorithms
for (auto alg : non_default_algorithms) {
Expand All @@ -188,31 +203,57 @@ void test_helper_with_format(
}
}

/**
* Helper function to test combination of transpose vals.
*
* @tparam fpType Complex or scalar, single or double precision type
* @tparam testFunctorI32 Test functor for fpType and int32
* @tparam testFunctorI64 Test functor for fpType and int64
* @param dev Device to test
* @param format Sparse matrix format to use
* @param non_default_algorithms Algorithms compatible with the given format, other than default_alg
* @param num_passed Increase the number of configurations passed
* @param num_skipped Increase the number of configurations skipped
*/
template <typename fpType, typename testFunctorI32, typename testFunctorI64>
void test_helper_with_format(
testFunctorI32 test_functor_i32, testFunctorI64 test_functor_i64, sycl::device *dev,
sparse_matrix_format_t format,
const std::vector<oneapi::mkl::sparse::spmv_alg> &non_default_algorithms, int &num_passed,
int &num_skipped) {
std::vector<oneapi::mkl::transpose> transpose_vals{ oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::conjtrans };
for (auto transpose_A : transpose_vals) {
test_helper_with_format_with_transpose<fpType>(test_functor_i32, test_functor_i64, dev,
format, non_default_algorithms, transpose_A,
num_passed, num_skipped);
}
}

/**
* Helper function to test multiple sparse matrix format and choose valid algorithms.
*
* @tparam fpType Complex or scalar, single or double precision type
* @tparam testFunctorI32 Test functor for fpType and int32
* @tparam testFunctorI64 Test functor for fpType and int64
* @param dev Device to test
* @param transpose_val Transpose value for the input matrix
* @param num_passed Increase the number of configurations passed
* @param num_skipped Increase the number of configurations skipped
*/
template <typename fpType, typename testFunctorI32, typename testFunctorI64>
void test_helper(testFunctorI32 test_functor_i32, testFunctorI64 test_functor_i64,
sycl::device *dev, oneapi::mkl::transpose transpose_val, int &num_passed,
int &num_skipped) {
sycl::device *dev, int &num_passed, int &num_skipped) {
test_helper_with_format<fpType>(
test_functor_i32, test_functor_i64, dev, sparse_matrix_format_t::CSR,
{ oneapi::mkl::sparse::spmv_alg::no_optimize_alg, oneapi::mkl::sparse::spmv_alg::csr_alg1,
oneapi::mkl::sparse::spmv_alg::csr_alg2, oneapi::mkl::sparse::spmv_alg::csr_alg3 },
transpose_val, num_passed, num_skipped);
num_passed, num_skipped);
test_helper_with_format<fpType>(
test_functor_i32, test_functor_i64, dev, sparse_matrix_format_t::COO,
{ oneapi::mkl::sparse::spmv_alg::no_optimize_alg, oneapi::mkl::sparse::spmv_alg::coo_alg1,
oneapi::mkl::sparse::spmv_alg::coo_alg2 },
transpose_val, num_passed, num_skipped);
num_passed, num_skipped);
}

/// Compute spmv reference as a dense operation
Expand Down
20 changes: 4 additions & 16 deletions tests/unit_tests/sparse_blas/source/sparse_spmv_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ TEST_P(SparseSpmvBufferTests, RealSinglePrecision) {
using fpType = float;
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand All @@ -199,9 +197,7 @@ TEST_P(SparseSpmvBufferTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand All @@ -213,11 +209,7 @@ TEST_P(SparseSpmvBufferTests, ComplexSinglePrecision) {
using fpType = std::complex<float>;
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::conjtrans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand All @@ -230,11 +222,7 @@ TEST_P(SparseSpmvBufferTests, ComplexDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::conjtrans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand Down
20 changes: 4 additions & 16 deletions tests/unit_tests/sparse_blas/source/sparse_spmv_usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ TEST_P(SparseSpmvUsmTests, RealSinglePrecision) {
using fpType = float;
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand All @@ -248,9 +246,7 @@ TEST_P(SparseSpmvUsmTests, RealDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand All @@ -262,11 +258,7 @@ TEST_P(SparseSpmvUsmTests, ComplexSinglePrecision) {
using fpType = std::complex<float>;
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::conjtrans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand All @@ -279,11 +271,7 @@ TEST_P(SparseSpmvUsmTests, ComplexDoublePrecision) {
CHECK_DOUBLE_ON_DEVICE(GetParam());
int num_passed = 0, num_skipped = 0;
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::nontrans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::trans, num_passed, num_skipped);
test_helper<fpType>(test_spmv<fpType, int32_t>, test_spmv<fpType, std::int64_t>, GetParam(),
oneapi::mkl::transpose::conjtrans, num_passed, num_skipped);
num_passed, num_skipped);
if (num_skipped > 0) {
// Mark that some tests were skipped
GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped
Expand Down

0 comments on commit 82566e5

Please sign in to comment.