Skip to content

Commit

Permalink
bypass pack/unpack for halo
Browse files Browse the repository at this point in the history
  • Loading branch information
rscohn2 committed Oct 9, 2023
1 parent 4dd4a55 commit 89e8cc8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 23 deletions.
4 changes: 3 additions & 1 deletion benchmarks/gbench/mhp/mhp-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void dr_init() {
if (options.count("sycl")) {
sycl::queue q = dr::mhp::select_queue();
benchmark::AddCustomContext("device_info", device_info(q.get_device()));
dr::mhp::init(q);
dr::mhp::init(q, options.count("device-memory") ? sycl::usm::alloc::device
: sycl::usm::alloc::shared);
return;
}
#endif
Expand Down Expand Up @@ -81,6 +82,7 @@ int main(int argc, char *argv[]) {
("stencil-steps", "Default steps for stencil", cxxopts::value<std::size_t>()->default_value("10"))
("vector-size", "Default vector size", cxxopts::value<std::size_t>()->default_value("100000000"))
("context", "Additional google benchmark context", cxxopts::value<std::vector<std::string>>())
("device-memory", "Use device memory")
("weak-scaling", "Scale the vector size by the number of ranks", cxxopts::value<bool>()->default_value("false"))
;
// clang-format on
Expand Down
52 changes: 30 additions & 22 deletions include/dr/mhp/halo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,45 +296,53 @@ template <typename T, typename Memory = default_memory<T>> class span_group {
T *buffer = nullptr;
std::size_t request_index = 0;
bool receive = false;
bool buffered = true;
bool buffered = false;

span_group(T *data, std::size_t size, std::size_t rank, communicator::tag tag,
const Memory &memory)
: data_(data, size), rank_(rank), tag_(tag), memory_(memory) {}
: data_(data, size), rank_(rank), tag_(tag), memory_(memory) {
#ifdef SYCL_LANGUAGE_VERSION
if (use_sycl() && sycl_mem_kind() == sycl::usm::alloc::shared) {
buffered = true;
}
#endif
}

span_group(std::span<T> data, std::size_t rank, communicator::tag tag)
: data_(data), rank_(rank), tag_(tag) {}

void unpack(const auto &op) {
if (mhp::use_sycl()) {
__detail::sycl_copy(buffer, buffer + rng::size(data_), data_.data());
} else {
for (std::size_t i = 0; i < rng::size(data_); i++) {
data_[i] = op(data_[i], buffer[i]);
}
}
}

void unpack() {
if (mhp::use_sycl()) {
__detail::sycl_copy(buffer, buffer + rng::size(data_), data_.data());
} else {
std::copy(buffer, buffer + rng::size(data_), data_.data());
if (buffered) {
fmt::print("Skipping buffering");
if (mhp::use_sycl()) {
__detail::sycl_copy(buffer, buffer + rng::size(data_), data_.data());
} else {
std::copy(buffer, buffer + rng::size(data_), data_.data());
}
}
}

void pack() {
if (mhp::use_sycl()) {
__detail::sycl_copy(data_.data(), data_.data() + rng::size(data_),
buffer);
} else {
std::copy(data_.begin(), data_.end(), buffer);
if (buffered) {
if (mhp::use_sycl()) {
__detail::sycl_copy(data_.data(), data_.data() + rng::size(data_),
buffer);
} else {
std::copy(data_.begin(), data_.end(), buffer);
}
}
}
std::size_t buffer_size() { return rng::size(data_); }

std::size_t data_size() { return rng::size(data_); }
T *data_pointer() { return buffer; }

T *data_pointer() {
if (buffered) {
return buffer;
} else {
return data_.data();
}
}

std::size_t rank() { return rank_; }

Expand Down

0 comments on commit 89e8cc8

Please sign in to comment.