Skip to content

Commit

Permalink
Early return in kokkos kernels sparse trsv in the event numRows == 0
Browse files Browse the repository at this point in the history
  • Loading branch information
malphil committed Apr 17, 2024
1 parent 23ccc58 commit 2d997a9
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions packages/kokkos-kernels/sparse/impl/KokkosSparse_trsv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand Down Expand Up @@ -212,6 +214,8 @@ struct TrsvWrap {
static void lowerTriSolveCsr(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand Down Expand Up @@ -255,6 +259,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand Down Expand Up @@ -305,6 +311,8 @@ struct TrsvWrap {
static void upperTriSolveCsr(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
const lno_t numVecs = X.extent(1);
Expand Down Expand Up @@ -372,6 +380,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -423,6 +433,8 @@ struct TrsvWrap {
static void upperTriSolveCsc(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -482,6 +494,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -511,6 +525,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -563,6 +579,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -621,6 +639,8 @@ struct TrsvWrap {
static void lowerTriSolveCsc(RangeMultiVectorType X, const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -658,6 +678,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down Expand Up @@ -687,6 +709,8 @@ struct TrsvWrap {
const CrsMatrixType& A,
DomainMultiVectorType Y) {
const lno_t numRows = A.numRows();
if(numRows == 0) return;

const lno_t numCols = A.numCols();
const lno_t numPointRows = A.numPointRows();
const lno_t block_size = numPointRows / numRows;
Expand Down

0 comments on commit 2d997a9

Please sign in to comment.