Skip to content

Commit

Permalink
implement proton gpu global scratch alloc to llvm lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
CRobeck committed Feb 25, 2025
1 parent fa08afa commit 592a7f1
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 0 deletions.
12 changes: 12 additions & 0 deletions test/Proton/Conversion/protongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s --dump-input-context 20

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>)
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: %0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: %1 = llvm.getelementptr %arg3[%0] : (!llvm.ptr<1>, i32)
%1 = proton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr<i32>
// CHECK: llvm.return
tt.return
}
} // end module
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateGlobalScratchAllocOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateProtonOpPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef TRITON_CONVERSION_PROTONGPU_TO_LLVM_UTILITY_H
#define TRITON_CONVERSION_PROTONGPU_TO_LLVM_UTILITY_H

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

using namespace mlir;
using namespace mlir::triton;

namespace mlir {
namespace triton {
namespace proton {
namespace gpu {

Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
FunctionOpInterface funcOp, Value allocOffset = {}) {
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto *ctx = rewriter.getContext();
Value bufferPtr = b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty,
gmemBase, allocOffset);
return bufferPtr;
}
} // namespace gpu
} // namespace proton
} // namespace triton
} // namespace mlir
#endif
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_triton_library(ProtonGPUToLLVM
RecordOpToLLVM.cpp
ProtonGPUToLLVM.cpp
GlobalScratchAllocOpToLLVM.cpp

LINK_LIBS PUBLIC
ProtonIR
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

#include "third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/PatternProtonOpToLLVM.h"
#include "third_party/proton/dialect/include/Conversion/ProtonGPUToLLVM/Utility.h"
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
#include "third_party/proton/dialect/include/Dialect/ProtonGPU/IR/Dialect.h"

namespace mlir {
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &converter);
}

namespace {
using namespace mlir;
using namespace mlir::triton;

triton::FuncOp amendFuncOp(LLVM::LLVMFuncOp funcOp,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
auto moduleOp = funcOp->getParentOfType<ModuleOp>();
Location loc = moduleOp->getLoc();
auto ctx = funcOp->getContext();
auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1);
auto funcTy = funcOp.getFunctionType();
auto amendedInputTy = llvm::to_vector(funcOp.getArgumentTypes());
unsigned oldNumArgs = amendedInputTy.size();
amendedInputTy.push_back(globalPtrTy);
auto amendedFuncTy =
FunctionType::get(ctx, amendedInputTy, funcOp.getResultTypes());
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
funcOp.getLoc(), funcOp.getName(), amendedFuncTy);
auto &region = funcOp.getBody();
region.addArgument(globalPtrTy, amendedFuncOp.getLoc());
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
amendedFuncOp.end());
IRMapping mapper;
if (auto argAttrs = funcOp.getAllArgAttrs()) {
SmallVector<Attribute> newArgAttrs;
newArgAttrs.reserve(amendedInputTy.size());
for (unsigned i = 0; i != oldNumArgs; ++i)
if (!mapper.contains(funcOp.getArgument(i)))
newArgAttrs.push_back(argAttrs[i]);
amendedFuncOp.setAllArgAttrs(newArgAttrs);
}
return amendedFuncOp;
}

struct GlobalScratchAllocOpConversion
: public ConvertOpToLLVMPattern<proton::gpu::GlobalScratchAllocOp> {
explicit GlobalScratchAllocOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit)
: mlir::ConvertOpToLLVMPattern<proton::gpu::GlobalScratchAllocOp>(
typeConverter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(proton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.eraseOp(op);
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPointToStart(moduleOp.getBody());
auto amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo);
FailureOr<LLVM::LLVMFuncOp> maybeNewFuncOp =
mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter,
*getTypeConverter());
if (failed(maybeNewFuncOp)) {
return failure();
}

LLVM::LLVMFuncOp newFuncOp = *maybeNewFuncOp;
auto ctx = funcOp->getContext();
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
newFuncOp.setLinkage(LLVM::Linkage::External);
rewriter.eraseOp(funcOp);
rewriter.eraseOp(amendedFuncOp);
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
auto loc = amendedFuncOp.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto gmemBase = newFuncOp.getArgument(newFuncOp.getNumArguments() - 1);
// TODO: implement this offset value
auto opOffset = 0;
auto llvmPointerType = LLVM::LLVMPointerType::get(ctx);
rewriter.create<LLVM::GEPOp>(loc, llvmPointerType, llvmPointerType,
gmemBase, b.i32_val(opOffset));
return success();
}

protected:
const TargetInfoBase &targetInfo;
};

} // namespace

void mlir::triton::proton::populateGlobalScratchAllocOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<GlobalScratchAllocOpConversion>(typeConverter, targetInfo,
benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ void populateProtonOpPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit) {
populateGlobalScratchAllocOpToLLVMPattern(typeConverter, patterns, targetInfo,
benefit);
populateRecordOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit);
}

Expand Down

0 comments on commit 592a7f1

Please sign in to comment.