Skip to content

Commit

Permalink
Support for passing global index to function in mdspan for_each
Browse files Browse the repository at this point in the history
  • Loading branch information
rscohn2 committed Oct 27, 2023
1 parent 1a415cb commit 340658b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
26 changes: 25 additions & 1 deletion include/dr/mhp/algorithms/md_for_each.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@
#include <dr/detail/tuple_utils.hpp>
#include <dr/mhp/global.hpp>

namespace dr::mhp::__detail {

template <typename R, typename ... Types> constexpr size_t argument_count( R(*f)(Types ...))
{
return sizeof...(Types);
}

};

namespace dr::mhp {


namespace detail = dr::__detail;

/// Collective for_each on distributed range
Expand Down Expand Up @@ -94,6 +104,7 @@ void stencil_for_each(auto op, is_mdspan_view auto &&...drs) {
barrier();
}


/// Collective for_each on distributed range
template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
auto ranges = std::tie(drs...);
Expand All @@ -109,6 +120,8 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {

// If local
if (dr::ranges::rank(seg0) == default_comm().rank()) {
auto origin = seg0.origin();

// make a tuple of mdspans
auto operand_mdspans = detail::tuple_transform(
segs, [](auto &&seg) { return seg.mdspan(); });
Expand Down Expand Up @@ -143,7 +156,18 @@ template <typename... Ts> void for_each(auto op, is_mdspan_view auto &&...drs) {
auto references = detail::tie_transform(
operand_mdspans,
[index](auto mdspan) -> decltype(auto) { return mdspan(index); });
op(references);
if constexpr (__detail::argument_count(op) == 1) {
op(references);
} else if constexpr (__detail::argument_count(op) == 2) {
auto global_index = index;
for (std::size_t i = 0; i < global_index.size(); i++) {
global_index[i] += origin[i];
}

op(global_index, references);
} else {
static_assert(false);
}
};
detail::mdspan_foreach<mdspan0.rank(), decltype(invoke_index)>(
mdspan0.extents(), invoke_index);
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/mhp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ add_executable(

add_executable(mhp-quick-test
mhp-tests.cpp
../common/reduce.cpp
mdstar.cpp
)
# cmake-format: on

Expand Down
19 changes: 19 additions & 0 deletions test/gtest/mhp/mdstar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,25 @@ TEST_F(MdForeach, 3ops) {
c.mdspan());
}

TEST_F(MdForeach, Indexed) {
xhp::distributed_mdarray<T, 2> a(extents2d);
xhp::distributed_mdarray<T, 2> b(extents2d);
auto mda = a.mdspan();
auto mdb = b.mdspan();
xhp::iota(a, 100);
xhp::iota(b, 200);
auto op = [](auto index, auto v) {
auto &[o1, o2] = v;
o1 = index[0];
o2 = index[1];
};

xhp::for_each(op, a, b);
EXPECT_EQ(mda(1, 2), 1);
EXPECT_EQ(mdb(1, 2), 2);
EXPECT_EQ(mda(xdim - 1, ydim - 1), mdb(xdim - 1, ydim - 1));
}

using MdStencilForeach = Mdspan;

TEST_F(MdStencilForeach, 2ops) {
Expand Down

0 comments on commit 340658b

Please sign in to comment.