diff --git a/clients/common/hipsparse_template_specialization.cpp b/clients/common/hipsparse_template_specialization.cpp index 47cd4c9d..5b275ab1 100644 --- a/clients/common/hipsparse_template_specialization.cpp +++ b/clients/common/hipsparse_template_specialization.cpp @@ -3469,6 +3469,46 @@ namespace hipsparse } #endif + template <> + hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle, + bsrilu02Info_t info, + int enable_boost, + double* tol, + float* boost_val) + { + return hipsparseSbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + + template <> + hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle, + bsrilu02Info_t info, + int enable_boost, + double* tol, + double* boost_val) + { + return hipsparseDbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + + template <> + hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle, + bsrilu02Info_t info, + int enable_boost, + double* tol, + hipComplex* boost_val) + { + return hipsparseCbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + + template <> + hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle, + bsrilu02Info_t info, + int enable_boost, + double* tol, + hipDoubleComplex* boost_val) + { + return hipsparseZbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + template <> hipsparseStatus_t hipsparseXbsrilu02_bufferSize(hipsparseHandle_t handle, hipsparseDirection_t dirA, @@ -3797,6 +3837,46 @@ namespace hipsparse pBuffer); } + template <> + hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + float* boost_val) + { + return hipsparseScsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + + template <> + hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + double* boost_val) + { + return hipsparseDcsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + + template <> + hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipComplex* boost_val) + { + return hipsparseCcsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + + template <> + hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipDoubleComplex* boost_val) + { + return hipsparseZcsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val); + } + template <> hipsparseStatus_t hipsparseXcsrilu02_bufferSize(hipsparseHandle_t handle, int m, diff --git a/clients/include/hipsparse.hpp b/clients/include/hipsparse.hpp index 1e5a44af..4f769054 100644 --- a/clients/include/hipsparse.hpp +++ b/clients/include/hipsparse.hpp @@ -523,6 +523,10 @@ namespace hipsparse void* pBuffer); #endif + template + hipsparseStatus_t hipsparseXbsrilu02_numericBoost( + hipsparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double* tol, T* boost_val); + template hipsparseStatus_t hipsparseXbsrilu02_bufferSize(hipsparseHandle_t handle, hipsparseDirection_t dirA, @@ -564,6 +568,10 @@ namespace hipsparse hipsparseSolvePolicy_t policy, void* pBuffer); + template + hipsparseStatus_t hipsparseXcsrilu02_numericBoost( + hipsparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, T* boost_val); + template hipsparseStatus_t hipsparseXcsrilu02_bufferSize(hipsparseHandle_t handle, int m, diff --git a/clients/include/testing_bsrilu02.hpp b/clients/include/testing_bsrilu02.hpp index c63c9a36..0d49a62c 100644 --- a/clients/include/testing_bsrilu02.hpp +++ b/clients/include/testing_bsrilu02.hpp @@ -32,7 +32,6 @@ #include #include -#include #include using namespace hipsparse; @@ -67,13 +66,17 @@ void testing_bsrilu02_bad_arg(void) auto dval_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free}; auto dbuffer_managed = hipsparse_unique_ptr{device_malloc(sizeof(char) * safe_size), device_free}; + auto dboost_tol_managed = hipsparse_unique_ptr{device_malloc(sizeof(double)), device_free}; + auto dboost_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T)), device_free}; - int* dptr = (int*)dptr_managed.get(); - int* dcol = (int*)dcol_managed.get(); - T* dval = (T*)dval_managed.get(); - void* dbuffer = (void*)dbuffer_managed.get(); + int* dptr = (int*)dptr_managed.get(); + int* dcol = (int*)dcol_managed.get(); + T* dval = (T*)dval_managed.get(); + void* dbuffer = (void*)dbuffer_managed.get(); + double* dboost_tol = (double*)dboost_tol_managed.get(); + T* dboost_val = (T*)dboost_val_managed.get(); - if(!dval || !dptr || !dcol || !dbuffer) + if(!dval || !dptr || !dcol || !dbuffer || !dboost_tol || !dboost_val) { PRINT_IF_HIP_ERROR(hipErrorOutOfMemory); return; @@ -139,6 +142,40 @@ void testing_bsrilu02_bad_arg(void) verify_hipsparse_status_invalid_handle(status); } + // testing hipsparseXbsrilu02_numericBoost + + // testing for(nullptr == handle) + { + hipsparseHandle_t handle_null = nullptr; + + status = hipsparseXbsrilu02_numericBoost(handle_null, info, 1, dboost_tol, dboost_val); + verify_hipsparse_status_invalid_handle(status); + } + + // testing for(nullptr == info) + { + bsrilu02Info_t info_null = nullptr; + + status = hipsparseXbsrilu02_numericBoost(handle, info_null, 1, dboost_tol, dboost_val); + verify_hipsparse_status_invalid_pointer(status, "Error: info is nullptr"); + } + + // testing for(nullptr == dboost_tol) + { + double* boost_tol_null = nullptr; + + status = hipsparseXbsrilu02_numericBoost(handle, info, 1, boost_tol_null, dboost_val); + verify_hipsparse_status_invalid_pointer(status, "Error: boost_tol is nullptr"); + } + + // testing for(nullptr == dboost_val) + { + T* boost_val_null = nullptr; + + status = hipsparseXbsrilu02_numericBoost(handle, info, 1, dboost_tol, boost_val_null); + verify_hipsparse_status_invalid_pointer(status, "Error: boost_val is nullptr"); + } + // testing hipsparseXbsrilu02_analysis // testing for(nullptr == dptr) @@ -289,6 +326,9 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) int safe_size = 100; int m = argus.M; int block_dim = argus.block_dim; + int boost = argus.numericboost; + double boost_tol = argus.boosttol; + T boost_val = make_DataType(argus.boostval, argus.boostvali); hipsparseDirection_t dir = argus.dirA; hipsparseIndexBase_t idx_base = argus.idx_base; hipsparseSolvePolicy_t policy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; @@ -476,17 +516,21 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) auto dcsr_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * nnz), device_free}; auto dbsr_row_ptr_managed = hipsparse_unique_ptr{device_malloc(sizeof(int) * (mb + 1)), device_free}; + auto boost_tol_managed = hipsparse_unique_ptr{device_malloc(sizeof(double)), device_free}; + auto boost_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T)), device_free}; - int* dcsr_row_ptr = (int*)dcsr_row_ptr_managed.get(); - int* dcsr_col_ind = (int*)dcsr_col_ind_managed.get(); - T* dcsr_val = (T*)dcsr_val_managed.get(); - int* dbsr_row_ptr = (int*)dbsr_row_ptr_managed.get(); + int* dcsr_row_ptr = (int*)dcsr_row_ptr_managed.get(); + int* dcsr_col_ind = (int*)dcsr_col_ind_managed.get(); + T* dcsr_val = (T*)dcsr_val_managed.get(); + int* dbsr_row_ptr = (int*)dbsr_row_ptr_managed.get(); + double* dboost_tol = (double*)boost_tol_managed.get(); + T* dboost_val = (T*)boost_val_managed.get(); - if(!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || !dbsr_row_ptr) + if(!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || !dbsr_row_ptr || !dboost_tol || !dboost_val) { - verify_hipsparse_status_success( - HIPSPARSE_STATUS_ALLOC_FAILED, - "!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || !dbsr_row_ptr"); + verify_hipsparse_status_success(HIPSPARSE_STATUS_ALLOC_FAILED, + "!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || " + "!dbsr_row_ptr || !dboost_tol || !dboost_val"); return HIPSPARSE_STATUS_ALLOC_FAILED; } @@ -604,6 +648,9 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) hipsparseStatus_t status_solve_1; hipsparseStatus_t status_solve_2; + CHECK_HIP_ERROR(hipMemcpy(dboost_tol, &boost_tol, sizeof(double), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dboost_val, &boost_val, sizeof(T), hipMemcpyHostToDevice)); + // bsrilu02 analysis - host mode CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST)); CHECK_HIPSPARSE_ERROR(hipsparseXbsrilu02_analysis(handle, @@ -652,6 +699,8 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) // bsrilu02 solve - host mode CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST)); + CHECK_HIPSPARSE_ERROR( + hipsparseXbsrilu02_numericBoost(handle, info, boost, &boost_tol, &boost_val)); CHECK_HIPSPARSE_ERROR(hipsparseXbsrilu02(handle, dir, mb, @@ -675,6 +724,8 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) // bsrilu02 solve - device mode CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_DEVICE)); + CHECK_HIPSPARSE_ERROR( + hipsparseXbsrilu02_numericBoost(handle, info, boost, dboost_tol, dboost_val)); CHECK_HIPSPARSE_ERROR(hipsparseXbsrilu02(handle, dir, mb, @@ -713,7 +764,7 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) CHECK_HIP_ERROR( hipMemcpy(&h_solve_pivot_2, d_solve_pivot_2, sizeof(int), hipMemcpyDeviceToHost)); - // Host csrilu02 + // Host bsrilu02 double cpu_time_used = get_time_us(); int numerical_pivot; @@ -726,7 +777,10 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) hbsr_val, idx_base, &structural_pivot, - &numerical_pivot); + &numerical_pivot, + boost, + boost_tol, + boost_val); cpu_time_used = get_time_us() - cpu_time_used; @@ -761,4 +815,4 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus) return HIPSPARSE_STATUS_SUCCESS; } -#endif // TESTING_BSRILU02_HPP +#endif // TESTING_BSRILU02_HPP \ No newline at end of file diff --git a/clients/include/testing_csrilu02.hpp b/clients/include/testing_csrilu02.hpp index 16c6c277..898303b5 100644 --- a/clients/include/testing_csrilu02.hpp +++ b/clients/include/testing_csrilu02.hpp @@ -64,13 +64,17 @@ void testing_csrilu02_bad_arg(void) auto dval_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free}; auto dbuffer_managed = hipsparse_unique_ptr{device_malloc(sizeof(char) * safe_size), device_free}; + auto dboost_tol_managed = hipsparse_unique_ptr{device_malloc(sizeof(double)), device_free}; + auto dboost_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T)), device_free}; - int* dptr = (int*)dptr_managed.get(); - int* dcol = (int*)dcol_managed.get(); - T* dval = (T*)dval_managed.get(); - void* dbuffer = (void*)dbuffer_managed.get(); + int* dptr = (int*)dptr_managed.get(); + int* dcol = (int*)dcol_managed.get(); + T* dval = (T*)dval_managed.get(); + void* dbuffer = (void*)dbuffer_managed.get(); + double* dboost_tol = (double*)dboost_tol_managed.get(); + T* dboost_val = (T*)dboost_val_managed.get(); - if(!dval || !dptr || !dcol || !dbuffer) + if(!dval || !dptr || !dcol || !dbuffer || !dboost_tol || !dboost_val) { PRINT_IF_HIP_ERROR(hipErrorOutOfMemory); return; @@ -136,6 +140,40 @@ void testing_csrilu02_bad_arg(void) verify_hipsparse_status_invalid_handle(status); } + // testing hipsparseXcsrilu02_numericBoost + + // testing for(nullptr == handle) + { + hipsparseHandle_t handle_null = nullptr; + + status = hipsparseXcsrilu02_numericBoost(handle_null, info, 1, dboost_tol, dboost_val); + verify_hipsparse_status_invalid_handle(status); + } + + // testing for(nullptr == info) + { + csrilu02Info_t info_null = nullptr; + + status = hipsparseXcsrilu02_numericBoost(handle, info_null, 1, dboost_tol, dboost_val); + verify_hipsparse_status_invalid_pointer(status, "Error: info is nullptr"); + } + + // testing for(nullptr == dboost_tol) + { + double* boost_tol_null = nullptr; + + status = hipsparseXcsrilu02_numericBoost(handle, info, 1, boost_tol_null, dboost_val); + verify_hipsparse_status_invalid_pointer(status, "Error: boost_tol is nullptr"); + } + + // testing for(nullptr == dboost_val) + { + T* boost_val_null = nullptr; + + status = hipsparseXcsrilu02_numericBoost(handle, info, 1, dboost_tol, boost_val_null); + verify_hipsparse_status_invalid_pointer(status, "Error: boost_val is nullptr"); + } + // testing hipsparseXcsrilu02_analysis // testing for(nullptr == dptr) @@ -285,6 +323,9 @@ hipsparseStatus_t testing_csrilu02(Arguments argus) { int safe_size = 100; int m = argus.M; + int boost = argus.numericboost; + double boost_tol = argus.boosttol; + T boost_val = make_DataType(argus.boostval, argus.boostvali); hipsparseIndexBase_t idx_base = argus.idx_base; hipsparseSolvePolicy_t policy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL; std::string binfile = ""; @@ -457,20 +498,27 @@ hipsparseStatus_t testing_csrilu02(Arguments argus) } // Allocate memory on device - auto dptr_managed = hipsparse_unique_ptr{device_malloc(sizeof(int) * (m + 1)), device_free}; - auto dcol_managed = hipsparse_unique_ptr{device_malloc(sizeof(int) * nnz), device_free}; - auto dval_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * nnz), device_free}; + auto dptr_managed = hipsparse_unique_ptr{device_malloc(sizeof(int) * (m + 1)), device_free}; + auto dcol_managed = hipsparse_unique_ptr{device_malloc(sizeof(int) * nnz), device_free}; + auto dval1_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * nnz), device_free}; + auto dval2_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * nnz), device_free}; auto d_position_managed = hipsparse_unique_ptr{device_malloc(sizeof(int)), device_free}; - - int* dptr = (int*)dptr_managed.get(); - int* dcol = (int*)dcol_managed.get(); - T* dval = (T*)dval_managed.get(); - int* d_position = (int*)d_position_managed.get(); - - if(!dval || !dptr || !dcol || !d_position) - { - verify_hipsparse_status_success(HIPSPARSE_STATUS_ALLOC_FAILED, - "!dval || !dptr || !dcol || !d_position"); + auto boost_tol_managed = hipsparse_unique_ptr{device_malloc(sizeof(double)), device_free}; + auto boost_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T)), device_free}; + + int* dptr = (int*)dptr_managed.get(); + int* dcol = (int*)dcol_managed.get(); + T* dval1 = (T*)dval1_managed.get(); + T* dval2 = (T*)dval2_managed.get(); + int* d_position = (int*)d_position_managed.get(); + double* dboost_tol = (double*)boost_tol_managed.get(); + T* dboost_val = (T*)boost_val_managed.get(); + + if(!dval1 || !dval2 || !dptr || !dcol || !d_position || !dboost_tol || !dboost_val) + { + verify_hipsparse_status_success( + HIPSPARSE_STATUS_ALLOC_FAILED, + "!dval1 || !dval2|| !dptr || !dcol || !d_position || !dboost_tol || !dboost_val"); return HIPSPARSE_STATUS_ALLOC_FAILED; } @@ -478,11 +526,12 @@ hipsparseStatus_t testing_csrilu02(Arguments argus) CHECK_HIP_ERROR( hipMemcpy(dptr, hcsr_row_ptr.data(), sizeof(int) * (m + 1), hipMemcpyHostToDevice)); CHECK_HIP_ERROR(hipMemcpy(dcol, hcsr_col_ind.data(), sizeof(int) * nnz, hipMemcpyHostToDevice)); - CHECK_HIP_ERROR(hipMemcpy(dval, hcsr_val.data(), sizeof(T) * nnz, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dval1, hcsr_val.data(), sizeof(T) * nnz, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dval2, hcsr_val.data(), sizeof(T) * nnz, hipMemcpyHostToDevice)); // Obtain csrilu02 buffer size CHECK_HIPSPARSE_ERROR( - hipsparseXcsrilu02_bufferSize(handle, m, nnz, descr, dval, dptr, dcol, info, &size)); + hipsparseXcsrilu02_bufferSize(handle, m, nnz, descr, dval1, dptr, dcol, info, &size)); // Allocate buffer on the device auto dbuffer_managed = hipsparse_unique_ptr{device_malloc(sizeof(char) * size), device_free}; @@ -495,39 +544,53 @@ hipsparseStatus_t testing_csrilu02(Arguments argus) return HIPSPARSE_STATUS_ALLOC_FAILED; } - // csrilu02 analysis - CHECK_HIPSPARSE_ERROR(hipsparseXcsrilu02_analysis( - handle, m, nnz, descr, dval, dptr, dcol, info, policy, dbuffer)); - if(argus.unit_check) { - CHECK_HIPSPARSE_ERROR( - hipsparseXcsrilu02(handle, m, nnz, descr, dval, dptr, dcol, info, policy, dbuffer)); + CHECK_HIP_ERROR(hipMemcpy(dboost_tol, &boost_tol, sizeof(double), hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(dboost_val, &boost_val, sizeof(T), hipMemcpyHostToDevice)); // Pointer mode host CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST)); - + CHECK_HIPSPARSE_ERROR(hipsparseXcsrilu02_analysis( + handle, m, nnz, descr, dval1, dptr, dcol, info, policy, dbuffer)); + CHECK_HIPSPARSE_ERROR( + hipsparseXcsrilu02_numericBoost(handle, info, boost, &boost_tol, &boost_val)); + CHECK_HIPSPARSE_ERROR( + hipsparseXcsrilu02(handle, m, nnz, descr, dval1, dptr, dcol, info, policy, dbuffer)); int hposition_1; hipsparseStatus_t pivot_status_1; pivot_status_1 = hipsparseXcsrilu02_zeroPivot(handle, info, &hposition_1); // Pointer mode device CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_DEVICE)); - + CHECK_HIPSPARSE_ERROR(hipsparseXcsrilu02_analysis( + handle, m, nnz, descr, dval2, dptr, dcol, info, policy, dbuffer)); + CHECK_HIPSPARSE_ERROR( + hipsparseXcsrilu02_numericBoost(handle, info, boost, dboost_tol, dboost_val)); + CHECK_HIPSPARSE_ERROR( + hipsparseXcsrilu02(handle, m, nnz, descr, dval2, dptr, dcol, info, policy, dbuffer)); + int hposition_2; hipsparseStatus_t pivot_status_2; pivot_status_2 = hipsparseXcsrilu02_zeroPivot(handle, info, d_position); + CHECK_HIP_ERROR(hipMemcpy(&hposition_2, d_position, sizeof(int), hipMemcpyDeviceToHost)); // Copy output from device to CPU - int hposition_2; - std::vector result(nnz); - CHECK_HIP_ERROR(hipMemcpy(result.data(), dval, sizeof(T) * nnz, hipMemcpyDeviceToHost)); - CHECK_HIP_ERROR(hipMemcpy(&hposition_2, d_position, sizeof(int), hipMemcpyDeviceToHost)); + std::vector result1(nnz); + std::vector result2(nnz); + CHECK_HIP_ERROR(hipMemcpy(result1.data(), dval1, sizeof(T) * nnz, hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR(hipMemcpy(result2.data(), dval2, sizeof(T) * nnz, hipMemcpyDeviceToHost)); // Host csrilu02 double cpu_time_used = get_time_us(); - int position_gold - = csrilu0(m, hcsr_row_ptr.data(), hcsr_col_ind.data(), hcsr_val.data(), idx_base); + int position_gold = csrilu0(m, + hcsr_row_ptr.data(), + hcsr_col_ind.data(), + hcsr_val.data(), + idx_base, + boost, + boost_tol, + boost_val); cpu_time_used = get_time_us() - cpu_time_used; @@ -549,10 +612,12 @@ hipsparseStatus_t testing_csrilu02(Arguments argus) } #if defined(__HIP_PLATFORM_HCC__) - unit_check_general(1, nnz, 1, hcsr_val.data(), result.data()); + unit_check_general(1, nnz, 1, hcsr_val.data(), result1.data()); + unit_check_general(1, nnz, 1, hcsr_val.data(), result2.data()); #elif defined(__HIP_PLATFORM_NVCC__) // do weaker check for cusparse - unit_check_near(1, nnz, 1, hcsr_val.data(), result.data()); + unit_check_near(1, nnz, 1, hcsr_val.data(), result1.data()); + unit_check_near(1, nnz, 1, hcsr_val.data(), result2.data()); #endif } diff --git a/clients/include/testing_csrilusv.hpp b/clients/include/testing_csrilusv.hpp index dc8ea89c..76835aa4 100644 --- a/clients/include/testing_csrilusv.hpp +++ b/clients/include/testing_csrilusv.hpp @@ -156,8 +156,14 @@ hipsparseStatus_t testing_csrilusv(Arguments argus) CHECK_HIP_ERROR(hipMemcpy(&hposition_2, d_position, sizeof(int), hipMemcpyDeviceToHost)); // Compute host reference csrilu0 - int position_gold - = csrilu0(m, hcsr_row_ptr.data(), hcsr_col_ind.data(), hcsr_val.data(), idx_base); + int position_gold = csrilu0(m, + hcsr_row_ptr.data(), + hcsr_col_ind.data(), + hcsr_val.data(), + idx_base, + false, + 0.0, + make_DataType(0.0, 0.0)); // Check zero pivot results unit_check_general(1, 1, 1, &position_gold, &hposition_1); diff --git a/clients/include/utility.hpp b/clients/include/utility.hpp index 615de71a..45e35d2f 100644 --- a/clients/include/utility.hpp +++ b/clients/include/utility.hpp @@ -1949,7 +1949,14 @@ inline void host_bsrmm(int Mb, } template -int csrilu0(int m, const int* ptr, const int* col, T* val, hipsparseIndexBase_t idx_base) +int csrilu0(int m, + const int* ptr, + const int* col, + T* val, + hipsparseIndexBase_t idx_base, + bool boost, + double boost_tol, + T boost_val) { // pointer of upper part of each row std::vector diag_offset(m); @@ -1977,30 +1984,38 @@ int csrilu0(int m, const int* ptr, const int* col, T* val, hipsparseIndexBase_t // if nnz entry is in lower matrix if(col[j] - idx_base < ai) { - int col_j = col[j] - idx_base; int diag_j = diag_offset[col_j]; - if(val[diag_j] != make_DataType(0.0)) - { - // multiplication factor - val[j] = val[j] / val[diag_j]; + T diag_val = val[diag_j]; - // loop over upper offset pointer and do linear combination for nnz entry - for(int k = diag_j + 1; k < ptr[col_j + 1] - idx_base; ++k) + if(boost) + { + diag_val = (boost_tol >= testing_abs(diag_val)) ? boost_val : diag_val; + val[diag_j] = diag_val; + } + else + { + // Check for numeric pivot + if(diag_val == make_DataType(0.0)) { - // if nnz at this position do linear combination - if(nnz_entries[col[k] - idx_base] != 0) - { - int idx = nnz_entries[col[k] - idx_base]; - val[idx] = testing_fma(testing_neg(val[j]), val[k], val[idx]); - } + // Numerical zero diagonal + return col_j + idx_base; } } - else + + // multiplication factor + val[j] = val[j] / diag_val; + + // loop over upper offset pointer and do linear combination for nnz entry + for(int k = diag_j + 1; k < ptr[col_j + 1] - idx_base; ++k) { - // Numerical zero diagonal - return col_j + idx_base; + // if nnz at this position do linear combination + if(nnz_entries[col[k] - idx_base] != 0) + { + int idx = nnz_entries[col[k] - idx_base]; + val[idx] = testing_fma(testing_neg(val[j]), val[k], val[idx]); + } } } else if(col[j] - idx_base == ai) @@ -2042,7 +2057,10 @@ inline void host_bsrilu02(hipsparseDirection_t dir, std::vector& bsr_val, hipsparseIndexBase_t base, int* struct_pivot, - int* numeric_pivot) + int* numeric_pivot, + bool boost, + double boost_tol, + T boost_val) { // Initialize pivots *struct_pivot = mb + 1; @@ -2173,11 +2191,19 @@ inline void host_bsrilu02(hipsparseDirection_t dir, { T diag = bsr_val[BSR_IND(j, bi, bi, dir)]; - // Check for numeric pivot - if(diag == make_DataType(0)) + if(boost) { - *numeric_pivot = std::min(*numeric_pivot, bsr_col_ind[j]); - continue; + diag = (boost_tol >= testing_abs(diag)) ? boost_val : diag; + bsr_val[BSR_IND(j, bi, bi, dir)] = diag; + } + else + { + // Check for numeric pivot + if(diag == make_DataType(0)) + { + *numeric_pivot = std::min(*numeric_pivot, bsr_col_ind[j]); + continue; + } } // Process all rows within the BSR block after bi-th row @@ -4142,6 +4168,11 @@ class Arguments int ell_width = 0; int temp = 0; + int numericboost; + double boosttol; + double boostval; + double boostvali; + std::string filename = ""; Arguments& operator=(const Arguments& rhs) @@ -4184,6 +4215,11 @@ class Arguments this->ell_width = rhs.ell_width; this->temp = rhs.temp; + this->numericboost = rhs.numericboost; + this->boosttol = rhs.boosttol; + this->boostval = rhs.boostval; + this->boostvali = rhs.boostvali; + this->filename = rhs.filename; return *this; diff --git a/clients/tests/test_bsrilu02.cpp b/clients/tests/test_bsrilu02.cpp index de604bf3..5e2eb82f 100644 --- a/clients/tests/test_bsrilu02.cpp +++ b/clients/tests/test_bsrilu02.cpp @@ -29,14 +29,19 @@ #include #include -typedef hipsparseIndexBase_t base; -typedef hipsparseDirection_t dir; -typedef std::tuple bsrilu02_tuple; -typedef std::tuple bsrilu02_bin_tuple; +typedef hipsparseIndexBase_t base; +typedef hipsparseDirection_t dir; +typedef std::tuple bsrilu02_tuple; +typedef std::tuple bsrilu02_bin_tuple; int bsrilu02_M_range[] = {-1, 0, 50, 426}; int bsrilu02_dim_range[] = {-1, 0, 1, 3, 5, 9}; +int bsrilu02_boost_range[] = {0}; +double bsrilu02_boost_tol_range[] = {1.1}; +double bsrilu02_boost_val_range[] = {0.3}; +double bsrilu02_boost_vali_range[] = {0.2}; + base bsrilu02_idxbase_range[] = {HIPSPARSE_INDEX_BASE_ZERO, HIPSPARSE_INDEX_BASE_ONE}; dir bsrilu02_dir_range[] = {HIPSPARSE_DIRECTION_ROW, HIPSPARSE_DIRECTION_COLUMN}; @@ -72,25 +77,33 @@ class parameterized_bsrilu02_bin : public testing::TestWithParam(tup); - arg.block_dim = std::get<1>(tup); - arg.dirA = std::get<2>(tup); - arg.idx_base = std::get<3>(tup); - arg.timing = 0; + arg.M = std::get<0>(tup); + arg.block_dim = std::get<1>(tup); + arg.numericboost = std::get<2>(tup); + arg.boosttol = std::get<3>(tup); + arg.boostval = std::get<4>(tup); + arg.boostvali = std::get<5>(tup); + arg.dirA = std::get<6>(tup); + arg.idx_base = std::get<7>(tup); + arg.timing = 0; return arg; } Arguments setup_bsrilu02_arguments(bsrilu02_bin_tuple tup) { Arguments arg; - arg.M = -99; - arg.block_dim = std::get<0>(tup); - arg.dirA = std::get<1>(tup); - arg.idx_base = std::get<2>(tup); - arg.timing = 0; + arg.M = -99; + arg.block_dim = std::get<0>(tup); + arg.numericboost = std::get<1>(tup); + arg.boosttol = std::get<2>(tup); + arg.boostval = std::get<3>(tup); + arg.boostvali = std::get<4>(tup); + arg.dirA = std::get<5>(tup); + arg.idx_base = std::get<6>(tup); + arg.timing = 0; // Determine absolute path of test matrix - std::string bin_file = std::get<3>(tup); + std::string bin_file = std::get<7>(tup); // Get current executables absolute path char path_exe[PATH_MAX]; @@ -167,12 +180,20 @@ INSTANTIATE_TEST_CASE_P(bsrilu02, parameterized_bsrilu02, testing::Combine(testing::ValuesIn(bsrilu02_M_range), testing::ValuesIn(bsrilu02_dim_range), + testing::ValuesIn(bsrilu02_boost_range), + testing::ValuesIn(bsrilu02_boost_tol_range), + testing::ValuesIn(bsrilu02_boost_val_range), + testing::ValuesIn(bsrilu02_boost_vali_range), testing::ValuesIn(bsrilu02_dir_range), testing::ValuesIn(bsrilu02_idxbase_range))); INSTANTIATE_TEST_CASE_P(bsrilu02_bin, parameterized_bsrilu02_bin, testing::Combine(testing::ValuesIn(bsrilu02_dim_range), + testing::ValuesIn(bsrilu02_boost_range), + testing::ValuesIn(bsrilu02_boost_tol_range), + testing::ValuesIn(bsrilu02_boost_val_range), + testing::ValuesIn(bsrilu02_boost_vali_range), testing::ValuesIn(bsrilu02_dir_range), testing::ValuesIn(bsrilu02_idxbase_range), testing::ValuesIn(bsrilu02_bin))); diff --git a/clients/tests/test_csrilu02.cpp b/clients/tests/test_csrilu02.cpp index 359d7800..1de99ae2 100644 --- a/clients/tests/test_csrilu02.cpp +++ b/clients/tests/test_csrilu02.cpp @@ -30,11 +30,15 @@ #include #include -typedef hipsparseIndexBase_t base; -typedef std::tuple csrilu02_tuple; -typedef std::tuple csrilu02_bin_tuple; +typedef hipsparseIndexBase_t base; +typedef std::tuple csrilu02_tuple; +typedef std::tuple csrilu02_bin_tuple; -int csrilu02_M_range[] = {-1, 0, 50, 647}; +int csrilu02_M_range[] = {-1, 0, 50, 647}; +int csrilu02_boost_range[] = {0, 1}; +double csrilu02_boost_tol_range[] = {0.5}; +double csrilu02_boost_val_range[] = {0.3, 2.0}; +double csrilu02_boost_vali_range[] = {0.2, 1.0}; base csrilu02_idxbase_range[] = {HIPSPARSE_INDEX_BASE_ZERO, HIPSPARSE_INDEX_BASE_ONE}; @@ -79,21 +83,29 @@ class parameterized_csrilu02_bin : public testing::TestWithParam(tup); - arg.idx_base = std::get<1>(tup); - arg.timing = 0; + arg.M = std::get<0>(tup); + arg.numericboost = std::get<1>(tup); + arg.boosttol = std::get<2>(tup); + arg.boostval = std::get<3>(tup); + arg.boostvali = std::get<4>(tup); + arg.idx_base = std::get<5>(tup); + arg.timing = 0; return arg; } Arguments setup_csrilu02_arguments(csrilu02_bin_tuple tup) { Arguments arg; - arg.M = -99; - arg.idx_base = std::get<0>(tup); - arg.timing = 0; + arg.M = -99; + arg.numericboost = std::get<0>(tup); + arg.boosttol = std::get<1>(tup); + arg.boostval = std::get<2>(tup); + arg.boostvali = std::get<3>(tup); + arg.idx_base = std::get<4>(tup); + arg.timing = 0; // Determine absolute path of test matrix - std::string bin_file = std::get<1>(tup); + std::string bin_file = std::get<5>(tup); // Get current executables absolute path char path_exe[PATH_MAX]; @@ -169,9 +181,17 @@ TEST_P(parameterized_csrilu02_bin, csrilu02_bin_double) INSTANTIATE_TEST_CASE_P(csrilu02, parameterized_csrilu02, testing::Combine(testing::ValuesIn(csrilu02_M_range), + testing::ValuesIn(csrilu02_boost_range), + testing::ValuesIn(csrilu02_boost_tol_range), + testing::ValuesIn(csrilu02_boost_val_range), + testing::ValuesIn(csrilu02_boost_vali_range), testing::ValuesIn(csrilu02_idxbase_range))); INSTANTIATE_TEST_CASE_P(csrilu02_bin, parameterized_csrilu02_bin, - testing::Combine(testing::ValuesIn(csrilu02_idxbase_range), + testing::Combine(testing::ValuesIn(csrilu02_boost_range), + testing::ValuesIn(csrilu02_boost_tol_range), + testing::ValuesIn(csrilu02_boost_val_range), + testing::ValuesIn(csrilu02_boost_vali_range), + testing::ValuesIn(csrilu02_idxbase_range), testing::ValuesIn(csrilu02_bin))); diff --git a/library/include/hipsparse.h b/library/include/hipsparse.h index d98267c7..cb1d720b 100644 --- a/library/include/hipsparse.h +++ b/library/include/hipsparse.h @@ -2540,6 +2540,31 @@ HIPSPARSE_EXPORT hipsparseStatus_t hipsparseXcsrilu02_zeroPivot(hipsparseHandle_t handle, csrilu02Info_t info, int* position); +HIPSPARSE_EXPORT +hipsparseStatus_t hipsparseScsrilu02_numericBoost( + hipsparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, float* boost_val); + +HIPSPARSE_EXPORT +hipsparseStatus_t hipsparseDcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + double* boost_val); + +HIPSPARSE_EXPORT +hipsparseStatus_t hipsparseCcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipComplex* boost_val); + +HIPSPARSE_EXPORT +hipsparseStatus_t hipsparseZcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipDoubleComplex* boost_val); + HIPSPARSE_EXPORT hipsparseStatus_t hipsparseScsrilu02_bufferSize(hipsparseHandle_t handle, int m, diff --git a/library/src/hcc_detail/hipsparse.cpp b/library/src/hcc_detail/hipsparse.cpp index a90ea8c7..0c59827f 100644 --- a/library/src/hcc_detail/hipsparse.cpp +++ b/library/src/hcc_detail/hipsparse.cpp @@ -65,6 +65,31 @@ extern "C" { } \ } +// Functions needed for hipsparse to match cuda API but which not part of rocsparse backend API +extern rocsparse_status rocsparse_dsbsrilu0_numeric_boost(rocsparse_handle handle, + rocsparse_mat_info info, + int enable_boost, + const double* boost_tol, + const float* boost_val); + +extern rocsparse_status rocsparse_dcbsrilu0_numeric_boost(rocsparse_handle handle, + rocsparse_mat_info info, + int enable_boost, + const double* boost_tol, + const rocsparse_float_complex* boost_val); + +extern rocsparse_status rocsparse_dscsrilu0_numeric_boost(rocsparse_handle handle, + rocsparse_mat_info info, + int enable_boost, + const double* boost_tol, + const float* boost_val); + +extern rocsparse_status rocsparse_dccsrilu0_numeric_boost(rocsparse_handle handle, + rocsparse_mat_info info, + int enable_boost, + const double* boost_tol, + const rocsparse_float_complex* boost_val); + hipsparseStatus_t hipErrorToHIPSPARSEStatus(hipError_t status) { switch(status) @@ -5062,6 +5087,48 @@ hipsparseStatus_t return HIPSPARSE_STATUS_SUCCESS; } +hipsparseStatus_t hipsparseSbsrilu02_numericBoost( + hipsparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double* tol, float* boost_val) +{ + return rocSPARSEStatusToHIPStatus(rocsparse_dsbsrilu0_numeric_boost( + (rocsparse_handle)handle, (rocsparse_mat_info)info, enable_boost, tol, boost_val)); +} + +hipsparseStatus_t hipsparseDbsrilu02_numericBoost( + hipsparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double* tol, double* boost_val) +{ + return rocSPARSEStatusToHIPStatus(rocsparse_dbsrilu0_numeric_boost( + (rocsparse_handle)handle, (rocsparse_mat_info)info, enable_boost, tol, boost_val)); +} + +hipsparseStatus_t hipsparseCbsrilu02_numericBoost(hipsparseHandle_t handle, + bsrilu02Info_t info, + int enable_boost, + double* tol, + hipComplex* boost_val) +{ + return rocSPARSEStatusToHIPStatus( + rocsparse_dcbsrilu0_numeric_boost((rocsparse_handle)handle, + (rocsparse_mat_info)info, + enable_boost, + tol, + (rocsparse_float_complex*)boost_val)); +} + +hipsparseStatus_t hipsparseZbsrilu02_numericBoost(hipsparseHandle_t handle, + bsrilu02Info_t info, + int enable_boost, + double* tol, + hipDoubleComplex* boost_val) +{ + return rocSPARSEStatusToHIPStatus( + rocsparse_zbsrilu0_numeric_boost((rocsparse_handle)handle, + (rocsparse_mat_info)info, + enable_boost, + tol, + (rocsparse_double_complex*)boost_val)); +} + hipsparseStatus_t hipsparseSbsrilu02_bufferSize(hipsparseHandle_t handle, hipsparseDirection_t dirA, int mb, @@ -5487,6 +5554,48 @@ hipsparseStatus_t return HIPSPARSE_STATUS_SUCCESS; } +hipsparseStatus_t hipsparseScsrilu02_numericBoost( + hipsparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, float* boost_val) +{ + return rocSPARSEStatusToHIPStatus(rocsparse_dscsrilu0_numeric_boost( + (rocsparse_handle)handle, (rocsparse_mat_info)info, enable_boost, tol, boost_val)); +} + +hipsparseStatus_t hipsparseDcsrilu02_numericBoost( + hipsparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, double* boost_val) +{ + return rocSPARSEStatusToHIPStatus(rocsparse_dcsrilu0_numeric_boost( + (rocsparse_handle)handle, (rocsparse_mat_info)info, enable_boost, tol, boost_val)); +} + +hipsparseStatus_t hipsparseCcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipComplex* boost_val) +{ + return rocSPARSEStatusToHIPStatus( + rocsparse_dccsrilu0_numeric_boost((rocsparse_handle)handle, + (rocsparse_mat_info)info, + enable_boost, + tol, + (rocsparse_float_complex*)boost_val)); +} + +hipsparseStatus_t hipsparseZcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipDoubleComplex* boost_val) +{ + return rocSPARSEStatusToHIPStatus( + rocsparse_zcsrilu0_numeric_boost((rocsparse_handle)handle, + (rocsparse_mat_info)info, + enable_boost, + tol, + (rocsparse_double_complex*)boost_val)); +} + hipsparseStatus_t hipsparseScsrilu02_bufferSize(hipsparseHandle_t handle, int m, int nnz, diff --git a/library/src/nvcc_detail/hipsparse.cpp b/library/src/nvcc_detail/hipsparse.cpp index b802f91b..e1469315 100644 --- a/library/src/nvcc_detail/hipsparse.cpp +++ b/library/src/nvcc_detail/hipsparse.cpp @@ -4770,6 +4770,44 @@ hipsparseStatus_t cusparseXcsrilu02_zeroPivot((cusparseHandle_t)handle, (csrilu02Info_t)info, position)); } +hipsparseStatus_t hipsparseScsrilu02_numericBoost( + hipsparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, float* boost_val) +{ + return hipCUSPARSEStatusToHIPStatus(cusparseScsrilu02_numericBoost( + (cusparseHandle_t)handle, (csrilu02Info_t)info, enable_boost, tol, boost_val)); +} + +hipsparseStatus_t hipsparseDcsrilu02_numericBoost( + hipsparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, double* boost_val) +{ + return hipCUSPARSEStatusToHIPStatus(cusparseDcsrilu02_numericBoost( + (cusparseHandle_t)handle, (csrilu02Info_t)info, enable_boost, tol, boost_val)); +} + +hipsparseStatus_t hipsparseCcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipComplex* boost_val) +{ + return hipCUSPARSEStatusToHIPStatus(cusparseCcsrilu02_numericBoost( + (cusparseHandle_t)handle, (csrilu02Info_t)info, enable_boost, tol, (cuComplex*)boost_val)); +} + +hipsparseStatus_t hipsparseZcsrilu02_numericBoost(hipsparseHandle_t handle, + csrilu02Info_t info, + int enable_boost, + double* tol, + hipDoubleComplex* boost_val) +{ + return hipCUSPARSEStatusToHIPStatus( + cusparseZcsrilu02_numericBoost((cusparseHandle_t)handle, + (csrilu02Info_t)info, + enable_boost, + tol, + (cuDoubleComplex*)boost_val)); +} + hipsparseStatus_t hipsparseScsrilu02_bufferSize(hipsparseHandle_t handle, int m, int nnz,