From 879dc69ca6c261d5f04428bd5c94e30ef55ff0f5 Mon Sep 17 00:00:00 2001 From: Henri Menke Date: Mon, 9 Oct 2023 11:59:58 +0200 Subject: [PATCH] Cover some more MPI_Win_* API surface --- c++/mpi/mpi.hpp | 104 +++++++++++++++++++++++++++++++--------- test/c++/mpi_window.cpp | 69 +++++++++++++++++++++++++- 2 files changed, 149 insertions(+), 24 deletions(-) diff --git a/c++/mpi/mpi.hpp b/c++/mpi/mpi.hpp index a121fdc2..49247f05 100644 --- a/c++/mpi/mpi.hpp +++ b/c++/mpi/mpi.hpp @@ -368,63 +368,120 @@ namespace mpi { window& operator=(window const&) = delete; window& operator=(window &&) = delete; - explicit window(communicator &c, BaseType *base, MPI_Aint size = 0) { + /// Create a window over an existing local memory buffer + explicit window(communicator &c, BaseType *base, MPI_Aint size = 0) noexcept { MPI_Win_create(base, size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &win); } + /// Create a window and allocate memory for a local memory buffer + explicit window(communicator &c, MPI_Aint size = 0) noexcept { + void *baseptr = nullptr; + MPI_Win_allocate(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &win); + } + ~window() { if (win != MPI_WIN_NULL) { MPI_Win_free(&win); } } - operator MPI_Win() const { return win; }; - operator MPI_Win*() { return &win; }; + explicit operator MPI_Win() const noexcept { return win; }; + explicit operator MPI_Win*() noexcept { return &win; }; - void fence(int assert = 0) const { + /// Synchronization routine in active target RMA. It opens and closes an access epoch. + void fence(int assert = 0) const noexcept { MPI_Win_fence(assert, win); } + /// Complete all outstanding RMA operations at both the origin and the target + void flush(int rank = -1) const noexcept { + if (rank < 0) { + MPI_Win_flush_all(win); + } else { + MPI_Win_flush(rank, win); + } + } + + /// Synchronize the private and public copies of the window + void sync() const noexcept { + MPI_Win_sync(win); + } + + /// Starts an RMA access epoch locking access to a particular or all ranks in the window + void lock(int rank = -1, int lock_type = MPI_LOCK_SHARED, int assert = 0) const noexcept { + if (rank < 0) { + MPI_Win_lock_all(assert, win); + } else { + MPI_Win_lock(lock_type, rank, assert, win); + } + } + + /// Completes an RMA access epoch started by a call to lock() + void unlock(int rank = -1) const noexcept { + if (rank < 0) { + MPI_Win_unlock_all(win); + } else { + MPI_Win_unlock(rank, win); + } + } + + /// Load data from a remote memory window. template std::enable_if_t && has_mpi_type, void> - get(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const { - MPI_Datatype origin_datatype = mpi_type::get(); - MPI_Datatype target_datatype = mpi_type::get(); - int target_count_ = target_count < 0 ? origin_count : target_count; - MPI_Get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win); + get(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const noexcept { + MPI_Datatype origin_datatype = mpi_type::get(); + MPI_Datatype target_datatype = mpi_type::get(); + int target_count_ = target_count < 0 ? origin_count : target_count; + MPI_Get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win); }; + /// Store data to a remote memory window. template std::enable_if_t && has_mpi_type, void> - put(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const { - MPI_Datatype origin_datatype = mpi_type::get(); - MPI_Datatype target_datatype = mpi_type::get(); - int target_count_ = target_count < 0 ? origin_count : target_count; - MPI_Put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win); + put(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const noexcept { + MPI_Datatype origin_datatype = mpi_type::get(); + MPI_Datatype target_datatype = mpi_type::get(); + int target_count_ = target_count < 0 ? origin_count : target_count; + MPI_Put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win); }; - void* get_attr(int win_keyval) const { + /// Accumulate data into target process through remote memory access. + template + std::enable_if_t && has_mpi_type, void> + accumulate(OriginType const *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1, MPI_Op op = MPI_SUM) const noexcept { + MPI_Datatype origin_datatype = mpi_type::get(); + MPI_Datatype target_datatype = mpi_type::get(); + int target_count_ = target_count < 0 ? origin_count : target_count; + MPI_Accumulate(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, op, win); + } + + /// Obtains the value of a window attribute. + void* get_attr(int win_keyval) const noexcept { int flag; void *attribute_val; MPI_Win_get_attr(win, win_keyval, &attribute_val, &flag); assert(flag); return attribute_val; } - BaseType* base() const { return static_cast(get_attr(MPI_WIN_BASE)); } - MPI_Aint size() const { return *static_cast(get_attr(MPI_WIN_SIZE)); } - int disp_unit() const { return *static_cast(get_attr(MPI_WIN_DISP_UNIT)); } + + // Expose some commonly used attributes + BaseType* base() const noexcept { return static_cast(get_attr(MPI_WIN_BASE)); } + MPI_Aint size() const noexcept { return *static_cast(get_attr(MPI_WIN_SIZE)); } + int disp_unit() const noexcept { return *static_cast(get_attr(MPI_WIN_DISP_UNIT)); } }; /// The shared_window class template class shared_window : public window { public: - shared_window(shared_communicator& c, MPI_Aint size) { + /// Create a window and allocate memory for a shared memory buffer + shared_window(shared_communicator& c, MPI_Aint size) noexcept { void* baseptr = nullptr; MPI_Win_allocate_shared(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &(this->win)); } - std::tuple query(int rank = MPI_PROC_NULL) const { + /// Query a shared memory window + std::tuple query(int rank = MPI_PROC_NULL) const noexcept { MPI_Aint size = 0; int disp_unit = 0; void *baseptr = nullptr; @@ -432,9 +489,10 @@ namespace mpi { return {size, disp_unit, baseptr}; } - MPI_Aint size(int rank = MPI_PROC_NULL) const { return std::get<0>(query(rank)) / sizeof(BaseType); } - int disp_unit(int rank = MPI_PROC_NULL) const { return std::get<1>(query(rank)); } - BaseType* base(int rank = MPI_PROC_NULL) const { return static_cast(std::get<2>(query(rank))); } + // Override the commonly used attributes of the window base class + BaseType* base(int rank = MPI_PROC_NULL) const noexcept { return static_cast(std::get<2>(query(rank))); } + MPI_Aint size(int rank = MPI_PROC_NULL) const noexcept { return std::get<0>(query(rank)) / sizeof(BaseType); } + int disp_unit(int rank = MPI_PROC_NULL) const noexcept { return std::get<1>(query(rank)); } }; /* ----------------------------------------------------------- diff --git a/test/c++/mpi_window.cpp b/test/c++/mpi_window.cpp index 7aacd3b7..e9b7902f 100644 --- a/test/c++/mpi_window.cpp +++ b/test/c++/mpi_window.cpp @@ -19,6 +19,13 @@ #include #include +// Test cases are adapted from slides and exercises of the HLRS course: +// Introduction to the Message Passing Interface (MPI) +// Authors: Joel Malard, Alan Simpson, (EPCC) +// Rolf Rabenseifner, Traugott Streicher, Tobias Haas (HLRS) +// https://fs.hlrs.de/projects/par/par_prog_ws/pdf/mpi_3.1_rab.pdf +// https://fs.hlrs.de/projects/par/par_prog_ws/practical/MPI31single.tar.gz + TEST(MPI_Window, SharedCommunicator) { mpi::communicator world; [[maybe_unused]] auto shm = world.split_shared(); @@ -68,7 +75,7 @@ TEST(MPI_Window, RingOneSidedPut) { EXPECT_EQ(sum, (size * (size - 1)) / 2); } -TEST(MPI_Window, RingOneSidedAllowShared) { +TEST(MPI_Window, RingOneSidedAllocShared) { mpi::communicator world; auto shm = world.split_shared(); int const rank_shm = shm.rank(); @@ -91,6 +98,66 @@ TEST(MPI_Window, RingOneSidedAllowShared) { EXPECT_EQ(sum, (size_shm * (size_shm - 1)) / 2); } +TEST(MPI_Window, RingOneSidedStoreWinAllocSharedSignal) { + mpi::communicator world; + auto shm = world.split_shared(); + + int const rank_shm = shm.rank(); + int const size_shm = shm.size(); + int const right = (rank_shm+1) % size_shm; + int const left = (rank_shm-1+size_shm) % size_shm; + + mpi::shared_window win{shm, 1}; + int *rcv_buf_ptr = win.base(rank_shm); + win.lock(); + + int sum = 0; + int snd_buf = rank_shm; + + MPI_Request rq; + MPI_Status status; + int snd_dummy, rcv_dummy; + + for(int i = 0; i < size_shm; ++i) { + // ... The local Win_syncs are needed to sync the processor and real memory. + // ... The following pair of syncs is needed that the read-write-rule is fulfilled. + win.sync(); + + // ... tag=17: posting to left that rcv_buf is exposed to left, i.e., + // the left process is now allowed to store data into the local rcv_buf + MPI_Irecv(&rcv_dummy, 0, MPI_INT, right, 17, shm.get(), &rq); + MPI_Send (&snd_dummy, 0, MPI_INT, left, 17, shm.get()); + MPI_Wait(&rq, &status); + + win.sync(); + + // MPI_Put(&snd_buf, 1, MPI_INT, right, (MPI_Aint) 0, 1, MPI_INT, win); + // ... is substited by (with offset "right-my_rank" to store into right neigbor's rcv_buf): + *(rcv_buf_ptr+(right-rank_shm)) = snd_buf; + + + // ... The following pair of syncs is needed that the write-read-rule is fulfilled. + win.sync(); + + // ... The following communication synchronizes the processors in the way + // that the origin processor has finished the store + // before the target processor starts to load the data. + // ... tag=18: posting to right that rcv_buf was stored from left + MPI_Irecv(&rcv_dummy, 0, MPI_INT, left, 18, shm.get(), &rq); + MPI_Send (&snd_dummy, 0, MPI_INT, right, 18, shm.get()); + MPI_Wait(&rq, &status); + + win.sync(); + + snd_buf = *rcv_buf_ptr; + sum += *rcv_buf_ptr; + } + + EXPECT_EQ(sum, (size_shm * (size_shm - 1)) / 2); + + win.unlock(); +} + TEST(MPI_Window, SharedArray) { mpi::communicator world; auto shm = world.split_shared();