Skip to content

Commit

Permalink
Recursively fuse siblings
Browse files Browse the repository at this point in the history
  • Loading branch information
dsharlet committed Jan 23, 2025
1 parent 6077789 commit 04122d0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
27 changes: 18 additions & 9 deletions builder/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,25 +936,34 @@ class sibling_fuser : public stmt_mutator {

public:
void visit(const block* op) override {
if (op->stmts.empty()) {
set_result(op);
return;
}
std::vector<stmt> result;
result.reserve(op->stmts.size());
bool changed = false;
for (const stmt& s : op->stmts) {
result.push_back(mutate(s));
changed = changed || !result.back().same_as(s);
}

// TODO: This currently only looks for immediately adjacent nodes that can be fused. We can also try to fuse
// ops with intervening ops, but this isn't obviously a simplification, and in the case of allocations, may
// increase peak memory usage.
for (std::size_t i = 0; i + 1 < result.size();) {
if (fuse(result[i], result[i + 1])) {
result.erase(result.begin() + i + 1);
result.push_back(op->stmts.front());
bool changed = false;
auto mutate_back = [&]() {
stmt m = mutate(result.back());
if (!m.same_as(result.back())) {
result.back() = std::move(m);
changed = true;
}
};
for (std::size_t i = 1; i < op->stmts.size(); ++i) {
if (!fuse(result.back(), op->stmts[i])) {
mutate_back();
result.push_back(op->stmts[i]);
} else {
++i;
changed = true;
}
}
mutate_back();

if (changed) {
set_result(block::make(std::move(result)));
Expand Down
7 changes: 7 additions & 0 deletions builder/test/optimizations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ node_context symbols;
var x(symbols, "x");
var y(symbols, "y");
var z(symbols, "z");
var w(symbols, "w");

MATCHER_P(matches, expected, "") { return match(arg, expected); }

Expand Down Expand Up @@ -68,6 +69,12 @@ TEST(optimizations, fuse_siblings) {
use_buffer(z),
allocate::make(y, memory_type::heap, 1, {}, use_buffer(y)),
})));
ASSERT_THAT(fuse_siblings(block::make({
crop_dim::make(x, y, 0, {0, 10}, crop_dim::make(z, x, 1, {0, 10}, use_buffer(z))),
crop_dim::make(z, y, 0, {0, 10}, crop_dim::make(w, z, 1, {0, 10}, use_buffer(w))),
})),
matches(crop_dim::make(
x, y, 0, {0, 10}, crop_dim::make(z, x, 1, {0, 10}, block::make({use_buffer(z), use_buffer(z)})))));
}

TEST(optimizations, optimize_symbols) {
Expand Down

0 comments on commit 04122d0

Please sign in to comment.