Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend] Codegen for ttg.warp_specialize #5968

Merged
merged 87 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
4e7b21b
wip num-warps
Mogball Feb 11, 2025
c31cdc3
[TritonGPU] Refactor `numWarps` lookups to be from the op (NFC)
Mogball Feb 12, 2025
9c72799
fmt
Mogball Feb 12, 2025
519edc8
Merge remote-tracking branch 'origin/main' into mogball/ws
Mogball Feb 12, 2025
6455147
clean up thread ID access
Mogball Feb 12, 2025
d77d445
start refactoring
Mogball Feb 13, 2025
4434fe4
[BACKEND] Refactor how thread/lane/warp IDs are created (NFC)
Mogball Feb 13, 2025
d601651
Merge remote-tracking branch 'origin/main' into mogball/ws1
Mogball Feb 13, 2025
0ce599b
start
Mogball Feb 13, 2025
a92a438
remove emitHardwareTuple
Mogball Feb 13, 2025
9b8e05d
Merge branch 'mogball/ws1' into mogball/ws2
Mogball Feb 13, 2025
f133abb
x
Mogball Feb 13, 2025
84ae862
Merge remote-tracking branch 'origin/main' into mogball/ws1
Mogball Feb 13, 2025
280f10e
merge main
Mogball Feb 13, 2025
91b9af6
Merge branch 'mogball/ws1' into mogball/ws2
Mogball Feb 13, 2025
4583b53
add ttg.warp_specialize op
Mogball Feb 13, 2025
d65a38d
clean up warp attrs
Mogball Feb 13, 2025
384b321
add tests for layout invariants
Mogball Feb 14, 2025
a86fe07
more tests
Mogball Feb 14, 2025
b7e2a8c
add op documentations
Mogball Feb 14, 2025
6bb80ef
fix pass defs
Mogball Feb 14, 2025
db724f8
x
Mogball Feb 14, 2025
2e13070
Merge remote-tracking branch 'origin/main' into mogball/ws2
Mogball Feb 14, 2025
15240b1
Merge branch 'mogball/ws2' into mogball/ws3
Mogball Feb 14, 2025
90343df
finish writing pass
Mogball Feb 14, 2025
a409ef0
add test
Mogball Feb 14, 2025
67fe223
more refactoring
Mogball Feb 14, 2025
cbfce0e
relative threadid
Mogball Feb 14, 2025
81700bd
rewrite allocation tests
Mogball Feb 14, 2025
db30ead
x
Mogball Feb 15, 2025
76901cd
misc
Mogball Feb 18, 2025
c294b29
fix dialect registration
Mogball Feb 18, 2025
5a50883
small fix
Mogball Feb 18, 2025
3c39512
it works
Mogball Feb 19, 2025
966df1e
Merge remote-tracking branch 'origin/main' into mogball/ws3
Mogball Feb 19, 2025
facb459
tests for allocation
Mogball Feb 19, 2025
30cb5cb
add another test
Mogball Feb 19, 2025
5ec20bb
format
Mogball Feb 19, 2025
b661bd9
fix build
Mogball Feb 19, 2025
8678d08
fmt again
Mogball Feb 19, 2025
f4d8405
fix
Mogball Feb 19, 2025
b70eebf
tmem allocation support
Mogball Feb 19, 2025
5e5c94b
make it work with membar
Mogball Feb 19, 2025
401747b
fix test
Mogball Feb 19, 2025
841057c
cleanup
Mogball Feb 19, 2025
0b6ec55
start writing the codegen
Mogball Feb 20, 2025
1b74c7c
remove recursion
Mogball Feb 20, 2025
8286517
fixme for warp id alignment
Mogball Feb 20, 2025
750c3f6
bufferId is size_t
Mogball Feb 20, 2025
4d095b2
warp yield returnlike
Mogball Feb 20, 2025
be01587
AsyncRegions op trait
Mogball Feb 20, 2025
6c0afc0
Merge branch 'mogball/ws3' into mogball/ws4
Mogball Feb 20, 2025
360079f
teach membar about RegionBranchOpInterface
Mogball Feb 20, 2025
9f19bab
skip scf tests
Mogball Feb 20, 2025
4463849
Merge branch 'mogball/ws3' into mogball/ws4
Mogball Feb 20, 2025
69556fe
add test for cfg
Mogball Feb 20, 2025
43d5f86
Merge remote-tracking branch 'origin/main' into mogball/ws3
Mogball Feb 20, 2025
0350627
it also works on scf now
Mogball Feb 20, 2025
e3088db
Merge branch 'mogball/ws3' into mogball/ws4
Mogball Feb 20, 2025
30eb5f2
undo
Mogball Feb 20, 2025
4d6a164
x
Mogball Feb 20, 2025
5293e40
plumbing data layout
Mogball Feb 21, 2025
f8e0999
still writing lowering
Mogball Feb 21, 2025
4747bc7
sort of works
Mogball Feb 24, 2025
847acee
review comments
Mogball Feb 24, 2025
d4fe5a5
Merge branch 'mogball/ws3' into mogball/ws4
Mogball Feb 24, 2025
2909356
fix minor bugs, add tests
Mogball Feb 24, 2025
f3d3762
more tests
Mogball Feb 24, 2025
caa229a
fix memalloc tests
Mogball Feb 24, 2025
3bcd2ce
fix crash when no captures
Mogball Feb 24, 2025
8ec256f
ptr datalayout
Mogball Feb 24, 2025
70c84d9
fix type conversion
Mogball Feb 25, 2025
915fc2d
integration tests
Mogball Feb 25, 2025
124d6c9
another integration test
Mogball Feb 25, 2025
08a3cee
Merge remote-tracking branch 'origin/main' into mogball/ws3
Mogball Feb 25, 2025
39cb456
Merge branch 'mogball/ws3' into mogball/ws4
Mogball Feb 25, 2025
4d20668
fix things
Mogball Feb 25, 2025
6b78ebd
fmt
Mogball Feb 25, 2025
3155489
document virtual block
Mogball Feb 25, 2025
a6c0ac7
Merge branch 'mogball/ws3' into mogball/ws4
Mogball Feb 25, 2025
c32e75e
fudge some unifomity
Mogball Feb 25, 2025
20acb1d
skip test on AMD
Mogball Feb 25, 2025
d9a27c9
run on A100 anyways
Mogball Feb 25, 2025
ab92c09
Merge remote-tracking branch 'origin/main' into mogball/ws4
Mogball Feb 25, 2025
095963d
strip out DataLayout plumbing
Mogball Feb 26, 2025
db5df4a
add desc
Mogball Feb 26, 2025
e92ecff
fmt
Mogball Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::gpu::registerAllocateSharedMemoryPass();
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();
mlir::triton::gpu::registerTritonGPUGlobalScratchAllocationPass();
mlir::triton::registerConvertWarpSpecializeToLLVM();
mlir::triton::registerConvertTritonGPUToLLVMPass();
mlir::triton::registerConvertNVGPUToLLVMPass();
mlir::registerLLVMDIScope();
Expand All @@ -71,7 +72,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();

