diff --git a/include/dr/mhp/algorithms/md_for_each.hpp b/include/dr/mhp/algorithms/md_for_each.hpp index 4107f3440e..6928ab62f2 100644 --- a/include/dr/mhp/algorithms/md_for_each.hpp +++ b/include/dr/mhp/algorithms/md_for_each.hpp @@ -16,6 +16,26 @@ #include #include +namespace dr::mhp::__detail { + +struct any { + template operator T() const noexcept { + return std::declval(); + } +}; + +template +concept one_argument = requires(F &f) { + { f(Arg1{}) }; +}; + +template +concept two_arguments = requires(F &f) { + { f(Arg1{}, Arg2{}) }; +}; + +}; // namespace dr::mhp::__detail + namespace dr::mhp { namespace detail = dr::__detail; @@ -95,7 +115,8 @@ void stencil_for_each(auto op, is_mdspan_view auto &&...drs) { } /// Collective for_each on distributed range -template void for_each(auto op, is_mdspan_view auto &&...drs) { +template +void for_each(F op, is_mdspan_view auto &&...drs) { auto ranges = std::tie(drs...); auto &&dr0 = std::get<0>(ranges); if (rng::empty(dr0)) { @@ -109,6 +130,8 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { // If local if (dr::ranges::rank(seg0) == default_comm().rank()) { + auto origin = seg0.origin(); + // make a tuple of mdspans auto operand_mdspans = detail::tuple_transform( segs, [](auto &&seg) { return seg.mdspan(); }); @@ -122,7 +145,19 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { operand_mdspans, [index](auto mdspan) -> decltype(auto) { return mdspan(index[0], index[1]); }); - op(references); + static_assert( + std::invocable || + std::invocable); + if constexpr (std::invocable) { + op(references); + } else { + auto global_index = index; + for (std::size_t i = 0; i < rng::size(global_index); i++) { + global_index[i] += origin[i]; + } + + op(global_index, references); + } }; // TODO: Extend sycl_utils.hpp to handle ranges > 1D. It uses @@ -143,7 +178,22 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { auto references = detail::tie_transform( operand_mdspans, [index](auto mdspan) -> decltype(auto) { return mdspan(index); }); - op(references); + static_assert( + std::invocable || + std::invocable); + if constexpr (std::invocable) { + op(references); + } else if constexpr (std::invocable) { + auto global_index = index; + for (std::size_t i = 0; i < rng::size(global_index); i++) { + global_index[i] += origin[i]; + } + + op(global_index, references); + } else { + assert(false); + } }; detail::mdspan_foreach( mdspan0.extents(), invoke_index); diff --git a/test/gtest/mhp/CMakeLists.txt b/test/gtest/mhp/CMakeLists.txt index ae6a8dd331..d720875cfa 100644 --- a/test/gtest/mhp/CMakeLists.txt +++ b/test/gtest/mhp/CMakeLists.txt @@ -56,7 +56,7 @@ add_executable( add_executable(mhp-quick-test mhp-tests.cpp - ../common/reduce.cpp + mdstar.cpp ) # cmake-format: on diff --git a/test/gtest/mhp/mdstar.cpp b/test/gtest/mhp/mdstar.cpp index bce9c503cd..b9dc948469 100644 --- a/test/gtest/mhp/mdstar.cpp +++ b/test/gtest/mhp/mdstar.cpp @@ -413,6 +413,22 @@ TEST_F(MdForeach, 3ops) { c.mdspan()); } +TEST_F(MdForeach, Indexed) { + xhp::distributed_mdarray dist(extents2d); + auto op = [l = ydim](auto index, auto v) { + auto &[o] = v; + o = index[0] * l + index[1]; + }; + + xhp::for_each(op, dist); + for (std::size_t i = 0; i < xdim; i++) { + for (std::size_t j = 0; j < ydim; j++) { + EXPECT_EQ(dist.mdspan()(i, j), i * ydim + j) + << fmt::format("i: {} j: {}\n", i, j); + } + } +} + using MdStencilForeach = Mdspan; TEST_F(MdStencilForeach, 2ops) {