From dc7464219c0643eda48668df8fa98bc99238e170 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Mon, 12 Aug 2024 09:15:01 -0600 Subject: [PATCH] Remove tags from core interface --- docs/api/core.rst | 12 ++++-------- docs/api/core_recv.cpp | 2 +- docs/api/core_send.cpp | 5 ++--- perf_tests/test_2dhalo.cpp | 18 +++++++++--------- perf_tests/test_osu_latency.cpp | 4 ++-- src/KokkosComm_point_to_point.hpp | 16 ++++++++-------- src/mpi/KokkosComm_mpi_irecv.hpp | 7 ++++--- src/mpi/KokkosComm_mpi_isend.hpp | 5 +++-- src/mpi/impl/KokkosComm_tags.hpp | 5 +++++ unit_tests/test_sendrecv.cpp | 8 ++++---- 10 files changed, 42 insertions(+), 40 deletions(-) create mode 100644 src/mpi/impl/KokkosComm_tags.hpp diff --git a/docs/api/core.rst b/docs/api/core.rst index e2fd6b88..743ea41b 100644 --- a/docs/api/core.rst +++ b/docs/api/core.rst @@ -6,7 +6,7 @@ Point-to-point .. cpp:namespace:: KokkosComm -.. cpp:function:: template Req send(Handle &h, SendView &sv, int dest, int tag) +.. cpp:function:: template Req send(Handle &h, SendView &sv, int dest) Initiates a non-blocking send operation. @@ -20,11 +20,10 @@ Point-to-point :param h: A handle to the execution space and transport mechanism. :param sv: The Kokkos view to send. :param dest: The destination rank. - :param tag: The message tag. :return: A request object for the non-blocking send operation. -.. cpp:function:: template Req send(SendView &sv, int dest, int tag) +.. cpp:function:: template Req send(SendView &sv, int dest) Initiates a non-blocking send operation using a default handle. @@ -37,7 +36,6 @@ Point-to-point :param sv: The Kokkos view to send. :param dest: The destination rank. - :param tag: The message tag. :return: A request object for the non-blocking send operation. @@ -48,7 +46,7 @@ Point-to-point -.. cpp:function:: template Req recv(Handle &h, RecvView &rv, int src, int tag) +.. cpp:function:: template Req recv(Handle &h, RecvView &rv, int src) Initiates a non-blocking receive operation. @@ -62,7 +60,6 @@ Point-to-point :param h: A handle to the execution space and transport mechanism. :param rv: The Kokkos view where the received data will be stored. :param src: The source rank from which to receive data. - :param tag: The message tag to identify the communication. :return: A request object of type `Req` representing the non-blocking receive operation. @@ -77,7 +74,7 @@ Point-to-point -.. cpp:function:: template Req recv(RecvView &rv, int src, int tag) +.. cpp:function:: template Req recv(RecvView &rv, int src) Initiates a non-blocking receive operation using a default handle. @@ -90,7 +87,6 @@ Point-to-point :param rv: The Kokkos view where the received data will be stored. :param src: The source rank from which to receive data. - :param tag: The message tag to identify the communication. :return: A request object of type `Req` representing the non-blocking receive operation. diff --git a/docs/api/core_recv.cpp b/docs/api/core_recv.cpp index 32e2599f..8b949bca 100644 --- a/docs/api/core_recv.cpp +++ b/docs/api/core_recv.cpp @@ -1,4 +1,4 @@ Handle<> handle; Kokkos::View recv_view("recv_view", 100); -auto req = recv(handle, recv_view, 1/*src*/, 0/*tag*/); +auto req = recv(handle, recv_view, 1/*src*/); KokkosComm::wait(req); \ No newline at end of file diff --git a/docs/api/core_send.cpp b/docs/api/core_send.cpp index 516c0797..b23d28fe 100644 --- a/docs/api/core_send.cpp +++ b/docs/api/core_send.cpp @@ -14,16 +14,15 @@ Kokkos::parallel_for("fill_data", Kokkos::RangePolicy(0, 100), KOKKOS // Destination rank and message tag int dest = 1; -int tag = 42; // Create a handle KokkosComm::Handle<> handle; // Same as Handle // Initiate a non-blocking send with a handle -auto req1 = send(handle, data, dest, tag); +auto req1 = send(handle, data, dest); // Initiate a non-blocking send with a default handle -auto req2 = send(data, dest, tag); +auto req2 = send(data, dest); // Wait for the requests to complete (assuming a wait function exists) KokkosComm::wait(req1); diff --git a/perf_tests/test_2dhalo.cpp b/perf_tests/test_2dhalo.cpp index fe2c7faa..3c67e901 100644 --- a/perf_tests/test_2dhalo.cpp +++ b/perf_tests/test_2dhalo.cpp @@ -50,15 +50,15 @@ void send_recv(benchmark::State &, MPI_Comm comm, const Space &space, int nx, in std::vector> reqs; // std::cerr << get_rank(rx, ry) << " -> " << get_rank(xp1, ry) << "\n"; - reqs.push_back(KokkosComm::send(h, xp1_s, get_rank(xp1, ry), 0)); - reqs.push_back(KokkosComm::send(h, xm1_s, get_rank(xm1, ry), 1)); - reqs.push_back(KokkosComm::send(h, yp1_s, get_rank(rx, yp1), 2)); - reqs.push_back(KokkosComm::send(h, ym1_s, get_rank(rx, ym1), 3)); - - reqs.push_back(KokkosComm::recv(h, xm1_r, get_rank(xm1, ry), 0)); - reqs.push_back(KokkosComm::recv(h, xp1_r, get_rank(xp1, ry), 1)); - reqs.push_back(KokkosComm::recv(h, ym1_r, get_rank(rx, ym1), 2)); - reqs.push_back(KokkosComm::recv(h, yp1_r, get_rank(rx, yp1), 3)); + reqs.push_back(KokkosComm::send(h, xp1_s, get_rank(xp1, ry))); + reqs.push_back(KokkosComm::send(h, xm1_s, get_rank(xm1, ry))); + reqs.push_back(KokkosComm::send(h, yp1_s, get_rank(rx, yp1))); + reqs.push_back(KokkosComm::send(h, ym1_s, get_rank(rx, ym1))); + + reqs.push_back(KokkosComm::recv(h, xm1_r, get_rank(xm1, ry))); + reqs.push_back(KokkosComm::recv(h, xp1_r, get_rank(xp1, ry))); + reqs.push_back(KokkosComm::recv(h, ym1_r, get_rank(rx, ym1))); + reqs.push_back(KokkosComm::recv(h, yp1_r, get_rank(rx, yp1))); // wait for comm KokkosComm::wait_all(reqs); diff --git a/perf_tests/test_osu_latency.cpp b/perf_tests/test_osu_latency.cpp index 2758f993..a6628ed0 100644 --- a/perf_tests/test_osu_latency.cpp +++ b/perf_tests/test_osu_latency.cpp @@ -24,9 +24,9 @@ template void osu_latency_Kokkos_Comm_sendrecv(benchmark::State &, MPI_Comm, KokkosComm::Handle<> &h, const View &v) { if (h.rank() == 0) { - KokkosComm::wait(KokkosComm::send(h, v, 1, 1)); + KokkosComm::wait(KokkosComm::send(h, v, 1)); } else if (h.rank() == 1) { - KokkosComm::wait(KokkosComm::recv(h, v, 0, 1)); + KokkosComm::wait(KokkosComm::recv(h, v, 0)); } } diff --git a/src/KokkosComm_point_to_point.hpp b/src/KokkosComm_point_to_point.hpp index d9577d48..e53b7181 100644 --- a/src/KokkosComm_point_to_point.hpp +++ b/src/KokkosComm_point_to_point.hpp @@ -25,26 +25,26 @@ namespace KokkosComm { template -Req recv(Handle &h, RecvView &rv, int src, int tag) { - return Impl::Recv::execute(h, rv, src, tag); +Req recv(Handle &h, RecvView &rv, int src) { + return Impl::Recv::execute(h, rv, src); } template -Req recv(RecvView &rv, int src, int tag) { - return recv(Handle{}, rv, src, tag); +Req recv(RecvView &rv, int src) { + return recv(Handle{}, rv, src); } template -Req send(Handle &h, SendView &sv, int dest, int tag) { - return Impl::Send::execute(h, sv, dest, tag); +Req send(Handle &h, SendView &sv, int dest) { + return Impl::Send::execute(h, sv, dest); } template -Req send(SendView &sv, int dest, int tag) { - return send(Handle{}, sv, dest, tag); +Req send(SendView &sv, int dest) { + return send(Handle{}, sv, dest); } } // namespace KokkosComm diff --git a/src/mpi/KokkosComm_mpi_irecv.hpp b/src/mpi/KokkosComm_mpi_irecv.hpp index 3504f185..1881957c 100644 --- a/src/mpi/KokkosComm_mpi_irecv.hpp +++ b/src/mpi/KokkosComm_mpi_irecv.hpp @@ -17,6 +17,7 @@ #pragma once #include "KokkosComm_mpi.hpp" +#include "impl/KokkosComm_tags.hpp" namespace KokkosComm { @@ -24,7 +25,7 @@ namespace Impl { // Recv implementation for Mpi template struct Recv { - static Req execute(Handle &h, const RecvView &rv, int src, int tag) { + static Req execute(Handle &h, const RecvView &rv, int src) { using KCT = KokkosComm::Traits; using KCPT = KokkosComm::PackTraits; using Packer = typename KCPT::packer_type; @@ -35,13 +36,13 @@ struct Recv { Req req; if (KokkosComm::is_contiguous(rv)) { space.fence("fence before irecv"); - MPI_Irecv(KokkosComm::data_handle(rv), 1, view_mpi_type(rv), src, tag, h.mpi_comm(), + MPI_Irecv(KokkosComm::data_handle(rv), 1, view_mpi_type(rv), src, POINTTOPOINT_TAG, h.mpi_comm(), &req.mpi_request()); // TODO: probably best to just use the scalar type req.extend_view_lifetime(rv); } else { Args args = Packer::allocate_packed_for(space, "TODO", rv); space.fence("fence before irecv"); - MPI_Irecv(args.view.data(), args.count, args.datatype, src, tag, h.mpi_comm(), &req.mpi_request()); + MPI_Irecv(args.view.data(), args.count, args.datatype, src, POINTTOPOINT_TAG, h.mpi_comm(), &req.mpi_request()); req.extend_view_lifetime(rv); // implicitly extends args.view lifetime since lambda holds a copy req.call_after_mpi_wait([=]() { Packer::unpack_into(space, rv, args.view); }); diff --git a/src/mpi/KokkosComm_mpi_isend.hpp b/src/mpi/KokkosComm_mpi_isend.hpp index ccea4578..ef8f542d 100644 --- a/src/mpi/KokkosComm_mpi_isend.hpp +++ b/src/mpi/KokkosComm_mpi_isend.hpp @@ -20,6 +20,7 @@ #include "KokkosComm_mpi.hpp" #include "impl/KokkosComm_types.hpp" +#include "impl/KokkosComm_tags.hpp" #include "KokkosComm_mpi_commmode.hpp" namespace KokkosComm { @@ -62,8 +63,8 @@ Req isend_impl(Handle &h, const SendView &sv, int dest, int // Implementation of KokkosComm::Send template struct Send { - static Req execute(Handle &h, const SendView &sv, int dest, int tag) { - return isend_impl(h, sv, dest, tag, mpi::DefaultCommMode{}); + static Req execute(Handle &h, const SendView &sv, int dest) { + return isend_impl(h, sv, dest, POINTTOPOINT_TAG, mpi::DefaultCommMode{}); } }; diff --git a/src/mpi/impl/KokkosComm_tags.hpp b/src/mpi/impl/KokkosComm_tags.hpp new file mode 100644 index 00000000..e1230bb0 --- /dev/null +++ b/src/mpi/impl/KokkosComm_tags.hpp @@ -0,0 +1,5 @@ +#pragma once + +namespace KokkosComm::Impl { +constexpr int POINTTOPOINT_TAG = 17; +} \ No newline at end of file diff --git a/unit_tests/test_sendrecv.cpp b/unit_tests/test_sendrecv.cpp index bbea7214..ac18e63a 100644 --- a/unit_tests/test_sendrecv.cpp +++ b/unit_tests/test_sendrecv.cpp @@ -46,10 +46,10 @@ void test_1d(const View1D &a) { int dst = 1; Kokkos::parallel_for( a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; }); - KokkosComm::wait(KokkosComm::send(h, a, dst, 0)); + KokkosComm::wait(KokkosComm::send(h, a, dst)); } else if (1 == h.rank()) { int src = 0; - KokkosComm::wait(KokkosComm::recv(h, a, src, 0)); + KokkosComm::wait(KokkosComm::recv(h, a, src)); int errs; Kokkos::parallel_reduce( a.extent(0), KOKKOS_LAMBDA(const int &i, int &lsum) { lsum += a(i) != Scalar(i); }, errs); @@ -74,10 +74,10 @@ void test_2d(const View2D &a) { int dst = 1; Kokkos::parallel_for( policy, KOKKOS_LAMBDA(int i, int j) { a(i, j) = i * a.extent(0) + j; }); - KokkosComm::wait(KokkosComm::send(h, a, dst, 0)); + KokkosComm::wait(KokkosComm::send(h, a, dst)); } else if (1 == h.rank()) { int src = 0; - KokkosComm::wait(KokkosComm::recv(h, a, src, 0)); + KokkosComm::wait(KokkosComm::recv(h, a, src)); int errs; Kokkos::parallel_reduce( policy, KOKKOS_LAMBDA(int i, int j, int &lsum) { lsum += a(i, j) != Scalar(i * a.extent(0) + j); }, errs);