Skip to content

Commit

Permalink
Merge branch 'main' into vksnk/constants
Browse files Browse the repository at this point in the history
  • Loading branch information
vksnk committed Jan 30, 2025
2 parents 3d6434d + 8b4d18a commit 36c0bd9
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 13 deletions.
37 changes: 26 additions & 11 deletions runtime/evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,23 @@ class evaluator {
return result;
}

SLINKY_NO_STACK_PROTECTOR index_t eval_unshadowed(const crop_buffer* op) {
// The operation is not shadowed. Make a clone and use eval_shadowed on the clone.
const raw_buffer* src_buf = reinterpret_cast<raw_buffer*>(context.lookup(op->src));
assert(src_buf);

raw_buffer sym_buf = *src_buf;
sym_buf.dims = SLINKY_ALLOCA(dim, src_buf->rank);
internal::copy_small_n(src_buf->dims, src_buf->rank, sym_buf.dims);
for (std::size_t d = 0; d < op->bounds.size(); ++d) {
const slinky::dim& dim = sym_buf.dims[d];
interval bounds = eval(op->bounds[d], {dim.min(), dim.max()});
sym_buf.crop(d, bounds.min, bounds.max);
}

return eval_with_value(op->body, op->sym, reinterpret_cast<index_t>(&sym_buf));
}

index_t eval_shadowed(const crop_dim* op) {
raw_buffer* buffer = reinterpret_cast<raw_buffer*>(context.lookup(op->sym));
assert(buffer);
Expand All @@ -599,21 +616,19 @@ class evaluator {
return result;
}

template <typename T>
SLINKY_NO_STACK_PROTECTOR index_t eval_unshadowed(const T* op) {
SLINKY_NO_STACK_PROTECTOR index_t eval_unshadowed(const crop_dim* op) {
// The operation is not shadowed. Make a clone and use eval_shadowed on the clone.
raw_buffer* src_buf = reinterpret_cast<raw_buffer*>(context.lookup(op->src));
const raw_buffer* src_buf = reinterpret_cast<raw_buffer*>(context.lookup(op->src));
assert(src_buf);

raw_buffer clone = *src_buf;
clone.dims = SLINKY_ALLOCA(dim, src_buf->rank);
internal::copy_small_n(src_buf->dims, src_buf->rank, clone.dims);
raw_buffer sym_buf = *src_buf;
sym_buf.dims = SLINKY_ALLOCA(dim, src_buf->rank);
internal::copy_small_n(src_buf->dims, src_buf->rank, sym_buf.dims);
slinky::dim& dim = sym_buf.dims[op->dim];
interval bounds = eval(op->bounds, {dim.min(), dim.max()});
sym_buf.crop(op->dim, bounds.min, bounds.max);

context.reserve(op->sym.id + 1);
index_t old_value = context.set(op->sym, reinterpret_cast<index_t>(&clone));
index_t result = eval_shadowed(op);
context.set(op->sym, old_value);
return result;
return eval_with_value(op->body, op->sym, reinterpret_cast<index_t>(&sym_buf));
}

template <typename T>
Expand Down
8 changes: 6 additions & 2 deletions runtime/test/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ void for_each_index(span<const dim> dims, int d, index_t* is, const F& f) {

template <typename F>
SLINKY_NO_STACK_PROTECTOR void for_each_index(span<const dim> dims, const F& f) {
index_t* i = SLINKY_ALLOCA(index_t, dims.size());
for_each_index(dims, dims.size() - 1, i, f);
if (dims.empty()) {
f(span<const index_t>{});
} else {
index_t* i = SLINKY_ALLOCA(index_t, dims.size());
for_each_index(dims, dims.size() - 1, i, f);
}
}
template <typename F>
void for_each_index(const raw_buffer& buf, const F& f) {
Expand Down
9 changes: 9 additions & 0 deletions runtime/test/evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,15 @@ TEST(evaluate, crop_dim) {
buffer<void, 2> buf({10, 20});
buf.allocate();
ctx[x] = reinterpret_cast<index_t>(&buf);
buffer<void, 1> y_buf({3});
ctx[y] = reinterpret_cast<index_t>(&y_buf);

auto buf_before = buf;

evaluate(crop_dim::make(x, x, 0, {1, 3}, make_check(x, {3, 20})), ctx);
evaluate(crop_dim::make(y, x, 0, {1, 3}, block::make({make_check(x, {10, 20}), make_check(y, {3, 20})})), ctx);
evaluate(
crop_dim::make(y, x, 0, buffer_bounds(y, 0), block::make({make_check(x, {10, 20}), make_check(y, {3, 20})})), ctx);
ASSERT_EQ(buf_before, buf);
}

Expand All @@ -151,13 +155,18 @@ TEST(evaluate, crop_buffer) {
buffer<void, 4> buf({10, 20, 30, 40});
buf.allocate();
ctx[x] = reinterpret_cast<index_t>(&buf);
buffer<void, 4> y_buf({3, 4, 5, 6});
ctx[y] = reinterpret_cast<index_t>(&y_buf);

auto buf_before = buf;

evaluate(crop_buffer::make(x, x, {{1, 3}, {}, {2, 5}}, make_check(x, {3, 20, 4, 40})), ctx);
evaluate(crop_buffer::make(y, x, {{1, 3}, {}, {2, 5}},
block::make({make_check(x, {10, 20, 30, 40}), make_check(y, {3, 20, 4, 40})})),
ctx);
evaluate(crop_buffer::make(y, x, {buffer_bounds(y, 0), buffer_bounds(y, 1)},
block::make({make_check(x, {10, 20, 30, 40}), make_check(y, {3, 4, 30, 40})})),
ctx);
ASSERT_EQ(buf_before, buf);
}

Expand Down

0 comments on commit 36c0bd9

Please sign in to comment.