Skip to content

Commit

Permalink
Throw unimplemented for some cases with csr_alg3
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 23, 2024
1 parent 050f22a commit 77d19e8
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
3 changes: 3 additions & 0 deletions docs/domains/sparse_linear_algebra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ cuSPARSE backend

Currently known limitations:

- Using ``spmm`` with the algorithm ``spmm_alg::csr_alg3`` and an ``opA`` other
than ``transpose::nontrans`` or an ``opB`` ``transpose::conjtrans`` will throw
an ``oneapi::mkl::unimplemented`` exception.
- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will
throw an ``oneapi::mkl::unimplemented`` exception.
- The COO format requires the indices to be sorted by row. See the `cuSPARSE
Expand Down
37 changes: 21 additions & 16 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,21 @@ inline auto get_cuda_spmm_alg(spmm_alg alg) {
}
}

inline void fallback_alg_if_needed(spmm_alg& alg, oneapi::mkl::transpose opA,
oneapi::mkl::transpose opB) {
if (alg == spmm_alg::csr_alg3 &&
(opA != oneapi::mkl::transpose::nontrans || opB == oneapi::mkl::transpose::conjtrans)) {
// Avoid warnings printed on std::cerr
alg = spmm_alg::default_alg;
void check_valid_spmm(const std::string& function_name, oneapi::mkl::transpose opA,
oneapi::mkl::transpose opB, matrix_view A_view, matrix_handle_t A_handle,
dense_matrix_handle_t B_handle, dense_matrix_handle_t C_handle,
bool is_alpha_host_accessible, bool is_beta_host_accessible, spmm_alg alg) {
detail::check_valid_spmm_common(function_name, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
if (alg == spmm_alg::csr_alg3 && opA != oneapi::mkl::transpose::nontrans) {
throw mkl::unimplemented(
"sparse_blas", function_name,
"The backend does not support spmm with the algorithm `spmm_alg::csr_alg3` if `opA` is not `transpose::nontrans`.");
}
if (alg == spmm_alg::csr_alg3 && opB == oneapi::mkl::transpose::conjtrans) {
throw mkl::unimplemented(
"sparse_blas", function_name,
"The backend does not support spmm with the algorithm `spmm_alg::csr_alg3` if `opB` is `transpose::conjtrans`.");
}
}

Expand All @@ -87,9 +96,8 @@ void spmm_buffer_size(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mk
std::size_t& temp_buffer_size) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
fallback_alg_if_needed(alg, opA, opB);
check_valid_spmm(__func__, opA, opB, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible, alg);
auto functor = [=, &temp_buffer_size](CusparseScopedContextHandler& sc) {
auto cu_handle = sc.get_handle(queue);
auto cu_a = A_handle->backend_handle;
Expand All @@ -116,8 +124,8 @@ inline void common_spmm_optimize(oneapi::mkl::transpose opA, oneapi::mkl::transp
matrix_handle_t A_handle, dense_matrix_handle_t B_handle,
bool is_beta_host_accessible, dense_matrix_handle_t C_handle,
spmm_alg alg, spmm_descr_t spmm_descr) {
detail::check_valid_spmm_common("spmm_optimize", A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_spmm("spmm_optimize", opA, opB, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible, alg);
if (!spmm_descr->buffer_size_called) {
throw mkl::uninitialized("sparse_blas", "spmm_optimize",
"spmm_buffer_size must be called before spmm_optimize.");
Expand Down Expand Up @@ -168,7 +176,6 @@ void spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::
// cusparseSpMM_preprocess cannot be called if the workspace is empty
return;
}
fallback_alg_if_needed(alg, opA, opB);
auto functor = [=](CusparseScopedContextHandler& sc,
sycl::accessor<std::uint8_t> workspace_acc) {
auto cu_handle = sc.get_handle(queue);
Expand Down Expand Up @@ -200,7 +207,6 @@ sycl::event spmm_optimize(sycl::queue& queue, oneapi::mkl::transpose opA,
// cusparseSpMM_preprocess cannot be called if the workspace is empty
return detail::collapse_dependencies(queue, dependencies);
}
fallback_alg_if_needed(alg, opA, opB);
auto functor = [=](CusparseScopedContextHandler& sc) {
auto cu_handle = sc.get_handle(queue);
spmm_optimize_impl(cu_handle, opA, opB, alpha, A_handle, B_handle, beta, C_handle, alg,
Expand All @@ -217,8 +223,8 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
const std::vector<sycl::event>& dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
bool is_beta_host_accessible = detail::is_ptr_accessible_on_host(queue, beta);
detail::check_valid_spmm_common(__func__, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_spmm(__func__, opA, opB, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible, alg);
if (A_handle->all_use_buffer() != spmm_descr->workspace.use_buffer()) {
detail::throw_incompatible_container(__func__);
}
Expand All @@ -235,7 +241,6 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
CHECK_DESCR_MATCH(spmm_descr, C_handle, "spmm_optimize");
CHECK_DESCR_MATCH(spmm_descr, alg, "spmm_optimize");

fallback_alg_if_needed(alg, opA, opB);
auto compute_functor = [=](CusparseScopedContextHandler& sc, void* workspace_ptr) {
auto [cu_handle, cu_stream] = sc.get_handle_and_stream(queue);
auto cu_a = A_handle->backend_handle;
Expand Down

0 comments on commit 77d19e8

Please sign in to comment.