Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Support the migration of cusparse<T>csrgemm2 related API #2643

Open
wants to merge 11 commits into
base: SYCLomatic
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,41 @@ struct csrgemm_args_info_hash {
return std::hash<std::string>{}(ss.str());
}
};

using csrgemm2_args_info =
std::tuple<int, int, int,
const std::shared_ptr<matrix_info>, const void *, const int *,
const int *,
const std::shared_ptr<matrix_info>, const void *, const int *,
const int *,
const std::shared_ptr<matrix_info>, const void *, const int *,
const int *,
const std::shared_ptr<matrix_info>,
const int *>;
struct csrgemm2_args_info_hash {
size_t operator()(const csrgemm2_args_info &args) const {
std::stringstream ss;
ss << std::get<0>(args) << ":";
ss << std::get<1>(args) << ":";
ss << std::get<2>(args) << ":";
ss << std::get<3>(args).get() << ":";
ss << std::get<4>(args) << ":";
ss << std::get<5>(args) << ":";
ss << std::get<6>(args) << ":";
ss << std::get<7>(args).get() << ":";
ss << std::get<8>(args) << ":";
ss << std::get<9>(args) << ":";
ss << std::get<10>(args) << ":";
ss << std::get<11>(args).get() << ":";
ss << std::get<12>(args) << ":";
ss << std::get<12>(args) << ":";
ss << std::get<14>(args) << ":";
ss << std::get<15>(args).get() << ":";
ss << std::get<16>(args) << ":";
return std::hash<std::string>{}(ss.str());
}
};

