diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 8ce76e9e44cb..55061dacefec 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -50,6 +50,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::gpu::registerAllocateSharedMemoryPass(); mlir::triton::gpu::registerTritonGPUAllocateWarpGroups(); mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass(); + mlir::triton::registerConvertWarpSpecializeToLLVM(); mlir::triton::registerConvertTritonGPUToLLVMPass(); mlir::triton::registerConvertNVGPUToLLVMPass(); mlir::registerLLVMDIScope(); @@ -71,7 +72,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); - // TODO: register Triton & TritonGPU passes registry .insert + TritonLLVMIRRewriter(Location loc, Args &&...args) + : IRRewriter(std::forward(args)...), + TritonLLVMOpBuilder(loc, *this) {} + + // Get the implicit location. + Location getLoc() const { return loc; } + // Set the implicit location used to build ops. + void setLoc(Location loc) { this->loc = loc; } + + // Wrapper for op creation that passes an implicit location. + template OpTy create(Args &&...args) { + return OpBuilder::create(loc, std::forward(args)...); + } +}; } // namespace mlir::triton // Types @@ -548,8 +572,10 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Operation *op) { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), target.getSharedAddressSpace()); - FunctionOpInterface func = - op->template getParentOfType(); + auto func = op->template getParentOfType(); + if (!func) + func = cast(op); + assert(op->hasAttr("allocation.offset")); size_t offset = cast(op->getAttr("allocation.offset")) .getValue() diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index c98d87d65f2d..36cd929e9c11 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -385,6 +385,11 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [ let extraClassDeclaration = [{ RegionRange getPartitionRegions(); + + // Get the size and alignment of the capture list. + std::pair getCaptureSizeAlign(); + // Get the total number of extra warps required. + unsigned getTotalPartitionWarps(); }]; let hasVerifier = 1; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 9d2be5b5fa0b..79dce758d8ac 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -283,6 +283,24 @@ class AllocationAnalysis { scratchAlignment); return; } + if (auto ws = dyn_cast(op)) { + // `ttg.warp_specialize` needs memory to pass its explicit captures. Pack + // the captures like a struct. + auto [captureSize, captureAlign] = ws.getCaptureSizeAlign(); + maybeAddScratchBuffer(op, captureSize, + captureAlign); + return; + } + if (auto func = dyn_cast(op)) { + unsigned numWarpIndices = 0; + // Warp specialization communicates states over shared memory to each + // warp. Add space for an i8 for each warpgroup warp. + func.walk([&](gpu::WarpSpecializeOp op) { + numWarpIndices = std::max(numWarpIndices, op.getTotalPartitionWarps()); + }); + maybeAddScratchBuffer(op, numWarpIndices); + return; + } unsigned bytes = scratchSizeGetter(op); maybeAddScratchBuffer(op, bytes, scratchAlignment); @@ -374,10 +392,19 @@ class AllocationAnalysis { // Analyze liveness of scratch buffers and virtual buffers. auto processScratchMemory = [&](const auto &container) { for (auto [op, buffer] : container) { + // Buffers owned by the function are assumed live for the whole + // function. This memory is used for warp specialization codegen. + // FIXME: Spooky-action-at-a-distance. Find a better way to model this. + if (op == operation) { + bufferRange.insert( + {buffer, Interval(size_t(), std::numeric_limits::max())}); + continue; + } + // Any scratch memory's live range is the current operation's live // range. - bufferRange.insert({buffer, Interval(operationId.lookup(op), - operationId.lookup(op) + 1)}); + bufferRange.insert( + {buffer, Interval(operationId.at(op), operationId.at(op) + 1)}); LLVM_DEBUG({ llvm::dbgs() << "-- buffer " << buffer->id << "; value: "; op->dump(); diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index a81f391723ec..126e3fc310df 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1295,8 +1295,15 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) { std::unique_ptr solver = createDataFlowSolver(); AxisInfoAnalysis *analysis = solver->load(); - if (failed(solver->initializeAndRun(funcOp))) + WalkResult result = funcOp.walk([&](Operation *op) { + if (op->hasTrait() && + failed(solver->initializeAndRun(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (result.wasInterrupted()) return; + auto *axisInfoMap = getFuncData(funcOp); auto updateAxisInfoMap = [&](Value value) { auto axisInfo = analysis->getLatticeElement(value)->getValue(); diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp index df9ae0f40538..c4aadba79d92 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp @@ -23,7 +23,7 @@ struct AllocateWarpGroups int maxExtraWarps = 0; mod.walk([&](WarpSpecializeOp op) { ArrayRef arr = op.getPartitionNumWarps(); - int req = std::accumulate(arr.begin(), arr.end(), 0, std::plus{}); + int req = op.getTotalPartitionWarps(); maxExtraWarps = std::max(maxExtraWarps, req); // Allocate the start IDs such that the largest warpgroups have lower diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c8efcc3f675e..5f0368401a16 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -1,5 +1,4 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 055577cfaf5f..7e9d35e1a685 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -168,6 +168,9 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { // Set an attribute for reqntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. int numWarps = triton::gpu::lookupNumWarps(funcOp); + if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType( + "ttg.total-num-warps")) + numWarps = totalNumWarps.getInt(); newFuncOp->setAttr("nvvm.reqntid", rewriter.getDenseI32ArrayAttr(32 * numWarps)); diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 6df6aed87f79..b118e941b3f9 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -1,4 +1,5 @@ #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/DebugStringHelper.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" @@ -770,4 +771,36 @@ LogicalResult WarpYieldOp::verify() { return success(); } +// Get the size of a scalar type when stored in shared memory. +// TODO: Generalize this as needed. +static size_t getSharedMemorySize(Type type) { + if (isa(type)) + return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8); + if (isa(type)) + return 8; + if (auto desc = dyn_cast(type)) { + if (!isa(desc.getMemorySpace())) + return 8; + return 8 + desc.getRank() * 4; + } + llvm::report_fatal_error( + Twine("shared memory size for scalar type is unspecified: ") + + mlir::debugString(type)); +} + +std::pair WarpSpecializeOp::getCaptureSizeAlign() { + uint64_t captureSize = 0; + // Tightly pack the captures in memory. + for (Type type : getOperandTypes()) { + captureSize += getSharedMemorySize(type); + } + // Align the captures to 8 bytes. + return {captureSize, 8}; +} + +unsigned WarpSpecializeOp::getTotalPartitionWarps() { + ArrayRef numWarps = getPartitionNumWarps(); + return std::accumulate(numWarps.begin(), numWarps.end(), 0); +} + } // namespace mlir::triton::gpu diff --git a/python/src/passes.cc b/python/src/passes.cc index 85b1cc2d5ede..1bbaf6ec72bc 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -63,6 +63,8 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPURemoveLayoutConversions); ADD_PASS_WRAPPER_0("add_reduce_data_duplication", createTritonGPUReduceDataDuplication); + ADD_PASS_WRAPPER_0("add_allocate_warp_groups", + createTritonGPUAllocateWarpGroups); ADD_PASS_WRAPPER_0("add_allocate_shared_memory", createAllocateSharedMemory); ADD_PASS_WRAPPER_0("add_allocate_global_scratch_memory", createTritonGPUGlobalScratchAllocationPass); diff --git a/python/test/unit/language/test_warp_specialization.py b/python/test/unit/language/test_warp_specialization.py new file mode 100644 index 000000000000..52257601c333 --- /dev/null +++ b/python/test/unit/language/test_warp_specialization.py @@ -0,0 +1,103 @@ +import torch +import pytest +import pathlib +import triton + +from triton._internal_testing import is_cuda + + +@pytest.mark.skipif(not is_cuda(), reason="warp specialization is only supported on NVIDIA") +def test_warp_specialize_basic_ir(tmp_path: pathlib.Path): + ir = """ + tt.func @kernel(%arg0: !tt.ptr) { + %c42_i32 = arith.constant 42 : i32 + gpu.barrier + ttg.warp_specialize(%arg0) + default { + tt.store %arg0, %c42_i32 : !tt.ptr + gpu.barrier + ttg.warp_yield + } + partition0(%arg1: !tt.ptr) num_warps(1) { + %c5555_i32 = arith.constant 5555 : i32 + %c1_i32 = arith.constant 1 : i32 + gpu.barrier + %ptr = tt.addptr %arg1, %c1_i32 : !tt.ptr, i32 + tt.store %ptr, %c5555_i32 : !tt.ptr + ttg.warp_return + } : (!tt.ptr) -> () + tt.return + } + """ + + temp_file = tmp_path / "test_warp_specialize_basic_ir.ttir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + input = torch.empty(2, dtype=torch.int32, device='cuda') + kernel[(1, 1, 1)](input) + assert input[0] == 42 + assert input[1] == 5555 + + +@pytest.mark.skipif(not is_cuda(), reason="warp specialization is only supported on NVIDIA") +def test_warpgroup_reduction(tmp_path: pathlib.Path): + + def template(i, num_warps, in_ptr, out_ptr): + return f""" + %range = tt.make_range {{end = {(i+1)*256} : i32, start = {i*256} : i32}} : tensor<256xi32, #blocked{num_warps}> + %splatted = tt.splat {in_ptr} : !tt.ptr -> tensor<256x!tt.ptr, #blocked{num_warps}> + %ptrs = tt.addptr %splatted, %range : tensor<256x!tt.ptr, #blocked{num_warps}>, tensor<256xi32, #blocked{num_warps}> + %input = tt.load %ptrs : tensor<256x!tt.ptr, #blocked{num_warps}> + %result = "tt.reduce"(%input) ({{ + ^bb0(%lhs: i32, %rhs: i32): + %result = arith.addi %lhs, %rhs : i32 + tt.reduce.return %result : i32 + }}) {{axis = 0 : i32}} : (tensor<256xi32, #blocked{num_warps}>) -> i32 + %offset = arith.constant {i} : i32 + %output = tt.addptr {out_ptr}, %offset : !tt.ptr, i32 + tt.store %output, %result : !tt.ptr + """ + + ir = """ + #blocked4 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + #blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> + + module attributes {"ttg.num-warps" = 4 : i32} { + + tt.func @kernel(%arg0: !tt.ptr, %arg1: !tt.ptr) { + ttg.warp_specialize(%arg0, %arg1) + default { + """ + template(0, 4, "%arg0", "%arg1") + """ + ttg.warp_yield + } + partition0(%arg2: !tt.ptr, %arg3: !tt.ptr) num_warps(4) { + """ + template(1, 4, "%arg2", "%arg3") + """ + ttg.warp_return + } + partition1(%arg4: !tt.ptr, %arg5: !tt.ptr) num_warps(2) { + """ + template(2, 2, "%arg4", "%arg5") + """ + ttg.warp_return + } + partition2(%arg6: !tt.ptr, %arg7: !tt.ptr) num_warps(1) { + """ + template(3, 1, "%arg6", "%arg7") + """ + ttg.warp_return + } : (!tt.ptr, !tt.ptr) -> () + tt.return + } + + } + """ + + temp_file = tmp_path / "test_warpgroup_reduction.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + input = torch.arange(1024, dtype=torch.int32, device='cuda') + output = torch.empty(4, dtype=torch.int32, device='cuda') + kernel[(1, 1, 1)](input, output) + assert output[0] == torch.arange(0, 256).sum() + assert output[1] == torch.arange(256, 512).sum() + assert output[2] == torch.arange(512, 768).sum() + assert output[3] == torch.arange(768, 1024).sum() diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 1ac0e3acb919..0ee013584326 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -628,7 +628,8 @@ tt.func @scan_alloc(%x : tensor<8x16xf32, #AL>) { } // expected-remark @below {{warp_specialize_default_region}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} tt.func @warp_specialize_default_region() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> @@ -647,7 +648,8 @@ tt.func @warp_specialize_default_region() { } // expected-remark @below {{nonoverlapping_liveness_in_default_region}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} tt.func @nonoverlapping_liveness_in_default_region() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> @@ -670,7 +672,8 @@ tt.func @nonoverlapping_liveness_in_default_region() { } // expected-remark @below {{overlapping_liveness_in_default_region}} -// expected-remark @below {{size = 48}} +// expected-remark @below {{size = 49}} +// expected-remark @below {{offset = 48, size = 1}} tt.func @overlapping_liveness_in_default_region() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> @@ -693,7 +696,8 @@ tt.func @overlapping_liveness_in_default_region() { } // expected-remark @below {{alias_through_default_outputs}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} tt.func @alias_through_default_outputs() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> @@ -711,7 +715,8 @@ tt.func @alias_through_default_outputs() { } // expected-remark @below {{implicit_capture_liveness}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} tt.func @implicit_capture_liveness() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> @@ -729,12 +734,14 @@ tt.func @implicit_capture_liveness() { } // expected-remark @below {{implicit_and_explicit_capture_liveness}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 45}} +// expected-remark @below {{offset = 44, size = 1}} tt.func @implicit_and_explicit_capture_liveness() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> // expected-remark @below {{offset = 16, size = 16}} %1 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 32, size = 12}} ttg.warp_specialize(%1) default { "use"(%0) : (!ttg.memdesc<2xi64, #A_SHARED, #smem, mutable>) -> () @@ -747,10 +754,12 @@ tt.func @implicit_and_explicit_capture_liveness() { } // expected-remark @below {{explicit_capture_liveness}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} tt.func @explicit_capture_liveness() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> + // expected-remark @below {{offset = 16, size = 12}} ttg.warp_specialize(%0) default { // expected-remark @below {{offset = 16, size = 16}} @@ -764,7 +773,8 @@ tt.func @explicit_capture_liveness() { } // expected-remark @below {{implicit_capture_liveness_default}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 33}} +// expected-remark @below {{offset = 32, size = 1}} tt.func @implicit_capture_liveness_default() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> @@ -784,7 +794,8 @@ tt.func @implicit_capture_liveness_default() { } // expected-remark @below {{liveness_in_partition}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 36}} +// expected-remark @below {{offset = 32, size = 4}} tt.func @liveness_in_partition() { ttg.warp_specialize() default { @@ -802,7 +813,8 @@ tt.func @liveness_in_partition() { } // expected-remark @below {{aliasing_in_partition}} -// expected-remark @below {{size = 32}} +// expected-remark @below {{size = 36}} +// expected-remark @below {{offset = 32, size = 4}} tt.func @aliasing_in_partition() { ttg.warp_specialize() default { @@ -822,7 +834,8 @@ tt.func @aliasing_in_partition() { } // expected-remark @below {{partition_region_interference}} -// expected-remark @below {{size = 80}} +// expected-remark @below {{size = 88}} +// expected-remark @below {{offset = 80, size = 8}} tt.func @partition_region_interference() { // expected-remark @below {{offset = 0, size = 16}} %0 = ttg.local_alloc : () -> !ttg.memdesc<2xi64, #A_SHARED, #smem, mutable> @@ -852,7 +865,8 @@ tt.func @partition_region_interference() { } // expected-remark @below {{two_different_ws}} -// expected-remark @below {{size = 16}} +// expected-remark @below {{size = 17}} +// expected-remark @below {{offset = 16, size = 1}} tt.func @two_different_ws() { ttg.warp_specialize() default { @@ -875,4 +889,26 @@ tt.func @two_different_ws() { tt.return } +// expected-remark @below {{ptr_allocation_datalayout}} +// expected-remark @below {{size = 8}} +tt.func @ptr_allocation_datalayout(%arg0: !tt.ptr) { + // expected-remark @below {{offset = 0, size = 8}} + ttg.warp_specialize(%arg0) + default { + ttg.warp_yield + } : (!tt.ptr) -> () + tt.return +} + +// expected-remark @below {{tightly_packed_captures}} +// expected-remark @below {{size = 9}} +tt.func @tightly_packed_captures(%arg0: i8, %arg1: i64) { + // expected-remark @below {{offset = 0, size = 9}} + ttg.warp_specialize(%arg0, %arg1) + default { + ttg.warp_yield + } : (i8, i64) -> () + tt.return +} + } diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index a1016d78e19c..d042bb9cb20c 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -1,5 +1,7 @@ // RUN: triton-opt %s -split-input-file --convert-scf-to-cf --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=CF // RUN: triton-opt %s -split-input-file --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=SCF +// RUN: triton-opt %s -split-input-file --convert-scf-to-cf --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=CF +// RUN: triton-opt %s -split-input-file --allocate-shared-memory -test-print-membar | FileCheck %s --check-prefix=CHECK --check-prefix=SCF #AL = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #sliceAd0 = #ttg.slice<{dim = 0, parent = #AL}> @@ -847,22 +849,23 @@ tt.func @warp_specialize_isolated_regions(%arg0: tensor<1xi64>) { ttg.local_load %0 : !ttg.memdesc<1xi64, #layout, #smem, mutable> -> tensor<1xi64> // CHECK-NEXT: warp_specialize - ttg.warp_specialize(%arg0) + ttg.warp_specialize() default { ttg.warp_yield } // CHECK: partition0 - partition0(%arg1: tensor<1xi64>) num_warps(4) { - // CHECK-NEXT: local_alloc + partition0() num_warps(4) { + %cst = arith.constant dense<0> : tensor<1xi64> + // CHECK: local_alloc %1 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #layout, #smem, mutable> // CHECK-NEXT: local_store - ttg.local_store %arg1, %1 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> + ttg.local_store %cst, %1 : tensor<1xi64> -> !ttg.memdesc<1xi64, #layout, #smem, mutable> // CHECK-NEXT: barrier // CHECK-NEXT: local_load ttg.local_load %1 : !ttg.memdesc<1xi64, #layout, #smem, mutable> -> tensor<1xi64> // CHECK-NEXT: warp_return ttg.warp_return - } : (tensor<1xi64>) -> () + } : () -> () tt.return } diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index 62763e4f056c..5c677a05f7d3 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -152,3 +152,25 @@ tt.func @ub_poison() { %0 = ub.poison : tensor<128x64xf16> tt.return } + +// ----- + +#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> + +module attributes {"ttg.num-warps" = 4 : i32} { + +// CHECK-LABEL: @partition_axis_info +tt.func @partition_axis_info(%arg0: !tt.ptr, %arg1: !tt.ptr) { + ttg.warp_specialize(%arg0) + default { + ttg.warp_yield + } + partition0(%arg2: !tt.ptr) num_warps(2) { + %splatted = tt.splat %arg2 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %input = tt.load %splatted : tensor<256x!tt.ptr, #blocked2> + ttg.warp_return + } : (!tt.ptr) -> () + tt.return +} + +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index d5fbcfa81f4a..e017c1ed1652 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s --dump-input-context 20 module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { - // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) + // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>) // Here the 128 comes from the 4 in module attribute multiples 32 // CHECK: nvvm.kernel = 1 : ui1, nvvm.reqntid = array tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { diff --git a/test/Conversion/warp_specialize_to_llvm.mlir b/test/Conversion/warp_specialize_to_llvm.mlir new file mode 100644 index 000000000000..a0356038f627 --- /dev/null +++ b/test/Conversion/warp_specialize_to_llvm.mlir @@ -0,0 +1,626 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -convert-warp-specialize-to-llvm | FileCheck %s + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 11 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @rewrite_barriers +llvm.func @rewrite_barriers() attributes {allocation.offset = 32 : i32} { + // CHECK: barrier.sync.aligned 2, 128 + // CHECK: barrier.sync.aligned 3, 64 + // CHECK: barrier.sync.aligned 4, 32 + + // CHECK: bb{{[0-9]+}}: + // CHECK-NEXT: barrier.sync.aligned 0, 128 + nvvm.barrier0 + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + // CHECK: barrier.sync.aligned 0, 128 + nvvm.barrier0 + ttg.warp_yield + } + partition0() num_warps(4) { + nvvm.barrier0 + ttg.warp_return + } + partition1() num_warps(2) { + nvvm.barrier0 + ttg.warp_return + } + partition2() num_warps(1) { + nvvm.barrier0 + ttg.warp_return + } : () -> () + // CHECK: barrier.sync.aligned 0, 128 + nvvm.barrier0 + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 11 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @generate_switch_loop +llvm.func @generate_switch_loop() attributes {allocation.offset = 32 : i32} { + // CHECK-NEXT: [[TIDX:%.*]] = nvvm.read.ptx.sreg.tid.x + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[WID:%.*]] = llvm.udiv [[TIDX]], [[C32]] + // CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32) + // CHECK-NEXT: [[C31:%.*]] = llvm.mlir.constant(31 : i32) + // CHECK-NEXT: [[CNEG1:%.*]] = llvm.mlir.constant(-1 : i32) + // CHECK-NEXT: [[WARP_ID:%.*]] = nvvm.shfl.sync idx [[CNEG1]], [[WID]], [[C0]], [[C31]] + // CHECK-NEXT: [[C4:%.*]] = llvm.mlir.constant(4 : i32) + // CHECK-NEXT: [[IS_DEFAULT:%.*]] = llvm.icmp "ult" [[WARP_ID]], [[C4]] + // CHECK-NEXT: llvm.cond_br [[IS_DEFAULT]], [[BODY:\^.*]], [[SWITCH_LOOP:\^.*]] + + // CHECK: [[SWITCH_LOOP]]: + // CHECK-NEXT: "barrier.sync 1 ;" + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C32]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK-NEXT: [[C4:%.*]] = llvm.mlir.constant(4 : i32) + // CHECK-NEXT: [[REL_WID:%.*]] = llvm.sub [[WARP_ID]], [[C4]] + + // CHECK-NEXT: [[STATE_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][[[REL_WID]]] + // CHECK-NEXT: [[STATE:%.*]] = llvm.load [[STATE_PTR]] + // CHECK-NEXT: llvm.switch [[STATE]] : i8, [[DEFAULT:\^.*]] [ + // CHECK-NEXT: 0: [[PARTITION0:\^.*]], + // CHECK-NEXT: 1: [[PARTITION1:\^.*]], + // CHECK-NEXT: 2: [[PARTITION2:\^.*]], + // CHECK-NEXT: 3: [[EXIT:\^.*]] + + // CHECK: [[DEFAULT]]: + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[EXIT]]: + // CHECK-NEXT: llvm.return + + // CHECK: [[PARTITION0]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "partition0" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[PARTITION1]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "partition1" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[PARTITION2]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "partition2" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: [[BODY]]: + // CHECK-NEXT: "before" + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C32]]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[DEFAULT_PARTITION:\^.*]] + // CHECK: [[DEFAULT_PARTITION]]: + // CHECK-NEXT: "default" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[AFTER:\^.*]] + "before"() : () -> () + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(2) { + "partition1"() : () -> () + ttg.warp_return + } + partition2() num_warps(1) { + "partition2"() : () -> () + ttg.warp_return + } : () -> () + // CHECK: [[AFTER]]: + // CHECK-NEXT: "after" + + // CHECK-NEXT: [[C32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C32]]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.return + "after"() : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 8 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @pass_captures +llvm.func @pass_captures(%arg0: i32, %arg1: i64) attributes {allocation.offset = 32 : i32} { + // CHECK: ^bb4: + // CHECK-NEXT: [[C0:%.*]] = llvm.mlir.constant(0 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C0]]] + + // CHECK-NEXT: [[ARG0_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: [[ARG0:%.*]] = llvm.load [[ARG0_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: [[ARG1_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: [[ARG1:%.*]] = llvm.load [[ARG1_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: "use"([[ARG0]], [[ARG1]]) + // CHECK-NEXT: barrier.sync 1 ; + + // CHECK: ^bb5: + // CHECK: llvm.mlir.addressof @global_smem + // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) + // CHECK-NEXT: [[SMEM_ADDR:%.*]] = llvm.mlir.addressof @global_smem + // CHECK-NEXT: [[SMEM_BASE:%.*]] = llvm.getelementptr [[SMEM_ADDR]][[[C0]]] + // CHECK-NEXT: [[ARG0_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: llvm.store %arg0, [[ARG0_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: [[ARG1_PTR:%.*]] = llvm.getelementptr [[SMEM_BASE]][0, 1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.struct + // CHECK-NEXT: llvm.store %arg1, [[ARG1_PTR]] {alignment = 1 : i64} + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + ttg.warp_specialize(%arg0, %arg1) attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + ttg.warp_yield + } + partition0(%arg2: i32, %arg3: i64) num_warps(4) { + "use"(%arg2, %arg3) : (i32, i64) -> () + ttg.warp_return + } : (i32, i64) -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 18 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @partition_warpid_order +llvm.func @partition_warpid_order() attributes {allocation.offset = 32 : i32} { + // CHECK: llvm.switch + // CHECK-NEXT: 0: [[PARTITION0:\^.*]], + // CHECK-NEXT: 1: [[PARTITION1:\^.*]], + // CHECK-NEXT: 2: [[PARTITION2:\^.*]], + // CHECK-NEXT: 3: [[EXIT:\^.*]] + + // CHECK: [[PARTITION0]]: + // CHECK: "ws0_partition0" + // CHECK: [[PARTITION1]]: + // CHECK: "ws0_partition1" + // CHECK: [[PARTITION2]]: + // CHECK: "ws0_partition2" + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[8] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[9] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[10] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[11] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[12] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[13] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws0_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "ws0_partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(2) { + "ws0_partition1"() : () -> () + ttg.warp_return + } + partition2() num_warps(8) { + "ws0_partition2"() : () -> () + ttg.warp_return + } : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 12 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @multiple_specialize +llvm.func @multiple_specialize() attributes {allocation.offset = 32 : i32} { + // CHECK: llvm.switch + // CHECK-NEXT: 0: [[WS0_PARTITION0:\^.*]], + // CHECK-NEXT: 1: [[WS0_PARTITION1:\^.*]], + // CHECK-NEXT: 2: [[WS0_PARTITION2:\^.*]], + // CHECK-NEXT: 3: [[WS1_PARTITION0:\^.*]], + // CHECK-NEXT: 4: [[WS1_PARTITION1:\^.*]], + // CHECK-NEXT: 5: [[WS3_PARTITION0:\^.*]], + // CHECK-NEXT: 6: [[EXIT:\^.*]] + + // CHECK: [[WS0_PARTITION0]]: + // CHECK: "ws0_partition0" + // CHECK: [[WS0_PARTITION1]]: + // CHECK: "ws0_partition1" + // CHECK: [[WS0_PARTITION2]]: + // CHECK: "ws0_partition2" + // CHECK: [[WS1_PARTITION0]]: + // CHECK: "ws1_partition0" + // CHECK: [[WS1_PARTITION1]]: + // CHECK: "ws1_partition1" + // CHECK: [[WS3_PARTITION0]]: + // CHECK: "ws3_partition0" + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(0 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(2 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws0_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws0_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "ws0_partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(2) { + "ws0_partition1"() : () -> () + ttg.warp_return + } + partition2() num_warps(1) { + "ws0_partition2"() : () -> () + ttg.warp_return + } : () -> () + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(4 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(3 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws1_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws1_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "ws1_partition0"() : () -> () + ttg.warp_return + } + partition1() num_warps(4) { + "ws1_partition1"() : () -> () + ttg.warp_return + } : () -> () + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(-1 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws2_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws2_default"() : () -> () + ttg.warp_yield + } : () -> () + + // CHECK: llvm.mlir.addressof @global_smem + // CHECK-NEXT: getelementptr + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[0] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[1] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[2] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[3] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[4] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[5] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[6] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK-NEXT: [[CX:%.*]] = llvm.mlir.constant(5 : i8) + // CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr %{{[0-9]+}}[7] + // CHECK-NEXT: llvm.store [[CX]], [[PTR]] + // CHECK: barrier.sync 1 ; + // CHECK: "ws3_default" + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "ws3_default"() : () -> () + ttg.warp_yield + } + partition0() num_warps(8) { + "ws3_partition0"() : () -> () + ttg.warp_return + }: () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 8 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @cfg +llvm.func @cfg() attributes {allocation.offset = 32 : i32} { + // CHECK: [[SWITCH_LOOP:\^bb1]]: + // CHECK: llvm.switch + // CHECK-NEXT: 0: [[PARTITION:\^.*]], + // CHECK-NEXT: 1: [[EXIT:\^.*]] + + // CHECK: [[PARTITION]]: + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: "something"()[[[A:\^.*]], [[B:\^.*]]] + // CHECK: [[A]]: + // CHECK-NEXT: "A" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + // CHECK: [[B]]: + // CHECK-NEXT: "B" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[SWITCH_LOOP]] + + // CHECK: barrier.sync 1 ; + // CHECK-NEXT: barrier.sync 1 ; + // CHECK: llvm.br [[DEFAULT:\^.*]] + // CHECK: [[DEFAULT]]: + // CHECK-NEXT: "something"()[[[A:\^.*]], [[B:\^.*]]] + // CHECK: [[A]]: + // CHECK-NEXT: "A" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[AFTER:\^.*]] + // CHECK: [[B]]: + // CHECK-NEXT: "B" + // CHECK-NEXT: barrier.sync 1 ; + // CHECK-NEXT: llvm.br [[AFTER]] + ttg.warp_specialize() attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + "something"()[^A, ^B] : () -> () + ^A: + "A"() : () -> () + ttg.warp_yield + ^B: + "B"() : () -> () + ttg.warp_yield + } + partition0() num_warps(4) { + "something"()[^A, ^B] : () -> () + ^A: + "A"() : () -> () + ttg.warp_return + ^B: + "B"() : () -> () + ttg.warp_return + } : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, "ttg.total-num-warps" = 8 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @no_captures +llvm.func @no_captures() attributes {allocation.offset = 0 : i32} { + ttg.warp_specialize() attributes {warpGroupStartIds = array} + default { + ttg.warp_yield + } + partition0() num_warps(4) { + ttg.warp_return + } : () -> () + llvm.return +} + +} + +// ----- + +module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.total-num-warps" = 6 : i32} { + +llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8> + +// CHECK-LABEL: @type_conversion_results +// CHECK-NOT: !tt.ptr +// CHECK-NOT: unrealized_conversion_cast +llvm.func @type_conversion_results(%arg0: !llvm.ptr<1>) attributes {allocation.offset = 0 : i32} { + %0 = builtin.unrealized_conversion_cast %arg0 : !llvm.ptr<1> to !tt.ptr + %1 = ttg.warp_specialize(%0) attributes {allocation.offset = 0 : i32, warpGroupStartIds = array} + default { + // CHECK: llvm.br [[AFTER:\^.*]](%arg0 : !llvm.ptr<1>) + ttg.warp_yield %0 : !tt.ptr + } + partition0(%arg1: !tt.ptr) num_warps(2) { + %3 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to !llvm.ptr<1> + %4 = llvm.load %3 : !llvm.ptr<1> -> i32 + ttg.warp_return + } : (!tt.ptr) -> !tt.ptr + // CHECK: [[AFTER]]([[OUT:%.*]]: !llvm.ptr<1>): + %2 = builtin.unrealized_conversion_cast %1 : !tt.ptr to !llvm.ptr<1> + // CHECK-NEXT: "use"([[OUT]]) + "use"(%2) : (!llvm.ptr<1>) -> () + llvm.return +} + +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index cb471803222e..6f87ecb6796d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -102,8 +102,6 @@ struct ConvertTritonAMDGPUToLLVM // Lower functions { - mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); mlir::triton::populateFuncOpConversionPattern( @@ -122,8 +120,6 @@ struct ConvertTritonAMDGPUToLLVM // Convert call and ret ops { - mlir::LowerToLLVMOptions option(context); - TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); if (failed( diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 02a77919680d..a0e31212a3ba 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -306,6 +306,7 @@ def make_llir(self, src, metadata, options, capability): nvidia.passes.ttnvgpuir.add_lower_mma(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_allocate_warp_groups(pm) passes.convert.add_scf_to_cf(pm) passes.ttgpuir.add_allocate_shared_memory(pm) nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) @@ -314,6 +315,7 @@ def make_llir(self, src, metadata, options, capability): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) + nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) @@ -347,9 +349,9 @@ def make_llir(self, src, metadata, options, capability): # Get some metadata # warp-specialization mutates num_warps - num_warp_groups = src.get_int_attr("ttg.num-warp-groups-per-cta") - if num_warp_groups is not None: - metadata["num_warps"] *= num_warp_groups + total_num_warps = src.get_int_attr("ttg.total-num-warps") + if total_num_warps is not None: + metadata["num_warps"] = total_num_warps metadata["shared"] = src.get_int_attr("ttg.shared") metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size") metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size") diff --git a/third_party/nvidia/include/NVGPUToLLVM/Passes.td b/third_party/nvidia/include/NVGPUToLLVM/Passes.td index 364ed4601f94..345e6408cf67 100644 --- a/third_party/nvidia/include/NVGPUToLLVM/Passes.td +++ b/third_party/nvidia/include/NVGPUToLLVM/Passes.td @@ -3,7 +3,6 @@ include "mlir/Pass/PassBase.td" - def ConvertNVGPUToLLVM : Pass<"convert-nv-gpu-to-llvm", "mlir::ModuleOp"> { let summary = "Convert NVGPU to LLVM"; let description = [{ @@ -17,4 +16,4 @@ def ConvertNVGPUToLLVM : Pass<"convert-nv-gpu-to-llvm", "mlir::ModuleOp"> { "mlir::triton::nvgpu::NVGPUDialect"]; } -#endif +#endif // NVGPU_CONVERSION_PASSES diff --git a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td index 6cdb657c59aa..ea753f5780dd 100644 --- a/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td +++ b/third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td @@ -31,4 +31,15 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" ]; } -#endif +def ConvertWarpSpecializeToLLVM : Pass<"convert-warp-specialize-to-llvm", "mlir::ModuleOp"> { + let summary = "lower `ttg.warp_specialize` to LLVM"; + let description = [{ + The `convert-warp-specialize-to-llvm` pass performs codegen for warp + specialization. It is a function-level transformation that rewrites + warp-specialized kernels by using shared memory and barriers to communicate + states between the default warpgroup and the worker warps. + }]; + let dependentDialects = ["mlir::LLVM::LLVMDialect", "mlir::NVVM::NVVMDialect"]; +} + +#endif // TRITONGPU_CONVERSION_PASSES diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt index fef1b916498e..82842f516003 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonNVIDIAGPUToLLVM ConvertLayoutOpToLLVM.cpp + ConvertWarpSpecializeToLLVM.cpp MemoryOpToLLVM.cpp DotOpToLLVM/MMAv2.cpp DotOpToLLVM/MMAv5.cpp diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp new file mode 100644 index 000000000000..050fe2f4d2f5 --- /dev/null +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp @@ -0,0 +1,429 @@ +#include "TargetInfo.h" +#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_CONVERTWARPSPECIALIZETOLLVM +#include "TritonNVIDIAGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +//===----------------------------------------------------------------------===// +// convertOpTypes +//===----------------------------------------------------------------------===// + +static void convertOpTypes(Operation *op, const TypeConverter &typeConverter) { + ImplicitLocOpBuilder b(op->getLoc(), op); + SmallVector operands = llvm::to_vector(op->getOperands()); + for (Value &operand : operands) { + Type type = typeConverter.convertType(operand.getType()); + if (type != operand.getType()) { + operand = + b.create(type, operand).getResult(0); + } + } + op->setOperands(operands); + + for (Region ®ion : op->getRegions()) { + b.setInsertionPointToStart(®ion.front()); + for (BlockArgument arg : llvm::to_vector(region.getArguments())) { + Type type = typeConverter.convertType(arg.getType()); + BlockArgument newArg = region.addArgument(type, arg.getLoc()); + auto cast = b.create(arg.getType(), newArg); + arg.replaceAllUsesWith(cast.getResult(0)); + region.eraseArgument(0); + } + } + + SmallVector resultTypes; + (void)typeConverter.convertTypes(op->getResultTypes(), resultTypes); + if (TypeRange(resultTypes) == op->getResultTypes()) + return; + OperationState state(op->getLoc(), op->getName(), op->getOperands(), + resultTypes, op->getAttrs()); + for (Region ®ion : op->getRegions()) + state.addRegion()->takeBody(region); + b.setInsertionPoint(op); + Operation *newOp = b.create(state); + + SmallVector results; + for (auto [i, result, type] : + llvm::enumerate(newOp->getResults(), op->getResultTypes())) { + auto cast = b.create(type, result); + op->getResult(i).replaceAllUsesWith(cast.getResult(0)); + } + op->erase(); +} + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +// Reserve one barrier for the default warp group, one for the start barrier, +// and one for the end barrier. +enum BarrierIndex { + kDefaultWarpGroupBarrierIdx, + kSwitchLoopBarrierIdx, + + kNumReservedBarriers, + kNumBarriers = 16 +}; + +static void createBarrier(TritonLLVMIRRewriter &b, unsigned barIdx, + std::optional numThreads, bool aligned) { + assert(barIdx < 16 && "not enough barriers"); + + PTXBuilder ptxBuilder; + std::string ptxString; + llvm::raw_string_ostream os(ptxString); + os << "barrier.sync"; + if (aligned) + os << ".aligned"; + os << ' ' << barIdx; + if (numThreads) + os << ", " << *numThreads; + + (*ptxBuilder.create<>(ptxString))(); + ptxBuilder.launch(b, b.getLoc(), void_ty(b.getContext())); +} + +//===----------------------------------------------------------------------===// +// lowerWarpSpecialize +//===----------------------------------------------------------------------===// + +// Assign hardware barriers to each warp group and rewrite warp group barriers +// into `barrier.sync` instructions. There is a maximum number of barriers. +static LogicalResult rewriteWarpGroupBarriers(LLVM::LLVMFuncOp func, + ArrayRef wsOps, + unsigned threadsPerWarp, + unsigned defaultWarpGroupSize) { + // HACK: Turn all `nvvm.barrier0` ops into warp group barriers. + func.walk([&](Operation *op) { + // Walk into default regions but not partition regions. + if (isa(op)) + return WalkResult::skip(); + + if (auto bar = dyn_cast(op)) { + TritonLLVMIRRewriter b(bar.getLoc(), bar); + createBarrier(b, /*barIdx=*/0, defaultWarpGroupSize, /*aligned=*/true); + bar.erase(); + return WalkResult::advance(); + } + return WalkResult::advance(); + }); + + // Each partition executes simultaneously, so each will get a different + // barrier ID, but note this means there is a maximum of 16 barriers. + for (WarpSpecializeOp op : wsOps) { + for (auto [idx, partition] : llvm::enumerate(op.getPartitionRegions())) { + unsigned barIdx = idx + kNumReservedBarriers; + if (barIdx >= kNumBarriers) { + return func.emitError("cannot support more than ") + << (kNumBarriers - kNumReservedBarriers) + << " warp group partitions"; + } + unsigned warpGroupSize = threadsPerWarp * op.getPartitionNumWarps()[idx]; + partition->walk([&](NVVM::Barrier0Op bar) { + TritonLLVMIRRewriter b(bar.getLoc(), bar); + createBarrier(b, barIdx, warpGroupSize, /*aligned=*/true); + bar.erase(); + }); + } + } + + return success(); +} + +static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop, + const NVIDIA::TargetInfo &targetInfo) { + TritonLLVMIRRewriter b(ws.getLoc(), ws.getContext()); + + for (Region *partition : ws.getPartitionRegions()) { + // Load the explicit captures from shared memory and replace the block args + // if there are any. + b.setInsertionPointToStart(&partition->front()); + if (partition->getNumArguments()) { + auto captureType = LLVM::LLVMStructType::getLiteral( + b.getContext(), llvm::to_vector(partition->getArgumentTypes()), + /*isPacked=*/true); + Value capturePtr = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, ws); + LLVM::LLVMPointerType ptrTy = ptr_ty(b.getContext(), 3); + for (auto [i, arg] : + llvm::zip(llvm::seq(partition->getNumArguments()), + partition->getArguments())) { + Value ptr = + b.gep(ptrTy, captureType, capturePtr, ArrayRef{0, i}); + // Each thread in the warp group needs a copy of the value. + Value value = b.load(arg.getType(), ptr, /*align=*/1); + arg.replaceAllUsesWith(value); + } + partition->front().eraseArguments([](auto) { return true; }); + } + + // The shared memory is only live for the entry into the region, so put + // another barrier here. + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + + // Rewrite all warp returns. + partition->walk([&](WarpReturnOp op) { + b.setInsertionPoint(op); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.replaceOpWithNewOp(op, switchLoop); + }); + } +} + +static LogicalResult lowerWarpSpecialize(LLVM::LLVMFuncOp func, + const NVIDIA::TargetInfo &targetInfo) { + SmallVector wsOps; + func.walk([&](WarpSpecializeOp op) { wsOps.push_back(op); }); + // Nothing to do. This kernel is not warp specialized. + if (wsOps.empty()) + return success(); + + // Before lowering away `ttg.warp_specialize`, lower warp group barriers. + auto module = cast(func->getParentOp()); + unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module); + unsigned defaultNumWarps = lookupNumWarps(func); + unsigned defaultWarpGroupSize = threadsPerWarp * defaultNumWarps; + if (failed(rewriteWarpGroupBarriers(func, wsOps, threadsPerWarp, + defaultWarpGroupSize))) + return failure(); + + MLIRContext *ctx = func.getContext(); + TritonLLVMIRRewriter b(func.getLoc(), ctx); + Builder rewriter(ctx); + + // Generate the function header. + Block *entry = &func.getBody().front(); + SmallVector argLocs = llvm::to_vector(llvm::map_range( + func.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); + Block *header = b.createBlock(entry, func.getArgumentTypes(), argLocs); + Block *switchLoop = b.createBlock(entry); + b.setInsertionPointToStart(header); + + // This is the absolute thread ID. + Value tid = b.create(i32_ty); + Value wid = b.udiv(tid, b.i32_val(threadsPerWarp)); + // Tell PTXAS this value is warp-uniform. + wid = targetInfo.shuffleIdx(b, b.getLoc(), wid, 0); + Value isDefault = b.icmp_ult(wid, b.i32_val(defaultNumWarps)); + b.create(isDefault, entry, switchLoop); + + // Forward arguments from the header into the old entry block. + for (auto [arg, oldArg] : + llvm::zip(header->getArguments(), entry->getArguments())) + oldArg.replaceAllUsesWith(arg); + entry->eraseArguments([](auto) { return true; }); + + // Generate the switch loop. + auto totalNumWarpsAttr = + module->getAttrOfType("ttg.total-num-warps"); + if (!totalNumWarpsAttr) { + return mlir::emitError(module.getLoc(), + "module missing 'ttg.total-num-warps' attribute"); + } + unsigned totalNumThreads = totalNumWarpsAttr.getInt() * threadsPerWarp; + + // ^switchLoop: + // barrier.sync 1 + // %state_ptr = getelementptr (ptr @shared), + // %rel_tid = sub %tid, + // %rel_wid = udiv %rel_tid, 32 + b.setInsertionPointToStart(switchLoop); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + Value relWid = b.sub(wid, b.i32_val(defaultNumWarps)); + + // The default warp group will populate the state pointer with the state ID + // for all warps. + // %warp_state_ptr = getelementptr ptr %state_tr[%rel_wid] + // %warp_state = load i8 %warp_state_ptr + LLVM::LLVMPointerType ptrTy = ptr_ty(ctx, 3); + Value warpStatePtr = b.gep(ptrTy, i8_ty, statePtr, relWid); + // All threads in a warp reading from the same smem address will not create + // bank conflicts and is better than predicated load. + Value warpState = b.load(i8_ty, warpStatePtr); + + // Pull the partition regions out. Switch based on the state ID to the right + // partition. + SmallVector partitionBlocks; + SmallVector partitionStates; + int32_t partitionStateCounter = 0; + // This represents the data that the default warp group will fill into the + // state pointer before entering each `warp_specialize` region, which maps + // a warp ID to a state ID in the switch. + int32_t maxNumWarps = totalNumWarpsAttr.getInt() - defaultNumWarps; + SmallVector> warpToState( + wsOps.size(), SmallVector(maxNumWarps, -1)); + for (auto [op, stateMap] : llvm::zip(wsOps, warpToState)) { + rewritePartitionRegions(op, switchLoop, targetInfo); + for (auto [partition, partitionNumWarps, startId] : + llvm::zip(op.getPartitionRegions(), op.getPartitionNumWarps(), + *op.getWarpGroupStartIds())) { + partitionStates.push_back(partitionStateCounter++); + partitionBlocks.push_back(&partition->front()); + for (int32_t &stateId : MutableArrayRef(stateMap).slice( + startId - defaultNumWarps, partitionNumWarps)) + stateId = partitionStates.back(); + } + } + if (partitionStateCounter > std::numeric_limits::max()) { + return mlir::emitError(func.getLoc(), + "FIXME: too many warp group partitions"); + } + + // Splice them in reverse order so the IR is easier to read. + Region::BlockListType &funcBlocks = func.getBody().getBlocks(); + for (Block *block : llvm::reverse(partitionBlocks)) { + Region *region = block->getParent(); + funcBlocks.splice(std::next(switchLoop->getIterator()), + region->getBlocks()); + } + + // Default destination. + Block *defaultBlock = new Block; + funcBlocks.insert(std::next(switchLoop->getIterator()), defaultBlock); + b.setInsertionPointToStart(defaultBlock); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.create(switchLoop); + + // Exit state. + Block *switchExit = new Block; + funcBlocks.insert(std::next(defaultBlock->getIterator()), switchExit); + partitionBlocks.push_back(switchExit); + partitionStates.push_back(partitionStateCounter); + + // Create the switch. + b.setInsertionPointToEnd(switchLoop); + SmallVector caseValues; + for (int32_t state : partitionStates) + caseValues.push_back(APInt(8, state)); + b.create(warpState, defaultBlock, ValueRange(), caseValues, + partitionBlocks, + SmallVector(partitionBlocks.size())); + + // Now add synchronization around the default regions. + for (auto [ws, stateMap] : llvm::zip(wsOps, warpToState)) { + Block *before = ws->getBlock(); + Block *after = b.splitBlock(before, ws->getIterator()); + b.setInsertionPointToEnd(before); + Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + for (auto [i, state] : llvm::enumerate(stateMap)) { + b.store(b.i8_val(state), b.gep(ptrTy, i8_ty, statePtr, LLVM::GEPArg(i))); + } + + // Store the captures if there are any. + if (ws.getNumOperands()) { + auto captureType = LLVM::LLVMStructType::getLiteral( + b.getContext(), llvm::to_vector(ws.getOperandTypes()), + /*isPacked=*/true); + Value capturePtr = + LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, ws); + for (auto [i, arg] : llvm::zip(llvm::seq(ws.getNumOperands()), + ws.getOperands())) { + Value ptr = + b.gep(ptrTy, captureType, capturePtr, ArrayRef{0, i}); + b.store(arg, ptr, /*align=*/1); + } + } + + // First barrier releases the waiting warpgroups. The second barrier ensures + // they have read the captures before the memory is released upon entry. + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.create(&ws.getDefaultRegion().front()); + + ws.getDefaultRegion().walk([&](WarpYieldOp op) { + b.setInsertionPoint(op); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + b.replaceOpWithNewOp(op, op.getOperands(), after); + }); + after->getParent()->getBlocks().splice(after->getIterator(), + ws.getDefaultRegion().getBlocks()); + + // Replace the results. + auto outputs = after->addArguments( + ws.getResultTypes(), + SmallVector(ws.getNumResults(), ws.getLoc())); + ws.replaceAllUsesWith(outputs); + ws.erase(); + } + + // Signal all warp groups to exit. + func.walk([&](LLVM::ReturnOp op) { + TritonLLVMIRRewriter b(op.getLoc(), op); + Value statePtr = LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, func); + Value cst = b.i8_val(partitionStateCounter); + for (int32_t i : llvm::seq(maxNumWarps)) + b.store(cst, b.gep(ptrTy, i8_ty, statePtr, LLVM::GEPArg(i))); + createBarrier(b, kSwitchLoopBarrierIdx, /*numThreads=*/std::nullopt, + /*aligned=*/false); + }); + b.setInsertionPointToStart(switchExit); + b.create(ValueRange()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertWarpSpecializeToLLVM + : public mlir::triton::impl::ConvertWarpSpecializeToLLVMBase< + ConvertWarpSpecializeToLLVM> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + // FIXME: Assume warp specialization only happens on Blackwell. + NVIDIA::TargetInfo targetInfo(/*computeCapability=*/100, + /*ptxVersion=*/100); + + // Convert types and cleanup unrealized conversions. + mlir::LowerToLLVMOptions option(&getContext()); + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(&getContext(), option, + targetInfo); + mod.walk([&](Operation *op) { + if (isa(op)) + convertOpTypes(op, typeConverter); + }); + RewritePatternSet patterns(&getContext()); + UnrealizedConversionCastOp::getCanonicalizationPatterns(patterns, + &getContext()); + if (failed(applyPatternsGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + + SmallVector kernels; + for (auto func : mod.getOps()) { + if (func.isPublic()) + kernels.push_back(func); + } + for (LLVM::LLVMFuncOp kernel : kernels) + if (failed(lowerWarpSpecialize(kernel, targetInfo))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index c6c61d0c8884..65353b5c37ea 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -58,6 +58,12 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addLegalOp(); + + // Warp specialization is lowered later. + addLegalOp(); + addLegalOp(); + addLegalOp(); + addLegalOp(); } }; @@ -80,13 +86,16 @@ struct ConvertTritonGPUToLLVM ModuleMembarAnalysis membarPass(&allocation); membarPass.run(); + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); + // Lower functions TritonLLVMFunctionConversionTarget funcTarget(*context); RewritePatternSet funcPatterns(context); - TritonGPUToLLVMTypeConverter funcTypeConverter(context, targetInfo); mlir::triton::populateFuncOpConversionPattern( - funcTypeConverter, funcPatterns, targetInfo, patternBenefitDefault); - mlir::cf::populateControlFlowToLLVMConversionPatterns(funcTypeConverter, + typeConverter, funcPatterns, targetInfo, patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, funcPatterns); if (failed( applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) @@ -95,9 +104,6 @@ struct ConvertTritonGPUToLLVM // initSharedMemory is run before the conversion of call and ret ops, // because the call op has to know the shared memory base address of each // function - mlir::LowerToLLVMOptions option(context); - option.overrideIndexBitwidth(32); - TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo); initSharedMemory(typeConverter); ModuleAxisInfoAnalysis axisInfoAnalysis(mod); diff --git a/third_party/nvidia/triton_nvidia.cc b/third_party/nvidia/triton_nvidia.cc index 713416484635..0b280f789314 100644 --- a/third_party/nvidia/triton_nvidia.cc +++ b/third_party/nvidia/triton_nvidia.cc @@ -38,6 +38,8 @@ void init_triton_nvidia_passes_ttnvgpuir(py::module &&m) { mlir::createTritonNvidiaGPUPromoteLHSToTMemPass); ADD_PASS_WRAPPER_0("add_nvgpu_to_llvm", mlir::triton::createConvertNVGPUToLLVMPass); + ADD_PASS_WRAPPER_0("add_warp_specialize_to_llvm", + mlir::triton::createConvertWarpSpecializeToLLVM); ADD_PASS_WRAPPER_0("add_allocate_tensor_memory", mlir::createTensorMemoryAllocationPass); ADD_PASS_WRAPPER_0("add_lower_mma",