Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync report names across ranks #25

Merged
merged 4 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 85 additions & 66 deletions src/library/implementation_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ struct Implementation {
static void add_communicator(const std::string& comm_name) {
return TImpl::add_communicator(comm_name);
}
static std::vector<std::string> sync_reports(const std::vector<std::string>& local_reports) {
return TImpl::sync_reports(local_reports);
}
static std::vector<std::string> sync_populations(
const std::string& comm_name, const std::vector<std::string>& local_populations) {
return TImpl::sync_populations(comm_name, local_populations);
Expand Down Expand Up @@ -108,6 +111,72 @@ static std::string add_extension(const std::string& report_name) {

#ifdef SONATA_REPORT_HAVE_MPI

static std::vector<char> serialize(const std::vector<std::string>& strings) {
std::vector<char> buffer;
for (const auto& str : strings) {
const auto offset = buffer.size();
buffer.resize(offset + str.size() + 1); // +1 for null-char
strcpy(&buffer[offset], str.c_str());
}
return buffer;
}

static std::vector<std::string> deserialize(const std::vector<char>& strings) {
std::vector<std::string> buffer;
// +1 for null-char
for (size_t offset = 0; offset < strings.size(); offset += (buffer.back().size() + 1)) {
buffer.emplace_back(&strings[offset]);
}
return buffer;
}

static std::vector<std::string> sync_strings(const MPI_Comm comm,
const std::vector<std::string>& strings) {
auto buffer = serialize(strings);
auto buffer_size = static_cast<int>(buffer.size());

int nranks, local_rank;
MPI_Comm_rank(comm, &local_rank);
MPI_Comm_size(comm, &nranks);
std::vector<int> counts((local_rank == 0) ? nranks : 0);
std::vector<int> displs((local_rank == 0) ? nranks : 0);

MPI_Gather(&buffer_size, 1, MPI_INT, counts.data(), 1, MPI_INT, 0, comm);
if (local_rank == 0) {
std::partial_sum(counts.begin(), counts.end(), displs.begin());
displs.insert(displs.begin(), 0); // To begin with offset=0

const auto buffer_sizes = std::accumulate(counts.begin(), counts.end(), 0);
buffer.resize(buffer_sizes);
}

auto send_buffer_ptr = (local_rank == 0) ? MPI_IN_PLACE : buffer.data();
MPI_Gatherv(send_buffer_ptr,
buffer_size,
MPI_CHAR,
buffer.data(),
counts.data(),
displs.data(),
MPI_CHAR,
0,
comm);
if (local_rank == 0) {
const auto buffer_str = deserialize(buffer);
// Eliminate duplicated populations
std::set<std::string> buffer_set(buffer_str.begin(), buffer_str.end());

buffer = serialize(std::vector<std::string>(buffer_set.begin(), buffer_set.end()));
buffer_size = static_cast<int>(buffer.size());
}

MPI_Bcast(&buffer_size, 1, MPI_INT, 0, comm);
buffer.resize(buffer_size);
MPI_Bcast(buffer.data(), buffer_size, MPI_CHAR, 0, comm);

// Return the vector of synced strings
return deserialize(buffer);
}

static MPI_Comm get_Comm(const std::string& comm_name) {
if (SonataReport::communicators_.find(comm_name) != SonataReport::communicators_.end()) {
// Found
Expand All @@ -129,86 +198,33 @@ struct ParallelImplementation {
int num_reports = report_names.size();
MPI_Comm_split(MPI_COMM_WORLD, num_reports == 0, 0, &SonataReport::has_nodes_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say it's worth having ticket about not referring to MPI_COMM_WORLD anywhere in the parallel MPI library i.e. when libsonatareport is initialised, the communicator should be passed as an argument (almost always).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created the issue as good for beginners :)
#27


std::vector<std::string> global_report_names = sync_reports(report_names);
for (const auto& report : global_report_names) {
bool has_report = std::find(report_names.begin(), report_names.end(), report) !=
report_names.end();
add_communicator(report, has_report);
}

return global_rank;
};

static void close(){};

static void add_communicator(const std::string& comm_name) {
std::hash<std::string> hasher;
size_t comm_hash = hasher(comm_name);
static void add_communicator(const std::string& comm_name, bool has_report = true) {
if (SonataReport::communicators_.find(comm_name) == SonataReport::communicators_.end()) {
MPI_Comm_split(SonataReport::has_nodes_,
comm_hash,
has_report,
0,
&SonataReport::communicators_[comm_name]);
}
};
static std::vector<std::string> sync_reports(const std::vector<std::string>& local_reports) {
return sync_strings(SonataReport::has_nodes_, local_reports);
};

static std::vector<std::string> sync_populations(
const std::string& comm_name, const std::vector<std::string>& local_populations) {
const auto& serialize = [](const std::vector<std::string>& strings) -> std::vector<char> {
std::vector<char> buffer;
for (const auto& str : strings) {
const auto offset = buffer.size();
buffer.resize(offset + str.size() + 1); // +1 for null-char

strcpy(&buffer[offset], str.c_str());
}
return buffer;
};
const auto& deserialize = [](const std::vector<char>& strings) -> std::vector<std::string> {
std::vector<std::string> buffer;
for (size_t offset = 0; offset < strings.size();
offset += (buffer.back().size() + 1)) { // +1 for null-char
buffer.emplace_back(&strings[offset]);
}
return buffer;
};
MPI_Comm comm = get_Comm(comm_name);
auto buffer = serialize(local_populations);
auto buffer_size = static_cast<int>(buffer.size());

int nranks, local_rank;
MPI_Comm_rank(comm, &local_rank);
MPI_Comm_size(comm, &nranks);
std::vector<int> counts((local_rank == 0) ? nranks : 0);
std::vector<int> displs((local_rank == 0) ? nranks : 0);

MPI_Gather(&buffer_size, 1, MPI_INT, counts.data(), 1, MPI_INT, 0, comm);
if (local_rank == 0) {
std::partial_sum(counts.begin(), counts.end(), displs.begin());
displs.insert(displs.begin(), 0); // To begin with offset=0

const auto buffer_sizes = std::accumulate(counts.begin(), counts.end(), 0);
buffer.resize(buffer_sizes);
}

auto send_buffer_ptr = (local_rank == 0) ? MPI_IN_PLACE : buffer.data();
MPI_Gatherv(send_buffer_ptr,
buffer_size,
MPI_CHAR,
buffer.data(),
counts.data(),
displs.data(),
MPI_CHAR,
0,
comm);
if (local_rank == 0) {
const auto buffer_str = deserialize(buffer);
// Eliminate duplicated populations
std::set<std::string> buffer_set(buffer_str.begin(), buffer_str.end());

buffer = serialize(std::vector<std::string>(buffer_set.begin(), buffer_set.end()));
buffer_size = static_cast<int>(buffer.size());
}

MPI_Bcast(&buffer_size, 1, MPI_INT, 0, comm);
buffer.resize(buffer_size);
MPI_Bcast(buffer.data(), buffer_size, MPI_CHAR, 0, comm);

// Return the vector of population names
return deserialize(buffer);
return sync_strings(get_Comm(comm_name), local_populations);
};

static hid_t prepare_write(const std::string& report_name) {
Expand Down Expand Up @@ -352,6 +368,9 @@ struct SerialImplementation {
};
static void close(){};
static void add_communicator(const std::string& /*comm_nam*/){};
static std::vector<std::string> sync_reports(const std::vector<std::string>& local_reports) {
return local_reports;
};
static std::vector<std::string> sync_populations(
const std::string& /*comm_name*/, const std::vector<std::string>& local_populations) {
return local_populations;
Expand Down
1 change: 0 additions & 1 deletion src/library/report.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ std::shared_ptr<Node> Report::get_node(const std::string& population_name, uint6
}

int Report::prepare_dataset() {
Implementation::add_communicator(report_name_);
file_handler_ = Implementation::prepare_write(report_name_);

std::vector<std::string> local_populations;
Expand Down
1 change: 1 addition & 0 deletions src/library/sonatareport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ void SonataReport::prepare_datasets() {

void SonataReport::create_spikefile(const std::string& output_dir, const std::string& filename) {
std::string report_name = output_dir + "/" + filename;
Implementation::add_communicator(report_name);
spike_data_ = std::make_unique<SonataData>(report_name);
}

Expand Down