Skip to content

Commit

Permalink
[BACKEND] Refactor how thread/lane/warp IDs are created (NFC) (#5906)
Browse files Browse the repository at this point in the history
Warp specialization will cause these to become relative to the current
warpgroup, so funnel all the code through a set of common APIs.
  • Loading branch information
Mogball authored Feb 14, 2025
1 parent 65c4294 commit dd17cfb
Show file tree
Hide file tree
Showing 16 changed files with 88 additions and 115 deletions.
64 changes: 27 additions & 37 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,10 @@ bool isConstantZero(Value v);

namespace mlir::triton {

// Returns CTA level thread idx
inline Value getThreadId(OpBuilder &rewriter, Location loc) {
Value tid =
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
Type i32_ty = rewriter.getIntegerType(32);
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
}

struct TritonLLVMOpBuilder {
TritonLLVMOpBuilder(const Location &loc, OpBuilder &builder)
TritonLLVMOpBuilder(Location loc, OpBuilder &builder)
: loc(loc), builder(&builder) {}

// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
// Operators
template <typename... Args> LLVM::SIToFPOp inttofloat(Args &&...args) {
Expand Down Expand Up @@ -282,7 +275,6 @@ struct TritonLLVMOpBuilder {
Value i16_val(int64_t val) { return int_val(16, val); }
Value i32_val(int64_t val) { return int_val(32, val); }
Value i64_val(int64_t val) { return int_val(64, val); }
Value tid_val() { return getThreadId(*builder, loc); }

Location loc;
OpBuilder *builder;
Expand Down Expand Up @@ -657,6 +649,20 @@ Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,

} // namespace LLVM

// -----------------------------------------------------------------------
// Hardware Indices
// -----------------------------------------------------------------------

// Returns CTA level thread ID.
Value getThreadId(OpBuilder &rewriter, Location loc);

// Get the lane ID, which is index of the thread within its warp.
Value getLaneId(OpBuilder &rewriter, Location loc, unsigned threadsPerWarp);

// Get the lane ID and warp ID.
std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc,
unsigned threadsPerWarp);

// -----------------------------------------------------------------------
// Shared memory utilities
// -----------------------------------------------------------------------
Expand Down Expand Up @@ -721,11 +727,9 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter,
RankedTensorType type) {
MLIRContext *ctx = rewriter.getContext();
auto shape = type.getShape();
Value threadId = getThreadId(rewriter, loc);
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value warpSize = b.i32_val(triton::gpu::getWarpSize(blockedLayout));
Value laneId = b.urem(threadId, warpSize);
Value warpId = b.udiv(threadId, warpSize);
auto [laneId, warpId] =
getLaneAndWarpId(rewriter, loc, triton::gpu::getWarpSize(blockedLayout));
auto sizePerThread = blockedLayout.getSizePerThread();
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
Expand Down Expand Up @@ -784,10 +788,7 @@ emitBaseIndexWithinCTAForMmaLayoutV2V3(Location loc, RewriterBase &rewriter,
warpsPerCTA.push_back(b.i32_val(_warpsPerCTA[i]));
auto shapePerCTA = getShapePerCTA(mmaLayout, shape);

Value threadId = getThreadId(rewriter, loc);
Value warpSize = b.i32_val(32);
Value laneId = b.urem(threadId, warpSize);
Value warpId = b.udiv(threadId, warpSize);
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc, /*warpSize=*/32);

uint32_t repM =
(_warpsPerCTA[rank - 2] * instrShape[rank - 2]) / shapePerCTA[rank - 2];
Expand Down Expand Up @@ -849,15 +850,13 @@ emitBaseIndexForMfmaLayout(Location loc, RewriterBase &rewriter,
assert((mDim == nDim && (mDim == 32 || mDim == 16 || mDim == 4)) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));

Value threadId = getThreadId(rewriter, loc);
Value warpSize = b.i32_val(triton::gpu::getWarpSize(mfmaLayout));
Value effectiveWarpSize = warpSize;
auto [laneId, warpId] =
getLaneAndWarpId(rewriter, loc, triton::gpu::getWarpSize(mfmaLayout));
if (mDim == 4 && nDim == 4) {
const int uniqueValuesPerWarp = 4;
effectiveWarpSize = b.i32_val(uniqueValuesPerWarp);
constexpr int uniqueValuesPerWarp = 4;
laneId = b.urem(laneId, b.i32_val(uniqueValuesPerWarp));
}
Value laneId = b.urem(threadId, effectiveWarpSize);
Value warpId = b.udiv(threadId, warpSize);

SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, _warpsPerCTA,
triton::gpu::getWarpOrder(mfmaLayout));
Expand Down Expand Up @@ -975,13 +974,10 @@ emitBaseIndexForWmmaLayout(Location loc, RewriterBase &rewriter,
warpsPerCTA.push_back(b.i32_val(_warpsPerCTA[i]));
auto mnkDim = AMDWmmaEncodingAttr::getMNKDimPerInstr();

Value threadId = getThreadId(rewriter, loc);
Value warpSize = b.i32_val(triton::gpu::getWarpSize(wmmaLayout));
Value laneId =
b.urem(threadId, b.i32_val(triton::gpu::getWarpSize(wmmaLayout) / 2));
Value threadIdPerWarp = b.urem(threadId, warpSize);
unsigned warpSize = triton::gpu::getWarpSize(wmmaLayout);
auto [threadIdPerWarp, warpId] = getLaneAndWarpId(rewriter, loc, warpSize);
Value laneId = b.urem(threadIdPerWarp, b.i32_val(warpSize / 2));

Value warpId = b.udiv(threadId, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, _warpsPerCTA,
triton::gpu::getWarpOrder(wmmaLayout));
Expand Down Expand Up @@ -1146,12 +1142,6 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
return idx;
}

// Emit code to compute the (laneId, warpId, blockId) for the current thread.
std::tuple</*laneId=*/Value, /*warpId=*/Value, /*blockId=*/Value>
emitHardwareTuple(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, bool withCTAOffset,
unsigned threadsPerWarp);

// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
//
Expand Down
10 changes: 3 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
StringAttr kOffset = str_attr("offset");
StringAttr kIteration = str_attr("iteration");

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = b.i32_val(srcLayout.getInDimSize(kLane));
Value laneId = b.urem(threadId, threadsPerWarp);
Value warpId = b.udiv(threadId, threadsPerWarp);
auto [laneId, warpId] =
getLaneAndWarpId(rewriter, loc, srcLayout.getInDimSize(kLane));

auto scratchConfig =
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
Expand Down Expand Up @@ -662,9 +660,7 @@ void ConvertLayoutOpUsingLinearLayoutsConversion::transferWithinWarp(
unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> shflOuts(Cp.getInDimSize(kRegister));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = b.i32_val(Cp.getInDimSize(kLane));
Value laneId = b.urem(threadId, threadsPerWarp);
Value laneId = getLaneId(rewriter, loc, Cp.getInDimSize(kLane));

// Emit one shuffle per destination register.
for (int i : llvm::seq(shflOuts.size())) {
Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ void GatherOpConversion::emitWarpLocalGather(
SmallVector<Value> idxValues =
unpackLLElements(loc, adaptor.getIndices(), rewriter);

auto [laneId, warpId, blockId] =
emitHardwareTuple(loc, rewriter, targetInfo, /*withCTAOffset=*/true,
srcLayout.getInDimSize(kLane));
auto [laneId, warpId] =
getLaneAndWarpId(rewriter, loc, srcLayout.getInDimSize(kLane));
Value blockId = targetInfo.getClusterCTAId(rewriter, loc);

unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister);
assert(srcRegsPerThread == srcValues.size());
Expand Down
6 changes: 2 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
using namespace mlir;
using namespace mlir::triton;

static int log2Int(int64_t num) { return (num > 1) ? 1 + log2Int(num / 2) : 0; }

// Compute a histogram within a warp. This uses an algorithm by @apgoucher
// that does the following:
// Create a ballot for each bit of the bin index (there
Expand All @@ -22,8 +20,8 @@ static SmallVector<Value> computeWarpLevelHistogram(
assert(numBins % numThreadPerWarp == 0 &&
"numBins must be divisible by numThreadPerWarp");
Value zero = b.i32_val(0);
int numBits = log2Int(numBins);
int numBitsLaneId = log2Int(numThreadPerWarp);
int numBits = llvm::Log2_64(numBins);
int numBitsLaneId = llvm::Log2_64(numThreadPerWarp);
unsigned numElementsPerThreads = triton::gpu::getTotalElemsPerThread(srcType);
unsigned numThreadWithUniqueData =
triton::gpu::getThreadsPerWarpWithUniqueData(srcType.getEncoding(),
Expand Down
7 changes: 2 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,11 @@ struct ReduceOpConversion
triton::ReduceOp op = helper.getOperation();
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value threadId = getThreadId(rewriter, loc);
auto srcLayout =
mlir::cast<DistributedEncodingTrait>(helper.getSrcLayout());
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
Value warpSize =
b.i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
Value warpId = b.udiv(threadId, warpSize);
Value laneId = b.urem(threadId, warpSize);
auto [laneId, warpId] = getLaneAndWarpId(
rewriter, loc, triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
unsigned axis = op.getAxis();
auto smemShape = helper.getScratchRepShape();

Expand Down
50 changes: 30 additions & 20 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,27 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
return outIndices;
}

std::tuple<Value, Value, Value> emitHardwareTuple(Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
bool withCTAOffset,
unsigned threadsPerWarpCst) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = b.i32_val(threadsPerWarpCst);
Value laneId = b.urem(threadId, threadsPerWarp);
Value warpId = b.udiv(threadId, threadsPerWarp);
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
return {laneId, warpId, blockId};
Value getThreadId(OpBuilder &rewriter, Location loc) {
Value tid =
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, tid);
}

Value getLaneId(OpBuilder &rewriter, Location loc, unsigned threadsPerWarp) {
TritonLLVMOpBuilder b(loc, rewriter);
Value tid = getThreadId(rewriter, loc);
return b.urem(tid, b.i32_val(threadsPerWarp));
}

std::pair<Value, Value> getLaneAndWarpId(OpBuilder &rewriter, Location loc,
unsigned warpSize) {
TritonLLVMOpBuilder b(loc, rewriter);
Value tid = getThreadId(rewriter, loc);
Value warpSizeVal = b.i32_val(warpSize);

Value laneId = b.urem(tid, warpSizeVal);
Value warpId = b.udiv(tid, warpSizeVal);
return {laneId, warpId};
}

SmallVector<SmallVector<Value>>
Expand All @@ -187,8 +195,10 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

auto [laneId, warpId, blockId] = emitHardwareTuple(
loc, rewriter, target, withCTAOffset, ll.getInDimSize(kLane));
auto [laneId, warpId] =
getLaneAndWarpId(rewriter, loc, ll.getInDimSize(kLane));
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
unsigned rank = shape.size();
SmallVector<SmallVector<Value>> ret;
// Linear layout function is split in two parts below:
Expand Down Expand Up @@ -407,9 +417,10 @@ bool emitTransferBetweenRegistersAndShared(
maxVecElems.value_or(std::numeric_limits<int>::max()));

auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1;
auto [laneId, warpId, blockId] =
emitHardwareTuple(loc, rewriter, target, withCTAOffset,
regToSharedLayout.getInDimSize(kLane));
auto [laneId, warpId] =
getLaneAndWarpId(rewriter, loc, regToSharedLayout.getInDimSize(kLane));
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);

// For kernels with a single CTA, `allocSharedLayout.sublayout(S("block"),
// outDims) == 0`. We need to take out the "block" dimension in order to use
Expand Down Expand Up @@ -915,8 +926,7 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
auto instrShape = mmaLayout.getInstrShape();
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
auto [laneId, warpId, blockId] = emitHardwareTuple(
loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32);
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc, /*warpSize=*/32);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct ConvertLayoutOpMFMAToDotOpConversion
Value c48 = b.i32_val(48);
Value c64 = b.i32_val(64);

Value threadId = b.tid_val();
Value threadId = getThreadId(rewriter, loc);
Value laneId = b.urem(threadId, c64);

Value mask0 = b.icmp_slt(laneId, c32);
Expand Down
10 changes: 1 addition & 9 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ struct DotOpMFMAConversionHelper {
: mfmaLayout(mfmaLayout), rewriter(rewriter),
typeConverter(typeConverter), loc(loc), ctx(mfmaLayout.getContext()) {}

Value getThreadId() const {
auto llvmIndexTy = typeConverter->getIndexType();
auto tid = rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}

Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB,
Value valC) const {
auto b = TritonLLVMOpBuilder(loc, rewriter);
Expand Down Expand Up @@ -124,8 +117,7 @@ struct DotOpMFMAConversionHelper {
return acc;
constexpr int warpSize = 64;
int subBlockSize = warpSize / numSubBlocks;
Value laneId = getThreadId();
laneId = b.and_(laneId, b.i32_val(warpSize - 1));
Value laneId = getLaneId(rewriter, loc, warpSize);
auto vecTy = dyn_cast<VectorType>(acc.getType());
auto elemType = vecTy.getElementType();
assert(elemType.getIntOrFloatBitWidth() == 32);
Expand Down
6 changes: 3 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto tensorTy = dyn_cast<RankedTensorType>(valueTy);
Value mask = b.int_val(1, 1);
auto tid = b.tid_val();
auto tid = getThreadId(rewriter, loc);
auto clusterCTAId = targetInfo.getClusterCTAId(rewriter, loc);
if (tensorTy) {
auto layout = tensorTy.getEncoding();
Expand Down Expand Up @@ -1062,7 +1062,7 @@ struct AtomicCASOpConversion

// Fill entry block with global memory barrier and conditional branch.
rewriter.setInsertionPointToEnd(curBlock);
auto tid = b.tid_val();
auto tid = getThreadId(rewriter, loc);
Value pred = b.icmp_eq(tid, b.i32_val(i));
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);

Expand Down Expand Up @@ -1329,7 +1329,7 @@ struct AtomicRMWOpConversion
numElems = tensorTy.getNumElements();
}
Value mask = b.int_val(1, 1);
auto tid = b.tid_val();
auto tid = getThreadId(rewriter, loc);
mask = b.and_(mask, b.icmp_slt(b.mul(tid, b.i32_val(elemsPerThread)),
b.i32_val(numElems)));
if (useDppForPackedF16)
Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct LocalLoadOpConversion
: SharedToDotOperandWMMA::convertLayout;
res = sharedToDotConvert(dotOperandLayout.getOpIdx(), rewriter, loc, src,
dotOperandLayout, smemObj, typeConverter,
b.tid_val());
getThreadId(rewriter, loc));
} else {
assert(false && "unsupported layout found");
}
Expand Down
5 changes: 1 addition & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,7 @@ class UpcastMXFPOpPattern

int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
Value warpSize = b.i32_val(numThreads);
Value tid = b.tid_val();
Value warpId = b.udiv(tid, warpSize);
Value laneId = b.urem(tid, warpSize);
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc, numThreads);

bool useFp16 = op.getType().getElementType().isF16();
if (isPacked) {
Expand Down
12 changes: 2 additions & 10 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,8 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
}

auto mod = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
Value threadId =
rewriter.create<::mlir::gpu::ThreadIdOp>(loc, ::mlir::gpu::Dimension::x);
threadId = rewriter.create<arith::IndexCastOp>(loc, i32_ty, threadId);
Value threadId = getThreadId(rewriter, loc);

unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
Value warpSize = b.i32_val(iWarpSize);
Value laneId = b.urem(threadId, warpSize);
Expand All @@ -131,13 +130,6 @@ static Value shuffleCommonImpl(Location loc, RewriterBase &rewriter,
switch (mode) {
case ShflKind::bfly:
if (strideInt > 16) {
Value threadId =
rewriter
.create<UnrealizedConversionCastOp>(
loc, TypeRange{i32_ty},
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)})
.getResult(0);
Value stride = b.i32_val(32);
Value lineId = b.xor_(threadId, stride);
return bpermute(lineId);
Expand Down
3 changes: 3 additions & 0 deletions third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ class WarpIdOpPattern : public OpRewritePattern<ttn::WarpIdOp> {

Value threadId = rewriter.create<NVVM::ThreadIdXOp>(loc, i32_ty);
Value warpId = b.udiv(threadId, b.i32_val(32));
// This indicates to PTXAS that the result and its derived values are
// uniform across the warp. For example, if a branch condition derives from
// this value, it can be proven to be non-divergent.
warpId = LLVM::NVIDIA::shuffleIdx(loc, rewriter, warpId, 0);
rewriter.replaceOp(op, warpId);
return success();
Expand Down
Loading

0 comments on commit dd17cfb

Please sign in to comment.