Skip to content

Commit

Permalink
[BACKEND] no longer uses shared mem or barriers for single-warp reduc…
Browse files Browse the repository at this point in the history
…tions (triton-lang#1915)

0-bytes shared mem buffers don't materialize empty allocation buffers;
this could lead to unnecessary barriers.

note: reduceop code has become quite messy and will require some cleanup
  • Loading branch information
ptillet authored and oplavsic committed Aug 15, 2023
1 parent 398d2c7 commit 4215086
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 32 deletions.
16 changes: 11 additions & 5 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,18 @@ class AllocationAnalysis {
}
}

template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes);
}

/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
Expand All @@ -200,7 +206,7 @@ class AllocationAnalysis {
srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
Expand All @@ -217,7 +223,7 @@ class AllocationAnalysis {
elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
Expand All @@ -229,13 +235,13 @@ class AllocationAnalysis {
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
allocation->addBuffer<BufferT::BufferKind::Virtual>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes);
}
}

Expand Down
7 changes: 4 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {

auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
// if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
// triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
// return {{1, 1}, {1, 1}};

// that case doesn't need inter-warp communication
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
return {{0, 0}, {0, 0}};

/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
Expand Down
55 changes: 42 additions & 13 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,20 @@ struct ReduceOpConversion
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));

SmallVector<Value> smemBases(op.getNumOperands());
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
elemPtrTys[i]);
}

unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();

SmallVector<Value> smemBases(op.getNumOperands());
if (sizeInterWarps > 1) {
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
elemPtrTys[i]);
}
}

unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
Expand Down Expand Up @@ -418,6 +420,7 @@ struct ReduceOpConversion
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);

std::map<SmallVector<unsigned>, SmallVector<Value>> finalAccs;
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> acc = it.second;
Expand All @@ -440,8 +443,13 @@ struct ReduceOpConversion
accumulate(rewriter, *combineOp, acc, shfl, false);
}

if (sizeInterWarps == 1) {
finalAccs[key] = acc;
continue;
}

SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
writeIdx[axis] = warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Expand All @@ -450,6 +458,30 @@ struct ReduceOpConversion
}
}

if (sizeInterWarps == 1) {
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(finalAccs[key][i]);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else
results[i] = finalAccs.begin()->second[i];
}
rewriter.replaceOp(op, results);
return success();
}

barrier();

// The second round of shuffle reduction
Expand Down Expand Up @@ -508,9 +540,6 @@ struct ReduceOpConversion
}
}

// We could avoid this barrier in some of the layouts, however this is not
// the general case.
// TODO: optimize the barrier in case the layouts are accepted.
barrier();

// set output values
Expand Down
16 changes: 11 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,23 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
// of shared memory and append it to the operands of the callOp.
auto loc = callOp.getLoc();
auto caller = callOp->getParentOfType<FunctionOpInterface>();
auto base = allocation.getFunctionSharedMemoryBase(caller);
auto *funcAllocation = allocation.getFuncData(caller);
auto bufferId = funcAllocation->getBufferId(callOp);
auto offset = funcAllocation->getOffset(bufferId);
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()),
NVVM::kSharedMemorySpace);
auto offsetValue = gep(ptrTy, base, i32_val(offset));
auto promotedOperands = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
auto base = allocation.getFunctionSharedMemoryBase(caller);
auto *funcAllocation = allocation.getFuncData(caller);
auto bufferId = funcAllocation->getBufferId(callOp);
// function doesn't have a shared mem buffer
if (bufferId == (size_t)-1) {
promotedOperands.push_back(base);
return promotedOperands;
}
// function has a shared mem buffer
auto offset = funcAllocation->getOffset(bufferId);
auto offsetValue = gep(ptrTy, base, i32_val(offset));
promotedOperands.push_back(offsetValue);
return promotedOperands;
}
Expand Down
19 changes: 13 additions & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,18 @@ def check_type_supported(dtype):
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
self.warps_per_cta = warps_per_cta

def __str__(self):
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"


class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
self.sz_per_thread = size_per_thread
self.threads_per_warp = threads_per_warp
self.warps_per_cta = warps_per_cta
self.order = order

def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
Expand Down Expand Up @@ -1959,7 +1959,6 @@ def kernel(X, stride_xm, stride_xk,
out_dtype = tl.float16
else:
out_dtype = tl.float32

pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
Expand All @@ -1974,6 +1973,14 @@ def kernel(X, stride_xm, stride_xk,
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync")
end = ptx.find("cvt.rn.f16.f32")
red_code = ptx[start:end]
assert len(red_code) > 0
assert "shared" not in red_code
assert "bar.sync" not in red_code
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
Expand Down

0 comments on commit 4215086

Please sign in to comment.