#ifdef __INTEL_MKL__ // The oneMKL Interfaces Project does not support this.
template <typename handle_t> class handle_manager {
public:
Expand All @@ -58,7 +93,7 @@ template <typename handle_t> class handle_manager {
void init(sycl::queue *q) {
_q = q;
_h = new handle_t;
_init_func(_h);
_init_func(*_q, _h);
}
handle_t &get_handle() { return *_h; }
void add_dependency(sycl::event e) { _deps.push_back(e); }
Expand All @@ -68,7 +103,7 @@ template <typename handle_t> class handle_manager {
sycl::queue *_q = nullptr;

private:
using init_func_t = std::function<void(handle_t *)>;
using init_func_t = std::function<void(sycl::queue &, handle_t *)>;
using rel_func_t = std::function<sycl::event(
sycl::queue &, handle_t *, const std::vector<sycl::event> &dependencies)>;
handle_t *_h = nullptr;
Expand All @@ -79,16 +114,33 @@ template <typename handle_t> class handle_manager {
template <>
inline handle_manager<oneapi::mkl::sparse::matrix_handle_t>::init_func_t
handle_manager<oneapi::mkl::sparse::matrix_handle_t>::_init_func =
oneapi::mkl::sparse::init_matrix_handle;
[](sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_desc) {
oneapi::mkl::sparse::init_matrix_handle(p_desc);
};
template <>
inline handle_manager<oneapi::mkl::sparse::matrix_handle_t>::rel_func_t
handle_manager<oneapi::mkl::sparse::matrix_handle_t>::_rel_func =
oneapi::mkl::sparse::release_matrix_handle;

template <>
inline handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::init_func_t
handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::_init_func =
oneapi::mkl::sparse::init_omatadd_descr;
template <>
inline handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::rel_func_t
handle_manager<oneapi::mkl::sparse::omatadd_descr_t>::_rel_func =
[](sycl::queue &queue, oneapi::mkl::sparse::omatadd_descr_t *p_desc,
const std::vector<sycl::event> &dependencies) -> sycl::event {
return oneapi::mkl::sparse::release_omatadd_descr(queue, *p_desc,
dependencies);
};

template <>
inline handle_manager<oneapi::mkl::sparse::matmat_descr_t>::init_func_t
handle_manager<oneapi::mkl::sparse::matmat_descr_t>::_init_func =
oneapi::mkl::sparse::init_matmat_descr;
[](sycl::queue &queue, oneapi::mkl::sparse::matmat_descr_t *p_desc) {
oneapi::mkl::sparse::init_matmat_descr(p_desc);
};
template <>
inline handle_manager<oneapi::mkl::sparse::matmat_descr_t>::rel_func_t
handle_manager<oneapi::mkl::sparse::matmat_descr_t>::_rel_func =
Expand Down
231 changes: 231 additions & 0 deletions clang/runtime/dpct-rt/include/dpct/sparse_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,29 @@ class descriptor {
detail::csrgemm_args_info_hash>
_csrgemm_info_map;

struct matmat2_info {
detail::matrix_handle_manager matrix_handle_c1;
detail::matrix_handle_manager matrix_handle_d;
detail::matrix_handle_manager matrix_handle_c;
detail::handle_manager<oneapi::mkl::sparse::omatadd_descr_t> omatadd_desc;
bool is_empty() { return omatadd_desc.is_empty(); }
void init(sycl::queue *q_ptr) {
omatadd_desc.init(q_ptr);
matrix_handle_c1.init(q_ptr);
matrix_handle_d.init(q_ptr);
matrix_handle_c.init(q_ptr);
}
};
std::unordered_map<detail::csrgemm2_args_info, matmat2_info,
detail::csrgemm2_args_info_hash> &
get_csrgemm2_info_map() {
return _csrgemm2_info_map;
}

std::unordered_map<detail::csrgemm2_args_info, matmat2_info,
detail::csrgemm2_args_info_hash>
_csrgemm2_info_map;

template <typename T>
friend void
csrgemm_nnz(descriptor *desc, oneapi::mkl::transpose trans_a,
Expand All @@ -578,6 +601,29 @@ class descriptor {
const int *row_ptr_b, const int *col_ind_b,
const std::shared_ptr<matrix_info> info_c, T *val_c,
const int *row_ptr_c, int *col_ind_c);

template <typename T>
friend void
csrgemm2_nnz(descriptor *desc, int m, int n, int k,
const std::shared_ptr<matrix_info> info_a, int nnz_a,
const T *val_a, const int *row_ptr_a, const int *col_ind_a,
const std::shared_ptr<matrix_info> info_b, int nnz_b,
const T *val_b, const int *row_ptr_b, const int *col_ind_b,
const std::shared_ptr<matrix_info> info_d, int nnz_d,
const T *val_d, const int *row_ptr_d, const int *col_ind_d,
const std::shared_ptr<matrix_info> info_c, int *row_ptr_c,
int *nnz_ptr);
template <typename T>
friend void csrgemm2(descriptor *desc, int m, int n, int k, const T *alpha,
const std::shared_ptr<matrix_info> info_a, const T *val_a,
const int *row_ptr_a, const int *col_ind_a,
const std::shared_ptr<matrix_info> info_b, const T *val_b,
const int *row_ptr_b, const int *col_ind_b,
const T *beta,
const std::shared_ptr<matrix_info> info_d, const T *val_d,
const int *row_ptr_d, const int *col_ind_d,
const std::shared_ptr<matrix_info> info_c, T *val_c,
const int *row_ptr_c, int *col_ind_c);
#endif
::dpct::cs::queue_ptr _queue_ptr = &::dpct::cs::get_default_queue();
};
Expand Down Expand Up @@ -1354,6 +1400,191 @@ void csrgemm(descriptor_ptr desc, oneapi::mkl::transpose trans_a,
info.matmat_desc.add_dependency(e);
desc->get_csrgemm_info_map().erase(args);
}

template <typename T>
void csrgemm2_nnz(descriptor_ptr desc, int m, int n, int k,
const std::shared_ptr<matrix_info> info_a, int nnz_a,
const T *val_a, const int *row_ptr_a, const int *col_ind_a,
const std::shared_ptr<matrix_info> info_b, int nnz_b,
const T *val_b, const int *row_ptr_b, const int *col_ind_b,
const std::shared_ptr<matrix_info> info_d, int nnz_d,
const T *val_d, const int *row_ptr_d, const int *col_ind_d,
const std::shared_ptr<matrix_info> info_c, int *row_ptr_c,
int *nnz_ptr) {
using Ty = typename ::dpct::detail::lib_data_traits_t<T>;
sycl::queue &queue = desc->get_queue();
detail::csrgemm2_args_info args(
m, n, k, info_a, val_a, row_ptr_a, col_ind_a, info_b, val_b, row_ptr_b,
col_ind_b, info_d, val_d, row_ptr_d, col_ind_d, info_c, row_ptr_c);
auto &info = desc->get_csrgemm2_info_map()[args];

int *row_ptr_c1 = (int *)::dpct::cs::malloc((m + 1) * sizeof(int), queue);
if (info.is_empty()) {
info.init(&queue);
info.matrix_handle_c1.set_matrix_data<Ty>(
m, n, oneapi::mkl::index_base::zero, row_ptr_c1, nullptr, nullptr);
info.matrix_handle_d.set_matrix_data<Ty>(m, n, info_d->get_index_base(),
row_ptr_d, col_ind_d, val_d);
// In the future, oneMKL will allow nullptr to be passed in for row_ptr_c in
// the initial calls before matmat. But currently, it needs an array of
// length row_number + 1.
info.matrix_handle_c.set_matrix_data<Ty>(m, n, info_c->get_index_base(),
row_ptr_c, nullptr, nullptr);
}

oneapi::mkl::sparse::matrix_handle_t a = nullptr;
oneapi::mkl::sparse::init_matrix_handle(&a);
auto data_row_ptr_a = dpct::detail::get_memory<int>(row_ptr_a);
auto data_col_ind_a = dpct::detail::get_memory<int>(col_ind_a);
auto data_val_a = dpct::detail::get_memory<Ty>(val_a);
oneapi::mkl::sparse::set_csr_data(queue, a, m, k, info_a->get_index_base(),
data_row_ptr_a, data_col_ind_a, data_val_a);

oneapi::mkl::sparse::matrix_handle_t b = nullptr;
oneapi::mkl::sparse::init_matrix_handle(&b);
auto data_row_ptr_b = dpct::detail::get_memory<int>(row_ptr_a);
auto data_col_ind_b = dpct::detail::get_memory<int>(col_ind_a);
auto data_val_b = dpct::detail::get_memory<Ty>(val_a);
oneapi::mkl::sparse::set_csr_data(queue, b, k, n, info_b->get_index_base(),
data_row_ptr_b, data_col_ind_b, data_val_b);

oneapi::mkl::sparse::matmat_descr_t matmat_desc = nullptr;
oneapi::mkl::sparse::init_matmat_descr(&matmat_desc);
oneapi::mkl::sparse::set_matmat_data(
matmat_desc, oneapi::mkl::sparse::matrix_view_descr::general,
oneapi::mkl::transpose::nontrans,
oneapi::mkl::sparse::matrix_view_descr::general,
oneapi::mkl::transpose::nontrans,
oneapi::mkl::sparse::matrix_view_descr::general);

#ifdef DPCT_USM_LEVEL_NONE
#define __MATMAT(STEP, NNZ_C1) \
oneapi::mkl::sparse::matmat(queue, a, b, info.matrix_handle_c1.get_handle(), \
STEP, matmat_desc, NNZ_C1, nullptr)
#else
#define __MATMAT(STEP, NNZ_C1) \
oneapi::mkl::sparse::matmat(queue, a, b, info.matrix_handle_c1.get_handle(), \
STEP, matmat_desc, NNZ_C1, nullptr, {})
#endif

__MATMAT(oneapi::mkl::sparse::matmat_request::work_estimation, nullptr);
queue.wait();

__MATMAT(oneapi::mkl::sparse::matmat_request::compute, nullptr);

int nnz_c1_int = 0;
#ifdef DPCT_USM_LEVEL_NONE
sycl::buffer<std::int64_t, 1> nnz_buf_c1(1);
__MATMAT(oneapi::mkl::sparse::matmat_request::get_nnz, &nnz_buf_c1);
nnz_c1_int = nnz_buf_c.get_host_access(sycl::read_only)[0];
#else
std::int64_t *nnz_c1 = sycl::malloc_host<std::int64_t>(1, queue);
__MATMAT(oneapi::mkl::sparse::matmat_request::get_nnz, nnz_c1);
queue.wait();
nnz_c1_int = *nnz_c1;
#endif

int *col_ind_c1 = (int *)::dpct::cs::malloc(nnz_c1_int * sizeof(int), queue);
Ty *val_c1 = (Ty *)::dpct::cs::malloc(nnz_c1_int * sizeof(Ty), queue);
info.matrix_handle_c1.set_matrix_data<Ty>(m, n, oneapi::mkl::index_base::zero,
row_ptr_c1, col_ind_c1, val_c1);

__MATMAT(oneapi::mkl::sparse::matmat_request::finalize, nullptr);
#undef __MATMAT

queue.wait();
oneapi::mkl::sparse::release_matmat_descr(&matmat_desc);

std::int64_t ws_size = 0;
oneapi::mkl::sparse::omatadd_buffer_size(
queue, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans,
info.matrix_handle_c1.get_handle(), info.matrix_handle_d.get_handle(),
info.matrix_handle_c.get_handle(),
oneapi::mkl::sparse::omatadd_alg::default_alg,
info.omatadd_desc.get_handle(), ws_size);

void *ws = ::dpct::cs::malloc(ws_size, queue);

oneapi::mkl::sparse::omatadd_analyze(
queue, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans,
info.matrix_handle_c1.get_handle(), info.matrix_handle_d.get_handle(),
info.matrix_handle_c.get_handle(),
oneapi::mkl::sparse::omatadd_alg::default_alg,
info.omatadd_desc.get_handle(), ws);

std::int64_t c_nnz = 0;
oneapi::mkl::sparse::omatadd_get_nnz(
queue, oneapi::mkl::transpose::nontrans, oneapi::mkl::transpose::nontrans,
info.matrix_handle_c1.get_handle(), info.matrix_handle_d.get_handle(),
info.matrix_handle_c.get_handle(),
oneapi::mkl::sparse::omatadd_alg::default_alg,
info.omatadd_desc.get_handle(), c_nnz);

int c_nnz_int = c_nnz;
if (nnz_ptr) {
::dpct::cs::memcpy(::dpct::cs::get_default_queue(), nnz_ptr, &c_nnz_int,
sizeof(int))
.wait();
}
if (info_c->get_index_base() == oneapi::mkl::index_base::one) {
c_nnz_int++;
}
::dpct::cs::memcpy(::dpct::cs::get_default_queue(), row_ptr_c + m, &c_nnz_int,
sizeof(int))
.wait();

oneapi::mkl::sparse::release_matmat_descr(&matmat_desc);
oneapi::mkl::sparse::release_matrix_handle(queue, &a);
oneapi::mkl::sparse::release_matrix_handle(queue, &b);
queue.wait();
}

template <typename T>
void csrgemm2(descriptor_ptr desc, int m, int n, int k, const T *alpha,
const std::shared_ptr<matrix_info> info_a, const T *val_a,
const int *row_ptr_a, const int *col_ind_a,
const std::shared_ptr<matrix_info> info_b, const T *val_b,
const int *row_ptr_b, const int *col_ind_b, const T *beta,
const std::shared_ptr<matrix_info> info_d, const T *val_d,
const int *row_ptr_d, const int *col_ind_d,
const std::shared_ptr<matrix_info> info_c, T *val_c,
const int *row_ptr_c, int *col_ind_c) {
using Ty = typename ::dpct::detail::lib_data_traits_t<T>;
sycl::queue &queue = desc->get_queue();
auto alpha_value =
dpct::detail::get_value(reinterpret_cast<const Ty *>(alpha), queue);
auto beta_value =
dpct::detail::get_value(reinterpret_cast<const Ty *>(beta), queue);

detail::csrgemm2_args_info args(
m, n, k, info_a, val_a, row_ptr_a, col_ind_a, info_b, val_b, row_ptr_b,
col_ind_b, info_d, val_d, row_ptr_d, col_ind_d, info_c, row_ptr_c);
auto &info = desc->get_csrgemm2_info_map()[args];
if (info.is_empty()) {
throw std::runtime_error("csrgemm2_nnz is not invoked previously.");
}

info.matrix_handle_c.set_matrix_data<Ty>(m, n, info_c->get_index_base(),
row_ptr_c, col_ind_c, val_c);
sycl::event e;
#ifndef DPCT_USM_LEVEL_NONE
e =
#endif
oneapi::mkl::sparse::omatadd(
queue, oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans, alpha_value,
info.matrix_handle_c1.get_handle(), beta_value,
info.matrix_handle_d.get_handle(), info.matrix_handle_c.get_handle(),
oneapi::mkl::sparse::omatadd_alg::default_alg,
info.omatadd_desc.get_handle());

info.matrix_handle_c1.add_dependency(e);
info.matrix_handle_d.add_dependency(e);
info.matrix_handle_c.add_dependency(e);
info.omatadd_desc.add_dependency(e);
desc->get_csrgemm2_info_map().erase(args);
}

#endif
} // namespace dpct::sparse

Expand Down
Loading