Skip to content

Commit

Permalink
safe size of csrmv, csr2bsr, bsr2csr, gebsr2csr and csr2gebsr test mu…
Browse files Browse the repository at this point in the history
…st be at least max(100, max(nrow, ncol)) (#312)
  • Loading branch information
ntrost57 authored Feb 27, 2023
1 parent 7a49d87 commit fde9c7d
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ endif()


# Setup version
rocm_setup_version(VERSION 2.3.5)
rocm_setup_version(VERSION 2.3.6)
set(hipsparse_SOVERSION 0.1)

# hipSPARSE library
Expand Down
9 changes: 6 additions & 3 deletions clients/include/testing_bsr2csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ hipsparseStatus_t testing_bsr2csr(Arguments argus)

// When in testing mode, M == N == -99 indicates that we are testing with a real
// matrix from cise.ufl.edu
int safe_size = 100;
int safe_size = std::max(100, std::max(m, n));
if(m == -99 && n == -99 && argus.timing == 0)
{
binfile = argus.filename;
Expand Down Expand Up @@ -330,13 +330,13 @@ hipsparseStatus_t testing_bsr2csr(Arguments argus)
if(mb <= 0 || nb <= 0 || block_dim <= 0)
{
auto dbsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dbsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dbsr_val_managed
= hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free};
auto dcsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dcsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dcsr_val_managed
Expand All @@ -349,6 +349,9 @@ hipsparseStatus_t testing_bsr2csr(Arguments argus)
int* dcsr_col_ind = (int*)dcsr_col_ind_managed.get();
T* dcsr_val = (T*)dcsr_val_managed.get();

// row pointer array must be valid
CHECK_HIP_ERROR(hipMemset(dbsr_row_ptr, 0, sizeof(int) * (safe_size + 1)));

if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dcsr_row_ptr || !dcsr_col_ind
|| !dcsr_val)
{
Expand Down
9 changes: 6 additions & 3 deletions clients/include/testing_csr2bsr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ hipsparseStatus_t testing_csr2bsr(Arguments argus)

// When in testing mode, M == N == -99 indicates that we are testing with a real
// matrix from cise.ufl.edu
int safe_size = 100;
int safe_size = std::max(100, std::max(m, n));
if(m == -99 && n == -99 && argus.timing == 0)
{
binfile = argus.filename;
Expand Down Expand Up @@ -450,13 +450,13 @@ hipsparseStatus_t testing_csr2bsr(Arguments argus)
return HIPSPARSE_STATUS_SUCCESS;
#endif
auto dcsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dcsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dcsr_val_managed
= hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free};
auto dbsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dbsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dbsr_val_managed
Expand All @@ -469,6 +469,9 @@ hipsparseStatus_t testing_csr2bsr(Arguments argus)
int* dbsr_col_ind = (int*)dbsr_col_ind_managed.get();
T* dbsr_val = (T*)dbsr_val_managed.get();

// row pointer need to be valid
CHECK_HIP_ERROR(hipMemset(dcsr_row_ptr, 0, sizeof(int) * (safe_size + 1)));

if(!dcsr_row_ptr || !dcsr_col_ind || !dcsr_val || !dbsr_row_ptr || !dbsr_col_ind
|| !dbsr_val)
{
Expand Down
9 changes: 6 additions & 3 deletions clients/include/testing_csr2gebsr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ hipsparseStatus_t testing_csr2gebsr(Arguments argus)

// When in testing mode, M == N == -99 indicates that we are testing with a real
// matrix from cise.ufl.edu
int safe_size = 100;
int safe_size = std::max(100, std::max(m, n));
if(m == -99 && n == -99 && argus.timing == 0)
{
binfile = argus.filename;
Expand Down Expand Up @@ -488,13 +488,13 @@ hipsparseStatus_t testing_csr2gebsr(Arguments argus)
return HIPSPARSE_STATUS_SUCCESS;
#endif
auto dcsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dcsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dcsr_val_managed
= hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free};
auto dbsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dbsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dbsr_val_managed
Expand All @@ -510,6 +510,9 @@ hipsparseStatus_t testing_csr2gebsr(Arguments argus)
T* dbsr_val = (T*)dbsr_val_managed.get();
void* dbuffer = dbuffer_managed.get();

// row pointer must be valid
CHECK_HIP_ERROR(hipMemset(dcsr_row_ptr, 0, sizeof(int) * (safe_size + 1)));

if(!dcsr_row_ptr || !dcsr_col_ind || !dcsr_val || !dbsr_row_ptr || !dbsr_col_ind
|| !dbsr_val || !dbuffer)
{
Expand Down
7 changes: 5 additions & 2 deletions clients/include/testing_csrmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ void testing_csrmv_bad_arg(void)
template <typename T>
hipsparseStatus_t testing_csrmv(Arguments argus)
{
int safe_size = 100;
int nrow = argus.M;
int ncol = argus.N;
int safe_size = std::max(100, std::max(nrow, ncol));
T h_alpha = make_DataType<T>(argus.alpha);
T h_beta = make_DataType<T>(argus.beta);
hipsparseOperation_t transA = argus.transA;
Expand Down Expand Up @@ -203,7 +203,7 @@ hipsparseStatus_t testing_csrmv(Arguments argus)
return HIPSPARSE_STATUS_SUCCESS;
#endif
auto dptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dcol_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dval_managed = hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free};
Expand All @@ -216,6 +216,9 @@ hipsparseStatus_t testing_csrmv(Arguments argus)
T* dx = (T*)dx_managed.get();
T* dy = (T*)dy_managed.get();

// row pointer should be valid
CHECK_HIP_ERROR(hipMemset(dptr, 0, sizeof(int) * (safe_size + 1)));

if(!dval || !dptr || !dcol || !dx || !dy)
{
verify_hipsparse_status_success(HIPSPARSE_STATUS_ALLOC_FAILED,
Expand Down
9 changes: 6 additions & 3 deletions clients/include/testing_gebsr2csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ hipsparseStatus_t testing_gebsr2csr(Arguments argus)

// When in testing mode, M == N == -99 indicates that we are testing with a real
// matrix from cise.ufl.edu
int safe_size = 100;
int safe_size = std::max(100, std::max(m, n));
if(m == -99 && n == -99 && argus.timing == 0)
{
binfile = argus.filename;
Expand Down Expand Up @@ -355,13 +355,13 @@ hipsparseStatus_t testing_gebsr2csr(Arguments argus)
if(mb <= 0 || nb <= 0 || row_block_dim <= 0 || col_block_dim <= 0)
{
auto dbsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dbsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dbsr_val_managed
= hipsparse_unique_ptr{device_malloc(sizeof(T) * safe_size), device_free};
auto dcsr_row_ptr_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
= hipsparse_unique_ptr{device_malloc(sizeof(int) * (safe_size + 1)), device_free};
auto dcsr_col_ind_managed
= hipsparse_unique_ptr{device_malloc(sizeof(int) * safe_size), device_free};
auto dcsr_val_managed
Expand All @@ -374,6 +374,9 @@ hipsparseStatus_t testing_gebsr2csr(Arguments argus)
int* dcsr_col_ind = (int*)dcsr_col_ind_managed.get();
T* dcsr_val = (T*)dcsr_val_managed.get();

// row pointer must be valid
CHECK_HIP_ERROR(hipMemset(dbsr_row_ptr, 0, sizeof(int) * (safe_size + 1)));

if(!dbsr_row_ptr || !dbsr_col_ind || !dbsr_val || !dcsr_row_ptr || !dcsr_col_ind
|| !dcsr_val)
{
Expand Down

0 comments on commit fde9c7d

Please sign in to comment.