From 256009780bf873d2b1a9282bc1dda359a2186490 Mon Sep 17 00:00:00 2001 From: Carl Pearson Date: Fri, 10 May 2024 13:45:52 -0600 Subject: [PATCH] mpi: 1D contiguous alltoall in Impl --- src/KokkosComm_collective.hpp | 1 + src/impl/KokkosComm_alltoall.hpp | 116 +++++++++++++++++++++++++++++++ unit_tests/CMakeLists.txt | 3 +- unit_tests/test_alltoall.cpp | 91 ++++++++++++++++++++++++ 4 files changed, 210 insertions(+), 1 deletion(-) create mode 100644 src/impl/KokkosComm_alltoall.hpp create mode 100644 unit_tests/test_alltoall.cpp diff --git a/src/KokkosComm_collective.hpp b/src/KokkosComm_collective.hpp index 45bc6964..92ccf50b 100644 --- a/src/KokkosComm_collective.hpp +++ b/src/KokkosComm_collective.hpp @@ -19,6 +19,7 @@ #include #include "KokkosComm_concepts.hpp" +#include "KokkosComm_alltoall.hpp" #include "KokkosComm_reduce.hpp" namespace KokkosComm { diff --git a/src/impl/KokkosComm_alltoall.hpp b/src/impl/KokkosComm_alltoall.hpp new file mode 100644 index 00000000..f97f33d4 --- /dev/null +++ b/src/impl/KokkosComm_alltoall.hpp @@ -0,0 +1,116 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#pragma once + +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#pragma once + +#include + +#include "KokkosComm_pack_traits.hpp" +#include "KokkosComm_traits.hpp" + +// impl +#include "KokkosComm_include_mpi.hpp" +#include "KokkosComm_types.hpp" + +namespace KokkosComm::Impl { +template +void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount, const RecvView &rv, + const size_t recvCount, MPI_Comm comm) { + Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall"); + + using ST = KokkosComm::Traits; + using RT = KokkosComm::Traits; + using SendScalar = typename SendView::value_type; + using RecvScalar = typename RecvView::value_type; + + static_assert(ST::rank() <= 1, "alltoall for SendView::rank > 1 not supported"); + static_assert(RT::rank() <= 1, "alltoall for RecvView::rank > 1 not supported"); + + if (KokkosComm::PackTraits::needs_pack(sv) || KokkosComm::PackTraits::needs_pack(rv)) { + throw std::runtime_error("alltoall for non-contiguous views not implemented"); + } else { + int size; + MPI_Comm_size(comm, &size); + + if (sendCount * size > ST::extent(sv, 0)) { + std::stringstream ss; + ss << "alltoall sendCount * communicator size (" << sendCount << " * " << size + << ") is greater than send view size"; + throw std::runtime_error(ss.str()); + } + if (recvCount * size > RT::extent(rv, 0)) { + std::stringstream ss; + ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size + << ") is greater than recv view size"; + throw std::runtime_error(ss.str()); + } + + MPI_Alltoall(ST::data_handle(sv), sendCount, mpi_type_v, RT::data_handle(rv), recvCount, + mpi_type_v, comm); + } + + Kokkos::Tools::popRegion(); +} + +// in-place alltoall +template +void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) { + Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall"); + + using RT = KokkosComm::Traits; + using RecvScalar = typename RecvView::value_type; + + static_assert(RT::rank() <= 1, "alltoall for RecvView::rank > 1 not supported"); + + if (KokkosComm::PackTraits::needs_pack(rv)) { + throw std::runtime_error("alltoall for non-contiguous views not implemented"); + } else { + int size; + MPI_Comm_size(comm, &size); + + if (recvCount * size > RT::extent(rv, 0)) { + std::stringstream ss; + ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size + << ") is greater than recv view size"; + throw std::runtime_error(ss.str()); + } + + MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, RT::data_handle(rv), recvCount, + mpi_type_v, comm); + } + + Kokkos::Tools::popRegion(); +} + +} // namespace KokkosComm::Impl diff --git a/unit_tests/CMakeLists.txt b/unit_tests/CMakeLists.txt index 029ff87d..cc69bb08 100644 --- a/unit_tests/CMakeLists.txt +++ b/unit_tests/CMakeLists.txt @@ -45,9 +45,10 @@ target_link_libraries(test-mpi MPI::MPI_CXX) add_executable(test-main test_main.cpp test_gtest_mpi.cpp test_isendrecv.cpp - test_reduce.cpp test_sendrecv.cpp test_barrier.cpp + test_alltoall.cpp + test_reduce.cpp ) target_link_libraries(test-main KokkosComm::KokkosComm gtest) if(KOKKOSCOMM_ENABLE_TESTS) diff --git a/unit_tests/test_alltoall.cpp b/unit_tests/test_alltoall.cpp new file mode 100644 index 00000000..61397734 --- /dev/null +++ b/unit_tests/test_alltoall.cpp @@ -0,0 +1,91 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#include + +#include "KokkosComm.hpp" + +namespace { + +template +class Alltoall : public testing::Test { + public: + using Scalar = T; +}; + +using ScalarTypes = ::testing::Types, Kokkos::complex>; +TYPED_TEST_SUITE(Alltoall, ScalarTypes); + +TYPED_TEST(Alltoall, 1D_contig) { + using TestScalar = typename TestFixture::Scalar; + + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + const int nContrib = 10; + + Kokkos::View sv("sv", size * nContrib); + Kokkos::View rv("rv", size * nContrib); + + // fill send buffer + Kokkos::parallel_for( + sv.extent(0), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; }); + + KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), sv, nContrib, rv, nContrib, MPI_COMM_WORLD); + + int errs; + Kokkos::parallel_reduce( + rv.extent(0), + KOKKOS_LAMBDA(const int &i, int &lsum) { + const int src = i / nContrib; // who sent this data + const int j = rank * nContrib + (i % nContrib); // what index i was at the source + lsum += rv(i) != src + j; + }, + errs); + EXPECT_EQ(errs, 0); +} + +TYPED_TEST(Alltoall, 1D_inplace_contig) { + using TestScalar = typename TestFixture::Scalar; + + int rank, size; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &size); + + const int nContrib = 10; + + Kokkos::View rv("rv", size * nContrib); + + // fill send buffer + Kokkos::parallel_for( + rv.extent(0), KOKKOS_LAMBDA(const int i) { rv(i) = rank + i; }); + + KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), rv, nContrib, MPI_COMM_WORLD); + + int errs; + Kokkos::parallel_reduce( + rv.extent(0), + KOKKOS_LAMBDA(const int &i, int &lsum) { + const int src = i / nContrib; // who sent this data + const int j = rank * nContrib + (i % nContrib); // what index i was at the source + lsum += rv(i) != src + j; + }, + errs); + EXPECT_EQ(errs, 0); +} + +} // namespace