Skip to content

Commit

Permalink
Remove tags from core interface
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Aug 12, 2024
1 parent 595794a commit dc74642
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 40 deletions.
12 changes: 4 additions & 8 deletions docs/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Point-to-point

.. cpp:namespace:: KokkosComm

.. cpp:function:: template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> send(Handle<ExecSpace, TRANSPORT> &h, SendView &sv, int dest, int tag)
.. cpp:function:: template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> send(Handle<ExecSpace, TRANSPORT> &h, SendView &sv, int dest)

Initiates a non-blocking send operation.

Expand All @@ -20,11 +20,10 @@ Point-to-point
:param h: A handle to the execution space and transport mechanism.
:param sv: The Kokkos view to send.
:param dest: The destination rank.
:param tag: The message tag.

:return: A request object for the non-blocking send operation.

.. cpp:function:: template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> send(SendView &sv, int dest, int tag)
.. cpp:function:: template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> send(SendView &sv, int dest)

Initiates a non-blocking send operation using a default handle.

Expand All @@ -37,7 +36,6 @@ Point-to-point

:param sv: The Kokkos view to send.
:param dest: The destination rank.
:param tag: The message tag.

:return: A request object for the non-blocking send operation.

Expand All @@ -48,7 +46,7 @@ Point-to-point



.. cpp:function:: template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> recv(Handle<ExecSpace, TRANSPORT> &h, RecvView &rv, int src, int tag)
.. cpp:function:: template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> recv(Handle<ExecSpace, TRANSPORT> &h, RecvView &rv, int src)

Initiates a non-blocking receive operation.

Expand All @@ -62,7 +60,6 @@ Point-to-point
:param h: A handle to the execution space and transport mechanism.
:param rv: The Kokkos view where the received data will be stored.
:param src: The source rank from which to receive data.
:param tag: The message tag to identify the communication.

:return: A request object of type `Req<TRANSPORT>` representing the non-blocking receive operation.

Expand All @@ -77,7 +74,7 @@ Point-to-point



.. cpp:function:: template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> recv(RecvView &rv, int src, int tag)
.. cpp:function:: template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace, Transport TRANSPORT = DefaultTransport> Req<TRANSPORT> recv(RecvView &rv, int src)

Initiates a non-blocking receive operation using a default handle.

Expand All @@ -90,7 +87,6 @@ Point-to-point

:param rv: The Kokkos view where the received data will be stored.
:param src: The source rank from which to receive data.
:param tag: The message tag to identify the communication.

:return: A request object of type `Req<TRANSPORT>` representing the non-blocking receive operation.

