Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE Tpetra: Transfer and fill complete refactor #13491

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 23 additions & 67 deletions packages/tpetra/core/src/Tpetra_CrsMatrix_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "KokkosBlas1_scal.hpp"
#include "KokkosSparse_getDiagCopy.hpp"
#include "KokkosSparse_spmv.hpp"
#include "Kokkos_StdAlgorithms.hpp"

#include <memory>
#include <sstream>
Expand Down Expand Up @@ -8301,59 +8302,43 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
<< std::endl;
std::cerr << os.str ();
}
// Make sure that host has the latest version, since we're
// using the version on host. If host has the latest
// version, syncing to host does nothing.
destMat->numExportPacketsPerLID_.sync_host ();
Teuchos::ArrayView<const size_t> numExportPacketsPerLID =
getArrayViewFromDualView (destMat->numExportPacketsPerLID_);
destMat->numImportPacketsPerLID_.sync_host ();
Teuchos::ArrayView<size_t> numImportPacketsPerLID =
getArrayViewFromDualView (destMat->numImportPacketsPerLID_);

destMat->numExportPacketsPerLID_.sync_device();
auto numExportPacketsPerLID = destMat->numExportPacketsPerLID_.view_device();
auto numImportPacketsPerLID = destMat->numImportPacketsPerLID_.view_device();
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Calling 3-arg doReversePostsAndWaits"
<< std::endl;
std::cerr << os.str ();
}
Distor.doReversePostsAndWaits(destMat->numExportPacketsPerLID_.view_host(), 1,
destMat->numImportPacketsPerLID_.view_host());
Distor.doReversePostsAndWaits(numExportPacketsPerLID, 1, numImportPacketsPerLID);
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Finished 3-arg doReversePostsAndWaits"
<< std::endl;
std::cerr << os.str ();
}

size_t totalImportPackets = 0;
for (Array_size_type i = 0; i < numImportPacketsPerLID.size (); ++i) {
totalImportPackets += numImportPacketsPerLID[i];
}
size_t totalImportPackets = Kokkos::Experimental::reduce(typename Node::execution_space(), numImportPacketsPerLID);

// Reallocation MUST go before setting the modified flag,
// because it may clear out the flags.
destMat->reallocImportsIfNeeded (totalImportPackets, verbose,
verbosePrefix.get ());
destMat->imports_.modify_host ();
auto hostImports = destMat->imports_.view_host();
// This is a legacy host pack/unpack path, so use the host
// version of exports_.
destMat->exports_.sync_host ();
auto hostExports = destMat->exports_.view_host();
auto deviceImports = destMat->imports_.view_device();
auto deviceExports = destMat->exports_.view_device();
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Calling 4-arg doReversePostsAndWaits"
os << *verbosePrefix << "Calling 4-arg doReversePostsAndWaitsKokkos"
<< std::endl;
std::cerr << os.str ();
}
Distor.doReversePostsAndWaits (hostExports,
numExportPacketsPerLID,
hostImports,
numImportPacketsPerLID);
destMat->imports_.sync_device();
Distor.doReversePostsAndWaitsKokkos (deviceExports, numExportPacketsPerLID, deviceImports, numImportPacketsPerLID);
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Finished 4-arg doReversePostsAndWaits"
os << *verbosePrefix << "Finished 4-arg doReversePostsAndWaitsKokkos"
<< std::endl;
std::cerr << os.str ();
}
Expand Down Expand Up @@ -8396,58 +8381,43 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
<< std::endl;
std::cerr << os.str ();
}
// Make sure that host has the latest version, since we're
// using the version on host. If host has the latest
// version, syncing to host does nothing.
destMat->numExportPacketsPerLID_.sync_host ();
Teuchos::ArrayView<const size_t> numExportPacketsPerLID =
getArrayViewFromDualView (destMat->numExportPacketsPerLID_);
destMat->numImportPacketsPerLID_.sync_host ();
Teuchos::ArrayView<size_t> numImportPacketsPerLID =
getArrayViewFromDualView (destMat->numImportPacketsPerLID_);
destMat->numExportPacketsPerLID_.sync_device ();
auto numExportPacketsPerLID = destMat->numExportPacketsPerLID_.view_device();
auto numImportPacketsPerLID = destMat->numImportPacketsPerLID_.view_device();
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Calling 3-arg doPostsAndWaits"
<< std::endl;
std::cerr << os.str ();
}
Distor.doPostsAndWaits(destMat->numExportPacketsPerLID_.view_host(), 1,
destMat->numImportPacketsPerLID_.view_host());
Distor.doPostsAndWaits(numExportPacketsPerLID, 1, numImportPacketsPerLID);
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Finished 3-arg doPostsAndWaits"
<< std::endl;
std::cerr << os.str ();
}

