From b480dd9f700f01705df44ed3b0624186dbd3c69d Mon Sep 17 00:00:00 2001 From: Yvan Mokwinski Date: Fri, 11 Aug 2023 13:59:04 -0600 Subject: [PATCH] Revert "reverted 5.6 api change (#371)" (#375) This reverts commit 88d58a85ab1be5cb2ed2950abe515ded22798de7. --- CHANGELOG.md | 4 --- CMakeLists.txt | 2 +- clients/include/testing_spsv_coo.hpp | 18 +++++----- clients/include/testing_spsv_csr.hpp | 18 +++++----- library/include/hipsparse.h | 3 +- library/src/hcc_detail/hipsparse.cpp | 52 ++++++++++++++++++--------- library/src/nvcc_detail/hipsparse.cpp | 3 +- 7 files changed, 57 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ae41fec..5ee89fd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,5 @@ # Change Log for hipSPARSE -## hipSPARSE 2.3.7 for ROCm 5.6.1 -### Bugfix -- Reverted an undocumented API change in hipSPARSE 2.3.6 that affected hipsparseSpSV_solve function - ## hipSPARSE 2.3.6 for ROCm 5.6.0 ### Added - Added SpGEMM algorithms diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c4ab1a2..97a256d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,7 @@ endif() # Setup version -rocm_setup_version(VERSION 2.3.7) +rocm_setup_version(VERSION 2.3.6) set(hipsparse_SOVERSION 0.1) # hipSPARSE library diff --git a/clients/include/testing_spsv_coo.hpp b/clients/include/testing_spsv_coo.hpp index b47840fd..f07c8dc7 100644 --- a/clients/include/testing_spsv_coo.hpp +++ b/clients/include/testing_spsv_coo.hpp @@ -138,23 +138,23 @@ void testing_spsv_coo_bad_arg(void) // SpSV solve verify_hipsparse_status_invalid_handle( - hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr, dbuf)); + hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr)); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr), "Error: alpha is nullptr"); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr), "Error: A is nullptr"); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr), "Error: x is nullptr"); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr), "Error: y is nullptr"); #if(!defined(CUDART_VERSION)) verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, descr, nullptr), - "Error: dbuf is nullptr"); + hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, nullptr), + "Error: descr is nullptr"); #endif // Destruct @@ -307,12 +307,12 @@ hipsparseStatus_t testing_spsv_coo(void) // HIPSPARSE pointer mode host CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST)); CHECK_HIPSPARSE_ERROR( - hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr, buffer)); + hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr)); // HIPSPARSE pointer mode device CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_DEVICE)); CHECK_HIPSPARSE_ERROR( - hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr, buffer)); + hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr)); // copy output from device to CPU CHECK_HIP_ERROR(hipMemcpy(hy_1.data(), dy_1, sizeof(T) * m, hipMemcpyDeviceToHost)); diff --git a/clients/include/testing_spsv_csr.hpp b/clients/include/testing_spsv_csr.hpp index ff65fa3a..8265baa8 100644 --- a/clients/include/testing_spsv_csr.hpp +++ b/clients/include/testing_spsv_csr.hpp @@ -139,23 +139,23 @@ void testing_spsv_csr_bad_arg(void) // SpSV solve verify_hipsparse_status_invalid_handle( - hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr, dbuf)); + hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr)); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr), "Error: alpha is nullptr"); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr), "Error: A is nullptr"); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr), "Error: x is nullptr"); verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr, dbuf), + hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr), "Error: y is nullptr"); #if(!defined(CUDART_VERSION)) verify_hipsparse_status_invalid_pointer( - hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, descr, nullptr), - "Error: dbuf is nullptr"); + hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, nullptr), + "Error: descr is nullptr"); #endif // Destruct @@ -300,12 +300,12 @@ hipsparseStatus_t testing_spsv_csr(void) // HIPSPARSE pointer mode host CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST)); CHECK_HIPSPARSE_ERROR( - hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr, buffer)); + hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr)); // HIPSPARSE pointer mode device CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_DEVICE)); CHECK_HIPSPARSE_ERROR( - hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr, buffer)); + hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr)); // copy output from device to CPU CHECK_HIP_ERROR(hipMemcpy(hy_1.data(), dy_1, sizeof(T) * m, hipMemcpyDeviceToHost)); diff --git a/library/include/hipsparse.h b/library/include/hipsparse.h index d28b53dc..c72f4b78 100644 --- a/library/include/hipsparse.h +++ b/library/include/hipsparse.h @@ -10394,8 +10394,7 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle, const hipsparseDnVecDescr_t y, hipDataType computeType, hipsparseSpSVAlg_t alg, - hipsparseSpSVDescr_t spsvDescr, - void* externalBuffer); + hipsparseSpSVDescr_t spsvDescr); #endif /*! \ingroup generic_module diff --git a/library/src/hcc_detail/hipsparse.cpp b/library/src/hcc_detail/hipsparse.cpp index f3820957..98187524 100644 --- a/library/src/hcc_detail/hipsparse.cpp +++ b/library/src/hcc_detail/hipsparse.cpp @@ -14259,15 +14259,25 @@ hipsparseStatus_t hipsparseSDDMM_preprocess(hipsparseHandle_t handle, tempBuffer)); } +struct hipsparseSpSVDescr +{ + void* externalBuffer{}; +}; + hipsparseStatus_t hipsparseSpSV_createDescr(hipsparseSpSVDescr_t* descr) { - // Do nothing + *descr = new hipsparseSpSVDescr; return HIPSPARSE_STATUS_SUCCESS; } hipsparseStatus_t hipsparseSpSV_destroyDescr(hipsparseSpSVDescr_t descr) { - // Do nothing + if(descr != nullptr) + { + descr->externalBuffer = nullptr; + delete descr; + } + return HIPSPARSE_STATUS_SUCCESS; } @@ -14306,17 +14316,24 @@ hipsparseStatus_t hipsparseSpSV_analysis(hipsparseHandle_t handle, hipsparseSpSVDescr_t spsvDescr, void* externalBuffer) { - return rocSPARSEStatusToHIPStatus(rocsparse_spsv((rocsparse_handle)handle, - hipOperationToHCCOperation(opA), - alpha, - (const rocsparse_spmat_descr)matA, - (const rocsparse_dnvec_descr)x, - (const rocsparse_dnvec_descr)y, - hipDataTypeToHCCDataType(computeType), - hipSpSVAlgToHCCSpSVAlg(alg), - rocsparse_spsv_stage_preprocess, - nullptr, - externalBuffer)); + + if(spsvDescr == nullptr) + { + return HIPSPARSE_STATUS_INVALID_VALUE; + } + RETURN_IF_ROCSPARSE_ERROR(rocsparse_spsv((rocsparse_handle)handle, + hipOperationToHCCOperation(opA), + alpha, + (const rocsparse_spmat_descr)matA, + (const rocsparse_dnvec_descr)x, + (const rocsparse_dnvec_descr)y, + hipDataTypeToHCCDataType(computeType), + hipSpSVAlgToHCCSpSVAlg(alg), + rocsparse_spsv_stage_preprocess, + nullptr, + externalBuffer)); + spsvDescr->externalBuffer = externalBuffer; + return HIPSPARSE_STATUS_SUCCESS; } hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle, @@ -14327,9 +14344,12 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle, const hipsparseDnVecDescr_t y, hipDataType computeType, hipsparseSpSVAlg_t alg, - hipsparseSpSVDescr_t spsvDescr, - void* externalBuffer) + hipsparseSpSVDescr_t spsvDescr) { + if(spsvDescr == nullptr) + { + return HIPSPARSE_STATUS_INVALID_VALUE; + } return rocSPARSEStatusToHIPStatus(rocsparse_spsv((rocsparse_handle)handle, hipOperationToHCCOperation(opA), alpha, @@ -14340,7 +14360,7 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle, hipSpSVAlgToHCCSpSVAlg(alg), rocsparse_spsv_stage_compute, nullptr, - externalBuffer)); + spsvDescr->externalBuffer)); } hipsparseStatus_t hipsparseSpSM_createDescr(hipsparseSpSMDescr_t* descr) diff --git a/library/src/nvcc_detail/hipsparse.cpp b/library/src/nvcc_detail/hipsparse.cpp index 75a6e1b9..7d35f0d4 100644 --- a/library/src/nvcc_detail/hipsparse.cpp +++ b/library/src/nvcc_detail/hipsparse.cpp @@ -12102,8 +12102,7 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle, const hipsparseDnVecDescr_t y, hipDataType computeType, hipsparseSpSVAlg_t alg, - hipsparseSpSVDescr_t spsvDescr, - void* externalBuffer) + hipsparseSpSVDescr_t spsvDescr) { return hipCUSPARSEStatusToHIPStatus(cusparseSpSV_solve((cusparseHandle_t)handle, hipOperationToCudaOperation(opA),