Skip to content

Commit

Permalink
support for mdspan for_each that passes index
Browse files Browse the repository at this point in the history
  • Loading branch information
rscohn2 committed Oct 27, 2023
1 parent 340658b commit 376d142
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 27 deletions.
52 changes: 39 additions & 13 deletions include/dr/mhp/algorithms/md_for_each.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,25 @@

namespace dr::mhp::__detail {

template <typename R, typename ... Types> constexpr size_t argument_count( R(*f)(Types ...))
{
return sizeof...(Types);
}
struct any {
template <typename T> operator T() const noexcept {
return std::declval<T>();
}
};

template <typename F, typename Arg1>
concept one_argument = requires(F &f) {
{ f(Arg1{}) };
};

namespace dr::mhp {
template <typename F, typename Arg1, typename Arg2>
concept two_arguments = requires(F &f) {
{ f(Arg1{}, Arg2{}) };
};

}; // namespace dr::mhp::__detail

namespace dr::mhp {

namespace detail = dr::__detail;

Expand Down Expand Up @@ -104,9 +114,9 @@ void stencil_for_each(auto op, is_mdspan_view auto &&...drs) {
barrier();
}


/// Collective for_each on distributed range
template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
template <typename F, typename... Ts>
void for_each(F op, is_mdspan_view auto &&...drs) {
auto ranges = std::tie(drs...);
auto &&dr0 = std::get<0>(ranges);
if (rng::empty(dr0)) {
Expand All @@ -121,7 +131,7 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
// If local
if (dr::ranges::rank(seg0) == default_comm().rank()) {
auto origin = seg0.origin();

// make a tuple of mdspans
auto operand_mdspans = detail::tuple_transform(
segs, [](auto &&seg) { return seg.mdspan(); });
Expand All @@ -135,7 +145,19 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
operand_mdspans, [index](auto mdspan) -> decltype(auto) {
return mdspan(index[0], index[1]);
});
op(references);
static_assert(
std::invocable<F, decltype(references)> ||
std::invocable<F, decltype(index), decltype(references)>);
if constexpr (std::invocable<F, decltype(references)>) {
op(references);
} else {
auto global_index = index;
for (std::size_t i = 0; i < rng::size(global_index); i++) {
global_index[i] += origin[i];
}

op(global_index, references);
}
};

// TODO: Extend sycl_utils.hpp to handle ranges > 1D. It uses
Expand All @@ -156,17 +178,21 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
auto references = detail::tie_transform(
operand_mdspans,
[index](auto mdspan) -> decltype(auto) { return mdspan(index); });
if constexpr (__detail::argument_count(op) == 1) {
static_assert(
std::invocable<F, decltype(references)> ||
std::invocable<F, decltype(index), decltype(references)>);
if constexpr (std::invocable<F, decltype(references)>) {
op(references);
} else if constexpr (__detail::argument_count(op) == 2) {
} else if constexpr (std::invocable<F, decltype(index),
decltype(references)>) {
auto global_index = index;
for (std::size_t i = 0; i < global_index.size(); i++) {
for (std::size_t i = 0; i < rng::size(global_index); i++) {
global_index[i] += origin[i];
}

op(global_index, references);
} else {
static_assert(false);
assert(false);
}
};
detail::mdspan_foreach<mdspan0.rank(), decltype(invoke_index)>(
Expand Down
25 changes: 11 additions & 14 deletions test/gtest/mhp/mdstar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,22 +414,19 @@ TEST_F(MdForeach, 3ops) {
}

TEST_F(MdForeach, Indexed) {
xhp::distributed_mdarray<T, 2> a(extents2d);
xhp::distributed_mdarray<T, 2> b(extents2d);
auto mda = a.mdspan();
auto mdb = b.mdspan();
xhp::iota(a, 100);
xhp::iota(b, 200);
auto op = [](auto index, auto v) {
auto &[o1, o2] = v;
o1 = index[0];
o2 = index[1];
xhp::distributed_mdarray<T, 2> dist(extents2d);
auto op = [l = ydim](auto index, auto v) {
auto &[o] = v;
o = index[0] * l + index[1];
};

xhp::for_each(op, a, b);
EXPECT_EQ(mda(1, 2), 1);
EXPECT_EQ(mdb(1, 2), 2);
EXPECT_EQ(mda(xdim - 1, ydim - 1), mdb(xdim - 1, ydim - 1));
xhp::for_each(op, dist);
for (std::size_t i = 0; i < xdim; i++) {
for (std::size_t j = 0; j < ydim; j++) {
EXPECT_EQ(dist.mdspan()(i, j), i * ydim + j)
<< fmt::format("i: {} j: {}\n", i, j);
}
}
}

using MdStencilForeach = Mdspan;
Expand Down

0 comments on commit 376d142

Please sign in to comment.