Skip to content

Commit

Permalink
[AMD] incorporate tl.assume into RangeAnalysis
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Feb 26, 2025
1 parent 852c05f commit e471467
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 11 deletions.
19 changes: 19 additions & 0 deletions test/TritonGPU/amd/amd-range-analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -929,3 +929,22 @@ module attributes {"ttg.num-warps" = 4 : i32} {
tt.return %11 : tensor<1024xf32>
}
}

// -----

// CHECK-LABEL: tt.func @DynamicKBound
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func @DynamicKBound(%K: i32) {
// expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}}
// expected-remark@+1 {{non-neg}}
%c1024_i32 = arith.constant 1024 : i32
// expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}}
// expected-remark@+1 {{non-neg}}
%c128 = arith.constant 128 : i32
%cmp = arith.cmpi sle, %K, %c128 : i32
llvm.intr.assume %cmp : i1
// expected-remark@+1 {{unsigned : [-1, -1] signed : [-1, -1]}}
%condtest = arith.cmpi sle, %K, %c1024_i32 : i32
tt.return
}
}
16 changes: 10 additions & 6 deletions third_party/amd/include/Analysis/RangeAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,9 @@ namespace mlir::triton::AMD {
/// See visitRegionSuccessors.
struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis;

llvm::SmallDenseMap<LoopLikeOpInterface, int64_t> loopTripCounts;
llvm::SmallDenseMap<
std::pair<LoopLikeOpInterface, dataflow::IntegerValueRangeLattice *>,
int64_t>
loopVisits;
TritonIntegerRangeAnalysis(DataFlowSolver &solver,
DenseSet<Value> &assumptions)
: dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions) {}

void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override;

Expand Down Expand Up @@ -72,6 +69,13 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis {
ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor,
ArrayRef<dataflow::AbstractSparseLattice *> abstractLattices) override;

llvm::SmallDenseMap<LoopLikeOpInterface, int64_t> loopTripCounts;
llvm::SmallDenseMap<
std::pair<LoopLikeOpInterface, dataflow::IntegerValueRangeLattice *>,
int64_t>
loopVisits;
llvm::DenseSet<Value> assumptions;
};

// TODO(max): remove after we catch up to
Expand Down
54 changes: 52 additions & 2 deletions third_party/amd/lib/Analysis/RangeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,58 @@ namespace mlir::triton::AMD {

void TritonIntegerRangeAnalysis::setToEntryState(
dataflow::IntegerValueRangeLattice *lattice) {
propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
lattice->getAnchor())));
auto anchor = lattice->getAnchor();
auto range = IntegerValueRange::getMaxRange(anchor);
auto maybeCmpOpUser = llvm::find_if(anchor.getUsers(), [&](Operation *op) {
return op->getNumResults() == 1 && assumptions.contains(op->getResult(0));
});
if (maybeCmpOpUser != anchor.getUsers().end()) {
auto cmpOp = llvm::cast<arith::CmpIOp>(*maybeCmpOpUser);
bool anchorIsLhs = cmpOp.getLhs() == anchor;
auto maybeConstantIntValue = getConstantIntValue(
getAsOpFoldResult(anchorIsLhs ? cmpOp.getRhs() : cmpOp.getLhs()));
if (auto constValue = maybeConstantIntValue; constValue.value_or(-1) > 0) {
auto operTy = llvm::cast<IntegerType>(anchor.getType());
unsigned bitWidth = operTy.getWidth();
bool isSigned = operTy.isSigned();
APInt apVal = {bitWidth, static_cast<uint64_t>(*constValue), isSigned};
APInt min = APInt::getMinValue(bitWidth);
APInt max = APInt::getMaxValue(bitWidth);

switch (cmpOp.getPredicate()) {
case arith::CmpIPredicate::eq:
range = mlir::ConstantIntRanges::constant(apVal);
break;
case arith::CmpIPredicate::sge:
if (anchorIsLhs)
range = mlir::ConstantIntRanges::range(apVal, max, isSigned);
else
range = mlir::ConstantIntRanges::range(min, apVal, isSigned);
break;
case arith::CmpIPredicate::sgt:
if (anchorIsLhs)
range = mlir::ConstantIntRanges::range(apVal + 1, max, isSigned);
else
range = mlir::ConstantIntRanges::range(min, apVal + 1, isSigned);
break;
case arith::CmpIPredicate::sle:
if (anchorIsLhs)
range = mlir::ConstantIntRanges::range(min, apVal, isSigned);
else
range = mlir::ConstantIntRanges::range(apVal, max, isSigned);
break;
case arith::CmpIPredicate::slt:
if (anchorIsLhs)
range = mlir::ConstantIntRanges::range(min, apVal - 1, isSigned);
else
range = mlir::ConstantIntRanges::range(apVal - 1, max, isSigned);
break;
default:
break;
}
}
}
propagateIfChanged(lattice, lattice->join(range));
}

LogicalResult TritonIntegerRangeAnalysis::visitOperation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,9 @@ class TritonAMDGPUConvertToBufferOpsPass
// Collect assumptions in the function
DenseSet<Value> assumptions;
mod.walk([&](LLVM::AssumeOp op) {
if (op->getOperand(0).getDefiningOp<arith::CmpIOp>())
assumptions.insert(op->getOperand(0));
auto oper = op->getOperand(0);
if (oper.getDefiningOp<arith::CmpIOp>())
assumptions.insert(oper);
});
LLVM_DEBUG({
DBGS() << "Number of assumptions found: " << assumptions.size() << "\n";
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct TestAMDRangeAnalysisPass
});

std::shared_ptr<DataFlowSolver> solver = createDataFlowSolver();
solver->load<AMD::TritonIntegerRangeAnalysis>();
solver->load<AMD::TritonIntegerRangeAnalysis>(assumptions);
if (failed(solver->initializeAndRun(getOperation())))
return signalPassFailure();

Expand Down

0 comments on commit e471467

Please sign in to comment.