Skip to content

Commit

Permalink
Introduce sorted_by_rows property
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy committed Sep 23, 2024
1 parent 77d19e8 commit d23b24d
Show file tree
Hide file tree
Showing 19 changed files with 239 additions and 183 deletions.
9 changes: 6 additions & 3 deletions docs/domains/sparse_linear_algebra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ cuSPARSE backend

Currently known limitations:

- The COO format requires the indices to be sorted by row. See the `cuSPARSE
documentation
<https://docs.nvidia.com/cuda/cusparse/index.html#coordinate-coo>`_. Sparse
operations using matrices with the COO format without the property
``matrix_property::sorted_by_rows`` or ``matrix_property::sorted`` will throw
an ``oneapi::mkl::unimplemented`` exception.
- Using ``spmm`` with the algorithm ``spmm_alg::csr_alg3`` and an ``opA`` other
than ``transpose::nontrans`` or an ``opB`` ``transpose::conjtrans`` will throw
an ``oneapi::mkl::unimplemented`` exception.
- Using ``spmv`` with a ``type_view`` other than ``matrix_descr::general`` will
throw an ``oneapi::mkl::unimplemented`` exception.
- The COO format requires the indices to be sorted by row. See the `cuSPARSE
documentation
<https://docs.nvidia.com/cuda/cusparse/index.html#coordinate-coo>`_.


Operation algorithms mapping
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/mkl/sparse_blas/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace sparse {
enum class matrix_property {
symmetric,
sorted,
sorted_by_rows,
};

enum class spmm_alg {
Expand Down
16 changes: 8 additions & 8 deletions src/sparse_blas/backends/cusparse/cusparse_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ void init_coo_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64
CUSPARSE_ERR_FUNC(cusparseCreateCoo, &cu_smhandle, num_rows, num_cols, nnz,
sc.get_mem(row_acc), sc.get_mem(col_acc), sc.get_mem(val_acc),
cuda_index_type, cuda_index_base, cuda_value_type);
*p_smhandle = new matrix_handle(cu_smhandle, row_ind, col_ind, val, num_rows, num_cols,
nnz, index);
*p_smhandle = new matrix_handle(cu_smhandle, row_ind, col_ind, val, detail::sparse_format::COO,
num_rows, num_cols, nnz, index);
});
});
event.wait_and_throw();
Expand All @@ -284,8 +284,8 @@ void init_coo_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64
cusparseSpMatDescr_t cu_smhandle;
CUSPARSE_ERR_FUNC(cusparseCreateCoo, &cu_smhandle, num_rows, num_cols, nnz, row_ind,
col_ind, val, cuda_index_type, cuda_index_base, cuda_value_type);
*p_smhandle = new matrix_handle(cu_smhandle, row_ind, col_ind, val, num_rows, num_cols,
nnz, index);
*p_smhandle = new matrix_handle(cu_smhandle, row_ind, col_ind, val, detail::sparse_format::COO,
num_rows, num_cols, nnz, index);
});
});
event.wait_and_throw();
Expand Down Expand Up @@ -388,8 +388,8 @@ void init_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64
CUSPARSE_ERR_FUNC(cusparseCreateCsr, &cu_smhandle, num_rows, num_cols, nnz,
sc.get_mem(row_acc), sc.get_mem(col_acc), sc.get_mem(val_acc),
cuda_index_type, cuda_index_type, cuda_index_base, cuda_value_type);
*p_smhandle = new matrix_handle(cu_smhandle, row_ptr, col_ind, val, num_rows, num_cols,
nnz, index);
*p_smhandle = new matrix_handle(cu_smhandle, row_ptr, col_ind, val, detail::sparse_format::CSR,
num_rows, num_cols, nnz, index);
});
});
event.wait_and_throw();
Expand All @@ -410,8 +410,8 @@ void init_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64
CUSPARSE_ERR_FUNC(cusparseCreateCsr, &cu_smhandle, num_rows, num_cols, nnz, row_ptr,
col_ind, val, cuda_index_type, cuda_index_type, cuda_index_base,
cuda_value_type);
*p_smhandle = new matrix_handle(cu_smhandle, row_ptr, col_ind, val, num_rows, num_cols,
nnz, index);
*p_smhandle = new matrix_handle(cu_smhandle, row_ptr, col_ind, val, detail::sparse_format::CSR,
num_rows, num_cols, nnz, index);
});
});
event.wait_and_throw();
Expand Down
27 changes: 20 additions & 7 deletions src/sparse_blas/backends/cusparse/cusparse_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,33 @@ struct dense_matrix_handle : public detail::generic_dense_matrix_handle<cusparse
struct matrix_handle : public detail::generic_sparse_handle<cusparseSpMatDescr_t> {
template <typename fpType, typename intType>
matrix_handle(cusparseSpMatDescr_t cu_descr, intType* row_ptr, intType* col_ptr,
fpType* value_ptr, std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz,
oneapi::mkl::index_base index)
fpType* value_ptr, detail::sparse_format format, std::int64_t num_rows,
std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index)
: detail::generic_sparse_handle<cusparseSpMatDescr_t>(
cu_descr, row_ptr, col_ptr, value_ptr, num_rows, num_cols, nnz, index) {}
cu_descr, row_ptr, col_ptr, value_ptr, format, num_rows, num_cols, nnz, index) {}

