Skip to content

Commit

Permalink
Add prototype for shared allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
hmenke committed Oct 9, 2023
1 parent 879dc69 commit 87ead7d
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 8 deletions.
54 changes: 54 additions & 0 deletions c++/mpi/allocator.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2023 Simons Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0.txt
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell

#pragma once
#include "mpi.hpp"
#include <algorithm>
#include <memory>

namespace mpi {

template <class T>
class shared_allocator {
shared_communicator shm = communicator{}.split_shared();
std::shared_ptr<std::vector<shared_window<T>>> blocks = std::make_shared<std::vector<shared_window<T>>>();
public:
using value_type = T;

shared_allocator() = default;
explicit shared_allocator(shared_communicator const &shm) noexcept : shm{shm} {}

[[nodiscard]] shared_communicator get_communicator() { return shm; }

[[nodiscard]] auto get_window(T *p) {
return std::find_if(blocks->begin(), blocks->end(), [p](shared_window<T> const &win) {
return win.base(0) == p;
});
}

[[nodiscard]] T* allocate(std::size_t n) {
shared_window<T> &win = blocks->emplace_back(shm, shm.rank() == 0 ? n : 0);
return win.base(0);
}

void deallocate(T *p, std::size_t) {
blocks->erase(std::remove_if(blocks->begin(), blocks->end(), [p](shared_window<T> const &win) {
return win.base(0) == p;
}), blocks->end());
}
};

} // namespace mpi
25 changes: 18 additions & 7 deletions c++/mpi/mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <complex>
#include <type_traits>
#include <algorithm>
#include <utility>
#include <unistd.h>

/// Library namespace
Expand Down Expand Up @@ -364,9 +365,15 @@ namespace mpi {
public:
window() = default;
window(window const&) = delete;
window(window &&) = delete;
window(window &&other) noexcept : win{std::exchange(other.win, MPI_WIN_NULL)} {}
window& operator=(window const&) = delete;
window& operator=(window &&) = delete;
window& operator=(window &&rhs) noexcept {
if (this != std::addressof(rhs)) {
this->free();
this->win = std::exchange(rhs.win, MPI_WIN_NULL);
}
return *this;
}

/// Create a window over an existing local memory buffer
explicit window(communicator &c, BaseType *base, MPI_Aint size = 0) noexcept {
Expand All @@ -379,15 +386,17 @@ namespace mpi {
MPI_Win_allocate(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &win);
}

~window() {
~window() { free(); }

explicit operator MPI_Win() const noexcept { return win; };
explicit operator MPI_Win*() noexcept { return &win; };

void free() noexcept {
if (win != MPI_WIN_NULL) {
MPI_Win_free(&win);
}
}

explicit operator MPI_Win() const noexcept { return win; };
explicit operator MPI_Win*() noexcept { return &win; };

/// Synchronization routine in active target RMA. It opens and closes an access epoch.
void fence(int assert = 0) const noexcept {
MPI_Win_fence(assert, win);
Expand Down Expand Up @@ -474,8 +483,10 @@ namespace mpi {
template <class BaseType>
class shared_window : public window<BaseType> {
public:
shared_window() = default;

/// Create a window and allocate memory for a shared memory buffer
shared_window(shared_communicator& c, MPI_Aint size) noexcept {
explicit shared_window(shared_communicator& c, MPI_Aint size) noexcept {
void* baseptr = nullptr;
MPI_Win_allocate_shared(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &(this->win));
}
Expand Down
2 changes: 1 addition & 1 deletion test/c++/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ file(GLOB_RECURSE all_tests RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cpp)
# List of all no mpi tests
file(GLOB_RECURSE nompi_tests RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cpp)
# remove custom mpi test, as this one explicitly uses MPI
list(REMOVE_ITEM nompi_tests mpi_custom.cpp mpi_monitor.cpp mpi_window.cpp)
list(REMOVE_ITEM nompi_tests mpi_custom.cpp mpi_monitor.cpp mpi_window.cpp mpi_allocator.cpp)

# ========= OpenMP Dependency ==========

Expand Down
45 changes: 45 additions & 0 deletions test/c++/mpi_allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2023 Simons Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0.txt
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Authors: Philipp Dumitrescu, Olivier Parcollet, Nils Wentzell

#include <mpi/mpi.hpp>
#include <mpi/allocator.hpp>
#include <mpi/vector.hpp>
#include <gtest/gtest.h>
#include <numeric>

TEST(MPI_Allocator, SharedAllocator) {
mpi::shared_allocator<int> alloc;
int *p = alloc.allocate(1);
alloc.deallocate(p, 1);
}

TEST(MPI_Allocator, SharedAllocatorVector) {
std::vector<int, mpi::shared_allocator<int>> v(128);
auto shm = v.get_allocator().get_communicator();

// Fill the vector in parallel
v.get_allocator().get_window(v.data())->fence();
auto slice = itertools::chunk_range(0, v.size(), shm.size(), shm.rank());
for (auto i = slice.first; i < slice.second; ++i) {
v.at(i) = i;
}
v.get_allocator().get_window(v.data())->fence();

int const sum = std::accumulate(v.begin(), v.end(), int{0});
EXPECT_EQ(sum, v.size() * (v.size() - 1) / 2);
}

MPI_TEST_MAIN;

0 comments on commit 87ead7d

Please sign in to comment.