diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 4d6db9638ed7..88d98d4645d0 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -171,12 +171,18 @@ class AllocationAnalysis { } } + template + void maybeAddScratchBuffer(Operation *op, unsigned bytes) { + if (bytes > 0) + allocation->addBuffer(op, bytes); + } + /// Initializes temporary shared memory for a given operation. void getScratchValueSize(Operation *op) { if (auto reduceOp = dyn_cast(op)) { ReduceOpHelper helper(reduceOp); unsigned bytes = helper.getScratchSizeInBytes(); - allocation->addBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes); } else if (auto cvtLayout = dyn_cast(op)) { auto srcTy = cvtLayout.getSrc().getType().cast(); auto dstTy = cvtLayout.getResult().getType().cast(); @@ -200,7 +206,7 @@ class AllocationAnalysis { srcTy.getElementType().isa() ? elems * kPtrBitWidth / 8 : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; - allocation->addBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes); } else if (auto atomicRMWOp = dyn_cast(op)) { auto value = op->getOperand(0); // only scalar requires scratch memory @@ -217,7 +223,7 @@ class AllocationAnalysis { elemTy.isa() ? elems * kPtrBitWidth / 8 : elems * std::max(8, elemTy.getIntOrFloatBitWidth()) / 8; - allocation->addBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes); } } else if (auto atomicCASOp = dyn_cast(op)) { auto value = op->getOperand(0); @@ -229,13 +235,13 @@ class AllocationAnalysis { auto bytes = elemTy.isa() ? elems * kPtrBitWidth / 8 : elems * elemTy.getIntOrFloatBitWidth() / 8; - allocation->addBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes); } else if (auto callOp = dyn_cast(op)) { auto callable = callOp.resolveCallable(); auto funcOp = dyn_cast(callable); auto *funcAlloc = &(*funcAllocMap)[funcOp]; auto bytes = funcAlloc->getSharedMemorySize(); - allocation->addBuffer(op, bytes); + maybeAddScratchBuffer(op, bytes); } } diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6cb791de41c1..92f7d9b6aaca 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -60,9 +60,10 @@ SmallVector> ReduceOpHelper::getScratchConfigsFast() { auto argLayout = getSrcLayout(); auto argLayoutMma = argLayout.dyn_cast(); - // if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 && - // triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1) - // return {{1, 1}, {1, 1}}; + + // that case doesn't need inter-warp communication + if (isFastReduction() && triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1) + return {{0, 0}, {0, 0}}; /// shared memory block0 smemShapes[0] = convertType(getSrcShape()); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 087ce718849c..f55014ca419f 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -349,18 +349,20 @@ struct ReduceOpConversion unsigned elems = product(smemShapes[0]); unsigned maxElems = std::max(elems, product(smemShapes[1])); - SmallVector smemBases(op.getNumOperands()); - smemBases[0] = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = - bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)), - elemPtrTys[i]); - } - unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); + SmallVector smemBases(op.getNumOperands()); + if (sizeInterWarps > 1) { + smemBases[0] = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + for (unsigned i = 1; i < op.getNumOperands(); ++i) { + smemBases[i] = + bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(maxElems)), + elemPtrTys[i]); + } + } + unsigned srcElems = getTotalElemsPerThread(srcTys[0]); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); auto srcValues = unpackInputs(loc, op, adaptor, rewriter); @@ -418,6 +420,7 @@ struct ReduceOpConversion Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); + std::map, SmallVector> finalAccs; for (auto it : accs) { const SmallVector &key = it.first; SmallVector acc = it.second; @@ -440,8 +443,13 @@ struct ReduceOpConversion accumulate(rewriter, *combineOp, acc, shfl, false); } + if (sizeInterWarps == 1) { + finalAccs[key] = acc; + continue; + } + SmallVector writeIdx = indices[key]; - writeIdx[axis] = (sizeInterWarps == 1) ? zero : warpIdAxis; + writeIdx[axis] = warpIdAxis; Value writeOffset = linearize(rewriter, loc, writeIdx, smemShapes[0], order); for (unsigned i = 0; i < op.getNumOperands(); ++i) { @@ -450,6 +458,30 @@ struct ReduceOpConversion } } + if (sizeInterWarps == 1) { + SmallVector results(op.getNumOperands()); + for (unsigned i = 0; i < op.getNumOperands(); ++i) { + if (auto resultTy = + op.getResult()[i].getType().dyn_cast()) { + auto resultLayout = resultTy.getEncoding().cast(); + unsigned resultElems = getTotalElemsPerThread(resultTy); + SmallVector> resultOffset = + emitOffsetForLayout(resultLayout, resultTy); + SmallVector resultVals; + for (int j = 0; j < resultElems; j++) { + auto key = resultOffset[j]; + key.insert(key.begin() + axis, 0); + resultVals.push_back(finalAccs[key][i]); + } + results[i] = getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); + } else + results[i] = finalAccs.begin()->second[i]; + } + rewriter.replaceOp(op, results); + return success(); + } + barrier(); // The second round of shuffle reduction @@ -508,9 +540,6 @@ struct ReduceOpConversion } } - // We could avoid this barrier in some of the layouts, however this is not - // the general case. - // TODO: optimize the barrier in case the layouts are accepted. barrier(); // set output values diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 335e8539ac49..c0369d4c9ad3 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -213,17 +213,23 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { // of shared memory and append it to the operands of the callOp. auto loc = callOp.getLoc(); auto caller = callOp->getParentOfType(); - auto base = allocation.getFunctionSharedMemoryBase(caller); - auto *funcAllocation = allocation.getFuncData(caller); - auto bufferId = funcAllocation->getBufferId(callOp); - auto offset = funcAllocation->getOffset(bufferId); auto ptrTy = LLVM::LLVMPointerType::get( this->getTypeConverter()->convertType(rewriter.getI8Type()), NVVM::kSharedMemorySpace); - auto offsetValue = gep(ptrTy, base, i32_val(offset)); auto promotedOperands = this->getTypeConverter()->promoteOperands( callOp.getLoc(), /*opOperands=*/callOp->getOperands(), adaptor.getOperands(), rewriter); + auto base = allocation.getFunctionSharedMemoryBase(caller); + auto *funcAllocation = allocation.getFuncData(caller); + auto bufferId = funcAllocation->getBufferId(callOp); + // function doesn't have a shared mem buffer + if (bufferId == (size_t)-1) { + promotedOperands.push_back(base); + return promotedOperands; + } + // function has a shared mem buffer + auto offset = funcAllocation->getOffset(bufferId); + auto offsetValue = gep(ptrTy, base, i32_val(offset)); promotedOperands.push_back(offsetValue); return promotedOperands; } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 759f082dea51..6443a15250be 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -113,7 +113,7 @@ def check_type_supported(dtype): class MmaLayout: def __init__(self, version, warps_per_cta): self.version = version - self.warps_per_cta = str(warps_per_cta) + self.warps_per_cta = warps_per_cta def __str__(self): return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>" @@ -121,10 +121,10 @@ def __str__(self): class BlockedLayout: def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): - self.sz_per_thread = str(size_per_thread) - self.threads_per_warp = str(threads_per_warp) - self.warps_per_cta = str(warps_per_cta) - self.order = str(order) + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order def __str__(self): return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" @@ -1959,7 +1959,6 @@ def kernel(X, stride_xm, stride_xk, out_dtype = tl.float16 else: out_dtype = tl.float32 - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, w_tri.stride(0), w_tri.stride(1), @@ -1974,6 +1973,14 @@ def kernel(X, stride_xm, stride_xk, CHAIN_DOT=epilogue == 'chain-dot', ALLOW_TF32=allow_tf32, num_warps=num_warps) + if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32): + ptx = pgm.asm["ptx"] + start = ptx.find("shfl.sync") + end = ptx.find("cvt.rn.f16.f32") + red_code = ptx[start:end] + assert len(red_code) > 0 + assert "shared" not in red_code + assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': z_ref = np.matmul(x.astype(np.float32),