From e7710d5a4606fd5d8dcb89153f2d724dc3809209 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Fri, 25 Oct 2024 07:57:11 +0800 Subject: [PATCH] AIRSegmentLoopFusion: Fixups on `affine::DelinearizeIndexOp` and rank reduction (#752) * Support affine::DelinearizeIndexOp in scf::ForOp's iv chain * Post-shrinkage subview mutation now supports rank reduction * Unit test * Remove #include --- .../Transform/AIRDependencyScheduleOpt.cpp | 7 ++- mlir/lib/Util/Util.cpp | 40 +++++++++---- .../segment_loop_fusion.mlir | 60 +++++++++++++++++++ 3 files changed, 92 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index e37401e65..cc2d866e1 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -4559,9 +4559,10 @@ struct ShrinkMemrefSizesByAccessPattern auto shrunkMemrefType = MemRefType::get(overall_access_bounds, elemType, nullptr, memorySpace); MemRefType inferredSubViewOutputTy = - llvm::cast(memref::SubViewOp::inferResultType( - shrunkMemrefType, subViewOp.getStaticOffsets(), - subViewOp.getStaticSizes(), subViewOp.getStaticStrides())); + llvm::cast(memref::SubViewOp::inferRankReducedResultType( + subViewOp.getType().getShape(), shrunkMemrefType, + subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(), + subViewOp.getStaticStrides())); // Case 1: static size mismatches the shrunk shape. for (unsigned i = 0; i < static_sizes.size(); i++) { if (static_sizes[i] < 0) { diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 058c205df..ce7add3cf 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -1242,28 +1242,44 @@ static void updateAccessPatternByScfForNest( &pattern, SmallVector indices, OpBuilder builder) { auto loc = builder.getUnknownLoc(); - auto updateWrapAndStride = [&](Value index, int i) { - if (auto scfForOp = scf::getForInductionVarOwner(index)) { - std::get<1>(pattern)[i] = builder.create( - loc, *air::getStaticScfForTripCountAsInt(scfForOp)); - std::get<2>(pattern)[i] = builder.create( - loc, (*getConstantIntValue(scfForOp.getStep())) * - (*getConstantIntValue(std::get<2>(pattern)[i]))); - - scfForOp.getStep(); + auto updateWrapAndStride = [&](int stepSize, int tripCount, int i) { + std::get<1>(pattern)[i] = + builder.create(loc, tripCount); + std::get<2>(pattern)[i] = builder.create( + loc, stepSize * (*getConstantIntValue(std::get<2>(pattern)[i]))); + }; + // Infer data access pattern's sizes from parent scf.for loop and any affine + // op applied on the induction variable + auto inferDataAccessSizes = [](scf::ForOp scfForOp, air::ExecuteOp execOp, + Value index) { + int scfForTripCount = *air::getStaticScfForTripCountAsInt(scfForOp); + // If scf.for's iv applies affine::DelinerizeIndexOp + if (auto delinearizeOp = + dyn_cast(execOp.getChildOp())) { + int resIdx = + llvm::find(execOp.getResults(), index) - execOp.getResults().begin(); + scfForTripCount = *getConstantIntValue(delinearizeOp.getBasis()[resIdx]); } + return scfForTripCount; }; int dim = -1; for (auto index : indices) { dim++; if (getConstantIntValue(index)) continue; - updateWrapAndStride(index, dim); + if (auto scfForOp = scf::getForInductionVarOwner(index)) + updateWrapAndStride(*getConstantIntValue(scfForOp.getStep()), + *air::getStaticScfForTripCountAsInt(scfForOp), dim); if (!index.getDefiningOp()) continue; - if (auto execOp = dyn_cast(index.getDefiningOp())) + if (auto execOp = dyn_cast(index.getDefiningOp())) { for (auto oper : execOp.getChildOp()->getOperands()) - updateWrapAndStride(oper, dim); + if (auto scfForOp = scf::getForInductionVarOwner(oper)) { + int scfForTripCount = inferDataAccessSizes(scfForOp, execOp, index); + updateWrapAndStride(*getConstantIntValue(scfForOp.getStep()), + scfForTripCount, dim); + } + } } } diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir index 28abbfa7f..6f5989eae 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir @@ -934,3 +934,63 @@ func.func @func10(%arg0: memref<8x512xi32>, %arg1: memref<256x512xi32>, %arg2: m } return } + +// Affine::DelinearizeIndexOp support; rank-reduced memref::SubViewOp. + +// CHECK-LABEL: func.func @func11 +// CHECK: air.herd +// CHECK: %[[SUBVIEW0:.*]] = memref.subview{{.*}} : memref<16x16x4x4xf32, 1 : i32> to memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32> +// CHECK: %[[SUBVIEW1:.*]] = memref.subview{{.*}} : memref<1x16x4xf32, 2 : i32> to memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32> +// CHECK: %[[SUBVIEW2:.*]] = memref.subview{{.*}} : memref<1x1x16x16x4x4xbf16, 2 : i32> to memref<1x1x4x4xbf16, strided<[4096, 4096, 4, 1], offset: ?>, 2 : i32> +// CHECK: linalg.generic{{.*}} ins(%[[SUBVIEW0]], %[[SUBVIEW1]] {{.*}}outs(%[[SUBVIEW2]] + +#map17 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map18 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +func.func @func11(%arg0: memref<512x512xbf16>, %arg1: memref<512x16384xbf16>, %arg2: memref<512xf32>, %arg3: memref<512x16384xbf16>) { + %c4 = arith.constant 4 : index + %c128 = arith.constant 128 : index + %0 = air.launch async (%arg4, %arg5) in (%arg6=%c4, %arg7=%c128) attributes {id = 1 : i32} { + %1 = air.segment @matmul_elementwise_bf16_dispatch_0_matmul_512x16384x512_bf16xbf16xf32_0 async attributes {id = 2 : i32} { + %c2 = arith.constant 2 : index + %async_token, %results = air.execute -> (memref<2x2x16x16x4x4xbf16, 2 : i32>) { + %alloc = memref.alloc() : memref<2x2x16x16x4x4xbf16, 2 : i32> + air.execute_terminator %alloc : memref<2x2x16x16x4x4xbf16, 2 : i32> + } + %async_token_0, %results_1 = air.execute -> (memref<1x16x4xf32, 2 : i32>) { + %alloc = memref.alloc() : memref<1x16x4xf32, 2 : i32> + air.execute_terminator %alloc : memref<1x16x4xf32, 2 : i32> + } + %async_token_2, %results_3 = air.execute -> (memref<16x16x4x4xf32, 1 : i32>) { + %alloc = memref.alloc() : memref<16x16x4x4xf32, 1 : i32> + air.execute_terminator %alloc : memref<16x16x4x4xf32, 1 : i32> + } + %2 = air.herd @herd_0 async tile (%arg8, %arg9) in (%arg10=%c2, %arg11=%c2) args(%arg12=%results_3, %arg13=%results_1, %arg14=%results) : memref<16x16x4x4xf32, 1 : i32>, memref<1x16x4xf32, 2 : i32>, memref<2x2x16x16x4x4xbf16, 2 : i32> { + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c256 = arith.constant 256 : index + %3 = air.wait_all async + %4 = scf.for %arg15 = %c0 to %c256 step %c1 iter_args(%arg16 = %3) -> (!air.async.token) { + %async_token_4, %results_5:2 = air.execute [%arg16] -> (index, index) { + %6:2 = affine.delinearize_index %arg15 into (%c16, %c16) : index, index + air.execute_terminator %6#0, %6#1 : index, index + } + %subview = memref.subview %arg12[%results_5#0, %results_5#1, 0, 0] [1, 1, 4, 4] [1, 1, 1, 1] : memref<16x16x4x4xf32, 1 : i32> to memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32> + %subview_6 = memref.subview %arg13[0, %results_5#1, 0] [1, 1, 4] [1, 1, 1] : memref<1x16x4xf32, 2 : i32> to memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32> + %subview_7 = memref.subview %arg14[%arg8, %arg9, %results_5#0, %results_5#1, 0, 0] [1, 1, 1, 1, 4, 4] [1, 1, 1, 1, 1, 1] : memref<2x2x16x16x4x4xbf16, 2 : i32> to memref<1x1x4x4xbf16, strided<[256, 16, 4, 1], offset: ?>, 2 : i32> + %async_token_8 = air.execute [%arg16] { + linalg.generic {indexing_maps = [#map17, #map18, #map17], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview_6 : memref<1x1x4x4xf32, strided<[256, 16, 4, 1], offset: ?>, 1 : i32>, memref<1x4xf32, strided<[4, 1], offset: ?>, 2 : i32>) outs(%subview_7 : memref<1x1x4x4xbf16, strided<[256, 16, 4, 1], offset: ?>, 2 : i32>) { + ^bb0(%in: f32, %in_9: f32, %out: bf16): + %6 = arith.addf %in, %in_9 : f32 + %7 = arith.truncf %6 : f32 to bf16 + linalg.yield %7 : bf16 + } + } + %5 = air.wait_all async [%async_token_4, %async_token_8] + scf.yield %5 : !air.async.token + } + } + } + } + return +}