From 9e472533d38d0e075fc5066eefd81b3985309692 Mon Sep 17 00:00:00 2001 From: Volodymyr Kysenko Date: Tue, 25 Feb 2025 13:21:22 -0800 Subject: [PATCH] Fix the order of the function during the IR construction. (#608) Apparently, the test which checked compute_at was disabled and was failing at head after I enabled it. The issue is that we first were visiting the inner loop which might violate the assumption that the function are added in the order of their occurance in the final IR (this is important so we can merge lifetimes correctly). --- builder/pipeline.cc | 152 +++++++++++++++++++--------------------- builder/test/softmax.cc | 8 +-- 2 files changed, 74 insertions(+), 86 deletions(-) diff --git a/builder/pipeline.cc b/builder/pipeline.cc index 242ca47e..b5849d80 100644 --- a/builder/pipeline.cc +++ b/builder/pipeline.cc @@ -559,17 +559,12 @@ class pipeline_builder { struct statement_with_range { stmt body; - int start; - int end; + int start = 0; + int end = 0; // Stores every allocation inside of the range. std::set allocations; static statement_with_range merge(const statement_with_range& a, const statement_with_range& b) { - if (a.start > b.end) { - // In case the range order is reversed. - return merge(b, a); - } - assert(a.end + 1 == b.start); statement_with_range r; r.body = block::make({a.body, b.body}); @@ -691,73 +686,73 @@ class pipeline_builder { return simplify(bounds); } - statement_with_range make_loop( - statement_with_range body, const func* base_f, const func::loop_info& loop = func::loop_info()) { + // Creates a loop body for a given function including all function bodies computed inside of the loops. + // It may recursively call itself if there are nested loops, it's assumed that loops are produced + // starting from the outer one. + statement_with_range make_loop(const func* base_f, int loop_index) { + const func::loop_info& loop = base_f->loops()[loop_index]; + assert(loop.defined()); loop_id here = {base_f, loop.var}; - body = build(body, base_f, here); - - if (loop.defined()) { - // Find which buffers are used inside of the body. - // TODO(vksnk): recomputing this seems really wasteful, we can should be - // able to maintain the list of buffers as we build the IR. - std::vector buffer_used = find_buffer_dependencies(body.body); - - // Add crops for the used buffers using previously inferred bounds. - // Input syms should be the innermost. - for (const auto& i : input_syms_) { - var sym = i.first; - if (!allocation_bounds_[sym]) continue; - if (!std::binary_search(buffer_used.begin(), buffer_used.end(), sym)) continue; - body.body = crop_buffer::make(sym, sym, *allocation_bounds_[sym], body.body); + statement_with_range body = build(base_f, here); + + if (loop_index > 0) { + statement_with_range inner_loop = make_loop(base_f, loop_index - 1); + assert(body.body.defined() || inner_loop.body.defined()); + if (body.body.defined() && inner_loop.body.defined()) { + body = statement_with_range::merge(body, inner_loop); + } else if (!body.body.defined() && inner_loop.body.defined()) { + body = inner_loop; } + } + // Find which buffers are used inside of the body. + // TODO(vksnk): recomputing this seems really wasteful, we can should be + // able to maintain the list of buffers as we build the IR. + std::vector buffer_used = find_buffer_dependencies(body.body); - // Followed by intermediate buffers in the reverse topological order - // (i.e. the outermost buffers are closer to the outputs of the pipeline). - for (auto i = order_.rbegin(); i != order_.rend(); ++i) { - const func* f = *i; + // Add crops for the used buffers using previously inferred bounds. + // Input syms should be the innermost. + for (const auto& i : input_syms_) { + var sym = i.first; + if (!allocation_bounds_[sym]) continue; + if (!std::binary_search(buffer_used.begin(), buffer_used.end(), sym)) continue; + body.body = crop_buffer::make(sym, sym, *allocation_bounds_[sym], body.body); + } - if (f == base_f) { - // Don't really need to emit buffer_crop for base_f, because they will - // have crop_dim anyway. - continue; - } - for (const func::output& o : f->outputs()) { - const buffer_expr_ptr& b = o.buffer; - if (!inferred_bounds_[b->sym()]) continue; - if (!std::binary_search(buffer_used.begin(), buffer_used.end(), b->sym())) continue; - body.body = crop_buffer::make(b->sym(), b->sym(), *inferred_bounds_[b->sym()], body.body); - } - } + // Followed by intermediate buffers in the reverse topological order + // (i.e. the outermost buffers are closer to the outputs of the pipeline). + for (auto i = order_.rbegin(); i != order_.rend(); ++i) { + const func* f = *i; - // The loop body is done, and we have an actual loop to make here. Crop the body. - body.body = crop_for_loop(body.body, base_f, loop); - // And make the actual loop. - expr loop_step = sanitizer_.mutate(loop.step); - interval_expr loop_bounds = get_loop_bounds(base_f, loop); - // Make sure that a loop variable is unique. - std::string loop_var_name; - if (base_f->outputs().size() == 1) { - loop_var_name = ctx.name(base_f->outputs()[0].sym()) + "."; + if (f == base_f) { + // Don't really need to emit buffer_crop for base_f, because they will + // have crop_dim anyway. + continue; + } + for (const func::output& o : f->outputs()) { + const buffer_expr_ptr& b = o.buffer; + if (!inferred_bounds_[b->sym()]) continue; + if (!std::binary_search(buffer_used.begin(), buffer_used.end(), b->sym())) continue; + body.body = crop_buffer::make(b->sym(), b->sym(), *inferred_bounds_[b->sym()], body.body); } - loop_var_name += ctx.name(loop.sym()); - var loop_var = ctx.insert_unique(loop_var_name); - body.body = substitute(body.body, loop.sym(), loop_var); - body.body = loop::make(loop_var, loop.max_workers, loop_bounds, loop_step, body.body); } - return body; - } - - // Generate the loops that we want to be explicit. - // Returns generated statement as well as the lifetime range covered by it. - statement_with_range make_loops(const func* f) { - statement_with_range result; - for (const auto& loop : f->loops()) { - result = make_loop(result, f, loop); + // The loop body is done, and we have an actual loop to make here. Crop the body. + body.body = crop_for_loop(body.body, base_f, loop); + // And make the actual loop. + expr loop_step = sanitizer_.mutate(loop.step); + interval_expr loop_bounds = get_loop_bounds(base_f, loop); + // Make sure that a loop variable is unique. + std::string loop_var_name; + if (base_f->outputs().size() == 1) { + loop_var_name = ctx.name(base_f->outputs()[0].sym()) + "."; } + loop_var_name += ctx.name(loop.sym()); + var loop_var = ctx.insert_unique(loop_var_name); + body.body = substitute(body.body, loop.sym(), loop_var); + body.body = loop::make(loop_var, loop.max_workers, loop_bounds, loop_step, body.body); - return result; + return body; } void compute_allocation_bounds() { @@ -997,19 +992,19 @@ class pipeline_builder { const std::vector& external_symbols() const { return sanitizer_.external; } - // This function works together with the produce() and make_loops() functions + // This function works together with the produce() and make_loop() functions // to build an initial IR. The high-level approach is the following: // * the `build()` function looks through the list of func's // to find funcs which need to be produced or allocated at given // loop level `at`. If func need to be produced it calls the // `produce()` function which actually produces the body of the - // func. If func has loops it calls the 'make_loops()' func to produce + // func. If func has loops it calls the 'make_loop()' func to produce // corresponding loops. // * the `produce()` for a given func produces it's body. - // * the `make_loops()` will produce the necessary loops defined for the function. + // * the `make_loop()` will produce the necessary loops defined for the function. // For each of the new loops, the `build()` is called for the case when there // are func which need to be produced in that new loop. - statement_with_range build(const statement_with_range& body, const func* base_f, const loop_id& at) { + statement_with_range build(const func* base_f, const loop_id& at) { symbol_map uncropped_subs; std::vector results; // Build the functions computed at this loop level. @@ -1023,7 +1018,10 @@ class pipeline_builder { assert(realize_at != realization_levels_.end()); if (compute_at->second == at && !f->loops().empty()) { - statement_with_range f_body = make_loops(f); + // Generate the loops that we want to be explicit by recursively calling make_loop starting + // from the outer loop. + statement_with_range f_body = make_loop(f, f->loops().size() - 1); + // This is a special case for the buffers which are produced and consumed inside // of this loop. In this case we simply wrap loop body with corresponding allocations. if (candidates_for_allocation_[at].size() > old_candidates.size() + 1) { @@ -1165,18 +1163,12 @@ class pipeline_builder { iteration_count++; } - assert(!results.empty() || body.body.defined()); + if (results.empty()) return {}; + statement_with_range result; - if (results.empty()) { - result = body; - } else { - result = results.front(); - for (std::size_t ix = 1; ix < results.size(); ix++) { - result = statement_with_range::merge(result, results[ix]); - } - if (body.body.defined()) { - result = statement_with_range::merge(result, body); - } + result = results.front(); + for (std::size_t ix = 1; ix < results.size(); ix++) { + result = statement_with_range::merge(result, results[ix]); } // Add all remaining allocations at this loop level. The allocations can be added in any order. This order enables @@ -1332,7 +1324,7 @@ stmt build_pipeline(node_context& ctx, const std::vector& input pipeline_builder builder(ctx, inputs, outputs); stmt result; - result = builder.build({}, nullptr, loop_id()).body; + result = builder.build(nullptr, loop_id()).body; result = builder.add_input_checks(result); result = builder.make_buffers(result); result = builder.define_sanitized_replacements(result); diff --git a/builder/test/softmax.cc b/builder/test/softmax.cc index ca0e04d1..a3535c32 100644 --- a/builder/test/softmax.cc +++ b/builder/test/softmax.cc @@ -103,11 +103,7 @@ class softmax : public testing::TestWithParam> { auto split_factors = testing::Values(0, 1, 4); INSTANTIATE_TEST_SUITE_P(mode, softmax, - testing::Combine(split_factors, split_factors, testing::Values(false), testing::Values(0)), - test_params_to_string); - -INSTANTIATE_TEST_SUITE_P(compute_at, softmax, - testing::Combine(testing::Values(1), testing::Values(1), testing::Values(false), testing::Values(0)), + testing::Combine(split_factors, split_factors, testing::Values(false, true), testing::Values(0)), test_params_to_string); INSTANTIATE_TEST_SUITE_P(with_copy, softmax, @@ -178,7 +174,7 @@ TEST_P(softmax, pipeline) { pass0.loops(loops); pass4.loops(loops); - if (use_compute_at) { + if (use_compute_at && split_b > 0) { pass1.compute_at({&pass4, b}); }