From 32e9568f21cbe9b9babf0458df31b310cf4d8f19 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Mon, 13 May 2024 09:51:57 -0600 Subject: [PATCH] mpi: in-place alltoall --- src/impl/KokkosComm_alltoall.hpp | 36 +++++++++++++++++++++++++++++++- unit_tests/test_alltoall.cpp | 31 +++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/impl/KokkosComm_alltoall.hpp b/src/impl/KokkosComm_alltoall.hpp index 1bcfe092..a05cf8e5 100644 --- a/src/impl/KokkosComm_alltoall.hpp +++ b/src/impl/KokkosComm_alltoall.hpp @@ -88,4 +88,38 @@ void alltoall(const ExecSpace &space, const SendView &sv, Kokkos::Tools::popRegion(); } -} // namespace KokkosComm::Impl \ No newline at end of file + +// in-place alltoall +template +void alltoall(const ExecSpace &space, const RecvView &rv, + const size_t recvCount, MPI_Comm comm) { + Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall"); + + using RT = KokkosComm::Traits; + using RecvScalar = typename RecvView::value_type; + + static_assert(RT::rank() <= 1, + "alltoall for RecvView::rank > 1 not supported"); + + if (KokkosComm::PackTraits::needs_pack(rv)) { + throw std::runtime_error( + "alltoall for non-contiguous views not implemented"); + } else { + int size; + MPI_Comm_size(comm, &size); + + if (recvCount * size > RT::extent(rv, 0)) { + std::stringstream ss; + ss << "alltoall recvCount * communicator size (" << recvCount << " * " + << size << ") is greater than recv view size"; + throw std::runtime_error(ss.str()); + } + + MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, + RT::data_handle(rv), recvCount, mpi_type_v, comm); + } + + Kokkos::Tools::popRegion(); +} + +} // namespace KokkosComm::Impl diff --git a/unit_tests/test_alltoall.cpp b/unit_tests/test_alltoall.cpp index a2542a7a..253357bd 100644 --- a/unit_tests/test_alltoall.cpp +++ b/unit_tests/test_alltoall.cpp @@ -63,4 +63,35 @@ TYPED_TEST(Alltoall, 1D_contig) { EXPECT_EQ(errs, 0); } +TYPED_TEST(Alltoall, 1D_inplace_contig) { + using TestScalar = typename TestFixture::Scalar; + + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + const int nContrib = 10; + + Kokkos::View rv("rv", size * nContrib); + + // fill send buffer + Kokkos::parallel_for( + rv.extent(0), KOKKOS_LAMBDA(const int i) { rv(i) = rank + i; }); + + KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), rv, nContrib, + MPI_COMM_WORLD); + + int errs; + Kokkos::parallel_reduce( + rv.extent(0), + KOKKOS_LAMBDA(const int &i, int &lsum) { + const int src = i / nContrib; // who sent this data + const int j = + rank * nContrib + (i % nContrib); // what index i was at the source + lsum += rv(i) != src + j; + }, + errs); + EXPECT_EQ(errs, 0); +} + } // namespace