Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rocblas] Use enqueue_native_command ext if available #581

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions src/blas/backends/rocblas/rocblas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ inline void copy_batch(Func func, sycl::queue &queue, int64_t n, sycl::buffer<T,
auto x_ = sc.get_mem<rocDataType *>(x_acc);
auto y_ = sc.get_mem<rocDataType *>(y_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, stridex, y_, incy, stridey,
rocblas_native_func(func, err, handle, n, x_, incx, stridex, y_, incy, stridey,
batch_size);
});
});
Expand Down Expand Up @@ -123,7 +123,7 @@ inline void axpy_batch(Func func, sycl::queue &queue, int64_t n, T alpha, sycl::
auto x_ = sc.get_mem<rocDataType *>(x_acc);
auto y_ = sc.get_mem<rocDataType *>(y_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, (rocDataType *)&alpha, x_, incx, stridex,
rocblas_native_func(func, err, handle, n, (rocDataType *)&alpha, x_, incx, stridex,
y_, incy, stridey, batch_size);
});
});
Expand Down Expand Up @@ -163,7 +163,7 @@ inline void gemv_batch(Func func, sycl::queue &queue, transpose trans, int64_t m
auto x_ = sc.get_mem<const rocDataType *>(x_acc);
auto y_ = sc.get_mem<rocDataType *>(y_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), m, n,
rocblas_native_func(func, err, handle, get_rocblas_operation(trans), m, n,
(rocDataType *)&alpha, a_, lda, stridea, x_, incx, stridex,
(rocDataType *)&beta, y_, incy, stridey, batch_size);
});
Expand Down Expand Up @@ -205,7 +205,7 @@ inline void dgmm_batch(Func func, sycl::queue &queue, side left_right, int64_t m
auto x_ = sc.get_mem<const rocDataType *>(x_acc);
auto c_ = sc.get_mem<rocDataType *>(c_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), m, n, a_,
rocblas_native_func(func, err, handle, get_rocblas_side_mode(left_right), m, n, a_,
lda, stridea, x_, incx, stridex, c_, ldc, stridec, batch_size);
});
});
Expand Down Expand Up @@ -253,7 +253,7 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran
auto c_ = sc.get_mem<rocTypeC *>(c_acc);

rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle,
rocblas_native_func(rocblas_gemm_strided_batched_ex, err, handle,
get_rocblas_operation(transa), get_rocblas_operation(transb), m,
n, k, &alpha, a_, get_rocblas_datatype<rocTypeA>(), lda,
stridea, b_, get_rocblas_datatype<rocTypeB>(), ldb, strideb,
Expand Down Expand Up @@ -320,7 +320,7 @@ inline void trsm_batch(Func func, sycl::queue &queue, side left_right, uplo uppe
auto a_ = sc.get_mem<const rocDataType *>(a_acc);
auto b_ = sc.get_mem<rocDataType *>(b_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
rocblas_native_func(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, stridea, b_, ldb, strideb,
Expand Down Expand Up @@ -362,7 +362,7 @@ inline void syrk_batch(Func func, sycl::queue &queue, uplo upper_lower, transpos
auto a_ = sc.get_mem<const rocDataType *>(a_acc);
auto c_ = sc.get_mem<rocDataType *>(c_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_fill_mode(upper_lower),
rocblas_native_func(func, err, handle, get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), n, k, (rocDataType *)&alpha, a_,
lda, stridea, (rocDataType *)&beta, c_, ldc, stridec,
batch_size);
Expand Down Expand Up @@ -406,7 +406,7 @@ inline void omatcopy_batch(Func func, sycl::queue &queue, transpose trans, int64
auto a_ = sc.get_mem<const rocDataType *>(a_acc);
auto b_ = sc.get_mem<rocDataType *>(b_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans),
rocblas_native_func(func, err, handle, get_rocblas_operation(trans),
get_rocblas_operation(trans), new_m, new_n,
(rocDataType *)&alpha, a_, lda, stridea, (rocDataType *)&beta,
nullptr, lda, stridea, b_, ldb, strideb, batch_size);
Expand Down Expand Up @@ -474,7 +474,7 @@ inline void omatadd_batch(Func func, sycl::queue &queue, transpose transa, trans
auto b_ = sc.get_mem<const rocDataType *>(b_acc);
auto c_ = sc.get_mem<rocDataType *>(c_acc);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa),
rocblas_native_func(func, err, handle, get_rocblas_operation(transa),
get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_,
lda, stridea, (rocDataType *)&beta, b_, ldb, strideb, c_, ldc,
stridec, batch_size);
Expand Down Expand Up @@ -520,7 +520,7 @@ inline sycl::event copy_batch(Func func, sycl::queue &queue, int64_t *n, const T
for (int64_t i = 0; i < group_count; i++) {
auto **x_ = reinterpret_cast<const rocDataType **>(x);
auto **y_ = reinterpret_cast<rocDataType **>(y);
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, (int)n[i], x_ + offset, (int)incx[i],
rocblas_native_func(func, err, handle, (int)n[i], x_ + offset, (int)incx[i],
y_ + offset, (int)incy[i], (int)group_size[i]);
offset += group_size[i];
}
Expand Down Expand Up @@ -560,7 +560,7 @@ inline sycl::event copy_batch(Func func, sycl::queue &queue, int64_t n, const T
auto x_ = reinterpret_cast<const rocDataType *>(x);
auto y_ = reinterpret_cast<rocDataType *>(y);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, stridex, y_, incy, stridey,
rocblas_native_func(func, err, handle, n, x_, incx, stridex, y_, incy, stridey,
batch_size);
});
});
Expand Down Expand Up @@ -602,7 +602,7 @@ inline sycl::event axpy_batch(Func func, sycl::queue &queue, int64_t *n, T *alph
for (int64_t i = 0; i < group_count; i++) {
auto **x_ = reinterpret_cast<const rocDataType **>(x);
auto **y_ = reinterpret_cast<rocDataType **>(y);
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, (int)n[i], (rocDataType *)&alpha[i],
rocblas_native_func(func, err, handle, (int)n[i], (rocDataType *)&alpha[i],
x_ + offset, (int)incx[i], y_ + offset, (int)incy[i],
(int)group_size[i]);
offset += group_size[i];
Expand Down Expand Up @@ -643,7 +643,7 @@ inline sycl::event axpy_batch(Func func, sycl::queue &queue, int64_t n, T alpha,
auto x_ = reinterpret_cast<const rocDataType *>(x);
auto y_ = reinterpret_cast<rocDataType *>(y);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, (rocDataType *)&alpha, x_, incx, stridex,
rocblas_native_func(func, err, handle, n, (rocDataType *)&alpha, x_, incx, stridex,
y_, incy, stridey, batch_size);
});
});
Expand Down Expand Up @@ -684,7 +684,7 @@ inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose trans, in
auto x_ = reinterpret_cast<const rocDataType *>(x);
auto y_ = reinterpret_cast<rocDataType *>(y);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), m, n,
rocblas_native_func(func, err, handle, get_rocblas_operation(trans), m, n,
(rocDataType *)&alpha, a_, lda, stridea, x_, incx, stridex,
(rocDataType *)&beta, y_, incy, stridey, batch_size);
});
Expand Down Expand Up @@ -731,7 +731,7 @@ inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose *trans, i
auto **a_ = reinterpret_cast<const rocDataType **>(a);
auto **x_ = reinterpret_cast<const rocDataType **>(x);
auto **y_ = reinterpret_cast<rocDataType **>(y);
ROCBLAS_ERROR_FUNC_SYNC(
rocblas_native_func(
func, err, handle, get_rocblas_operation(trans[i]), (int)m[i], (int)n[i],
(rocDataType *)&alpha[i], a_ + offset, (int)lda[i], x_ + offset, (int)incx[i],
(rocDataType *)&beta[i], y_ + offset, (int)incy[i], (int)group_size[i]);
Expand Down Expand Up @@ -776,7 +776,7 @@ inline sycl::event dgmm_batch(Func func, sycl::queue &queue, side left_right, in
auto x_ = reinterpret_cast<const rocDataType *>(x);
auto c_ = reinterpret_cast<rocDataType *>(c);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), m, n, a_,
rocblas_native_func(func, err, handle, get_rocblas_side_mode(left_right), m, n, a_,
lda, stridea, x_, incx, stridex, c_, ldc, stridec, batch_size);
});
});
Expand Down Expand Up @@ -821,7 +821,7 @@ inline sycl::event dgmm_batch(Func func, sycl::queue &queue, side *left_right, i
auto **a_ = reinterpret_cast<const rocDataType **>(a);
auto **x_ = reinterpret_cast<const rocDataType **>(x);
auto **c_ = reinterpret_cast<rocDataType **>(c);
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right[i]),
rocblas_native_func(func, err, handle, get_rocblas_side_mode(left_right[i]),
(int)m[i], (int)n[i], a_ + offset, (int)lda[i], x_ + offset,
(int)incx[i], c_ + offset, (int)ldc[i], (int)group_size[i]);
offset += group_size[i];
Expand Down Expand Up @@ -873,7 +873,7 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra
auto b_ = reinterpret_cast<const rocTypeB *>(b);
auto c_ = reinterpret_cast<rocTypeC *>(c);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle,
rocblas_native_func(rocblas_gemm_strided_batched_ex, err, handle,
get_rocblas_operation(transa), get_rocblas_operation(transb), m,
n, k, &alpha, a_, get_rocblas_datatype<rocTypeA>(), lda,
stridea, b_, get_rocblas_datatype<rocTypeB>(), ldb, strideb,
Expand Down Expand Up @@ -953,7 +953,7 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
auto **a_ = reinterpret_cast<const rocTypeA **>(a);
auto **b_ = reinterpret_cast<const rocTypeB **>(b);
auto **c_ = reinterpret_cast<rocTypeC **>(c);
ROCBLAS_ERROR_FUNC_SYNC(
rocblas_native_func(
rocblas_gemm_batched_ex, err, handle, get_rocblas_operation(transa[i]),
get_rocblas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], &alpha[i],
a_ + offset, get_rocblas_datatype<rocTypeA>(), (int)lda[i], b_ + offset,
Expand Down Expand Up @@ -1025,7 +1025,7 @@ inline sycl::event trsm_batch(Func func, sycl::queue &queue, side left_right, up
auto a_ = reinterpret_cast<const rocDataType *>(a);
auto b_ = reinterpret_cast<rocDataType *>(b);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right),
rocblas_native_func(func, err, handle, get_rocblas_side_mode(left_right),
get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag),
m, n, (rocDataType *)&alpha, a_, lda, stridea, b_, ldb, strideb,
Expand Down Expand Up @@ -1072,7 +1072,7 @@ inline sycl::event trsm_batch(Func func, sycl::queue &queue, side *left_right, u
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const rocDataType **>(a);
auto **b_ = reinterpret_cast<rocDataType **>(b);
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right[i]),
rocblas_native_func(func, err, handle, get_rocblas_side_mode(left_right[i]),
get_rocblas_fill_mode(upper_lower[i]),
get_rocblas_operation(trans[i]),
get_rocblas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
Expand Down Expand Up @@ -1123,7 +1123,7 @@ inline sycl::event syrk_batch(Func func, sycl::queue &queue, uplo *upper_lower,
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const rocDataType **>(a);
auto **c_ = reinterpret_cast<rocDataType **>(c);
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_fill_mode(upper_lower[i]),
rocblas_native_func(func, err, handle, get_rocblas_fill_mode(upper_lower[i]),
get_rocblas_operation(trans[i]), (int)n[i], (int)k[i],
(rocDataType *)&alpha[i], a_ + offset, (int)lda[i],
(rocDataType *)&beta[i], c_ + offset, (int)ldc[i],
Expand Down Expand Up @@ -1168,7 +1168,7 @@ inline sycl::event syrk_batch(Func func, sycl::queue &queue, uplo upper_lower, t
auto a_ = reinterpret_cast<const rocDataType *>(a);
auto c_ = reinterpret_cast<rocDataType *>(c);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_fill_mode(upper_lower),
rocblas_native_func(func, err, handle, get_rocblas_fill_mode(upper_lower),
get_rocblas_operation(trans), n, k, (rocDataType *)&alpha, a_,
lda, stridea, (rocDataType *)&beta, c_, ldc, stridec,
batch_size);
Expand Down Expand Up @@ -1216,7 +1216,7 @@ inline sycl::event omatcopy_batch(Func func, sycl::queue &queue, transpose trans
auto a_ = reinterpret_cast<const rocDataType *>(a);
auto b_ = reinterpret_cast<rocDataType *>(b);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans),
rocblas_native_func(func, err, handle, get_rocblas_operation(trans),
get_rocblas_operation(trans), new_m, new_n,
(rocDataType *)&alpha, a_, lda, stridea, (rocDataType *)&beta,
nullptr, lda, stridea, b_, ldb, strideb, batch_size);
Expand Down Expand Up @@ -1286,7 +1286,7 @@ inline sycl::event omatadd_batch(Func func, sycl::queue &queue, transpose transa
auto b_ = reinterpret_cast<const rocDataType *>(b);
auto c_ = reinterpret_cast<rocDataType *>(c);
rocblas_status err;
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa),
rocblas_native_func(func, err, handle, get_rocblas_operation(transa),
get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_,
lda, stridea, (rocDataType *)&beta, b_, ldb, strideb, c_, ldc,
stridec, batch_size);
Expand Down Expand Up @@ -1338,7 +1338,7 @@ inline sycl::event omatcopy_batch(Func func, sycl::queue &queue, transpose *tran
const auto new_m = trans[i] == oneapi::mkl::transpose::nontrans ? m[i] : n[i];
const auto new_n = trans[i] == oneapi::mkl::transpose::nontrans ? n[i] : m[i];

ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans[i]),
rocblas_native_func(func, err, handle, get_rocblas_operation(trans[i]),
get_rocblas_operation(trans[i]), (int)new_m, (int)new_n,
(rocDataType *)&alpha[i], a_ + offset, (int)lda[i],
(rocDataType *)&beta, nullptr, (int)lda[i], b_ + offset,
Expand Down
10 changes: 10 additions & 0 deletions src/blas/backends/rocblas/rocblas_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ class hip_error : virtual public std::runtime_error {
hipError_t hip_err; \
HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId);

template <class Func, class... Types>
inline void rocblas_native_func(Func func, rocblas_status err,
rocblas_handle handle, Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
ROCBLAS_ERROR_FUNC(func, err, handle, args...)
#else
ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, args...)
#endif
};

inline rocblas_operation get_rocblas_operation(oneapi::mkl::transpose trn) {
switch (trn) {
case oneapi::mkl::transpose::nontrans: return rocblas_operation_none;
Expand Down
Loading
Loading