diff --git a/include/dr/mhp/algorithms/md_for_each.hpp b/include/dr/mhp/algorithms/md_for_each.hpp index 45f25e7bcc..6928ab62f2 100644 --- a/include/dr/mhp/algorithms/md_for_each.hpp +++ b/include/dr/mhp/algorithms/md_for_each.hpp @@ -18,15 +18,25 @@ namespace dr::mhp::__detail { -template constexpr size_t argument_count( R(*f)(Types ...)) -{ - return sizeof...(Types); -} +struct any { + template operator T() const noexcept { + return std::declval(); + } +}; +template +concept one_argument = requires(F &f) { + { f(Arg1{}) }; }; -namespace dr::mhp { +template +concept two_arguments = requires(F &f) { + { f(Arg1{}, Arg2{}) }; +}; +}; // namespace dr::mhp::__detail + +namespace dr::mhp { namespace detail = dr::__detail; @@ -104,9 +114,9 @@ void stencil_for_each(auto op, is_mdspan_view auto &&...drs) { barrier(); } - /// Collective for_each on distributed range -template void for_each(auto op, is_mdspan_view auto &&...drs) { +template +void for_each(F op, is_mdspan_view auto &&...drs) { auto ranges = std::tie(drs...); auto &&dr0 = std::get<0>(ranges); if (rng::empty(dr0)) { @@ -121,7 +131,7 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { // If local if (dr::ranges::rank(seg0) == default_comm().rank()) { auto origin = seg0.origin(); - + // make a tuple of mdspans auto operand_mdspans = detail::tuple_transform( segs, [](auto &&seg) { return seg.mdspan(); }); @@ -135,7 +145,19 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { operand_mdspans, [index](auto mdspan) -> decltype(auto) { return mdspan(index[0], index[1]); }); - op(references); + static_assert( + std::invocable || + std::invocable); + if constexpr (std::invocable) { + op(references); + } else { + auto global_index = index; + for (std::size_t i = 0; i < rng::size(global_index); i++) { + global_index[i] += origin[i]; + } + + op(global_index, references); + } }; // TODO: Extend sycl_utils.hpp to handle ranges > 1D. It uses @@ -156,17 +178,21 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { auto references = detail::tie_transform( operand_mdspans, [index](auto mdspan) -> decltype(auto) { return mdspan(index); }); - if constexpr (__detail::argument_count(op) == 1) { + static_assert( + std::invocable || + std::invocable); + if constexpr (std::invocable) { op(references); - } else if constexpr (__detail::argument_count(op) == 2) { + } else if constexpr (std::invocable) { auto global_index = index; - for (std::size_t i = 0; i < global_index.size(); i++) { + for (std::size_t i = 0; i < rng::size(global_index); i++) { global_index[i] += origin[i]; } op(global_index, references); } else { - static_assert(false); + assert(false); } }; detail::mdspan_foreach( diff --git a/test/gtest/mhp/mdstar.cpp b/test/gtest/mhp/mdstar.cpp index c45ec8458a..b9dc948469 100644 --- a/test/gtest/mhp/mdstar.cpp +++ b/test/gtest/mhp/mdstar.cpp @@ -414,22 +414,19 @@ TEST_F(MdForeach, 3ops) { } TEST_F(MdForeach, Indexed) { - xhp::distributed_mdarray a(extents2d); - xhp::distributed_mdarray b(extents2d); - auto mda = a.mdspan(); - auto mdb = b.mdspan(); - xhp::iota(a, 100); - xhp::iota(b, 200); - auto op = [](auto index, auto v) { - auto &[o1, o2] = v; - o1 = index[0]; - o2 = index[1]; + xhp::distributed_mdarray dist(extents2d); + auto op = [l = ydim](auto index, auto v) { + auto &[o] = v; + o = index[0] * l + index[1]; }; - xhp::for_each(op, a, b); - EXPECT_EQ(mda(1, 2), 1); - EXPECT_EQ(mdb(1, 2), 2); - EXPECT_EQ(mda(xdim - 1, ydim - 1), mdb(xdim - 1, ydim - 1)); + xhp::for_each(op, dist); + for (std::size_t i = 0; i < xdim; i++) { + for (std::size_t j = 0; j < ydim; j++) { + EXPECT_EQ(dist.mdspan()(i, j), i * ydim + j) + << fmt::format("i: {} j: {}\n", i, j); + } + } } using MdStencilForeach = Mdspan;