From 2d997a991cf38f1b4f838b213edb4b8720b34b58 Mon Sep 17 00:00:00 2001 From: malphil Date: Wed, 17 Apr 2024 14:06:28 -0600 Subject: [PATCH] Early return in kokkos kernels sparse trsv in the event numRows == 0 --- .../sparse/impl/KokkosSparse_trsv_impl.hpp | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/packages/kokkos-kernels/sparse/impl/KokkosSparse_trsv_impl.hpp b/packages/kokkos-kernels/sparse/impl/KokkosSparse_trsv_impl.hpp index 9adb029d12ab..6d43c55dc6d9 100644 --- a/packages/kokkos-kernels/sparse/impl/KokkosSparse_trsv_impl.hpp +++ b/packages/kokkos-kernels/sparse/impl/KokkosSparse_trsv_impl.hpp @@ -185,6 +185,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; const lno_t numVecs = X.extent(1); @@ -212,6 +214,8 @@ struct TrsvWrap { static void lowerTriSolveCsr(RangeMultiVectorType X, const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; const lno_t numVecs = X.extent(1); @@ -255,6 +259,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; const lno_t numVecs = X.extent(1); @@ -305,6 +311,8 @@ struct TrsvWrap { static void upperTriSolveCsr(RangeMultiVectorType X, const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; const lno_t numVecs = X.extent(1); @@ -372,6 +380,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; @@ -423,6 +433,8 @@ struct TrsvWrap { static void upperTriSolveCsc(RangeMultiVectorType X, const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; @@ -482,6 +494,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; @@ -511,6 +525,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; @@ -563,6 +579,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; @@ -621,6 +639,8 @@ struct TrsvWrap { static void lowerTriSolveCsc(RangeMultiVectorType X, const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; @@ -658,6 +678,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows; @@ -687,6 +709,8 @@ struct TrsvWrap { const CrsMatrixType& A, DomainMultiVectorType Y) { const lno_t numRows = A.numRows(); + if(numRows == 0) return; + const lno_t numCols = A.numCols(); const lno_t numPointRows = A.numPointRows(); const lno_t block_size = numPointRows / numRows;