// TODO: register Triton & TritonGPU passes
registry
.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
Expand Down
30 changes: 28 additions & 2 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
Expand Down Expand Up @@ -279,6 +280,29 @@ struct TritonLLVMOpBuilder {
Location loc;
OpBuilder *builder;
};

// This builder combines an IRRewriter and a TritonLLVMOpBuilder into one,
// making it easy to create operations with an implicit location and create LLVM
// operations with shorthands.
class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
public:
// Create a builder with an implicit location. Arguments are forwarded to
// IRRewriter's constructor.
template <typename... Args>
TritonLLVMIRRewriter(Location loc, Args &&...args)
: IRRewriter(std::forward<Args>(args)...),
TritonLLVMOpBuilder(loc, *this) {}

// Get the implicit location.
Location getLoc() const { return loc; }
// Set the implicit location used to build ops.
void setLoc(Location loc) { this->loc = loc; }

// Wrapper for op creation that passes an implicit location.
template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
return OpBuilder::create<OpTy>(loc, std::forward<Args>(args)...);
}
};
} // namespace mlir::triton

// Types
Expand Down Expand Up @@ -548,8 +572,10 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
target.getSharedAddressSpace());
FunctionOpInterface func =
op->template getParentOfType<FunctionOpInterface>();
auto func = op->template getParentOfType<FunctionOpInterface>();
if (!func)
func = cast<FunctionOpInterface>(op);

