From 83a3eb14b3d5ed1f340b84e23936bac4fa5cf391 Mon Sep 17 00:00:00 2001 From: Pawel Szczerbuk Date: Wed, 26 Feb 2025 17:55:32 -0800 Subject: [PATCH] Marking loop as scheduled with max_stages --- .../TritonGPU/Transforms/PipeliningUtility.h | 1 + .../Transforms/Pipeliner/Schedule.cpp | 14 ++++--- .../Pipeliner/SoftwarePipeliner.cpp | 1 + test/TritonGPU/pipeline-lower-loop.mlir | 42 +++++++++---------- test/TritonGPU/pipeline-schedule-loop.mlir | 1 + 5 files changed, 32 insertions(+), 27 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index bb6e4ffdae72..a005e790a6c7 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -15,6 +15,7 @@ static const char *kDisallowAccMultiBufferAttrName = "tt.disallow_acc_multi_buffer"; static const char *kLoopStageAttrName = "loop.stage"; static const char *kLoopClusterAttrName = "loop.cluster"; +static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage"; static const char *kLatencyAttrName = "tt.latency"; bool loopHasDistGreaterThanOne(scf::ForOp forOp); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index 748b1b38e9e4..301d6382fc2c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -172,12 +172,11 @@ static std::pair getMinMaxCluster(scf::ForOp &forOp) { static std::optional tryGetMaxStage(scf::ForOp &forOp) { std::optional maxStage = std::nullopt; - for (auto &op : forOp.getBody()->without_terminator()) { - if (!op.hasAttr(mlir::triton::kLoopStageAttrName) || - !op.hasAttr(mlir::triton::kLoopClusterAttrName)) - continue; - auto [stage, _] = getStageCluster(&op); - maxStage = maxStage ? (stage > *maxStage ? stage : *maxStage) : stage; + if (forOp->hasAttr(mlir::triton::kScheduledMaxStageAttrName)) { + return forOp + ->getAttrOfType(mlir::triton::kScheduledMaxStageAttrName) + .getValue() + .getSExtValue(); } return maxStage; } @@ -187,6 +186,9 @@ void tt::CoarseSchedule::serialize(scf::ForOp &forOp) { for (auto [op, stage, cluster] : getOpsInOrder(forOp)) { setStageCluster(op, stage, *cluster); } + forOp->setAttr(mlir::triton::kScheduledMaxStageAttrName, + IntegerAttr::get(IntegerType::get(forOp.getContext(), 32), + numStages - 1)); } // Create a CoarseSchedule based on forOp's . diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp index cd05d3513a00..dbf8e01fee46 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp @@ -72,6 +72,7 @@ static void removeAttributes(ModuleOp moduleOp) { moduleOp->walk([&](Operation *op) { op->removeAttr(mlir::triton::kLoopStageAttrName); op->removeAttr(mlir::triton::kLoopClusterAttrName); + op->removeAttr(mlir::triton::kScheduledMaxStageAttrName); }); } diff --git a/test/TritonGPU/pipeline-lower-loop.mlir b/test/TritonGPU/pipeline-lower-loop.mlir index 68196862af19..a0f7179256e9 100644 --- a/test/TritonGPU/pipeline-lower-loop.mlir +++ b/test/TritonGPU/pipeline-lower-loop.mlir @@ -51,7 +51,7 @@ tt.func @one_dep_async(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step : index { %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -75,7 +75,7 @@ tt.func @different_use_stages(%lb : index, %ub : index, %step : index, %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> "use1"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () "use2"(%a) {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf16, #A>) -> () - } + } {tt.scheduled_max_stage = 3 : i32} tt.return } } @@ -106,7 +106,7 @@ tt.func @used_by_if_yield(%lb : index, %ub : index, %step : index, scf.yield %init_a : tensor<128x32xf16, #A> } {loop.cluster = 0 : i32, loop.stage = 2 : i32} "use"(%a_if) {loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf16, #A>) -> () - } + } {tt.scheduled_max_stage = 3 : i32} tt.return } } @@ -124,7 +124,7 @@ tt.func @dist1_load(%lb : index, %ub : index, %step : index, %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () scf.yield %a : tensor<128x32xf16, #A> - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -142,7 +142,7 @@ tt.func @one_dep_sync(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step : index { %a = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<1x!tt.ptr, #A> "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<1xf16, #A>) -> () - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -183,7 +183,7 @@ tt.func @one_dep_local_alloc(%lb : index, %ub : index, %step : index, %a_alloc = ttg.local_alloc %a {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> %a_load = ttg.local_load %a_alloc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory, mutable> -> tensor<128x32xf16, #A> "use"(%a_load) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -214,7 +214,7 @@ tt.func @one_load_group(%lb : index, %ub : index, %step : index, %b = tt.load %a_ptr_init {loop.cluster = 2 : i32, loop.stage = 0 : i32} : tensor<128x32x!tt.ptr, #A> "use1"(%a){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () "use2"(%b){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -255,7 +255,7 @@ tt.func @two_load_groups(%lb : index, %ub : index, %step : index, "use1"(%a){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () "use2"(%b){loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> () "use3"(%c){loop.cluster = 0 : i32, loop.stage = 3 : i32} : (tensor<128x32xf32, #A>) -> () - } + } {tt.scheduled_max_stage = 3 : i32} tt.return } } @@ -304,7 +304,7 @@ tt.func @dependent_loads(%lb : index, %ub : index, %step : index, %b = "pointerize"(%a) {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> tensor<128x32x!tt.ptr, #A> %c = tt.load %b {loop.cluster = 2 : i32, loop.stage = 2 : i32} : tensor<128x32x!tt.ptr, #A> "use1"(%c){loop.cluster = 0 : i32, loop.stage = 4 : i32} : (tensor<128x32xf32, #A>) -> () - } + } {tt.scheduled_max_stage = 4 : i32} tt.return } } @@ -361,7 +361,7 @@ tt.func @dependent_loads_asymmetric(%lb : index, %ub : index, %step : index, %b = "pointerize"(%a) {loop.cluster = 2 : i32, loop.stage = 2 : i32} : (tensor<128x32xf32, #A>) -> tensor<128x32x!tt.ptr, #A> %c = tt.load %b {loop.cluster = 2 : i32, loop.stage = 2 : i32} : tensor<128x32x!tt.ptr, #A> "use1"(%c){loop.cluster = 0 : i32, loop.stage = 5 : i32} : (tensor<128x32xf32, #A>) -> () - } + } {tt.scheduled_max_stage = 5 : i32} tt.return } } @@ -379,7 +379,7 @@ tt.func @unused_load(%lb : index, %ub : index, %step : index, // CHECK: dummy %a = tt.load %a_ptr_init {loop.cluster = 0 : i32, loop.stage = 1 : i32} : tensor<128x32x!tt.ptr, #A> "dummy"() : () -> () - } + } {tt.scheduled_max_stage = 1 : i32} tt.return } } @@ -434,7 +434,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> scf.yield %acc_res : tensor<128x128xf32, #mma> - } + } {tt.scheduled_max_stage = 2 : i32} %res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> tt.return %res_f16 : tensor<128x128xf16, #mma> } @@ -489,7 +489,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory> %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared1, #ttg.shared_memory> -> tensor<128x128xf32, #mma> scf.yield %acc_res : tensor<128x128xf32, #mma> - } + } {tt.scheduled_max_stage = 2 : i32} tt.return %res : tensor<128x128xf32, #mma> } } @@ -555,7 +555,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> scf.yield %acc_res : tensor<128x128xf32, #mma> - } + } {tt.scheduled_max_stage = 2 : i32} %res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> tt.return %res_f16 : tensor<128x128xf16, #mma> } @@ -614,7 +614,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm, %true, %true {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, i1, i1) -> () %acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked> scf.yield %acc_res : tensor<128x128xf32, #blocked> - } + } {tt.scheduled_max_stage = 2 : i32} %res_f16 = arith.truncf %res : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> tt.return %res_f16 : tensor<128x128xf16, #blocked> } @@ -669,7 +669,7 @@ tt.func @tma_load_lowering(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step : index { %a = tt.experimental_descriptor_load %desc[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -725,7 +725,7 @@ tt.func @tma_gather_lowering(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step : index { %a = tt.experimental_descriptor_gather %desc[%x, %y] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : (!tt.tensordesc>, tensor<32xi32, #offsets>, i32) -> tensor<32x128xf32, #A> "use"(%a) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x128xf32, #A>) -> () - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -760,7 +760,7 @@ tt.func @tma_reuse_barrier(%lb : index, %ub : index, %step : index, "use2"(%b) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () %c = tt.experimental_descriptor_load %descC[%offs, %offs] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x32xf16, #A> "use3"(%c) {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x32xf16, #A>) -> () - } + } {tt.scheduled_max_stage = 2 : i32} tt.return } } @@ -798,7 +798,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %B_sh = ttg.local_alloc %B {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> %acc_res = ttng.warp_group_dot %A_sh, %B_sh, %acc {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> * !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> -> tensor<128x128xf32, #mma> scf.yield %acc_res : tensor<128x128xf32, #mma> - } + } {tt.scheduled_max_stage = 2 : i32} %res_f16 = arith.truncf %res : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> tt.return %res_f16 : tensor<128x128xf16, #mma> } @@ -833,7 +833,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ scf.for %iv = %lb to %ub step %step : index { %desc = tt.make_tensor_descriptor %A, [%shape_x, %shape_y], [%strides_x, %strides_y] {loop.cluster = 0 : i32, loop.stage = 1 : i32} : , > "use"(%desc) {loop.cluster = 0 : i32, loop.stage = 1 : i32} : (!tt.tensordesc>) -> () - } + } {tt.scheduled_max_stage = 1 : i32} tt.return } } @@ -879,7 +879,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %A_sc_sh, %B_sc_sh, %true, %true lhs = e5m2 rhs = e5m2 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, i1, i1) -> () %acc_res = ttng.tmem_load %acc_tm {loop.cluster = 0 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked> scf.yield %acc_res : tensor<128x128xf32, #blocked> - } + } {tt.scheduled_max_stage = 2 : i32} tt.return %res : tensor<128x128xf32, #blocked> } } diff --git a/test/TritonGPU/pipeline-schedule-loop.mlir b/test/TritonGPU/pipeline-schedule-loop.mlir index 2d04b7462258..ad46c3871d0d 100644 --- a/test/TritonGPU/pipeline-schedule-loop.mlir +++ b/test/TritonGPU/pipeline-schedule-loop.mlir @@ -21,6 +21,7 @@ tt.func @one_dep(%lb : index, %ub : index, %step : index, %res = arith.addf %acc, %a : tensor<128x32xf16, #A> scf.yield %res : tensor<128x32xf16, #A> } + // CHECK: tt.scheduled_max_stage tt.return %loop#0 : tensor<128x32xf16, #A> }