Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes #573

Merged
merged 4 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,10 @@ class reuse_shadows : public stmt_mutator {

void visit(const allocate* op) override { visit_buffer_decl(op); }
void visit(const make_buffer* op) override { visit_buffer_decl(op); }
void visit(const constant_buffer* op) override { visit_buffer_decl(op, false); }
void visit(const constant_buffer* op) override {
// Constant buffers are not mutable, because the raw_buffer object we use is not allocated by a declaration.
visit_buffer_decl(op, false);
}

void visit(const crop_buffer* op) override { visit_buffer_mutator(op); }
void visit(const crop_dim* op) override { visit_buffer_mutator(op); }
Expand Down Expand Up @@ -1405,9 +1408,7 @@ class node_canonicalizer : public node_mutator {

} // namespace

expr canonicalize_nodes(const expr& e) {
return node_canonicalizer().mutate(e);
}
expr canonicalize_nodes(const expr& e) { return node_canonicalizer().mutate(e); }
stmt canonicalize_nodes(const stmt& s) {
scoped_trace trace("canonicalize_nodes");
return node_canonicalizer().mutate(s);
Expand Down
4 changes: 3 additions & 1 deletion builder/slide_and_fold_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ class slide_and_fold : public stmt_mutator {
return true;
}

loop_info() = default;

loop_info(node_context& ctx, var sym, std::size_t loop_id, expr orig_min, interval_expr bounds, expr step,
int max_workers)
: sym(sym), orig_min(orig_min), bounds(bounds), step(step), max_workers(max_workers),
Expand All @@ -243,7 +245,7 @@ class slide_and_fold : public stmt_mutator {
symbol_map<modulus_remainder<index_t>>& current_expr_alignment() { return *loops.back().expr_alignment; }

slide_and_fold(node_context& ctx) : ctx(ctx), x(ctx.insert_unique("_x")) {
loops.emplace_back(ctx, var(), loop_counter++, expr(), interval_expr::none(), expr(), loop::serial);
loops.emplace_back(loop_info());
}

stmt mutate(const stmt& s) override {
Expand Down
11 changes: 11 additions & 0 deletions builder/substitute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,17 @@ void substitutor::visit(const make_buffer* op) {
}
exit_decls();
}
void substitutor::visit(const constant_buffer* op) {
var sym = enter_decl(op->sym);
stmt body = sym.defined() ? mutate(op->body) : op->body;
sym = sym.defined() ? sym : op->sym;
if (sym == op->sym && body.same_as(op->body)) {
set_result(op);
} else {
set_result(constant_buffer::make(sym, op->value, std::move(body)));
}
exit_decls();
}

void substitutor::visit(const slice_buffer* op) {
var src = visit_symbol(op->src);
Expand Down
1 change: 1 addition & 0 deletions builder/substitute.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class substitutor : public node_mutator {
void visit(const loop* op) override;
void visit(const allocate* op) override;
void visit(const make_buffer* op) override;
void visit(const constant_buffer* op) override;
void visit(const slice_buffer* op) override;
void visit(const slice_dim* op) override;
void visit(const crop_buffer* op) override;
Expand Down
2 changes: 2 additions & 0 deletions runtime/evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ class evaluator {
SLINKY_NO_INLINE index_t eval_loop_parallel(const loop* op) {
interval bounds = eval(op->bounds);
index_t step = eval(op->step, 1);
assert(step != 0);
std::atomic<index_t> result = 0;
std::size_t n = ceil_div(bounds.max - bounds.min + 1, step);
context.reserve(op->sym.id + 1);
Expand All @@ -431,6 +432,7 @@ class evaluator {
SLINKY_NO_INLINE index_t eval_loop_serial(const loop* op) {
interval bounds = eval(op->bounds);
index_t step = eval(op->step, 1);
assert(step != 0);
// TODO(https://github.com/dsharlet/slinky/issues/3): We don't get a reference to context[op->sym] here
// because the context could grow and invalidate the reference. This could be fixed by having evaluate
// fully traverse the expression to find the max var, and pre-allocate the context up front. It's
Expand Down
Loading