diff --git a/perf_tests/test_2dhalo.cpp b/perf_tests/test_2dhalo.cpp index 073eb683..fe2c7faa 100644 --- a/perf_tests/test_2dhalo.cpp +++ b/perf_tests/test_2dhalo.cpp @@ -82,14 +82,12 @@ void benchmark_2dhalo(benchmark::State &state) { const int ry = rank / rs; if (rank < rs * rs) { - auto mode = KokkosComm::DefaultCommMode(); auto space = Kokkos::DefaultExecutionSpace(); // grid of elements, each with 3 properties, and a radius-1 halo grid_type grid("", nx + 2, ny + 2, nprops); while (state.KeepRunning()) { - do_iteration(state, MPI_COMM_WORLD, - send_recv, mode, space, nx, - ny, rx, ry, rs, grid); + do_iteration(state, MPI_COMM_WORLD, send_recv, space, nx, ny, rx, ry, + rs, grid); } } else { while (state.KeepRunning()) { diff --git a/perf_tests/test_osu_latency.cpp b/perf_tests/test_osu_latency.cpp index 68a6a248..2758f993 100644 --- a/perf_tests/test_osu_latency.cpp +++ b/perf_tests/test_osu_latency.cpp @@ -63,7 +63,6 @@ void benchmark_osu_latency_Kokkos_Comm_mpi_sendrecv(benchmark::State &state) { state.SkipWithError("benchmark_osu_latency_KokkosComm needs exactly 2 ranks"); } - auto mode = KokkosComm::DefaultCommMode(); auto space = Kokkos::DefaultExecutionSpace(); using view_type = Kokkos::View; view_type a("A", state.range(0)); diff --git a/perf_tests/test_sendrecv.cpp b/perf_tests/test_sendrecv.cpp index ae39a747..11fb220e 100644 --- a/perf_tests/test_sendrecv.cpp +++ b/perf_tests/test_sendrecv.cpp @@ -18,14 +18,14 @@ #include "KokkosComm.hpp" -template -void send_recv(benchmark::State &, MPI_Comm comm, const Mode &mode, const Space &space, int rank, const View &v) { +template +void send_recv(benchmark::State &, MPI_Comm comm, const Space &space, int rank, const View &v) { if (0 == rank) { - KokkosComm::mpi::send(space, v, 1, 0, comm); + KokkosComm::mpi::send(space, v, 1, 0, comm, Mode{}); KokkosComm::mpi::recv(space, v, 1, 0, comm); } else if (1 == rank) { KokkosComm::mpi::recv(space, v, 0, 0, comm); - KokkosComm::mpi::send(space, v, 0, 0, comm); + KokkosComm::mpi::send(space, v, 0, 0, comm, Mode{}); } } @@ -39,15 +39,13 @@ void benchmark_sendrecv(benchmark::State &state) { using Scalar = double; - auto mode = KokkosComm::DefaultCommMode(); + using Mode = KokkosComm::mpi::DefaultCommMode; auto space = Kokkos::DefaultExecutionSpace(); using view_type = Kokkos::View; view_type a("", 1000000); while (state.KeepRunning()) { - do_iteration(state, MPI_COMM_WORLD, - send_recv, mode, space, rank, - a); + do_iteration(state, MPI_COMM_WORLD, send_recv, space, rank, a); } state.SetBytesProcessed(sizeof(Scalar) * state.iterations() * a.size() * 2); diff --git a/src/KokkosComm_collective.hpp b/src/KokkosComm_collective.hpp index b61b5f14..3e2f6e6a 100644 --- a/src/KokkosComm_collective.hpp +++ b/src/KokkosComm_collective.hpp @@ -25,7 +25,8 @@ namespace KokkosComm { -template +template void barrier(Handle &&h) { Impl::Barrier{std::forward>(h)}; } diff --git a/src/KokkosComm_fwd.hpp b/src/KokkosComm_fwd.hpp index a9dcbcbc..354056b1 100644 --- a/src/KokkosComm_fwd.hpp +++ b/src/KokkosComm_fwd.hpp @@ -33,7 +33,8 @@ using FallbackCommunicationSpace = Mpi; template class Req; -template +template class Handle; namespace Impl { @@ -44,7 +45,8 @@ struct Recv; template struct Send; -template +template struct Barrier; } // namespace Impl diff --git a/src/mpi/KokkosComm_mpi_commmode.hpp b/src/mpi/KokkosComm_mpi_commmode.hpp index 489bfe01..d0249d9d 100644 --- a/src/mpi/KokkosComm_mpi_commmode.hpp +++ b/src/mpi/KokkosComm_mpi_commmode.hpp @@ -16,26 +16,52 @@ #pragma once -namespace KokkosComm::mpi { -// Scoped enumeration to specify the communication mode of a sending operation. +#include + // See section 3.4 of the MPI standard for a complete specification. -enum class CommMode { - // Default mode: lets the user override the send operations behavior at - // compile-time. E.g., this can be set to mode "Synchronous" for debug - // builds by defining KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE. - Default, - // Standard mode: MPI implementation decides whether outgoing messages will - // be buffered. Send operations can be started whether or not a matching - // receive has been started. They may complete before a matching receive is - // started. Standard mode is non-local: successful completion of the send - // operation may depend on the occurrence of a matching receive. - Standard, - // Ready mode: Send operations may be started only if the matching receive is - // already started. - Ready, - // Synchronous mode: Send operations complete successfully only if a matching - // receive is started, and the receive operation has started to receive the - // message sent. - Synchronous, -}; + +namespace KokkosComm::mpi { +// Standard mode: MPI implementation decides whether outgoing messages will +// be buffered. Send operations can be started whether or not a matching +// receive has been started. They may complete before a matching receive is +// started. Standard mode is non-local: successful completion of the send +// operation may depend on the occurrence of a matching receive. +struct CommModeStandard {}; + +// Ready mode: Send operations may be started only if the matching receive is +// already started. +struct CommModeReady {}; + +// Synchronous mode: Send operations complete successfully only if a matching +// receive is started, and the receive operation has started to receive the +// message sent. +struct CommModeSynchronous {}; + +// Default mode: lets the user override the send operations behavior at +// compile-time. E.g., this can be set to mode "Synchronous" for debug +// builds by defining KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE. +#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE +using DefaultCommMode = CommModeSynchronous; +#else +using DefaultCommMode = CommModeStandard; +#endif + +template +struct is_communication_mode : std::false_type {}; + +template <> +struct is_communication_mode : std::true_type {}; + +template <> +struct is_communication_mode : std::true_type {}; + +template <> +struct is_communication_mode : std::true_type {}; + +template +inline constexpr bool is_communication_mode_v = is_communication_mode::value; + +template +concept CommunicationMode = is_communication_mode_v; + } // namespace KokkosComm::mpi \ No newline at end of file diff --git a/src/mpi/KokkosComm_mpi_isend.hpp b/src/mpi/KokkosComm_mpi_isend.hpp index eea987bd..ccea4578 100644 --- a/src/mpi/KokkosComm_mpi_isend.hpp +++ b/src/mpi/KokkosComm_mpi_isend.hpp @@ -26,22 +26,18 @@ namespace KokkosComm { namespace Impl { -template -Req isend_impl(Handle &h, const SendView &sv, int dest, int tag) { +template +Req isend_impl(Handle &h, const SendView &sv, int dest, int tag, SendMode) { auto mpi_isend_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag, MPI_Comm mpi_comm, MPI_Request *mpi_req) { - if constexpr (SendMode == mpi::CommMode::Standard) { + if constexpr (std::is_same_v) { MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req); - } else if constexpr (SendMode == mpi::CommMode::Ready) { + } else if constexpr (std::is_same_v) { MPI_Irsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req); - } else if constexpr (SendMode == mpi::CommMode::Synchronous) { + } else if constexpr (std::is_same_v) { MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req); - } else if constexpr (SendMode == mpi::CommMode::Default) { -#ifdef KOKKOSCOMM_FORCE_SYNCHRONOUS_MODE - MPI_Issend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req); -#else - MPI_Isend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm, mpi_req); -#endif + } else { + static_assert(std::is_void_v, "unexpected communication mode"); } }; @@ -67,7 +63,7 @@ Req isend_impl(Handle &h, const SendView &sv, int dest, int template struct Send { static Req execute(Handle &h, const SendView &sv, int dest, int tag) { - return isend_impl(h, sv, dest, tag); + return isend_impl(h, sv, dest, tag, mpi::DefaultCommMode{}); } }; @@ -75,9 +71,14 @@ struct Send { namespace mpi { -template +template +Req isend(Handle &h, const SendView &sv, int dest, int tag, SendMode) { + return KokkosComm::Impl::isend_impl(h, sv, dest, tag, SendMode{}); +} + +template Req isend(Handle &h, const SendView &sv, int dest, int tag) { - return KokkosComm::Impl::isend_impl(h, sv, dest, tag); + return isend(h, sv, dest, tag, DefaultCommMode{}); } template diff --git a/src/mpi/KokkosComm_mpi_send.hpp b/src/mpi/KokkosComm_mpi_send.hpp index e335ced1..3ac1a4c3 100644 --- a/src/mpi/KokkosComm_mpi_send.hpp +++ b/src/mpi/KokkosComm_mpi_send.hpp @@ -24,19 +24,21 @@ namespace KokkosComm::mpi { -template -void send(const SendMode &, const SendView &sv, int dest, int tag, MPI_Comm comm) { +template +void send(const SendView &sv, int dest, int tag, MPI_Comm comm, SendMode) { Kokkos::Tools::pushRegion("KokkosComm::Impl::send"); using KCT = typename KokkosComm::Traits; auto mpi_send_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag, MPI_Comm mpi_comm) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { MPI_Rsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); + } else { + static_assert(std::is_void_v, "unexpected communication mode"); } }; @@ -50,25 +52,22 @@ void send(const SendMode &, const SendView &sv, int dest, int tag, MPI_Comm comm Kokkos::Tools::popRegion(); } -template -void send(const SendView &sv, int dest, int tag, MPI_Comm comm) { - send(KokkosComm::DefaultCommMode(), sv, dest, tag, comm); -} - -template -void send(const SendMode &, const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) { +template +void send(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm, SendMode) { Kokkos::Tools::pushRegion("KokkosComm::Impl::send"); using Packer = typename KokkosComm::PackTraits::packer_type; auto mpi_send_fn = [](void *mpi_view, int mpi_count, MPI_Datatype mpi_datatype, int mpi_dest, int mpi_tag, MPI_Comm mpi_comm) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { MPI_Send(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { MPI_Rsend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { MPI_Ssend(mpi_view, mpi_count, mpi_datatype, mpi_dest, mpi_tag, mpi_comm); + } else { + static_assert(std::is_void_v, "unexpected communication mode"); } }; @@ -84,4 +83,9 @@ void send(const SendMode &, const ExecSpace &space, const SendView &sv, int dest Kokkos::Tools::popRegion(); } +template +void send(const ExecSpace &space, const SendView &sv, int dest, int tag, MPI_Comm comm) { + send(space, sv, dest, tag, comm, DefaultCommMode{}); +} + } // namespace KokkosComm::mpi diff --git a/unit_tests/mpi/test_isendrecv.cpp b/unit_tests/mpi/test_isendrecv.cpp index f2ffbbf4..bc03ef8f 100644 --- a/unit_tests/mpi/test_isendrecv.cpp +++ b/unit_tests/mpi/test_isendrecv.cpp @@ -21,6 +21,8 @@ namespace { +using namespace KokkosComm::mpi; + template class IsendRecv : public testing::Test { public: @@ -31,9 +33,9 @@ using ScalarTypes = ::testing::Types, Kokkos::complex, int, unsigned, int64_t, size_t>; TYPED_TEST_SUITE(IsendRecv, ScalarTypes); -template +template void isend_comm_mode_1d_contig() { - if (IsendMode == KokkosComm::mpi::CommMode::Ready) { + if constexpr (std::is_same_v) { GTEST_SKIP() << "Skipping test for ready-mode send"; } @@ -48,7 +50,7 @@ void isend_comm_mode_1d_contig() { int dst = 1; Kokkos::parallel_for( a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; }); - KokkosComm::Req req = KokkosComm::mpi::isend(h, a, dst, 0); + KokkosComm::Req req = KokkosComm::mpi::isend(h, a, dst, 0, IsendMode{}); KokkosComm::wait(req); } else if (1 == h.rank()) { int src = 0; @@ -60,9 +62,9 @@ void isend_comm_mode_1d_contig() { } } -template +template void isend_comm_mode_1d_noncontig() { - if (IsendMode == KokkosComm::mpi::CommMode::Ready) { + if constexpr (std::is_same_v) { GTEST_SKIP() << "Skipping test for ready-mode send"; } @@ -79,7 +81,7 @@ void isend_comm_mode_1d_noncontig() { int dst = 1; Kokkos::parallel_for( a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; }); - KokkosComm::Req req = KokkosComm::mpi::isend(h, a, dst, 0); + KokkosComm::Req req = KokkosComm::mpi::isend(h, a, dst, 0, IsendMode{}); KokkosComm::wait(req); } else if (1 == h.rank()) { int src = 0; @@ -92,27 +94,25 @@ void isend_comm_mode_1d_noncontig() { } TYPED_TEST(IsendRecv, 1D_contig_standard) { - isend_comm_mode_1d_contig(); + isend_comm_mode_1d_contig(); } -TYPED_TEST(IsendRecv, 1D_contig_ready) { - isend_comm_mode_1d_contig(); -} +TYPED_TEST(IsendRecv, 1D_contig_ready) { isend_comm_mode_1d_contig(); } TYPED_TEST(IsendRecv, 1D_contig_synchronous) { - isend_comm_mode_1d_contig(); + isend_comm_mode_1d_contig(); } TYPED_TEST(IsendRecv, 1D_noncontig_standard) { - isend_comm_mode_1d_noncontig(); + isend_comm_mode_1d_noncontig(); } TYPED_TEST(IsendRecv, 1D_noncontig_ready) { - isend_comm_mode_1d_noncontig(); + isend_comm_mode_1d_noncontig(); } TYPED_TEST(IsendRecv, 1D_noncontig_synchronous) { - isend_comm_mode_1d_noncontig(); + isend_comm_mode_1d_noncontig(); } } // namespace diff --git a/unit_tests/mpi/test_sendrecv.cpp b/unit_tests/mpi/test_sendrecv.cpp index 306d2301..19a9e455 100644 --- a/unit_tests/mpi/test_sendrecv.cpp +++ b/unit_tests/mpi/test_sendrecv.cpp @@ -21,6 +21,8 @@ namespace { +using namespace KokkosComm::mpi; + template class MpiSendRecv : public testing::Test { public: @@ -30,9 +32,9 @@ class MpiSendRecv : public testing::Test { using ScalarTypes = ::testing::Types, Kokkos::complex>; TYPED_TEST_SUITE(MpiSendRecv, ScalarTypes); -template +template void send_comm_mode_1d_contig() { - if (SendMode == KokkosComm::mpi::CommMode::Ready) { + if constexpr (std::is_same_v) { GTEST_SKIP() << "Skipping test for ready-mode send"; } @@ -49,7 +51,7 @@ void send_comm_mode_1d_contig() { int dst = 1; Kokkos::parallel_for( a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; }); - KokkosComm::mpi::send(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD); + KokkosComm::mpi::send(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD, SendMode{}); } else if (1 == rank) { int src = 0; KokkosComm::mpi::recv(Kokkos::DefaultExecutionSpace(), a, src, 0, MPI_COMM_WORLD); @@ -60,9 +62,9 @@ void send_comm_mode_1d_contig() { } } -template +template void send_comm_mode_1d_noncontig() { - if (SendMode == KokkosComm::mpi::CommMode::Ready) { + if constexpr (std::is_same_v) { GTEST_SKIP() << "Skipping test for ready-mode send"; } @@ -77,7 +79,7 @@ void send_comm_mode_1d_noncontig() { int dst = 1; Kokkos::parallel_for( a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; }); - KokkosComm::mpi::send(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD); + KokkosComm::mpi::send(Kokkos::DefaultExecutionSpace(), a, dst, 0, MPI_COMM_WORLD, SendMode{}); } else if (1 == rank) { int src = 0; KokkosComm::mpi::recv(Kokkos::DefaultExecutionSpace(), a, src, 0, MPI_COMM_WORLD); @@ -89,27 +91,25 @@ void send_comm_mode_1d_noncontig() { } TYPED_TEST(MpiSendRecv, 1D_contig_standard) { - send_comm_mode_1d_contig(); + send_comm_mode_1d_contig(); } -TYPED_TEST(MpiSendRecv, 1D_contig_ready) { - send_comm_mode_1d_contig(); -} +TYPED_TEST(MpiSendRecv, 1D_contig_ready) { send_comm_mode_1d_contig(); } TYPED_TEST(MpiSendRecv, 1D_contig_synchronous) { - send_comm_mode_1d_contig(); + send_comm_mode_1d_contig(); } TYPED_TEST(MpiSendRecv, 1D_noncontig_standard) { - send_comm_mode_1d_noncontig(); + send_comm_mode_1d_noncontig(); } TYPED_TEST(MpiSendRecv, 1D_noncontig_ready) { - send_comm_mode_1d_noncontig(); + send_comm_mode_1d_noncontig(); } TYPED_TEST(MpiSendRecv, 1D_noncontig_synchronous) { - send_comm_mode_1d_noncontig(); + send_comm_mode_1d_noncontig(); } } // namespace