size_t totalImportPackets = 0;
for (Array_size_type i = 0; i < numImportPacketsPerLID.size (); ++i) {
totalImportPackets += numImportPacketsPerLID[i];
}
size_t totalImportPackets = Kokkos::Experimental::reduce(typename Node::execution_space(), numImportPacketsPerLID);

// Reallocation MUST go before setting the modified flag,
// because it may clear out the flags.
destMat->reallocImportsIfNeeded (totalImportPackets, verbose,
verbosePrefix.get ());
destMat->imports_.modify_host ();
auto hostImports = destMat->imports_.view_host();
// This is a legacy host pack/unpack path, so use the host
// version of exports_.
destMat->exports_.sync_host ();
auto hostExports = destMat->exports_.view_host();
auto deviceImports = destMat->imports_.view_device();
auto deviceExports = destMat->exports_.view_device();
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Calling 4-arg doPostsAndWaits"
os << *verbosePrefix << "Calling 4-arg doPostsAndWaitsKokkos"
<< std::endl;
std::cerr << os.str ();
}
Distor.doPostsAndWaits (hostExports,
numExportPacketsPerLID,
hostImports,
numImportPacketsPerLID);
destMat->imports_.sync_device ();
Distor.doPostsAndWaitsKokkos (deviceExports, numExportPacketsPerLID, deviceImports, numImportPacketsPerLID);
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Finished 4-arg doPostsAndWaits"
os << *verbosePrefix << "Finished 4-arg doPostsAndWaitsKokkos"
<< std::endl;
std::cerr << os.str ();
}
Expand Down Expand Up @@ -8494,12 +8464,6 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
Teuchos::Array<int> RemotePids;
if (runOnHost) {
Teuchos::Array<int> TargetPids;
// Backwards compatibility measure. We'll use this again below.

// TODO JHU Need to track down why numImportPacketsPerLID_ has not been corrently marked as modified on host (which it has been)
// TODO JHU somewhere above, e.g., call to Distor.doPostsAndWaits().
// TODO JHU This only becomes apparent as we begin to convert TAFC to run on device.
destMat->numImportPacketsPerLID_.modify_host(); //FIXME

# ifdef HAVE_TPETRA_MMM_TIMINGS
RCP<TimeMonitor> tmCopySPRdata = rcp(new TimeMonitor(*TimeMonitor::getNewTimer(prefix + std::string("TAFC unpack-count-resize + copy same-perm-remote data"))));
Expand Down Expand Up @@ -8691,14 +8655,6 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
} else {
// run on device


// Backwards compatibility measure. We'll use this again below.

// TODO JHU Need to track down why numImportPacketsPerLID_ has not been corrently marked as modified on host (which it has been)
// TODO JHU somewhere above, e.g., call to Distor.doPostsAndWaits().
// TODO JHU This only becomes apparent as we begin to convert TAFC to run on device.
destMat->numImportPacketsPerLID_.modify_host(); //FIXME

# ifdef HAVE_TPETRA_MMM_TIMINGS
RCP<TimeMonitor> tmCopySPRdata = rcp(new TimeMonitor(*TimeMonitor::getNewTimer(prefix + std::string("TAFC unpack-count-resize + copy same-perm-remote data"))));
# endif
Expand Down
Loading