Skip to content

Commit

Permalink
More name refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jgfouca committed Sep 29, 2023
1 parent c259032 commit 50593a6
Showing 1 changed file with 45 additions and 66 deletions.
111 changes: 45 additions & 66 deletions packages/shylu/shylu_node/fastilu/src/shylu_fastilu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,8 +992,8 @@ class FastILUPrec
// Re-block A_. At this point, aHost and A_ are unblocked. The host stuff isn't
// used anymore after this, so just reblock A_
#ifdef FASTILU_ONE_TO_ONE_UNBLOCKED
if (blockCrsSize > 1) {
reblock(aRowMap_, aRowIdx_, aColIdx_, aVal_, blockCrsSize);
if (m_blockCrsSize > 1) {
reblock(aRowMap_, aRowIdx_, aColIdx_, aVal_, m_blockCrsSize);
}
#endif

Expand Down Expand Up @@ -1196,7 +1196,7 @@ class FastILUPrec
// create a map from A to U (sorted)
auto nnzU = m_uRowMapHost[m_nRows];
m_a2uMap = OrdinalArray("a2uMap", nnzU);
auto a2uMap_ = Kokkos::create_mirror_view(m_a2uMap);
auto a2uMap = Kokkos::create_mirror_view(m_a2uMap);
for (Ordinal i = 0; i < m_nRows; i++)
{
for (Ordinal k = m_aRowMapHost[i]; k < m_aRowMapHost[i+1]; k++)
Expand All @@ -1206,12 +1206,12 @@ class FastILUPrec
if (row <= col)
{
Ordinal pos = m_uRowMapHost[col];
a2uMap_(pos) = k;
a2uMap(pos) = k;
m_uRowMapHost[col]++;
}
}
}
Kokkos::deep_copy(m_a2uMap, a2uMap_);
Kokkos::deep_copy(m_a2uMap, a2uMap);
// shift back pointer
for (Ordinal i = m_nRows; i > 0; i--)
{
Expand Down Expand Up @@ -1363,67 +1363,66 @@ class FastILUPrec
public:
//Constructor
//TODO: Use a Teuchos::ParameterList object
FastILUPrec(bool skipSortMatrix_, OrdinalArray &aRowMapIn_, OrdinalArray &aColIdxIn_, ScalarArray &aValIn_, Ordinal nRow_,
FastILU::SpTRSV sptrsv_algo_, Ordinal nFact_, Ordinal nTrisol_, Ordinal level_, Scalar omega_, Scalar shift_,
Ordinal guessFlag_, Ordinal blkSzILU_, Ordinal blkSz_, Ordinal blockCrsSize_ = 1)
FastILUPrec(bool skipSortMatrix, OrdinalArray &aRowMapIn, OrdinalArray &aColIdxIn, ScalarArray &aValIn, Ordinal nRows,
FastILU::SpTRSV sptrsv_algo, Ordinal nFact, Ordinal nTrisol, Ordinal level, Scalar omega, Scalar shift,
Ordinal guessFlag, Ordinal blkSzILU, Ordinal blkSz, Ordinal blockCrsSize = 1)
{
m_nRows = nRow_;
m_sptrsv_algo = sptrsv_algo_;
m_nFact = nFact_;
m_nTrisol = nTrisol_;
m_nRows = nRows;
m_sptrsv_algo = sptrsv_algo;
m_nFact = nFact;
m_nTrisol = nTrisol;

m_useMetis = false;

m_computeTime = 0.0;
m_applyTime = 0.0;
m_initTime = 0.0;
//icFlag = icFlag_;
m_level = level_;
m_level = level;

// mirror & deep-copy the input matrix
m_skipSortMatrix = skipSortMatrix_;
m_aRowMapIn = aRowMapIn_;
m_aColIdxIn = aColIdxIn_;
m_aValIn = aValIn_;

m_omega = omega_;
m_guessFlag = guessFlag_;
m_shift = shift_;
m_blkSzILU = blkSzILU_;
m_blkSz = blkSz_;
m_blockCrsSize = blockCrsSize_;
m_skipSortMatrix = skipSortMatrix;
m_aRowMapIn = aRowMapIn;
m_aColIdxIn = aColIdxIn;
m_aValIn = aValIn;

m_omega = omega;
m_guessFlag = guessFlag;
m_shift = shift;
m_blkSzILU = blkSzILU;
m_blkSz = blkSz;
m_blockCrsSize = blockCrsSize;
m_doUnitDiag_TRSV = true; // perform TRSV with unit diagonals
m_sptrsv_KKSpMV = true; // use Kokkos-Kernels SpMV for Fast SpTRSV

if (!BlockCrsEnabled) {
assert(blockCrsSize_ == 1);
assert(blockCrsSize == 1);
}

const Scalar one = STS::one();
m_onesVector = ScalarArray("onesVector", nRow_);
m_onesVector = ScalarArray("onesVector", nRows);
Kokkos::deep_copy(m_onesVector, one);

m_diagFact = RealArray("diagFact", nRow_ * blockCrsSize_);
m_diagElems = ScalarArray("diagElems", nRow_ * blockCrsSize_);
m_xOld = ScalarArray("xOld", nRow_ * blockCrsSize_);
m_xTemp = ScalarArray("xTemp", nRow_ * blockCrsSize_);
m_diagFact = RealArray("diagFact", nRows * blockCrsSize);
m_diagElems = ScalarArray("diagElems", nRows * blockCrsSize);
m_xOld = ScalarArray("xOld", nRows * blockCrsSize);
m_xTemp = ScalarArray("xTemp", nRows * blockCrsSize);

m_aRowMapInHost = OrdinalArrayMirror("aRowMapHost", m_aRowMapIn.size());
m_aColIdxInHost = OrdinalArrayMirror("aColIdxHost", m_aColIdxIn.size());
m_aValInHost = ScalarArrayMirror("aValHost", m_aValIn.size());
Kokkos::deep_copy(m_aRowMapInHost, aRowMapIn_);
Kokkos::deep_copy(m_aColIdxInHost, aColIdxIn_);
Kokkos::deep_copy(m_aValInHost, aValIn_);
Kokkos::deep_copy(m_aRowMapInHost, aRowMapIn);
Kokkos::deep_copy(m_aColIdxInHost, aColIdxIn);
Kokkos::deep_copy(m_aValInHost, aValIn);
#ifdef FASTILU_ONE_TO_ONE_UNBLOCKED
if (blockCrsSize > 1) {
unblock(aRowMapHost, aColIdxHost, aValHost, blockCrsSize);
if (m_blockCrsSize > 1) {
unblock(m_aRowMapHost, m_aColIdxHost, m_aValHost, m_blockCrsSize);
}
#endif

if ((m_level > 0) && (m_guessFlag != 0))
{
m_initGuessPrec = Teuchos::rcp(new FastPrec(skipSortMatrix_, aRowMapIn_, aColIdxIn_, aValIn_, nRow_, sptrsv_algo_,
3, 5, level_-1, omega_, shift_, guessFlag_, blkSzILU_, blkSz_, blockCrsSize_));
m_initGuessPrec = Teuchos::rcp(new FastPrec(skipSortMatrix, aRowMapIn, aColIdxIn, aValIn, nRows, sptrsv_algo,
3, 5, level-1, omega, shift, guessFlag, blkSzILU, blkSz, blockCrsSize));
}
}

Expand Down Expand Up @@ -1695,25 +1694,25 @@ class FastILUPrec

// set Metis pre-ordering
template<class MetisArrayHost>
void setMetisPerm(MetisArrayHost permMetis_, MetisArrayHost ipermMetis_)
void setMetisPerm(MetisArrayHost permMetis, MetisArrayHost ipermMetis)
{
Ordinal nRows_ = permMetis_.size();
Ordinal nRows_ = permMetis.size();
if (m_nRows > 0) {
m_permMetis = OrdinalArray("permMetis", nRows_);
m_ipermMetis = OrdinalArray("ipermMetis", nRows_);

m_permMetisHost = Kokkos::create_mirror_view(m_permMetis);
m_ipermMetisHost = Kokkos::create_mirror_view(m_ipermMetis);
for (Ordinal i = 0; i < nRows_; i++) {
m_permMetisHost(i) = permMetis_(i);
m_ipermMetisHost(i) = ipermMetis_(i);
m_permMetisHost(i) = permMetis(i);
m_ipermMetisHost(i) = ipermMetis(i);
}
Kokkos::deep_copy(m_permMetis, m_permMetisHost);
Kokkos::deep_copy(m_ipermMetis, m_ipermMetisHost);
}
if ((m_level > 0) && (m_guessFlag != 0))
{
m_initGuessPrec->setMetisPerm(permMetis_, ipermMetis_);
m_initGuessPrec->setMetisPerm(permMetis, ipermMetis);
}
m_useMetis = true;
}
Expand Down Expand Up @@ -1762,25 +1761,6 @@ class FastILUPrec
//initialize L, U, A patterns
symbolicILU_common();

#ifdef SHYLU_DEBUG
Ordinal nnzU = uRowMap[nRows];
MemoryPrimeFunctorN<Ordinal, Scalar, ExecSpace> copyFunc1(aRowMap, lRowMap, uRowMap, diagElems);
MemoryPrimeFunctorNnzCsr<Ordinal, Scalar, ExecSpace> copyFunc4(uColIdx, uVal);

Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, nRows), copyFunc1);
Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, nnzU), copyFunc4);

