Skip to content

Commit

Permalink
Csrilu numeric boost (#139)
Browse files Browse the repository at this point in the history
* completing implementation of csrilu0_numericBoost in hipsparse

* clang formatting

* adding missing bsrilu numeric boost calls

* adding testing code for numeric boost

* clang formatting

* finishing hipsparse code for numeric boost

* clang formatting

* removing comments

* trigger CI
  • Loading branch information
jsandham authored Sep 23, 2020
1 parent c5e4633 commit 39bdb97
Show file tree
Hide file tree
Showing 11 changed files with 566 additions and 104 deletions.
80 changes: 80 additions & 0 deletions clients/common/hipsparse_template_specialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3469,6 +3469,46 @@ namespace hipsparse
}
#endif

template <>
hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle,
bsrilu02Info_t info,
int enable_boost,
double* tol,
float* boost_val)
{
return hipsparseSbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle,
bsrilu02Info_t info,
int enable_boost,
double* tol,
double* boost_val)
{
return hipsparseDbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle,
bsrilu02Info_t info,
int enable_boost,
double* tol,
hipComplex* boost_val)
{
return hipsparseCbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXbsrilu02_numericBoost(hipsparseHandle_t handle,
bsrilu02Info_t info,
int enable_boost,
double* tol,
hipDoubleComplex* boost_val)
{
return hipsparseZbsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXbsrilu02_bufferSize(hipsparseHandle_t handle,
hipsparseDirection_t dirA,
Expand Down Expand Up @@ -3797,6 +3837,46 @@ namespace hipsparse
pBuffer);
}

template <>
hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle,
csrilu02Info_t info,
int enable_boost,
double* tol,
float* boost_val)
{
return hipsparseScsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle,
csrilu02Info_t info,
int enable_boost,
double* tol,
double* boost_val)
{
return hipsparseDcsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle,
csrilu02Info_t info,
int enable_boost,
double* tol,
hipComplex* boost_val)
{
return hipsparseCcsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXcsrilu02_numericBoost(hipsparseHandle_t handle,
csrilu02Info_t info,
int enable_boost,
double* tol,
hipDoubleComplex* boost_val)
{
return hipsparseZcsrilu02_numericBoost(handle, info, enable_boost, tol, boost_val);
}

template <>
hipsparseStatus_t hipsparseXcsrilu02_bufferSize(hipsparseHandle_t handle,
int m,
Expand Down
8 changes: 8 additions & 0 deletions clients/include/hipsparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,10 @@ namespace hipsparse
void* pBuffer);
#endif

template <typename T>
hipsparseStatus_t hipsparseXbsrilu02_numericBoost(
hipsparseHandle_t handle, bsrilu02Info_t info, int enable_boost, double* tol, T* boost_val);

template <typename T>
hipsparseStatus_t hipsparseXbsrilu02_bufferSize(hipsparseHandle_t handle,
hipsparseDirection_t dirA,
Expand Down Expand Up @@ -564,6 +568,10 @@ namespace hipsparse
hipsparseSolvePolicy_t policy,
void* pBuffer);

template <typename T>
hipsparseStatus_t hipsparseXcsrilu02_numericBoost(
hipsparseHandle_t handle, csrilu02Info_t info, int enable_boost, double* tol, T* boost_val);

template <typename T>
hipsparseStatus_t hipsparseXcsrilu02_bufferSize(hipsparseHandle_t handle,
int m,
Expand Down
88 changes: 71 additions & 17 deletions clients/include/testing_bsrilu02.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

#include <cmath>
#include <hipsparse.h>
#include <iostream>
#include <string>

using namespace hipsparse;
Expand Down Expand Up @@ -67,13 +66,17 @@ void testing_bsrilu02_bad_arg(void)
auto dval_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free};
auto dbuffer_managed
= hipsparse_unique_ptr{device_malloc(sizeof(char) * safe_size), device_free};
auto dboost_tol_managed = hipsparse_unique_ptr{device_malloc(sizeof(double)), device_free};
auto dboost_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T)), device_free};

int* dptr = (int*)dptr_managed.get();
int* dcol = (int*)dcol_managed.get();
T* dval = (T*)dval_managed.get();
void* dbuffer = (void*)dbuffer_managed.get();
int* dptr = (int*)dptr_managed.get();
int* dcol = (int*)dcol_managed.get();
T* dval = (T*)dval_managed.get();
void* dbuffer = (void*)dbuffer_managed.get();
double* dboost_tol = (double*)dboost_tol_managed.get();
T* dboost_val = (T*)dboost_val_managed.get();

if(!dval || !dptr || !dcol || !dbuffer)
if(!dval || !dptr || !dcol || !dbuffer || !dboost_tol || !dboost_val)
{
PRINT_IF_HIP_ERROR(hipErrorOutOfMemory);
return;
Expand Down Expand Up @@ -139,6 +142,40 @@ void testing_bsrilu02_bad_arg(void)
verify_hipsparse_status_invalid_handle(status);
}

// testing hipsparseXbsrilu02_numericBoost

// testing for(nullptr == handle)
{
hipsparseHandle_t handle_null = nullptr;

status = hipsparseXbsrilu02_numericBoost(handle_null, info, 1, dboost_tol, dboost_val);
verify_hipsparse_status_invalid_handle(status);
}

