Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the order of the function during the IR construction. #608

Merged
merged 7 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 72 additions & 80 deletions builder/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,17 +559,12 @@ class pipeline_builder {

struct statement_with_range {
stmt body;
int start;
int end;
int start = 0;
int end = 0;
// Stores every allocation inside of the range.
std::set<var> allocations;

static statement_with_range merge(const statement_with_range& a, const statement_with_range& b) {
if (a.start > b.end) {
// In case the range order is reversed.
return merge(b, a);
}

assert(a.end + 1 == b.start);
statement_with_range r;
r.body = block::make({a.body, b.body});
Expand Down Expand Up @@ -691,73 +686,73 @@ class pipeline_builder {
return simplify(bounds);
}

statement_with_range make_loop(
statement_with_range body, const func* base_f, const func::loop_info& loop = func::loop_info()) {
// Creates a loop body for a given function including all function bodies computed inside of the loops.
// It may recursively call itself if there are nested loops, it's assumed that loops are produced
// starting from the outer one.
statement_with_range make_loop(const func* base_f, int loop_index) {
const func::loop_info& loop = base_f->loops()[loop_index];
assert(loop.defined());
loop_id here = {base_f, loop.var};

body = build(body, base_f, here);

if (loop.defined()) {
// Find which buffers are used inside of the body.
// TODO(vksnk): recomputing this seems really wasteful, we can should be
// able to maintain the list of buffers as we build the IR.
std::vector<var> buffer_used = find_buffer_dependencies(body.body);

// Add crops for the used buffers using previously inferred bounds.
// Input syms should be the innermost.
for (const auto& i : input_syms_) {
var sym = i.first;
if (!allocation_bounds_[sym]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), sym)) continue;
body.body = crop_buffer::make(sym, sym, *allocation_bounds_[sym], body.body);
statement_with_range body = build(base_f, here);

if (loop_index > 0) {
statement_with_range inner_loop = make_loop(base_f, loop_index - 1);
assert(body.body.defined() || inner_loop.body.defined());
if (body.body.defined() && inner_loop.body.defined()) {
body = statement_with_range::merge(body, inner_loop);
} else if (!body.body.defined() && inner_loop.body.defined()) {
body = inner_loop;
}
}
// Find which buffers are used inside of the body.
// TODO(vksnk): recomputing this seems really wasteful, we can should be
// able to maintain the list of buffers as we build the IR.
std::vector<var> buffer_used = find_buffer_dependencies(body.body);

// Followed by intermediate buffers in the reverse topological order
// (i.e. the outermost buffers are closer to the outputs of the pipeline).
for (auto i = order_.rbegin(); i != order_.rend(); ++i) {
const func* f = *i;
// Add crops for the used buffers using previously inferred bounds.
// Input syms should be the innermost.
for (const auto& i : input_syms_) {
var sym = i.first;
if (!allocation_bounds_[sym]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), sym)) continue;
body.body = crop_buffer::make(sym, sym, *allocation_bounds_[sym], body.body);
}

if (f == base_f) {
// Don't really need to emit buffer_crop for base_f, because they will
// have crop_dim anyway.
continue;
}
for (const func::output& o : f->outputs()) {
const buffer_expr_ptr& b = o.buffer;
if (!inferred_bounds_[b->sym()]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), b->sym())) continue;
body.body = crop_buffer::make(b->sym(), b->sym(), *inferred_bounds_[b->sym()], body.body);
}
}
// Followed by intermediate buffers in the reverse topological order
// (i.e. the outermost buffers are closer to the outputs of the pipeline).
for (auto i = order_.rbegin(); i != order_.rend(); ++i) {
const func* f = *i;

// The loop body is done, and we have an actual loop to make here. Crop the body.
body.body = crop_for_loop(body.body, base_f, loop);
// And make the actual loop.
expr loop_step = sanitizer_.mutate(loop.step);
interval_expr loop_bounds = get_loop_bounds(base_f, loop);
// Make sure that a loop variable is unique.
std::string loop_var_name;
if (base_f->outputs().size() == 1) {
loop_var_name = ctx.name(base_f->outputs()[0].sym()) + ".";
if (f == base_f) {
// Don't really need to emit buffer_crop for base_f, because they will
// have crop_dim anyway.
continue;
}
for (const func::output& o : f->outputs()) {
const buffer_expr_ptr& b = o.buffer;
if (!inferred_bounds_[b->sym()]) continue;
if (!std::binary_search(buffer_used.begin(), buffer_used.end(), b->sym())) continue;
body.body = crop_buffer::make(b->sym(), b->sym(), *inferred_bounds_[b->sym()], body.body);
}
loop_var_name += ctx.name(loop.sym());
var loop_var = ctx.insert_unique(loop_var_name);
body.body = substitute(body.body, loop.sym(), loop_var);
body.body = loop::make(loop_var, loop.max_workers, loop_bounds, loop_step, body.body);
}

return body;
}

// Generate the loops that we want to be explicit.
// Returns generated statement as well as the lifetime range covered by it.
statement_with_range make_loops(const func* f) {
statement_with_range result;
for (const auto& loop : f->loops()) {
result = make_loop(result, f, loop);
// The loop body is done, and we have an actual loop to make here. Crop the body.
body.body = crop_for_loop(body.body, base_f, loop);
// And make the actual loop.
expr loop_step = sanitizer_.mutate(loop.step);
interval_expr loop_bounds = get_loop_bounds(base_f, loop);
// Make sure that a loop variable is unique.
std::string loop_var_name;
if (base_f->outputs().size() == 1) {
loop_var_name = ctx.name(base_f->outputs()[0].sym()) + ".";
}
loop_var_name += ctx.name(loop.sym());
var loop_var = ctx.insert_unique(loop_var_name);
body.body = substitute(body.body, loop.sym(), loop_var);
body.body = loop::make(loop_var, loop.max_workers, loop_bounds, loop_step, body.body);

return result;
return body;
}

void compute_allocation_bounds() {
Expand Down Expand Up @@ -997,19 +992,19 @@ class pipeline_builder {

const std::vector<var>& external_symbols() const { return sanitizer_.external; }

// This function works together with the produce() and make_loops() functions
// This function works together with the produce() and make_loop() functions
// to build an initial IR. The high-level approach is the following:
// * the `build()` function looks through the list of func's
// to find funcs which need to be produced or allocated at given
// loop level `at`. If func need to be produced it calls the
// `produce()` function which actually produces the body of the
// func. If func has loops it calls the 'make_loops()' func to produce
// func. If func has loops it calls the 'make_loop()' func to produce
// corresponding loops.
// * the `produce()` for a given func produces it's body.
// * the `make_loops()` will produce the necessary loops defined for the function.
// * the `make_loop()` will produce the necessary loops defined for the function.
// For each of the new loops, the `build()` is called for the case when there
// are func which need to be produced in that new loop.
statement_with_range build(const statement_with_range& body, const func* base_f, const loop_id& at) {
statement_with_range build(const func* base_f, const loop_id& at) {
symbol_map<var> uncropped_subs;
std::vector<statement_with_range> results;
// Build the functions computed at this loop level.
Expand All @@ -1023,7 +1018,10 @@ class pipeline_builder {
assert(realize_at != realization_levels_.end());

if (compute_at->second == at && !f->loops().empty()) {
statement_with_range f_body = make_loops(f);
// Generate the loops that we want to be explicit by recursively calling make_loop starting
// from the outer loop.
statement_with_range f_body = make_loop(f, f->loops().size() - 1);

// This is a special case for the buffers which are produced and consumed inside
// of this loop. In this case we simply wrap loop body with corresponding allocations.
if (candidates_for_allocation_[at].size() > old_candidates.size() + 1) {
Expand Down Expand Up @@ -1165,18 +1163,12 @@ class pipeline_builder {
iteration_count++;
}

assert(!results.empty() || body.body.defined());
if (results.empty()) return {};

statement_with_range result;
if (results.empty()) {
result = body;
} else {
result = results.front();
for (std::size_t ix = 1; ix < results.size(); ix++) {
result = statement_with_range::merge(result, results[ix]);
}
if (body.body.defined()) {
result = statement_with_range::merge(result, body);
}
result = results.front();
for (std::size_t ix = 1; ix < results.size(); ix++) {
result = statement_with_range::merge(result, results[ix]);
}

// Add all remaining allocations at this loop level. The allocations can be added in any order. This order enables
Expand Down Expand Up @@ -1332,7 +1324,7 @@ stmt build_pipeline(node_context& ctx, const std::vector<buffer_expr_ptr>& input
pipeline_builder builder(ctx, inputs, outputs);

stmt result;
result = builder.build({}, nullptr, loop_id()).body;
result = builder.build(nullptr, loop_id()).body;
result = builder.add_input_checks(result);
result = builder.make_buffers(result);
result = builder.define_sanitized_replacements(result);
Expand Down
8 changes: 2 additions & 6 deletions builder/test/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,7 @@ class softmax : public testing::TestWithParam<std::tuple<int, int, bool, int>> {
auto split_factors = testing::Values(0, 1, 4);

INSTANTIATE_TEST_SUITE_P(mode, softmax,
testing::Combine(split_factors, split_factors, testing::Values(false), testing::Values(0)),
test_params_to_string<softmax::ParamType>);

INSTANTIATE_TEST_SUITE_P(compute_at, softmax,
testing::Combine(testing::Values(1), testing::Values(1), testing::Values(false), testing::Values(0)),
testing::Combine(split_factors, split_factors, testing::Values(false, true), testing::Values(0)),
test_params_to_string<softmax::ParamType>);

INSTANTIATE_TEST_SUITE_P(with_copy, softmax,
Expand Down Expand Up @@ -178,7 +174,7 @@ TEST_P(softmax, pipeline) {
pass0.loops(loops);
pass4.loops(loops);

if (use_compute_at) {
if (use_compute_at && split_b > 0) {
pass1.compute_at({&pass4, b});
}

Expand Down
Loading