From 032fa41a45847cdc00119ed3bdd5bc0adab9c938 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Mon, 3 Feb 2025 22:31:25 -0800 Subject: [PATCH 1/7] [LAYOUTS] Don't hoist into ifs outside of loops (#5801) Hoisting layout conversions into ifs relies on the assumption that the if infrequently executes, but this assumption only makes sense in a loop. A single top-level if in a kernel either executes or it doesn't, and if the hoist is incorrect, it can lead to a slowdown. --- .../Transforms/RemoveLayoutConversions.cpp | 81 ++++++------------- test/TritonGPU/combine.mlir | 20 +++-- 2 files changed, 33 insertions(+), 68 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 0514ffc5dc..99b60c07bd 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1301,27 +1301,19 @@ void LayoutRematerialization::hoistConvertIntoConditionals( // These are the conditional edges above which conversions should be hoisted. // The value represents the `scf.if` op result and the operand represents the // edge into one of the branches. - SmallVector> hoistAbove; + SmallVector> hoistAbove; // The list of `scf.if` op results in the slice that are not rematerializable. // Hoisting is terminated at these values. SmallVector terminals; - // Process the whole backward slice in subslices that stop at each condtional. - // This is so we can apply more specific rules about when to hoist. - struct Subslice { - OpResult v; - OpOperand *edge; - SetVector slice; - DenseMap layout; - }; - SmallVector subslices; - - // Check a value in the subslice. - auto visitValue = [&](OpResult v) { + // This loop recurses through the subslices of the backwards dependencies, so + // re-query the size of `slice`. + for (unsigned i = 0; i != slice.size(); ++i) { + Value v = slice[i]; auto ifOp = v.getDefiningOp(); if (!ifOp) - return; + continue; Attribute rootLayout = layout.at(v); unsigned resIdx = cast(v).getResultNumber(); @@ -1350,66 +1342,41 @@ void LayoutRematerialization::hoistConvertIntoConditionals( slice.insert(elseSlice.begin(), elseSlice.end()); layout.insert(thenLayout.begin(), thenLayout.end()); layout.insert(elseLayout.begin(), elseLayout.end()); - return; + continue; } // If propagation across both edges failed, then this conditional // terminates backwards rematerialization. if (failed(thenResult) && failed(elseResult)) { - terminals.push_back(v); - return; + terminals.push_back(cast(v)); + continue; + } + + // Only hoist into conditionals inside loops. The assumption is that an if + // inside a loop executes fewer than the total number of loop iterations, + // making this hoist profitable. + if (!isa(ifOp->getParentOp())) { + terminals.push_back(cast(v)); + continue; } // The layout conversion can be rematerialized along one edge but not the // other. We can hoist the conversion into the other branch. Push this // into the subslice list for analysis. if (succeeded(thenResult)) { - subslices.push_back( - {v, &elseRes, std::move(thenSlice), std::move(thenLayout)}); + hoistAbove.emplace_back(v, &elseRes); + slice.insert(thenSlice.begin(), thenSlice.end()); + layout.insert(thenLayout.begin(), thenLayout.end()); } else { - subslices.push_back( - {v, &thenRes, std::move(elseSlice), std::move(elseLayout)}); - } - }; - - // Process the whole slice in subslices. - unsigned i = 0; - bool isLoneHoist = false; - do { - // Visit values in the current subslice. - for (; i != slice.size(); ++i) { - if (auto v = dyn_cast(slice[i])) - visitValue(v); - } - // Check the next chunk of subslices. When a condtional is marked as being - // valid to be hoisted across, we have to recurse on a new subslice rooted - // at the corresopnding yield operand. - // - // Hoist across condtionals when: - // 1. The conditional is directly inside a loop. - // 2. The whole slice contains only one conditional. - for (auto &[v, edge, subslice, layouts] : subslices) { - bool oneHoist = false; - if (isa(v.getDefiningOp()->getParentOp()) || - (oneHoist = subslices.size() == 1 && hoistAbove.empty())) { - isLoneHoist |= oneHoist; - hoistAbove.push_back({v, edge}); - // Recurse on the subslice. - slice.insert(subslice.begin(), subslice.end()); - layout.insert(layouts.begin(), layouts.end()); - } else { - terminals.push_back(v); - } + hoistAbove.emplace_back(v, &thenRes); + slice.insert(elseSlice.begin(), elseSlice.end()); + layout.insert(elseLayout.begin(), elseLayout.end()); } - subslices.clear(); - } while (i != slice.size()); + } // Exit early if there is nothing to do. if (hoistAbove.empty()) return; - // Check if this is a lone hoist. There should be no other terminals. - if (isLoneHoist && !terminals.empty()) - return; // Rematerialize failed hoists right before the condtional, and hoist those // that succeeded into the branch and then rewrite the slice. diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 95b167d223..6798cd85af 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2856,27 +2856,25 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-LABEL: @hoist_one_conditional tt.func @hoist_one_conditional( %arg0: i1, - %arg1: tensor<128x32x!tt.ptr, #blocked>, - %arg2: tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, - %arg3: tensor<128x128xf32, #mma> -) -> tensor<128x128xf32, #mma> { + %arg1: tensor<128x32x!tt.ptr, #blocked> +) -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> { - // CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op + // CHECK: arith.constant {{.*}} tensor<128x32xf32, #blocked> %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked> // CHECK: scf.if %0 = scf.if %arg0 -> (tensor<128x32xf32, #blocked>) { // CHECK-NEXT: [[RES:%.*]] = tt.load %3 = tt.load %arg1 : tensor<128x32x!tt.ptr, #blocked> - // CHECK-NEXT: ttg.convert_layout [[RES]] - // CHECK-NEXT: yield + // CHECK-NEXT: yield [[RES]] scf.yield %3 : tensor<128x32xf32, #blocked> } else { scf.yield %cst : tensor<128x32xf32, #blocked> } - // CHECK-NOT: ttg.convert_layout - %1 = ttg.convert_layout %0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> - %2 = tt.dot %1, %arg2, %arg3 : tensor<128x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> - tt.return %2 : tensor<128x128xf32, #mma> + // CHECK: [[TRUNC:%.*]] = arith.truncf + %1 = arith.truncf %0 : tensor<128x32xf32, #blocked> to tensor<128x32xf16, #blocked> + // CHECK-NEXT: convert_layout [[TRUNC]] + %2 = ttg.convert_layout %1 : tensor<128x32xf16, #blocked> -> tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.return %2 : tensor<128x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> } // CHECK-LABEL: @hoist_multiple_conditional From 547fba0a25a4642c7954d1e3bf17d345addcfd01 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 4 Feb 2025 07:20:14 -0500 Subject: [PATCH 2/7] [BACKEND] Disable `ldmatrix.trans` for fp8 (#5800) Since we load data in the column major format with `ldmatrix.trans`, pre-blackwell hardware seems difficult to support the transpose case. --- python/test/unit/language/test_core.py | 69 +++++++++++++++++++ test/Conversion/tritongpu_to_llvm.mlir | 2 - .../TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 2 +- 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a1dbe74711..93ae3bd35a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5828,6 +5828,75 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: assert torch.equal(z, x) +dot_layouts = [ + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=0, k_width=4), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=1, k_width=4), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=2), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), op_idx=0, k_width=1), + DotOperandLayout(parent=MmaLayout([2, 0], [4, 1], [1, 1], [1, 1], [1, 0], [16, 8]), op_idx=1, k_width=1), +] + +shared_layouts = [ + SharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]), + SharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]), + SharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N", [[16, 32]]) +@pytest.mark.parametrize("dtype", ['float16', 'float8e5', 'float32']) +@pytest.mark.parametrize("shared_layout", shared_layouts) +@pytest.mark.parametrize("dist_layout", filter_layouts(dot_layouts)) +def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, tmp_path: pathlib.Path): + if dtype == "float32": + mlir_dtype = "f32" + elif dtype == "float16": + mlir_dtype = "f16" + elif dtype == "float8e5": + mlir_dtype = "f8E5M2" + + layouts = f""" + #dist = {dist_layout} + #shared = {shared_layout} + #smem = #ttg.shared_memory + """ + ir = layouts + f""" + module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> + %2 = tt.splat %arg0 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %3 = tt.splat %arg1 : !tt.ptr<{mlir_dtype}> -> tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<{M}x1xi32, #dist> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #dist> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #ttg.slice<{{dim = 0, parent = #dist}}>> -> tensor<1x{N}xi32, #dist> + %7 = tt.broadcast %6 : tensor<1x{N}xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #dist> -> tensor<{M}x{N}xi32, #dist> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #dist> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist> + %11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + %12 = ttg.local_alloc %11 : (tensor<{M}x{N}x{mlir_dtype}, #dist>) -> !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem> + %13 = ttg.local_load %12 : !ttg.memdesc<{M}x{N}x{mlir_dtype}, #shared, #smem> -> tensor<{M}x{N}x{mlir_dtype}, #dist> + %14 = tt.addptr %3, %9 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist>, tensor<{M}x{N}xi32, #dist> + tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<{mlir_dtype}>, #dist> + tt.return + }} +}} +""" + + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) + z = torch.empty_like(x, device=device) + + temp_file = tmp_path / "test_local_load_store_dot.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](x, z) + assert torch.equal(z, x) + + mma_layouts = [ MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), # simple 4 warps case diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index aacb2053c0..c29d7657ae 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -974,8 +974,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { %AA = ttg.local_alloc %A : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> %BB = ttg.local_alloc %B : (tensor<16x16xf8E5M2, #blocked0>) -> !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> // CHECK: nvgpu.ldmatrix %{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)> - // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> - // CHECK: nvgpu.ldmatrix %{{.*}} {trans} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> // CHECK-NOT: nvgpu.ldmatrix %AA_DOT = ttg.local_load %AA : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_a> %BB_DOT = ttg.local_load %BB : !ttg.memdesc<16x16xf8E5M2, #shared0, #smem> -> tensor<16x16xf8E5M2, #dot_operand_b> diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 587d42a3e0..f4b411f875 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -54,7 +54,7 @@ struct LocalLoadOpConversion auto shape = srcTy.getShape(); // Limitation 2 [TODO: remove]: Only support 2d matrices now but we should // be able to support 3D minor changes - canUseLdmatrix &= (bitwidth <= 16 || !needTrans) && shape.size() <= 2; + canUseLdmatrix &= (bitwidth == 16 || !needTrans) && shape.size() <= 2; // Limitation 3: Minimum tile size (8)x(8x16bits) canUseLdmatrix &= shape[kOrder] >= (8 * 16 / bitwidth) && shape[nonKOrder] >= 8; From d85b664f136c7475a6469249c8e0fc95f9166c2c Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 4 Feb 2025 07:29:47 -0500 Subject: [PATCH 3/7] [PROTON] Fix incorrect tmp_path initialization (#5803) --- third_party/proton/test/test_profile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index e52c67f0bd..a673c1da61 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -210,7 +210,6 @@ def foo(x, size: tl.constexpr, y): @pytest.mark.parametrize("context", ["shadow", "python"]) def test_hook_gpu_kernel(tmp_path: pathlib.Path, context: str): - tmp_path = pathlib.Path("./") def metadata_fn(grid: tuple, metadata: NamedTuple, args: dict): x = args["x"] From ebb99b1d38b3f97ae7719a058e482d12e6c9bf52 Mon Sep 17 00:00:00 2001 From: pawelszczerbuk <153013546+pawelszczerbuk@users.noreply.github.com> Date: Tue, 4 Feb 2025 08:18:35 -0800 Subject: [PATCH 4/7] [PIPELINE] Remove outer loop pipelining transformation (#5766) Simplify pipelining by removing outer loop pipelining transformation. Performance benefits of it are smaller than pipelining fused persistent loops, while making the pipeliner harder to maintain and refactor. --- .../TritonGPU/Transforms/CMakeLists.txt | 1 - .../Pipeliner/OuterLoopPipeline.cpp | 131 ------------------ .../Pipeliner/SoftwarePipeliner.cpp | 23 --- test/TritonGPU/loop-pipeline-hopper.mlir | 11 +- test/TritonGPU/loop-pipeline.mlir | 41 ++---- 5 files changed, 13 insertions(+), 194 deletions(-) delete mode 100644 lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index fcba58fbcc..bdad398575 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -11,7 +11,6 @@ add_triton_library(TritonGPUTransforms OptimizeThreadLocality.cpp Pipeliner/AssignLatencies.cpp Pipeliner/MatmulLoopPipeline.cpp - Pipeliner/OuterLoopPipeline.cpp Pipeliner/PipelineExpander.cpp Pipeliner/TestPipelineAssignLatencies.cpp Pipeliner/TestPipelineScheduleLoop.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp deleted file mode 100644 index d8a34f6946..0000000000 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/OuterLoopPipeline.cpp +++ /dev/null @@ -1,131 +0,0 @@ -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" -#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" -#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" - -using namespace mlir; -namespace tt = mlir::triton; -namespace ttg = mlir::triton::gpu; - -// create the schedule for a matmul loop. This is ad hoc based on how we know -// matmul loops should be pipelined and is not a generic scheduler. -static std::vector> -createSchedule(scf::ForOp forOp, int numStages) { - SmallVector insertOps; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (isa(op)) - insertOps.emplace_back(&op); - } - DenseSet insertAndDeps; - for (Operation *op : insertOps) { - tt::addDep(op, insertAndDeps, true); - } - - DenseSet epilogue; - bool foundLoop = false; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (insertAndDeps.count(&op)) - continue; - if (isa(op)) - foundLoop = true; - if (isa(op)) - continue; - if (foundLoop) - epilogue.insert(&op); - } - - std::vector> schedule; - // Schedule stage 1 first. - tt::addOps(forOp, 1, schedule, [&](Operation *op) { - return insertAndDeps.count(op) == 0 && epilogue.count(op) == 0; - }); - - // Then Schedule stage 0. - tt::addOps(forOp, 0, schedule, - [&](Operation *op) { return insertAndDeps.count(op); }); - - // Then schedule the epilogue in stage 1 - tt::addOps(forOp, 1, schedule, - [&](Operation *op) { return epilogue.count(op); }); - return schedule; -} - -// pre-process the loop by hosting allocations/deallocation out of the -// loop. -static void hoistAllocAndConst(scf::ForOp forOp) { - SmallVector toHoist; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (auto allocOp = dyn_cast(op)) { - // We hoist the allocOp only if it is created by the inner loop - // pipelining. - if (!allocOp.getSrc()) - toHoist.push_back(&op); - } else if (isa(op)) { - toHoist.push_back(&op); - } - } - for (Operation *op : toHoist) { - op->moveBefore(forOp); - auto allocOp = dyn_cast(op); - if (!allocOp) - continue; - for (Operation *user : allocOp->getUsers()) { - if (auto dealloc = dyn_cast(user)) { - dealloc->moveAfter(forOp); - } - } - } -} - -static bool preCondition(scf::ForOp forOp) { - // Check if there is a dependency from the loop to the async copy op. In this - // case we cannot pipeline the async copy. - SmallVector insertOps; - int numForOps = 0; - for (Operation &op : forOp.getBody()->without_terminator()) { - if (isa(op)) - insertOps.emplace_back(&op); - if (isa(op)) - numForOps++; - } - if (insertOps.empty() || numForOps != 1) - return false; - DenseSet insertAndDeps; - for (Operation *op : insertOps) { - tt::addDep(op, insertAndDeps, true); - } - // If there is a recurrence containing both the async and the for op we cannot - // pipeline. - for (Operation *op : insertAndDeps) { - if (isa(op)) - return false; - } - return true; -} - -bool mlir::triton::getOuterLoopSchedule( - scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { - assert(numStages == 2 && "only support 2 stage pipelining for now"); - // 1. Check precondition, we cannot have a recurrence involving async cp ops - if (!preCondition(forOp)) - return false; - - // 2. pre-process the loop by hosting allocations. - hoistAllocAndConst(forOp); - - // 3. Create the final schedule for the kernel loop. This will dictate the - // stages and order of operations to the pipeline expander. - std::vector> schedule = - createSchedule(forOp, numStages); - - // 4. Fill out the pipeline options. - options.getScheduleFn = - [schedule](scf::ForOp forOp, - std::vector> &s) { - s = std::move(schedule); - }; - options.peelEpilogue = false; - options.predicateFn = mlir::triton::predicateOp; - options.supportDynamicLoops = true; - return true; -} diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index 33e7a6437a..a75f91d716 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -33,19 +33,6 @@ namespace gpu { #define GEN_PASS_DEF_TRITONGPUPIPELINE #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" -static void tryAndPipelineOuterLoop(scf::ForOp forOp) { - mlir::triton::PipeliningOption options; - bool foundSchedule = false; - // Limit 2 stages to not require extra shared memory. - foundSchedule = getOuterLoopSchedule(forOp, /*numStage=*/2, options); - if (!foundSchedule) - return; - IRRewriter rewriter(forOp->getContext()); - rewriter.setInsertionPoint(forOp); - FailureOr newForOp = - mlir::triton::pipelineForLoop(rewriter, forOp, options); -} - static scf::ForOp pipelineLoop(scf::ForOp forOp, int numStages) { mlir::triton::PipeliningOption options; @@ -92,17 +79,12 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { if (loops.empty()) return; - llvm::SmallSetVector outerLoops; llvm::SmallVector pipelinedLoops; for (scf::ForOp forOp : loops) { - auto outerLoop = dyn_cast(forOp->getParentOp()); int loopNumStages = getNumStagesOrDefault(forOp); scf::ForOp pipelinedFor = pipelineLoop(forOp, loopNumStages); if (pipelinedFor != nullptr) pipelinedLoops.push_back(pipelinedFor); - if (pipelinedFor != nullptr && outerLoop && - getNumStagesOrDefault(outerLoop) > 1) - outerLoops.insert(outerLoop); } // There is a hard dependency between load pipelining and the TC05MMA @@ -122,11 +104,6 @@ struct PipelinePass : public impl::TritonGPUPipelineBase { if (applyPatternsGreedily(getOperation(), std::move(patterns)).failed()) return signalPassFailure(); - // Try to pipeline the outer loop to overlap the prologue and epilogue of - // the inner loop. - for (scf::ForOp outerLoop : outerLoops) - tryAndPipelineOuterLoop(outerLoop); - // Re-collect loop ops loops.clear(); getOperation()->walk([&](scf::ForOp forOp) { diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index fd10ab5aea..c9d926660e 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -111,6 +111,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = ttg.local_alloc // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] @@ -121,7 +122,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: ttg.async_copy_global_to_local // CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] // CHECK: ttg.async_copy_global_to_local -// CHECK: scf.for // CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], // CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], // CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} @@ -144,15 +144,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] // CHECK: ttg.async_wait {num = 0 : i32} -// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local -// CHECK scf.yield module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr {tt.divisibility = 16 : i32}, diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 49f18939fb..73ab2f158f 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -178,6 +178,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 +// CHECK: scf.for // CHECK: %[[ABUFFER:.*]] = ttg.local_alloc // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] @@ -188,7 +189,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK: ttg.async_copy_global_to_local // CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] // CHECK: ttg.async_copy_global_to_local -// CHECK: scf.for // CHECK-DAG: %[[A0:.*]] = ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], // CHECK-DAG: %[[B0:.*]] = ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], // CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} @@ -211,14 +211,6 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: ttg.async_wait {{.*}} {num = 2 : i32} // CHECK: scf.yield {{.*}}, %[[INS_IDX_3]], %[[EXT_IDX_3]], %[[NEXT_A]], %[[NEXT_B]] // CHECK: ttg.async_wait {num = 0 : i32} -// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[ABUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local -// CHECK: ttg.memdesc_subview %[[BBUFFER]][%[[CONSTANT_1]], %[[CONSTANT_0]], %[[CONSTANT_0]]] -// CHECK: ttg.async_copy_global_to_local // CHECK scf.yield // AMD-LABEL: tt.func @matmul_loop_nested @@ -945,17 +937,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- // CHECK-LABEL: nested_loops -// CHECK: ttg.local_alloc // CHECK: scf.for -// CHECK-NOT: ttg.local_alloc -// CHECK: scf.for -// CHECK: scf.yield -// CHECK: ttg.async_wait {num = 0 : i32} +// CHECK: ttg.local_alloc // CHECK: ttg.async_copy_global_to_local // CHECK: ttg.async_commit_group // CHECK: ttg.async_copy_global_to_local // CHECK: ttg.async_commit_group -// CHECK: scf.yield +// CHECK: scf.for +// CHECK: scf.yield +// CHECK: ttg.async_wait {num = 0 : i32} // AMD-LABEL: tt.func public @nested_loops // AMD: scf.for @@ -1283,18 +1273,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @nested_loops // CHECK: tt.addptr %{{.*}}, {{.*}} // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} -// CHECK: %[[BUFFER_1:.*]] = ttg.local_alloc -// CHECK: %[[SUBVIEW_1:.*]] = ttg.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_1:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] -// CHECK: ttg.async_commit_group %[[ASYNC_COPY_1]] -// CHECK: %[[SUBVIEW_2:.*]] = ttg.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_2:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] -// CHECK: ttg.async_commit_group %[[ASYNC_COPY_2]] // CHECK: scf.for // CHECK: %[[LOAD_1:.*]] = tt.load %[[NEXT_BUFFER_1]] // CHECK: %[[BUFFER_2:.*]] = ttg.local_alloc %[[LOAD_1]] // CHECK: %[[TRANS:.*]] = ttg.memdesc_trans %[[BUFFER_2]] // CHECK: %[[LOCAL_LOAD_1:.*]] = ttg.local_load %[[TRANS]] +// CHECK: %[[BUFFER_1:.*]] = ttg.local_alloc : () +// CHECK: %[[SUBVIEW_1:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_1:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_1]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_1]] +// CHECK: %[[SUBVIEW_2:.*]] = ttg.memdesc_subview %[[BUFFER_1]] +// CHECK: %[[ASYNC_COPY_2:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_2]] +// CHECK: ttg.async_commit_group %[[ASYNC_COPY_2]] // CHECK: ttg.async_wait // CHECK: ttg.memdesc_subview %[[BUFFER_1]] // CHECK: scf.for @@ -1305,13 +1295,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: %[[ASYNC_COPY_3:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_4]] // CHECK: ttg.async_commit_group %[[ASYNC_COPY_3]] // CHECK: ttg.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[SUBVIEW_6:.*]] = ttg.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_4:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_6]] mask -// CHECK: %[[COMMIT_1:.*]] = ttg.async_commit_group %[[ASYNC_COPY_4]] -// CHECK: %[[SUBVIEW_7:.*]] = ttg.memdesc_subview %[[BUFFER_1]] -// CHECK: %[[ASYNC_COPY_5:.*]] = ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]], %[[SUBVIEW_7]] mask -// CHECK: %[[COMMIT_2:.*]] = ttg.async_commit_group %[[ASYNC_COPY_5]] -// CHECK: scf.yield %[[COMMIT_1]], %[[COMMIT_2]] // CHECK: ttg.local_dealloc %[[BUFFER_1]] // AMD-LABEL: tt.func public @nested_loops From b3524fab6aa54ec581211bfd8e5ff8983233e602 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:12:43 +0000 Subject: [PATCH 5/7] [LAYOUTS] Fix TransOp::fold (#5807) The folder for TransOp was a bit too aggresive. Sometimes it would change the representation of a layout for an equivalent one, and that's not allowed in the current state of things. We move the optimisation we had to a different canonicalizer. --- lib/Dialect/Triton/IR/Ops.cpp | 8 +++++++- lib/Dialect/TritonGPU/IR/Ops.cpp | 8 ++++++++ test/TritonGPU/canonicalize.mlir | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 2232aff13e..c0484e7b53 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -191,7 +191,13 @@ void StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, OpFoldResult TransOp::fold(FoldAdaptor adaptor) { // transpose(x, order=[0, 1, ...]) -> x if (isIota(getOrder())) { - return getSrc(); + // If the source and result types are the same, we can return the source + // If their layout is different (even if structurally equivalent), we need + // to insert a convert_layout in between as otherwise ::fold complains + // We do this in CanonicalizeConvertFromTranspose + if (getSrc().getType() == getType()) { + return getSrc(); + } } // transpose(transpose(x)) -> transpose(x) diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 64a8f8cc59..c3d8ff4940 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -85,6 +85,14 @@ struct CanonicalizeConvertFromTranspose mlir::LogicalResult matchAndRewrite(triton::TransOp op, PatternRewriter &rewriter) const override { + // transpose(x, order=[0, 1, ...]) -> x + // We turn it into a (trivial) convert_layout that may be folded away + if (isIota(op.getOrder())) { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getSrc()); + return success(); + } + // If the layouts are structurally the same, the convert is trivial auto convert = op.getSrc().getDefiningOp(); if (!convert || !isConvertTrivial(convert)) diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index 3c2d60a243..7af051dca5 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -199,3 +199,19 @@ tt.func @infer_trans(%arg0: tensor<32x32xf32, #linear>) -> tensor<32x32xf32, #bl } } + +// ----- + +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16, 16]}> +#dot_t = #ttg.linear<{register = [[1, 0], [0, 8], [8, 0], [16, 0], [32, 0], [64, 0], [128, 0], [0, 64], [0, 128]], lane = [[2, 0], [4, 0], [0, 1], [0, 2], [0, 4]], warp = [[0, 16], [0, 32]], block = []}> +#dot_linear = #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128], [64, 0], [128, 0]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = []}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @simplify_trans_trans + tt.func public @simplify_trans_trans(%arg0: tensor<256x256xf32, #dot_linear>) -> tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> { + // CHECK-NEXT: ttg.convert_layout + %a = tt.trans %arg0 {order=array} : tensor<256x256xf32, #dot_linear> -> tensor<256x256xf32, #dot_t> + %b = tt.trans %a {order=array} : tensor<256x256xf32, #dot_t> -> tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + tt.return %b : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + } +} From 716a5218908ec40a4b09a17ebce7b02d05cd64be Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Tue, 4 Feb 2025 12:19:51 -0500 Subject: [PATCH 6/7] [TEST] Use a fresh triton cache dir for warning tests (#5809) If we expect a warning, we need to use a fresh cache dir; otherwise, no warning will be thrown when the cache is hit. Also take out interpreter related code from this file. --- python/test/unit/language/test_compile_errors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 7dcc8457a2..3e168e2bb5 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -7,7 +7,7 @@ import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback -from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300 +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300 def format_exception(type, value, tb): @@ -362,7 +362,7 @@ def kernel(): @pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15]) -def test_fp8_support(dtype): +def test_fp8_support(fresh_triton_cache, dtype): warning_dtypes = [] supported_dtypes = [tl.float8e5] if is_cuda(): @@ -375,8 +375,6 @@ def test_fp8_support(dtype): elif is_hip(): if is_hip_mi300(): supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16] - elif is_interpreter(): - supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15] @triton.jit def dtype_kernel(dtype: tl.constexpr): From bf333c2ca0ac776446f2fec357b114b53ad2f32c Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 6 Feb 2025 14:44:13 +0000 Subject: [PATCH 7/7] [TEST] Do not run test_local_load_store_dot All `dot_layouts` are filtered, which impact number of skipped tests, and affects our pass rate. Signed-off-by: Whitney Tsang --- scripts/test-triton.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 03ab688098..bdf405f8b1 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -189,7 +189,7 @@ run_core_tests() { ensure_spirv_dis TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=language \ - pytest -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + pytest -k "not test_local_load_store_dot" -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu language/ --ignore=language/test_line_info.py --ignore=language/test_subprocess.py TRITON_DISABLE_LINE_INFO=1 TRITON_TEST_SUITE=subprocess \ pytest -vvv -n ${PYTEST_MAX_PROCESSES:-8} --device xpu language/test_subprocess.py