Skip to content

Commit

Permalink
Fix the order of the function during the IR construction. (#608)
Browse files Browse the repository at this point in the history
Apparently, the test which checked compute_at was disabled and was
failing at head after I enabled it. The issue is that we first were
visiting the inner loop which might violate the assumption that the
function are added in the order of their occurance in the final IR (this
is important so we can merge lifetimes correctly).
  • Loading branch information
vksnk authored Feb 25, 2025
1 parent 01e3778 commit 9e47253
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 86 deletions.
152 changes: 72 additions & 80 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,17 +559,12 @@ class pipeline_builder {

struct statement_with_range {
stmt body;
int start;
int end;
int start = 0;
int end = 0;
// Stores every allocation inside of the range.
std::set<var> allocations;

static statement_with_range merge(const statement_with_range& a, const statement_with_range& b) {
if (a.start > b.end) {
// In case the range order is reversed.
return merge(b, a);
}

assert(a.end + 1 == b.start);
statement_with_range r;
r.body = block::make({a.body, b.body});
Expand Down Expand Up @@ -691,73 +686,73 @@ class pipeline_builder {
return simplify(bounds);
}

statement_with_range make_loop(
statement_with_range body, const func* base_f, const func::loop_info& loop = func::loop_info()) {
// Creates a loop body for a given function including all function bodies computed inside of the loops.
// It may recursively call itself if there are nested loops, it's assumed that loops are produced
// starting from the outer one.
statement_with_range make_loop(const func* base_f, int loop_index) {
const func::loop_info& loop = base_f->loops()[loop_index];
assert(loop.defined());
loop_id here = {base_f, loop.var};

body = build(body, base_f, here);

if (loop.defined()) {
// Find which buffers are used inside of the body.
// TODO(vksnk): recomputing this seems really wasteful, we can should be
// able to maintain the list of buffers as we build the IR.
std::vector<var> buffer_used = find_buffer_dependencies(body.body);

// Add crops for the used buffers using previously inferred bounds.
// Input syms should be the innermost.
for (const auto& i : input_syms_) {
var sym = i.first;
if (!allocation_bounds_[sym]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), sym)) continue;
body.body = crop_buffer::make(sym, sym, *allocation_bounds_[sym], body.body);
statement_with_range body = build(base_f, here);

if (loop_index > 0) {
statement_with_range inner_loop = make_loop(base_f, loop_index - 1);
assert(body.body.defined() || inner_loop.body.defined());
if (body.body.defined() && inner_loop.body.defined()) {
body = statement_with_range::merge(body, inner_loop);
} else if (!body.body.defined() && inner_loop.body.defined()) {
body = inner_loop;
}
}
// Find which buffers are used inside of the body.
// TODO(vksnk): recomputing this seems really wasteful, we can should be
// able to maintain the list of buffers as we build the IR.
std::vector<var> buffer_used = find_buffer_dependencies(body.body);

// Followed by intermediate buffers in the reverse topological order
// (i.e. the outermost buffers are closer to the outputs of the pipeline).
for (auto i = order_.rbegin(); i != order_.rend(); ++i) {
const func* f = *i;
// Add crops for the used buffers using previously inferred bounds.
// Input syms should be the innermost.
for (const auto& i : input_syms_) {
var sym = i.first;
if (!allocation_bounds_[sym]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), sym)) continue;
body.body = crop_buffer::make(sym, sym, *allocation_bounds_[sym], body.body);
}

if (f == base_f) {
// Don't really need to emit buffer_crop for base_f, because they will
// have crop_dim anyway.
continue;
}
for (const func::output& o : f->outputs()) {
const buffer_expr_ptr& b = o.buffer;
if (!inferred_bounds_[b->sym()]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), b->sym())) continue;
body.body = crop_buffer::make(b->sym(), b->sym(), *inferred_bounds_[b->sym()], body.body);
}
}
// Followed by intermediate buffers in the reverse topological order
// (i.e. the outermost buffers are closer to the outputs of the pipeline).
for (auto i = order_.rbegin(); i != order_.rend(); ++i) {
const func* f = *i;

// The loop body is done, and we have an actual loop to make here. Crop the body.
body.body = crop_for_loop(body.body, base_f, loop);
// And make the actual loop.
expr loop_step = sanitizer_.mutate(loop.step);
interval_expr loop_bounds = get_loop_bounds(base_f, loop);
// Make sure that a loop variable is unique.
std::string loop_var_name;
if (base_f->outputs().size() == 1) {
loop_var_name = ctx.name(base_f->outputs()[0].sym()) + ".";
if (f == base_f) {
// Don't really need to emit buffer_crop for base_f, because they will
// have crop_dim anyway.
continue;
}
for (const func::output& o : f->outputs()) {
const buffer_expr_ptr& b = o.buffer;
if (!inferred_bounds_[b->sym()]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), b->sym())) continue;
body.body = crop_buffer::make(b->sym(), b->sym(), *inferred_bounds_[b->sym()], body.body);
}
loop_var_name += ctx.name(loop.sym());
var loop_var = ctx.insert_unique(loop_var_name);
body.body = substitute(body.body, loop.sym(), loop_var);
body.body = loop::make(loop_var, loop.max_workers, loop_bounds, loop_step, body.body);
}

return body;
}