//Note that the following is a temporary measure
//to ensure that memory resides on the device.
Ordinal nnzL = lRowMap[nRows];
Ordinal nnzA = aRowMap[nRows];
MemoryPrimeFunctorNnzCoo<Ordinal, Scalar, ExecSpace> copyFunc2(aColIdx, aRowIdx, aVal);
MemoryPrimeFunctorNnzCsr<Ordinal, Scalar, ExecSpace> copyFunc3(lColIdx, lVal);

Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, nRows), copyFunc1);
Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, nnzA), copyFunc2);
Kokkos::parallel_for(Kokkos::RangePolicy<ExecSpace>(0, nnzL), copyFunc3);
#endif
ExecSpace().fence(); //Fence so that init time is accurate
m_initTime = timer.seconds();
}
Expand Down Expand Up @@ -1882,14 +1862,14 @@ class FastILUPrec
// setup L solve
khL.create_sptrsv_handle(algo, m_nRows, true);
#if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
KokkosSparse::Experimental::sptrsv_symbolic(&khL, lRowMap, lColIdx, lVal);
KokkosSparse::Experimental::sptrsv_symbolic(&khL, m_lRowMap, m_lColIdx, m_lVal);
#else
KokkosSparse::Experimental::sptrsv_symbolic(&khL, m_lRowMap, m_lColIdx);
#endif
// setup U solve
khU.create_sptrsv_handle(algo, m_nRows, false);
#if defined(KOKKOSKERNELS_ENABLE_TPL_CUSPARSE)
KokkosSparse::Experimental::sptrsv_symbolic(&khU, utRowMap, utColIdx, utVal);
KokkosSparse::Experimental::sptrsv_symbolic(&khU, m_utRowMap, m_utColIdx, m_utVal);
#else
KokkosSparse::Experimental::sptrsv_symbolic(&khU, m_utRowMap, m_utColIdx);
#endif
Expand Down Expand Up @@ -2164,8 +2144,7 @@ class FastILUPrec
}
//Only fencing here so that apply time is accurate
ExecSpace().fence();
double t = timer.seconds();
m_applyTime = t;
m_applyTime = timer.seconds();
}

Ordinal getNFact() const
Expand Down

0 comments on commit 50593a6

Please sign in to comment.