Skip to content

Commit

Permalink
Reduce duplication of spmm and spmv compute stage
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Jul 15, 2024
1 parent 1262248 commit 85eefd0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 63 deletions.
55 changes: 23 additions & 32 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,53 +195,44 @@ sycl::event spmm(sycl::queue& queue, oneapi::mkl::transpose opA, oneapi::mkl::tr
detail::throw_incompatible_container(__func__);
}
fallback_alg_if_needed(alg, opA, opB);
auto compute_functor = [=](CusparseScopedContextHandler& sc, void* workspace_ptr) {
auto [cu_handle, cu_stream] = sc.get_handle_and_stream(queue);
auto cu_a = A_handle->backend_handle;
auto cu_b = B_handle->backend_handle;
auto cu_c = C_handle->backend_handle;
auto type = A_handle->value_container.data_type;
auto cu_op_a = get_cuda_operation(type, opA);
auto cu_op_b = get_cuda_operation(type, opB);
auto cu_type = get_cuda_value_type(type);
auto cu_alg = get_cuda_spmm_alg(alg);
set_pointer_mode(cu_handle, queue, alpha);
auto status = cusparseSpMM(cu_handle, cu_op_a, cu_op_b, alpha, cu_a, cu_b, beta, cu_c,
cu_type, cu_alg, workspace_ptr);
check_status(status, __func__);
CUDA_ERROR_FUNC(cuStreamSynchronize, cu_stream);
};
if (A_handle->all_use_buffer() && spmm_descr->temp_buffer_size > 0) {
// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
auto functor = [=](CusparseScopedContextHandler& sc,
sycl::accessor<std::uint8_t> workspace_acc) {
auto [cu_handle, cu_stream] = sc.get_handle_and_stream(queue);
auto functor_buffer = [=](CusparseScopedContextHandler& sc,
sycl::accessor<std::uint8_t> workspace_acc) {
auto workspace_ptr = sc.get_mem(workspace_acc);
auto cu_a = A_handle->backend_handle;
auto cu_b = B_handle->backend_handle;
auto cu_c = C_handle->backend_handle;
auto type = A_handle->value_container.data_type;
auto cu_op_a = get_cuda_operation(type, opA);
auto cu_op_b = get_cuda_operation(type, opB);
auto cu_type = get_cuda_value_type(type);
auto cu_alg = get_cuda_spmm_alg(alg);
auto status = cusparseSpMM(cu_handle, cu_op_a, cu_op_b, alpha, cu_a, cu_b, beta, cu_c,
cu_type, cu_alg, workspace_ptr);
check_status(status, __func__);
CUDA_ERROR_FUNC(cuStreamSynchronize, cu_stream);
compute_functor(sc, workspace_ptr);
};
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(
spmm_descr->workspace.get_buffer<std::uint8_t>());
return dispatch_submit<true>(__func__, queue, dependencies, functor, A_handle,
return dispatch_submit<true>(__func__, queue, dependencies, functor_buffer, A_handle,
workspace_placeholder_acc, B_handle, C_handle);
}
else {
// The same dispatch_submit can be used for USM or buffers if no
// workspace accessor is needed, workspace_ptr will be a nullptr in the
// latter case.
auto workspace_ptr = spmm_descr->workspace.usm_ptr;
auto functor = [=](CusparseScopedContextHandler& sc) {
auto [cu_handle, cu_stream] = sc.get_handle_and_stream(queue);
auto cu_a = A_handle->backend_handle;
auto cu_b = B_handle->backend_handle;
auto cu_c = C_handle->backend_handle;
auto type = A_handle->value_container.data_type;
auto cu_op_a = get_cuda_operation(type, opA);
auto cu_op_b = get_cuda_operation(type, opB);
auto cu_type = get_cuda_value_type(type);
auto cu_alg = get_cuda_spmm_alg(alg);
set_pointer_mode(cu_handle, queue, alpha);
auto status = cusparseSpMM(cu_handle, cu_op_a, cu_op_b, alpha, cu_a, cu_b, beta, cu_c,
cu_type, cu_alg, workspace_ptr);
check_status(status, __func__);
CUDA_ERROR_FUNC(cuStreamSynchronize, cu_stream);
auto functor_usm = [=](CusparseScopedContextHandler& sc) {
compute_functor(sc, workspace_ptr);
};
return dispatch_submit(__func__, queue, dependencies, functor, A_handle, B_handle,
return dispatch_submit(__func__, queue, dependencies, functor_usm, A_handle, B_handle,
C_handle);
}
}
Expand Down
55 changes: 24 additions & 31 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,36 +195,8 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
if (A_handle->all_use_buffer() != spmv_descr->workspace.use_buffer()) {
detail::throw_incompatible_container(__func__);
}
if (A_handle->all_use_buffer() && spmv_descr->temp_buffer_size > 0) {
// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
auto functor = [=](CusparseScopedContextHandler &sc,
sycl::accessor<std::uint8_t> workspace_acc) {
auto [cu_handle, cu_stream] = sc.get_handle_and_stream(queue);
auto workspace_ptr = sc.get_mem(workspace_acc);
auto cu_a = A_handle->backend_handle;
auto cu_x = x_handle->backend_handle;
auto cu_y = y_handle->backend_handle;
auto type = A_handle->value_container.data_type;
auto cu_op = get_cuda_operation(type, opA);
auto cu_type = get_cuda_value_type(type);
auto cu_alg = get_cuda_spmv_alg(alg);
auto status = cusparseSpMV(cu_handle, cu_op, alpha, cu_a, cu_x, beta, cu_y, cu_type,
cu_alg, workspace_ptr);
check_status(status, __func__);
CUDA_ERROR_FUNC(cuStreamSynchronize, cu_stream);
};
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(
spmv_descr->workspace.get_buffer<std::uint8_t>());
return dispatch_submit<true>(__func__, queue, dependencies, functor, A_handle,
workspace_placeholder_acc, x_handle, y_handle);
}
else {
// The same dispatch_submit can be used for USM or buffers if no
// workspace accessor is needed, workspace_ptr will be a nullptr in the
// latter case.
auto workspace_ptr = spmv_descr->workspace.usm_ptr;
auto functor = [=](CusparseScopedContextHandler &sc) {
auto compute_functor =
[=](CusparseScopedContextHandler &sc, void *workspace_ptr) {
auto [cu_handle, cu_stream] = sc.get_handle_and_stream(queue);
auto cu_a = A_handle->backend_handle;
auto cu_x = x_handle->backend_handle;
Expand All @@ -249,7 +221,28 @@ sycl::event spmv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
check_status(status, __func__);
CUDA_ERROR_FUNC(cuStreamSynchronize, cu_stream);
};
return dispatch_submit(__func__, queue, dependencies, functor, A_handle, x_handle,
if (A_handle->all_use_buffer() && spmv_descr->temp_buffer_size > 0) {
// The accessor can only be bound to the cgh if the buffer size is
// greater than 0
auto functor_buffer = [=](CusparseScopedContextHandler &sc,
sycl::accessor<std::uint8_t> workspace_acc) {
auto workspace_ptr = sc.get_mem(workspace_acc);
compute_functor(sc, workspace_ptr);
};
sycl::accessor<std::uint8_t, 1> workspace_placeholder_acc(
spmv_descr->workspace.get_buffer<std::uint8_t>());
return dispatch_submit<true>(__func__, queue, dependencies, functor_buffer, A_handle,
workspace_placeholder_acc, x_handle, y_handle);
}
else {
// The same dispatch_submit can be used for USM or buffers if no
// workspace accessor is needed, workspace_ptr will be a nullptr in the
// latter case.
auto workspace_ptr = spmv_descr->workspace.usm_ptr;
auto functor_usm = [=](CusparseScopedContextHandler &sc) {
compute_functor(sc, workspace_ptr);
};
return dispatch_submit(__func__, queue, dependencies, functor_usm, A_handle, x_handle,
y_handle);
}
}
Expand Down

0 comments on commit 85eefd0

Please sign in to comment.