Skip to content

Commit

Permalink
mpi: 1D contiguous alltoall in Impl
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed May 15, 2024
1 parent 0ecdea0 commit 2560097
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/KokkosComm_collective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <Kokkos_Core.hpp>

#include "KokkosComm_concepts.hpp"
#include "KokkosComm_alltoall.hpp"
#include "KokkosComm_reduce.hpp"

namespace KokkosComm {
Expand Down
116 changes: 116 additions & 0 deletions src/impl/KokkosComm_alltoall.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#pragma once

#include <Kokkos_Core.hpp>

#include "KokkosComm_pack_traits.hpp"
#include "KokkosComm_traits.hpp"

// impl
#include "KokkosComm_include_mpi.hpp"
#include "KokkosComm_types.hpp"

namespace KokkosComm::Impl {
template <KokkosExecutionSpace ExecSpace, KokkosView SendView, KokkosView RecvView>
void alltoall(const ExecSpace &space, const SendView &sv, const size_t sendCount, const RecvView &rv,
const size_t recvCount, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall");

using ST = KokkosComm::Traits<SendView>;
using RT = KokkosComm::Traits<RecvView>;
using SendScalar = typename SendView::value_type;
using RecvScalar = typename RecvView::value_type;

static_assert(ST::rank() <= 1, "alltoall for SendView::rank > 1 not supported");
static_assert(RT::rank() <= 1, "alltoall for RecvView::rank > 1 not supported");

if (KokkosComm::PackTraits<SendView>::needs_pack(sv) || KokkosComm::PackTraits<RecvView>::needs_pack(rv)) {
throw std::runtime_error("alltoall for non-contiguous views not implemented");
} else {
int size;
MPI_Comm_size(comm, &size);

if (sendCount * size > ST::extent(sv, 0)) {
std::stringstream ss;
ss << "alltoall sendCount * communicator size (" << sendCount << " * " << size
<< ") is greater than send view size";
throw std::runtime_error(ss.str());
}
if (recvCount * size > RT::extent(rv, 0)) {
std::stringstream ss;
ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size
<< ") is greater than recv view size";
throw std::runtime_error(ss.str());
}

MPI_Alltoall(ST::data_handle(sv), sendCount, mpi_type_v<SendScalar>, RT::data_handle(rv), recvCount,
mpi_type_v<RecvScalar>, comm);
}

Kokkos::Tools::popRegion();
}

// in-place alltoall
template <KokkosExecutionSpace ExecSpace, KokkosView RecvView>
void alltoall(const ExecSpace &space, const RecvView &rv, const size_t recvCount, MPI_Comm comm) {
Kokkos::Tools::pushRegion("KokkosComm::Impl::alltoall");

using RT = KokkosComm::Traits<RecvView>;
using RecvScalar = typename RecvView::value_type;

static_assert(RT::rank() <= 1, "alltoall for RecvView::rank > 1 not supported");

if (KokkosComm::PackTraits<RecvView>::needs_pack(rv)) {
throw std::runtime_error("alltoall for non-contiguous views not implemented");
} else {
int size;
MPI_Comm_size(comm, &size);

if (recvCount * size > RT::extent(rv, 0)) {
std::stringstream ss;
ss << "alltoall recvCount * communicator size (" << recvCount << " * " << size
<< ") is greater than recv view size";
throw std::runtime_error(ss.str());
}

MPI_Alltoall(MPI_IN_PLACE, 0 /*ignored*/, MPI_BYTE /*ignored*/, RT::data_handle(rv), recvCount,
mpi_type_v<RecvScalar>, comm);
}

Kokkos::Tools::popRegion();
}

} // namespace KokkosComm::Impl
3 changes: 2 additions & 1 deletion unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ target_link_libraries(test-mpi MPI::MPI_CXX)
add_executable(test-main test_main.cpp
test_gtest_mpi.cpp
test_isendrecv.cpp
test_reduce.cpp
test_sendrecv.cpp
test_barrier.cpp
test_alltoall.cpp
test_reduce.cpp
)
target_link_libraries(test-main KokkosComm::KokkosComm gtest)
if(KOKKOSCOMM_ENABLE_TESTS)
Expand Down
91 changes: 91 additions & 0 deletions unit_tests/test_alltoall.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <gtest/gtest.h>

#include "KokkosComm.hpp"

namespace {

template <typename T>
class Alltoall : public testing::Test {
public:
using Scalar = T;
};

using ScalarTypes = ::testing::Types<int, int64_t, float, double, Kokkos::complex<float>, Kokkos::complex<double>>;
TYPED_TEST_SUITE(Alltoall, ScalarTypes);

TYPED_TEST(Alltoall, 1D_contig) {
using TestScalar = typename TestFixture::Scalar;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);

const int nContrib = 10;

Kokkos::View<TestScalar *> sv("sv", size * nContrib);
Kokkos::View<TestScalar *> rv("rv", size * nContrib);

// fill send buffer
Kokkos::parallel_for(
sv.extent(0), KOKKOS_LAMBDA(const int i) { sv(i) = rank + i; });

KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), sv, nContrib, rv, nContrib, MPI_COMM_WORLD);

int errs;
Kokkos::parallel_reduce(
rv.extent(0),
KOKKOS_LAMBDA(const int &i, int &lsum) {
const int src = i / nContrib; // who sent this data
const int j = rank * nContrib + (i % nContrib); // what index i was at the source
lsum += rv(i) != src + j;
},
errs);
EXPECT_EQ(errs, 0);
}

TYPED_TEST(Alltoall, 1D_inplace_contig) {
using TestScalar = typename TestFixture::Scalar;

int rank, size;
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);

const int nContrib = 10;

Kokkos::View<TestScalar *> rv("rv", size * nContrib);

// fill send buffer
Kokkos::parallel_for(
rv.extent(0), KOKKOS_LAMBDA(const int i) { rv(i) = rank + i; });

KokkosComm::Impl::alltoall(Kokkos::DefaultExecutionSpace(), rv, nContrib, MPI_COMM_WORLD);

int errs;
Kokkos::parallel_reduce(
rv.extent(0),
KOKKOS_LAMBDA(const int &i, int &lsum) {
const int src = i / nContrib; // who sent this data
const int j = rank * nContrib + (i % nContrib); // what index i was at the source
lsum += rv(i) != src + j;
},
errs);
EXPECT_EQ(errs, 0);
}

} // namespace

0 comments on commit 2560097

Please sign in to comment.