Skip to content

Commit

Permalink
Improved communication in eq distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
Xewar313 committed Nov 4, 2024
1 parent 0a1a4dc commit 2beec18
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
4 changes: 2 additions & 2 deletions include/dr/detail/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ class communicator {
i_all_gather(&src, rng::data(dst), 1, req);
}

void gatherv(const void *src, int *counts, int *offsets, void *dst,
void gatherv(const void *src, long long *counts, long *offsets, void *dst,
std::size_t root) const {
MPI_Gatherv(src, counts[rank()], MPI_BYTE, dst, counts, offsets, MPI_BYTE,
MPI_Gatherv_c(src, counts[rank()], MPI_BYTE, dst, counts, offsets, MPI_BYTE,
root, mpi_comm_);
}

Expand Down
52 changes: 37 additions & 15 deletions include/dr/mp/containers/matrix_formats/csr_eq_distribution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class csr_eq_distribution {
}
auto vals_len = shape_[1];
auto size = row_sizes_[rank];
auto res_col_len = max_row_size_;
auto res_col_len = row_sizes_[default_comm().rank()];
if (dr::mp::use_sycl()) {
auto localVals = dr::__detail::direct_iterator(
dr::mp::local_segment(*vals_data_).begin());
Expand All @@ -65,7 +65,7 @@ class csr_eq_distribution {
auto real_segment_size =
std::min(nnz_ - rank * segment_size_, segment_size_);
auto local_data = rows_data_;
auto division = std::max(real_segment_size / 100, max_row_size_ * 10);
auto division = std::max(real_segment_size / 100, row_sizes_[default_comm().rank()] * 10);
auto one_computation_size =
(real_segment_size + division - 1) / division;
auto row_size = row_size_;
Expand Down Expand Up @@ -141,12 +141,12 @@ class csr_eq_distribution {
auto local_gemv_and_collect(std::size_t root, C &res, T* vals, std::size_t vals_width) const {
assert(res.size() == shape_.first * vals_width);
__detail::allocator<T> alloc;
auto res_alloc = alloc.allocate(max_row_size_ * vals_width);
auto res_alloc = alloc.allocate( row_sizes_[default_comm().rank()] * vals_width);
if (use_sycl()) {
sycl_queue().fill(res_alloc, 0, max_row_size_ * vals_width).wait();
sycl_queue().fill(res_alloc, 0, row_sizes_[default_comm().rank()] * vals_width).wait();
}
else {
std::fill(res_alloc, res_alloc + max_row_size_ * vals_width, 0);
std::fill(res_alloc, res_alloc + row_sizes_[default_comm().rank()] * vals_width, 0);
}

// auto begin = std::chrono::high_resolution_clock::now();
Expand All @@ -156,9 +156,14 @@ class csr_eq_distribution {
// auto size = std::min(segment_size_, shape_[0] - segment_size_ * default_comm().rank());
// fmt::print("rows gemv time {} {} {}\n", duration * 1000, size, default_comm().rank());

// begin = std::chrono::high_resolution_clock::now();
gather_gemv_vector(root, res, res_alloc, vals_width);
// end = std::chrono::high_resolution_clock::now();
// duration = std::chrono::duration<double>(end - begin).count();
// size = std::min(segment_size_, shape_[0] - segment_size_ * default_comm().rank());
// fmt::print("rows gather time {} {} {}\n", duration * 1000, size, default_comm().rank());
fence();
alloc.deallocate(res_alloc, max_row_size_ * vals_width);
alloc.deallocate(res_alloc, row_sizes_[default_comm().rank()] * vals_width);
}

private:
Expand All @@ -168,14 +173,25 @@ class csr_eq_distribution {
void gather_gemv_vector(std::size_t root, C &res, A &partial_res, std::size_t vals_width) const {
auto communicator = default_comm();
__detail::allocator<T> alloc;
long long* counts = new long long[communicator.size()];
for (auto i = 0; i < communicator.size(); i++) {
counts[i] = row_sizes_[i] * sizeof(T) * vals_width;
}

if (communicator.rank() == root) {
auto gathered_res = alloc.allocate(max_row_size_ * communicator.size() * vals_width);
communicator.gather(partial_res, gathered_res, max_row_size_ * vals_width, root);
long* offsets = new long[communicator.size()];
offsets[0] = 0;
for (auto i = 0; i < communicator.size() - 1; i++) {
offsets[i + 1] = offsets[i] + counts[i];
}
auto gathered_res = alloc.allocate(max_row_size_ * vals_width);
communicator.gatherv(partial_res, counts, offsets, gathered_res, root);
// communicator.gather(partial_res, gathered_res, max_row_size_ * vals_width, root);
T* gathered_res_host;

if (use_sycl()) {
gathered_res_host = new T[max_row_size_ * communicator.size() * vals_width];
__detail::sycl_copy(gathered_res, gathered_res_host, max_row_size_ * communicator.size() * vals_width);
gathered_res_host = new T[max_row_size_ * vals_width];
__detail::sycl_copy(gathered_res, gathered_res_host, max_row_size_ * vals_width);
}
else {
gathered_res_host = gathered_res;
Expand All @@ -185,12 +201,15 @@ class csr_eq_distribution {

// auto begin = std::chrono::high_resolution_clock::now();
for (auto k = 0; k < vals_width; k++) {
auto current_offset = 0;
for (auto i = 0; i < communicator.size(); i++) {
auto first_row = row_offsets_[i];
auto last_row = row_offsets_[i] + row_sizes_[i];
auto row_size = row_sizes_[i];
for (auto j = first_row; j < last_row; j++) {
res[j + k * shape_[1]] += gathered_res_host[vals_width * max_row_size_ * i + k * max_row_size_ + j - first_row];
res[j + k * shape_[1]] += gathered_res_host[vals_width * current_offset + k * row_size + j - first_row];
}
current_offset += row_sizes_[i];
}
}

Expand All @@ -200,11 +219,14 @@ class csr_eq_distribution {
if (use_sycl()) {
delete[] gathered_res_host;
}
alloc.deallocate(gathered_res, max_row_size_ * communicator.size() * vals_width);
delete[] offsets;
alloc.deallocate(gathered_res, max_row_size_ * vals_width);
} else {
communicator.gather(partial_res, static_cast<T *>(nullptr), max_row_size_ * vals_width,
root);
// communicator.gather(partial_res, static_cast<T *>(nullptr), max_row_size_ * vals_width,
// root);
communicator.gatherv(partial_res, counts, nullptr, nullptr, root);
}
delete[] counts;
}

std::size_t get_row_size(std::size_t rank) { return row_sizes_[rank]; }
Expand Down Expand Up @@ -260,7 +282,7 @@ class csr_eq_distribution {
row_sizes_.push_back(higher_limit - lower_limit);
row_information[i] = lower_limit;
row_information[default_comm().size() + i] = higher_limit - lower_limit;
max_row_size_ = std::max(max_row_size_, row_sizes_.back());
max_row_size_ = max_row_size_ + row_sizes_.back();
}
row_information[default_comm().size() * 2] = max_row_size_;
default_comm().bcast(row_information, sizeof(std::size_t) * row_info_size,
Expand Down

0 comments on commit 2beec18

Please sign in to comment.