From e4714675d88f5f602ce91aadf94ac8c334f7f3d7 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 25 Feb 2025 21:36:28 -0500 Subject: [PATCH] [AMD] incorporate tl.assume into RangeAnalysis --- test/TritonGPU/amd/amd-range-analysis.mlir | 19 +++++++ .../amd/include/Analysis/RangeAnalysis.h | 16 +++--- .../amd/lib/Analysis/RangeAnalysis.cpp | 54 ++++++++++++++++++- .../ConvertToBufferOps.cpp | 5 +- .../lib/Analysis/TestAMDRangeAnalysis.cpp | 2 +- 5 files changed, 85 insertions(+), 11 deletions(-) diff --git a/test/TritonGPU/amd/amd-range-analysis.mlir b/test/TritonGPU/amd/amd-range-analysis.mlir index 359f3bbe8996..76275be9559c 100644 --- a/test/TritonGPU/amd/amd-range-analysis.mlir +++ b/test/TritonGPU/amd/amd-range-analysis.mlir @@ -929,3 +929,22 @@ module attributes {"ttg.num-warps" = 4 : i32} { tt.return %11 : tensor<1024xf32> } } + +// ----- + +// CHECK-LABEL: tt.func @DynamicKBound +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func @DynamicKBound(%K: i32) { + // expected-remark@+2 {{unsigned : [1024, 1024] signed : [1024, 1024]}} + // expected-remark@+1 {{non-neg}} + %c1024_i32 = arith.constant 1024 : i32 + // expected-remark@+2 {{unsigned : [128, 128] signed : [128, 128]}} + // expected-remark@+1 {{non-neg}} + %c128 = arith.constant 128 : i32 + %cmp = arith.cmpi sle, %K, %c128 : i32 + llvm.intr.assume %cmp : i1 + // expected-remark@+1 {{unsigned : [-1, -1] signed : [-1, -1]}} + %condtest = arith.cmpi sle, %K, %c1024_i32 : i32 + tt.return + } +} diff --git a/third_party/amd/include/Analysis/RangeAnalysis.h b/third_party/amd/include/Analysis/RangeAnalysis.h index cce31969b1cb..2b792f42dd3f 100644 --- a/third_party/amd/include/Analysis/RangeAnalysis.h +++ b/third_party/amd/include/Analysis/RangeAnalysis.h @@ -26,12 +26,9 @@ namespace mlir::triton::AMD { /// See visitRegionSuccessors. struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { using dataflow::IntegerRangeAnalysis::IntegerRangeAnalysis; - - llvm::SmallDenseMap loopTripCounts; - llvm::SmallDenseMap< - std::pair, - int64_t> - loopVisits; + TritonIntegerRangeAnalysis(DataFlowSolver &solver, + DenseSet &assumptions) + : dataflow::IntegerRangeAnalysis(solver), assumptions(assumptions) {} void setToEntryState(dataflow::IntegerValueRangeLattice *lattice) override; @@ -72,6 +69,13 @@ struct TritonIntegerRangeAnalysis : dataflow::IntegerRangeAnalysis { ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor, ArrayRef abstractLattices) override; + + llvm::SmallDenseMap loopTripCounts; + llvm::SmallDenseMap< + std::pair, + int64_t> + loopVisits; + llvm::DenseSet assumptions; }; // TODO(max): remove after we catch up to diff --git a/third_party/amd/lib/Analysis/RangeAnalysis.cpp b/third_party/amd/lib/Analysis/RangeAnalysis.cpp index 3828ccbdfe2e..c293a6ca8e60 100644 --- a/third_party/amd/lib/Analysis/RangeAnalysis.cpp +++ b/third_party/amd/lib/Analysis/RangeAnalysis.cpp @@ -108,8 +108,58 @@ namespace mlir::triton::AMD { void TritonIntegerRangeAnalysis::setToEntryState( dataflow::IntegerValueRangeLattice *lattice) { - propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange( - lattice->getAnchor()))); + auto anchor = lattice->getAnchor(); + auto range = IntegerValueRange::getMaxRange(anchor); + auto maybeCmpOpUser = llvm::find_if(anchor.getUsers(), [&](Operation *op) { + return op->getNumResults() == 1 && assumptions.contains(op->getResult(0)); + }); + if (maybeCmpOpUser != anchor.getUsers().end()) { + auto cmpOp = llvm::cast(*maybeCmpOpUser); + bool anchorIsLhs = cmpOp.getLhs() == anchor; + auto maybeConstantIntValue = getConstantIntValue( + getAsOpFoldResult(anchorIsLhs ? cmpOp.getRhs() : cmpOp.getLhs())); + if (auto constValue = maybeConstantIntValue; constValue.value_or(-1) > 0) { + auto operTy = llvm::cast(anchor.getType()); + unsigned bitWidth = operTy.getWidth(); + bool isSigned = operTy.isSigned(); + APInt apVal = {bitWidth, static_cast(*constValue), isSigned}; + APInt min = APInt::getMinValue(bitWidth); + APInt max = APInt::getMaxValue(bitWidth); + + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::eq: + range = mlir::ConstantIntRanges::constant(apVal); + break; + case arith::CmpIPredicate::sge: + if (anchorIsLhs) + range = mlir::ConstantIntRanges::range(apVal, max, isSigned); + else + range = mlir::ConstantIntRanges::range(min, apVal, isSigned); + break; + case arith::CmpIPredicate::sgt: + if (anchorIsLhs) + range = mlir::ConstantIntRanges::range(apVal + 1, max, isSigned); + else + range = mlir::ConstantIntRanges::range(min, apVal + 1, isSigned); + break; + case arith::CmpIPredicate::sle: + if (anchorIsLhs) + range = mlir::ConstantIntRanges::range(min, apVal, isSigned); + else + range = mlir::ConstantIntRanges::range(apVal, max, isSigned); + break; + case arith::CmpIPredicate::slt: + if (anchorIsLhs) + range = mlir::ConstantIntRanges::range(min, apVal - 1, isSigned); + else + range = mlir::ConstantIntRanges::range(apVal - 1, max, isSigned); + break; + default: + break; + } + } + } + propagateIfChanged(lattice, lattice->join(range)); } LogicalResult TritonIntegerRangeAnalysis::visitOperation( diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 95d978e56913..d7462023176f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -491,8 +491,9 @@ class TritonAMDGPUConvertToBufferOpsPass // Collect assumptions in the function DenseSet assumptions; mod.walk([&](LLVM::AssumeOp op) { - if (op->getOperand(0).getDefiningOp()) - assumptions.insert(op->getOperand(0)); + auto oper = op->getOperand(0); + if (oper.getDefiningOp()) + assumptions.insert(oper); }); LLVM_DEBUG({ DBGS() << "Number of assumptions found: " << assumptions.size() << "\n"; diff --git a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp index 82c2e09f5384..4e71a86c637a 100644 --- a/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp +++ b/third_party/amd/test/lib/Analysis/TestAMDRangeAnalysis.cpp @@ -46,7 +46,7 @@ struct TestAMDRangeAnalysisPass }); std::shared_ptr solver = createDataFlowSolver(); - solver->load(); + solver->load(assumptions); if (failed(solver->initializeAndRun(getOperation()))) return signalPassFailure();