Skip to content

Commit

Permalink
Support contiguous folded dimensions that aren't the innermost loop
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Feb 13, 2025
1 parent afc5627 commit 712757d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 51 deletions.
89 changes: 38 additions & 51 deletions runtime/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,11 @@ SLINKY_ALWAYS_INLINE inline bool use_nonlinear_loop(span<const raw_buffer*, Bufs
template <std::size_t BufsSize = dynamic_extent>
struct for_each_loop;

// These functions may modify the second parameter in-place.
// These functions may modify the first parameter in-place.
// TODO: The last parameter (slice_extent) is unused by `for_each_element_callback`. Find a way to eliminate the
// parameter in that case. Adding a template parameter to everything affected to control this is painful.
template <std::size_t BufsSize>
using for_each_loop_impl = void (*)(mutable_span<void*, BufsSize>, const for_each_loop<BufsSize>*);
using for_each_loop_impl = void (*)(mutable_span<void*, BufsSize>, const for_each_loop<BufsSize>*, index_t);

template <std::size_t BufsSize>
struct for_each_loop {
Expand All @@ -503,91 +505,76 @@ struct for_each_loop {
};
};

template <typename F>
struct callback {
F f;
index_t slice_extent;
};

std::ptrdiff_t sizeof_for_each_loop(std::size_t bufs_size) {
return sizeof(for_each_loop<>) - sizeof(for_each_loop<>::dims) + sizeof(for_each_loop<>::dims) * bufs_size;
}

// We store a plan for a parallel for loop in a structure of the following layout, for N buffers and rank R loops:
// for_each_loop<N> loops[R];
// callback<F> f;
// F f;
//
// We can't make a simple struct for this, because N and R are not necessarily compile-time constants.
template <typename F>
SLINKY_ALWAYS_INLINE inline std::size_t size_of_plan(std::size_t bufs_size, std::size_t rank) {
return sizeof_for_each_loop(bufs_size) * std::max<std::size_t>(1, rank) + sizeof(callback<F>);
return sizeof_for_each_loop(bufs_size) * std::max<std::size_t>(1, rank) + sizeof(F);
}

// Compile-time dispatch to either for_each_contiguous_slice_callback or for_each_element_callback.
SLINKY_ALWAYS_INLINE inline void call_f(const callback<for_each_element_callback>& f, void** bases, index_t extent,
const index_t* strides, std::false_type = {}) {
f.f(bases, extent, strides);
SLINKY_ALWAYS_INLINE inline void call_f(
for_each_element_callback f, void** bases, index_t extent, const index_t* strides, index_t slice_extent) {
assert(slice_extent == 1);
f(bases, extent, strides);
}
SLINKY_ALWAYS_INLINE inline void call_f(const callback<for_each_element_callback>& f, void** bases, index_t extent,
const index_t* strides, std::true_type) {
f.f(bases, extent, strides);
}
SLINKY_ALWAYS_INLINE inline void call_f(const callback<for_each_contiguous_slice_callback>& f, void** bases,
index_t extent, const index_t* strides, std::false_type = {}) {
f.f(f.slice_extent, bases, extent, strides);
}
// If we have a contiguous callback, and we know the slice is contiguous, we can adjust the contiguous slice extent
// instead of calling f in a loop.
SLINKY_ALWAYS_INLINE inline void call_f(const callback<for_each_contiguous_slice_callback>& f, void** bases,
index_t extent, const index_t* strides, std::true_type) {
f.f(f.slice_extent * extent, bases, 1, nullptr);
SLINKY_ALWAYS_INLINE inline void call_f(
for_each_contiguous_slice_callback f, void** bases, index_t extent, const index_t* strides, index_t slice_extent) {
f(slice_extent, bases, extent, strides);
}

template <typename F, std::size_t BufsSize>
void for_each_impl_call_f(mutable_span<void*, BufsSize> bases, const for_each_loop<BufsSize>* loop) {
const callback<F>& f =
*reinterpret_cast<const callback<F>*>(offset_bytes_non_null(loop, sizeof_for_each_loop(bases.size())));
void for_each_impl_call_f(
mutable_span<void*, BufsSize> bases, const for_each_loop<BufsSize>* loop, index_t slice_extent) {
const F& f = *reinterpret_cast<const F*>(offset_bytes_non_null(loop, sizeof_for_each_loop(bases.size())));

call_f(f, bases.data(), loop->extent, loop->strides);
call_f(f, bases.data(), loop->extent, loop->strides, slice_extent);
}

template <std::size_t BufsSize>
SLINKY_NO_STACK_PROTECTOR void call_impl_linear(
index_t extent, mutable_span<void*, BufsSize> bases, const for_each_loop<BufsSize>* loop, const index_t* strides) {
SLINKY_NO_STACK_PROTECTOR void call_impl_linear(index_t extent, mutable_span<void*, BufsSize> bases,
const for_each_loop<BufsSize>* loop, const index_t* strides, index_t slice_extent) {
assert(extent >= 1);

for_each_loop_impl<BufsSize> impl = loop->impl;

void** bases_i = SLINKY_ALLOCA(void*, bases.size());
for (;;) {
copy_small_n(bases.data(), bases.size(), bases_i);
impl({bases_i, bases.size()}, loop);
impl({bases_i, bases.size()}, loop, slice_extent);
if (SLINKY_UNLIKELY(--extent <= 0)) break;
increment_bases<BufsSize>(bases.size(), bases.data(), strides);
}
}

template <typename F, std::size_t BufsSize>
SLINKY_NO_STACK_PROTECTOR void for_each_impl_linear(
mutable_span<void*, BufsSize> bases, const for_each_loop<BufsSize>* loop) {
mutable_span<void*, BufsSize> bases, const for_each_loop<BufsSize>* loop, index_t slice_extent) {
const index_t* strides = loop->strides;
index_t extent = loop->extent;
assert(extent >= 1);

loop = offset_bytes_non_null(loop, sizeof_for_each_loop(bases.size()));

call_impl_linear(extent, bases, loop, strides);
call_impl_linear(extent, bases, loop, strides, slice_extent);
}

template <typename F, std::size_t BufsSize, bool CallF, typename Contiguous = std::false_type>
template <typename F, std::size_t BufsSize, bool CallF, bool Contiguous>
SLINKY_NO_STACK_PROTECTOR void for_each_impl_nonlinear(
mutable_span<void*, BufsSize> bases, const for_each_loop<BufsSize>* loop) {
mutable_span<void*, BufsSize> bases, const for_each_loop<BufsSize>* loop, index_t slice_extent) {
const dim* const* dims = loop->dims;
const index_t fold_factor = loop->fold_factor;

loop = offset_bytes_non_null(loop, sizeof_for_each_loop(bases.size()));

const callback<F>& f = *reinterpret_cast<const callback<F>*>(loop);
const F& f = *reinterpret_cast<const F*>(loop);

index_t* strides = SLINKY_ALLOCA(index_t, bases.size());
for (std::size_t n = 0; n < bases.size(); ++n) {
Expand All @@ -599,7 +586,8 @@ SLINKY_NO_STACK_PROTECTOR void for_each_impl_nonlinear(

// To handle non-linear loops, we process an interval [min, max] in blocks of the `fold_factor`, within which we can
// compute the base pointers linearly from the strides, if the buffers are fully in-bounds or out-of-bounds.
// We need to handle buffers going in and out of bounds too, so we break the blocks into smaller chunks at those boundaries.
// We need to handle buffers going in and out of bounds too, so we break the blocks into smaller chunks at those
// boundaries.
auto run_one_fold = [&](index_t min, index_t max) {
for (index_t i = min; i <= max;) {
index_t max_i = max;
Expand All @@ -624,12 +612,13 @@ SLINKY_NO_STACK_PROTECTOR void for_each_impl_nonlinear(
}
}
}
index_t extent_i = max_i - i + 1;
index_t extent_i = Contiguous ? 1 : max_i - i + 1;
index_t slice_extent_i = (Contiguous ? (max_i - i + 1) : 1) * slice_extent;
if (CallF) {
// If the next step is to call f, do that eagerly here to avoid an extra call.
call_f(f, bases_i, extent_i, strides, Contiguous());
call_f(f, bases_i, extent_i, strides, slice_extent_i);
} else {
call_impl_linear<BufsSize>(extent_i, {bases_i, bases.size()}, loop, strides);
call_impl_linear<BufsSize>(extent_i, {bases_i, bases.size()}, loop, strides, slice_extent_i);
}
i = max_i + 1;
}
Expand Down Expand Up @@ -685,11 +674,12 @@ SLINKY_NO_STACK_PROTECTOR SLINKY_ALWAYS_INLINE inline void for_each_impl(span<co
} else if (buf_dim.max() < buf_dim.min() || use_nonlinear_loop(bufs, d)) {
// extent > 1 and there is a folded dimension in one of the buffers, or we need to crop one of the buffers, or the
// loops are empty.
loop->impl = for_each_impl_nonlinear<F, BufsSize, false>;
if (SkipContiguous && is_contiguous_slice(bufs, d)) {
inner_impl = for_each_impl_nonlinear<F, BufsSize, true, std::true_type>;
loop->impl = for_each_impl_nonlinear<F, BufsSize, false, true>;
inner_impl = for_each_impl_nonlinear<F, BufsSize, true, true>;
} else {
inner_impl = for_each_impl_nonlinear<F, BufsSize, true, std::false_type>;
loop->impl = for_each_impl_nonlinear<F, BufsSize, false, false>;
inner_impl = for_each_impl_nonlinear<F, BufsSize, true, false>;
}
extent = 1;

Expand Down Expand Up @@ -762,20 +752,17 @@ SLINKY_NO_STACK_PROTECTOR SLINKY_ALWAYS_INLINE inline void for_each_impl(span<co
}
if (loop == outer_loop) {
// There are no loops, just call f. This is an edge case below branch which assumes there is at least one loop.
call_f(callback<F>{f, slice_extent}, bases, 1, nullptr);
call_f(f, bases, 1, nullptr, slice_extent);
} else {
// Put the callback at the end of the plan, where the inner loop expects to find it.
reinterpret_cast<callback<F>*>(loop)->f = f;
if (SkipContiguous) {
reinterpret_cast<callback<F>*>(loop)->slice_extent = slice_extent;
}
*reinterpret_cast<F*>(loop) = f;

// We need to replace the implementation of the last loop.
for_each_loop<BufsSize>* inner_loop = offset_bytes_non_null(loop, -sizeof_for_each_loop(bufs.size()));
inner_loop->impl = inner_impl;

// Run the outer loop.
outer_loop->impl({bases, bufs.size()}, outer_loop);
outer_loop->impl({bases, bufs.size()}, outer_loop, slice_extent);
}
}

Expand Down
15 changes: 15 additions & 0 deletions runtime/test/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,21 @@ TEST(buffer, for_each_contiguous_folded_innermost) {
ASSERT_TRUE(is_filled_buffer(buf, 7));
}

TEST(buffer, for_each_contiguous_folded_innermost_dim_1) {
buffer<char, 3> buf({10, 20});
buf.dim(0).set_fold_factor(4);
buf.init_strides();
std::swap(buf.dim(0), buf.dim(1));
buf.allocate();
int slices = 0;
for_each_contiguous_slice(buf, [&](index_t slice_extent, char* slice) {
std::fill_n(slice, slice_extent, 7);
slices++;
});
ASSERT_EQ(slices, 60);
ASSERT_TRUE(is_filled_buffer(buf, 7));
}

TEST(buffer, for_each_contiguous_cropped) {
buffer<char, 1> src({10});
buffer<char, 1> dst({10});
Expand Down

0 comments on commit 712757d

Please sign in to comment.