assert(op->hasAttr("allocation.offset"));
size_t offset = cast<IntegerAttr>(op->getAttr("allocation.offset"))
.getValue()
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [

let extraClassDeclaration = [{
RegionRange getPartitionRegions();

// Get the size and alignment of the capture list.
std::pair<uint64_t, uint64_t> getCaptureSizeAlign();
// Get the total number of extra warps required.
unsigned getTotalPartitionWarps();
}];

let hasVerifier = 1;
Expand Down
31 changes: 29 additions & 2 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,24 @@ class AllocationAnalysis {
scratchAlignment);
return;
}
if (auto ws = dyn_cast<gpu::WarpSpecializeOp>(op)) {
// `ttg.warp_specialize` needs memory to pass its explicit captures. Pack
// the captures like a struct.
auto [captureSize, captureAlign] = ws.getCaptureSizeAlign();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, captureSize,
captureAlign);
return;
}
if (auto func = dyn_cast<FunctionOpInterface>(op)) {
unsigned numWarpIndices = 0;
// Warp specialization communicates states over shared memory to each
// warp. Add space for an i8 for each warpgroup warp.
func.walk([&](gpu::WarpSpecializeOp op) {
numWarpIndices = std::max(numWarpIndices, op.getTotalPartitionWarps());
});
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, numWarpIndices);
return;
}
unsigned bytes = scratchSizeGetter(op);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
Expand Down Expand Up @@ -374,10 +392,19 @@ class AllocationAnalysis {
// Analyze liveness of scratch buffers and virtual buffers.
auto processScratchMemory = [&](const auto &container) {
for (auto [op, buffer] : container) {
// Buffers owned by the function are assumed live for the whole
// function. This memory is used for warp specialization codegen.
// FIXME: Spooky-action-at-a-distance. Find a better way to model this.
if (op == operation) {
bufferRange.insert(
{buffer, Interval(size_t(), std::numeric_limits<size_t>::max())});
continue;
}

// Any scratch memory's live range is the current operation's live
// range.
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
bufferRange.insert(
{buffer, Interval(operationId.at(op), operationId.at(op) + 1)});
LLVM_DEBUG({
llvm::dbgs() << "-- buffer " << buffer->id << "; value: ";
op->dump();
Expand Down
9 changes: 8 additions & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,8 +1295,15 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) {
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(funcOp)))
WalkResult result = funcOp.walk([&](Operation *op) {
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
failed(solver->initializeAndRun(op)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (result.wasInterrupted())
return;

auto *axisInfoMap = getFuncData(funcOp);
auto updateAxisInfoMap = [&](Value value) {
auto axisInfo = analysis->getLatticeElement(value)->getValue();
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AllocateWarpGroups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct AllocateWarpGroups
int maxExtraWarps = 0;
mod.walk([&](WarpSpecializeOp op) {
ArrayRef<int32_t> arr = op.getPartitionNumWarps();
int req = std::accumulate(arr.begin(), arr.end(), 0, std::plus{});
int req = op.getTotalPartitionWarps();
maxExtraWarps = std::max(maxExtraWarps, req);

// Allocate the start IDs such that the largest warpgroups have lower
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Support/LogicalResult.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
int numWarps = triton::gpu::lookupNumWarps(funcOp);
if (auto totalNumWarps = funcOp.getParentOp()->getAttrOfType<IntegerAttr>(
"ttg.total-num-warps"))
numWarps = totalNumWarps.getInt();
newFuncOp->setAttr("nvvm.reqntid",
rewriter.getDenseI32ArrayAttr(32 * numWarps));

Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/DebugStringHelper.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
Expand Down Expand Up @@ -770,4 +771,36 @@ LogicalResult WarpYieldOp::verify() {
return success();
}

// Get the size of a scalar type when stored in shared memory.
// TODO: Generalize this as needed.
static size_t getSharedMemorySize(Type type) {
if (isa<IntegerType, FloatType>(type))
return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8);
if (isa<PointerType>(type))
return 8;
if (auto desc = dyn_cast<MemDescType>(type)) {
if (!isa<SharedMemorySpaceAttr>(desc.getMemorySpace()))
return 8;
return 8 + desc.getRank() * 4;
}
llvm::report_fatal_error(
Twine("shared memory size for scalar type is unspecified: ") +
mlir::debugString(type));
}

std::pair<uint64_t, uint64_t> WarpSpecializeOp::getCaptureSizeAlign() {
uint64_t captureSize = 0;
// Tightly pack the captures in memory.
for (Type type : getOperandTypes()) {
captureSize += getSharedMemorySize(type);
}
// Align the captures to 8 bytes.
return {captureSize, 8};
}

unsigned WarpSpecializeOp::getTotalPartitionWarps() {
ArrayRef<int32_t> numWarps = getPartitionNumWarps();
return std::accumulate(numWarps.begin(), numWarps.end(), 0);
}

} // namespace mlir::triton::gpu
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createTritonGPURemoveLayoutConversions);
ADD_PASS_WRAPPER_0("add_reduce_data_duplication",
createTritonGPUReduceDataDuplication);
ADD_PASS_WRAPPER_0("add_allocate_warp_groups",
createTritonGPUAllocateWarpGroups);
ADD_PASS_WRAPPER_0("add_allocate_shared_memory", createAllocateSharedMemory);
ADD_PASS_WRAPPER_0("add_allocate_global_scratch_memory",
createTritonGPUGlobalScratchAllocationPass);
Expand Down
103 changes: 103 additions & 0 deletions python/test/unit/language/test_warp_specialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import pytest
import pathlib
import triton

from triton._internal_testing import is_cuda


@pytest.mark.skipif(not is_cuda(), reason="warp specialization is only supported on NVIDIA")
def test_warp_specialize_basic_ir(tmp_path: pathlib.Path):
ir = """
tt.func @kernel(%arg0: !tt.ptr<i32>) {
%c42_i32 = arith.constant 42 : i32
gpu.barrier
ttg.warp_specialize(%arg0)
default {
tt.store %arg0, %c42_i32 : !tt.ptr<i32>
gpu.barrier
ttg.warp_yield
}
partition0(%arg1: !tt.ptr<i32>) num_warps(1) {
%c5555_i32 = arith.constant 5555 : i32
%c1_i32 = arith.constant 1 : i32
gpu.barrier
%ptr = tt.addptr %arg1, %c1_i32 : !tt.ptr<i32>, i32
tt.store %ptr, %c5555_i32 : !tt.ptr<i32>
ttg.warp_return
} : (!tt.ptr<i32>) -> ()
tt.return
}
"""

temp_file = tmp_path / "test_warp_specialize_basic_ir.ttir"
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))

input = torch.empty(2, dtype=torch.int32, device='cuda')
kernel[(1, 1, 1)](input)
assert input[0] == 42
assert input[1] == 5555


@pytest.mark.skipif(not is_cuda(), reason="warp specialization is only supported on NVIDIA")
def test_warpgroup_reduction(tmp_path: pathlib.Path):

def template(i, num_warps, in_ptr, out_ptr):
return f"""
%range = tt.make_range {{end = {(i+1)*256} : i32, start = {i*256} : i32}} : tensor<256xi32, #blocked{num_warps}>
%splatted = tt.splat {in_ptr} : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>, #blocked{num_warps}>
%ptrs = tt.addptr %splatted, %range : tensor<256x!tt.ptr<i32>, #blocked{num_warps}>, tensor<256xi32, #blocked{num_warps}>
%input = tt.load %ptrs : tensor<256x!tt.ptr<i32>, #blocked{num_warps}>
%result = "tt.reduce"(%input) ({{
^bb0(%lhs: i32, %rhs: i32):
%result = arith.addi %lhs, %rhs : i32
tt.reduce.return %result : i32
}}) {{axis = 0 : i32}} : (tensor<256xi32, #blocked{num_warps}>) -> i32
%offset = arith.constant {i} : i32
%output = tt.addptr {out_ptr}, %offset : !tt.ptr<i32>, i32
tt.store %output, %result : !tt.ptr<i32>
"""

ir = """
#blocked4 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>

module attributes {"ttg.num-warps" = 4 : i32} {

tt.func @kernel(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
ttg.warp_specialize(%arg0, %arg1)
default {
""" + template(0, 4, "%arg0", "%arg1") + """
ttg.warp_yield
}
partition0(%arg2: !tt.ptr<i32>, %arg3: !tt.ptr<i32>) num_warps(4) {
""" + template(1, 4, "%arg2", "%arg3") + """
ttg.warp_return
}
partition1(%arg4: !tt.ptr<i32>, %arg5: !tt.ptr<i32>) num_warps(2) {
""" + template(2, 2, "%arg4", "%arg5") + """
ttg.warp_return
}
partition2(%arg6: !tt.ptr<i32>, %arg7: !tt.ptr<i32>) num_warps(1) {
""" + template(3, 1, "%arg6", "%arg7") + """
ttg.warp_return
} : (!tt.ptr<i32>, !tt.ptr<i32>) -> ()
tt.return
}

}
"""

temp_file = tmp_path / "test_warpgroup_reduction.ttgir"
temp_file.write_text(ir)
kernel = triton.compile(str(temp_file))

input = torch.arange(1024, dtype=torch.int32, device='cuda')
output = torch.empty(4, dtype=torch.int32, device='cuda')
kernel[(1, 1, 1)](input, output)
assert output[0] == torch.arange(0, 256).sum()
assert output[1] == torch.arange(256, 512).sum()
assert output[2] == torch.arange(512, 768).sum()
assert output[3] == torch.arange(768, 1024).sum()
Loading
Loading