Skip to content

Commit

Permalink
Merge pull request #268 from ROCmSoftwarePlatform/improve_reduce_for_fa
Browse files Browse the repository at this point in the history
[CHERRY-PICKED FROM UPSTREAM][BACKEND] no longer uses shared mem or barriers for single-warp reductions (openai#1915)
  • Loading branch information
jayfurmanek authored Aug 21, 2023
2 parents d86b19f + d0b7793 commit fa42931
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 33 deletions.
16 changes: 11 additions & 5 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,18 @@ class AllocationAnalysis {
}
}

template <BufferT::BufferKind T>
void maybeAddScratchBuffer(Operation *op, unsigned bytes) {
if (bytes > 0)
allocation->addBuffer<T>(op, bytes);
}

/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
if (auto reduceOp = dyn_cast<triton::ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto cvtLayout = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType().cast<RankedTensorType>();
auto dstTy = cvtLayout.getResult().getType().cast<RankedTensorType>();
Expand All @@ -200,7 +206,7 @@ class AllocationAnalysis {
srcTy.getElementType().isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
Expand All @@ -217,7 +223,7 @@ class AllocationAnalysis {
elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
Expand All @@ -229,13 +235,13 @@ class AllocationAnalysis {
auto bytes = elemTy.isa<triton::PointerType>()
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
allocation->addBuffer<BufferT::BufferKind::Virtual>(op, bytes);
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes);
}
}

Expand Down
7 changes: 4 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {

auto argLayout = getSrcLayout();
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
// if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
// triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
// return {{1, 1}, {1, 1}};

// that case doesn't need inter-warp communication
if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
return {{0, 0}, {0, 0}};

/// shared memory block0
smemShapes[0] = convertType<unsigned>(getSrcShape());
Expand Down
55 changes: 42 additions & 13 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,18 +349,20 @@ struct ReduceOpConversion
unsigned elems = product<unsigned>(smemShapes[0]);
unsigned maxElems = std::max(elems, product<unsigned>(smemShapes[1]));

SmallVector<Value> smemBases(op.getNumOperands());
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
elemPtrTys[i]);
}

unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();

SmallVector<Value> smemBases(op.getNumOperands());
if (sizeInterWarps > 1) {
smemBases[0] = bitcast(
getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]);
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
smemBases[i] =
bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)),
elemPtrTys[i]);
}
}

unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
auto srcValues = unpackInputs(loc, op, adaptor, rewriter);
Expand Down Expand Up @@ -418,6 +420,7 @@ struct ReduceOpConversion
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);

std::map<SmallVector<unsigned>, SmallVector<Value>> finalAccs;
for (auto it : accs) {
const SmallVector<unsigned> &key = it.first;
SmallVector<Value> acc = it.second;
Expand All @@ -440,8 +443,13 @@ struct ReduceOpConversion
accumulate(rewriter, *combineOp, acc, shfl, false);
}

if (sizeInterWarps == 1) {
finalAccs[key] = acc;
continue;
}

SmallVector<Value> writeIdx = indices[key];
writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis;
writeIdx[axis] = warpIdAxis;
Value writeOffset =
linearize(rewriter, loc, writeIdx, smemShapes[0], order);
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
Expand All @@ -450,6 +458,30 @@ struct ReduceOpConversion
}
}

if (sizeInterWarps == 1) {
SmallVector<Value> results(op.getNumOperands());
for (unsigned i = 0; i < op.getNumOperands(); ++i) {
if (auto resultTy =
op.getResult()[i].getType().dyn_cast<RankedTensorType>()) {
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
unsigned resultElems = getTotalElemsPerThread(resultTy);
SmallVector<SmallVector<unsigned>> resultOffset =
emitOffsetForLayout(resultLayout, resultTy);
SmallVector<Value> resultVals;
for (int j = 0; j < resultElems; j++) {
auto key = resultOffset[j];
key.insert(key.begin() + axis, 0);
resultVals.push_back(finalAccs[key][i]);
}
results[i] = getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
} else
results[i] = finalAccs.begin()->second[i];
}
rewriter.replaceOp(op, results);
return success();
}

barrier();

// The second round of shuffle reduction
Expand Down Expand Up @@ -508,9 +540,6 @@ struct ReduceOpConversion
}
}

// We could avoid this barrier in some of the layouts, however this is not
// the general case.
// TODO: optimize the barrier in case the layouts are accepted.
barrier();

// set output values
Expand Down
16 changes: 11 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,23 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
// of shared memory and append it to the operands of the callOp.
auto loc = callOp.getLoc();
auto caller = callOp->getParentOfType<FunctionOpInterface>();
auto base = allocation.getFunctionSharedMemoryBase(caller);
auto *funcAllocation = allocation.getFuncData(caller);
auto bufferId = funcAllocation->getBufferId(callOp);
auto offset = funcAllocation->getOffset(bufferId);
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()),
NVVM::kSharedMemorySpace);
auto offsetValue = gep(ptrTy, base, i32_val(offset));
auto promotedOperands = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
auto base = allocation.getFunctionSharedMemoryBase(caller);
auto *funcAllocation = allocation.getFuncData(caller);
auto bufferId = funcAllocation->getBufferId(callOp);
// function doesn't have a shared mem buffer
if (bufferId == (size_t)-1) {
promotedOperands.push_back(base);
return promotedOperands;
}
// function has a shared mem buffer
auto offset = funcAllocation->getOffset(bufferId);
auto offsetValue = gep(ptrTy, base, i32_val(offset));
promotedOperands.push_back(offsetValue);
return promotedOperands;
}
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
}
auto swait = builder.create("s_waitcnt lgkmcnt(0)");
(*swait)();
return builder.launch(rewriter, loc, val.getType(), true);
#else
PTXBuilder builder;
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
Expand All @@ -148,8 +149,8 @@ Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
auto *cOpr = builder.newConstantOperand("0x1f");
auto *maskOpr = builder.newConstantOperand("0xffffffff");
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
#endif
return builder.launch(rewriter, loc, val.getType(), false);
#endif
}

Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
Expand Down
19 changes: 13 additions & 6 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,18 @@ def check_type_supported(dtype):
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
self.warps_per_cta = warps_per_cta

def __str__(self):
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"


class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
self.sz_per_thread = size_per_thread
self.threads_per_warp = threads_per_warp
self.warps_per_cta = warps_per_cta
self.order = order

def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
Expand Down Expand Up @@ -1959,7 +1959,6 @@ def kernel(X, stride_xm, stride_xk,
out_dtype = tl.float16
else:
out_dtype = tl.float32

pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
Expand All @@ -1974,6 +1973,14 @@ def kernel(X, stride_xm, stride_xk,
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync")
end = ptx.find("cvt.rn.f16.f32")
red_code = ptx[start:end]
assert len(red_code) > 0
assert "shared" not in red_code
assert "bar.sync" not in red_code
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),
Expand Down

0 comments on commit fa42931

Please sign in to comment.