Skip to content

Commit

Permalink
Merge pull request #12688 from ndellingwood/tpetra-spadd-api-update
Browse files Browse the repository at this point in the history
tpetra: update spadd_* api for KOKKOSKERNELS_VERSION >= 40299
  • Loading branch information
ndellingwood authored Jan 29, 2024
2 parents 7536a9e + 2946371 commit 4ac842a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
6 changes: 5 additions & 1 deletion packages/tpetra/core/ext/TpetraExt_MatrixMatrix_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ struct AddKernels
/// \param Browptrs Row pointers array for B
/// \param Bcolinds Column indices array for B
/// \param scalarB Scaling factor for B
/// \param numGlobalCols The global size of the column map
/// \param[Out] Cvals Values array for C (allocated inside function)
/// \param[Out] Crowptrs Row pointers array for C (allocated inside function)
/// \param[Out] Ccolinds Column indices array for C (allocated inside function)
Expand All @@ -692,6 +693,9 @@ struct AddKernels
const row_ptrs_array_const& Browptrs,
const col_inds_array& Bcolinds,
const impl_scalar_type scalarB,
#if KOKKOSKERNELS_VERSION >= 40299
GlobalOrdinal numGlobalCols,
#endif
values_array& Cvals,
row_ptrs_array& Crowptrs,
col_inds_array& Ccolinds);
Expand Down Expand Up @@ -728,7 +732,7 @@ struct AddKernels
/// \param Browptrs Row pointers array for B
/// \param Bcolinds Column indices array for B
/// \param scalarB Scaling factor for B
/// \param globalNumCols The global size of the column map
/// \param numGlobalCols The global size of the column map
/// \param[Out] Cvals Values array for C (allocated inside function)
/// \param[Out] Crowptrs Row pointers array for C (allocated inside function)
/// \param[Out] Ccolinds Column indices array for C (allocated inside function)
Expand Down
38 changes: 35 additions & 3 deletions packages/tpetra/core/ext/TpetraExt_MatrixMatrix_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,11 @@ add (const Scalar& alpha,
<< "Call AddKern::addSorted(...)" << std::endl;
std::cerr << os.str ();
}
#if KOKKOSKERNELS_VERSION >= 40299
AddKern::addSorted(Avals, Arowptrs, Acolinds, alpha, Bvals, Browptrs, Bcolinds, beta, Aprime->getGlobalNumCols(), vals, rowptrs, colinds);
#else
AddKern::addSorted(Avals, Arowptrs, Acolinds, alpha, Bvals, Browptrs, Bcolinds, beta, vals, rowptrs, colinds);
#endif
}
else
{
Expand Down Expand Up @@ -3709,6 +3713,9 @@ addSorted(
const typename MMdetails::AddKernels<SC, LO, GO, NO>::row_ptrs_array_const& Browptrs,
const typename MMdetails::AddKernels<SC, LO, GO, NO>::col_inds_array& Bcolinds,
const typename MMdetails::AddKernels<SC, LO, GO, NO>::impl_scalar_type scalarB,
#if KOKKOSKERNELS_VERSION >= 40299
GO numGlobalCols,
#endif
typename MMdetails::AddKernels<SC, LO, GO, NO>::values_array& Cvals,
typename MMdetails::AddKernels<SC, LO, GO, NO>::row_ptrs_array& Crowptrs,
typename MMdetails::AddKernels<SC, LO, GO, NO>::col_inds_array& Ccolinds)
Expand All @@ -3725,7 +3732,11 @@ addSorted(
auto MM = rcp(new TimeMonitor(*TimeMonitor::getNewTimer("TpetraExt::MatrixMatrix::add() sorted symbolic")));
#endif
KokkosSparse::Experimental::spadd_symbolic
(&handle, Arowptrs, Acolinds, Browptrs, Bcolinds, Crowptrs);
(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, numGlobalCols,
#endif
Arowptrs, Acolinds, Browptrs, Bcolinds, Crowptrs);
//KokkosKernels requires values to be zeroed
Cvals = values_array("C values", addHandle->get_c_nnz());
Ccolinds = col_inds_array(Kokkos::ViewAllocateWithoutInitializing("C colinds"), addHandle->get_c_nnz());
Expand All @@ -3734,6 +3745,9 @@ addSorted(
MM = rcp(new TimeMonitor(*TimeMonitor::getNewTimer("TpetraExt::MatrixMatrix::add() sorted numeric")));
#endif
KokkosSparse::Experimental::spadd_numeric(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, numGlobalCols,
#endif
Arowptrs, Acolinds, Avals, scalarA,
Browptrs, Bcolinds, Bvals, scalarB,
Crowptrs, Ccolinds, Cvals);
Expand All @@ -3750,7 +3764,11 @@ addUnsorted(
const typename MMdetails::AddKernels<SC, LO, GO, NO>::row_ptrs_array_const& Browptrs,
const typename MMdetails::AddKernels<SC, LO, GO, NO>::col_inds_array& Bcolinds,
const typename MMdetails::AddKernels<SC, LO, GO, NO>::impl_scalar_type scalarB,
#if KOKKOSKERNELS_VERSION >= 40299
GO numGlobalCols,
#else
GO /* numGlobalCols */,
#endif
typename MMdetails::AddKernels<SC, LO, GO, NO>::values_array& Cvals,
typename MMdetails::AddKernels<SC, LO, GO, NO>::row_ptrs_array& Crowptrs,
typename MMdetails::AddKernels<SC, LO, GO, NO>::col_inds_array& Ccolinds)
Expand All @@ -3768,7 +3786,11 @@ addUnsorted(
auto MM = rcp(new TimeMonitor(*TimeMonitor::getNewTimer("TpetraExt::MatrixMatrix::add() unsorted symbolic")));
#endif
KokkosSparse::Experimental::spadd_symbolic
(&handle, Arowptrs, Acolinds, Browptrs, Bcolinds, Crowptrs);
(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, numGlobalCols,
#endif
Arowptrs, Acolinds, Browptrs, Bcolinds, Crowptrs);
//Cvals must be zeroed out
Cvals = values_array("C values", addHandle->get_c_nnz());
Ccolinds = col_inds_array(Kokkos::ViewAllocateWithoutInitializing("C colinds"), addHandle->get_c_nnz());
Expand All @@ -3777,6 +3799,9 @@ addUnsorted(
MM = rcp(new TimeMonitor(*TimeMonitor::getNewTimer("TpetraExt::MatrixMatrix::add() unsorted kernel: unsorted numeric")));
#endif
KokkosSparse::Experimental::spadd_numeric(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, numGlobalCols,
#endif
Arowptrs, Acolinds, Avals, scalarA,
Browptrs, Bcolinds, Bvals, scalarB,
Crowptrs, Ccolinds, Cvals);
Expand Down Expand Up @@ -3850,14 +3875,21 @@ convertToGlobalAndAdd(
auto nrows = Arowptrs.extent(0) - 1;
Crowptrs = row_ptrs_array(Kokkos::ViewAllocateWithoutInitializing("C row ptrs"), nrows + 1);
KokkosSparse::Experimental::spadd_symbolic
(&handle, Arowptrs, AcolindsConverted, Browptrs, BcolindsConverted, Crowptrs);
(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, A.numCols(),
#endif
Arowptrs, AcolindsConverted, Browptrs, BcolindsConverted, Crowptrs);
Cvals = values_array("C values", addHandle->get_c_nnz());
Ccolinds = global_col_inds_array(Kokkos::ViewAllocateWithoutInitializing("C colinds"), addHandle->get_c_nnz());
#ifdef HAVE_TPETRA_MMM_TIMINGS
MM = Teuchos::null;
MM = rcp(new TimeMonitor(*TimeMonitor::getNewTimer("TpetraExt::MatrixMatrix::add() diff col map kernel: unsorted numeric")));
#endif
KokkosSparse::Experimental::spadd_numeric(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, A.numCols(),
#endif
Arowptrs, AcolindsConverted, Avals, scalarA,
Browptrs, BcolindsConverted, Bvals, scalarB,
Crowptrs, Ccolinds, Cvals);
Expand Down
12 changes: 10 additions & 2 deletions packages/tpetra/core/src/Tpetra_CrsGraphTransposer_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ namespace Tpetra {
global_col_inds_array globalColindsSym;

KokkosSparse::Experimental::spadd_symbolic
(&handle, rowptrs, colindsConverted, rowptrsT, colindsTConverted, rowptrsSym);
(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, graph->getGlobalNumCols(),
#endif
rowptrs, colindsConverted, rowptrsT, colindsTConverted, rowptrsSym);
globalColindsSym = global_col_inds_array(Kokkos::ViewAllocateWithoutInitializing("global colinds sym"), addHandle->get_c_nnz());

UnsortedNumericIndicesOnlyFunctor<
Expand Down Expand Up @@ -325,7 +329,11 @@ namespace Tpetra {
auto addHandle = handle.get_spadd_handle();

KokkosSparse::Experimental::spadd_symbolic
(&handle, rowptrs, colinds, rowptrsT, colindsT, rowptrsSym);
(&handle,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, graph->getGlobalNumCols(),
#endif
rowptrs, colinds, rowptrsT, colindsT, rowptrsSym);
colindsSym = col_inds_array(Kokkos::ViewAllocateWithoutInitializing("C colinds"), addHandle->get_c_nnz());

if (sorted) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2154,6 +2154,9 @@ TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Tpetra_MatMat, threaded_add_sorted, SC, LO, GO
Tpetra::MMdetails::AddKernels<SC, LO, GO, NT>::addSorted(
valsCRS[0], rowptrsCRS[0], colindsCRS[0], one,
valsCRS[1], rowptrsCRS[1], colindsCRS[1], one,
#if KOKKOSKERNELS_VERSION >= 40299
nrows, // assumes square matrices
#endif
valsCRS[2], rowptrsCRS[2], colindsCRS[2]);

ExecSpace().fence();
Expand Down

0 comments on commit 4ac842a

Please sign in to comment.