Skip to content

Commit

Permalink
support for passing global index in a mdspan for_each (#600)
Browse files Browse the repository at this point in the history
  • Loading branch information
rscohn2 authored Oct 27, 2023
1 parent 1a415cb commit c961a53
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 4 deletions.
56 changes: 53 additions & 3 deletions include/dr/mhp/algorithms/md_for_each.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@
#include <dr/detail/tuple_utils.hpp>
#include <dr/mhp/global.hpp>

namespace dr::mhp::__detail {

struct any {
template <typename T> operator T() const noexcept {
return std::declval<T>();
}
};

template <typename F, typename Arg1>
concept one_argument = requires(F &f) {
{ f(Arg1{}) };
};

template <typename F, typename Arg1, typename Arg2>
concept two_arguments = requires(F &f) {
{ f(Arg1{}, Arg2{}) };
};

}; // namespace dr::mhp::__detail

namespace dr::mhp {

namespace detail = dr::__detail;
Expand Down Expand Up @@ -95,7 +115,8 @@ void stencil_for_each(auto op, is_mdspan_view auto &&...drs) {
}

/// Collective for_each on distributed range
template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
template <typename F, typename... Ts>
void for_each(F op, is_mdspan_view auto &&...drs) {
auto ranges = std::tie(drs...);
auto &&dr0 = std::get<0>(ranges);
if (rng::empty(dr0)) {
Expand All @@ -109,6 +130,8 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {

// If local
if (dr::ranges::rank(seg0) == default_comm().rank()) {
auto origin = seg0.origin();

// make a tuple of mdspans
auto operand_mdspans = detail::tuple_transform(
segs, [](auto &&seg) { return seg.mdspan(); });
Expand All @@ -122,7 +145,19 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
operand_mdspans, [index](auto mdspan) -> decltype(auto) {
return mdspan(index[0], index[1]);
});
op(references);
static_assert(
std::invocable<F, decltype(references)> ||
std::invocable<F, decltype(index), decltype(references)>);
if constexpr (std::invocable<F, decltype(references)>) {
op(references);
} else {
auto global_index = index;
for (std::size_t i = 0; i < rng::size(global_index); i++) {
global_index[i] += origin[i];
}

op(global_index, references);
}
};

// TODO: Extend sycl_utils.hpp to handle ranges > 1D. It uses
Expand All @@ -143,7 +178,22 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
auto references = detail::tie_transform(
operand_mdspans,
[index](auto mdspan) -> decltype(auto) { return mdspan(index); });
op(references);
static_assert(
std::invocable<F, decltype(references)> ||
std::invocable<F, decltype(index), decltype(references)>);
if constexpr (std::invocable<F, decltype(references)>) {
op(references);
} else if constexpr (std::invocable<F, decltype(index),
decltype(references)>) {
auto global_index = index;
for (std::size_t i = 0; i < rng::size(global_index); i++) {
global_index[i] += origin[i];
}

op(global_index, references);
} else {
assert(false);
}
};
detail::mdspan_foreach<mdspan0.rank(), decltype(invoke_index)>(
mdspan0.extents(), invoke_index);
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/mhp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ add_executable(

add_executable(mhp-quick-test
mhp-tests.cpp
../common/reduce.cpp
mdstar.cpp
)
# cmake-format: on

Expand Down
16 changes: 16 additions & 0 deletions test/gtest/mhp/mdstar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,22 @@ TEST_F(MdForeach, 3ops) {
c.mdspan());
}

TEST_F(MdForeach, Indexed) {
xhp::distributed_mdarray<T, 2> dist(extents2d);
auto op = [l = ydim](auto index, auto v) {
auto &[o] = v;
o = index[0] * l + index[1];
};

xhp::for_each(op, dist);
for (std::size_t i = 0; i < xdim; i++) {
for (std::size_t j = 0; j < ydim; j++) {
EXPECT_EQ(dist.mdspan()(i, j), i * ydim + j)
<< fmt::format("i: {} j: {}\n", i, j);
}
}
}

using MdStencilForeach = Mdspan;

TEST_F(MdStencilForeach, 2ops) {
Expand Down

0 comments on commit c961a53

Please sign in to comment.