template <typename fpType, typename intType>
matrix_handle(cusparseSpMatDescr_t cu_descr, const sycl::buffer<intType, 1> row_buffer,
const sycl::buffer<intType, 1> col_buffer,
const sycl::buffer<fpType, 1> value_buffer, std::int64_t num_rows,
std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index)
: detail::generic_sparse_handle<cusparseSpMatDescr_t>(
cu_descr, row_buffer, col_buffer, value_buffer, num_rows, num_cols, nnz, index) {}
const sycl::buffer<fpType, 1> value_buffer, detail::sparse_format format,
std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz,
oneapi::mkl::index_base index)
: detail::generic_sparse_handle<cusparseSpMatDescr_t>(cu_descr, row_buffer, col_buffer,
value_buffer, format, num_rows,
num_cols, nnz, index) {}
};

inline void check_valid_matrix_properties(const std::string& function_name,
matrix_handle_t sm_handle) {
if (sm_handle->format == detail::sparse_format::COO &&
!(sm_handle->has_matrix_property(matrix_property::sorted_by_rows) ||
sm_handle->has_matrix_property(matrix_property::sorted))) {
throw mkl::unimplemented(
"sparse_blas", function_name,
"The backend does not support unsorted COO format. Use `set_matrix_property` to set the property `matrix_property::sorted_by_rows` or `matrix_property::sorted`");
}
}

} // namespace oneapi::mkl::sparse

