diff --git a/include/dr/mhp/algorithms/md_for_each.hpp b/include/dr/mhp/algorithms/md_for_each.hpp index 4107f3440e..45f25e7bcc 100644 --- a/include/dr/mhp/algorithms/md_for_each.hpp +++ b/include/dr/mhp/algorithms/md_for_each.hpp @@ -16,8 +16,18 @@ #include #include +namespace dr::mhp::__detail { + +template constexpr size_t argument_count( R(*f)(Types ...)) +{ + return sizeof...(Types); +} + +}; + namespace dr::mhp { + namespace detail = dr::__detail; /// Collective for_each on distributed range @@ -94,6 +104,7 @@ void stencil_for_each(auto op, is_mdspan_view auto &&...drs) { barrier(); } + /// Collective for_each on distributed range template void for_each(auto op, is_mdspan_view auto &&...drs) { auto ranges = std::tie(drs...); @@ -109,6 +120,8 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { // If local if (dr::ranges::rank(seg0) == default_comm().rank()) { + auto origin = seg0.origin(); + // make a tuple of mdspans auto operand_mdspans = detail::tuple_transform( segs, [](auto &&seg) { return seg.mdspan(); }); @@ -143,7 +156,18 @@ template void for_each(auto op, is_mdspan_view auto &&...drs) { auto references = detail::tie_transform( operand_mdspans, [index](auto mdspan) -> decltype(auto) { return mdspan(index); }); - op(references); + if constexpr (__detail::argument_count(op) == 1) { + op(references); + } else if constexpr (__detail::argument_count(op) == 2) { + auto global_index = index; + for (std::size_t i = 0; i < global_index.size(); i++) { + global_index[i] += origin[i]; + } + + op(global_index, references); + } else { + static_assert(false); + } }; detail::mdspan_foreach( mdspan0.extents(), invoke_index); diff --git a/test/gtest/mhp/CMakeLists.txt b/test/gtest/mhp/CMakeLists.txt index ae6a8dd331..d720875cfa 100644 --- a/test/gtest/mhp/CMakeLists.txt +++ b/test/gtest/mhp/CMakeLists.txt @@ -56,7 +56,7 @@ add_executable( add_executable(mhp-quick-test mhp-tests.cpp - ../common/reduce.cpp + mdstar.cpp ) # cmake-format: on diff --git a/test/gtest/mhp/mdstar.cpp b/test/gtest/mhp/mdstar.cpp index bce9c503cd..c45ec8458a 100644 --- a/test/gtest/mhp/mdstar.cpp +++ b/test/gtest/mhp/mdstar.cpp @@ -413,6 +413,25 @@ TEST_F(MdForeach, 3ops) { c.mdspan()); } +TEST_F(MdForeach, Indexed) { + xhp::distributed_mdarray a(extents2d); + xhp::distributed_mdarray b(extents2d); + auto mda = a.mdspan(); + auto mdb = b.mdspan(); + xhp::iota(a, 100); + xhp::iota(b, 200); + auto op = [](auto index, auto v) { + auto &[o1, o2] = v; + o1 = index[0]; + o2 = index[1]; + }; + + xhp::for_each(op, a, b); + EXPECT_EQ(mda(1, 2), 1); + EXPECT_EQ(mdb(1, 2), 2); + EXPECT_EQ(mda(xdim - 1, ydim - 1), mdb(xdim - 1, ydim - 1)); +} + using MdStencilForeach = Mdspan; TEST_F(MdStencilForeach, 2ops) {