Expand Down
2 changes: 1 addition & 1 deletion docs/api/core_recv.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Handle<> handle;
Kokkos::View<double*> recv_view("recv_view", 100);
auto req = recv(handle, recv_view, 1/*src*/, 0/*tag*/);
auto req = recv(handle, recv_view, 1/*src*/);
KokkosComm::wait(req);
5 changes: 2 additions & 3 deletions docs/api/core_send.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ Kokkos::parallel_for("fill_data", Kokkos::RangePolicy<ExecSpace>(0, 100), KOKKOS

// Destination rank and message tag
int dest = 1;
int tag = 42;

// Create a handle
KokkosComm::Handle<> handle; // Same as Handle<Execspace, Transport>

// Initiate a non-blocking send with a handle
auto req1 = send(handle, data, dest, tag);
auto req1 = send(handle, data, dest);

// Initiate a non-blocking send with a default handle
auto req2 = send(data, dest, tag);
auto req2 = send(data, dest);

// Wait for the requests to complete (assuming a wait function exists)
KokkosComm::wait(req1);
Expand Down
18 changes: 9 additions & 9 deletions perf_tests/test_2dhalo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ void send_recv(benchmark::State &, MPI_Comm comm, const Space &space, int nx, in

std::vector<KokkosComm::Req<>> reqs;
// std::cerr << get_rank(rx, ry) << " -> " << get_rank(xp1, ry) << "\n";
reqs.push_back(KokkosComm::send(h, xp1_s, get_rank(xp1, ry), 0));
reqs.push_back(KokkosComm::send(h, xm1_s, get_rank(xm1, ry), 1));
reqs.push_back(KokkosComm::send(h, yp1_s, get_rank(rx, yp1), 2));
reqs.push_back(KokkosComm::send(h, ym1_s, get_rank(rx, ym1), 3));

reqs.push_back(KokkosComm::recv(h, xm1_r, get_rank(xm1, ry), 0));
reqs.push_back(KokkosComm::recv(h, xp1_r, get_rank(xp1, ry), 1));
reqs.push_back(KokkosComm::recv(h, ym1_r, get_rank(rx, ym1), 2));
reqs.push_back(KokkosComm::recv(h, yp1_r, get_rank(rx, yp1), 3));
reqs.push_back(KokkosComm::send(h, xp1_s, get_rank(xp1, ry)));
reqs.push_back(KokkosComm::send(h, xm1_s, get_rank(xm1, ry)));
reqs.push_back(KokkosComm::send(h, yp1_s, get_rank(rx, yp1)));
reqs.push_back(KokkosComm::send(h, ym1_s, get_rank(rx, ym1)));

reqs.push_back(KokkosComm::recv(h, xm1_r, get_rank(xm1, ry)));
reqs.push_back(KokkosComm::recv(h, xp1_r, get_rank(xp1, ry)));
reqs.push_back(KokkosComm::recv(h, ym1_r, get_rank(rx, ym1)));
reqs.push_back(KokkosComm::recv(h, yp1_r, get_rank(rx, yp1)));

// wait for comm
KokkosComm::wait_all(reqs);
Expand Down
4 changes: 2 additions & 2 deletions perf_tests/test_osu_latency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
template <typename Space, typename View>
void osu_latency_Kokkos_Comm_sendrecv(benchmark::State &, MPI_Comm, KokkosComm::Handle<> &h, const View &v) {
if (h.rank() == 0) {
KokkosComm::wait(KokkosComm::send(h, v, 1, 1));
KokkosComm::wait(KokkosComm::send(h, v, 1));
} else if (h.rank() == 1) {
KokkosComm::wait(KokkosComm::recv(h, v, 0, 1));
KokkosComm::wait(KokkosComm::recv(h, v, 0));
}
}

Expand Down
16 changes: 8 additions & 8 deletions src/KokkosComm_point_to_point.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,26 @@ namespace KokkosComm {

template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
CommunicationSpace CommSpace = DefaultCommunicationSpace>
Req<CommSpace> recv(Handle<ExecSpace, CommSpace> &h, RecvView &rv, int src, int tag) {
return Impl::Recv<RecvView, ExecSpace, CommSpace>::execute(h, rv, src, tag);
Req<CommSpace> recv(Handle<ExecSpace, CommSpace> &h, RecvView &rv, int src) {
return Impl::Recv<RecvView, ExecSpace, CommSpace>::execute(h, rv, src);
}

template <KokkosView RecvView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
CommunicationSpace CommSpace = DefaultCommunicationSpace>
Req<CommSpace> recv(RecvView &rv, int src, int tag) {
return recv<RecvView, ExecSpace, CommSpace>(Handle<ExecSpace, CommSpace>{}, rv, src, tag);
Req<CommSpace> recv(RecvView &rv, int src) {
return recv<RecvView, ExecSpace, CommSpace>(Handle<ExecSpace, CommSpace>{}, rv, src);
}

template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
CommunicationSpace CommSpace = DefaultCommunicationSpace>
Req<CommSpace> send(Handle<ExecSpace, CommSpace> &h, SendView &sv, int dest, int tag) {
return Impl::Send<SendView, ExecSpace, CommSpace>::execute(h, sv, dest, tag);
Req<CommSpace> send(Handle<ExecSpace, CommSpace> &h, SendView &sv, int dest) {
return Impl::Send<SendView, ExecSpace, CommSpace>::execute(h, sv, dest);
}

template <KokkosView SendView, KokkosExecutionSpace ExecSpace = Kokkos::DefaultExecutionSpace,
CommunicationSpace CommSpace = DefaultCommunicationSpace>
Req<CommSpace> send(SendView &sv, int dest, int tag) {
return send<SendView, ExecSpace, CommSpace>(Handle<ExecSpace, CommSpace>{}, sv, dest, tag);
Req<CommSpace> send(SendView &sv, int dest) {
return send<SendView, ExecSpace, CommSpace>(Handle<ExecSpace, CommSpace>{}, sv, dest);
}

} // namespace KokkosComm
7 changes: 4 additions & 3 deletions src/mpi/KokkosComm_mpi_irecv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
#pragma once

#include "KokkosComm_mpi.hpp"
#include "impl/KokkosComm_tags.hpp"

namespace KokkosComm {

namespace Impl {
// Recv implementation for Mpi
template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
struct Recv<RecvView, ExecSpace, Mpi> {
static Req<Mpi> execute(Handle<ExecSpace, Mpi> &h, const RecvView &rv, int src, int tag) {
static Req<Mpi> execute(Handle<ExecSpace, Mpi> &h, const RecvView &rv, int src) {
using KCT = KokkosComm::Traits<RecvView>;
using KCPT = KokkosComm::PackTraits<RecvView>;
using Packer = typename KCPT::packer_type;
Expand All @@ -35,13 +36,13 @@ struct Recv<RecvView, ExecSpace, Mpi> {
Req<Mpi> req;
if (KokkosComm::is_contiguous(rv)) {
space.fence("fence before irecv");
MPI_Irecv(KokkosComm::data_handle(rv), 1, view_mpi_type(rv), src, tag, h.mpi_comm(),
MPI_Irecv(KokkosComm::data_handle(rv), 1, view_mpi_type(rv), src, POINTTOPOINT_TAG, h.mpi_comm(),
&req.mpi_request()); // TODO: probably best to just use the scalar type
req.extend_view_lifetime(rv);
} else {
Args args = Packer::allocate_packed_for(space, "TODO", rv);
space.fence("fence before irecv");
MPI_Irecv(args.view.data(), args.count, args.datatype, src, tag, h.mpi_comm(), &req.mpi_request());
MPI_Irecv(args.view.data(), args.count, args.datatype, src, POINTTOPOINT_TAG, h.mpi_comm(), &req.mpi_request());
req.extend_view_lifetime(rv);
// implicitly extends args.view lifetime since lambda holds a copy
req.call_after_mpi_wait([=]() { Packer::unpack_into(space, rv, args.view); });
Expand Down
5 changes: 3 additions & 2 deletions src/mpi/KokkosComm_mpi_isend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "KokkosComm_mpi.hpp"
#include "impl/KokkosComm_types.hpp"
#include "impl/KokkosComm_tags.hpp"
#include "KokkosComm_mpi_commmode.hpp"

namespace KokkosComm {
Expand Down Expand Up @@ -62,8 +63,8 @@ Req<Mpi> isend_impl(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int
// Implementation of KokkosComm::Send
template <KokkosExecutionSpace ExecSpace, KokkosView SendView>
struct Send<SendView, ExecSpace, Mpi> {
static Req<Mpi> execute(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest, int tag) {
return isend_impl<ExecSpace, SendView>(h, sv, dest, tag, mpi::DefaultCommMode{});
static Req<Mpi> execute(Handle<ExecSpace, Mpi> &h, const SendView &sv, int dest) {
return isend_impl<ExecSpace, SendView>(h, sv, dest, POINTTOPOINT_TAG, mpi::DefaultCommMode{});
}
};

Expand Down
5 changes: 5 additions & 0 deletions src/mpi/impl/KokkosComm_tags.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

namespace KokkosComm::Impl {
constexpr int POINTTOPOINT_TAG = 17;
}
8 changes: 4 additions & 4 deletions unit_tests/test_sendrecv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ void test_1d(const View1D &a) {
int dst = 1;
Kokkos::parallel_for(
a.extent(0), KOKKOS_LAMBDA(const int i) { a(i) = i; });
KokkosComm::wait(KokkosComm::send(h, a, dst, 0));
KokkosComm::wait(KokkosComm::send(h, a, dst));
} else if (1 == h.rank()) {
int src = 0;
KokkosComm::wait(KokkosComm::recv(h, a, src, 0));
KokkosComm::wait(KokkosComm::recv(h, a, src));
int errs;
Kokkos::parallel_reduce(
a.extent(0), KOKKOS_LAMBDA(const int &i, int &lsum) { lsum += a(i) != Scalar(i); }, errs);
Expand All @@ -74,10 +74,10 @@ void test_2d(const View2D &a) {
int dst = 1;
Kokkos::parallel_for(
policy, KOKKOS_LAMBDA(int i, int j) { a(i, j) = i * a.extent(0) + j; });
KokkosComm::wait(KokkosComm::send(h, a, dst, 0));
KokkosComm::wait(KokkosComm::send(h, a, dst));
} else if (1 == h.rank()) {
int src = 0;
KokkosComm::wait(KokkosComm::recv(h, a, src, 0));
KokkosComm::wait(KokkosComm::recv(h, a, src));
int errs;
Kokkos::parallel_reduce(
policy, KOKKOS_LAMBDA(int i, int j, int &lsum) { lsum += a(i, j) != Scalar(i * a.extent(0) + j); }, errs);
Expand Down

0 comments on commit dc74642

Please sign in to comment.