Skip to content

Commit

Permalink
Revert throwing unsupported for spsv + no_optimize_alg
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 30, 2024
1 parent 00e5ced commit 342380e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
4 changes: 2 additions & 2 deletions docs/domains/sparse_linear_algebra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ Currently known limitations:
a ``oneapi::mkl::unimplemented`` exception.
- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will
throw a ``oneapi::mkl::unimplemented`` exception.
- Using ``spsv`` with the algorithm ``spsv_alg::no_optimize_alg`` will throw a
``oneapi::mkl::unimplemented`` exception.
- Using ``spsv`` with the algorithm ``spsv_alg::no_optimize_alg`` may still
perform some mandatory preprocessing.
- oneMKL Interface does not provide a way to use non-default algorithms without
calling preprocess functions such as ``cusparseSpMM_preprocess`` or
``cusparseSpMV_preprocess``. Feel free to create an issue if this is needed.
Expand Down
15 changes: 6 additions & 9 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,18 @@ inline auto get_cuda_spsv_alg(spsv_alg /*alg*/) {

void check_valid_spsv(const std::string &function_name, matrix_view A_view,
matrix_handle_t A_handle, dense_vector_handle_t x_handle,
dense_vector_handle_t y_handle, spsv_alg alg, bool is_alpha_host_accessible) {
dense_vector_handle_t y_handle, bool is_alpha_host_accessible) {
detail::check_valid_spsv_common(function_name, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
check_valid_matrix_properties(function_name, A_handle);
if (alg == spsv_alg::no_optimize_alg) {
throw mkl::unimplemented(
"sparse_blas", function_name,
"The backend does not support the algorithm ``spsv_alg::no_optimize_alg``.");
}
}

void spsv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
matrix_view A_view, matrix_handle_t A_handle, dense_vector_handle_t x_handle,
dense_vector_handle_t y_handle, spsv_alg alg, spsv_descr_t spsv_descr,
std::size_t &temp_buffer_size) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, alg, is_alpha_host_accessible);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible);
auto functor = [=, &temp_buffer_size](CusparseScopedContextHandler &sc) {
auto cu_handle = sc.get_handle(queue);
auto cu_a = A_handle->backend_handle;
Expand All @@ -133,7 +128,7 @@ inline void common_spsv_optimize(oneapi::mkl::transpose opA, bool is_alpha_host_
matrix_view A_view, matrix_handle_t A_handle,
dense_vector_handle_t x_handle, dense_vector_handle_t y_handle,
spsv_alg alg, spsv_descr_t spsv_descr) {
check_valid_spsv("spsv_optimize", A_view, A_handle, x_handle, y_handle, alg,
check_valid_spsv("spsv_optimize", A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
if (!spsv_descr->buffer_size_called) {
throw mkl::uninitialized("sparse_blas", "spsv_optimize",
Expand Down Expand Up @@ -178,6 +173,7 @@ void spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const void *a
}
common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, y_handle, alg,
spsv_descr);
// Ignore spsv_alg::no_optimize_alg as this step is mandatory for cuSPARSE
// Copy the buffer to extend its lifetime until the descriptor is free'd.
spsv_descr->workspace.set_buffer_untyped(workspace);

Expand Down Expand Up @@ -215,6 +211,7 @@ sycl::event spsv_optimize(sycl::queue &queue, oneapi::mkl::transpose opA, const
}
common_spsv_optimize(opA, is_alpha_host_accessible, A_view, A_handle, x_handle, y_handle, alg,
spsv_descr);
// Ignore spsv_alg::no_optimize_alg as this step is mandatory for cuSPARSE
auto functor = [=](CusparseScopedContextHandler &sc) {
auto cu_handle = sc.get_handle(queue);
spsv_optimize_impl(cu_handle, opA, alpha, A_view, A_handle, x_handle, y_handle, alg,
Expand All @@ -229,7 +226,7 @@ sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
dense_vector_handle_t y_handle, spsv_alg alg, spsv_descr_t spsv_descr,
const std::vector<sycl::event> &dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, alg, is_alpha_host_accessible);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible);
if (A_handle->all_use_buffer() != spsv_descr->workspace.use_buffer()) {
detail::throw_incompatible_container(__func__);
}
Expand Down

0 comments on commit 342380e

Please sign in to comment.