Skip to content

Commit

Permalink
Throw exceptions on unsorted CSR format + tests this
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Dec 2, 2024
1 parent 633c655 commit b6518c9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ int run_sparse_matrix_vector_multiply_example(const sycl::device& dev) {
oneapi::mkl::sparse::init_csr_matrix(main_queue, &A_handle, nrows, nrows, nnz,
oneapi::mkl::index_base::zero, ia, ja, a);

// rocSPARSE backend requires that the property sorted is set when using matrices in CSR format.
// Setting this property is also the best practice to get best performance.
oneapi::mkl::sparse::set_matrix_property(main_queue, A_handle,
oneapi::mkl::sparse::matrix_property::sorted);

// Create and initialize dense vector handles
oneapi::mkl::sparse::dense_vector_handle_t x_handle = nullptr;
oneapi::mkl::sparse::dense_vector_handle_t y_handle = nullptr;
Expand Down
6 changes: 6 additions & 0 deletions src/sparse_blas/backends/rocsparse/rocsparse_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ struct matrix_handle : public detail::generic_sparse_handle<rocsparse_spmat_desc
"sparse_blas", function_name,
"The backend does not support unsorted COO format. Use `set_matrix_property` to set the property `matrix_property::sorted`");
}
if (this->format == detail::sparse_format::CSR &&
!this->has_matrix_property(matrix_property::sorted)) {
throw mkl::unimplemented(
"sparse_blas", function_name,
"The backend does not support unsorted CSR format. Use `set_matrix_property` to set the property `matrix_property::sorted`");
}
}

void mark_used() {
Expand Down
18 changes: 4 additions & 14 deletions tests/unit_tests/sparse_blas/include/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,6 @@ inline std::set<oneapi::mkl::sparse::matrix_property> get_default_matrix_propert
inline std::vector<std::set<oneapi::mkl::sparse::matrix_property>>
get_all_matrix_properties_combinations(sycl::queue queue, sparse_matrix_format_t format) {
auto vendor_id = oneapi::mkl::get_device_id(queue);
if (vendor_id == oneapi::mkl::device::nvidiagpu && format == sparse_matrix_format_t::COO) {
// Ensure all the sets have the sorted or sorted_by_rows properties
return { { oneapi::mkl::sparse::matrix_property::sorted },
{ oneapi::mkl::sparse::matrix_property::sorted_by_rows,
oneapi::mkl::sparse::matrix_property::symmetric },
{ oneapi::mkl::sparse::matrix_property::sorted,
oneapi::mkl::sparse::matrix_property::symmetric } };
}
if (vendor_id == oneapi::mkl::device::amdgpu &&
(format == sparse_matrix_format_t::COO || format == sparse_matrix_format_t::CSR)) {
return { { oneapi::mkl::sparse::matrix_property::sorted,
oneapi::mkl::sparse::matrix_property::symmetric } };
}

std::vector<std::set<oneapi::mkl::sparse::matrix_property>> properties_combinations{
{ oneapi::mkl::sparse::matrix_property::sorted },
{ oneapi::mkl::sparse::matrix_property::symmetric },
Expand All @@ -99,6 +85,10 @@ get_all_matrix_properties_combinations(sycl::queue queue, sparse_matrix_format_t
if (format == sparse_matrix_format_t::COO) {
properties_combinations.push_back({ oneapi::mkl::sparse::matrix_property::sorted_by_rows });
}
if (vendor_id == oneapi::mkl::device::nvidiagpu || vendor_id == oneapi::mkl::device::amdgpu) {
// Test without any properties set since for backends for which this is not the default behavior
properties_combinations.push_back({});
}
return properties_combinations;
}

Expand Down

0 comments on commit b6518c9

Please sign in to comment.