Skip to content

Commit

Permalink
Create custom mpi types for serializable custom types
Browse files Browse the repository at this point in the history
- Add Serializable concept
- Add tests for serializable MPI datatypes
- Add doc strings

Co-authored-by: Thomas Hahn <[email protected]>
  • Loading branch information
Wentzell and Thoemi09 committed Feb 20, 2025
1 parent a22f9a8 commit 442dac8
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 18 deletions.
111 changes: 105 additions & 6 deletions c++/mpi/datatypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ namespace mpi {
D(unsigned long long, MPI_UNSIGNED_LONG_LONG);
#undef D

/**
* @brief Specialization of mpi::mpi_type for enum types.
* @tparam E C++ enum type.
*/
template <typename E>
requires(std::is_enum_v<E>)
struct mpi_type<E> : mpi_type<std::underlying_type_t<E>> {};

/**
* @brief Specialization of mpi::mpi_type for `const` types.
* @tparam T C++ type.
Expand All @@ -94,6 +102,28 @@ namespace mpi {
*/
template <typename T> constexpr bool has_mpi_type<T, std::void_t<decltype(mpi_type<T>::get())>> = true;

namespace detail {

// Helper struct to check if member types are mpi-serializable, i.e. have an associated mpi_type
struct serialize_checker {
template <typename T>
void operator&(T &)
requires(has_mpi_type<T>)
{}
};

} // namespace detail

/**
* @brief A concept that checks if objects of a type can be serialized and deserialized.
* @tparam T Type to check.
*/
template <typename T>
concept Serializable = requires(const T ac, T a, detail::serialize_checker ar) {
{ ac.serialize(ar) } -> std::same_as<void>;
{ a.deserialize(ar) } -> std::same_as<void>;
};

/**
* @brief Create a new `MPI_Datatype` from a tuple.
*
Expand Down Expand Up @@ -135,8 +165,11 @@ namespace mpi {
* @brief Specialization of mpi::mpi_type for std::tuple.
* @tparam Ts Tuple element types.
*/
template <typename... T> struct mpi_type<std::tuple<T...>> {
[[nodiscard]] static MPI_Datatype get() noexcept { return get_mpi_type(std::tuple<T...>{}); }
template <typename... Ts> struct mpi_type<std::tuple<Ts...>> {
[[nodiscard]] static MPI_Datatype get() noexcept {
static MPI_Datatype type = get_mpi_type(std::tuple<Ts...>{});
return type;
}
};

/**
Expand All @@ -156,15 +189,81 @@ namespace mpi {
* auto tie_data(foo f) {
* return std::tie(f.x, f.y);
* }
* @endcode
*
* @tparam U Type to be converted to an `MPI_Datatype`.
*/
template <typename U>
requires(not Serializable<U>) and requires(U u) { tie_data(u); }
struct mpi_type<U> {
[[nodiscard]] static MPI_Datatype get() noexcept {
static MPI_Datatype type = get_mpi_type(tie_data(U{}));
return type;
}
};

namespace detail {

// Archive helper class to obtain MPI custom type info using references to class members.
struct mpi_archive {
std::vector<int> block_lengths{};
std::vector<MPI_Aint> displacements{};
std::vector<MPI_Datatype> types{};
MPI_Aint base_address{};

// Constructor sets the base address of the object.
explicit mpi_archive(const void *base) { MPI_Get_address(base, &base_address); }

// Overloaded operator& to process members to set the block lengths, displacements and MPI types.
template <typename T>
requires(has_mpi_type<T>)
mpi_archive &operator&(const T &member) {
types.push_back(mpi_type<T>::get());
MPI_Aint address{};
MPI_Get_address(&member, &address);
displacements.push_back(MPI_Aint_diff(address, base_address));
block_lengths.push_back(1);
return *this;
}
};

} // namespace detail

/**
* @brief Create an `MPI_Datatype` from a serializable type.
*
* @details It is assumed that the type has a member function `serialize`
* which feeds all its class members into an archive using the `operator&`.
*
* // provide a specialization of mpi_type
* template <> struct mpi::mpi_type<foo> : mpi::mpi_type_from_tie<foo> {};
* @code{.cpp}
* // type to use for MPI communication
* struct foo {
* double x;
* int y;
* void serialize(auto& ar) const { ar & x & y; }
* };
* @endcode
*
* @tparam T Type to be converted to an `MPI_Datatype`.
*/
template <typename T> struct mpi_type_from_tie {
[[nodiscard]] static MPI_Datatype get() noexcept { return get_mpi_type(tie_data(T{})); }
template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type(const T &obj) {
detail::mpi_archive ar(&obj);
obj.serialize(ar);
MPI_Datatype mpi_type{};
MPI_Type_create_struct(static_cast<int>(ar.block_lengths.size()), ar.block_lengths.data(), ar.displacements.data(), ar.types.data(), &mpi_type);
MPI_Type_commit(&mpi_type);
return mpi_type;
}

/**
* @brief Specialization of mpi::mpi_type for serializable types.
* @tparam S Serializable type.
*/
template <Serializable S> struct mpi_type<S> {
[[nodiscard]] static MPI_Datatype get() noexcept {
static MPI_Datatype type = get_mpi_type(S{});
return type;
}
};

/** @} */
Expand Down
1 change: 1 addition & 0 deletions c++/mpi/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <stdexcept>
#include <string>
#include <type_traits>

namespace mpi {

Expand Down
11 changes: 8 additions & 3 deletions doc/DoxygenLayout.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,20 @@
<tab type="usergroup" url="@ref mpi::mpi_type" title="mpi_type">
<tab type="user" url="@ref mpi::mpi_type< bool >" title="mpi_type<bool>"/>
<tab type="user" url="@ref mpi::mpi_type< char >" title="mpi_type<char>"/>
<tab type="user" url="@ref mpi::mpi_type< const T >" title="mpi_type<const T>"/>
<tab type="user" url="@ref mpi::mpi_type< double >" title="mpi_type<double>"/>
<tab type="user" url="@ref mpi::mpi_type< E >" title="mpi_type<E>"/>
<tab type="user" url="@ref mpi::mpi_type< float >" title="mpi_type<float>"/>
<tab type="user" url="@ref mpi::mpi_type< int >" title="mpi_type<int>"/>
<tab type="user" url="@ref mpi::mpi_type< long >" title="mpi_type<long>"/>
<tab type="user" url="@ref mpi::mpi_type< long long >" title="mpi_type<long long>"/>
<tab type="user" url="@ref mpi::mpi_type< double >" title="mpi_type<double>"/>
<tab type="user" url="@ref mpi::mpi_type< float >" title="mpi_type<float>"/>
<tab type="user" url="@ref mpi::mpi_type< S >" title="mpi_type<S>"/>
<tab type="user" url="@ref mpi::mpi_type< std::complex< double > >" title="mpi_type<std::complex<double>>"/>
<tab type="user" url="@ref mpi::mpi_type< std::tuple< Ts... > >" title="mpi_type<std::tuple>"/>
<tab type="user" url="@ref mpi::mpi_type< U >" title="mpi_type<U>"/>
<tab type="user" url="@ref mpi::mpi_type< unsigned int >" title="mpi_type<unsigned int>"/>
<tab type="user" url="@ref mpi::mpi_type< unsigned long >" title="mpi_type<unsigned long>"/>
<tab type="user" url="@ref mpi::mpi_type< unsigned long long >" title="mpi_type<unsigned long long>"/>
<tab type="user" url="@ref mpi::mpi_type< std::tuple< Ts... > >" title="mpi_type<std::tuple>"/>
</tab>
</tab>
<tab type="user" url="@ref coll_comm" title="Collective MPI communication"/>
Expand All @@ -56,6 +60,7 @@
</tab>
<tab type="usergroup" url="@ref utilities" title="Utilities">
<tab type="user" url="@ref mpi::contiguous_sized_range" title="contiguous_sized_range"/>
<tab type="user" url="@ref mpi::Serializable" title="Serializable"/>
</tab>
<tab type="filelist" visible="yes" title="" intro=""/>
</tab>
Expand Down
8 changes: 3 additions & 5 deletions doc/ex3.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

[TOC]

In this example, we show how to use mpi::mpi_type_from_tie, mpi::map_C_function and mpi::map_add to register a new MPI datatype and to define MPI operations for it.
In this example, we show how to register a new MPI datatype and how to use mpi::map_C_function and mpi::map_add to
define MPI operations for it.

```cpp
#include <mpi/mpi.hpp>
Expand All @@ -19,14 +20,11 @@ inline my_complex operator+(const my_complex& z1, const my_complex& z2) {
return { z1.real + z2.real, z1.imag + z2.imag };
}

// define a tie_data function for mpi_type_from_tie
// define a tie_data function for my_complex to make it MPI compatible
inline auto tie_data(const my_complex& z) {
return std::tie(z.real, z.imag);
}

// register my_complex as an MPI type
template <> struct mpi::mpi_type<my_complex> : mpi::mpi_type_from_tie<my_complex> {};

int main(int argc, char *argv[]) {
// initialize MPI environment
mpi::environment env(argc, argv);
Expand Down
84 changes: 80 additions & 4 deletions test/c++/mpi_custom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ struct custom_cplx {
// tie the data (used to construct the custom MPI type)
inline auto tie_data(custom_cplx z) { return std::tie(z.real, z.imag); }

// specialize mpi_type for custom_cplx
template <> struct mpi::mpi_type<custom_cplx> : mpi::mpi_type_from_tie<custom_cplx> {};

// stand-alone add function (the same as the operator+ above)
custom_cplx add(custom_cplx const &x, custom_cplx const &y) { return x + y; }

Expand Down Expand Up @@ -131,9 +128,88 @@ TEST(MPI, TupleMPIDatatypes) {

using type5 = std::tuple<int, double, char, custom_cplx, bool>;
type5 tup5;
if (rank == root) { tup5 = std::make_tuple(100, 3.1314, 'r', custom_cplx{1.0, 2.0}, false); }
if (rank == root) { tup5 = std::make_tuple(100, 3.1314, 'r', custom_cplx{.real = 1.0, .imag = 2.0}, false); }
mpi::broadcast(tup5, world, root);
EXPECT_EQ(tup5, std::make_tuple(100, 3.1314, 'r', custom_cplx{1.0, 2.0}, false));
}

// a simple struct representing a complex number that is serializable
struct serializable_cplx {
double real{}, imag{};

// add two serializable_cplx objects
serializable_cplx operator+(serializable_cplx z) const {
z.real += real;
z.imag += imag;
return z;
}

// default equal-to operator
bool operator==(const serializable_cplx &) const = default;

// serialize the object
void serialize(auto &ar) const { ar & real & imag; }
void deserialize(auto &ar) { ar & real & imag; }
};

// a simple struct that contains a serializable type and is serializable itself
struct serializable_container {
serializable_cplx z1;
custom_cplx z2;

// add two serializable_container objects
serializable_container operator+(serializable_container z) const {
z.z1 = z.z1 + z1;
z.z2 = z.z2 + z2;
return z;
}

// default equal-to operator
bool operator==(const serializable_container &) const = default;

// serialize the object
void serialize(auto &ar) const { ar & z1 & z2; }
void deserialize(auto &ar) { ar & z1 & z2; }
};

// check Serializable concept
static_assert(mpi::Serializable<serializable_cplx>);
static_assert(mpi::Serializable<serializable_container>);

TEST(MPI, SerializableMPIDatatypes) {
mpi::communicator world;
int rank = world.rank();
int root = 0;

// check broadcast
auto z_exp = serializable_cplx{.real = 1.0, .imag = 2.0};
auto z = (rank == root ? z_exp : serializable_cplx{});
mpi::broadcast(z, world, root);
EXPECT_EQ(z, z_exp);

// check all_reduce
auto z_red = mpi::all_reduce(z, world, mpi::map_add<serializable_cplx>());
EXPECT_DOUBLE_EQ(z_exp.real * world.size(), z_red.real);
EXPECT_DOUBLE_EQ(z.imag * world.size(), z_red.imag);
}

TEST(MPI, SerializableOfSerializableMPIDatatypes) {
mpi::communicator world;
int rank = world.rank();
int root = 0;

// check broadcast
auto c_exp = serializable_container{.z1 = {.real = 1.0, .imag = 2.0}, .z2 = {.real = 3.0, .imag = 4.0}};
auto c = (rank == root ? c_exp : serializable_container{});
mpi::broadcast(c, world, root);
EXPECT_EQ(c, c_exp);

// check all_reduce
auto c_red = mpi::all_reduce(c, world, mpi::map_add<serializable_container>());
EXPECT_DOUBLE_EQ(c_exp.z1.real * world.size(), c_red.z1.real);
EXPECT_DOUBLE_EQ(c_exp.z1.imag * world.size(), c_red.z1.imag);
EXPECT_DOUBLE_EQ(c_exp.z2.real * world.size(), c_red.z2.real);
EXPECT_DOUBLE_EQ(c_exp.z2.imag * world.size(), c_red.z2.imag);
}

MPI_TEST_MAIN;

0 comments on commit 442dac8

Please sign in to comment.