Skip to content

Commit

Permalink
Cover some more MPI_Win_* API surface
Browse files Browse the repository at this point in the history
  • Loading branch information
hmenke committed Oct 9, 2023
1 parent 1738e29 commit 879dc69
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 24 deletions.
104 changes: 81 additions & 23 deletions c++/mpi/mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,73 +368,131 @@ namespace mpi {
window& operator=(window const&) = delete;
window& operator=(window &&) = delete;

explicit window(communicator &c, BaseType *base, MPI_Aint size = 0) {
/// Create a window over an existing local memory buffer
explicit window(communicator &c, BaseType *base, MPI_Aint size = 0) noexcept {
MPI_Win_create(base, size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &win);
}

/// Create a window and allocate memory for a local memory buffer
explicit window(communicator &c, MPI_Aint size = 0) noexcept {
void *baseptr = nullptr;
MPI_Win_allocate(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &win);
}

~window() {
if (win != MPI_WIN_NULL) {
MPI_Win_free(&win);
}
}

operator MPI_Win() const { return win; };
operator MPI_Win*() { return &win; };
explicit operator MPI_Win() const noexcept { return win; };
explicit operator MPI_Win*() noexcept { return &win; };

void fence(int assert = 0) const {
/// Synchronization routine in active target RMA. It opens and closes an access epoch.
void fence(int assert = 0) const noexcept {
MPI_Win_fence(assert, win);
}

/// Complete all outstanding RMA operations at both the origin and the target
void flush(int rank = -1) const noexcept {
if (rank < 0) {
MPI_Win_flush_all(win);
} else {
MPI_Win_flush(rank, win);
}
}

/// Synchronize the private and public copies of the window
void sync() const noexcept {
MPI_Win_sync(win);
}

/// Starts an RMA access epoch locking access to a particular or all ranks in the window
void lock(int rank = -1, int lock_type = MPI_LOCK_SHARED, int assert = 0) const noexcept {
if (rank < 0) {
MPI_Win_lock_all(assert, win);
} else {
MPI_Win_lock(lock_type, rank, assert, win);
}
}

/// Completes an RMA access epoch started by a call to lock()
void unlock(int rank = -1) const noexcept {
if (rank < 0) {
MPI_Win_unlock_all(win);
} else {
MPI_Win_unlock(rank, win);
}
}

/// Load data from a remote memory window.
template <typename TargetType = BaseType, typename OriginType>
std::enable_if_t<has_mpi_type<OriginType> && has_mpi_type<TargetType>, void>
get(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const {
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
int target_count_ = target_count < 0 ? origin_count : target_count;
MPI_Get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
get(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const noexcept {
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
int target_count_ = target_count < 0 ? origin_count : target_count;
MPI_Get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
};

/// Store data to a remote memory window.
template <typename TargetType = BaseType, typename OriginType>
std::enable_if_t<has_mpi_type<OriginType> && has_mpi_type<TargetType>, void>
put(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const {
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
int target_count_ = target_count < 0 ? origin_count : target_count;
MPI_Put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
put(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const noexcept {
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
int target_count_ = target_count < 0 ? origin_count : target_count;
MPI_Put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
};

void* get_attr(int win_keyval) const {
/// Accumulate data into target process through remote memory access.
template <typename TargetType = BaseType, typename OriginType>
std::enable_if_t<has_mpi_type<OriginType> && has_mpi_type<TargetType>, void>
accumulate(OriginType const *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1, MPI_Op op = MPI_SUM) const noexcept {
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
int target_count_ = target_count < 0 ? origin_count : target_count;
MPI_Accumulate(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, op, win);
}

/// Obtains the value of a window attribute.
void* get_attr(int win_keyval) const noexcept {
int flag;
void *attribute_val;
MPI_Win_get_attr(win, win_keyval, &attribute_val, &flag);
assert(flag);
return attribute_val;
}
BaseType* base() const { return static_cast<BaseType*>(get_attr(MPI_WIN_BASE)); }
MPI_Aint size() const { return *static_cast<MPI_Aint*>(get_attr(MPI_WIN_SIZE)); }
int disp_unit() const { return *static_cast<int*>(get_attr(MPI_WIN_DISP_UNIT)); }

// Expose some commonly used attributes
BaseType* base() const noexcept { return static_cast<BaseType*>(get_attr(MPI_WIN_BASE)); }
MPI_Aint size() const noexcept { return *static_cast<MPI_Aint*>(get_attr(MPI_WIN_SIZE)); }
int disp_unit() const noexcept { return *static_cast<int*>(get_attr(MPI_WIN_DISP_UNIT)); }
};

/// The shared_window class
template <class BaseType>
class shared_window : public window<BaseType> {
public:
shared_window(shared_communicator& c, MPI_Aint size) {
/// Create a window and allocate memory for a shared memory buffer
shared_window(shared_communicator& c, MPI_Aint size) noexcept {
void* baseptr = nullptr;
MPI_Win_allocate_shared(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &(this->win));
}

std::tuple<MPI_Aint, int, void*> query(int rank = MPI_PROC_NULL) const {
/// Query a shared memory window
std::tuple<MPI_Aint, int, void*> query(int rank = MPI_PROC_NULL) const noexcept {
MPI_Aint size = 0;
int disp_unit = 0;
void *baseptr = nullptr;
MPI_Win_shared_query(this->win, rank, &size, &disp_unit, &baseptr);
return {size, disp_unit, baseptr};
}

MPI_Aint size(int rank = MPI_PROC_NULL) const { return std::get<0>(query(rank)) / sizeof(BaseType); }
int disp_unit(int rank = MPI_PROC_NULL) const { return std::get<1>(query(rank)); }
BaseType* base(int rank = MPI_PROC_NULL) const { return static_cast<BaseType*>(std::get<2>(query(rank))); }
// Override the commonly used attributes of the window base class
BaseType* base(int rank = MPI_PROC_NULL) const noexcept { return static_cast<BaseType*>(std::get<2>(query(rank))); }
MPI_Aint size(int rank = MPI_PROC_NULL) const noexcept { return std::get<0>(query(rank)) / sizeof(BaseType); }
int disp_unit(int rank = MPI_PROC_NULL) const noexcept { return std::get<1>(query(rank)); }
};

/* -----------------------------------------------------------
Expand Down
69 changes: 68 additions & 1 deletion test/c++/mpi_window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
#include <gtest/gtest.h>
#include <numeric>

// Test cases are adapted from slides and exercises of the HLRS course:
// Introduction to the Message Passing Interface (MPI)
// Authors: Joel Malard, Alan Simpson, (EPCC)
// Rolf Rabenseifner, Traugott Streicher, Tobias Haas (HLRS)
// https://fs.hlrs.de/projects/par/par_prog_ws/pdf/mpi_3.1_rab.pdf
// https://fs.hlrs.de/projects/par/par_prog_ws/practical/MPI31single.tar.gz

TEST(MPI_Window, SharedCommunicator) {
mpi::communicator world;
[[maybe_unused]] auto shm = world.split_shared();
Expand Down Expand Up @@ -68,7 +75,7 @@ TEST(MPI_Window, RingOneSidedPut) {
EXPECT_EQ(sum, (size * (size - 1)) / 2);
}

TEST(MPI_Window, RingOneSidedAllowShared) {
TEST(MPI_Window, RingOneSidedAllocShared) {
mpi::communicator world;
auto shm = world.split_shared();
int const rank_shm = shm.rank();
Expand All @@ -91,6 +98,66 @@ TEST(MPI_Window, RingOneSidedAllowShared) {
EXPECT_EQ(sum, (size_shm * (size_shm - 1)) / 2);
}

TEST(MPI_Window, RingOneSidedStoreWinAllocSharedSignal) {
mpi::communicator world;
auto shm = world.split_shared();

int const rank_shm = shm.rank();
int const size_shm = shm.size();
int const right = (rank_shm+1) % size_shm;
int const left = (rank_shm-1+size_shm) % size_shm;

mpi::shared_window<int> win{shm, 1};
int *rcv_buf_ptr = win.base(rank_shm);
win.lock();

int sum = 0;
int snd_buf = rank_shm;

MPI_Request rq;
MPI_Status status;
int snd_dummy, rcv_dummy;

for(int i = 0; i < size_shm; ++i) {
// ... The local Win_syncs are needed to sync the processor and real memory.
// ... The following pair of syncs is needed that the read-write-rule is fulfilled.
win.sync();

// ... tag=17: posting to left that rcv_buf is exposed to left, i.e.,
// the left process is now allowed to store data into the local rcv_buf
MPI_Irecv(&rcv_dummy, 0, MPI_INT, right, 17, shm.get(), &rq);
MPI_Send (&snd_dummy, 0, MPI_INT, left, 17, shm.get());
MPI_Wait(&rq, &status);

win.sync();

// MPI_Put(&snd_buf, 1, MPI_INT, right, (MPI_Aint) 0, 1, MPI_INT, win);
// ... is substited by (with offset "right-my_rank" to store into right neigbor's rcv_buf):
*(rcv_buf_ptr+(right-rank_shm)) = snd_buf;


// ... The following pair of syncs is needed that the write-read-rule is fulfilled.
win.sync();

// ... The following communication synchronizes the processors in the way
// that the origin processor has finished the store
// before the target processor starts to load the data.
// ... tag=18: posting to right that rcv_buf was stored from left
MPI_Irecv(&rcv_dummy, 0, MPI_INT, left, 18, shm.get(), &rq);
MPI_Send (&snd_dummy, 0, MPI_INT, right, 18, shm.get());
MPI_Wait(&rq, &status);

win.sync();

snd_buf = *rcv_buf_ptr;
sum += *rcv_buf_ptr;
}

EXPECT_EQ(sum, (size_shm * (size_shm - 1)) / 2);

win.unlock();
}

TEST(MPI_Window, SharedArray) {
mpi::communicator world;
auto shm = world.split_shared();
Expand Down

0 comments on commit 879dc69

Please sign in to comment.