// Generate the loops that we want to be explicit.
// Returns generated statement as well as the lifetime range covered by it.
statement_with_range make_loops(const func* f) {
statement_with_range result;
for (const auto& loop : f->loops()) {
result = make_loop(result, f, loop);
// The loop body is done, and we have an actual loop to make here. Crop the body.
body.body = crop_for_loop(body.body, base_f, loop);
// And make the actual loop.
expr loop_step = sanitizer_.mutate(loop.step);
interval_expr loop_bounds = get_loop_bounds(base_f, loop);
// Make sure that a loop variable is unique.
std::string loop_var_name;
if (base_f->outputs().size() == 1) {
loop_var_name = ctx.name(base_f->outputs()[0].sym()) + ".";
}
loop_var_name += ctx.name(loop.sym());
var loop_var = ctx.insert_unique(loop_var_name);
body.body = substitute(body.body, loop.sym(), loop_var);
body.body = loop::make(loop_var, loop.max_workers, loop_bounds, loop_step, body.body);

return result;
return body;
}

void compute_allocation_bounds() {
Expand Down Expand Up @@ -997,19 +992,19 @@ class pipeline_builder {

const std::vector<var>& external_symbols() const { return sanitizer_.external; }

// This function works together with the produce() and make_loops() functions
// This function works together with the produce() and make_loop() functions
// to build an initial IR. The high-level approach is the following:
// * the `build()` function looks through the list of func's
// to find funcs which need to be produced or allocated at given
// loop level `at`. If func need to be produced it calls the
// `produce()` function which actually produces the body of the
// func. If func has loops it calls the 'make_loops()' func to produce
// func. If func has loops it calls the 'make_loop()' func to produce
// corresponding loops.
// * the `produce()` for a given func produces it's body.
// * the `make_loops()` will produce the necessary loops defined for the function.
// * the `make_loop()` will produce the necessary loops defined for the function.
// For each of the new loops, the `build()` is called for the case when there
// are func which need to be produced in that new loop.
statement_with_range build(const statement_with_range& body, const func* base_f, const loop_id& at) {
statement_with_range build(const func* base_f, const loop_id& at) {
symbol_map<var> uncropped_subs;
std::vector<statement_with_range> results;
// Build the functions computed at this loop level.
Expand All @@ -1023,7 +1018,10 @@ class pipeline_builder {
assert(realize_at != realization_levels_.end());

if (compute_at->second == at && !f->loops().empty()) {
statement_with_range f_body = make_loops(f);
// Generate the loops that we want to be explicit by recursively calling make_loop starting
// from the outer loop.
statement_with_range f_body = make_loop(f, f->loops().size() - 1);

// This is a special case for the buffers which are produced and consumed inside
// of this loop. In this case we simply wrap loop body with corresponding allocations.
if (candidates_for_allocation_[at].size() > old_candidates.size() + 1) {
Expand Down Expand Up @@ -1165,18 +1163,12 @@ class pipeline_builder {
iteration_count++;
}

assert(!results.empty() || body.body.defined());
if (results.empty()) return {};

statement_with_range result;
if (results.empty()) {
result = body;
} else {
result = results.front();
for (std::size_t ix = 1; ix < results.size(); ix++) {
result = statement_with_range::merge(result, results[ix]);
}
if (body.body.defined()) {
result = statement_with_range::merge(result, body);
}
result = results.front();
for (std::size_t ix = 1; ix < results.size(); ix++) {
result = statement_with_range::merge(result, results[ix]);
}

// Add all remaining allocations at this loop level. The allocations can be added in any order. This order enables
Expand Down Expand Up @@ -1332,7 +1324,7 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
pipeline_builder builder(ctx, inputs, outputs);

stmt result;
result = builder.build({}, nullptr, loop_id()).body;
result = builder.build(nullptr, loop_id()).body;
result = builder.add_input_checks(result);
result = builder.make_buffers(result);
result = builder.define_sanitized_replacements(result);
Expand Down
8 changes: 2 additions & 6 deletions builder/test/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ class softmax : public testing::TestWithParam<std::tuple<int, int, bool, int>> {
auto split_factors = testing::Values(0, 1, 4);

INSTANTIATE_TEST_SUITE_P(mode, softmax,
testing::Combine(split_factors, split_factors, testing::Values(false), testing::Values(0)),
test_params_to_string<softmax::ParamType>);

INSTANTIATE_TEST_SUITE_P(compute_at, softmax,
testing::Combine(testing::Values(1), testing::Values(1), testing::Values(false), testing::Values(0)),
testing::Combine(split_factors, split_factors, testing::Values(false, true), testing::Values(0)),
test_params_to_string<softmax::ParamType>);

INSTANTIATE_TEST_SUITE_P(with_copy, softmax,
Expand Down Expand Up @@ -178,7 +174,7 @@ TEST_P(softmax, pipeline) {
pass0.loops(loops);
pass4.loops(loops);

if (use_compute_at) {
if (use_compute_at && split_b > 0) {
pass1.compute_at({&pass4, b});
}

Expand Down

0 comments on commit 9e47253

Please sign in to comment.