// testing for(nullptr == info)
{
bsrilu02Info_t info_null = nullptr;

status = hipsparseXbsrilu02_numericBoost(handle, info_null, 1, dboost_tol, dboost_val);
verify_hipsparse_status_invalid_pointer(status, "Error: info is nullptr");
}

// testing for(nullptr == dboost_tol)
{
double* boost_tol_null = nullptr;

status = hipsparseXbsrilu02_numericBoost(handle, info, 1, boost_tol_null, dboost_val);
verify_hipsparse_status_invalid_pointer(status, "Error: boost_tol is nullptr");
}

// testing for(nullptr == dboost_val)
{
T* boost_val_null = nullptr;

status = hipsparseXbsrilu02_numericBoost(handle, info, 1, dboost_tol, boost_val_null);
verify_hipsparse_status_invalid_pointer(status, "Error: boost_val is nullptr");
}

// testing hipsparseXbsrilu02_analysis

// testing for(nullptr == dptr)
Expand Down Expand Up @@ -289,6 +326,9 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)
int safe_size = 100;
int m = argus.M;
int block_dim = argus.block_dim;
int boost = argus.numericboost;
double boost_tol = argus.boosttol;
T boost_val = make_DataType<T>(argus.boostval, argus.boostvali);
hipsparseDirection_t dir = argus.dirA;
hipsparseIndexBase_t idx_base = argus.idx_base;
hipsparseSolvePolicy_t policy = HIPSPARSE_SOLVE_POLICY_USE_LEVEL;
Expand Down Expand Up @@ -476,17 +516,21 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)
auto dcsr_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * nnz), device_free};
auto dbsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (mb + 1)), device_free};
auto boost_tol_managed = hipsparse_unique_ptr{device_malloc(sizeof(double)), device_free};
auto boost_val_managed = hipsparse_unique_ptr{device_malloc(sizeof(T)), device_free};

int* dcsr_row_ptr = (int*)dcsr_row_ptr_managed.get();
int* dcsr_col_ind = (int*)dcsr_col_ind_managed.get();
T* dcsr_val = (T*)dcsr_val_managed.get();
int* dbsr_row_ptr = (int*)dbsr_row_ptr_managed.get();
int* dcsr_row_ptr = (int*)dcsr_row_ptr_managed.get();
int* dcsr_col_ind = (int*)dcsr_col_ind_managed.get();
T* dcsr_val = (T*)dcsr_val_managed.get();
int* dbsr_row_ptr = (int*)dbsr_row_ptr_managed.get();
double* dboost_tol = (double*)boost_tol_managed.get();
T* dboost_val = (T*)boost_val_managed.get();

if(!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || !dbsr_row_ptr)
if(!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || !dbsr_row_ptr || !dboost_tol || !dboost_val)
{
verify_hipsparse_status_success(
HIPSPARSE_STATUS_ALLOC_FAILED,
"!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || !dbsr_row_ptr");
verify_hipsparse_status_success(HIPSPARSE_STATUS_ALLOC_FAILED,
"!dcsr_val || !dcsr_row_ptr || !dcsr_col_ind || "
"!dbsr_row_ptr || !dboost_tol || !dboost_val");
return HIPSPARSE_STATUS_ALLOC_FAILED;
}

Expand Down Expand Up @@ -604,6 +648,9 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)
hipsparseStatus_t status_solve_1;
hipsparseStatus_t status_solve_2;

CHECK_HIP_ERROR(hipMemcpy(dboost_tol, &boost_tol, sizeof(double), hipMemcpyHostToDevice));
CHECK_HIP_ERROR(hipMemcpy(dboost_val, &boost_val, sizeof(T), hipMemcpyHostToDevice));

// bsrilu02 analysis - host mode
CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST));
CHECK_HIPSPARSE_ERROR(hipsparseXbsrilu02_analysis(handle,
Expand Down Expand Up @@ -652,6 +699,8 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)

// bsrilu02 solve - host mode
CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST));
CHECK_HIPSPARSE_ERROR(
hipsparseXbsrilu02_numericBoost(handle, info, boost, &boost_tol, &boost_val));
CHECK_HIPSPARSE_ERROR(hipsparseXbsrilu02(handle,
dir,
mb,
Expand All @@ -675,6 +724,8 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)

// bsrilu02 solve - device mode
CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_DEVICE));
CHECK_HIPSPARSE_ERROR(
hipsparseXbsrilu02_numericBoost(handle, info, boost, dboost_tol, dboost_val));
CHECK_HIPSPARSE_ERROR(hipsparseXbsrilu02(handle,
dir,
mb,
Expand Down Expand Up @@ -713,7 +764,7 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)
CHECK_HIP_ERROR(
hipMemcpy(&h_solve_pivot_2, d_solve_pivot_2, sizeof(int), hipMemcpyDeviceToHost));

// Host csrilu02
// Host bsrilu02
double cpu_time_used = get_time_us();

int numerical_pivot;
Expand All @@ -726,7 +777,10 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)
hbsr_val,
idx_base,
&structural_pivot,
&numerical_pivot);
&numerical_pivot,
boost,
boost_tol,
boost_val);

cpu_time_used = get_time_us() - cpu_time_used;

Expand Down Expand Up @@ -761,4 +815,4 @@ hipsparseStatus_t testing_bsrilu02(Arguments argus)
return HIPSPARSE_STATUS_SUCCESS;
}

#endif // TESTING_BSRILU02_HPP
#endif // TESTING_BSRILU02_HPP
Loading

0 comments on commit 39bdb97

Please sign in to comment.