Skip to content

Commit

Permalink
Revert "reverted 5.6 api change (#371)" (#375)
Browse files Browse the repository at this point in the history
This reverts commit 88d58a8.
  • Loading branch information
YvanMokwinski authored Aug 11, 2023
1 parent 88d58a8 commit b480dd9
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 43 deletions.
4 changes: 0 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# Change Log for hipSPARSE

## hipSPARSE 2.3.7 for ROCm 5.6.1
### Bugfix
- Reverted an undocumented API change in hipSPARSE 2.3.6 that affected hipsparseSpSV_solve function

## hipSPARSE 2.3.6 for ROCm 5.6.0
### Added
- Added SpGEMM algorithms
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ endif()


# Setup version
rocm_setup_version(VERSION 2.3.7)
rocm_setup_version(VERSION 2.3.6)
set(hipsparse_SOVERSION 0.1)

# hipSPARSE library
Expand Down
18 changes: 9 additions & 9 deletions clients/include/testing_spsv_coo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,23 @@ void testing_spsv_coo_bad_arg(void)

// SpSV solve
verify_hipsparse_status_invalid_handle(
hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr, dbuf));
hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr));
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr),
"Error: alpha is nullptr");
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr),
"Error: A is nullptr");
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr),
"Error: x is nullptr");
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr),
"Error: y is nullptr");
#if(!defined(CUDART_VERSION))
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, descr, nullptr),
"Error: dbuf is nullptr");
hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, nullptr),
"Error: descr is nullptr");
#endif

// Destruct
Expand Down Expand Up @@ -307,12 +307,12 @@ hipsparseStatus_t testing_spsv_coo(void)
// HIPSPARSE pointer mode host
CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST));
CHECK_HIPSPARSE_ERROR(
hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr, buffer));
hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr));

// HIPSPARSE pointer mode device
CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_DEVICE));
CHECK_HIPSPARSE_ERROR(
hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr, buffer));
hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr));

// copy output from device to CPU
CHECK_HIP_ERROR(hipMemcpy(hy_1.data(), dy_1, sizeof(T) * m, hipMemcpyDeviceToHost));
Expand Down
18 changes: 9 additions & 9 deletions clients/include/testing_spsv_csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,23 @@ void testing_spsv_csr_bad_arg(void)

// SpSV solve
verify_hipsparse_status_invalid_handle(
hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr, dbuf));
hipsparseSpSV_solve(nullptr, transA, &alpha, A, x, y, dataType, alg, descr));
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, nullptr, A, x, y, dataType, alg, descr),
"Error: alpha is nullptr");
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, &alpha, nullptr, x, y, dataType, alg, descr),
"Error: A is nullptr");
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, &alpha, A, nullptr, y, dataType, alg, descr),
"Error: x is nullptr");
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr, dbuf),
hipsparseSpSV_solve(handle, transA, &alpha, A, x, nullptr, dataType, alg, descr),
"Error: y is nullptr");
#if(!defined(CUDART_VERSION))
verify_hipsparse_status_invalid_pointer(
hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, descr, nullptr),
"Error: dbuf is nullptr");
hipsparseSpSV_solve(handle, transA, &alpha, A, x, y, dataType, alg, nullptr),
"Error: descr is nullptr");
#endif

// Destruct
Expand Down Expand Up @@ -300,12 +300,12 @@ hipsparseStatus_t testing_spsv_csr(void)
// HIPSPARSE pointer mode host
CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_HOST));
CHECK_HIPSPARSE_ERROR(
hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr, buffer));
hipsparseSpSV_solve(handle, transA, &h_alpha, A, x, y1, typeT, alg, descr));

// HIPSPARSE pointer mode device
CHECK_HIPSPARSE_ERROR(hipsparseSetPointerMode(handle, HIPSPARSE_POINTER_MODE_DEVICE));
CHECK_HIPSPARSE_ERROR(
hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr, buffer));
hipsparseSpSV_solve(handle, transA, d_alpha, A, x, y2, typeT, alg, descr));

// copy output from device to CPU
CHECK_HIP_ERROR(hipMemcpy(hy_1.data(), dy_1, sizeof(T) * m, hipMemcpyDeviceToHost));
Expand Down
3 changes: 1 addition & 2 deletions library/include/hipsparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -10394,8 +10394,7 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle,
const hipsparseDnVecDescr_t y,
hipDataType computeType,
hipsparseSpSVAlg_t alg,
hipsparseSpSVDescr_t spsvDescr,
void* externalBuffer);
hipsparseSpSVDescr_t spsvDescr);
#endif

