From 31e8c012852282ab442bcd2aa59194ec46f5debc Mon Sep 17 00:00:00 2001 From: Shuhao Liang <50269654+lshpku@users.noreply.github.com> Date: Wed, 8 Jan 2025 16:31:23 +0800 Subject: [PATCH] [CINN] Implement the new AlignIterSpaceTactic (#70649) --- .../dy_shape_group_scheduler.cc | 2 + .../tactic/align_iter_space_tactic.cc | 227 +++++++++++++----- .../tactic/align_iter_space_tactic.h | 4 +- .../tactic/tile_first_general_tactic.cc | 44 ---- 4 files changed, 175 insertions(+), 102 deletions(-) diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index e533e35c67663b..758464d5d21857 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -16,6 +16,7 @@ #include "paddle/cinn/common/cas.h" #include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h" #include "paddle/cinn/ir/group_schedule/config/schedule_config_manager.h" +#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.h" @@ -33,6 +34,7 @@ void DynamicShapeGroupScheduler::Init() { VLOG(4) << "original group func body: \n" << ir_sch_->GetModule().GetExprs()[0]; InitBuckets(); + tactics_.emplace_back(CreateAlignIterSpaceTactic()); tactics_.emplace_back(CreateTileBroadcastTactic()); tactics_.emplace_back(CreateTileFirstGeneralTactic()); tactics_.emplace_back(CreateComputeInlineTactic()); diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc index dcc72e4a217d82..3476755d2460be 100644 --- a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// Copyright (c) 2025 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,89 +13,206 @@ // limitations under the License. #include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h" -#include "paddle/cinn/common/cas.h" -#include "paddle/cinn/common/integer_set.h" -#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" -#include "paddle/cinn/ir/op/ir_operators.h" -#include "paddle/cinn/ir/utils/ir_copy.h" namespace cinn { namespace ir { +namespace { +/** + * Reorder the loops according to the memory-consistent order of input or output + * to make memory access as coalesced as possible. + * + * This tactic uses different alignment policies for Reduce and Trivial: + * 1) Reduce: align with the input, because after reduction, the output data is + * significantly smaller than the input data, so it's more critical to make + * input coalesced. + * 2) Trivial: align with the output, because discrete writes incur higher costs + * than discrete reads for the same volume of data due to the hardware design + * of cache. Therefore, we should ensure coalesced writes in priority. + * + * Note: we reorder spatial and reduce loops seperately, because we need to + * maintain the relative order between spatial and reduce loops, so as for later + * tactics to work properly. Thus, we use two lists sp_loop_perm & rd_loop_perm + * to record the permutation of spatial and reduce loops respectively. + * + * + * Examples: + * 1. Reduce + * Input: + * for (i, 0, 8): # S + * for (j, 0, 32): # S + * for (k, 0, 128): # R + * for (a, 0, 256): # R + * var_1[i, j] += var_0[j, a, k, i] + * Analysis: + * We align Reduce to the input `var_0[j, a, k, i]`. In the indices of var_0, + * the mapping from each index to the loop index is: + * indices[0] = j => loops[1] # S + * indices[1] = a => loops[3] # R + * indices[2] = k => loops[2] # R + * indices[3] = i => loops[0] # S + * To make the indices of var_0 consistent with its original memory layout, we + * need to permute the loops in the order {1, 3, 2, 0}. However, as we reorder + * spatial and reduce loop seperately, we split the permutation into sp & rd, + * getting sp_loop_perm = {1, 0} and rd_loop_perm = {3, 2}. + * Output: + * for (j, 0, 32): # S + * for (i, 0, 8): # S + * for (a, 0, 256): # R + * for (k, 0, 128): # R + * var_1[i, j] += var_0[j, a, k, i] + * + * 2. Trivial + * Input: + * for (i, 0, 32): + * for (j, 0, 128): + * for (k, 0, 256): + * var_1[k, i, j] = exp(var_0[j, i, k]) + * Analysis: + * We align Trivial to the output `var_1[k, i, j]`. In the indices of var_1, + * the mapping from each index to the loop index is: + * indices[0] = k => loops[2] + * indices[1] = i => loops[0] + * indices[2] = j => loops[1] + * Like example 1, we should permute the loops in the order {2, 0, 1}. As this + * graph doesn't contain reduce loops, all we get is sp_loop_perm = {2, 0, 1}, + * and rd_loop_perm = {}. + * Output: + * for (k, 0, 256): + * for (i, 0, 32): + * for (j, 0, 128): + * var_1[k, i, j] = exp(var_0[j, i, k]) + */ class AlignIterSpaceTactic final : public ScheduleTactic { public: - void Init(ScheduleContext* context) override; + void Init(ScheduleContext* context, ir::IRSchedule* sch) override; void Apply(ir::IRSchedule* sch, const std::string& block_id) override; std::string TacticName() const override { return "AlignIterSpaceTactic"; } + private: + /** + * Get the common memory-consistent order of loops according to the outputs. + * Returns null if not all outputs share the same order. + */ + std::vector GetCommonOutputLoopPerm(ir::IRSchedule* sch); + private: ScheduleContext* context_; + + // The permutation of spatial and reduce loops, in other to achieve the + // memory-consistent alignment. + std::vector sp_loop_perm_; + std::vector rd_loop_perm_; }; -void AlignIterSpaceTactic::Init(ScheduleContext* context) { +void AlignIterSpaceTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) { context_ = context; -} + sp_loop_perm_.clear(); + rd_loop_perm_.clear(); -void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch, - const std::string& block_id) { - ir::Expr block = sch->GetBlock(block_id); + auto& loop_strides = context_->config.base_info->loop_strides; + auto& reduce_axis = context_->config.base_info->reduce_axis; + std::set reduce_axis_set(reduce_axis.begin(), reduce_axis.end()); - std::vector loops = sch->GetLoops(block_id); - ir::Expr src_total_extent{1}; - for (const auto& loop : loops) { - src_total_extent = src_total_extent * loop.As()->extent; - } - ir::Expr target_sp_extent{1}; - for (const auto& iter : context_->iter_space_info.sp_space) { - target_sp_extent = target_sp_extent * std::get<0>(iter); + if (!loop_strides.empty()) { + // If this is a Reduce, calculate the loop_perm by sorting the loops in the + // descending order of their strides according to the input, then split the + // loop_perm into sp_loop_perm & rd_loop_perm. + std::vector loop_perm(loop_strides.size()); + std::iota(loop_perm.begin(), loop_perm.end(), 0); + std::stable_sort(loop_perm.begin(), loop_perm.end(), [&](int a, int b) { + return loop_strides[a] > loop_strides[b]; + }); + + for (int axis : loop_perm) { + if (reduce_axis_set.count(axis) > 0) { + rd_loop_perm_.push_back(axis); + } else if (loop_strides[axis] != 0) { + sp_loop_perm_.push_back(axis); + } + } + } else { + // If this is a Trvial, calculate the sp_loop_perm according to the output. + sp_loop_perm_ = GetCommonOutputLoopPerm(sch); } - ir::Expr target_total_extent = ir_utils::IRCopy(target_sp_extent); - for (const auto& iter : context_->iter_space_info.rb_space) { - target_total_extent = target_total_extent * std::get<0>(iter); + + VLOG(4) << "AlignIterSpaceTactic:\n" + << "sp_loop_perm: " << utils::Join(sp_loop_perm_, ", ") << "\n" + << "rd_loop_perm: " << utils::Join(rd_loop_perm_, ", "); +} + +std::unordered_map GetLoopVarToIndex( + const std::vector& loops) { + std::unordered_map loop_var2index; + for (int i = 0; i < loops.size(); ++i) { + auto* node = loops[i].As(); + loop_var2index[node->loop_var] = i; } + return loop_var2index; +} - common::cas_intervals_t var_intervals; - common::SymbolicExprAnalyzer symbolic_expr_analyzer(var_intervals); - std::optional total_extent_eq = - symbolic_expr_analyzer.ProveEQ(src_total_extent, target_total_extent); - bool need_reorder = false; - for (int i = 0; i < context_->iter_space_info.rb_last_order.size(); ++i) { - if (context_->iter_space_info.rb_last_order[i] != i) { - need_reorder = true; - break; - } +/** + * Check whether this is an effective permutation. + * A permutation is ineffective if it's entirely in ascending order. + */ +bool IsPermutationEffective(const std::vector& perm) { + for (int i = 1; i < perm.size(); ++i) { + if (perm[i - 1] > perm[i]) return true; } + return false; +} - if (total_extent_eq.has_value() && total_extent_eq.value()) { - if (need_reorder) { - sch->Reorder(block_id, context_->iter_space_info.rb_last_order); - } - if (context_->iter_space_info.sp_space.size() < loops.size() - 1) { - loops = sch->GetLoops(block_id); - - // Align the loop in the current block that needs to be aligned with the - // reduce loop in iter_space_info - std::vector rb_loops( - loops.end() - context_->iter_space_info.rb_space.size(), loops.end()); - sch->Fuse(rb_loops); +std::vector AlignIterSpaceTactic::GetCommonOutputLoopPerm( + ir::IRSchedule* sch) { + std::vector common_loop_perm; + + for (auto& block : sch->GetAllBlocks()) { + std::string block_id = ir::analyzer::GetBlockName(block); + if (context_->output_names.count(block_id) == 0) continue; + + auto store = ir::analyzer::GetStoreOfSBlock(block); + auto& indices = store.As()->indices; + std::unordered_map iter_var2iter_value = + ir::analyzer::GetIterVarToValueOfSBlock(block); + std::unordered_map loop_var2index = + GetLoopVarToIndex(sch->GetLoops(block)); + + std::vector loop_perm; + for (auto& index : indices) { + if (index.is_constant()) continue; + if (!index.is_var()) return {}; + ir::Expr iter_value = iter_var2iter_value[index.as_var_ref()]; + if (!iter_value.is_var()) return {}; + ir::Expr loop_var = iter_value.as_var_ref(); + loop_perm.push_back(loop_var2index[loop_var]); } - if (context_->iter_space_info.sp_space.size() > 1) { - // Align the loop in the current block that needs to be aligned with the - // spatial loop in iter_space_info - loops = sch->GetLoops(block_id); - std::vector sp_loops( - loops.begin(), - loops.end() - context_->iter_space_info.rb_space.size()); - sch->Fuse(sp_loops); + + if (common_loop_perm.empty()) { + common_loop_perm = std::move(loop_perm); + } else if (common_loop_perm != loop_perm) { + return {}; } - } else { - sch->Fuse(loops); } + + return common_loop_perm; } +void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch, + const std::string& block_id) { + if (ir::IsReduceInitTensorName(block_id)) return; + if (IsPermutationEffective(sp_loop_perm_)) { + sch->Reorder(block_id, sp_loop_perm_); + } + if (IsPermutationEffective(rd_loop_perm_)) { + sch->Reorder(block_id, rd_loop_perm_); + } +} + +} // namespace + std::unique_ptr CreateAlignIterSpaceTactic() { return std::make_unique(); } diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h index 2ac65d114c7f51..12891818120712 100644 --- a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023 CINN Authors. All Rights Reserved. +// Copyright (c) 2025 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,10 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once -#include #include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" namespace cinn { diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc index e71e0052a3803c..1022c97420e7cc 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc @@ -48,7 +48,6 @@ class TileFirstGeneralTactic final : public ScheduleTactic { std::string TacticName() const override { return "TileFirstGeneralTactic"; } private: - void AlignToReduceInput(ir::IRSchedule* sch, const std::string& block_id); void MergeFlattenAxis(ir::IRSchedule* sch, const std::string& block_id); void MergeDiscreteFlattenAxis(ir::IRSchedule* sch, const std::string& block_id); @@ -128,11 +127,6 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch, if (!can_apply_) return; if (ir::IsReduceInitTensorName(block_id)) return; - AlignToReduceInput(sch, block_id); - VLOG(6) << "After AlignToReduceInput on block: [" << block_id - << "], loop nest:\n" - << sch->GetLoops(block_id)[0]; - if (UseContinuousDataTile(context_->config)) { VLOG(4) << "Using ApplyContinuousDataTile"; ApplyContinuousDataTile(sch, block_id); @@ -293,44 +287,6 @@ void TileFirstGeneralTactic::ApplyContinuousDataTile( SetReduceType(sch, block_id); } -void TileFirstGeneralTactic::AlignToReduceInput(ir::IRSchedule* sch, - const std::string& block_id) { - const auto& loop_strides = context_->config.base_info->loop_strides; - if (loop_strides.empty()) { - return; - } - - std::vector loops = sch->GetLoops(block_id); - std::vector loop_perm(loops.size()); - std::iota(loop_perm.begin(), loop_perm.end(), 0); - - const auto IsReduce = [&](int64_t axis) { - auto& reduce_axis = context_->config.base_info->reduce_axis; - return std::find(reduce_axis.begin(), reduce_axis.end(), axis) != - reduce_axis.end(); - }; - - std::sort(loop_perm.begin(), loop_perm.end(), [&](int64_t a, int64_t b) { - if (IsReduce(a) == IsReduce(b)) { - return loop_strides[a] > loop_strides[b]; - } - return IsReduce(b); - }); - VLOG(4) << "loop_perm: " << utils::Join(loop_perm, ", "); - - // Reorder S/R loops seperately, otherwise reduce_init will be de-inlined. - std::vector sp_loops, rd_loops; - for (auto i : loop_perm) { - if (IsReduce(i)) { - rd_loops.push_back(loops[i]); - } else if (loop_strides[i] != 0) { - sp_loops.push_back(loops[i]); - } - } - sch->Reorder(sp_loops); - sch->Reorder(rd_loops); -} - void TileFirstGeneralTactic::MergeFlattenAxis(ir::IRSchedule* sch, const std::string& block_id) { if (vec_flatten_axis_.size() >= 2) {