diff --git a/test/Proton/Conversion/protongpu_to_llvm.mlir b/test/Proton/Conversion/protongpu_to_llvm.mlir new file mode 100644 index 000000000000..3123cec32b01 --- /dev/null +++ b/test/Proton/Conversion/protongpu_to_llvm.mlir @@ -0,0 +1,12 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s --dump-input-context 20 + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) + tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + // CHECK: %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %1 = llvm.getelementptr %arg3[%0] : (!llvm.ptr<1>, i32) + %1 = proton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr + // CHECK: llvm.return + tt.return + } +} // end module diff --git a/third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonOpToLLVM.h b/third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonOpToLLVM.h index f7ec7028d880..39e4fc11ee85 100644 --- a/third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonOpToLLVM.h +++ b/third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonOpToLLVM.h @@ -12,6 +12,11 @@ void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateGlobalScratchAllocOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + void populateProtonOpPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, diff --git a/third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/Utility.h b/third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/Utility.h new file mode 100644 index 000000000000..f49fe5bd5865 --- /dev/null +++ b/third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/Utility.h @@ -0,0 +1,36 @@ +#ifndef TRITON_CONVERSION_PROTONGPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_PROTONGPU_TO_LLVM_UTILITY_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +namespace proton { +namespace gpu { + +Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter, + FunctionOpInterface funcOp, Value allocOffset = {}) { + auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1); + ModuleOp mod = funcOp.getOperation()->getParentOfType(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto *ctx = rewriter.getContext(); + Value bufferPtr = b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, + gmemBase, allocOffset); + return bufferPtr; +} +} // namespace gpu +} // namespace proton +} // namespace triton +} // namespace mlir +#endif diff --git a/third_party/proton/dialect/lib/ProtonGPUToLLVM/CMakeLists.txt b/third_party/proton/dialect/lib/ProtonGPUToLLVM/CMakeLists.txt index 2cdd3d5cb888..17f388268b7a 100644 --- a/third_party/proton/dialect/lib/ProtonGPUToLLVM/CMakeLists.txt +++ b/third_party/proton/dialect/lib/ProtonGPUToLLVM/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(ProtonGPUToLLVM RecordOpToLLVM.cpp ProtonGPUToLLVM.cpp + GlobalScratchAllocOpToLLVM.cpp LINK_LIBS PUBLIC ProtonIR diff --git a/third_party/proton/dialect/lib/ProtonGPUToLLVM/GlobalScratchAllocOpToLLVM.cpp b/third_party/proton/dialect/lib/ProtonGPUToLLVM/GlobalScratchAllocOpToLLVM.cpp new file mode 100644 index 000000000000..e53b1bce60e3 --- /dev/null +++ b/third_party/proton/dialect/lib/ProtonGPUToLLVM/GlobalScratchAllocOpToLLVM.cpp @@ -0,0 +1,111 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +#include "third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonOpToLLVM.h" +#include "third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/Utility.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/ProtonGPU/IR/Dialect.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { +using namespace mlir; +using namespace mlir::triton; + +triton::FuncOp amendFuncOp(LLVM::LLVMFuncOp funcOp, + ConversionPatternRewriter &rewriter, + const TargetInfoBase &targetInfo) { + auto moduleOp = funcOp->getParentOfType(); + Location loc = moduleOp->getLoc(); + auto ctx = funcOp->getContext(); + auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1); + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector(funcOp.getArgumentTypes()); + unsigned oldNumArgs = amendedInputTy.size(); + amendedInputTy.push_back(globalPtrTy); + auto amendedFuncTy = + FunctionType::get(ctx, amendedInputTy, funcOp.getResultTypes()); + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy); + auto ®ion = funcOp.getBody(); + region.addArgument(globalPtrTy, amendedFuncOp.getLoc()); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + IRMapping mapper; + if (auto argAttrs = funcOp.getAllArgAttrs()) { + SmallVector newArgAttrs; + newArgAttrs.reserve(amendedInputTy.size()); + for (unsigned i = 0; i != oldNumArgs; ++i) + if (!mapper.contains(funcOp.getArgument(i))) + newArgAttrs.push_back(argAttrs[i]); + amendedFuncOp.setAllArgAttrs(newArgAttrs); + } + return amendedFuncOp; +} + +struct GlobalScratchAllocOpConversion + : public ConvertOpToLLVMPattern { + explicit GlobalScratchAllocOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern( + typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(proton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.eraseOp(op); + auto funcOp = op->getParentOfType(); + auto moduleOp = op->getParentOfType(); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + auto amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo); + FailureOr maybeNewFuncOp = + mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter, + *getTypeConverter()); + if (failed(maybeNewFuncOp)) { + return failure(); + } + + LLVM::LLVMFuncOp newFuncOp = *maybeNewFuncOp; + auto ctx = funcOp->getContext(); + newFuncOp->setAttr("nvvm.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + newFuncOp.setLinkage(LLVM::Linkage::External); + rewriter.eraseOp(funcOp); + rewriter.eraseOp(amendedFuncOp); + rewriter.setInsertionPointToStart(&newFuncOp.getBody().front()); + auto loc = amendedFuncOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + auto gmemBase = newFuncOp.getArgument(newFuncOp.getNumArguments() - 1); + // TODO: implement this offset value + auto opOffset = 0; + auto llvmPointerType = LLVM::LLVMPointerType::get(ctx); + rewriter.create(loc, llvmPointerType, llvmPointerType, + gmemBase, b.i32_val(opOffset)); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +void mlir::triton::proton::populateGlobalScratchAllocOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, + benefit); +} diff --git a/third_party/proton/dialect/lib/ProtonGPUToLLVM/ProtonGPUToLLVM.cpp b/third_party/proton/dialect/lib/ProtonGPUToLLVM/ProtonGPUToLLVM.cpp index a24719bbbf95..2af9d863ae92 100644 --- a/third_party/proton/dialect/lib/ProtonGPUToLLVM/ProtonGPUToLLVM.cpp +++ b/third_party/proton/dialect/lib/ProtonGPUToLLVM/ProtonGPUToLLVM.cpp @@ -15,6 +15,8 @@ void populateProtonOpPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit) { + populateGlobalScratchAllocOpToLLVMPattern(typeConverter, patterns, targetInfo, + benefit); populateRecordOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); }