/*! \ingroup generic_module
Expand Down
52 changes: 36 additions & 16 deletions library/src/hcc_detail/hipsparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14259,15 +14259,25 @@ hipsparseStatus_t hipsparseSDDMM_preprocess(hipsparseHandle_t handle,
tempBuffer));
}

struct hipsparseSpSVDescr
{
void* externalBuffer{};
};

hipsparseStatus_t hipsparseSpSV_createDescr(hipsparseSpSVDescr_t* descr)
{
// Do nothing
*descr = new hipsparseSpSVDescr;
return HIPSPARSE_STATUS_SUCCESS;
}

hipsparseStatus_t hipsparseSpSV_destroyDescr(hipsparseSpSVDescr_t descr)
{
// Do nothing
if(descr != nullptr)
{
descr->externalBuffer = nullptr;
delete descr;
}

return HIPSPARSE_STATUS_SUCCESS;
}

Expand Down Expand Up @@ -14306,17 +14316,24 @@ hipsparseStatus_t hipsparseSpSV_analysis(hipsparseHandle_t handle,
hipsparseSpSVDescr_t spsvDescr,
void* externalBuffer)
{
return rocSPARSEStatusToHIPStatus(rocsparse_spsv((rocsparse_handle)handle,
hipOperationToHCCOperation(opA),
alpha,
(const rocsparse_spmat_descr)matA,
(const rocsparse_dnvec_descr)x,
(const rocsparse_dnvec_descr)y,
hipDataTypeToHCCDataType(computeType),
hipSpSVAlgToHCCSpSVAlg(alg),
rocsparse_spsv_stage_preprocess,
nullptr,
externalBuffer));

if(spsvDescr == nullptr)
{
return HIPSPARSE_STATUS_INVALID_VALUE;
}
RETURN_IF_ROCSPARSE_ERROR(rocsparse_spsv((rocsparse_handle)handle,
hipOperationToHCCOperation(opA),
alpha,
(const rocsparse_spmat_descr)matA,
(const rocsparse_dnvec_descr)x,
(const rocsparse_dnvec_descr)y,
hipDataTypeToHCCDataType(computeType),
hipSpSVAlgToHCCSpSVAlg(alg),
rocsparse_spsv_stage_preprocess,
nullptr,
externalBuffer));
spsvDescr->externalBuffer = externalBuffer;
return HIPSPARSE_STATUS_SUCCESS;
}

hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle,
Expand All @@ -14327,9 +14344,12 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle,
const hipsparseDnVecDescr_t y,
hipDataType computeType,
hipsparseSpSVAlg_t alg,
hipsparseSpSVDescr_t spsvDescr,
void* externalBuffer)
hipsparseSpSVDescr_t spsvDescr)
{
if(spsvDescr == nullptr)
{
return HIPSPARSE_STATUS_INVALID_VALUE;
}
return rocSPARSEStatusToHIPStatus(rocsparse_spsv((rocsparse_handle)handle,
hipOperationToHCCOperation(opA),
alpha,
Expand All @@ -14340,7 +14360,7 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle,
hipSpSVAlgToHCCSpSVAlg(alg),
rocsparse_spsv_stage_compute,
nullptr,
externalBuffer));
spsvDescr->externalBuffer));
}

hipsparseStatus_t hipsparseSpSM_createDescr(hipsparseSpSMDescr_t* descr)
Expand Down
3 changes: 1 addition & 2 deletions library/src/nvcc_detail/hipsparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12102,8 +12102,7 @@ hipsparseStatus_t hipsparseSpSV_solve(hipsparseHandle_t handle,
const hipsparseDnVecDescr_t y,
hipDataType computeType,
hipsparseSpSVAlg_t alg,
hipsparseSpSVDescr_t spsvDescr,
void* externalBuffer)
hipsparseSpSVDescr_t spsvDescr)
{
return hipCUSPARSEStatusToHIPStatus(cusparseSpSV_solve((cusparseHandle_t)handle,
hipOperationToCudaOperation(opA),
Expand Down

0 comments on commit b480dd9

Please sign in to comment.