Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Implement the new AlignIterSpaceTactic #70649

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h"
#include "paddle/cinn/ir/group_schedule/config/schedule_config_manager.h"
#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h"
#include "paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.h"
Expand All @@ -33,6 +34,7 @@ void DynamicShapeGroupScheduler::Init() {
VLOG(4) << "original group func body: \n"
<< ir_sch_->GetModule().GetExprs()[0];
InitBuckets();
tactics_.emplace_back(CreateAlignIterSpaceTactic());
tactics_.emplace_back(CreateTileBroadcastTactic());
tactics_.emplace_back(CreateTileFirstGeneralTactic());
tactics_.emplace_back(CreateComputeInlineTactic());
lshpku marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
227 changes: 172 additions & 55 deletions paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
// Copyright (c) 2025 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -13,89 +13,206 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/integer_set.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"

namespace cinn {
namespace ir {
namespace {

/**
* Reorder the loops according to the memory-consistent order of input or output
* to make memory access as coalesced as possible.
*
* This tactic uses different alignment policies for Reduce and Trivial:
* 1) Reduce: align with the input, because after reduction, the output data is
* significantly smaller than the input data, so it's more critical to make
* input coalesced.
* 2) Trivial: align with the output, because discrete writes incur higher costs
* than discrete reads for the same volume of data due to the hardware design
* of cache. Therefore, we should ensure coalesced writes in priority.
*
* Note: we reorder spatial and reduce loops seperately, because we need to
* maintain the relative order between spatial and reduce loops, so as for later
* tactics to work properly. Thus, we use two lists sp_loop_perm & rd_loop_perm
* to record the permutation of spatial and reduce loops respectively.
*
*
* Examples:
* 1. Reduce
* Input:
* for (i, 0, 8): # S
* for (j, 0, 32): # S
* for (k, 0, 128): # R
* for (a, 0, 256): # R
* var_1[i, j] += var_0[j, a, k, i]
* Analysis:
* We align Reduce to the input `var_0[j, a, k, i]`. In the indices of var_0,
* the mapping from each index to the loop index is:
* indices[0] = j => loops[1] # S
* indices[1] = a => loops[3] # R
* indices[2] = k => loops[2] # R
* indices[3] = i => loops[0] # S
* To make the indices of var_0 consistent with its original memory layout, we
* need to permute the loops in the order {1, 3, 2, 0}. However, as we reorder
* spatial and reduce loop seperately, we split the permutation into sp & rd,
* getting sp_loop_perm = {1, 0} and rd_loop_perm = {3, 2}.
* Output:
* for (j, 0, 32): # S
* for (i, 0, 8): # S
* for (a, 0, 256): # R
* for (k, 0, 128): # R
* var_1[i, j] += var_0[j, a, k, i]
*
* 2. Trivial
* Input:
* for (i, 0, 32):
* for (j, 0, 128):
* for (k, 0, 256):
* var_1[k, i, j] = exp(var_0[j, i, k])
* Analysis:
* We align Trivial to the output `var_1[k, i, j]`. In the indices of var_1,
* the mapping from each index to the loop index is:
* indices[0] = k => loops[2]
* indices[1] = i => loops[0]
* indices[2] = j => loops[1]
* Like example 1, we should permute the loops in the order {2, 0, 1}. As this
* graph doesn't contain reduce loops, all we get is sp_loop_perm = {2, 0, 1},
* and rd_loop_perm = {}.
* Output:
* for (k, 0, 256):
* for (i, 0, 32):
* for (j, 0, 128):
* var_1[k, i, j] = exp(var_0[j, i, k])
*/
class AlignIterSpaceTactic final : public ScheduleTactic {
public:
void Init(ScheduleContext* context) override;
void Init(ScheduleContext* context, ir::IRSchedule* sch) override;

void Apply(ir::IRSchedule* sch, const std::string& block_id) override;

std::string TacticName() const override { return "AlignIterSpaceTactic"; }

private:
/**
* Get the common memory-consistent order of loops according to the outputs.
* Returns null if not all outputs share the same order.
*/
std::vector<int> GetCommonOutputLoopPerm(ir::IRSchedule* sch);

private:
ScheduleContext* context_;

// The permutation of spatial and reduce loops, in other to achieve the
// memory-consistent alignment.
std::vector<int> sp_loop_perm_;
std::vector<int> rd_loop_perm_;
};

void AlignIterSpaceTactic::Init(ScheduleContext* context) {
void AlignIterSpaceTactic::Init(ScheduleContext* context, ir::IRSchedule* sch) {
context_ = context;
}
sp_loop_perm_.clear();
rd_loop_perm_.clear();

void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
ir::Expr block = sch->GetBlock(block_id);
auto& loop_strides = context_->config.base_info->loop_strides;
auto& reduce_axis = context_->config.base_info->reduce_axis;
std::set<int> reduce_axis_set(reduce_axis.begin(), reduce_axis.end());

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
ir::Expr src_total_extent{1};
for (const auto& loop : loops) {
src_total_extent = src_total_extent * loop.As<ir::For>()->extent;
}
ir::Expr target_sp_extent{1};
for (const auto& iter : context_->iter_space_info.sp_space) {
target_sp_extent = target_sp_extent * std::get<0>(iter);
if (!loop_strides.empty()) {
// If this is a Reduce, calculate the loop_perm by sorting the loops in the
// descending order of their strides according to the input, then split the
// loop_perm into sp_loop_perm & rd_loop_perm.
std::vector<int> loop_perm(loop_strides.size());
std::iota(loop_perm.begin(), loop_perm.end(), 0);
std::stable_sort(loop_perm.begin(), loop_perm.end(), [&](int a, int b) {
return loop_strides[a] > loop_strides[b];
});

for (int axis : loop_perm) {
if (reduce_axis_set.count(axis) > 0) {
rd_loop_perm_.push_back(axis);
} else if (loop_strides[axis] != 0) {
sp_loop_perm_.push_back(axis);
}
}
} else {
// If this is a Trvial, calculate the sp_loop_perm according to the output.
sp_loop_perm_ = GetCommonOutputLoopPerm(sch);
}
ir::Expr target_total_extent = ir_utils::IRCopy(target_sp_extent);
for (const auto& iter : context_->iter_space_info.rb_space) {
target_total_extent = target_total_extent * std::get<0>(iter);

VLOG(4) << "AlignIterSpaceTactic:\n"
<< "sp_loop_perm: " << utils::Join(sp_loop_perm_, ", ") << "\n"
<< "rd_loop_perm: " << utils::Join(rd_loop_perm_, ", ");
}

std::unordered_map<ir::Var, int> GetLoopVarToIndex(
const std::vector<ir::Expr>& loops) {
std::unordered_map<ir::Var, int> loop_var2index;
for (int i = 0; i < loops.size(); ++i) {
auto* node = loops[i].As<ir::For>();
loop_var2index[node->loop_var] = i;
}
return loop_var2index;
}

common::cas_intervals_t var_intervals;
common::SymbolicExprAnalyzer symbolic_expr_analyzer(var_intervals);
std::optional<bool> total_extent_eq =
symbolic_expr_analyzer.ProveEQ(src_total_extent, target_total_extent);
bool need_reorder = false;
for (int i = 0; i < context_->iter_space_info.rb_last_order.size(); ++i) {
if (context_->iter_space_info.rb_last_order[i] != i) {
need_reorder = true;
break;
}
/**
* Check whether this is an effective permutation.
* A permutation is ineffective if it's entirely in ascending order.
*/
bool IsPermutationEffective(const std::vector<int>& perm) {
for (int i = 1; i < perm.size(); ++i) {
if (perm[i - 1] > perm[i]) return true;
}
return false;
}

if (total_extent_eq.has_value() && total_extent_eq.value()) {
if (need_reorder) {
sch->Reorder(block_id, context_->iter_space_info.rb_last_order);
}
if (context_->iter_space_info.sp_space.size() < loops.size() - 1) {
loops = sch->GetLoops(block_id);

// Align the loop in the current block that needs to be aligned with the
// reduce loop in iter_space_info
std::vector<ir::Expr> rb_loops(
loops.end() - context_->iter_space_info.rb_space.size(), loops.end());
sch->Fuse(rb_loops);
std::vector<int> AlignIterSpaceTactic::GetCommonOutputLoopPerm(
ir::IRSchedule* sch) {
std::vector<int> common_loop_perm;

for (auto& block : sch->GetAllBlocks()) {
std::string block_id = ir::analyzer::GetBlockName(block);
if (context_->output_names.count(block_id) == 0) continue;

auto store = ir::analyzer::GetStoreOfSBlock(block);
auto& indices = store.As<ir::Store>()->indices;
std::unordered_map<ir::Var, ir::Expr> iter_var2iter_value =
ir::analyzer::GetIterVarToValueOfSBlock(block);
std::unordered_map<ir::Var, int> loop_var2index =
GetLoopVarToIndex(sch->GetLoops(block));

std::vector<int> loop_perm;
for (auto& index : indices) {
if (index.is_constant()) continue;
if (!index.is_var()) return {};
ir::Expr iter_value = iter_var2iter_value[index.as_var_ref()];
if (!iter_value.is_var()) return {};
ir::Expr loop_var = iter_value.as_var_ref();
loop_perm.push_back(loop_var2index[loop_var]);
}
if (context_->iter_space_info.sp_space.size() > 1) {
// Align the loop in the current block that needs to be aligned with the
// spatial loop in iter_space_info
loops = sch->GetLoops(block_id);
std::vector<ir::Expr> sp_loops(
loops.begin(),
loops.end() - context_->iter_space_info.rb_space.size());
sch->Fuse(sp_loops);

if (common_loop_perm.empty()) {
common_loop_perm = std::move(loop_perm);
} else if (common_loop_perm != loop_perm) {
return {};
}
} else {
sch->Fuse(loops);
}

return common_loop_perm;
}

void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch,
const std::string& block_id) {
if (ir::IsReduceInitTensorName(block_id)) return;
if (IsPermutationEffective(sp_loop_perm_)) {
sch->Reorder(block_id, sp_loop_perm_);
}
if (IsPermutationEffective(rd_loop_perm_)) {
sch->Reorder(block_id, rd_loop_perm_);
}
}

} // namespace

std::unique_ptr<ScheduleTactic> CreateAlignIterSpaceTactic() {
return std::make_unique<AlignIterSpaceTactic>();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
// Copyright (c) 2025 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -11,10 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h"

namespace cinn {
Expand Down
44 changes: 0 additions & 44 deletions paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class TileFirstGeneralTactic final : public ScheduleTactic {
std::string TacticName() const override { return "TileFirstGeneralTactic"; }

private:
void AlignToReduceInput(ir::IRSchedule* sch, const std::string& block_id);
void MergeFlattenAxis(ir::IRSchedule* sch, const std::string& block_id);
void MergeDiscreteFlattenAxis(ir::IRSchedule* sch,
const std::string& block_id);
Expand Down Expand Up @@ -128,11 +127,6 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
if (!can_apply_) return;
if (ir::IsReduceInitTensorName(block_id)) return;

AlignToReduceInput(sch, block_id);
VLOG(6) << "After AlignToReduceInput on block: [" << block_id
<< "], loop nest:\n"
<< sch->GetLoops(block_id)[0];

if (UseContinuousDataTile(context_->config)) {
VLOG(4) << "Using ApplyContinuousDataTile";
ApplyContinuousDataTile(sch, block_id);
Expand Down Expand Up @@ -293,44 +287,6 @@ void TileFirstGeneralTactic::ApplyContinuousDataTile(
SetReduceType(sch, block_id);
}

void TileFirstGeneralTactic::AlignToReduceInput(ir::IRSchedule* sch,
const std::string& block_id) {
const auto& loop_strides = context_->config.base_info->loop_strides;
if (loop_strides.empty()) {
return;
}

std::vector<ir::Expr> loops = sch->GetLoops(block_id);
std::vector<int64_t> loop_perm(loops.size());
std::iota(loop_perm.begin(), loop_perm.end(), 0);

const auto IsReduce = [&](int64_t axis) {
auto& reduce_axis = context_->config.base_info->reduce_axis;
return std::find(reduce_axis.begin(), reduce_axis.end(), axis) !=
reduce_axis.end();
};

std::sort(loop_perm.begin(), loop_perm.end(), [&](int64_t a, int64_t b) {
if (IsReduce(a) == IsReduce(b)) {
return loop_strides[a] > loop_strides[b];
}
return IsReduce(b);
});
VLOG(4) << "loop_perm: " << utils::Join(loop_perm, ", ");

// Reorder S/R loops seperately, otherwise reduce_init will be de-inlined.
std::vector<Expr> sp_loops, rd_loops;
for (auto i : loop_perm) {
if (IsReduce(i)) {
rd_loops.push_back(loops[i]);
} else if (loop_strides[i] != 0) {
sp_loops.push_back(loops[i]);
}
}
sch->Reorder(sp_loops);
sch->Reorder(rd_loops);
}

void TileFirstGeneralTactic::MergeFlattenAxis(ir::IRSchedule* sch,
const std::string& block_id) {
if (vec_flatten_axis_.size() >= 2) {
Expand Down
Loading