#endif // _ONEMKL_SRC_SPARSE_BLAS_BACKENDS_CUSPARSE_HANDLES_HPP_
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ void check_valid_spmm(const std::string& function_name, oneapi::mkl::transpose o
bool is_alpha_host_accessible, bool is_beta_host_accessible, spmm_alg alg) {
detail::check_valid_spmm_common(function_name, A_view, A_handle, B_handle, C_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_matrix_properties(function_name, A_handle);
if (alg == spmm_alg::csr_alg3 && opA != oneapi::mkl::transpose::nontrans) {
throw mkl::unimplemented(
"sparse_blas", function_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void check_valid_spmv(const std::string &function_name, oneapi::mkl::transpose o
bool is_beta_host_accessible) {
detail::check_valid_spmv_common(function_name, opA, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible, is_beta_host_accessible);
check_valid_matrix_properties(function_name, A_handle);
if (A_view.type_view != matrix_descr::general) {
throw mkl::unimplemented(
"sparse_blas", function_name,
Expand Down
18 changes: 12 additions & 6 deletions src/sparse_blas/backends/cusparse/operations/cusparse_spsv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,20 @@ inline auto get_cuda_spsv_alg(spsv_alg /*alg*/) {
return CUSPARSE_SPSV_ALG_DEFAULT;
}

void check_valid_spsv(const std::string &function_name, matrix_view A_view,
matrix_handle_t A_handle, dense_vector_handle_t x_handle,
dense_vector_handle_t y_handle, bool is_alpha_host_accessible) {
detail::check_valid_spsv_common(function_name, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
check_valid_matrix_properties(function_name, A_handle);
}

void spsv_buffer_size(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alpha,
matrix_view A_view, matrix_handle_t A_handle, dense_vector_handle_t x_handle,
dense_vector_handle_t y_handle, spsv_alg alg, spsv_descr_t spsv_descr,
std::size_t &temp_buffer_size) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
detail::check_valid_spsv_common(__func__, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible);
auto functor = [=, &temp_buffer_size](CusparseScopedContextHandler &sc) {
auto cu_handle = sc.get_handle(queue);
auto cu_a = A_handle->backend_handle;
Expand All @@ -101,8 +108,8 @@ inline void common_spsv_optimize(oneapi::mkl::transpose opA, bool is_alpha_host_
matrix_view A_view, matrix_handle_t A_handle,
dense_vector_handle_t x_handle, dense_vector_handle_t y_handle,
spsv_alg alg, spsv_descr_t spsv_descr) {
detail::check_valid_spsv_common("spsv_optimize", A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
check_valid_spsv("spsv_optimize", A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
if (!spsv_descr->buffer_size_called) {
throw mkl::uninitialized("sparse_blas", "spsv_optimize",
"spsv_buffer_size must be called before spsv_optimize.");
Expand Down Expand Up @@ -202,8 +209,7 @@ sycl::event spsv(sycl::queue &queue, oneapi::mkl::transpose opA, const void *alp
dense_vector_handle_t y_handle, spsv_alg alg, spsv_descr_t spsv_descr,
const std::vector<sycl::event> &dependencies) {
bool is_alpha_host_accessible = detail::is_ptr_accessible_on_host(queue, alpha);
detail::check_valid_spsv_common(__func__, A_view, A_handle, x_handle, y_handle,
is_alpha_host_accessible);
check_valid_spsv(__func__, A_view, A_handle, x_handle, y_handle, is_alpha_host_accessible);
if (A_handle->all_use_buffer() != spsv_descr->workspace.use_buffer()) {
detail::throw_incompatible_container(__func__);
}
Expand Down
16 changes: 8 additions & 8 deletions src/sparse_blas/backends/mkl_common/mkl_handles.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p
sycl::buffer<intType, 1> col_ind, sycl::buffer<fpType, 1> val) {
oneapi::mkl::sparse::matrix_handle_t mkl_handle;
oneapi::mkl::sparse::init_matrix_handle(&mkl_handle);
auto internal_smhandle = new detail::sparse_matrix_handle(mkl_handle, row_ind, col_ind, val,
num_rows, num_cols, nnz, index);
auto internal_smhandle = new detail::sparse_matrix_handle(
mkl_handle, row_ind, col_ind, val, detail::sparse_format::COO, num_rows, num_cols, nnz, index);
// The backend handle must use the buffers from the internal handle as they will be kept alive until the handle is released.
oneapi::mkl::sparse::set_coo_data(queue, mkl_handle, static_cast<intType>(num_rows),
static_cast<intType>(num_cols), static_cast<intType>(nnz),
Expand All @@ -127,8 +127,8 @@ void init_coo_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p
fpType *val) {
oneapi::mkl::sparse::matrix_handle_t mkl_handle;
oneapi::mkl::sparse::init_matrix_handle(&mkl_handle);
auto internal_smhandle = new detail::sparse_matrix_handle(mkl_handle, row_ind, col_ind, val,
num_rows, num_cols, nnz, index);
auto internal_smhandle = new detail::sparse_matrix_handle(
mkl_handle, row_ind, col_ind, val, detail::sparse_format::COO, num_rows, num_cols, nnz, index);
auto event = oneapi::mkl::sparse::set_coo_data(
queue, mkl_handle, static_cast<intType>(num_rows), static_cast<intType>(num_cols),
static_cast<intType>(nnz), index, row_ind, col_ind, val);
Expand Down Expand Up @@ -189,8 +189,8 @@ void init_csr_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p
sycl::buffer<intType, 1> col_ind, sycl::buffer<fpType, 1> val) {
oneapi::mkl::sparse::matrix_handle_t mkl_handle;
oneapi::mkl::sparse::init_matrix_handle(&mkl_handle);
auto internal_smhandle = new detail::sparse_matrix_handle(mkl_handle, row_ptr, col_ind, val,
num_rows, num_cols, nnz, index);
auto internal_smhandle = new detail::sparse_matrix_handle(
mkl_handle, row_ptr, col_ind, val, detail::sparse_format::CSR, num_rows, num_cols, nnz, index);
// The backend deduces nnz from row_ptr.
// The backend handle must use the buffers from the internal handle as they will be kept alive until the handle is released.
oneapi::mkl::sparse::set_csr_data(queue, mkl_handle, static_cast<intType>(num_rows),
Expand All @@ -208,8 +208,8 @@ void init_csr_matrix(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p
fpType *val) {
oneapi::mkl::sparse::matrix_handle_t mkl_handle;
oneapi::mkl::sparse::init_matrix_handle(&mkl_handle);
auto internal_smhandle = new detail::sparse_matrix_handle(mkl_handle, row_ptr, col_ind, val,
num_rows, num_cols, nnz, index);
auto internal_smhandle = new detail::sparse_matrix_handle(
mkl_handle, row_ptr, col_ind, val, detail::sparse_format::CSR, num_rows, num_cols, nnz, index);
// The backend deduces nnz from row_ptr.
auto event = oneapi::mkl::sparse::set_csr_data(
queue, mkl_handle, static_cast<intType>(num_rows), static_cast<intType>(num_cols), index,
Expand Down
20 changes: 16 additions & 4 deletions src/sparse_blas/generic_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ struct generic_dense_matrix_handle : public detail::generic_dense_handle<Backend
}
};

enum class sparse_format { CSR, COO };

/// Generic sparse_matrix_handle used by all backends
template <typename BackendHandleT>
struct generic_sparse_handle {
Expand All @@ -214,6 +216,7 @@ struct generic_sparse_handle {
generic_container col_container;
generic_container value_container;

sparse_format format;
std::int64_t num_rows;
std::int64_t num_cols;
std::int64_t nnz;
Expand All @@ -223,12 +226,13 @@ struct generic_sparse_handle {

template <typename fpType, typename intType>
generic_sparse_handle(BackendHandleT backend_handle, intType* row_ptr, intType* col_ptr,
fpType* value_ptr, std::int64_t num_rows, std::int64_t num_cols,
std::int64_t nnz, oneapi::mkl::index_base index)
fpType* value_ptr, sparse_format format, std::int64_t num_rows,
std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index)
: backend_handle(backend_handle),
row_container(generic_container(row_ptr)),
col_container(generic_container(col_ptr)),
value_container(generic_container(value_ptr)),
format(format),
num_rows(num_rows),
num_cols(num_cols),
nnz(nnz),
Expand All @@ -239,12 +243,14 @@ struct generic_sparse_handle {
template <typename fpType, typename intType>
generic_sparse_handle(BackendHandleT backend_handle, const sycl::buffer<intType, 1> row_buffer,
const sycl::buffer<intType, 1> col_buffer,
const sycl::buffer<fpType, 1> value_buffer, std::int64_t num_rows,
std::int64_t num_cols, std::int64_t nnz, oneapi::mkl::index_base index)
const sycl::buffer<fpType, 1> value_buffer, sparse_format format,
std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz,
oneapi::mkl::index_base index)
: backend_handle(backend_handle),
row_container(row_buffer),
col_container(col_buffer),
value_container(value_buffer),
format(format),
num_rows(num_rows),
num_cols(num_cols),
nnz(nnz),
Expand All @@ -266,6 +272,11 @@ struct generic_sparse_handle {
}

void set_matrix_property(matrix_property property) {
if (format == sparse_format::CSR && property == matrix_property::sorted_by_rows) {
throw mkl::invalid_argument(
"sparse_blas", "set_matrix_property",
"Property `matrix_property::sorted_by_rows` is not compatible with CSR format.");
}
properties_mask |= matrix_property_to_mask(property);
}

Expand All @@ -278,6 +289,7 @@ struct generic_sparse_handle {
switch (property) {
case matrix_property::symmetric: return 1 << 0;
case matrix_property::sorted: return 1 << 1;
case matrix_property::sorted_by_rows: return 1 << 2;
default:
throw oneapi::mkl::invalid_argument(
"sparse_blas", "set_matrix_property",
Expand Down
Loading

0 comments on commit d23b24d

Please sign in to comment.