From 27b7e1aa95b0ea1baa43dda0c6e5ecfc38bff106 Mon Sep 17 00:00:00 2001 From: somehow6 Date: Thu, 26 Sep 2024 20:43:33 +0800 Subject: [PATCH 01/17] [Midend] Enhancements and Optimizations for batch matmul and convolution [Examples] Added MLIRLinalg Examples for Various Optimization Options. fixed thirdparty. --- .../MLIRLinalg/linalg-batch-matmul-dync.mlir | 67 ++++ .../MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir | 96 +++++ .../linalg-depthwise_conv_2d_nhwc_hwc.mlir | 71 ++++ examples/MLIRLinalg/makefile | 49 +++ midend/lib/Conversion/CMakeLists.txt | 1 + .../ConvOptimization/CMakeLists.txt | 2 + .../ConvOptimization/ConvNhwcFhwcOptimize.cpp | 276 ++++++++++++++ .../ConvNhwcFhwcOptimizeTile.cpp | 342 +++++++++++++++++ .../GEMMPointwiseConv2DNhwcHwcf.cpp | 15 +- .../DepthwiseConvOptimization/CMakeLists.txt | 3 + .../DepthwiseConvNhwcHwc.cpp | 331 ++++++++++++++++ .../BatchMatMulSCFOptimize.cpp | 281 ++++++++++++++ .../BatchMatMulTileOptimize.cpp | 353 ++++++++++++++++++ .../MatMulOptimization/CMakeLists.txt | 12 +- tools/buddy-opt/CMakeLists.txt | 3 + tools/buddy-opt/buddy-opt.cpp | 22 +- 16 files changed, 1913 insertions(+), 11 deletions(-) create mode 100644 examples/MLIRLinalg/linalg-batch-matmul-dync.mlir create mode 100644 examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir create mode 100644 examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir create mode 100644 midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp create mode 100644 midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp create mode 100644 midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt create mode 100644 midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp create mode 100644 midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp create mode 100644 midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp diff --git a/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir new file mode 100644 index 0000000000..1b910e4a3e --- /dev/null +++ b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir @@ -0,0 +1,67 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + + // Definition for the batch matrix multiplication function + func.func @buddy_batchmatmul_f32(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul + ins(%A, %B: memref, memref) + outs(%C: memref) + return + } + + func.func @main(){ + // Set up dims. + %cBatch = arith.constant 10:index + %cM = arith.constant 2 : index + %cN = arith.constant 5 : index + %cK = arith.constant 4 : index + + // Set Init Value. + %cf1 = arith.constant 1.0 : f32 + %cf2 = arith.constant 2.0 : f32 + %c0 = arith.constant 0.0 : f32 + + %A = memref.alloc(%cBatch,%cM, %cK) : memref + %B = memref.alloc(%cBatch,%cK, %cN) : memref + %C = memref.alloc(%cBatch,%cM, %cN) : memref + + linalg.fill + ins(%cf1 : f32) + outs(%A:memref) + + linalg.fill + ins(%cf2 : f32) + outs(%B:memref) + + linalg.fill + ins(%c0 : f32) + outs(%C:memref) + + call @buddy_batchmatmul_f32(%A, %B, %C) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C = memref.cast %C : memref to memref<*xf32> + call @printMemrefF32(%print_C) : (memref<*xf32>) -> () + + memref.dealloc %C : memref + memref.dealloc %B : memref + memref.dealloc %A : memref + return + } +} diff --git a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir new file mode 100644 index 0000000000..2c8cc171ec --- /dev/null +++ b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir @@ -0,0 +1,96 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func @alloc_2d_filled_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %arg5 = %c0 to %arg0 step %c1 { + scf.for %arg6 = %c0 to %arg1 step %c1 { + scf.for %arg7 = %c0 to %arg2 step %c1 { + scf.for %arg8 = %c0 to %arg3 step %c1 { + %iarg8=arith.index_cast %arg8 : index to i32 + %loopf= arith.sitofp %iarg8 : i32 to f32 + memref.store %loopf, %0[%arg5, %arg6, %arg7, %arg8] : memref + } + } + } + } + return %0 : memref + } + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nhwc_fhwc ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) + return + } + func.func @main() { + // Intput(image, filter) and output value. + %cst = arith.constant 0.500000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + + %current_image_n = arith.constant 2 : index + %current_image_c = arith.constant 18 : index + %current_image_h = arith.constant 8 : index + %current_image_w = arith.constant 8 : index + + %current_filter_f = arith.constant 2 : index + %current_filter_c = arith.constant 18 : index + %current_filter_h = arith.constant 4 : index + %current_filter_w = arith.constant 4 : index + + %current_output_n = arith.constant 2 : index + %current_output_c = arith.constant 2 : index + %current_output_h = arith.constant 5 : index + %current_output_w = arith.constant 5 : index + + // Image. + %image = call @alloc_2d_filled_f32(%current_image_n,%current_image_h, %current_image_w, %current_image_c, %cst) : (index, index, index, index, f32) -> memref + // Filter. + %filter = call @alloc_2d_filled_f32(%current_filter_f, %current_filter_h, %current_filter_w,%current_filter_c, %cst) : (index, index, index, index, f32) -> memref + // Output. + %output = call @alloc_2d_filled_f32(%current_output_n, %current_output_h, %current_output_w,%current_output_c, %cst_0) : (index, index, index, index, f32) -> memref + + call @conv_2d_nhwc_fhwc(%image, %filter, %output) : (memref, memref, memref) -> () + + %3 = memref.cast %output : memref to memref<*xf32> + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [2, 2, 4, 4] strides = [32, 16, 4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ], + // CHECK-NEXT: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ] + // CHECK-SAME: ], + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ], + // CHECK-NEXT: [ + // CHECK-COUNT-3: [32, 32, 32, 32], + // CHECK-NEXT: [32, 32, 32, 32] + // CHECK-SAME: ] + // CHECK-SAME: ] + // CHECK-SAME: ] + call @printMemrefF32(%3) : (memref<*xf32>) -> () + + memref.dealloc %output : memref + memref.dealloc %image : memref + memref.dealloc %filter : memref + return + } +} + diff --git a/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir new file mode 100644 index 0000000000..510835a271 --- /dev/null +++ b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir @@ -0,0 +1,71 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + + func.func @depthwise_conv_2d_nhwc_hwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.depthwise_conv_2d_nhwc_hwc + {dilations = dense<[1,1]> : tensor<2xi64>, strides = dense<[1,1]> : tensor<2xi64>} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return + } + + func.func @main() { + // Constants for input image, filter, and output sizes. + %cst = arith.constant 0.500000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cf1 = arith.constant 1.0 : f32 + + %image_n = arith.constant 2 : index + %image_h = arith.constant 8 : index + %image_w = arith.constant 8 : index + %image_c = arith.constant 18 : index + + %filter_h = arith.constant 4 : index + %filter_w = arith.constant 4 : index + %filter_c = arith.constant 18 : index + + %output_n = arith.constant 2 : index + %output_h = arith.constant 5 : index + %output_w = arith.constant 5 : index + %output_c = arith.constant 18 : index + + %image = memref.alloc(%image_n,%image_h,%image_w,%image_c) : memref + %filter = memref.alloc(%filter_h,%filter_w,%filter_c) : memref + %output = memref.alloc(%output_n,%output_h,%output_w,%output_c) : memref + + // Allocate and fill image, filter, and output. + linalg.fill + ins(%cf1 : f32) + outs(%image:memref) + + linalg.fill + ins(%cf1 : f32) + outs(%filter:memref) + linalg.fill + ins(%cf1 : f32) + outs(%output:memref) + + // Call depthwise convolution. + call @depthwise_conv_2d_nhwc_hwc(%image, %filter, %output) : (memref, memref, memref) -> () + + %output_cast = memref.cast %output : memref to memref<*xf32> + + // Print the output. + call @printMemrefF32(%output_cast) : (memref<*xf32>) -> () + + // Deallocate memory. + memref.dealloc %output : memref + memref.dealloc %image : memref + memref.dealloc %filter : memref + return + } +} diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index ffd6888cc2..e257022013 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -60,6 +60,45 @@ linalg-conv2d-tiling-run: -convert-func-to-llvm -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} +linalg-conv2d_nhwc_fhwc-optimize-lower: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ + -conv-nhwc-fhwc-optimize="vec-size=16" \ + -o ./log.mlir + +linalg-conv2d_nhwc_fhwc-tile-optimize-lower: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ + -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ + -o ./log.mlir + +linalg-conv2d_nhwc_fhwc-optimize-run: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ + -conv-nhwc-fhwc-optimize="vec-size=16" \ + -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-conv2d_nhwc_fhwc-tile-optimize-run: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ + -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ + -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-depthwise_conv_2d_nhwc_hwc-optimize-lower: + @${BUDDY_OPT} linalg-depthwise_conv_2d_nhwc_hwc.mlir \ + -depthwise-conv-nhwc-hwc-optimize="vec-size=16" \ + -o ./log.mlir + +linalg-depthwise_conv_2d_nhwc_hwc-optimize-run: + @${BUDDY_OPT} linalg-depthwise_conv_2d_nhwc_hwc.mlir \ + -depthwise-conv-nhwc-hwc-optimize="vec-size=16" \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + linalg-generic-lower: @${MLIR_OPT} ./linalg-generic.mlir \ -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ @@ -177,6 +216,16 @@ linalg-batch-matmul-optimize-lower: -batchmatmul-optimize="vector-size=64" \ -o ./log.mlir +linalg-batch-matmul-tile-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ + -o ./log.mlir + +linalg-batch-matmul-scf-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-scf-optimize="vector-size=64" \ + -o ./log.mlir + linalg-batch-matmul-optimize-translate: @${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-optimize="vector-size=64" \ diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index 99254e4104..cfe12a8d6c 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -14,3 +14,4 @@ add_subdirectory(LowerLinalgToGemmini) add_subdirectory(SchedulingOnDevices) add_subdirectory(LowerSche) add_subdirectory(FuncBufferize) +add_subdirectory(DepthwiseConvOptimization) diff --git a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt index fc88a92ef6..336c95a303 100644 --- a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt @@ -1,3 +1,5 @@ add_mlir_library(ConvOptimization ConvOptimize.cpp + ConvNhwcFhwcOptimize.cpp + ConvNhwcFhwcOptimizeTile.cpp ) diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp new file mode 100644 index 0000000000..e4bc67e361 --- /dev/null +++ b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp @@ -0,0 +1,276 @@ +//====- ConvNhwcFhwcOptimize.cpp----------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Conv2DNhwcFhwcOp optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class ConvNhwcFhwcOptimizePattern : public ConversionPattern { +public: + explicit ConvNhwcFhwcOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC + Value IC = rewriter.create(loc, input, 3); // IC + Value FH = rewriter.create(loc, filter, 1); // FH + Value FW = rewriter.create(loc, filter, 2); // FW + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW,OC + auto outputForAllOp = rewriter.create( + loc, SmallVector({N, OH, OW, OC}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + Value ivOH = loopIndices[1]; // Index for the second dimension OH + Value ivOW = loopIndices[2]; // Index for the third dimension OW + Value ivOC = loopIndices[3]; // Index for the third dimension OC + + Value addRes = nestedBuilder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + // IC + auto forOp = nestedBuilder.create( + nestedLoc, c0, IC, vecSizeValue, ValueRange{addRes}, + [&](OpBuilder &builder, Location loc, Value ivIC, + ValueRange iargs) { + Value tVec; + if (isa(elemTy)) { + tVec = builder.create(loc, vecTy, + zeroElementType); + } else { + tVec = builder.create(loc, vecTy, + zeroElementType); + } + + Value remainLen = builder.create( + loc, + AffineMap::get(2, 1, {-d0 + s0, d1}, builder.getContext()), + ValueRange{ivIC, vecSizeValue, IC}); + Value remainMask = builder.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{remainLen}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = ivFW; + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, ivIC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{ivOC, rowFilter, columnFilter, + ivIC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + auto reduceVecOp = builder.create( + loc, vector::CombiningKind::ADD, forOp.getResult(0)); + auto maskedOp = + cast(mlir::vector::maskOperation( + builder, reduceVecOp, remainMask)); + Value reduceVec = maskedOp->getResult(0); + Value addNext; + if (isa(elemTy)) { + addNext = + builder.create(loc, iargs[0], reduceVec); + } else { + addNext = + builder.create(loc, iargs[0], reduceVec); + } + builder.create(loc, ValueRange{addNext}); + }); + + nestedBuilder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConvNhwcFhwcOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class ConvNhwcFhwcOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvNhwcFhwcOptimizePass) + StringRef getArgument() const final { return "conv-nhwc-fhwc-optimize"; } + StringRef getDescription() const final { + return "Conv2d NHWC FHWC optimize."; + } + ConvNhwcFhwcOptimizePass() = default; + ConvNhwcFhwcOptimizePass(const ConvNhwcFhwcOptimizePass &) {} + explicit ConvNhwcFhwcOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void ConvNhwcFhwcOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConvNhwcFhwcOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp new file mode 100644 index 0000000000..db812aceb7 --- /dev/null +++ b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp @@ -0,0 +1,342 @@ +//====- ConvNhwcFhwcOptimizeTile.cpp------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Conv2DNhwcFhwcOp tile optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class ConvNhwcFhwcTileOptimizePattern : public ConversionPattern { +public: + explicit ConvNhwcFhwcTileOptimizePattern(MLIRContext *context, + int64_t vecSizeParam, + int64_t tilingOHParam, + int64_t tilingOWParam, + int64_t tilingOCParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + tilingOH = tilingOHParam; + tilingOW = tilingOWParam; + tilingOC = tilingOCParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC + Value IC = rewriter.create(loc, input, 3); // IC + Value FH = rewriter.create(loc, filter, 1); // FH + Value FW = rewriter.create(loc, filter, 2); // FW + + auto tilingUpperBound = + AffineMap::get(2, 1, {d0 + d1, s0}, rewriter.getContext()); + + Value stepOH = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOH)), OH); + Value stepOW = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOW)), OW); + Value stepOC = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOC)), OC); + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW,OC + rewriter.create( + loc, SmallVector{c0, c0, c0, c0}, + SmallVector({N, OH, OW, OC}), + SmallVector({c1, stepOH, stepOW, stepOC}), + ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + + Value ubOH = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[1], stepOH, + OH}); // ub for the second dimension OH + Value ubOW = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[2], stepOW, + OW}); // ub for the second dimension OW + Value ubOC = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[3], stepOC, + OC}); // ub for the second dimension OC + + rewriter.create( + loc, + SmallVector{loopIndices[1], loopIndices[2], + loopIndices[3]}, + SmallVector({ubOH, ubOW, ubOC}), + SmallVector({c1, c1, c1}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivOH = loopIndices[0]; // Index for the first dimension OH + Value ivOW = loopIndices[1]; // Index for the first dimension OW + Value ivOC = loopIndices[2]; // Index for the first dimension OC + + Value addRes = nestedBuilder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + // IC + auto forOp = nestedBuilder.create( + nestedLoc, c0, IC, vecSizeValue, ValueRange{addRes}, + [&](OpBuilder &builder, Location loc, Value ivIC, + ValueRange iargs) { + Value tVec; + if (isa(elemTy)) { + tVec = builder.create( + loc, vecTy, zeroElementType); + } else { + tVec = builder.create(loc, vecTy, + zeroElementType); + } + + Value remainLen = builder.create( + loc, + AffineMap::get(2, 1, {-d0 + s0, d1}, + builder.getContext()), + ValueRange{ivIC, vecSizeValue, IC}); + Value remainMask = builder.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{remainLen}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, + Value ivFW, ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get(2, 0, + d0 * strWidth + + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, + ivIC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{ivOC, rowFilter, columnFilter, + ivIC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = + builder.create(loc, iVec, + fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create( + loc, ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + auto reduceVecOp = builder.create( + loc, vector::CombiningKind::ADD, forOp.getResult(0)); + auto maskedOp = + cast(mlir::vector::maskOperation( + builder, reduceVecOp, remainMask)); + Value reduceVec = maskedOp->getResult(0); + Value addNext; + if (isa(elemTy)) { + addNext = builder.create(loc, iargs[0], + reduceVec); + } else { + addNext = builder.create(loc, iargs[0], + reduceVec); + } + builder.create(loc, ValueRange{addNext}); + }); + + nestedBuilder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + nestedBuilder.create(nestedLoc); + }); + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; + int64_t tilingOH; + int64_t tilingOW; + int64_t tilingOC; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConvNhwcFhwcTileOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class ConvNhwcFhwcTileOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvNhwcFhwcTileOptimizePass) + StringRef getArgument() const final { return "conv-nhwc-fhwc-tile-optimize"; } + StringRef getDescription() const final { + return "Conv2d NHWC FHWC optimize with Tile."; + } + ConvNhwcFhwcTileOptimizePass() = default; + ConvNhwcFhwcTileOptimizePass(const ConvNhwcFhwcTileOptimizePass &) {} + explicit ConvNhwcFhwcTileOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; + Option tilingOH{*this, "tiling-height", + llvm::cl::desc("tiling the output height."), + llvm::cl::init(1)}; + Option tilingOW{*this, "tiling-width", + llvm::cl::desc("tiling the output width."), + llvm::cl::init(1)}; + Option tilingOC{*this, "tiling-channel", + llvm::cl::desc("tiling the output channel."), + llvm::cl::init(1)}; +}; +} // end anonymous namespace. + +void ConvNhwcFhwcTileOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize, tilingOH, + tilingOW, tilingOC); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConvNhwcFhwcTileOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp b/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp index 55c876dd63..918a1388d6 100644 --- a/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp +++ b/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp @@ -122,8 +122,7 @@ class GEMMPointwiseConvPattern : public ConversionPattern { namespace { class PointwiseConvToGemmPass - : public PassWrapper> { + : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PointwiseConvToGemmPass) StringRef getArgument() const final { return "pointwise-conv-to-gemm"; } @@ -144,14 +143,20 @@ class PointwiseConvToGemmPass void PointwiseConvToGemmPass::runOnOperation() { MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); ConversionTarget target(*context); - target.addLegalDialect(); + target + .addLegalDialect(); target.addLegalOp(); target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); } namespace mlir { diff --git a/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt b/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt new file mode 100644 index 0000000000..8493e2a60a --- /dev/null +++ b/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_library(DepthwiseConvOptimization + DepthwiseConvNhwcHwc.cpp + ) diff --git a/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp b/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp new file mode 100644 index 0000000000..04bf76f769 --- /dev/null +++ b/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp @@ -0,0 +1,331 @@ +//====- DepthwiseConvNhwcHwc.cpp +//--------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the DepthwiseConvNhwcHwc optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class DepthwiseConv2DNhwcHwcOptimizePattern : public ConversionPattern { +public: + explicit DepthwiseConv2DNhwcHwcOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::DepthwiseConv2DNhwcHwcOp::getOperationName(), + 1, context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + Value zeroElementTypeVec; + if (isa(elemTy)) { + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + } else { + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + } + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC/FC/IC + + Value applyOC = rewriter.create( + loc, AffineMap::get(1, 0, d0.floorDiv(vecSize) * vecSize), OC); + Value tailLength = rewriter.create( + loc, AffineMap::get(1, 0, d0 % vecSize), ValueRange{OC}); + Value maskVector = rewriter.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{tailLength}); + + Value FH = rewriter.create(loc, filter, 0); // FH + Value FW = rewriter.create(loc, filter, 1); // FW + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW + auto outputForAllOp = rewriter.create( + loc, SmallVector({N, OH, OW}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + Value ivOH = loopIndices[1]; // Index for the second dimension OH + Value ivOW = loopIndices[2]; // Index for the third dimension OW + // OC + nestedBuilder.create( + nestedLoc, c0, applyOC, vecSizeValue, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value ivOC, + ValueRange iargs) { + Value tVec = builder.create( + loc, vecTy, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, ivOC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{rowFilter, columnFilter, ivOC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + builder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + + builder.create(loc, ValueRange{std::nullopt}); + }); + + // applyOC + Value condition = nestedBuilder.create( + loc, arith::CmpIPredicate::sgt, tailLength, c0); + nestedBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value tVec = builder.create( + loc, vecTy, output, ValueRange{ivN, ivOH, ivOW, applyOC}, + maskVector, zeroElementTypeVec); + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, applyOC}, + maskVector, zeroElementTypeVec); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{rowFilter, columnFilter, applyOC}, + maskVector, zeroElementTypeVec); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + builder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, applyOC}, + maskVector, forOp.getResult(0)); + builder.create(loc, ValueRange{std::nullopt}); + }); + + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// DepthwiseConv2DNhwcHwcOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class DepthwiseConv2DNhwcHwcOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + DepthwiseConv2DNhwcHwcOptimizePass) + StringRef getArgument() const final { + return "depthwise-conv-nhwc-hwc-optimize"; + } + StringRef getDescription() const final { + return "Depthwise Conv2d NHWC HWC optimize."; + } + DepthwiseConv2DNhwcHwcOptimizePass() = default; + DepthwiseConv2DNhwcHwcOptimizePass( + const DepthwiseConv2DNhwcHwcOptimizePass &) {} + explicit DepthwiseConv2DNhwcHwcOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void DepthwiseConv2DNhwcHwcOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerDepthwiseConv2DNhwcHwcOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp new file mode 100644 index 0000000000..a3d079be22 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp @@ -0,0 +1,281 @@ +//===- BatchMatMulOptimize.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the batchmatmul scf vectorization optimization. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; +using namespace affine; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class BatchMatMuSCFOptimizePattern : public ConversionPattern { +private: + int64_t vecSize; + +public: + explicit BatchMatMuSCFOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Retrieve input tensors A, B, and C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Acquire the element type of input tensors. + Type elementType = A.getType().cast().getElementType(); + + // Define constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value cVecSize = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr d2 = rewriter.getAffineDimExpr(2); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); + + const Value zeroElementType = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, 0); + Value aRow = rewriter.create(loc, A, 1); + Value bCol = rewriter.create(loc, B, 2); + Value bRow = rewriter.create(loc, B, 1); + + VectorType vecTy = VectorType::get({vecSize}, elementType); + Value zeroElementTypeVec; + if (isa(elementType)) + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + else + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + // Calculate the length of the tail, which might not fit in a + // vector. + Value tailLength = rewriter.create( + loc, AffineMap::get(1, 0, d0 % vecSize), ValueRange{bCol}); + + // Generate a mask vector based on the tail length. + Value maskVector = rewriter.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{tailLength}); + + Value ApplyBCol = rewriter.create( + loc, AffineMap::get(1, 0, d0.floorDiv(vecSize) * vecSize), bCol); + + rewriter.create( + loc, SmallVector({c0}), + SmallVector({batch}), + SmallVector({c1}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &builder, Location loc, ValueRange loopIndices) { + Value loopVarBatchIdx = loopIndices[0]; + builder.create( + loc, c0, aRow, c1, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value loopVarRowOfA, + ValueRange iargs) { + builder.create( + loc, c0, bRow, c1, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value loopVarRowOfB, + ValueRange iargs) { + Value aElement = builder.create( + loc, A, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarRowOfB}); + Value aVec = builder.create( + loc, vecTy, aElement); + builder.create( + loc, c0, ApplyBCol, cVecSize, + ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, + Value loopVarColOfB, ValueRange iargs) { + Value bVec = builder.create( + loc, vecTy, B, + ValueRange{loopVarBatchIdx, loopVarRowOfB, + loopVarColOfB}); + + Value cVec = builder.create( + loc, vecTy, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarColOfB}); + Value computedVec; + + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, cVec); + } else { + computedVec = builder.create( + loc, aVec, bVec, cVec); + } + builder.create( + loc, computedVec, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarColOfB}); + builder.create( + loc, ValueRange{std::nullopt}); + }); + Value condition = builder.create( + loc, arith::CmpIPredicate::sgt, tailLength, c0); + builder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + Value bVec = builder.create( + loc, vecTy, B, + ValueRange{loopVarBatchIdx, loopVarRowOfB, + ApplyBCol}, + maskVector, zeroElementTypeVec); + + Value cVec = builder.create( + loc, vecTy, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + ApplyBCol}, + maskVector, zeroElementTypeVec); + + Value computedVec; + + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, cVec); + } else { + computedVec = builder.create( + loc, aVec, bVec, cVec); + } + + builder.create( + loc, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + ApplyBCol}, + maskVector, computedVec); + builder.create(loc); + }); + builder.create(loc, + ValueRange{std::nullopt}); + }); + builder.create(loc, ValueRange{std::nullopt}); + }); + + builder.create(loc); + }); + + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BatchMatMuSCFOptimize +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class BatchMatMuSCFOptimize + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BatchMatMuSCFOptimize) + StringRef getArgument() const final { return "batchmatmul-scf-optimize"; } + StringRef getDescription() const final { + return "BatchMatMul SCF Optimization."; + } + BatchMatMuSCFOptimize() = default; + BatchMatMuSCFOptimize(const BatchMatMuSCFOptimize &) {} + explicit BatchMatMuSCFOptimize(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vector-size", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void BatchMatMuSCFOptimize::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} +// add to buddy-opt.cpp +namespace mlir { +namespace buddy { +void registerBatchMatMuSCFOptimize() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp new file mode 100644 index 0000000000..91d10c6456 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp @@ -0,0 +1,353 @@ +//===- BatchMatMulOptimize.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the batchmatmul tile optimization. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; +using namespace affine; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class BatchMatMulTileOptimizePattern : public ConversionPattern { +private: + int64_t vecSize, kernelM, kernelN; + +public: + explicit BatchMatMulTileOptimizePattern(MLIRContext *context, + int64_t vecSizeParam, + int64_t kernelMParam, + int64_t kernelNParam) + : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Retrieve input tensors A, B, and C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Acquire the element type of input tensors. + Type elementType = A.getType().cast().getElementType(); + ShapedType ATy = A.getType().cast(); + + // Define constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr d2 = rewriter.getAffineDimExpr(2); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + const AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + const AffineExpr s2 = rewriter.getAffineSymbolExpr(2); + + const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); + + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, 0); + Value M = rewriter.create(loc, A, 1); // aRow + Value K = rewriter.create(loc, B, 1); // bRow + Value N = rewriter.create(loc, B, 2); // bCol + + SmallVector reducedValues = llvm::to_vector<4>( + llvm::map_range(ArrayRef{}, + [](const LoopReduction &red) { return red.value; })); + + // Configs + int64_t kNLen = vecSize * kernelN; + + // Create the primary parallel batch level loop. + AffineParallelOp parallelBatchLoop = + rewriter.create( + loc, ValueRange(reducedValues).getTypes(), ValueRange{batch}, + ArrayRef{ + rewriter.getNamedAttr("lowerBoundsGroups", + rewriter.getI32TensorAttr({1})), + rewriter.getNamedAttr("upperBoundsGroups", + rewriter.getI32TensorAttr({1})), + rewriter.getNamedAttr( + "lowerBoundsMap", + AffineMapAttr::get(AffineMap::get(0, 0, {zeroAffine}, + rewriter.getContext()))), + rewriter.getNamedAttr("upperBoundsMap", + AffineMapAttr::get(AffineMap::get( + 1, 0, {d0}, rewriter.getContext()))), + rewriter.getNamedAttr("reductions", rewriter.getArrayAttr({})), + rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({1}))}); + + // Create the loop body for the parallel loop. + Block *loopBody = new Block(); + rewriter.setInsertionPointToStart(loopBody); + loopBody->addArgument(rewriter.getIndexType(), loc); + Value loopVarBatchIdx = loopBody->getArguments()[0]; + + // Prefetching data from tensor 'A' for better cache utilization. + rewriter.create( + loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()), + ArrayRef{loopVarBatchIdx, M, K}, false, 3, true); + + // build loop body + affine::buildAffineLoopNest( + rewriter, loc, {c0}, {N}, kNLen, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + auto ivJ = ivRange.front(); + affine::buildAffineLoopNest( + builder, loc, {c0}, {M}, kernelM, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivI = ivRange.front(); + SmallVector cptrs; + + const VectorType vTy = + VectorType::get(vecSize, ATy.getElementType()); + + for (int i = 0; i < kernelM; i++) { + Value fixedIV = builder.create( + loc, + AffineMap::get(1, 1, {d0 + i, s0 - 1}, + builder.getContext()), + SmallVector{ivI, M}); + MemRefType resTy = MemRefType::get( + ATy.getShape(), ATy.getElementType(), + AffineMap::get(3, 3, d1 * s2 + d0 * s1 + s0 + d2)); + auto cptr = builder.create( + loc, resTy, C, + SmallVector{loopVarBatchIdx, fixedIV, c0}, + SmallVector{c1, c1, N}, + SmallVector{c1, c1, c1}); + cptrs.push_back(cptr); + } + affine::buildAffineLoopNest( + builder, loc, {c0}, {K}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivK = ivRange.front(); + SmallVector bs; + + for (int j = 0; j < kernelN; j++) { + Value fixedJV = ivJ; + if (j != 0) { + fixedJV = builder.create( + loc, AffineMap::get(1, 0, d0 + j * vecSize), ivJ); + } + bs.push_back(builder.create( + loc, vTy, B, + ValueRange{loopVarBatchIdx, ivK, fixedJV})); + } + + for (int i = 0; i < kernelM; ++i) { + Value fixedIV = ivI; + if (i != 0) { + fixedIV = builder.create( + loc, + AffineMap::get(1, 0, {d0 + i}, + builder.getContext()), + SmallVector{ivI}); + } + affine::AffineIfOp mBranchingOp = + builder.create( + loc, + IntegerSet::get(1, 1, {-d0 + s0 - 1}, {false}), + ValueRange{fixedIV, M}, false); + OpBuilder mTrueBranchBuilder = + mBranchingOp.getThenBodyBuilder(); + Value ksubAElement = + mTrueBranchBuilder.create( + loc, A, + ValueRange{loopVarBatchIdx, fixedIV, ivK}); + + for (int j = 0; j < kernelN; j++) { + Value fixedJV = ivJ; + if (j != 0) { + fixedJV = + mTrueBranchBuilder + .create( + loc, + AffineMap::get(1, 0, d0 + j * vecSize), + ivJ); + } + Value vecC = mTrueBranchBuilder.create( + loc, vTy, cptrs[i], ValueRange{c0, c0, fixedJV}); + if (isa(elementType)) { + Value vecA = + mTrueBranchBuilder.create( + loc, vTy, ksubAElement); + Value vecMul = + mTrueBranchBuilder.create( + loc, vTy, vecA, bs[j]); + vecC = mTrueBranchBuilder.create( + loc, vTy, vecMul, vecC); + } else { + Value vecA = + mTrueBranchBuilder.create( + loc, vTy, ksubAElement); + vecC = mTrueBranchBuilder.create( + loc, vTy, vecA, bs[j], vecC); + } + // store vecC + Value tailLength = + mTrueBranchBuilder.create( + loc, AffineMap::get(2, 0, -d0 + d1), + ValueRange{fixedJV, N}); + affine::AffineIfOp nBranchingOp = + mTrueBranchBuilder.create( + loc, + IntegerSet::get(1, 0, {-vecSize + d0}, + {false}), + ValueRange{tailLength}, true); + // Calculate the length of the tail, which might not + // fit in a vector. + OpBuilder nTrueBranchBuilder = + nBranchingOp.getThenBodyBuilder(); + nTrueBranchBuilder.create( + loc, vecC, cptrs[i], ValueRange{c0, c0, fixedJV}); + OpBuilder nFalseBranchBuilder = + nBranchingOp.getElseBodyBuilder(); + // Generate a mask vector based on the tail length. + Value maskVector = + nFalseBranchBuilder.create( + loc, + VectorType::get({vecSize}, + rewriter.getI1Type()), + ValueRange{tailLength}); + nFalseBranchBuilder.create( + loc, cptrs[i], ValueRange{c0, c0, fixedJV}, + maskVector, vecC); + } + } + }); + }); + }); + + rewriter.create(loc); + + // Finalize the loop and erase the original operation. + parallelBatchLoop.getRegion().push_back(loopBody); + rewriter.setInsertionPointAfter(parallelBatchLoop); + + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BatchMatMulTileOptimizePass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class BatchMatMulTileOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BatchMatMulTileOptimizePass) + StringRef getArgument() const final { return "batchmatmul-tile-optimize"; } + StringRef getDescription() const final { + return "BatchMatMul Tile Optimization."; + } + BatchMatMulTileOptimizePass() = default; + BatchMatMulTileOptimizePass(const BatchMatMulTileOptimizePass &) {} + explicit BatchMatMulTileOptimizePass(int64_t vecSizeParam, + int64_t kernelMParam, + int64_t kernelNParam) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(16)}; + + Option kernelM{*this, "kernel-m", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(4)}; + + Option kernelN{*this, "kernel-n", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(2)}; +}; +} // end anonymous namespace. + +void BatchMatMulTileOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize, kernelM, + kernelN); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} +// add to buddy-opt.cpp +namespace mlir { +namespace buddy { +void registerBatchMatMulTileOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index 8e726863eb..2803af674e 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -1,8 +1,10 @@ add_mlir_library(MatMulOptimization - BatchMatMulOptimize.cpp MatMulOptimize.cpp MatMulVectorization.cpp MatMulParallelVectorization.cpp + BatchMatMulOptimize.cpp + BatchMatMulTileOptimize.cpp + BatchMatMulSCFOptimize.cpp LINK_LIBS PUBLIC BuddyUtils ) @@ -11,6 +13,14 @@ add_mlir_library(BatchMatMulOptimization BatchMatMulOptimize.cpp ) +add_mlir_library(BatchMatMulTileOptimization + BatchMatMulTileOptimize.cpp +) + +add_mlir_library(BatchMatMulSCFOptimization + BatchMatMulSCFOptimize.cpp +) + add_mlir_library(MatMulParallelVectorization MatMulParallelVectorization.cpp ) diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 24bcde9359..94109d28d7 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -26,9 +26,12 @@ target_link_libraries(buddy-opt LowerRVVPass MatMulOptimization BatchMatMulOptimization + BatchMatMulTileOptimization + BatchMatMulSCFOptimization MatMulParallelVectorization TransposeOptimization ConvOptimization + DepthwiseConvOptimization VectorExp LowerVectorExpPass BuddyGemmini diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index bea9513b5e..a40fda18f8 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -40,31 +40,37 @@ #include "DAP/DAPOps.h" #include "DIP/DIPDialect.h" #include "DIP/DIPOps.h" -#include "RVV/RVVDialect.h" -#include "VectorExp/VectorExpDialect.h" -#include "VectorExp/VectorExpOps.h" #include "Gemmini/GemminiDialect.h" #include "Gemmini/GemminiOps.h" +#include "RVV/RVVDialect.h" #include "Sche/ScheDialect.h" #include "Sche/ScheOps.h" +#include "VectorExp/VectorExpDialect.h" +#include "VectorExp/VectorExpOps.h" namespace mlir { namespace buddy { void registerConvVectorizationPass(); void registerPointwiseConvToGemmPass(); +void registerPointwiseConvToGemmForNhwcFhwcPass(); void registerPoolingVectorizationPass(); void registerLowerBudPass(); void registerLowerDIPPass(); +void registerBatchMatMulOptimizePass(); +void registerBatchMatMulTileOptimizePass(); +void registerBatchMatMuSCFOptimize(); void registerLowerDAPPass(); void registerExtendDAPPass(); void registerDAPVectorizePass(); void registerLowerRVVPass(); -void registerBatchMatMulOptimizePass(); void registerMatMulOptimizePass(); void registerMatMulVectorizationPass(); void registerMatMulParallelVectorizationPass(); void registerTransposeOptimizationPass(); void registerConvOptimizePass(); +void registerConvNhwcFhwcOptimizePass(); +void registerConvNhwcFhwcTileOptimizePass(); +void registerDepthwiseConv2DNhwcHwcOptimizePass(); void registerLowerVectorExpPass(); void registerLowerGemminiPass(); void registerLowerLinalgToGemminiPass(); @@ -78,6 +84,7 @@ int main(int argc, char **argv) { // Register all MLIR passes. mlir::registerAllPasses(); mlir::buddy::registerPointwiseConvToGemmPass(); + // mlir::buddy::registerPointwiseConvToGemmForNhwcFhwcPass(); // Register Vectorization of Convolution. mlir::buddy::registerConvVectorizationPass(); // Register Vectorization of Pooling. @@ -95,11 +102,16 @@ int main(int argc, char **argv) { // Register Several Optimize Pass. mlir::buddy::registerMatMulOptimizePass(); + mlir::buddy::registerBatchMatMulOptimizePass(); + mlir::buddy::registerBatchMatMulTileOptimizePass(); + mlir::buddy::registerBatchMatMuSCFOptimize(); mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerMatMulParallelVectorizationPass(); - mlir::buddy::registerBatchMatMulOptimizePass(); mlir::buddy::registerTransposeOptimizationPass(); mlir::buddy::registerConvOptimizePass(); + mlir::buddy::registerConvNhwcFhwcOptimizePass(); + mlir::buddy::registerConvNhwcFhwcTileOptimizePass(); + mlir::buddy::registerDepthwiseConv2DNhwcHwcOptimizePass(); mlir::buddy::registerDeviceSchedulePass(); mlir::buddy::registerLowerSchePass(); mlir::buddy::registerFuncBufferizeDynamicOffsetPass(); From fd9059c840802beeab1d9714ed8e634348c88201 Mon Sep 17 00:00:00 2001 From: somehow6 Date: Thu, 10 Oct 2024 19:05:50 +0800 Subject: [PATCH 02/17] [Midend] add description for conv-nhwc-fhwc-tile-optimize --- .../ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp index db812aceb7..41e1c066ee 100644 --- a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp +++ b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp @@ -303,14 +303,14 @@ class ConvNhwcFhwcTileOptimizePass Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), llvm::cl::init(16)}; Option tilingOH{*this, "tiling-height", - llvm::cl::desc("tiling the output height."), + llvm::cl::desc("number of the output height tiles."), llvm::cl::init(1)}; Option tilingOW{*this, "tiling-width", - llvm::cl::desc("tiling the output width."), - llvm::cl::init(1)}; - Option tilingOC{*this, "tiling-channel", - llvm::cl::desc("tiling the output channel."), + llvm::cl::desc("number of the output width tiles."), llvm::cl::init(1)}; + Option tilingOC{ + *this, "tiling-channel", + llvm::cl::desc("number of the output channel tiles."), llvm::cl::init(1)}; }; } // end anonymous namespace. From b75bbb6c8b29ff66df2285a7cbdaf445b017668d Mon Sep 17 00:00:00 2001 From: somehow6 Date: Fri, 11 Oct 2024 19:23:17 +0800 Subject: [PATCH 03/17] [Midend] modify batch matmul the cmake construction and example MLIRLinalg --- examples/MLIRLinalg/makefile | 41 ++++++++++++++++--- .../MatMulOptimization/CMakeLists.txt | 6 --- tools/buddy-opt/CMakeLists.txt | 2 - tools/buddy-opt/buddy-opt.cpp | 1 - 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index e257022013..d9a37926f4 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -64,11 +64,6 @@ linalg-conv2d_nhwc_fhwc-optimize-lower: @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ -conv-nhwc-fhwc-optimize="vec-size=16" \ -o ./log.mlir - -linalg-conv2d_nhwc_fhwc-tile-optimize-lower: - @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ - -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ - -o ./log.mlir linalg-conv2d_nhwc_fhwc-optimize-run: @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ @@ -78,6 +73,12 @@ linalg-conv2d_nhwc_fhwc-optimize-run: -convert-func-to-llvm -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-conv2d_nhwc_fhwc-tile-optimize-lower: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ + -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ + -o ./log.mlir + linalg-conv2d_nhwc_fhwc-tile-optimize-run: @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ @@ -221,11 +222,41 @@ linalg-batch-matmul-tile-optimize-lower: -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ -o ./log.mlir +linalg-batch-matmul-tile-optimize-run: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + linalg-batch-matmul-scf-optimize-lower: @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-scf-optimize="vector-size=64" \ -o ./log.mlir +linalg-batch-matmul-scf-optimize-run: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-scf-optimize="vector-size=64" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + linalg-batch-matmul-optimize-translate: @${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-optimize="vector-size=64" \ diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index 2803af674e..047a7f5c48 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -11,13 +11,7 @@ add_mlir_library(MatMulOptimization add_mlir_library(BatchMatMulOptimization BatchMatMulOptimize.cpp -) - -add_mlir_library(BatchMatMulTileOptimization BatchMatMulTileOptimize.cpp -) - -add_mlir_library(BatchMatMulSCFOptimization BatchMatMulSCFOptimize.cpp ) diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 94109d28d7..7ab19cbcd1 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -26,8 +26,6 @@ target_link_libraries(buddy-opt LowerRVVPass MatMulOptimization BatchMatMulOptimization - BatchMatMulTileOptimization - BatchMatMulSCFOptimization MatMulParallelVectorization TransposeOptimization ConvOptimization diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index a40fda18f8..c737b5cf72 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -84,7 +84,6 @@ int main(int argc, char **argv) { // Register all MLIR passes. mlir::registerAllPasses(); mlir::buddy::registerPointwiseConvToGemmPass(); - // mlir::buddy::registerPointwiseConvToGemmForNhwcFhwcPass(); // Register Vectorization of Convolution. mlir::buddy::registerConvVectorizationPass(); // Register Vectorization of Pooling. From 9e9d6b26f93485863f1c13bd0f6093943362fc47 Mon Sep 17 00:00:00 2001 From: somehow6 Date: Fri, 11 Oct 2024 19:27:45 +0800 Subject: [PATCH 04/17] [Midend] modify conv opt file name --- midend/lib/Conversion/ConvOptimization/CMakeLists.txt | 2 +- ...onvNhwcFhwcOptimizeTile.cpp => ConvNhwcFhwcTileOptimize.cpp} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename midend/lib/Conversion/ConvOptimization/{ConvNhwcFhwcOptimizeTile.cpp => ConvNhwcFhwcTileOptimize.cpp} (100%) diff --git a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt index 336c95a303..9f77079d38 100644 --- a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_library(ConvOptimization ConvOptimize.cpp ConvNhwcFhwcOptimize.cpp - ConvNhwcFhwcOptimizeTile.cpp + ConvNhwcFhwcTileOptimize.cpp ) diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcTileOptimize.cpp similarity index 100% rename from midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimizeTile.cpp rename to midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcTileOptimize.cpp From 27361e9efdbccfd8a33e78d9a0ed929955094ae4 Mon Sep 17 00:00:00 2001 From: somehow6 Date: Mon, 14 Oct 2024 22:49:24 +0800 Subject: [PATCH 05/17] [Midend] Enhancements and Optimizations for batch matmul and convolution [Examples] Added MLIRLinalg Examples for Various Optimization Options. --- midend/lib/Conversion/MatMulOptimization/CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index 047a7f5c48..776a766b47 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -11,8 +11,6 @@ add_mlir_library(MatMulOptimization add_mlir_library(BatchMatMulOptimization BatchMatMulOptimize.cpp - BatchMatMulTileOptimize.cpp - BatchMatMulSCFOptimize.cpp ) add_mlir_library(MatMulParallelVectorization From e0094ded51ec505fb52de6c32bdf4d55295d579d Mon Sep 17 00:00:00 2001 From: somehow6 Date: Tue, 15 Oct 2024 09:57:04 +0800 Subject: [PATCH 06/17] [Midend] Enhancements and Optimizations for batch matmul and convolution. example check implement. --- .../MLIRLinalg/linalg-batch-matmul-dync.mlir | 30 ++++++--- .../MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir | 61 ++++++++++--------- .../linalg-depthwise_conv_2d_nhwc_hwc.mlir | 32 ++++++---- 3 files changed, 74 insertions(+), 49 deletions(-) diff --git a/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir index 1b910e4a3e..a84e993591 100644 --- a/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir +++ b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir @@ -48,14 +48,28 @@ module { call @buddy_batchmatmul_f32(%A, %B, %C) : (memref, memref, memref) -> () - // Print output. - // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = - // CHECK-NEXT: [ - // CHECK-SAME: [5, 5, 5, 5], - // CHECK-NEXT: [5, 5, 5, 5], - // CHECK-NEXT: [5, 5, 5, 5], - // CHECK-NEXT: [5, 5, 5, 5] - // CHECK-SAME: ] + // CHECK: {{ Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 3 offset = 0 sizes = \[10, 2, 5\] strides = \[10, 5, 1\] data = }} + // CHECK{LITERAL}: [[[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]], + // CHECK{LITERAL}: [[8, 8, 8, 8, 8], + // CHECK{LITERAL}: [8, 8, 8, 8, 8]]] + %print_C = memref.cast %C : memref to memref<*xf32> call @printMemrefF32(%print_C) : (memref<*xf32>) -> () diff --git a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir index 2c8cc171ec..fcddd1e921 100644 --- a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir +++ b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir @@ -35,17 +35,17 @@ module { %cst = arith.constant 0.500000e+00 : f32 %cst_0 = arith.constant 0.000000e+00 : f32 - %current_image_n = arith.constant 2 : index - %current_image_c = arith.constant 18 : index + %current_image_n = arith.constant 1 : index + %current_image_c = arith.constant 2 : index %current_image_h = arith.constant 8 : index %current_image_w = arith.constant 8 : index - %current_filter_f = arith.constant 2 : index - %current_filter_c = arith.constant 18 : index + %current_filter_f = arith.constant 1 : index + %current_filter_c = arith.constant 2 : index %current_filter_h = arith.constant 4 : index %current_filter_w = arith.constant 4 : index - %current_output_n = arith.constant 2 : index + %current_output_n = arith.constant 1 : index %current_output_c = arith.constant 2 : index %current_output_h = arith.constant 5 : index %current_output_w = arith.constant 5 : index @@ -60,32 +60,33 @@ module { call @conv_2d_nhwc_fhwc(%image, %filter, %output) : (memref, memref, memref) -> () %3 = memref.cast %output : memref to memref<*xf32> - - // Print output. - // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [2, 2, 4, 4] strides = [32, 16, 4, 1] data = - // CHECK-NEXT: [ - // CHECK-SAME: [ - // CHECK-SAME: [ - // CHECK-COUNT-3: [32, 32, 32, 32], - // CHECK-NEXT: [32, 32, 32, 32] - // CHECK-SAME: ], - // CHECK-NEXT: [ - // CHECK-COUNT-3: [32, 32, 32, 32], - // CHECK-NEXT: [32, 32, 32, 32] - // CHECK-SAME: ] - // CHECK-SAME: ], - // CHECK-NEXT: [ - // CHECK-SAME: [ - // CHECK-COUNT-3: [32, 32, 32, 32], - // CHECK-NEXT: [32, 32, 32, 32] - // CHECK-SAME: ], - // CHECK-NEXT: [ - // CHECK-COUNT-3: [32, 32, 32, 32], - // CHECK-NEXT: [32, 32, 32, 32] - // CHECK-SAME: ] - // CHECK-SAME: ] - // CHECK-SAME: ] call @printMemrefF32(%3) : (memref<*xf32>) -> () + // CHECK: {{ Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 4 offset = 0 sizes = \[1, 5, 5, 2\] strides = \[50, 10, 2, 1\] data = }} + // CHECK{LITERAL}: [[[[16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1]], + // CHECK{LITERAL}: [[16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1]], + // CHECK{LITERAL}: [[16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1]], + // CHECK{LITERAL}: [[16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1]], + // CHECK{LITERAL}: [[16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1], + // CHECK{LITERAL}: [16, 1]]]] memref.dealloc %output : memref memref.dealloc %image : memref diff --git a/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir index 510835a271..0b6ace667d 100644 --- a/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir +++ b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir @@ -24,19 +24,19 @@ module { %cst_0 = arith.constant 0.000000e+00 : f32 %cf1 = arith.constant 1.0 : f32 - %image_n = arith.constant 2 : index - %image_h = arith.constant 8 : index - %image_w = arith.constant 8 : index - %image_c = arith.constant 18 : index + %image_n = arith.constant 1 : index + %image_h = arith.constant 4 : index + %image_w = arith.constant 4 : index + %image_c = arith.constant 2 : index - %filter_h = arith.constant 4 : index - %filter_w = arith.constant 4 : index - %filter_c = arith.constant 18 : index + %filter_h = arith.constant 1 : index + %filter_w = arith.constant 2 : index + %filter_c = arith.constant 2 : index - %output_n = arith.constant 2 : index - %output_h = arith.constant 5 : index - %output_w = arith.constant 5 : index - %output_c = arith.constant 18 : index + %output_n = arith.constant 1 : index + %output_h = arith.constant 3 : index + %output_w = arith.constant 3 : index + %output_c = arith.constant 2 : index %image = memref.alloc(%image_n,%image_h,%image_w,%image_c) : memref %filter = memref.alloc(%filter_h,%filter_w,%filter_c) : memref @@ -61,6 +61,16 @@ module { // Print the output. call @printMemrefF32(%output_cast) : (memref<*xf32>) -> () + // CHECK: {{ Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 4 offset = 0 sizes = \[1, 3, 3, 2\] strides = \[18, 6, 2, 1\] data = }} + // CHECK{{LITERAL}}: [[[[3, 3], + // CHECK{{LITERAL}}: [3, 3], + // CHECK{{LITERAL}}: [3, 3]], + // CHECK{{LITERAL}}: [[3, 3], + // CHECK{{LITERAL}}: [3, 3], + // CHECK{{LITERAL}}: [3, 3]], + // CHECK{{LITERAL}}: [[3, 3], + // CHECK{{LITERAL}}: [3, 3], + // CHECK{{LITERAL}}: [3, 3]]]] // Deallocate memory. memref.dealloc %output : memref From 7e6c2c6dd127105cdc06a58d9d01af0569c07038 Mon Sep 17 00:00:00 2001 From: somehow6 Date: Tue, 15 Oct 2024 17:32:05 +0800 Subject: [PATCH 07/17] [Examples] Added MLIRLinalg Examples for Various Optimization Options. Fixed FILECHECK bugs. --- .../MLIRLinalg/linalg-batch-matmul-dync.mlir | 46 ++++++----------- .../MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir | 51 +++++++------------ .../linalg-depthwise_conv_2d_nhwc_hwc.mlir | 23 +++++---- 3 files changed, 45 insertions(+), 75 deletions(-) diff --git a/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir index a84e993591..04dea80df6 100644 --- a/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir +++ b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir @@ -1,9 +1,9 @@ -// RUN: buddy-opt %s \ -// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ -// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ -// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ -// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: buddy-opt %s -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ +// RUN: -convert-linalg-to-loops -expand-strided-metadata -lower-affine \ +// RUN: -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm \ +// RUN: -convert-arith-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s @@ -20,9 +20,9 @@ module { func.func @main(){ // Set up dims. - %cBatch = arith.constant 10:index + %cBatch = arith.constant 2:index %cM = arith.constant 2 : index - %cN = arith.constant 5 : index + %cN = arith.constant 3 : index %cK = arith.constant 4 : index // Set Init Value. @@ -48,28 +48,6 @@ module { call @buddy_batchmatmul_f32(%A, %B, %C) : (memref, memref, memref) -> () - // CHECK: {{ Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 3 offset = 0 sizes = \[10, 2, 5\] strides = \[10, 5, 1\] data = }} - // CHECK{LITERAL}: [[[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]], - // CHECK{LITERAL}: [[8, 8, 8, 8, 8], - // CHECK{LITERAL}: [8, 8, 8, 8, 8]]] - %print_C = memref.cast %C : memref to memref<*xf32> call @printMemrefF32(%print_C) : (memref<*xf32>) -> () @@ -77,5 +55,11 @@ module { memref.dealloc %B : memref memref.dealloc %A : memref return - } + } } + +// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [2, 2, 3] strides = [6, 3, 1] data = +// CHECK{LITERAL}: [[[8, 8, 8], +// CHECK{LITERAL}: [8, 8, 8]], +// CHECK{LITERAL}: [[8, 8, 8], +// CHECK{LITERAL}: [8, 8, 8]]] diff --git a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir index fcddd1e921..a06c0d562d 100644 --- a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir +++ b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir @@ -1,5 +1,5 @@ // RUN: buddy-opt %s \ -// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -conv-nhwc-fhwc-optimize -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ // RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ // RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ @@ -37,18 +37,18 @@ module { %current_image_n = arith.constant 1 : index %current_image_c = arith.constant 2 : index - %current_image_h = arith.constant 8 : index - %current_image_w = arith.constant 8 : index + %current_image_h = arith.constant 4 : index + %current_image_w = arith.constant 4 : index %current_filter_f = arith.constant 1 : index %current_filter_c = arith.constant 2 : index - %current_filter_h = arith.constant 4 : index - %current_filter_w = arith.constant 4 : index + %current_filter_h = arith.constant 2 : index + %current_filter_w = arith.constant 2 : index %current_output_n = arith.constant 1 : index %current_output_c = arith.constant 2 : index - %current_output_h = arith.constant 5 : index - %current_output_w = arith.constant 5 : index + %current_output_h = arith.constant 3 : index + %current_output_w = arith.constant 3 : index // Image. %image = call @alloc_2d_filled_f32(%current_image_n,%current_image_h, %current_image_w, %current_image_c, %cst) : (index, index, index, index, f32) -> memref @@ -61,32 +61,7 @@ module { %3 = memref.cast %output : memref to memref<*xf32> call @printMemrefF32(%3) : (memref<*xf32>) -> () - // CHECK: {{ Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 4 offset = 0 sizes = \[1, 5, 5, 2\] strides = \[50, 10, 2, 1\] data = }} - // CHECK{LITERAL}: [[[[16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1]], - // CHECK{LITERAL}: [[16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1]], - // CHECK{LITERAL}: [[16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1]], - // CHECK{LITERAL}: [[16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1]], - // CHECK{LITERAL}: [[16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1], - // CHECK{LITERAL}: [16, 1]]]] + memref.dealloc %output : memref memref.dealloc %image : memref @@ -95,3 +70,13 @@ module { } } +// CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 3, 3, 2] strides = [18, 6, 2, 1] data = +// CHECK{LITERAL}: [[[[4, 1], +// CHECK{LITERAL}: [4, 1], +// CHECK{LITERAL}: [4, 1]], +// CHECK{LITERAL}: [[4, 1], +// CHECK{LITERAL}: [4, 1], +// CHECK{LITERAL}: [4, 1]], +// CHECK{LITERAL}: [[4, 1], +// CHECK{LITERAL}: [4, 1], +// CHECK{LITERAL}: [4, 1]]]] diff --git a/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir index 0b6ace667d..905df48bd8 100644 --- a/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir +++ b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir @@ -1,5 +1,5 @@ // RUN: buddy-opt %s \ -// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -depthwise-conv-nhwc-hwc-optimize -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ // RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ // RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ @@ -61,16 +61,6 @@ module { // Print the output. call @printMemrefF32(%output_cast) : (memref<*xf32>) -> () - // CHECK: {{ Unranked Memref base@ = 0x[0-9A-Fa-f]{1,} rank = 4 offset = 0 sizes = \[1, 3, 3, 2\] strides = \[18, 6, 2, 1\] data = }} - // CHECK{{LITERAL}}: [[[[3, 3], - // CHECK{{LITERAL}}: [3, 3], - // CHECK{{LITERAL}}: [3, 3]], - // CHECK{{LITERAL}}: [[3, 3], - // CHECK{{LITERAL}}: [3, 3], - // CHECK{{LITERAL}}: [3, 3]], - // CHECK{{LITERAL}}: [[3, 3], - // CHECK{{LITERAL}}: [3, 3], - // CHECK{{LITERAL}}: [3, 3]]]] // Deallocate memory. memref.dealloc %output : memref @@ -79,3 +69,14 @@ module { return } } + +// CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 3, 3, 2] strides = [18, 6, 2, 1] data = +// CHECK{LITERAL}: [[[[3, 3], +// CHECK{LITERAL}: [3, 3], +// CHECK{LITERAL}: [3, 3]], +// CHECK{LITERAL}: [[3, 3], +// CHECK{LITERAL}: [3, 3], +// CHECK{LITERAL}: [3, 3]], +// CHECK{LITERAL}: [[3, 3], +// CHECK{LITERAL}: [3, 3], +// CHECK{LITERAL}: [3, 3]]]] From ddc1a48beb073340b9aebfe012b32a437557cdf3 Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Tue, 15 Oct 2024 11:05:53 +0800 Subject: [PATCH 08/17] [Container] Add macro to control libpng. Co-authored-by: zhanghb97 --- CMakeLists.txt | 5 ++ README.md | 14 ++++ frontend/Interfaces/buddy/DIP/ImgContainer.h | 15 +++- tests/CMakeLists.txt | 5 +- tests/Interface/core/CMakeLists.txt | 16 +++- ...rTest.cpp => NewImageContainerTestBmp.cpp} | 59 +------------ .../core/NewImageContainerTestPng.cpp | 82 +++++++++++++++++++ tests/Interface/core/lit.local.cfg | 6 +- tests/lit.cfg.py | 4 +- tests/lit.site.cfg.py.in | 1 + 10 files changed, 141 insertions(+), 66 deletions(-) rename tests/Interface/core/{NewImageContainerTest.cpp => NewImageContainerTestBmp.cpp} (75%) create mode 100644 tests/Interface/core/NewImageContainerTestPng.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 796f5e344f..cd2379468a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,11 @@ if(BUDDY_MLIR_ENABLE_DIP_LIB) find_package(PNG REQUIRED) endif() +if(BUDDY_ENABLE_PNG) + add_definitions(-DBUDDY_ENABLE_PNG) + find_package(PNG REQUIRED) +endif() + # Generate libraries into `lib` of build directory. set(LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) diff --git a/README.md b/README.md index d413513ca5..2e44658b02 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,20 @@ $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build $ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` +To configure the build environment for using image processing libraries, follow these steps: + +``` +$ cmake -G Ninja .. \ + -DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \ + -DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DBUDDY_MLIR_ENABLE_DIP_LIB=ON \ + -DBUDDY_ENABLE_PNG=ON +$ ninja +$ ninja check-buddy +``` + To build buddy-mlir with custom LLVM sources: ``` diff --git a/frontend/Interfaces/buddy/DIP/ImgContainer.h b/frontend/Interfaces/buddy/DIP/ImgContainer.h index 382974e967..2525641bff 100644 --- a/frontend/Interfaces/buddy/DIP/ImgContainer.h +++ b/frontend/Interfaces/buddy/DIP/ImgContainer.h @@ -25,7 +25,10 @@ #include #include #include +#include +#ifdef BUDDY_ENABLE_PNG #include +#endif namespace dip { enum ImageModes { @@ -88,7 +91,9 @@ template class Image : public MemRef { // Decodes a BMP image from raw file data. bool decodeBMP(const std::vector &fileData); // Decodes a PNG image from raw file data. +#ifdef BUDDY_ENABLE_PNG bool decodePNG(const std::vector &fileData); +#endif }; // Image Container Constructor @@ -129,13 +134,17 @@ Image::Image(std::string filePath, ImageModes mode, bool norm) this->imageFormat = ImageFormat::ERROR; throw std::runtime_error("Failed to decode BMP file from " + filePath); }; - } else if (this->imageFormat == ImageFormat::PNG) { + } +#ifdef BUDDY_ENABLE_PNG + else if (this->imageFormat == ImageFormat::PNG) { bool success = decodePNG(fileData); if (!success) { this->imageFormat = ImageFormat::ERROR; throw std::runtime_error("Failed to decode PNG file from " + filePath); }; - } else { + } +#endif + else { throw std::runtime_error("Unsupported image file format."); } } @@ -414,6 +423,7 @@ bool Image::decodeBMP(const std::vector &fileData) { } // PNG Image File Decoder +#ifdef BUDDY_ENABLE_PNG template bool Image::decodePNG(const std::vector &fileData) { // Check if the provided data is large enough to contain a minimal PNG header @@ -604,6 +614,7 @@ bool Image::decodePNG(const std::vector &fileData) { } return true; } +#endif } // namespace dip diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bd4b3ef335..2cffa98469 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -22,7 +22,10 @@ if(BUDDY_ENABLE_OPENCV) endif() if(BUDDY_MLIR_ENABLE_DIP_LIB) - list(APPEND BUDDY_TEST_DEPENDS buddy-new-image-container-test) + list(APPEND BUDDY_TEST_DEPENDS buddy-new-image-container-test-bmp) + if(BUDDY_ENABLE_PNG) + list(APPEND BUDDY_TEST_DEPENDS buddy-new-image-container-test-png) + endif() endif() add_lit_testsuite(check-tests "Running the buddy regression tests..." diff --git a/tests/Interface/core/CMakeLists.txt b/tests/Interface/core/CMakeLists.txt index dd7a191e7f..b84ae71aef 100644 --- a/tests/Interface/core/CMakeLists.txt +++ b/tests/Interface/core/CMakeLists.txt @@ -17,10 +17,18 @@ if(BUDDY_MLIR_ENABLE_DIP_LIB OR BUDDY_ENABLE_OPENCV) ) endif() -if (BUDDY_MLIR_ENABLE_DIP_LIB) - set(NEW_DIP_LIBS ${PNG_LIBRARIES}) - _add_test_executable(buddy-new-image-container-test - NewImageContainerTest.cpp +if(BUDDY_MLIR_ENABLE_DIP_LIB) + set(NEW_DIP_LIBS "") + if(BUDDY_ENABLE_PNG) + list(APPEND NEW_DIP_LIBS ${PNG_LIBRARIES}) + _add_test_executable(buddy-new-image-container-test-png + NewImageContainerTestPng.cpp + LINK_LIBS + ${NEW_DIP_LIBS} + ) + endif() + _add_test_executable(buddy-new-image-container-test-bmp + NewImageContainerTestBmp.cpp LINK_LIBS ${NEW_DIP_LIBS} ) diff --git a/tests/Interface/core/NewImageContainerTest.cpp b/tests/Interface/core/NewImageContainerTestBmp.cpp similarity index 75% rename from tests/Interface/core/NewImageContainerTest.cpp rename to tests/Interface/core/NewImageContainerTestBmp.cpp index c109230790..13f1a9c7cf 100644 --- a/tests/Interface/core/NewImageContainerTest.cpp +++ b/tests/Interface/core/NewImageContainerTestBmp.cpp @@ -1,4 +1,4 @@ -//===- NewImageContainerTest.cpp ------------------------------------------===// +//===- NewImageContainerTestBmp.cpp ---------------------------------------===// // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ // //===----------------------------------------------------------------------===// -// RUN: buddy-new-image-container-test 2>&1 | FileCheck %s +// RUN: buddy-new-image-container-test-bmp 2>&1 | FileCheck %s #include @@ -167,60 +167,5 @@ int main() { // CHECK: 0.45490 fprintf(stderr, "%f\n", bmp24bitRGBNorm.getData()[0]); - // Default Gray Scale - dip::Image pngGrayDefault( - "../../../../tests/Interface/core/TestGrayImage.png", dip::DIP_GRAYSCALE); - // CHECK: PNG - fprintf(stderr, "%s\n", pngGrayDefault.getFormatName().c_str()); - // CHECK: 4 - fprintf(stderr, "%ld\n", pngGrayDefault.getWidth()); - // CHECK: 4 - fprintf(stderr, "%ld\n", pngGrayDefault.getHeight()); - // CHECK: 8 - fprintf(stderr, "%d\n", pngGrayDefault.getBitDepth()); - // CHECK: 15 - fprintf(stderr, "%f\n", pngGrayDefault.getData()[0]); - // Gray Scale + Normalization - dip::Image pngGrayNorm( - "../../../../tests/Interface/core/TestGrayImage.png", dip::DIP_GRAYSCALE, - true /* norm */); - // CHECK: PNG - fprintf(stderr, "%s\n", pngGrayNorm.getFormatName().c_str()); - // CHECK: 4 - fprintf(stderr, "%ld\n", pngGrayNorm.getWidth()); - // CHECK: 4 - fprintf(stderr, "%ld\n", pngGrayNorm.getHeight()); - // CHECK: 8 - fprintf(stderr, "%d\n", pngGrayNorm.getBitDepth()); - // CHECK: 0.058824 - fprintf(stderr, "%f\n", pngGrayNorm.getData()[0]); - - dip::Image pngRGBDefault( - "../../../../tests/Interface/core/TestImage-RGB.png", dip::DIP_RGB); - // CHECK: PNG - fprintf(stderr, "%s\n", pngRGBDefault.getFormatName().c_str()); - // CHECK: 224 - fprintf(stderr, "%ld\n", pngRGBDefault.getWidth()); - // CHECK: 224 - fprintf(stderr, "%ld\n", pngRGBDefault.getHeight()); - // CHECK: 8 - fprintf(stderr, "%d\n", pngRGBDefault.getBitDepth()); - // CHECK: 144 - fprintf(stderr, "%f\n", pngRGBDefault.getData()[0]); - - dip::Image pngRGBNorm( - "../../../../tests/Interface/core/TestImage-RGB.png", dip::DIP_RGB, - true /* norm */); - // CHECK: PNG - fprintf(stderr, "%s\n", pngRGBNorm.getFormatName().c_str()); - // CHECK: 224 - fprintf(stderr, "%ld\n", pngRGBNorm.getWidth()); - // CHECK: 224 - fprintf(stderr, "%ld\n", pngRGBNorm.getHeight()); - // CHECK: 8 - fprintf(stderr, "%d\n", pngRGBNorm.getBitDepth()); - // CHECK: 0.5647 - fprintf(stderr, "%f\n", pngRGBNorm.getData()[0]); - return 0; } diff --git a/tests/Interface/core/NewImageContainerTestPng.cpp b/tests/Interface/core/NewImageContainerTestPng.cpp new file mode 100644 index 0000000000..0f1dea37c3 --- /dev/null +++ b/tests/Interface/core/NewImageContainerTestPng.cpp @@ -0,0 +1,82 @@ +//===- NewImageContainerTestPng.cpp ---------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the image container test file. +// +//===----------------------------------------------------------------------===// + +// RUN: buddy-new-image-container-test-png 2>&1 | FileCheck %s + +#include + +int main() { + // Default Gray Scale + dip::Image pngGrayDefault( + "../../../../tests/Interface/core/TestGrayImage.png", dip::DIP_GRAYSCALE); + // CHECK: PNG + fprintf(stderr, "%s\n", pngGrayDefault.getFormatName().c_str()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayDefault.getWidth()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayDefault.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngGrayDefault.getBitDepth()); + // CHECK: 15 + fprintf(stderr, "%f\n", pngGrayDefault.getData()[0]); + // Gray Scale + Normalization + dip::Image pngGrayNorm( + "../../../../tests/Interface/core/TestGrayImage.png", dip::DIP_GRAYSCALE, + true /* norm */); + // CHECK: PNG + fprintf(stderr, "%s\n", pngGrayNorm.getFormatName().c_str()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayNorm.getWidth()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayNorm.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngGrayNorm.getBitDepth()); + // CHECK: 0.058824 + fprintf(stderr, "%f\n", pngGrayNorm.getData()[0]); + + dip::Image pngRGBDefault( + "../../../../tests/Interface/core/TestImage-RGB.png", dip::DIP_RGB); + // CHECK: PNG + fprintf(stderr, "%s\n", pngRGBDefault.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBDefault.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBDefault.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngRGBDefault.getBitDepth()); + // CHECK: 144 + fprintf(stderr, "%f\n", pngRGBDefault.getData()[0]); + + dip::Image pngRGBNorm( + "../../../../tests/Interface/core/TestImage-RGB.png", dip::DIP_RGB, + true /* norm */); + // CHECK: PNG + fprintf(stderr, "%s\n", pngRGBNorm.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBNorm.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBNorm.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngRGBNorm.getBitDepth()); + // CHECK: 0.5647 + fprintf(stderr, "%f\n", pngRGBNorm.getData()[0]); + + return 0; +} diff --git a/tests/Interface/core/lit.local.cfg b/tests/Interface/core/lit.local.cfg index 83f1696f7c..f5c2722550 100644 --- a/tests/Interface/core/lit.local.cfg +++ b/tests/Interface/core/lit.local.cfg @@ -2,4 +2,8 @@ if config.buddy_enable_opencv != 'ON': config.excludes.add('ImageContainerTest.cpp') if config.buddy_mlir_enable_dip_lib != 'ON': - config.excludes.add('NewImageContainerTest.cpp') \ No newline at end of file + config.excludes.add('NewImageContainerTestBmp.cpp') + config.excludes.add('NewImageContainerTestPng.cpp') + +if config.buddy_enable_png != 'ON': + config.excludes.add('NewImageContainerTestPng.cpp') diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py index 0ff7379027..2982e2851f 100644 --- a/tests/lit.cfg.py +++ b/tests/lit.cfg.py @@ -108,6 +108,8 @@ tools.append("buddy-image-container-test") if config.buddy_mlir_enable_dip_lib == "ON": - tools.append("buddy-new-image-container-test") + tools.append("buddy-new-image-container-test-bmp") + if config.buddy_enable_png == "ON": + tools.append("buddy-new-image-container-test-png") llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tests/lit.site.cfg.py.in b/tests/lit.site.cfg.py.in index 5f2d7276b0..0011f056f7 100644 --- a/tests/lit.site.cfg.py.in +++ b/tests/lit.site.cfg.py.in @@ -34,6 +34,7 @@ config.buddy_src_root = "@CMAKE_SOURCE_DIR@" config.buddy_obj_root = "@CMAKE_BINARY_DIR@" config.buddy_tools_dir = "@BUDDY_BINARY_DIR@" config.buddy_enable_opencv = "@BUDDY_ENABLE_OPENCV@" +config.buddy_enable_png = "@BUDDY_ENABLE_PNG@" config.buddy_mlir_enable_dip_lib = "@BUDDY_MLIR_ENABLE_DIP_LIB@" config.buddy_mlir_enable_python_packages = "@BUDDY_MLIR_ENABLE_PYTHON_PACKAGES@" config.buddy_python_packages_dir = "@BUDDY_MLIR_PYTHON_PACKAGES_DIR@" From 5011b33a790d4bc89f9cec7868c5f14796c99a48 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Wed, 16 Oct 2024 10:56:52 +0000 Subject: [PATCH 09/17] [examples] Update MobileNet example doc. --- examples/BuddyMobileNetV3/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/BuddyMobileNetV3/README.md b/examples/BuddyMobileNetV3/README.md index 984fe9e306..a55cd74304 100644 --- a/examples/BuddyMobileNetV3/README.md +++ b/examples/BuddyMobileNetV3/README.md @@ -16,7 +16,8 @@ $ cmake -G Ninja .. \ -DCMAKE_BUILD_TYPE=RELEASE \ -DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \ -DPython3_EXECUTABLE=$(which python3) \ - -DBUDDY_MLIR_ENABLE_DIP_LIB=ON + -DBUDDY_MLIR_ENABLE_DIP_LIB=ON \ + -DBUDDY_ENABLE_PNG=ON $ ninja $ ninja check-buddy ``` From ec6860494f32e819f4caf1a7a2c32bd5593e1c46 Mon Sep 17 00:00:00 2001 From: Wu Xintong <56297184+WuXintong123@users.noreply.github.com> Date: Wed, 16 Oct 2024 19:34:41 +0800 Subject: [PATCH 10/17] [examples] Remove the num_batches_tracked attribute for MobileNet example. (#407) --- examples/BuddyMobileNetV3/CMakeLists.txt | 1 - .../buddy-mobilenetv3-import.py | 21 ++++-- .../buddy-mobilenetv3-main.cpp | 65 +++++++++---------- 3 files changed, 45 insertions(+), 42 deletions(-) diff --git a/examples/BuddyMobileNetV3/CMakeLists.txt b/examples/BuddyMobileNetV3/CMakeLists.txt index c5a932c673..ef60c7e931 100644 --- a/examples/BuddyMobileNetV3/CMakeLists.txt +++ b/examples/BuddyMobileNetV3/CMakeLists.txt @@ -1,6 +1,5 @@ add_custom_command( OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/arg0.data - ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/arg1.data ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/subgraph0.mlir COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/buddy-mobilenetv3-import.py diff --git a/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py b/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py index 2403800bf9..704b8fc2e3 100644 --- a/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py +++ b/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py @@ -38,9 +38,17 @@ "The environment variable 'MOBILENETV3_MODEL_PATH' is not set or is invalid." ) -model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1, pretrained=True) +model = models.mobilenet_v3_small( + weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1, pretrained=True +) model = model.eval() +# Remove the num_batches_tracked attribute. +for layer in model.modules(): + if isinstance(layer, torch.nn.BatchNorm2d): + if hasattr(layer, "num_batches_tracked"): + del layer.num_batches_tracked + # Initialize Dynamo Compiler with specific configurations as an importer. dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, @@ -68,11 +76,10 @@ float32_param = np.concatenate( - [param.detach().numpy().reshape([-1]) for param in params if param.dtype == torch.float32] + [ + param.detach().numpy().reshape([-1]) + for param in params + if param.dtype == torch.float32 + ] ) float32_param.tofile(Path(current_path) / "arg0.data") - -int64_param = np.concatenate( - [param.detach().numpy().reshape([-1]) for param in params if param.dtype == torch.int64] -) -int64_param.tofile(Path(current_path) / "arg1.data") diff --git a/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp b/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp index a9eb1a2aa1..90defb895e 100644 --- a/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp +++ b/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp @@ -33,43 +33,43 @@ const std::string ImgName = "dog.png"; // Declare the mobilenet C interface. extern "C" void _mlir_ciface_forward(MemRef *output, MemRef *arg0, - MemRef *arg1, MemRef *input); /// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } -void loadParameters(const std::string &floatParamPath, - const std::string &int64ParamPath, - MemRef &floatParam, - MemRef &int64Param) { - std::ifstream floatParamFile(floatParamPath, std::ios::in | std::ios::binary); - if (!floatParamFile.is_open()) { - std::string errMsg = "Failed to open float param file: " + - std::filesystem::canonical(floatParamPath).string(); - throw std::runtime_error(errMsg); +/// Load parameters into data container. +void loadParameters(const std::string ¶mFilePath, + MemRef ¶ms) { + const auto loadStart = std::chrono::high_resolution_clock::now(); + // Open the parameter file in binary mode. + std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary); + if (!paramFile.is_open()) { + throw std::runtime_error("[Error] Failed to open params file!"); } - floatParamFile.read(reinterpret_cast(floatParam.getData()), - floatParam.getSize() * sizeof(float)); - if (floatParamFile.fail()) { - throw std::runtime_error("Failed to read float param file"); + printLogLabel(); + std::cout << "Loading params..." << std::endl; + printLogLabel(); + // Print the canonical path of the parameter file. + std::cout << "Params file: " << std::filesystem::canonical(paramFilePath) + << std::endl; + // Read the parameter data into the provided memory reference. + paramFile.read(reinterpret_cast(params.getData()), + sizeof(float) * (params.getSize())); + if (paramFile.fail()) { + throw std::runtime_error("Error occurred while reading params file!"); } - floatParamFile.close(); - - std::ifstream int64ParamFile(int64ParamPath, std::ios::in | std::ios::binary); - if (!int64ParamFile.is_open()) { - std::string errMsg = "Failed to open int64 param file: " + - std::filesystem::canonical(int64ParamPath).string(); - throw std::runtime_error(errMsg); - } - int64ParamFile.read(reinterpret_cast(int64Param.getData()), - int64Param.getSize() * sizeof(long long)); - if (int64ParamFile.fail()) { - throw std::runtime_error("Failed to read int64 param file"); - } - int64ParamFile.close(); + paramFile.close(); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Params load time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; } + // Softmax function. void softmax(float *input, size_t size) { size_t i; @@ -124,13 +124,10 @@ int main() { // Load model parameters from the specified file. std::string paramsDir = mobilenetDir + "/arg0.data"; - std::string intDir = mobilenetDir + "/arg1.data"; - MemRef paramsContainerf32({ParamsSize}); - MemRef ParamsContainerInt64({34}); - loadParameters(paramsDir, intDir, paramsContainerf32, ParamsContainerInt64); + MemRef paramsContainer({ParamsSize}); + loadParameters(paramsDir, paramsContainer); // Call the forward function of the model. - _mlir_ciface_forward(&output, ¶msContainerf32, &ParamsContainerInt64, - &inputResize); + _mlir_ciface_forward(&output, ¶msContainer, &inputResize); auto out = output.getData(); softmax(out, 1000); From e154c3aa926415c54abbeec1b394355563e1922b Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Fri, 18 Oct 2024 02:22:15 +0000 Subject: [PATCH 11/17] [examples] Update vectorization iteration pattern. --- examples/MLIRVector/vector-iteration.mlir | 28 ++++++++++------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/examples/MLIRVector/vector-iteration.mlir b/examples/MLIRVector/vector-iteration.mlir index ba10f27ce0..7d63f22896 100644 --- a/examples/MLIRVector/vector-iteration.mlir +++ b/examples/MLIRVector/vector-iteration.mlir @@ -60,16 +60,14 @@ func.func @main() -> i32 { %load_vec2 = vector.load %mem_pat_1[%i] : memref<10xf32>, vector<4xf32> %res = arith.addf %load_vec1, %load_vec2 : vector<4xf32> vector.store %res, %mem_pat_1[%i] : memref<10xf32>, vector<4xf32> - scf.yield %i : index + %i_next = arith.addi %i, %vl_step_pat_1 : index + scf.yield %i_next : index } // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9] call @printMemrefF32(%print_mem_pat_1) : (memref<*xf32>) -> () - // 5. Calculate the position for tail processing. - %tail_idx_pat_1 = arith.addi %iter_idx_pat_1, %vl_step_pat_1 : index - - // 6. Process the remainder of the elements with scalar operations. - scf.for %i = %tail_idx_pat_1 to %vl_total_pat_1 step %c1 { + // 5. Process the remainder of the elements with scalar operations. + scf.for %i = %iter_idx_pat_1 to %vl_total_pat_1 step %c1 { %ele1 = memref.load %mem_pat_1[%i] : memref<10xf32> %ele2 = memref.load %mem_pat_1[%i] : memref<10xf32> %res = arith.addf %ele1, %ele2 : f32 @@ -105,25 +103,23 @@ func.func @main() -> i32 { %load_vec2 = vector.load %mem_pat_2[%i] : memref<10xf32>, vector<4xf32> %res = arith.addf %load_vec1, %load_vec2 : vector<4xf32> vector.store %res, %mem_pat_2[%i] : memref<10xf32>, vector<4xf32> - scf.yield %i : index + %i_next = arith.addi %i, %vl_step_pat_1 : index + scf.yield %i_next : index } // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9] call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> () - // 5. Calculate the position for tail processing. - %tail_idx_pat_2 = arith.addi %iter_idx_pat_2, %vl_step_pat_2 : index - - // 6. Compute the tail size and create mask and pass-through vector for the + // 5. Compute the tail size and create mask and pass-through vector for the // remaining elements. - %tail_size_pat_2 = arith.subi %vl_total_pat_2, %iter_idx_pat_2 :index + %tail_size_pat_2 = arith.subi %vl_total_pat_2, %iter_idx_pat_2 : index %mask_pat_2 = vector.create_mask %tail_size_pat_2 : vector<4xi1> %pass_thr_vec = arith.constant dense<0.> : vector<4xf32> - // 7. Process the remaining elements using masked vector operations. - %ele1 = vector.maskedload %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> - %ele2 = vector.maskedload %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> + // 6. Process the remaining elements using masked vector operations. + %ele1 = vector.maskedload %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> + %ele2 = vector.maskedload %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> %res = arith.addf %ele1, %ele2 : vector<4xf32> - vector.maskedstore %mem_pat_2[%tail_idx_pat_2], %mask_pat_2, %res : memref<10xf32>, vector<4xi1>, vector<4xf32> + vector.maskedstore %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %res : memref<10xf32>, vector<4xi1>, vector<4xf32> // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> () From 4dfcb7ee2babd5fac6231c35b9decf1a68f01838 Mon Sep 17 00:00:00 2001 From: somehow6 Date: Fri, 18 Oct 2024 15:27:57 +0800 Subject: [PATCH 12/17] [Examples/MLIRLinalg] buddy-mlir/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir fix input argvs --- .../MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir index a06c0d562d..ea81007153 100644 --- a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir +++ b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir @@ -40,7 +40,7 @@ module { %current_image_h = arith.constant 4 : index %current_image_w = arith.constant 4 : index - %current_filter_f = arith.constant 1 : index + %current_filter_f = arith.constant 2 : index %current_filter_c = arith.constant 2 : index %current_filter_h = arith.constant 2 : index %current_filter_w = arith.constant 2 : index @@ -71,12 +71,12 @@ module { } // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 3, 3, 2] strides = [18, 6, 2, 1] data = -// CHECK{LITERAL}: [[[[4, 1], -// CHECK{LITERAL}: [4, 1], -// CHECK{LITERAL}: [4, 1]], -// CHECK{LITERAL}: [[4, 1], -// CHECK{LITERAL}: [4, 1], -// CHECK{LITERAL}: [4, 1]], -// CHECK{LITERAL}: [[4, 1], -// CHECK{LITERAL}: [4, 1], -// CHECK{LITERAL}: [4, 1]]]] +// CHECK{LITERAL}: [[[[4, 5], +// CHECK{LITERAL}: [4, 5], +// CHECK{LITERAL}: [4, 5]], +// CHECK{LITERAL}: [[4, 5], +// CHECK{LITERAL}: [4, 5], +// CHECK{LITERAL}: [4, 5]], +// CHECK{LITERAL}: [[4, 5], +// CHECK{LITERAL}: [4, 5], +// CHECK{LITERAL}: [4, 5]]]] From c1175b74e7459bea3bb89060371d1336571e6240 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Sat, 19 Oct 2024 11:21:13 +0000 Subject: [PATCH 13/17] [examples] Add VectorExp iteration pattern. --- examples/VectorExpDialect/makefile | 21 +++++++ .../vector-exp-iteration.mlir | 57 +++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 examples/VectorExpDialect/vector-exp-iteration.mlir diff --git a/examples/VectorExpDialect/makefile b/examples/VectorExpDialect/makefile index ab85a8a2cc..fc88556419 100644 --- a/examples/VectorExpDialect/makefile +++ b/examples/VectorExpDialect/makefile @@ -319,3 +319,24 @@ vector-exp-dynamic-vector-run: -L${CROSS_MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ -o a.out @LD_LIBRARY_PATH=${CROSS_MLIR_LIB} ${QEMU} -L ${RISCV_GNU_TOOLCHAIN_SYSROOT} -cpu max a.out + +vector-exp-iteration-aot: + @${BUDDY_OPT} ./vector-exp-iteration.mlir \ + -lower-vector-exp \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-index-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir -o log.ll + ${LOCAL_CLANG} -O3 log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu -fPIC \ + --sysroot=${RISCV_GNU_TOOLCHAIN}/sysroot \ + --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -L${CROSS_MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + diff --git a/examples/VectorExpDialect/vector-exp-iteration.mlir b/examples/VectorExpDialect/vector-exp-iteration.mlir new file mode 100644 index 0000000000..2606596988 --- /dev/null +++ b/examples/VectorExpDialect/vector-exp-iteration.mlir @@ -0,0 +1,57 @@ +memref.global "private" @gv : memref<10xf32> = dense<[0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9.]> + +func.func private @printMemrefF32(memref<*xf32>) + +func.func @main() -> i32 { + %c0 = arith.constant 0 : index + + // --------------------------------------------------------------------------- + // Iteration Pattern for RVV Dynamic Vector Length + // --------------------------------------------------------------------------- + + // 1. Get the total length of the workload. + %mem = memref.get_global @gv : memref<10xf32> + %print_mem = memref.cast %mem : memref<10xf32> to memref<*xf32> + %vl_total = memref.dim %mem, %c0 : memref<10xf32> + + // 2. Set the scale factor, iteration step, and mask. + %vs = vector.vscale + %factor = arith.constant 2 : index + %vl_step = arith.muli %vs, %factor : index + %mask = arith.constant dense<1> : vector<[2]xi1> + %vl_total_i32 = index.casts %vl_total : index to i32 + %vl_step_i32 = index.casts %vl_step : index to i32 + + // 3. Perform the vectorization. + %iter_vl = scf.for %i = %c0 to %vl_total step %vl_step + iter_args(%iter_vl_i32 = %vl_total_i32) -> (i32) { + + %load_vec1 = vector_exp.predication %mask, %iter_vl_i32 : vector<[2]xi1>, i32 { + %ele = vector.load %mem[%i] : memref<10xf32>, vector<[2]xf32> + vector.yield %ele : vector<[2]xf32> + } : vector<[2]xf32> + + %load_vec2 = vector_exp.predication %mask, %iter_vl_i32 : vector<[2]xi1>, i32 { + %ele = vector.load %mem[%i] : memref<10xf32>, vector<[2]xf32> + vector.yield %ele : vector<[2]xf32> + } : vector<[2]xf32> + + %res = "llvm.intr.vp.fadd" (%load_vec1, %load_vec2, %mask, %iter_vl_i32) : + (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xi1>, i32) -> vector<[2]xf32> + + vector_exp.predication %mask, %iter_vl_i32 : vector<[2]xi1>, i32 { + vector.store %res, %mem[%i] : memref<10xf32>, vector<[2]xf32> + vector.yield + } : () -> () + + // Update dynamic vector length. + %new_vl = arith.subi %vl_total_i32, %vl_step_i32 : i32 + scf.yield %new_vl : i32 + } + + // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9] + call @printMemrefF32(%print_mem) : (memref<*xf32>) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} From 49f4410433e0e8b0a7ceb0d24c5a1bca43935e3b Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Sat, 19 Oct 2024 11:33:35 +0000 Subject: [PATCH 14/17] [examples] Update VectorExp iteration value. --- examples/VectorExpDialect/vector-exp-iteration.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/VectorExpDialect/vector-exp-iteration.mlir b/examples/VectorExpDialect/vector-exp-iteration.mlir index 2606596988..bc879d0103 100644 --- a/examples/VectorExpDialect/vector-exp-iteration.mlir +++ b/examples/VectorExpDialect/vector-exp-iteration.mlir @@ -45,7 +45,7 @@ func.func @main() -> i32 { } : () -> () // Update dynamic vector length. - %new_vl = arith.subi %vl_total_i32, %vl_step_i32 : i32 + %new_vl = arith.subi %iter_vl_i32, %vl_step_i32 : i32 scf.yield %new_vl : i32 } From b91e06460f14ccef4640a7580c8c254358442f1e Mon Sep 17 00:00:00 2001 From: RISHIK RAM <44949025+markram1729@users.noreply.github.com> Date: Tue, 22 Oct 2024 07:44:24 +0530 Subject: [PATCH 15/17] [Container] Add OpenCV test for ImageContainer (#303) --- .../Interfaces/buddy/DIP/ImageContainer.h | 8 +- tests/Interface/core/ImageContainerTest.cpp | 206 ++++++++++++++++-- 2 files changed, 197 insertions(+), 17 deletions(-) diff --git a/frontend/Interfaces/buddy/DIP/ImageContainer.h b/frontend/Interfaces/buddy/DIP/ImageContainer.h index a613ceb351..4470e4a443 100644 --- a/frontend/Interfaces/buddy/DIP/ImageContainer.h +++ b/frontend/Interfaces/buddy/DIP/ImageContainer.h @@ -141,6 +141,7 @@ Img::Img(T *data, intptr_t sizes[N]) : MemRef(data, sizes) {} #ifdef BUDDY_ENABLE_OPENCV // Image Constructor from OpenCV Mat. + template Img::Img(cv::Mat image, intptr_t sizes[N], bool norm) : MemRef() { if (image.channels() == 1) { @@ -189,14 +190,16 @@ Img::Img(cv::Mat image, intptr_t sizes[N], bool norm) : MemRef() { this->allocated = new T[size]; this->aligned = this->allocated; size_t k = 0; + //NCHW Layout for (int batch = 0; batch < this->sizes[0]; batch++) { for (int channel = 0; channel < this->sizes[1]; channel++) { + T *chandata = image.ptr(batch, channel); for (int row = 0; row < this->sizes[2]; row++) { for (int col = 0; col < this->sizes[3]; col++) { if (norm) { - this->aligned[k] = (T)image.at(row, col) / 255; + this->aligned[k] = chandata[row * this->sizes[3] + col] / 255; } else { - this->aligned[k] = (T)image.at(row, col); + this->aligned[k] = chandata[row * this->sizes[3] + col]; } k++; } @@ -205,6 +208,7 @@ Img::Img(cv::Mat image, intptr_t sizes[N], bool norm) : MemRef() { } } } + #endif template int Img::channels() { diff --git a/tests/Interface/core/ImageContainerTest.cpp b/tests/Interface/core/ImageContainerTest.cpp index 442f79ca6c..f84bc4237f 100644 --- a/tests/Interface/core/ImageContainerTest.cpp +++ b/tests/Interface/core/ImageContainerTest.cpp @@ -24,6 +24,54 @@ #include #include +bool compare_flt(float a, float b) { return (std::abs(a - b) < FLT_EPSILON); } + +template +bool testImgcvnorm(cv::Mat testImgcv, Img testImg, bool norm = false, + intptr_t sizes[N] = nullptr) { + int cvn = testImgcv.dims; + if (cvn != N) + return false; + for (size_t i = 0; i < N; ++i) { + if (testImgcv.size[i] != testImg.getSizes()[i]) + return false; + } + T *data = testImg.getData(); + if (N == 2) { + size_t k = 0; + for (int i = 0; i < testImg.getSizes()[0]; ++i) { + for (int j = 0; j < testImg.getSizes()[1]; ++j) { + if (norm ? !compare_flt(data[k], (T)testImgcv.at(i, j)) + : !compare_flt(data[k], (T)testImgcv.at(i, j))) + return false; + + ++k; + } + } + return true; + } else if (N == 4) { + if (sizes == nullptr) { + return false; + } + size_t k = 0; + // NCHW layout + for (size_t batch = 0; batch < sizes[0]; ++batch) { + for (size_t channel = 0; channel < sizes[1]; ++channel) { + T *chandata = testImgcv.ptr(batch, channel); + for (size_t row = 0; row < sizes[2]; ++row) { + for (size_t col = 0; col < sizes[3]; ++col) { + if (!compare_flt(data[k], chandata[row * sizes[3] + col])) + return false; + + ++k; + } + } + } + } + return true; + } +} + int main() { // The original test image is a gray scale image, and the pixel values are as // follows: @@ -33,7 +81,7 @@ int main() { // 195.0, 210.0, 225.0, 240.0 // The test running directory is in /tests/Interface/core, so the // `imread` function uses the following relative path. - + //===--------------------------------------------------------------------===// // Test bmp format image. //===--------------------------------------------------------------------===// @@ -75,12 +123,10 @@ int main() { fprintf(stderr, "%ld\n", testCopyConstructor2.getSize()); // CHECK: 60.0 fprintf(stderr, "%f\n", testCopyConstructor2[3]); - Img testCopyConstructor3 = - Img(grayimage_bmp); + Img testCopyConstructor3 = Img(grayimage_bmp); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor3[0]); - Img *testCopyConstructor4 = - new Img(grayimage_bmp); + Img *testCopyConstructor4 = new Img(grayimage_bmp); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor4->getData()[0]); delete testCopyConstructor4; @@ -132,7 +178,50 @@ int main() { const Img testBracketOperator2(grayimage_bmp); // CHECK: 240.0 fprintf(stderr, "%f\n", testBracketOperator2[15]); + //===--------------------------------------------------------------------===// + // Test Opencv Image without norm + //===--------------------------------------------------------------------===// + cv::Mat testImgcvbmp = + cv::imread("../../../../tests/Interface/core/TestGrayImage.bmp", + cv::IMREAD_GRAYSCALE); + Img testImgbmp(testImgcvbmp); + bool testbmp = testImgcvnorm(testImgcvbmp, testImgbmp); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmp); + //===--------------------------------------------------------------------===// + // Test Opencv Image with norm + //===--------------------------------------------------------------------===// + Img testImgbmpnorm(testImgcvbmp, nullptr, true); + cv::Mat checkimgbmp(testImgcvbmp.rows, testImgcvbmp.cols, CV_32FC1); + testImgcvbmp.convertTo(checkimgbmp, CV_32FC1, 1.f / 255); + bool testbmp1 = testImgcvnorm(checkimgbmp, testImgbmpnorm, true); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmp1); + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) without norm (NCHW) + //===--------------------------------------------------------------------===// + std::vector testbmpvec = {testImgcvbmp, testImgcvbmp}; + cv::Mat testcvbmpblob = cv::dnn::blobFromImages( + testbmpvec, 1.0, cv::Size(testImgcvbmp.rows, testImgcvbmp.cols)); + intptr_t sizesbmp[4] = {testcvbmpblob.size[0], testcvbmpblob.size[1], + testcvbmpblob.size[2], testcvbmpblob.size[3]}; + Img testImgbmpblob(testcvbmpblob, sizesbmp, false); + bool testbmpN4 = + testImgcvnorm(testcvbmpblob, testImgbmpblob, false, sizesbmp); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmpN4); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) with norm (NCHW) + //===--------------------------------------------------------------------===// + cv::Mat testcvbmpblob2 = cv::dnn::blobFromImages( + testbmpvec, 1.0f / 255.0, cv::Size(testImgcvbmp.rows, testImgcvbmp.cols)); + Img testImgbmpblobnorm(testcvbmpblob, sizesbmp, true); + bool testbmpN4norm = testImgcvnorm( + testcvbmpblob2, testImgbmpblobnorm, true, sizesbmp); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmpN4norm); //===--------------------------------------------------------------------===// // Test jpeg format image. @@ -175,12 +264,10 @@ int main() { fprintf(stderr, "%ld\n", testCopyConstructor6.getSize()); // CHECK: 60.0 fprintf(stderr, "%f\n", testCopyConstructor6[3]); - Img testCopyConstructor7 = - Img(grayimage_jpg); + Img testCopyConstructor7 = Img(grayimage_jpg); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor7[0]); - Img *testCopyConstructor8 = - new Img(grayimage_jpg); + Img *testCopyConstructor8 = new Img(grayimage_jpg); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor8->getData()[0]); delete testCopyConstructor8; @@ -233,6 +320,51 @@ int main() { // CHECK: 240.0 fprintf(stderr, "%f\n", testBracketOperator4[15]); + //===--------------------------------------------------------------------===// + // Test Opencv Image without norm + //===--------------------------------------------------------------------===// + cv::Mat testImgcvjpg = + cv::imread("../../../../tests/Interface/core/TestGrayImage.jpg", + cv::IMREAD_GRAYSCALE); + Img testImgjpg(testImgcvjpg); + bool testjpg = testImgcvnorm(testImgcvjpg, testImgjpg); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpg); + + //===--------------------------------------------------------------------===// + // Test Opencv Image with norm + //===--------------------------------------------------------------------===// + Img testImgjpgnorm(testImgcvjpg, nullptr, true); + cv::Mat checkimgjpg(testImgcvjpg.rows, testImgcvjpg.cols, CV_32FC1); + testImgcvjpg.convertTo(checkimgjpg, CV_32FC1, 1.f / 255); + bool testjpg1 = testImgcvnorm(checkimgjpg, testImgjpgnorm, true); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpg1); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) without norm (NCHW) + //===--------------------------------------------------------------------===// + std::vector testjpgvec = {testImgcvjpg, testImgcvjpg}; + cv::Mat testcvjpgblob = cv::dnn::blobFromImages( + testjpgvec, 1.0, cv::Size(testImgcvjpg.rows, testImgcvjpg.cols)); + intptr_t sizesjpg[4] = {testcvjpgblob.size[0], testcvjpgblob.size[1], + testcvjpgblob.size[2], testcvjpgblob.size[3]}; + Img testImgjpgblob(testcvjpgblob, sizesjpg, false); + bool testjpgN4 = + testImgcvnorm(testcvjpgblob, testImgjpgblob, false, sizesjpg); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpgN4); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) with norm (NCHW) + //===--------------------------------------------------------------------===// + cv::Mat testcvjpgblob2 = cv::dnn::blobFromImages( + testjpgvec, 1.0f / 255.0, cv::Size(testImgcvjpg.rows, testImgcvjpg.cols)); + Img testImgjpgblobnorm(testcvjpgblob, sizesjpg, true); + bool testjpgN4norm = testImgcvnorm( + testcvjpgblob2, testImgjpgblobnorm, true, sizesjpg); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpgN4norm); //===--------------------------------------------------------------------===// // Test png format image. @@ -275,12 +407,10 @@ int main() { fprintf(stderr, "%ld\n", testCopyConstructor10.getSize()); // CHECK: 60.0 fprintf(stderr, "%f\n", testCopyConstructor10[3]); - Img testCopyConstructor11 = - Img(grayimage_png); + Img testCopyConstructor11 = Img(grayimage_png); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor11[0]); - Img *testCopyConstructor12 = - new Img(grayimage_png); + Img *testCopyConstructor12 = new Img(grayimage_png); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor12->getData()[0]); delete testCopyConstructor12; @@ -332,6 +462,52 @@ int main() { const Img testBracketOperator6(grayimage_png); // CHECK: 240.0 fprintf(stderr, "%f\n", testBracketOperator6[15]); - + + //===--------------------------------------------------------------------===// + // Test Opencv Image without norm + //===--------------------------------------------------------------------===// + cv::Mat testImgcvpng = + cv::imread("../../../../tests/Interface/core/TestGrayImage.png", + cv::IMREAD_GRAYSCALE); + Img testImgpng(testImgcvpng); + bool testpng = testImgcvnorm(testImgcvpng, testImgpng); + /// CHECK: 1 + fprintf(stderr, "%d \n", testpng); + + //===--------------------------------------------------------------------===// + // Test Opencv Image with norm + //===--------------------------------------------------------------------===// + Img testImgpngnorm(testImgcvpng, nullptr, true); + cv::Mat checkimgpng(testImgcvpng.rows, testImgcvpng.cols, CV_32FC1); + testImgcvpng.convertTo(checkimgpng, CV_32FC1, 1.f / 255); + bool testpng1 = testImgcvnorm(checkimgpng, testImgpngnorm, true); + // CHECK: 1 + fprintf(stderr, "%d \n", testpng1); + + ///===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) without norm (NCHW) + //===--------------------------------------------------------------------===// + std::vector testpngvec = {testImgcvpng, testImgcvpng}; + cv::Mat testcvpngblob = cv::dnn::blobFromImages( + testpngvec, 1.0, cv::Size(testImgcvpng.rows, testImgcvpng.cols)); + intptr_t sizespng[4] = {testcvpngblob.size[0], testcvpngblob.size[1], + testcvpngblob.size[2], testcvpngblob.size[3]}; + Img testImgpngblob(testcvpngblob, sizespng, false); + bool testpngN4 = + testImgcvnorm(testcvpngblob, testImgpngblob, false, sizespng); + // CHECK: 1 + fprintf(stderr, "%d \n", testpngN4); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) with norm (NCHW) + //===--------------------------------------------------------------------===// + cv::Mat testcvpngblob2 = cv::dnn::blobFromImages( + testpngvec, 1.0f / 255.0, cv::Size(testImgcvpng.rows, testImgcvpng.cols)); + Img testImgpngblobnorm(testcvpngblob, sizespng, true); + bool testpngN4norm = testImgcvnorm( + testcvpngblob2, testImgpngblobnorm, true, sizespng); + // CHECK: 1 + fprintf(stderr, "%d \n", testpngN4norm); + return 0; -} \ No newline at end of file +} From 5d9d52b392967cb13c06b2b54a9ce28e73ef3cbb Mon Sep 17 00:00:00 2001 From: zhxzh-2001 <70198007+zhxzh-2001@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:01:19 +0800 Subject: [PATCH 16/17] add pass for linalg matmultransposeb op (#377) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add pass for linalg matmultransposeb op * add example for linalg matmultransposeb op and check the code layout * check the code layout * vector.transfer_read -> vector.load * fuse tanspose+matmul * fix problems about pass MatMulTransposeBVec * deal with conflicts * change pass name * fix problem for check-buddy * correct the format,and removed some unnecessary changes to the pass. --------- Co-authored-by: “username” <“email”> --- .../linalg-transposematmulb-f32.mlir | 75 ++++++ examples/BuddyMatmul/makefile | 18 ++ frontend/Python/graph/transform/fuse_ops.py | 6 +- .../MatMulOptimization/CMakeLists.txt | 8 + .../MatMulTransposeBVec.cpp | 214 ++++++++++++++++++ tools/buddy-opt/CMakeLists.txt | 7 + tools/buddy-opt/buddy-opt.cpp | 4 + 7 files changed, 328 insertions(+), 4 deletions(-) create mode 100644 examples/BuddyMatmul/linalg-transposematmulb-f32.mlir create mode 100644 midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp diff --git a/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir b/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir new file mode 100644 index 0000000000..26a4458c53 --- /dev/null +++ b/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir @@ -0,0 +1,75 @@ +// RUN: buddy-opt %s \ +// RUN: -matmul-transpose-b-vectorization \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -expand-strided-metadata \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @printMemrefF32(memref<*xf32>) + +func.func @test(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c: memref) + return + } + +func.func @alloc_f32(%arg0: index, %arg1: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + memref.store %arg4, %0[%idx0, %idx1] : memref + } + } + return %0 : memref +} + +func.func @main(){ + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %c3 = arith.constant 3 : index + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + + %m0 = call @alloc_f32(%c32,%c1024, %f1) : (index, index, f32) -> memref + %m1 = call @alloc_f32(%c32,%c1024, %f1) : (index, index, f32) -> memref + %m2 = call @alloc_f32(%c32,%c32, %f0) : (index, index, f32) -> memref + + call @test(%m0, %m1, %m2) : (memref, memref, memref) -> () + + %printed_m2 = memref.cast %m2 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [32, 32] strides = [32, 1] data = + // CHECK-NEXT: [ + // CHECK: [1024{{(, 1024)*}}] + call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> () + + %m3 = call @alloc_f32(%c3,%c3, %f1) : (index, index, f32) -> memref + %m4 = call @alloc_f32(%c3,%c3, %f1) : (index, index, f32) -> memref + %m5 = call @alloc_f32(%c3,%c3, %f0) : (index, index, f32) -> memref + + call @test(%m3, %m4, %m5) : (memref, memref, memref) -> () + + %printed_m5 = memref.cast %m5 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] data = + // CHECK-NEXT: [ + // CHECK: [3{{(, 3)*}}] + call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () + + return +} diff --git a/examples/BuddyMatmul/makefile b/examples/BuddyMatmul/makefile index 812e68b150..0940d608da 100644 --- a/examples/BuddyMatmul/makefile +++ b/examples/BuddyMatmul/makefile @@ -35,3 +35,21 @@ linalg-batchmatmul-f32-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-matmul-transpose-b-f32-run: + @${BUDDY_OPT} ./linalg-transposematmulb-f32.mlir\ + -matmul-transpose-b-vectorization \ + -convert-linalg-to-affine-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -expand-strided-metadata \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index 61f6a5b54a..ac7d34c99c 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -19,14 +19,14 @@ # ===--------------------------------------------------------------------------- from .. import Graph -from ..operation import PlaceholderOp, OpType +from ..operation import * from .. import DeviceType # TODO: classify op type for op fusion # OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType] # OP_TYPE_UNFUSABLE = [OpType.Unfusable, OpType.ConcatType] # OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = [] -# ANCHOR_OP_TYPE = [] +# ANCHOR_OP_TYPE = [] def simply_fuse(graph: Graph): """ @@ -47,5 +47,3 @@ def simply_fuse(graph: Graph): graph.op_groups = {} graph.op_groups["subgraph0"] = new_op_group graph.group_map_device = {"subgraph0": device} - - diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index 776a766b47..7ec2cf4ac4 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -5,6 +5,10 @@ add_mlir_library(MatMulOptimization BatchMatMulOptimize.cpp BatchMatMulTileOptimize.cpp BatchMatMulSCFOptimize.cpp + MatMulTransposeBVec.cpp + BatchMatMulOptimize.cpp + BatchMatMulTileOptimize.cpp + BatchMatMulSCFOptimize.cpp LINK_LIBS PUBLIC BuddyUtils ) @@ -16,3 +20,7 @@ add_mlir_library(BatchMatMulOptimization add_mlir_library(MatMulParallelVectorization MatMulParallelVectorization.cpp ) + +add_mlir_library(MatMulTransposeBVec + MatMulTransposeBVec.cpp +) diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp new file mode 100644 index 0000000000..4500119d76 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp @@ -0,0 +1,214 @@ +//===- MatMulTransposeBVec.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Matmul_TransposeB vectorization. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Utils/Utils.h" + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { +class MatMulTransposeBVecPattern : public ConversionPattern{ +public: + explicit MatMulTransposeBVecPattern(MLIRContext *context,int64_t vecSizeparam) + : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(),1,context){ + vecSize = vecSizeparam; + } + + LogicalResult + matchAndRewrite(Operation *op,ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override{ + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Get shape of input and output. + ShapedType ATy = A.getType().cast(); + Type eleTy = ATy.getElementType(); + + // the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + + VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); + VectorType vectorMaskTy = VectorType::get({vecSize}, i1); + + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value step = rewriter.create(loc, vecSize); + + const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); + Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); + + const Value aRow = rewriter.create(loc, A, c0); + const Value bRow = rewriter.create(loc, B, c0); + const Value bCol = rewriter.create(loc, B, c1); + + AffineExpr d0; + bindDims(ctx, d0); + AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); + SmallVector lowerBounds(2, c0); + SmallVector uperBounds{aRow, bRow}; + SmallVector steps(2, 1); + // clang-format off + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create loop based on vector size. + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{bCol}, vecTailMap, 1, std::nullopt, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + AffineExpr a,b,c; + bindDims(ctx, a,b,c); + AffineMap AVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); + // Check tail. + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + AffineMap BVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); + + // Calculate the tail. + Value bColCur = builder.create(loc, iv, step); + Value tailLen = builder.create(loc, bCol, bColCur); + Value tailFlag = rewriter.create( + loc, arith::CmpIPredicate::sge, tailLen, step); + // If the current column does not reach the tail. + builder.create(loc, tailFlag, + [&](OpBuilder &builder, Location loc) { + Value aVec = builder.create( + loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); + Value bVec = builder.create( + loc, vectorTy, B, BVectorMap, ValueRange{ivs[1], ivs[1], iv}); + Value resvec = builder.create(loc,aVec,bVec); + Value res1 = builder.create( + loc,mlir::vector::CombiningKind::ADD,resvec); + Value res2 = builder.create( + loc, C, ValueRange{ivs[0], ivs[1]}); + Value sum = builder.create(loc, res1, res2); + builder.create(loc, sum, + C, ValueRange{ivs[0], ivs[1]}); + builder.create(loc); + }, + // The else branch + [&](OpBuilder &builder, Location loc) { + Value aVec = builder.create( + loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); + // Create mask according to the tail. + Value maskVec = builder.create( + loc, vectorMaskTy, tailLen); + Value ColIdxTail = builder.create(loc, iv, step); + + Value aVecTail = builder.create( + loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, + maskVec, passthruVec); + + Value bVecTail = builder.create( + loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, + maskVec, passthruVec); + + Value resvec = builder.create(loc,aVecTail,bVecTail); + Value res1 = builder.create( + loc,mlir::vector::CombiningKind::ADD,resvec); + Value res2 = builder.create( + loc, C, ValueRange{ivs[0], ivs[1]}); + Value sum = builder.create(loc, res1, res2); + builder.create(loc, sum, C, ValueRange{ivs[0], ivs[1]}); + builder.create(loc); + }); + builder.create(loc); + }); + }); + // clang-format on + rewriter.eraseOp(op); + return success(); + } +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// MatMulVectorizationPass +//===----------------------------------------------------------------------===// + +namespace{ + class MatMulTransposeBVecPass + :public PassWrapper>{ +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) + StringRef getArgument() const final{ return "matmul-transpose-b-vectorization"; } + StringRef getDescription() const final { return "vectorize linalg MatmulTransposeBOp"; } + MatMulTransposeBVecPass() = default; + MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override{ + registry.insert(); + } + Option vecSize{*this,"vec-size", + llvm::cl::desc("The size of vectorization"), + llvm::cl::init(32)}; + +}; +} + +void MatMulTransposeBVecPass::runOnOperation(){ + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context,vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerMatMulTransposeBVecPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 325e1a20d0..0abb857fad 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -44,4 +44,11 @@ target_link_libraries(buddy-opt MLIRTestTransformDialect MLIRTransforms MLIRTransformUtils + MatMulTransposeBVec + MLIRGPUPasses + BuddyGPUTransformOPs + MLIRTestTransforms + MLIRTestTransformDialect + MLIRTransforms + MLIRTransformUtils ) diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 50f65ba0a1..08e172f8bc 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -80,6 +80,9 @@ void registerLowerSchePass(); void registerFuncBufferizeDynamicOffsetPass(); void registerConvertMemcpyToGPUPass(); void registerLegalizeShmemOutliningPass(); +void registerMatMulTransposeBVecPass(); +void registerConvertMemcpyToGPUPass(); +void registerLegalizeShmemOutliningPass(); } // namespace buddy } // namespace mlir @@ -117,6 +120,7 @@ int main(int argc, char **argv) { mlir::buddy::registerDeviceSchedulePass(); mlir::buddy::registerLowerSchePass(); mlir::buddy::registerFuncBufferizeDynamicOffsetPass(); + mlir::buddy::registerMatMulTransposeBVecPass(); // Register gpu passes mlir::buddy::registerConvertMemcpyToGPUPass(); From 2b2a8dfb54bb00263e59e7196e99d41b10332cc3 Mon Sep 17 00:00:00 2001 From: BrokenArrow Date: Tue, 22 Oct 2024 11:28:36 +0800 Subject: [PATCH 17/17] [Midend] Add RFFT op in Extend DAP Pass (#387) --- examples/DAPDialect/CMakeLists.txt | 7 + examples/DAPDialect/RFFT.cpp | 75 + .../buddy/DAP/DSP/WhisperPreprocess.h | 7 + frontend/Interfaces/lib/DAP-extend.mlir | 4 + midend/include/Dialect/DAP/DAPOps.td | 6 +- .../Conversion/ExtendDAP/ExtendDAPPass.cpp | 2960 +++++++++++++++-- 6 files changed, 2715 insertions(+), 344 deletions(-) create mode 100644 examples/DAPDialect/RFFT.cpp diff --git a/examples/DAPDialect/CMakeLists.txt b/examples/DAPDialect/CMakeLists.txt index dff9b10ffb..96b921ee3a 100644 --- a/examples/DAPDialect/CMakeLists.txt +++ b/examples/DAPDialect/CMakeLists.txt @@ -62,3 +62,10 @@ target_link_libraries(buddy-whisper-preprocess BuddyLibDAP mlir_c_runner_utils ) + +add_executable(buddy-rfft RFFT.cpp) +add_dependencies(buddy-rfft buddy-opt) +target_link_libraries(buddy-rfft + BuddyLibDAP + mlir_c_runner_utils +) diff --git a/examples/DAPDialect/RFFT.cpp b/examples/DAPDialect/RFFT.cpp new file mode 100644 index 0000000000..993fec95e1 --- /dev/null +++ b/examples/DAPDialect/RFFT.cpp @@ -0,0 +1,75 @@ +//===- RFFT.cpp - Example of DAP RFFT Operation ---------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// An example of the RFFT function from Whisper Preprocessor operation. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#define testLength 840 + +using namespace dap; +using namespace std; + +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +// Write preprocessing results to a text file. +void printResult(MemRef &outputMemRef) { + ofstream fout("whisperPreprocessResultRFFT.txt"); + // Print title. + fout << "-----------------------------------------" << std::endl; + fout << "[ Buddy RFFT Result ]" << std::endl; + fout << "-----------------------------------------" << std::endl; + // Print reuslt data. + for (int i = 0; i < testLength; ++i) { + fout << outputMemRef[i] << std::endl; + } + fout.close(); +} + +int main() { + // Print the title of this example. + const std::string title = "RFFT Operation Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + double *inputAlign = new double[testLength]; + for (int i = 0; i < testLength; ++i) { + inputAlign[i] = static_cast(i); + } + intptr_t inputSizes[1] = {testLength}; + MemRef inputMemRef(inputAlign, inputSizes); + + printLogLabel(); + std::cout << "Running RFFT operation" << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::RFFT(&inputMemRef); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "RFFT time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; + + printResult(inputMemRef); + + return 0; +} diff --git a/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h index a6c3ef3b2e..d0d1d8fb63 100644 --- a/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h +++ b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h @@ -40,6 +40,9 @@ extern "C" { // first operand. void _mlir_ciface_buddy_whisperPreprocess(MemRef *outputFeatures, MemRef *inputRawSpeech); + +void _mlir_ciface_buddy_RFFT(MemRef *inputRawSpeech); + } } // namespace detail @@ -49,6 +52,10 @@ void whisperPreprocess(MemRef *inputRawSpeech, detail::_mlir_ciface_buddy_whisperPreprocess(outputFeatures, inputRawSpeech); } + +void RFFT(MemRef *inputRawSpeech) { + detail::_mlir_ciface_buddy_RFFT(inputRawSpeech); +} } // namespace dap #endif // FRONTEND_INTERFACES_BUDDY_DAP_DSP_WHISPERPREPROCESS diff --git a/frontend/Interfaces/lib/DAP-extend.mlir b/frontend/Interfaces/lib/DAP-extend.mlir index c77fe38735..2c9b7a5a3b 100644 --- a/frontend/Interfaces/lib/DAP-extend.mlir +++ b/frontend/Interfaces/lib/DAP-extend.mlir @@ -2,3 +2,7 @@ func.func @buddy_whisperPreprocess(%in : memref) -> memref<1x80x3000xf32> %out = dap.whisper_preprocess %in : memref to memref<1x80x3000xf32> return %out : memref<1x80x3000xf32> } +func.func @buddy_RFFT(%in : memref) -> () { + dap.rfft %in : memref + return +} diff --git a/midend/include/Dialect/DAP/DAPOps.td b/midend/include/Dialect/DAP/DAPOps.td index 70d7a21fe6..d14ca5cfcd 100644 --- a/midend/include/Dialect/DAP/DAPOps.td +++ b/midend/include/Dialect/DAP/DAPOps.td @@ -93,8 +93,8 @@ def DAP_IirOp : DAP_Op<"iir"> { }]; } -def DAP_RFFT400Op : DAP_Op<"rfft400"> { - let summary = "RFFT operation for length 400."; +def DAP_RFFTOp : DAP_Op<"rfft"> { + let summary = "RFFT operation."; let description = [{ The RFFT algorithm is designed to handle real-valued input signals. Real signals exhibit conjugate symmetry in the frequency domain, meaning that @@ -105,7 +105,7 @@ def DAP_RFFT400Op : DAP_Op<"rfft400"> { Example: ```mlir - dap.rfft400 %data : memref<400xf64> + dap.rfft %data : memref ``` }]; diff --git a/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp index 20918fda97..32fc42fcf7 100644 --- a/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp +++ b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp @@ -38,6 +38,7 @@ using namespace vector; using namespace mlir::arith; using namespace mlir::linalg; using namespace mlir::bufferization; +using namespace mlir::scf; //===----------------------------------------------------------------------===// // Rewrite Pattern @@ -756,6 +757,28 @@ Value padReflect(PatternRewriter &rewriter, Location loc, Value c0, Value c1, return padOp2.getResult(); } +// function to print a memref for debug +void printMemref(OpBuilder &rewriter, Location loc, Value input, int l) { + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value length = rewriter.create(loc, l); + rewriter.create(loc, "Print Start:\n"); + + rewriter.create( + loc, c0, length, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value x = b.create(loc, input, i); + b.create(loc, x); + + b.create(loc, std::nullopt); + }); + + rewriter.create(loc, "\n"); +} + +// WA CC CH PM MULPM C1 C1w C2 CH2 CH2w CH_radfg CCw CSARR AR AI IANG are helper +// functions for RFFTP inline Value WA(OpBuilder &builder, Location loc, Value wa, Value x, Value i, Value ido, Value c1) { Value idom1 = builder.create(loc, ido, c1); @@ -799,15 +822,695 @@ inline std::vector MULPM(OpBuilder &builder, Location loc, Value c, builder.create(loc, tmp3, tmp4)}; } -void radf4Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, - Value wa, Value ido, Value l1, Value cdim, Value c0, Value c1, - Value c2, Value c3) { +inline Value C1(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value l1) { + Value tmp1 = builder.create(loc, l1, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + return builder.create(loc, cc, index); +} + +inline void C1w(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value l1, Value toWrite) { + Value tmp1 = builder.create(loc, l1, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + builder.create(loc, toWrite, cc, index); + return; +} + +inline Value C2(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value idl1) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + return builder.create(loc, cc, index); +} + +inline Value CH2(OpBuilder &builder, Location loc, Value ch, Value a, Value b, + Value idl1) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + return builder.create(loc, ch, index); +} + +inline void CH2w(OpBuilder &builder, Location loc, Value ch, Value a, Value b, + Value idl1, Value toWrite) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + builder.create(loc, toWrite, ch, index); + return; +} + +inline Value CH_radfg(OpBuilder &builder, Location loc, Value ch, Value a, + Value b, Value c, Value ido, Value l1) { + Value tmp = builder.create(loc, l1, c); + Value tmp1 = builder.create(loc, b, tmp); + Value tmp2 = builder.create(loc, tmp1, ido); + Value index = builder.create(loc, tmp2, a); + return builder.create(loc, ch, index); +} + +inline void CCw(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value cdim, Value toWrite) { + Value tmp = builder.create(loc, cdim, c); + Value tmp1 = builder.create(loc, b, tmp); + Value tmp2 = builder.create(loc, tmp1, ido); + Value index = builder.create(loc, tmp2, a); + builder.create(loc, toWrite, cc, index); + return; +} + +inline Value CSARR(OpBuilder &builder, Location loc, Value csarr, Value index) { + + return builder.create(loc, csarr, index); +} + +inline Value AR(OpBuilder &builder, Location loc, Value csarr, Value iang) { + Value c2 = builder.create(loc, 2); + Value index = builder.create(loc, iang, c2); + return CSARR(builder, loc, csarr, index); +} + +inline Value AI(OpBuilder &builder, Location loc, Value csarr, Value iang) { + Value c1 = builder.create(loc, 1); + Value c2 = builder.create(loc, 2); + Value tmp = builder.create(loc, iang, c2); + Value index = builder.create(loc, tmp, c1); + return CSARR(builder, loc, csarr, index); +} + +inline Value IANG(OpBuilder &builder, Location loc, Value iang, Value l, + Value ip) { + + Value iang_new = builder.create(loc, iang, l); + + Value condition = builder.create( + loc, arith::CmpIPredicate::sge, iang_new, ip); + + auto result = builder.create( + loc, condition, + [&](OpBuilder &b, Location loc) { + Value res = b.create(loc, iang_new, ip); + b.create(loc, ValueRange{res}); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{iang_new}); + }); + + return result.getResult(0); +} + +void radfgExtend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value csarr, Value ido, Value ip, Value l1) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value cdim = opBuilder.create(loc, ip, c0); + Value tmp0 = opBuilder.create(loc, ip, c1); + Value ipph = opBuilder.create(loc, tmp0, c2); + Value idom1 = opBuilder.create(loc, ido, c1); + Value idom2 = opBuilder.create(loc, ido, c2); + Value idl1 = opBuilder.create(loc, ido, l1); + + opBuilder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value ik, ValueRange ik_args) { + Value c2ik0 = C2(builder, loc, cc, ik, c0, idl1); + CH2w(builder, loc, ch, ik, c0, idl1, c2ik0); + builder.create(loc, std::nullopt); + }); + + opBuilder.create( + loc, c1, ipph, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value j, ValueRange j_args) { + builder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value ik, ValueRange ik_args) { + Value c2ikj = C2(b, loc, cc, ik, j, idl1); + Value ch2ik0 = CH2(b, loc, ch, ik, c0, idl1); + Value ch2ik0_updated = + b.create(loc, ch2ik0, c2ikj); + + CH2w(b, loc, ch, ik, c0, idl1, ch2ik0_updated); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + opBuilder.create( loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + builder.create( + loc, c0, ido, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value chik0 = CH_radfg(b, loc, ch, i, k, c0, ido, l1); + + CCw(b, loc, cc, i, c0, k, ido, cdim, chik0); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + + Value j_start_0 = opBuilder.create(loc, 1); + Value jc_start_0 = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{j_start_0, jc_start_0}, + [&](OpBuilder &builder, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + + Value tmp = builder.create(loc, j, c2); + Value j2 = builder.create(loc, tmp, c1); + Value j2p1 = builder.create(loc, j2, c1); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k, ValueRange k_args) { + Value ch0kj = CH_radfg(b, loc, ch, c0, k, j, ido, l1); + CCw(b, loc, cc, idom1, j2, k, ido, cdim, ch0kj); + + Value ch0kjc = CH_radfg(b, loc, ch, c0, k, jc, ido, l1); + CCw(b, loc, cc, c0, j2p1, k, ido, cdim, ch0kjc); + + b.create(loc, std::nullopt); + }); + + Value j_next = builder.create(loc, j, c1); + Value jc_next = builder.create(loc, jc, c1); + builder.create(loc, std::vector{j_next, jc_next}); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::ne, ido, l1); + + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + Value j_start_1 = opBuilder.create(loc, 1); + Value jc_start_1 = opBuilder.create(loc, ip, c1); + + builder.create( + loc, c1, ipph, c1, ValueRange{j_start_1, jc_start_1}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + + Value tmp = b.create(loc, j, c2); + Value j2 = b.create(loc, tmp, c1); + Value j2p1 = b.create(loc, j2, c1); + + b.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value k, ValueRange k_args) { + Value i_start_0 = b2.create(loc, 1); + Value ic_start_0 = b2.create(loc, ido, c3); + + b2.create( + loc, c1, idom1, c2, ValueRange{i_start_0, ic_start_0}, + [&](OpBuilder &b3, Location loc, Value i_loop, + ValueRange i_loop_args) { + Value i = i_loop_args[0]; + Value ic = i_loop_args[1]; + + Value ip1 = b3.create(loc, i, c1); + Value icp1 = b3.create(loc, ic, c1); + + Value chikj = CH_radfg(b3, loc, ch, i, k, j, ido, l1); + Value chikjc = + CH_radfg(b3, loc, ch, i, k, jc, ido, l1); + Value tmp2 = + b3.create(loc, chikj, chikjc); + Value tmp3 = + b3.create(loc, chikj, chikjc); + CCw(b3, loc, cc, i, j2p1, k, ido, cdim, tmp2); + CCw(b3, loc, cc, ic, j2, k, ido, cdim, tmp3); + + Value chip1kj = + CH_radfg(b3, loc, ch, ip1, k, j, ido, l1); + Value chip1kjc = + CH_radfg(b3, loc, ch, ip1, k, jc, ido, l1); + Value tmp4 = + b3.create(loc, chip1kj, chip1kjc); + Value tmp5 = + b3.create(loc, chip1kjc, chip1kj); + CCw(b3, loc, cc, ip1, j2p1, k, ido, cdim, tmp4); + CCw(b3, loc, cc, icp1, j2, k, ido, cdim, tmp5); + + Value i_next = b3.create(loc, i, c2); + Value ic_next = b3.create(loc, ic, c2); + b3.create( + loc, std::vector{i_next, ic_next}); + }); + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c1); + Value jc_next = b.create(loc, jc, c1); + b.create(loc, std::vector{j_next, jc_next}); + }); + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle general radix FFT computation. +void radfg(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value csarr, Value ido, Value ip, Value l1) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value ipm1 = opBuilder.create(loc, ip, c1); + Value ipm2 = opBuilder.create(loc, ip, c2); + + Value cdim = opBuilder.create(loc, ip, c0); + Value tmp = opBuilder.create(loc, ip, c1); + Value ipph = opBuilder.create(loc, tmp, c2); + + Value idl1 = opBuilder.create(loc, ido, l1); + Value idom1 = opBuilder.create(loc, ido, c1); + Value idom2 = opBuilder.create(loc, ido, c2); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, l1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value jc_start = builder.create(loc, ip, c1); + + builder.create( + loc, c1, ipph, c1, ValueRange{jc_start}, + [&](OpBuilder &b, Location loc, Value j, ValueRange j_args) { + Value jc = j_args[0]; + + Value jm1 = b.create(loc, j, c1); + Value jcm1 = b.create(loc, jc, c1); + + Value is = b.create(loc, jm1, idom1); + Value is2 = b.create(loc, jcm1, idom1); + + b.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value k, ValueRange k_args) { + Value idij_start = b2.create(loc, is, c0); + Value idij2_start = b2.create(loc, is2, c0); + + b2.create( + loc, c1, idom1, c2, ValueRange{idij_start, idij2_start}, + [&](OpBuilder &b3, Location loc, Value i, + ValueRange i_args) { + Value idij = i_args[0]; + Value idij2 = i_args[1]; + + Value ip1 = b3.create(loc, i, c1); + Value idijp1 = + b3.create(loc, idij, c1); + Value idij2p1 = + b3.create(loc, idij2, c1); + + Value t1 = C1(b3, loc, cc, i, k, j, ido, l1); + Value t2 = C1(b3, loc, cc, ip1, k, j, ido, l1); + Value t3 = C1(b3, loc, cc, i, k, jc, ido, l1); + Value t4 = C1(b3, loc, cc, ip1, k, jc, ido, l1); + + Value waidij = + b3.create(loc, wa, idij); + Value waidijp1 = + b3.create(loc, wa, idijp1); + Value waidij2 = + b3.create(loc, wa, idij2); + Value waidij2p1 = + b3.create(loc, wa, idij2p1); + + Value tmp1_x1 = + b3.create(loc, waidij, t1); + Value tmp2_x1 = + b3.create(loc, waidijp1, t2); + Value x1 = + b3.create(loc, tmp1_x1, tmp2_x1); + + Value tmp1_x2 = + b3.create(loc, waidij, t2); + Value tmp2_x2 = + b3.create(loc, waidijp1, t1); + Value x2 = + b3.create(loc, tmp1_x2, tmp2_x2); + + Value tmp1_x3 = + b3.create(loc, waidij2, t3); + Value tmp2_x3 = + b3.create(loc, waidij2p1, t4); + Value x3 = + b3.create(loc, tmp1_x3, tmp2_x3); + + Value tmp1_x4 = + b3.create(loc, waidij2, t4); + Value tmp2_x4 = + b3.create(loc, waidij2p1, t3); + Value x4 = + b3.create(loc, tmp1_x4, tmp2_x4); + + Value tmp3 = b3.create(loc, x1, x3); + Value tmp4 = b3.create(loc, x2, x4); + Value tmp5 = b3.create(loc, x2, x4); + Value tmp6 = b3.create(loc, x3, x1); + + C1w(b3, loc, cc, i, k, j, ido, l1, tmp3); + C1w(b3, loc, cc, i, k, jc, ido, l1, tmp4); + C1w(b3, loc, cc, ip1, k, j, ido, l1, tmp5); + C1w(b3, loc, cc, ip1, k, jc, ido, l1, tmp6); + + Value idij_next = + b3.create(loc, idij, c2); + Value idij2_next = + b3.create(loc, idij2, c2); + + b3.create( + loc, std::vector{idij_next, idij2_next}); + }); + b2.create(loc, std::nullopt); + } + + ); + + Value jc_next = b.create(loc, jc, c1); + b.create(loc, jc_next); + }); + + builder.create(loc, std::nullopt); + }); + + Value jc_a_start = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{jc_a_start}, + [&](OpBuilder &builder, Location loc, Value j_a, ValueRange j_a_args) { + Value jc_a = j_a_args[0]; + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k_a, ValueRange k_a_args) { + Value t1_a = C1(b, loc, cc, c0, k_a, j_a, ido, l1); + Value t2_a = C1(b, loc, cc, c0, k_a, jc_a, ido, l1); + + Value tmp_a = b.create(loc, t1_a, t2_a); + Value tmp1_a = b.create(loc, t2_a, t1_a); + + C1w(b, loc, cc, c0, k_a, j_a, ido, l1, tmp_a); + C1w(b, loc, cc, c0, k_a, jc_a, ido, l1, tmp1_a); + b.create(loc, std::nullopt); + }); + + Value jc_a_next = builder.create(loc, jc_a, c1); + builder.create(loc, jc_a_next); + }); + + Value lc_b_start = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{lc_b_start}, + [&](OpBuilder &builder, Location loc, Value l_b, ValueRange l_b_args) { + Value lc_b = l_b_args[0]; + + builder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value ik_b, ValueRange ik_b_args) { + Value m2l = b.create(loc, l_b, c2); + Value m4l = b.create(loc, l_b, c4); + Value m2lp1 = b.create(loc, m2l, c1); + Value m4lp1 = b.create(loc, m4l, c1); + + Value csarr2l = CSARR(b, loc, csarr, m2l); + Value csarr4l = CSARR(b, loc, csarr, m4l); + Value csarr2lp1 = CSARR(b, loc, csarr, m2lp1); + Value csarr4lp1 = CSARR(b, loc, csarr, m4lp1); + + Value c2ik0 = C2(b, loc, cc, ik_b, c0, idl1); + Value c2ik1 = C2(b, loc, cc, ik_b, c1, idl1); + Value c2ik2 = C2(b, loc, cc, ik_b, c2, idl1); + + Value c2ikipm1 = C2(b, loc, cc, ik_b, ipm1, idl1); + Value c2ikipm2 = C2(b, loc, cc, ik_b, ipm2, idl1); + + Value tmp_b = b.create(loc, csarr2l, c2ik1); + Value tmp1_b = b.create(loc, csarr4l, c2ik2); + Value tmp2_b = b.create(loc, tmp_b, tmp1_b); + Value tmp3_b = b.create(loc, c2ik0, tmp2_b); + + CH2w(b, loc, ch, ik_b, l_b, idl1, tmp3_b); + + Value tmp4_b = b.create(loc, csarr2lp1, c2ikipm1); + Value tmp5_b = b.create(loc, csarr4lp1, c2ikipm2); + Value tmp6_b = b.create(loc, tmp4_b, tmp5_b); + + CH2w(b, loc, ch, ik_b, lc_b, idl1, tmp6_b); + b.create(loc, std::nullopt); + }); + + Value iang_start_c = builder.create(loc, c2, l_b); + Value j_start_c = builder.create(loc, 3); + Value jc_start_c = builder.create(loc, ip, c3); + Value ipphm1 = builder.create(loc, ipph, c1); + Value ipphm3 = builder.create(loc, ipph, c3); + + auto loop1 = builder.create( + loc, j_start_c, ipphm3, c4, + ValueRange{j_start_c, jc_start_c, iang_start_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_c = IANG(b, loc, iang, l_b, ip); + Value ar1 = AR(b, loc, csarr, iang_1_c); + Value ai1 = AI(b, loc, csarr, iang_1_c); + + Value iang_2_c = IANG(b, loc, iang_1_c, l_b, ip); + Value ar2 = AR(b, loc, csarr, iang_2_c); + Value ai2 = AI(b, loc, csarr, iang_2_c); + + Value iang_3_c = IANG(b, loc, iang_2_c, l_b, ip); + Value ar3 = AR(b, loc, csarr, iang_3_c); + Value ai3 = AI(b, loc, csarr, iang_3_c); + + Value iang_4_c = IANG(b, loc, iang_3_c, l_b, ip); + Value ar4 = AR(b, loc, csarr, iang_4_c); + Value ai4 = AI(b, loc, csarr, iang_4_c); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_c, + ValueRange ik_c_args) { + Value jp1 = b2.create(loc, j, c1); + Value jp2 = b2.create(loc, j, c2); + Value jp3 = b2.create(loc, j, c3); + Value jm1 = b2.create(loc, j, c1); + Value jm2 = b2.create(loc, j, c2); + Value jm3 = b2.create(loc, j, c3); + + Value c2ikj = C2(b2, loc, cc, ik_c, j, idl1); + Value c2ikjp1 = C2(b2, loc, cc, ik_c, jp1, idl1); + Value c2ikjp2 = C2(b2, loc, cc, ik_c, jp2, idl1); + Value c2ikjp3 = C2(b2, loc, cc, ik_c, jp3, idl1); + + Value tmp_c = b2.create(loc, ar1, c2ikj); + Value tmp1_c = b2.create(loc, ar2, c2ikjp1); + Value tmp2_c = b2.create(loc, ar3, c2ikjp2); + Value tmp3_c = b2.create(loc, ar4, c2ikjp3); + + Value tmp4_c = b2.create(loc, tmp_c, tmp1_c); + Value tmp5_c = + b2.create(loc, tmp4_c, tmp2_c); + Value tmp6_c = + b2.create(loc, tmp5_c, tmp3_c); + + Value ch2ikl = CH2(b2, loc, ch, ik_c, l_b, idl1); + Value tmp7_c = + b2.create(loc, tmp6_c, ch2ikl); + CH2w(b2, loc, ch, ik_c, l_b, idl1, tmp7_c); + + Value jcm1 = b2.create(loc, jc, c1); + Value jcm2 = b2.create(loc, jc, c2); + Value jcm3 = b2.create(loc, jc, c3); + + Value c2ikjc = C2(b2, loc, cc, ik_c, jc, idl1); + Value c2ikjcm1 = C2(b2, loc, cc, ik_c, jcm1, idl1); + Value c2ikjcm2 = C2(b2, loc, cc, ik_c, jcm2, idl1); + Value c2ikjcm3 = C2(b2, loc, cc, ik_c, jcm3, idl1); + + Value tmp_ai1 = b2.create(loc, ai1, c2ikjc); + Value tmp_ai2 = + b2.create(loc, ai2, c2ikjcm1); + Value tmp_ai3 = + b2.create(loc, ai3, c2ikjcm2); + Value tmp_ai4 = + b2.create(loc, ai4, c2ikjcm3); + + Value tmp_ai5 = + b2.create(loc, tmp_ai1, tmp_ai2); + Value tmp_ai6 = + b2.create(loc, tmp_ai5, tmp_ai3); + Value tmp_ai7 = + b2.create(loc, tmp_ai6, tmp_ai4); + + Value ch2iklc = CH2(b2, loc, ch, ik_c, lc_b, idl1); + Value tmp_ai8 = + b2.create(loc, tmp_ai7, ch2iklc); + CH2w(b2, loc, ch, ik_c, lc_b, idl1, tmp_ai8); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c4); + Value jc_next = b.create(loc, jc, c4); + builder.create( + loc, std::vector{j_next, jc_next, iang_4_c}); + }); + + Value j_1_c = loop1.getResults()[0]; + Value jc_1_c = loop1.getResults()[1]; + Value iang1_c = loop1.getResults()[2]; + + auto loop2 = builder.create( + loc, j_1_c, ipphm1, c2, ValueRange{j_1_c, jc_1_c, iang1_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_d = IANG(b, loc, iang, l_b, ip); + Value ar1 = AR(b, loc, csarr, iang_1_d); + Value ai1 = AI(b, loc, csarr, iang_1_d); + + Value iang_2_d = IANG(b, loc, iang_1_d, l_b, ip); + Value ar2 = AR(b, loc, csarr, iang_2_d); + Value ai2 = AI(b, loc, csarr, iang_2_d); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_d, + ValueRange ik_d_args) { + Value jp1 = b2.create(loc, j, c1); + Value jm1 = b2.create(loc, j, c1); + + Value c2ikj = C2(b2, loc, cc, ik_d, j, idl1); + Value c2ikjp1 = C2(b2, loc, cc, ik_d, jp1, idl1); + + Value tmp_c = b2.create(loc, ar1, c2ikj); + Value tmp1_c = b2.create(loc, ar2, c2ikjp1); + Value tmp2_c = b2.create(loc, tmp_c, tmp1_c); + + Value ch2ikl = CH2(b2, loc, ch, ik_d, l_b, idl1); + Value tmp3_c = + b2.create(loc, tmp2_c, ch2ikl); + CH2w(b2, loc, ch, ik_d, l_b, idl1, tmp3_c); + + Value jcm1 = b2.create(loc, jc, c1); + Value c2ikjc = C2(b2, loc, cc, ik_d, jc, idl1); + Value c2ikjcm1 = C2(b2, loc, cc, ik_d, jcm1, idl1); + + Value tmp_ai1 = b2.create(loc, ai1, c2ikjc); + Value tmp_ai2 = + b2.create(loc, ai2, c2ikjcm1); + Value tmp_ai3 = + b2.create(loc, tmp_ai1, tmp_ai2); + + Value ch2iklc = CH2(b2, loc, ch, ik_d, lc_b, idl1); + Value tmp_ai4 = + b2.create(loc, tmp_ai3, ch2iklc); + CH2w(b2, loc, ch, ik_d, lc_b, idl1, tmp_ai4); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c2); + Value jc_next = b.create(loc, jc, c2); + builder.create( + loc, std::vector{j_next, jc_next, iang_2_d}); + }); + + Value j_2_c = loop2.getResults()[0]; + Value jc_2_c = loop2.getResults()[1]; + Value iang2_c = loop2.getResults()[2]; + + auto loop3 = builder.create( + loc, j_2_c, ipph, c1, ValueRange{j_2_c, jc_2_c, iang2_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_e = IANG(b, loc, iang, l_b, ip); + Value ar = AR(b, loc, csarr, iang_1_e); + Value ai = AI(b, loc, csarr, iang_1_e); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_e, + ValueRange ik_e_args) { + Value c2ikj = C2(b2, loc, cc, ik_e, j, idl1); + Value tmp_c = b2.create(loc, ar, c2ikj); + Value ch2ikl = CH2(b2, loc, ch, ik_e, l_b, idl1); + Value tmp2_c = b2.create(loc, tmp_c, ch2ikl); + CH2w(b2, loc, ch, ik_e, l_b, idl1, tmp2_c); + + Value c2ikjc = C2(b2, loc, cc, ik_e, jc, idl1); + Value tmp_ai = b2.create(loc, ai, c2ikjc); + Value ch2iklc = CH2(b2, loc, ch, ik_e, lc_b, idl1); + Value tmp2_ai = + b2.create(loc, tmp_ai, ch2iklc); + CH2w(b2, loc, ch, ik_e, lc_b, idl1, tmp2_ai); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c2); + Value jc_next = b.create(loc, jc, c2); + builder.create( + loc, std::vector{j_next, jc_next, iang_1_e}); + }); + + Value lc_b_next = builder.create(loc, lc_b, c1); + builder.create(loc, lc_b_next); + }); + + radfgExtend(opBuilder, loc, cc, ch, wa, csarr, ido, ip, l1); +} + +void radf2Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim) { + FloatType f64Ty = opBuilder.getF64Type(); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c20 = opBuilder.create(loc, 20); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { builder.create( loc, c2, ido, c2, std::nullopt, - [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { Value ic = b.create(loc, ido, i); Value icm1 = b.create(loc, ic, c1); Value im1 = b.create(loc, i, c1); @@ -817,303 +1520,1946 @@ void radf4Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); - std::vector cr2_ci2 = + std::vector tr2_ti2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + std::vector ccim1k0_tr2 = PM(b, loc, ccim1k0, tr2_ti2[0]); + std::vector ti2_ccik0 = PM(b, loc, tr2_ti2[1], ccik0); + + CH(b, loc, ch, im1, c0, k, ido, cdim, ccim1k0_tr2[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, ccim1k0_tr2[1]); + + CH(b, loc, ch, i, c0, k, ido, cdim, ti2_ccik0[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, ti2_ccik0[1]); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// Handle radix-2 FFT computation +void radf2(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 2); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c20 = opBuilder.create(loc, 20); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iv_args) { + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + std::vector cc0k0_cc0k1 = PM(builder, loc, cc0k0, cc0k1); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, cc0k0_cc0k1[0]); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, cc0k0_cc0k1[1]); + builder.create(loc, std::nullopt); + }); + + Value flag = opBuilder.create(loc, ido, c2); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::eq, flag, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k, ValueRange k_args) { + Value ccidom1k1 = CC(b, loc, cc, idom1, k, c1, ido, l1); + Value tmp = b.create(loc, ccidom1k1); + CH(b, loc, ch, c0, c1, k, ido, cdim, tmp); + Value ccidom1k0 = CC(b, loc, cc, idom1, k, c0, ido, l1); + CH(b, loc, ch, idom1, c0, k, ido, cdim, ccidom1k0); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + radf2Extend(builder, loc, cc, ch, wa, ido, l1, cdim); + builder.create(loc, std::nullopt); + }); +} + +void radf3Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value taur = + opBuilder.create(loc, APFloat(double(-0.5)), f64Ty); + Value taui = opBuilder.create( + loc, APFloat(double(0.86602540378443864676)), f64Ty); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector dr2_di2 = MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); - std::vector cr3_ci3 = + std::vector dr3_di3 = MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); - Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); - Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); - Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); - Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); - std::vector cr4_ci4 = - MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + Value cr2 = b.create(loc, dr2_di2[0], dr3_di3[0]); + Value ci2 = b.create(loc, dr2_di2[1], dr3_di3[1]); - std::vector tr1_tr4 = PM(b, loc, cr4_ci4[0], cr2_ci2[0]); - std::vector ti1_ti4 = PM(b, loc, cr2_ci2[1], cr4_ci4[1]); Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); - std::vector tr2_tr3 = PM(b, loc, ccim1k0, cr3_ci3[0]); + Value tmp5 = b.create(loc, ccim1k0, cr2); + CH(builder, loc, ch, im1, c0, k, ido, cdim, tmp5); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); - std::vector ti2_ti3 = PM(b, loc, ccik0, cr3_ci3[1]); + Value tmp6 = b.create(loc, ccik0, ci2); + CH(builder, loc, ch, i, c0, k, ido, cdim, tmp6); - std::vector chtmp0 = PM(b, loc, tr2_tr3[0], tr1_tr4[0]); - CH(b, loc, ch, im1, c0, k, ido, cdim, chtmp0[0]); - CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp0[1]); + Value tmp7 = b.create(loc, taur, cr2); + Value tr2 = b.create(loc, ccim1k0, tmp7); - std::vector chtmp1 = PM(b, loc, ti1_ti4[0], ti2_ti3[0]); - CH(b, loc, ch, i, c0, k, ido, cdim, chtmp1[0]); - CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp1[1]); + Value tmp8 = b.create(loc, taur, ci2); + Value ti2 = b.create(loc, ccik0, tmp8); - std::vector chtmp2 = PM(b, loc, tr2_tr3[1], ti1_ti4[1]); - CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp2[0]); - CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp2[1]); + Value tmp9 = b.create(loc, dr2_di2[1], dr3_di3[1]); + Value tr3 = b.create(loc, taui, tmp9); - std::vector chtmp3 = PM(b, loc, tr1_tr4[1], ti2_ti3[1]); - CH(b, loc, ch, i, c2, k, ido, cdim, chtmp3[0]); - CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp3[1]); + Value tmp10 = + b.create(loc, dr3_di3[0], dr2_di2[0]); + Value ti3 = b.create(loc, taui, tmp10); + + std::vector tr2_tr3 = PM(b, loc, tr2, tr3); + std::vector ti3_ti2 = PM(b, loc, ti3, ti2); + CH(builder, loc, ch, im1, c2, k, ido, cdim, tr2_tr3[0]); + CH(builder, loc, ch, icm1, c1, k, ido, cdim, tr2_tr3[1]); + + CH(builder, loc, ch, i, c2, k, ido, cdim, ti3_ti2[0]); + CH(builder, loc, ch, ic, c1, k, ido, cdim, ti3_ti2[1]); + + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// Handle radix-3 FFT computation +void radf3(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 3); + Value taur = + opBuilder.create(loc, APFloat(double(-0.5)), f64Ty); + Value taui = opBuilder.create( + loc, APFloat(double(0.86602540378443864676)), f64Ty); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iv_args) { + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); + Value cr2 = builder.create(loc, cc0k1, cc0k2); + + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value tmp0 = builder.create(loc, cc0k0, cr2); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp0); + + Value tmp1 = builder.create(loc, cc0k2, cc0k1); + Value tmp2 = builder.create(loc, tmp1, taui); + CH(builder, loc, ch, c0, c2, iv, ido, cdim, tmp2); + + Value tmp3 = builder.create(loc, taur, cr2); + Value tmp4 = builder.create(loc, tmp3, cc0k0); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tmp4); + + builder.create(loc, std::nullopt); + }); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, ido, c1); + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + radf3Extend(builder, loc, cc, ch, wa, ido, l1, cdim); + builder.create(loc, std::nullopt); + }); +} + +void radf4Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim, Value c0, Value c1, + Value c2, Value c3) { + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector cr2_ci2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); + Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); + Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); + Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); + std::vector cr3_ci3 = + MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + + Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); + Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); + Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); + Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); + std::vector cr4_ci4 = + MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + + std::vector tr1_tr4 = PM(b, loc, cr4_ci4[0], cr2_ci2[0]); + std::vector ti1_ti4 = PM(b, loc, cr2_ci2[1], cr4_ci4[1]); + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + std::vector tr2_tr3 = PM(b, loc, ccim1k0, cr3_ci3[0]); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + std::vector ti2_ti3 = PM(b, loc, ccik0, cr3_ci3[1]); + + std::vector chtmp0 = PM(b, loc, tr2_tr3[0], tr1_tr4[0]); + CH(b, loc, ch, im1, c0, k, ido, cdim, chtmp0[0]); + CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp0[1]); + + std::vector chtmp1 = PM(b, loc, ti1_ti4[0], ti2_ti3[0]); + CH(b, loc, ch, i, c0, k, ido, cdim, chtmp1[0]); + CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp1[1]); + + std::vector chtmp2 = PM(b, loc, tr2_tr3[1], ti1_ti4[1]); + CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp2[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp2[1]); + + std::vector chtmp3 = PM(b, loc, tr1_tr4[1], ti2_ti3[1]); + CH(b, loc, ch, i, c2, k, ido, cdim, chtmp3[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp3[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle radix-4 FFT computation +void radf4(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1, Value c0, Value c1, Value c2, Value c3) { + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 4); + Value hsqt2 = opBuilder.create( + loc, APFloat(double(0.70710678118654752440)), f64Ty); + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iargs) { + Value cc0k3 = CC(builder, loc, cc, c0, iv, c3, ido, l1); + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + std::vector tr1_tmp0 = PM(builder, loc, cc0k3, cc0k1); + CH(builder, loc, ch, c0, c2, iv, ido, cdim, tr1_tmp0[1]); + + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); + std::vector tr2_tmp1 = PM(builder, loc, cc0k0, cc0k2); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tr2_tmp1[1]); + + std::vector tmp2_tmp3 = + PM(builder, loc, tr2_tmp1[0], tr1_tmp0[0]); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp2_tmp3[0]); + CH(builder, loc, ch, idom1, c3, iv, ido, cdim, tmp2_tmp3[1]); + + builder.create(loc, std::nullopt); + }); + + Value reminder = opBuilder.create(loc, ido, c2); + Value condition0 = opBuilder.create( + loc, arith::CmpIPredicate::eq, reminder, c0); + opBuilder.create( + loc, condition0, [&](OpBuilder &builder, Location loc) { + Value negHsqt2 = builder.create( + loc, APFloat(double(-0.70710678118654752440)), f64Ty); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { + Value ccidom1k1 = CC(b, loc, cc, idom1, iv, c1, ido, l1); + Value ccidom1k3 = CC(b, loc, cc, idom1, iv, c3, ido, l1); + Value tmp0 = b.create(loc, ccidom1k1, ccidom1k3); + Value ti1 = b.create(loc, negHsqt2, tmp0); + + Value tmp1 = b.create(loc, ccidom1k1, ccidom1k3); + Value tr1 = b.create(loc, hsqt2, tmp1); + + Value ccidom1k0 = CC(b, loc, cc, idom1, iv, c0, ido, l1); + std::vector tmp2_tmp3 = PM(b, loc, ccidom1k0, tr1); + CH(b, loc, ch, idom1, c0, iv, ido, cdim, tmp2_tmp3[0]); + CH(b, loc, ch, idom1, c2, iv, ido, cdim, tmp2_tmp3[1]); + + Value ccidom1k2 = CC(b, loc, cc, idom1, iv, c2, ido, l1); + std::vector tmp4_tmp5 = PM(b, loc, ti1, ccidom1k2); + CH(b, loc, ch, c0, c3, iv, ido, cdim, tmp4_tmp5[0]); + CH(b, loc, ch, c0, c1, iv, ido, cdim, tmp4_tmp5[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + radf4Extend(builder, loc, cc, ch, wa, ido, l1, cdim, c0, c1, c2, c3); + builder.create(loc, std::nullopt); + }); + + return; +} + +void radf5Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim, Value tr11, + Value tr12, Value ti11, Value ti12, Value c0, Value c1, + Value c2, Value c3, Value c4) { + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector dr2_di2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); + Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); + Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); + Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); + std::vector dr3_di3 = + MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + + Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); + Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); + Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); + Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); + std::vector dr4_di4 = + MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + + Value wa3im2 = WA(b, loc, wa, c3, im2, ido, c1); + Value wa3im1 = WA(b, loc, wa, c3, im1, ido, c1); + Value ccim1k4 = CC(b, loc, cc, im1, k, c4, ido, l1); + Value ccik4 = CC(b, loc, cc, i, k, c4, ido, l1); + std::vector dr5_di5 = + MULPM(b, loc, wa3im2, wa3im1, ccim1k4, ccik4); + + std::vector cr2_ci5 = PM(b, loc, dr5_di5[0], dr2_di2[0]); + std::vector ci2_cr5 = PM(b, loc, dr2_di2[1], dr5_di5[1]); + std::vector cr3_ci4 = PM(b, loc, dr4_di4[0], dr3_di3[0]); + std::vector ci3_cr4 = PM(b, loc, dr3_di3[1], dr4_di4[1]); + + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + Value tmpch0 = b.create(loc, ccim1k0, cr2_ci5[0]); + Value chim10k = b.create(loc, tmpch0, cr3_ci4[0]); + CH(b, loc, ch, im1, c0, k, ido, cdim, chim10k); + + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + Value tmpch1 = b.create(loc, ccik0, ci2_cr5[0]); + Value chi0k = b.create(loc, tmpch1, ci3_cr4[0]); + CH(b, loc, ch, i, c0, k, ido, cdim, chi0k); + + Value tmp0 = b.create(loc, tr11, cr2_ci5[0]); + Value tmp1 = b.create(loc, ccim1k0, tmp0); + Value tmp2 = b.create(loc, tr12, cr3_ci4[0]); + Value tr2 = b.create(loc, tmp1, tmp2); + + Value tmp3 = b.create(loc, tr11, ci2_cr5[0]); + Value tmp4 = b.create(loc, ccik0, tmp3); + Value tmp5 = b.create(loc, tr12, ci3_cr4[0]); + Value ti2 = b.create(loc, tmp4, tmp5); + + Value tmp6 = b.create(loc, tr12, cr2_ci5[0]); + Value tmp7 = b.create(loc, ccim1k0, tmp6); + Value tmp8 = b.create(loc, tr11, cr3_ci4[0]); + Value tr3 = b.create(loc, tmp7, tmp8); + + Value tmp9 = b.create(loc, tr12, ci2_cr5[0]); + Value tmp10 = b.create(loc, ccik0, tmp9); + Value tmp11 = b.create(loc, tr11, ci3_cr4[0]); + Value ti3 = b.create(loc, tmp10, tmp11); + + std::vector tr5_tr4 = + MULPM(b, loc, ci2_cr5[1], ci3_cr4[1], ti11, ti12); + std::vector ti5_ti4 = + MULPM(b, loc, cr2_ci5[1], cr3_ci4[1], ti11, ti12); + + std::vector chtmp0 = PM(b, loc, tr2, tr5_tr4[0]); + CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp0[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp0[1]); + + std::vector chtmp1 = PM(b, loc, ti5_ti4[0], ti2); + CH(b, loc, ch, i, c2, k, ido, cdim, chtmp1[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp1[1]); + + std::vector chtmp2 = PM(b, loc, tr3, tr5_tr4[1]); + CH(b, loc, ch, im1, c4, k, ido, cdim, chtmp2[0]); + CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp2[1]); + + std::vector chtmp3 = PM(b, loc, ti5_ti4[1], ti3); + CH(b, loc, ch, i, c4, k, ido, cdim, chtmp3[0]); + CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp3[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle radix-5 FFT computation +void radf5(OpBuilder &builder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1, Value c0, Value c1, Value c2, Value c3, + Value c4) { + + FloatType f64Ty = builder.getF64Type(); + Value cdim = builder.create(loc, 5); + Value tr11 = builder.create( + loc, APFloat(double(0.3090169943749474241)), f64Ty); + Value tr12 = builder.create( + loc, APFloat(double(-0.8090169943749474241)), f64Ty); + Value ti11 = builder.create( + loc, APFloat(double(0.95105651629515357212)), f64Ty); + Value ti12 = builder.create( + loc, APFloat(double(0.58778525229247312917)), f64Ty); + Value idom1 = builder.create(loc, ido, c1); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { + Value cc0k4 = CC(b, loc, cc, c0, iv, c4, ido, l1); + Value cc0k1 = CC(b, loc, cc, c0, iv, c1, ido, l1); + std::vector cr2_ci5 = PM(b, loc, cc0k4, cc0k1); + + Value cc0k3 = CC(b, loc, cc, c0, iv, c3, ido, l1); + Value cc0k2 = CC(b, loc, cc, c0, iv, c2, ido, l1); + std::vector cr3_ci4 = PM(b, loc, cc0k3, cc0k2); + + Value cc0k0 = CC(b, loc, cc, c0, iv, c0, ido, l1); + Value tmpch0 = b.create(loc, cc0k0, cr2_ci5[0]); + Value ch0 = b.create(loc, tmpch0, cr3_ci4[0]); + CH(b, loc, ch, c0, c0, iv, ido, cdim, ch0); + + Value tmpch1 = b.create(loc, tr11, cr2_ci5[0]); + Value tmpch2 = b.create(loc, tr12, cr3_ci4[0]); + Value tmpch3 = b.create(loc, cc0k0, tmpch1); + Value ch1 = b.create(loc, tmpch2, tmpch3); + CH(b, loc, ch, idom1, c1, iv, ido, cdim, ch1); + + Value tmpch4 = b.create(loc, ti11, cr2_ci5[1]); + Value tmpch5 = b.create(loc, ti12, cr3_ci4[1]); + Value ch2 = b.create(loc, tmpch4, tmpch5); + CH(b, loc, ch, c0, c2, iv, ido, cdim, ch2); + + Value tmpch6 = b.create(loc, tr12, cr2_ci5[0]); + Value tmpch7 = b.create(loc, tr11, cr3_ci4[0]); + Value tmpch8 = b.create(loc, tmpch6, tmpch7); + Value ch3 = b.create(loc, cc0k0, tmpch8); + CH(b, loc, ch, idom1, c3, iv, ido, cdim, ch3); + + Value tmpch9 = b.create(loc, ti12, cr2_ci5[1]); + Value tmpch10 = b.create(loc, ti11, cr3_ci4[1]); + Value ch4 = b.create(loc, tmpch9, tmpch10); + CH(b, loc, ch, c0, c4, iv, ido, cdim, ch4); + + b.create(loc, std::nullopt); + }); + + Value condition = + builder.create(loc, arith::CmpIPredicate::ne, ido, c1); + builder.create(loc, condition, [&](OpBuilder &b, Location loc) { + radf5Extend(b, loc, cc, ch, wa, ido, l1, cdim, tr11, tr12, ti11, ti12, c0, + c1, c2, c3, c4); + b.create(loc, std::nullopt); + }); + + return; +} + +// function to implement ++ operation +void index_increment(OpBuilder &opBuilder, Location loc, Value target) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value a = opBuilder.create(loc, target, c0); + Value b = opBuilder.create(loc, a, c1); + opBuilder.create(loc, b, target, c0); +} + +// switch 2 element in an array +void index_SWAP(OpBuilder &opBuilder, Location loc, Value array, Value target1, + Value target2) { + Value a = opBuilder.create(loc, array, target1); + Value b = opBuilder.create(loc, array, target2); + + opBuilder.create(loc, a, array, target2); + opBuilder.create(loc, b, array, target1); +} + +// factorize the input length ans store factors in Rfftp_fctdata_fct +Value rfftp_factorize(OpBuilder &opBuilder, Location loc, + Value Rfftp_fctdata_fct, Value Rfftp_fctdata_tw, + Value Rfftp_fctdata_tws, Value Rfftp_plan_length, + Value Rfftp_plan_nfct, Value Rfftp_plan_mem) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c_neg1 = opBuilder.create(loc, -1); + Value NFCT = opBuilder.create(loc, 25); + + FloatType f64Ty = opBuilder.getF64Type(); + IndexType indexTy = opBuilder.getIndexType(); + + Value length = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + Value length_1 = opBuilder.create(loc, Rfftp_plan_length, c0); + opBuilder.create(loc, length_1, length, c0); + + Value nfct = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + + opBuilder.create(loc, c0, nfct, c0); + + auto loop = opBuilder.create( + loc, TypeRange{indexTy}, ValueRange{length_1}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value length_while = args[0]; + + Value length_mod_4 = + builder.create(loc, length_while, c4); + Value condition = builder.create( + loc, arith::CmpIPredicate::eq, length_mod_4, c0); + builder.create(loc, condition, + ValueRange{length_while}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value length_while = args[0]; + + Value currnet_nfct = builder.create(loc, nfct, c0); + builder.create(loc, c4, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(builder, loc, nfct); + Value length_next = + builder.create(loc, length_while, c2); + builder.create(loc, length_next, length, c0); + + builder.create(loc, std::vector{length_next}); + }); + + Value length_if = opBuilder.create(loc, length, c0); + Value length_if_mod_2 = opBuilder.create(loc, length_if, c2); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, length_if_mod_2, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value length_next = builder.create(loc, length_if, c1); + builder.create(loc, length_next, length, c0); + + Value currnet_nfct = builder.create(loc, nfct, c0); + builder.create(loc, c2, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(builder, loc, nfct); + + Value currnet_nfct_1 = builder.create(loc, nfct, c0); + Value nfctm1 = builder.create(loc, currnet_nfct_1, c1); + index_SWAP(builder, loc, Rfftp_fctdata_fct, nfctm1, c0); + + builder.create(loc, std::nullopt); + }); + + TypeRange type1 = TypeRange{f64Ty}; + TypeRange type2 = TypeRange{indexTy}; + + Value maxl = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + Value current_length2 = opBuilder.create(loc, length, c0); + Value current_length2_i32 = opBuilder.create( + loc, opBuilder.getI32Type(), current_length2); + Value length_f64 = opBuilder.create( + loc, opBuilder.getF64Type(), current_length2_i32); + Value sqrt_length = opBuilder.create(loc, length_f64); + Value maxl_index = opBuilder.create( + loc, opBuilder.getI32Type(), sqrt_length); + Value maxl_index_index = opBuilder.create( + loc, opBuilder.getIndexType(), maxl_index); + Value maxl_final = opBuilder.create(loc, maxl_index_index, c1); + opBuilder.create(loc, maxl_final, maxl, c0); + + opBuilder.create( + loc, TypeRange{indexTy}, ValueRange{c3}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value divisor = args[0]; + Value length_while = builder.create(loc, length, c0); + Value current_maxl = builder.create(loc, maxl, c0); + + Value condition1 = builder.create( + loc, arith::CmpIPredicate::sgt, length_while, c1); + Value condition2 = builder.create( + loc, arith::CmpIPredicate::slt, divisor, current_maxl); + Value and_cond = + builder.create(loc, condition1, condition2); + builder.create(loc, and_cond, ValueRange{divisor}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value divisor = args[0]; + + Value length_while = builder.create(loc, length, c0); + Value length_mod_divisor = + builder.create(loc, length_while, divisor); + Value condition1 = builder.create( + loc, arith::CmpIPredicate::eq, length_mod_divisor, c0); + builder.create( + loc, condition1, [&](OpBuilder &b, Location loc) { + b.create( + loc, TypeRange{indexTy}, ValueRange{c1}, + [&](OpBuilder &b2, Location loc, ValueRange args) { + Value x = args[0]; + + Value length_while_1 = + b2.create(loc, length, c0); + Value length_mod_divisor_1 = + b2.create(loc, length_while_1, divisor); + + Value condition2 = + b2.create(loc, arith::CmpIPredicate::eq, + length_mod_divisor_1, c0); + b2.create(loc, condition2, ValueRange{x}); + }, + [&](OpBuilder &b2, Location loc, ValueRange args) { + Value x = args[0]; + + Value currnet_nfct = + b2.create(loc, nfct, c0); + b2.create(loc, divisor, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(b2, loc, nfct); + + Value length_while_1 = + b2.create(loc, length, c0); + Value length_new = + b2.create(loc, length_while_1, divisor); + b2.create(loc, length_new, length, c0); + + b2.create(loc, std::vector{x}); + }); + + Value current_length2_1 = + b.create(loc, length, c0); + Value currnet_length2_i32_1 = b.create( + loc, opBuilder.getI32Type(), current_length2_1); + Value length_f64_1 = b.create( + loc, opBuilder.getF64Type(), currnet_length2_i32_1); + Value sqrt_length_1 = b.create(loc, length_f64_1); + Value maxl_index_1 = + b.create(loc, b.getI32Type(), sqrt_length_1); + Value maxl_index_index_1 = b.create( + loc, opBuilder.getIndexType(), maxl_index_1); + Value maxl_final_1 = + b.create(loc, maxl_index_index_1, c1); + b.create(loc, maxl_final_1, maxl, c0); + + b.create(loc, std::nullopt); + }); + + Value divisor_next = builder.create(loc, divisor, c2); + builder.create(loc, std::vector{divisor_next}); + }); + + Value current_length1 = opBuilder.create(loc, length, c0); + Value condition1 = opBuilder.create( + loc, arith::CmpIPredicate::sgt, current_length1, c1); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + Value current_nfct = builder.create(loc, nfct, c0); + builder.create(loc, current_length1, Rfftp_fctdata_fct, + current_nfct); + index_increment(builder, loc, nfct); + builder.create(loc, std::nullopt); + }); + + Value current_nfct1 = opBuilder.create(loc, nfct, c0); + opBuilder.create(loc, current_nfct1, Rfftp_plan_nfct, c0); + + return c0; +} + +Value index_to_f64(OpBuilder &opBuilder, Location loc, Value n) { + TypeRange type = TypeRange{opBuilder.getF64Type()}; + Value n_i32 = + opBuilder.create(loc, opBuilder.getI32Type(), n); + Value n_f64 = + opBuilder.create(loc, opBuilder.getF64Type(), n_i32); + return n_f64; +} + +Value f64_to_index(OpBuilder &opBuilder, Location loc, Value n_f64) { + TypeRange type = TypeRange{opBuilder.getI32Type()}; + Value n_i32 = + opBuilder.create(loc, opBuilder.getI32Type(), n_f64); + Value n_index = opBuilder.create( + loc, opBuilder.getIndexType(), n_i32); + return n_index; +} + +void my_sincosm1pi(OpBuilder &opBuilder, Location loc, Value a, Value res, + Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{c2}, SmallVector{c1}); + + Value s = opBuilder.create(loc, a, a); + + Value r1 = opBuilder.create( + loc, APFloat(double(-1.0369917389758117e-4)), f64Ty); + Value r2 = opBuilder.create( + loc, APFloat(double(1.9294935641298806e-3)), f64Ty); + Value r3 = opBuilder.create( + loc, APFloat(double(-2.5806887942825395e-2)), f64Ty); + Value r4 = opBuilder.create( + loc, APFloat(double(2.3533063028328211e-1)), f64Ty); + Value r5 = opBuilder.create( + loc, APFloat(double(-1.3352627688538006e+0)), f64Ty); + Value r6 = opBuilder.create( + loc, APFloat(double(4.0587121264167623e+0)), f64Ty); + Value r7 = opBuilder.create( + loc, APFloat(double(-4.9348022005446790e+0)), f64Ty); + + Value fma1 = opBuilder.create(loc, r1, s, r2); + Value fma2 = opBuilder.create(loc, fma1, s, r3); + Value fma3 = opBuilder.create(loc, fma2, s, r4); + Value fma4 = opBuilder.create(loc, fma3, s, r5); + Value fma5 = opBuilder.create(loc, fma4, s, r6); + Value fma6 = opBuilder.create(loc, fma5, s, r7); + + Value c = opBuilder.create(loc, fma6, s); + + Value r8 = opBuilder.create( + loc, APFloat(double(4.6151442520157035e-4)), f64Ty); + Value r9 = opBuilder.create( + loc, APFloat(double(-7.3700183130883555e-3)), f64Ty); + Value r10 = opBuilder.create( + loc, APFloat(double(8.2145868949323936e-2)), f64Ty); + Value r11 = opBuilder.create( + loc, APFloat(double(-5.9926452893214921e-1)), f64Ty); + Value r12 = opBuilder.create( + loc, APFloat(double(2.5501640398732688e+0)), f64Ty); + Value r13 = opBuilder.create( + loc, APFloat(double(-5.1677127800499516e+0)), f64Ty); + + Value fma7 = opBuilder.create(loc, r8, s, r9); + Value fma8 = opBuilder.create(loc, fma7, s, r10); + Value fma9 = opBuilder.create(loc, fma8, s, r11); + Value fma10 = opBuilder.create(loc, fma9, s, r12); + Value fma11 = opBuilder.create(loc, fma10, s, r13); + + Value s_new = opBuilder.create(loc, s, a); + Value r = opBuilder.create(loc, fma11, s_new); + + Value pi = opBuilder.create( + loc, APFloat(double(3.1415926535897931e+0)), f64Ty); + Value s_final = opBuilder.create(loc, a, pi, r); + + opBuilder.create(loc, c, res_raw, c0); + opBuilder.create(loc, s_final, res_raw, c1); + + return; +} + +void calc_first_octant_extend2(OpBuilder &opBuilder, Location loc, Value den, + Value res, Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, bias); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value f2 = + opBuilder.create(loc, APFloat(double(2.0)), f64Ty); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + + Value n_f64 = index_to_f64(opBuilder, loc, n); + Value l1_f64 = opBuilder.create(loc, n_f64); + Value l1 = f64_to_index(opBuilder, loc, l1_f64); + + opBuilder.create( + loc, c1, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iargs) { + Value i_f64 = index_to_f64(builder, loc, i); + Value den_f64 = index_to_f64(builder, loc, den); + Value arg = builder.create(loc, i_f64, den_f64); + Value arg_scaled = builder.create(loc, arg, f2); + + Value im2 = builder.create(loc, i, c2); + Value im2_bias = builder.create(loc, im2, bias); + + my_sincosm1pi(builder, loc, arg_scaled, res, im2_bias); + builder.create(loc, std::nullopt); + }); + + Value start_start = opBuilder.create(loc, l1, c0); + + opBuilder.create( + loc, start_start, n, l1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value start_loop, + ValueRange start_loop_args) { + Value start_f64 = index_to_f64(builder, loc, start_loop); + Value den_f64 = index_to_f64(builder, loc, den); + Value arg = builder.create(loc, start_f64, den_f64); + Value arg_scaled = builder.create(loc, arg, f2); + + Value cs = + builder.create(loc, MemRefType::get(2, f64Ty)); + my_sincosm1pi(builder, loc, arg_scaled, cs, c0); + + Value cs0 = builder.create(loc, cs, c0); + Value cs1 = builder.create(loc, cs, c1); + + Value cs0_plus_1 = builder.create(loc, cs0, f1); + + Value start_2 = builder.create(loc, start_loop, c2); + builder.create(loc, cs0_plus_1, res_raw, start_2); + Value start_2_plus_1 = builder.create(loc, start_2, c1); + builder.create(loc, cs1, res_raw, start_2_plus_1); + + Value n_minus_start = builder.create(loc, n, start_loop); + Value end_1 = builder.create(loc, l1, c0); + Value sum = builder.create(loc, start_loop, end_1); + Value condition = builder.create( + loc, arith::CmpIPredicate::sgt, sum, n); + Value end = builder.create(loc, condition, + n_minus_start, end_1); + + builder.create( + loc, c1, end, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value i_2 = b.create(loc, i, c2); + Value csx0 = b.create(loc, res_raw, i_2); + Value i_2_plus_1 = b.create(loc, i_2, c1); + Value csx1 = b.create(loc, res_raw, i_2_plus_1); + + Value tmp1 = b.create(loc, cs0, csx0); + Value tmp2 = b.create(loc, cs1, csx1); + Value tmp3 = b.create(loc, tmp1, tmp2); + Value tmp4 = b.create(loc, tmp3, cs0); + Value tmp5 = b.create(loc, tmp4, csx0); + Value res_real = b.create(loc, tmp5, f1); + + Value tmp6 = b.create(loc, cs0, csx1); + Value tmp7 = b.create(loc, cs1, csx0); + Value tmp8 = b.create(loc, tmp6, tmp7); + Value tmp9 = b.create(loc, tmp8, cs1); + Value res_imag = b.create(loc, tmp9, csx1); + + Value start_plus_i = b.create(loc, start_loop, i); + Value start_plus_i_2 = + b.create(loc, start_plus_i, c2); + Value start_plus_i_2_plus_1 = + b.create(loc, start_plus_i_2, c1); + b.create(loc, res_real, res_raw, start_plus_i_2); + b.create(loc, res_imag, res_raw, + start_plus_i_2_plus_1); + b.create(loc, std::nullopt); + }); + + builder.create(loc, cs); + builder.create(loc, std::nullopt); + }); + + opBuilder.create( + loc, c1, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange i_args) { + Value i_2 = builder.create(loc, i, c2); + Value val = builder.create(loc, res_raw, i_2); + Value val_plus_1 = builder.create(loc, val, f1); + builder.create(loc, val_plus_1, res_raw, i_2); + builder.create(loc, std::nullopt); + }); + + return; +} + +void calc_first_octant_extend1(OpBuilder &opBuilder, Location loc, Value den, + Value res, Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, bias); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + + opBuilder.create(loc, f1, res_raw, c0); + opBuilder.create(loc, f0, res_raw, c1); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + calc_first_octant_extend2(builder, loc, den, res, bias); + builder.create(loc, std::nullopt); + }); +} + +void calc_first_octant(OpBuilder &opBuilder, Location loc, Value den, Value res, + Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + calc_first_octant_extend1(builder, loc, den, res, bias); + builder.create(loc, std::nullopt); + }); +} + +void calc_first_quadrant(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, n); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value p_raw = opBuilder.create( + loc, resultType, res, SmallVector{n}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value n_times_2 = opBuilder.create(loc, n, c1); + calc_first_octant(opBuilder, loc, n_times_2, res, n); + + Value n_plus_2 = opBuilder.create(loc, n, c2); + Value ndone = opBuilder.create(loc, n_plus_2, c2); + Value ndonem1 = opBuilder.create(loc, ndone, c1); + Value ndone2 = opBuilder.create(loc, ndone, c2); + Value idx2_start = opBuilder.create(loc, ndone2, c2); + + Value i_start = opBuilder.create(loc, 0); + Value idx1_start = opBuilder.create(loc, 0); + + auto loop = opBuilder.create( + loc, i_start, ndonem1, c2, ValueRange{i_start, idx1_start, idx2_start}, + [&](OpBuilder &builder, Location loc, Value i_loop, + ValueRange i_loop_args) { + Value i_loop1 = i_loop_args[0]; + Value idx1 = i_loop_args[1]; + Value idx2 = i_loop_args[2]; + + Value p_2i = builder.create(loc, i_loop1, c2); + Value p_val = builder.create(loc, p_raw, p_2i); + builder.create(loc, p_val, res, idx1); + + Value p_2i_plus_1 = builder.create(loc, p_2i, c1); + Value p_val_1 = builder.create(loc, p_raw, p_2i_plus_1); + Value idx1_plus_1 = builder.create(loc, idx1, c1); + builder.create(loc, p_val_1, res, idx1_plus_1); + + Value p_2i_plus_3 = builder.create(loc, p_2i, c3); + Value p_val_3 = builder.create(loc, p_raw, p_2i_plus_3); + builder.create(loc, p_val_3, res, idx2); + + Value p_2i_plus_2 = builder.create(loc, p_2i, c2); + Value p_val_2 = builder.create(loc, p_raw, p_2i_plus_2); + Value idx2_plus_1 = builder.create(loc, idx2, c1); + builder.create(loc, p_val_2, res, idx2_plus_1); + + Value i_loop1_next = builder.create(loc, i_loop1, c2); + Value idx1_next = builder.create(loc, idx1, c2); + Value idx2_next = builder.create(loc, idx2, c2); + builder.create( + loc, std::vector{i_loop1_next, idx1_next, idx2_next}); + }); + + Value i_v = loop.getResults()[0]; + Value idx1_v = loop.getResults()[1]; + Value idx2_v = loop.getResults()[2]; + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::ne, i_v, ndone); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value p_2i = builder.create(loc, i_v, c2); + Value p_val = builder.create(loc, p_raw, p_2i); + builder.create(loc, p_val, res, idx1_v); + + Value p_2i_plus_1 = builder.create(loc, p_2i, c1); + Value p_val_1 = builder.create(loc, p_raw, p_2i_plus_1); + Value idx1_plus_1 = builder.create(loc, idx1_v, c1); + builder.create(loc, p_val_1, res, idx1_plus_1); + builder.create(loc, std::nullopt); + }); + + return; +} + +void calc_first_half(OpBuilder &opBuilder, Location loc, Value n, Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + IndexType indexTy = opBuilder.getIndexType(); + FloatType f64Ty = opBuilder.getF64Type(); + + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + + Value n_plus_1 = opBuilder.create(loc, n, c1); + Value ndone = opBuilder.create(loc, n_plus_1, c1); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, n); + Value remaining_size_p1 = + opBuilder.create(loc, remaining_size, c1); + + Value nm1 = opBuilder.create(loc, n, c1); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value p_raw = opBuilder.create( + loc, resultType, res, SmallVector{nm1}, + SmallVector{remaining_size_p1}, + SmallVector{c1}); + + Value n_times_4 = opBuilder.create(loc, n, c2); + calc_first_octant(opBuilder, loc, n_times_4, res, nm1); + + Value i4_start = opBuilder.create(loc, 0); + Value i_start = opBuilder.create(loc, 0); + Value in = opBuilder.create(loc, n, c0); + + auto loop = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{i4_start, i_start}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_minus_i4 = builder.create(loc, in, i4); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4, in_minus_i4); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value i4_2 = builder.create(loc, i4, c2); + Value i_2 = builder.create(loc, i, c2); + Value i4_2_p1 = builder.create(loc, i4_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_i4_2 = builder.create(loc, p_raw, i4_2); + Value p_i4_2_p1 = builder.create(loc, p_raw, i4_2_p1); + + builder.create(loc, p_i4_2, res, i_2); + builder.create(loc, p_i4_2_p1, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_0 = loop.getResults()[0]; + Value final_i_0 = loop.getResults()[1]; + + auto loop1 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_0, final_i_0}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value i4_minus_in = builder.create(loc, i4, in); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4_minus_in, c0); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value xm = builder.create(loc, in, i4); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + builder.create(loc, p_xm_2_p1, res, i_2); + builder.create(loc, p_xm_2, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_1 = loop1.getResults()[0]; + Value final_i_1 = loop1.getResults()[1]; + + auto loop2 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_1, final_i_1}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_3 = builder.create(loc, in, c3); + Value in_3_m_i4 = builder.create(loc, in_3, i4); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4, in_3_m_i4); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value xm = builder.create(loc, i4, in); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + Value m_p_xm_2_p1 = builder.create(loc, f0, p_xm_2_p1); + + builder.create(loc, m_p_xm_2_p1, res, i_2); + builder.create(loc, p_xm_2, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_2 = loop2.getResults()[0]; + Value final_i_2 = loop2.getResults()[1]; + + auto loop3 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_2, final_i_2}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value condition = builder.create( + loc, arith::CmpIPredicate::slt, i, ndone); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_2 = builder.create(loc, in, c2); + + Value xm = builder.create(loc, in_2, i4); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + Value m_p_xm_2 = builder.create(loc, f0, p_xm_2); + + builder.create(loc, m_p_xm_2, res, i_2); + builder.create(loc, p_xm_2_p1, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + + builder.create(loc, std::vector{i4_next, i_next}); + }); + + return; +} + +void fill_first_quadrant(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c8 = opBuilder.create(loc, 8); + + FloatType f64Ty = opBuilder.getF64Type(); + + Value hsqt2 = opBuilder.create( + loc, APFloat(double(0.707106781186547524400844362104849)), f64Ty); + + Value quart = opBuilder.create(loc, n, c2); + Value n_mod_8 = opBuilder.create(loc, n, c8); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_8, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value quart_plus_1 = builder.create(loc, quart, c1); + builder.create(loc, hsqt2, res, quart); + builder.create(loc, hsqt2, res, quart_plus_1); + builder.create(loc, std::nullopt); + }); + + Value two_quart = opBuilder.create(loc, quart, c2); + Value two_quart_minus_2 = opBuilder.create(loc, two_quart, c2); + + opBuilder.create( + loc, c2, quart, c2, ValueRange{two_quart_minus_2}, + [&](OpBuilder &builder, Location loc, Value i, ValueRange i_args) { + Value j = i_args[0]; + + Value i_plus_1 = builder.create(loc, i, c1); + Value j_plus_1 = builder.create(loc, j, c1); + + Value val_i = builder.create(loc, res, i); + Value val_i_plus_1 = builder.create(loc, res, i_plus_1); + + builder.create(loc, val_i_plus_1, res, j); + builder.create(loc, val_i, res, j_plus_1); + + Value j_next = builder.create(loc, j, c2); + builder.create(loc, j_next); + }); + + return; +} + +void fill_first_half(OpBuilder &opBuilder, Location loc, Value n, Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + FloatType f64Ty = opBuilder.getF64Type(); + Value c_1 = + opBuilder.create(loc, APFloat(double(-1.0)), f64Ty); + + Value half = opBuilder.create(loc, n, c1); + Value n_mod_4 = opBuilder.create(loc, n, c4); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_4, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, c0, half, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value i_plus_1 = b.create(loc, i, c1); + Value i_plus_half = b.create(loc, i, half); + Value i_plus_half_plus_1 = + b.create(loc, i_plus_half, c1); + + Value val_i = b.create(loc, res, i); + Value val_i_plus_1 = b.create(loc, res, i_plus_1); + + Value neg_val_i_plus_1 = + b.create(loc, val_i_plus_1, c_1); + b.create(loc, neg_val_i_plus_1, res, + i_plus_half); + b.create(loc, val_i, res, i_plus_half_plus_1); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value two_half_minus_2 = builder.create(loc, half, c1); + Value two_half_minus_2_mul_2 = + builder.create(loc, two_half_minus_2, c2); + + builder.create( + loc, c2, half, c2, ValueRange{two_half_minus_2_mul_2}, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value j = i_args[0]; + Value i_plus_1 = builder.create(loc, i, c1); + Value j_plus_1 = builder.create(loc, j, c1); + Value val_i = b.create(loc, res, i); + Value val_i_plus_1 = b.create(loc, res, i_plus_1); + Value neg_val_i = b.create(loc, val_i, c_1); + b.create(loc, neg_val_i, res, j); + b.create(loc, val_i_plus_1, res, j_plus_1); + + Value j_next = builder.create(loc, j, c2); + b.create(loc, j_next); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +void sincos_2pibyn_half(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value n_mod_4 = opBuilder.create(loc, n, c4); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_4, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + calc_first_octant(builder, loc, n, res, c0); + + fill_first_quadrant(builder, loc, n, res); + fill_first_half(builder, loc, n, res); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value n_mod_2 = builder.create(loc, n, c2); + Value condition1 = builder.create( + loc, arith::CmpIPredicate::eq, n_mod_2, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &b, Location loc) { + calc_first_quadrant(b, loc, n, res); + fill_first_half(b, loc, n, res); + b.create(loc, std::nullopt); + }, + [&](OpBuilder &b, Location loc) { + calc_first_half(b, loc, n, res); b.create(loc, std::nullopt); }); - builder.create(loc, std::nullopt); }); - - return; } -void radf4(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, - Value ido, Value l1, Value c0, Value c1, Value c2, Value c3) { +// calcuate the twiddle factors for the input length +Value rfftp_comp_twiddle(OpBuilder &opBuilder, Location loc, Value length, + Value Rfftp_fctdata_fct, Value Rfftp_fctdata_tw, + Value Rfftp_fctdata_tws, Value Rfftp_plan_length, + Value Rfftp_plan_nfct, Value Rfftp_plan_mem) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value length_2 = opBuilder.create(loc, length, c2); FloatType f64Ty = opBuilder.getF64Type(); - Value cdim = opBuilder.create(loc, 4); - Value hsqt2 = opBuilder.create( - loc, APFloat(double(0.70710678118654752440)), f64Ty); - Value idom1 = opBuilder.create(loc, ido, c1); - - opBuilder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value iv, ValueRange iargs) { - Value cc0k3 = CC(builder, loc, cc, c0, iv, c3, ido, l1); - Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); - std::vector tr1_tmp0 = PM(builder, loc, cc0k3, cc0k1); - CH(builder, loc, ch, c0, c2, iv, ido, cdim, tr1_tmp0[1]); - - Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); - Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); - std::vector tr2_tmp1 = PM(builder, loc, cc0k0, cc0k2); - CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tr2_tmp1[1]); - std::vector tmp2_tmp3 = - PM(builder, loc, tr2_tmp1[0], tr1_tmp0[0]); - CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp2_tmp3[0]); - CH(builder, loc, ch, idom1, c3, iv, ido, cdim, tmp2_tmp3[1]); + Value twid = opBuilder.create( + loc, MemRefType::get(ShapedType::kDynamic, f64Ty), + /*dynamicOperands=*/length_2); - builder.create(loc, std::nullopt); - }); + Value plan_nfct = opBuilder.create(loc, Rfftp_plan_nfct, c0); - Value reminder = opBuilder.create(loc, ido, c2); - Value condition0 = opBuilder.create( - loc, arith::CmpIPredicate::eq, reminder, c0); - opBuilder.create( - loc, condition0, [&](OpBuilder &builder, Location loc) { - Value negHsqt2 = builder.create( - loc, APFloat(double(-0.70710678118654752440)), f64Ty); + sincos_2pibyn_half(opBuilder, loc, length, twid); - builder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { - Value ccidom1k1 = CC(b, loc, cc, idom1, iv, c1, ido, l1); - Value ccidom1k3 = CC(b, loc, cc, idom1, iv, c3, ido, l1); - Value tmp0 = b.create(loc, ccidom1k1, ccidom1k3); - Value ti1 = b.create(loc, negHsqt2, tmp0); + Value l1_start = opBuilder.create(loc, 1); - Value tmp1 = b.create(loc, ccidom1k1, ccidom1k3); - Value tr1 = b.create(loc, hsqt2, tmp1); + opBuilder.create( + loc, c0, plan_nfct, c1, ValueRange{l1_start}, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + Value l1 = k_args[0]; + + Value ip = builder.create(loc, Rfftp_fctdata_fct, k); + + Value l1_m_ip = builder.create(loc, l1, ip); + Value ido = builder.create(loc, length, l1_m_ip); + Value plan_nfct_m1 = builder.create(loc, plan_nfct, c1); + + Value condition1 = builder.create( + loc, arith::CmpIPredicate::slt, k, plan_nfct_m1); + + builder.create( + loc, condition1, [&](OpBuilder &b, Location loc) { + Value ido_m1 = b.create(loc, ido, c1); + Value ido_m1_d2 = b.create(loc, ido_m1, c2); + Value ido_m1_d2_p1 = b.create(loc, ido_m1_d2, c1); + + b.create( + loc, c1, ip, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value j, ValueRange j_args) { + b2.create( + loc, c1, ido_m1_d2_p1, c1, std::nullopt, + [&](OpBuilder &b3, Location loc, Value i, + ValueRange i_args) { + Value j2 = b3.create(loc, j, c2); + Value j2_l1 = b3.create(loc, j2, l1); + Value j2_l1_i = + b3.create(loc, j2_l1, i); + Value j2_l1_i_p1 = + b3.create(loc, j2_l1_i, c1); + + Value j_m1 = b3.create(loc, j, c1); + Value ido_m1_j_m1 = + b3.create(loc, ido_m1, j_m1); + + Value i2 = b3.create(loc, i, c2); + Value i2_m1 = b3.create(loc, i2, c1); + Value i2_m2 = b3.create(loc, i2, c2); + + Value tw_a = + b3.create(loc, ido_m1_j_m1, i2_m2); + Value tw_b = + b3.create(loc, ido_m1_j_m1, i2_m1); + + Value twid_a = + b3.create(loc, twid, j2_l1_i); + Value twid_b = + b3.create(loc, twid, j2_l1_i_p1); + + Value fct_k = b3.create( + loc, Rfftp_fctdata_tw, k); + + b3.create(loc, twid_a, fct_k, tw_a); + b3.create(loc, twid_b, fct_k, tw_b); + + b3.create(loc, std::nullopt); + }); + b2.create(loc, std::nullopt); + }); - Value ccidom1k0 = CC(b, loc, cc, idom1, iv, c0, ido, l1); - std::vector tmp2_tmp3 = PM(b, loc, ccidom1k0, tr1); - CH(b, loc, ch, idom1, c0, iv, ido, cdim, tmp2_tmp3[0]); - CH(b, loc, ch, idom1, c2, iv, ido, cdim, tmp2_tmp3[1]); + b.create(loc, std::nullopt); + }); - Value ccidom1k2 = CC(b, loc, cc, idom1, iv, c2, ido, l1); - std::vector tmp4_tmp5 = PM(b, loc, ti1, ccidom1k2); - CH(b, loc, ch, c0, c3, iv, ido, cdim, tmp4_tmp5[0]); - CH(b, loc, ch, c0, c1, iv, ido, cdim, tmp4_tmp5[1]); + Value condition2 = builder.create( + loc, arith::CmpIPredicate::sgt, ip, c5); + + builder.create( + loc, condition2, [&](OpBuilder &b, Location loc) { + Value fct_k = b.create(loc, Rfftp_fctdata_tws, k); + Value c_f0 = + b.create(loc, APFloat(double(0.0)), f64Ty); + Value c_f1 = + b.create(loc, APFloat(double(1.0)), f64Ty); + + b.create(loc, c_f1, fct_k, c0); + b.create(loc, c_f0, fct_k, c1); + + Value ip_div_2 = b.create(loc, ip, c1); + Value ip_div_2_p1 = b.create(loc, ip_div_2, c1); + + b.create( + loc, c1, ip_div_2_p1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value i, ValueRange i_args) { + Value i2 = b2.create(loc, i, c2); + Value i2_p1 = b2.create(loc, i2, c1); + Value ip_m_i = b2.create(loc, ip, i); + Value ip_m_i_2 = b2.create(loc, ip_m_i, c2); + Value ip_m_i_2_p1 = + b2.create(loc, ip_m_i_2, c1); + + Value length_div_ip = + b2.create(loc, length, ip); + Value i2_length_div_ip = + b2.create(loc, i2, length_div_ip); + Value i2_length_div_ip_p1 = + b2.create(loc, i2_length_div_ip, c1); + + Value twid_a = + b2.create(loc, twid, i2_length_div_ip); + Value twid_b = b2.create( + loc, twid, i2_length_div_ip_p1); + Value twid_c = b2.create(loc, c_f0, twid_a); + Value twid_d = b2.create(loc, c_f0, twid_b); + + b2.create(loc, twid_a, fct_k, i2); + b2.create(loc, twid_b, fct_k, i2_p1); + b2.create(loc, twid_c, fct_k, ip_m_i_2); + b2.create(loc, twid_d, fct_k, ip_m_i_2_p1); + b2.create(loc, std::nullopt); + }); b.create(loc, std::nullopt); }); - builder.create(loc, std::nullopt); + Value l1_next = builder.create(loc, l1, ip); + builder.create(loc, l1_next); }); - Value condition1 = - opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); - opBuilder.create( - loc, condition1, [&](OpBuilder &builder, Location loc) { - radf4Extend(builder, loc, cc, ch, wa, ido, l1, cdim, c0, c1, c2, c3); - builder.create(loc, std::nullopt); - }); + opBuilder.create(loc, twid); - return; + return c0; } -void radf5Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, - Value wa, Value ido, Value l1, Value cdim, Value tr11, - Value tr12, Value ti11, Value ti12, Value c0, Value c1, - Value c2, Value c3, Value c4) { - opBuilder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { - builder.create( - loc, c2, ido, c2, std::nullopt, - [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { - Value ic = b.create(loc, ido, i); - Value icm1 = b.create(loc, ic, c1); - Value im1 = b.create(loc, i, c1); - Value im2 = b.create(loc, i, c2); +// calculate the twiddle factors and generates the computation order of +// butterfly operators +std::vector make_rfftp_plan(OpBuilder &opBuilder, Location loc, + Value length) { - Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); - Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); - Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); - Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); - std::vector dr2_di2 = - MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); - Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); - Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); - Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); - Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); - std::vector dr3_di3 = - MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + int64_t NFCT_num = 25; + Value NFCT = opBuilder.create(loc, NFCT_num); - Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); - Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); - Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); - Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); - std::vector dr4_di4 = - MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + FloatType f64Ty = opBuilder.getF64Type(); + IndexType indexTy = opBuilder.getIndexType(); - Value wa3im2 = WA(b, loc, wa, c3, im2, ido, c1); - Value wa3im1 = WA(b, loc, wa, c3, im1, ido, c1); - Value ccim1k4 = CC(b, loc, cc, im1, k, c4, ido, l1); - Value ccik4 = CC(b, loc, cc, i, k, c4, ido, l1); - std::vector dr5_di5 = - MULPM(b, loc, wa3im2, wa3im1, ccim1k4, ccik4); + Value length_2 = opBuilder.create(loc, length, c2); - std::vector cr2_ci5 = PM(b, loc, dr5_di5[0], dr2_di2[0]); - std::vector ci2_cr5 = PM(b, loc, dr2_di2[1], dr5_di5[1]); - std::vector cr3_ci4 = PM(b, loc, dr4_di4[0], dr3_di3[0]); - std::vector ci3_cr4 = PM(b, loc, dr3_di3[1], dr4_di4[1]); + MemRefType type = MemRefType::get(NFCT_num, indexTy); + // MemRefType type1 = MemRefType::get(length_num2, f64Ty); + MemRefType type1 = MemRefType::get(ShapedType::kDynamic, f64Ty); + MemRefType type2 = MemRefType::get(NFCT_num, type1); + MemRefType type3 = MemRefType::get(1, indexTy); + MemRefType type4 = MemRefType::get(1, f64Ty); - Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); - Value tmpch0 = b.create(loc, ccim1k0, cr2_ci5[0]); - Value chim10k = b.create(loc, tmpch0, cr3_ci4[0]); - CH(b, loc, ch, im1, c0, k, ido, cdim, chim10k); + Value Rfftp_fctdata_fct = opBuilder.create(loc, type); + Value Rfftp_fctdata_tw = opBuilder.create(loc, type2); + Value Rfftp_fctdata_tws = opBuilder.create(loc, type2); + Value Rfftp_plan_length = opBuilder.create(loc, type3); + Value Rfftp_plan_nfct = opBuilder.create(loc, type3); + Value Rfftp_plan_mem = opBuilder.create(loc, type4); - Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); - Value tmpch1 = b.create(loc, ccik0, ci2_cr5[0]); - Value chi0k = b.create(loc, tmpch1, ci3_cr4[0]); - CH(b, loc, ch, i, c0, k, ido, cdim, chi0k); + opBuilder.create(loc, length, Rfftp_plan_length, c0); + opBuilder.create(loc, c0, Rfftp_plan_nfct, c0); - Value tmp0 = b.create(loc, tr11, cr2_ci5[0]); - Value tmp1 = b.create(loc, ccim1k0, tmp0); - Value tmp2 = b.create(loc, tr12, cr3_ci4[0]); - Value tr2 = b.create(loc, tmp1, tmp2); + opBuilder.create( + loc, c0, NFCT, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iargs) { + builder.create(loc, c0, Rfftp_fctdata_fct, i); - Value tmp3 = b.create(loc, tr11, ci2_cr5[0]); - Value tmp4 = b.create(loc, ccik0, tmp3); - Value tmp5 = b.create(loc, tr12, ci3_cr4[0]); - Value ti2 = b.create(loc, tmp4, tmp5); + Value tw_i = builder.create( + loc, type1, /*dynamicOperands=*/length_2); + builder.create(loc, tw_i, Rfftp_fctdata_tw, i); + Value tws_i = builder.create( + loc, type1, /*dynamicOperands=*/length_2); + builder.create(loc, tws_i, Rfftp_fctdata_tws, i); - Value tmp6 = b.create(loc, tr12, cr2_ci5[0]); - Value tmp7 = b.create(loc, ccim1k0, tmp6); - Value tmp8 = b.create(loc, tr11, cr3_ci4[0]); - Value tr3 = b.create(loc, tmp7, tmp8); + builder.create(loc, std::nullopt); + }); - Value tmp9 = b.create(loc, tr12, ci2_cr5[0]); - Value tmp10 = b.create(loc, ccik0, tmp9); - Value tmp11 = b.create(loc, tr11, ci3_cr4[0]); - Value ti3 = b.create(loc, tmp10, tmp11); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::ne, length, c1); - std::vector tr5_tr4 = - MULPM(b, loc, ci2_cr5[1], ci3_cr4[1], ti11, ti12); - std::vector ti5_ti4 = - MULPM(b, loc, cr2_ci5[1], cr3_ci4[1], ti11, ti12); + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value xxx = builder.create(loc, 1); + rfftp_factorize(builder, loc, Rfftp_fctdata_fct, Rfftp_fctdata_tw, + Rfftp_fctdata_tws, Rfftp_plan_length, Rfftp_plan_nfct, + Rfftp_plan_mem); + rfftp_comp_twiddle(builder, loc, length, Rfftp_fctdata_fct, + Rfftp_fctdata_tw, Rfftp_fctdata_tws, + Rfftp_plan_length, Rfftp_plan_nfct, Rfftp_plan_mem); + builder.create(loc, std::nullopt); + }); - std::vector chtmp0 = PM(b, loc, tr2, tr5_tr4[0]); - CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp0[0]); - CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp0[1]); + return {Rfftp_fctdata_fct, Rfftp_fctdata_tw, Rfftp_fctdata_tws, + Rfftp_plan_length, Rfftp_plan_nfct, Rfftp_plan_mem}; +} - std::vector chtmp1 = PM(b, loc, ti5_ti4[0], ti2); - CH(b, loc, ch, i, c2, k, ido, cdim, chtmp1[0]); - CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp1[1]); +void memref_SWAP(OpBuilder &opBuilder, Location loc, Value p, Value p1) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); - std::vector chtmp2 = PM(b, loc, tr3, tr5_tr4[1]); - CH(b, loc, ch, im1, c4, k, ido, cdim, chtmp2[0]); - CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp2[1]); + Value length = opBuilder.create(loc, p, c0); - std::vector chtmp3 = PM(b, loc, ti5_ti4[1], ti3); - CH(b, loc, ch, i, c4, k, ido, cdim, chtmp3[0]); - CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp3[1]); + opBuilder.create( + loc, c0, length, c1, std::nullopt, + [&](OpBuilder builder, Location loc, Value i, ValueRange i_args) { + Value val_p = builder.create(loc, p, i); + Value val_p1 = builder.create(loc, p1, i); + + builder.create(loc, val_p, p1, i); + builder.create(loc, val_p1, p, i); + builder.create(loc, std::nullopt); + }); +} +void flag_SWAP(OpBuilder &opBuilder, Location loc, Value flag) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + + Value val = opBuilder.create(loc, flag, c0); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::eq, val, c0); + + Value x = opBuilder.create(loc, condition, c1, c0); + + opBuilder.create(loc, x, flag, c0); +} + +void copy_and_norm(OpBuilder &opBuilder, Location loc, Value c, Value p1, + Value n, Value fct, Value flag) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + FloatType f64Ty = opBuilder.getF64Type(); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + + Value flag_val = opBuilder.create(loc, flag, c0); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, flag_val, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + Value condition1 = builder.create( + loc, arith::CmpFPredicate::ONE, fct, f1); + builder.create( + loc, condition1, + [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder b2, Location loc, Value i, ValueRange i_args) { + Value p1_i = b2.create(loc, p1, i); + Value v = b2.create(loc, fct, p1_i); + b2.create(loc, v, c, i); + b2.create(loc, std::nullopt); + }); + b.create(loc, std::nullopt); + }, + [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder b2, Location loc, Value i, ValueRange i_args) { + Value val = b2.create(loc, p1, i); + b2.create(loc, val, c, i); + b2.create(loc, std::nullopt); + }); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value condition2 = builder.create( + loc, arith::CmpFPredicate::ONE, fct, f1); + builder.create( + loc, condition2, [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value i, ValueRange i_args) { + Value c_i = b2.create(loc, c, i); + Value newC = b2.create(loc, fct, c_i); + b2.create(loc, newC, c, i); + b2.create(loc, std::nullopt); + }); b.create(loc, std::nullopt); }); - builder.create(loc, std::nullopt); }); - - return; } -void radf5(OpBuilder &builder, Location loc, Value cc, Value ch, Value wa, - Value ido, Value l1, Value c0, Value c1, Value c2, Value c3, - Value c4) { - FloatType f64Ty = builder.getF64Type(); - Value cdim = builder.create(loc, 5); - Value tr11 = builder.create( - loc, APFloat(double(0.3090169943749474241)), f64Ty); - Value tr12 = builder.create( - loc, APFloat(double(-0.8090169943749474241)), f64Ty); - Value ti11 = builder.create( - loc, APFloat(double(0.95105651629515357212)), f64Ty); - Value ti12 = builder.create( - loc, APFloat(double(0.58778525229247312917)), f64Ty); - Value idom1 = builder.create(loc, ido, c1); +// FFT forward function for real number +void rfftp_forward(OpBuilder &opBuilder, Location loc, Value Rfftp_fctdata_fct, + Value Rfftp_fctdata_tw, Value Rfftp_fctdata_tws, + Value Rfftp_plan_length, Value Rfftp_plan_nfct, + Value Rfftp_plan_mem, Value c, Value fct) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c20 = opBuilder.create(loc, 20); + FloatType f64Ty = opBuilder.getF64Type(); - builder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { - Value cc0k4 = CC(b, loc, cc, c0, iv, c4, ido, l1); - Value cc0k1 = CC(b, loc, cc, c0, iv, c1, ido, l1); - std::vector cr2_ci5 = PM(b, loc, cc0k4, cc0k1); + Value n = opBuilder.create(loc, Rfftp_plan_length, c0); - Value cc0k3 = CC(b, loc, cc, c0, iv, c3, ido, l1); - Value cc0k2 = CC(b, loc, cc, c0, iv, c2, ido, l1); - std::vector cr3_ci4 = PM(b, loc, cc0k3, cc0k2); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c1); - Value cc0k0 = CC(b, loc, cc, c0, iv, c0, ido, l1); - Value tmpch0 = b.create(loc, cc0k0, cr2_ci5[0]); - Value ch0 = b.create(loc, tmpch0, cr3_ci4[0]); - CH(b, loc, ch, c0, c0, iv, ido, cdim, ch0); + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value flag = builder.create( + loc, MemRefType::get(1, builder.getIndexType())); + builder.create(loc, c1, flag, c0); + Value l1_raw = builder.create(loc, n, c0); + Value nf = builder.create(loc, Rfftp_plan_nfct, c0); - Value tmpch1 = b.create(loc, tr11, cr2_ci5[0]); - Value tmpch2 = b.create(loc, tr12, cr3_ci4[0]); - Value tmpch3 = b.create(loc, cc0k0, tmpch1); - Value ch1 = b.create(loc, tmpch2, tmpch3); - CH(b, loc, ch, idom1, c1, iv, ido, cdim, ch1); + MemRefType cType = dyn_cast(c.getType()); + Value dimSize = builder.create(loc, c, 0); + Value ch = builder.create(loc, cType, + /*dynamicOperands=*/dimSize); - Value tmpch4 = b.create(loc, ti11, cr2_ci5[1]); - Value tmpch5 = b.create(loc, ti12, cr3_ci4[1]); - Value ch2 = b.create(loc, tmpch4, tmpch5); - CH(b, loc, ch, c0, c2, iv, ido, cdim, ch2); + // Value ch = builder.create( + // loc, MemRefType::get(cType.getShape(), f64Ty)); - Value tmpch6 = b.create(loc, tr12, cr2_ci5[0]); - Value tmpch7 = b.create(loc, tr11, cr3_ci4[0]); - Value tmpch8 = b.create(loc, tmpch6, tmpch7); - Value ch3 = b.create(loc, cc0k0, tmpch8); - CH(b, loc, ch, idom1, c3, iv, ido, cdim, ch3); + FailureOr computelayout = StridedLayoutAttr::get( + opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); - Value tmpch9 = b.create(loc, ti12, cr2_ci5[1]); - Value tmpch10 = b.create(loc, ti11, cr3_ci4[1]); - Value ch4 = b.create(loc, tmpch9, tmpch10); - CH(b, loc, ch, c0, c4, iv, ido, cdim, ch4); + // memref> - b.create(loc, std::nullopt); - }); + Value p1_raw = builder.create( + loc, resultType, c, SmallVector{c0}, + SmallVector{n}, SmallVector{c1}); - Value condition = - builder.create(loc, arith::CmpIPredicate::ne, ido, c1); - builder.create(loc, condition, [&](OpBuilder &b, Location loc) { - radf5Extend(b, loc, cc, ch, wa, ido, l1, cdim, tr11, tr12, ti11, ti12, c0, - c1, c2, c3, c4); - b.create(loc, std::nullopt); - }); + Value p2_raw = builder.create( + loc, resultType, ch, SmallVector{c0}, + SmallVector{n}, SmallVector{c1}); - return; + builder.create( + loc, c0, nf, c1, ValueRange{l1_raw}, + [&](OpBuilder b, Location loc, Value k1, ValueRange k1_args) { + Value l1_old = k1_args[0]; + + Value nf_m_k1 = b.create(loc, nf, k1); + Value k = b.create(loc, nf_m_k1, c1); + Value ip = b.create(loc, Rfftp_fctdata_fct, k); + Value ido = b.create(loc, n, l1_old); + Value l1 = b.create(loc, l1_old, ip); + + Value tw = b.create(loc, Rfftp_fctdata_tw, k); + + Value condition1 = b.create( + loc, arith::CmpIPredicate::eq, ip, c4); + + b.create( + loc, condition1, + [&](OpBuilder &b2, Location loc) { + radf4(b2, loc, p1_raw, p2_raw, tw, ido, l1, c0, c1, c2, c3); + b2.create(loc, std::nullopt); + }, + [&](OpBuilder &b2, Location loc) { + Value condition2 = b2.create( + loc, arith::CmpIPredicate::eq, ip, c2); + b2.create( + loc, condition2, + [&](OpBuilder &b3, Location loc) { + radf2(b3, loc, p1_raw, p2_raw, tw, ido, l1); + b3.create(loc, std::nullopt); + }, + [&](OpBuilder &b3, Location loc) { + Value condition3 = b3.create( + loc, arith::CmpIPredicate::eq, ip, c3); + b3.create( + loc, condition3, + [&](OpBuilder &b4, Location loc) { + radf3(b4, loc, p1_raw, p2_raw, tw, ido, l1); + b4.create(loc, std::nullopt); + }, + [&](OpBuilder &b4, Location loc) { + Value condition4 = b4.create( + loc, arith::CmpIPredicate::eq, ip, c5); + b4.create( + loc, condition4, + [&](OpBuilder &b5, Location loc) { + radf5(b5, loc, p1_raw, p2_raw, tw, ido, + l1, c0, c1, c2, c3, c4); + b5.create(loc, + std::nullopt); + }, + [&](OpBuilder &b5, Location loc) { + Value tws = b5.create( + loc, Rfftp_fctdata_tws, k); + radfg(b5, loc, p1_raw, p2_raw, tw, tws, + ido, ip, l1); + memref_SWAP(b5, loc, p1_raw, p2_raw); + flag_SWAP(b5, loc, flag); + b5.create(loc, + std::nullopt); + }); + b4.create(loc, std::nullopt); + }); + b3.create(loc, std::nullopt); + } + + ); + b2.create(loc, std::nullopt); + }); + + memref_SWAP(b, loc, p1_raw, p2_raw); + flag_SWAP(b, loc, flag); + + b.create(loc, l1); + }); + + copy_and_norm(builder, loc, c, p1_raw, n, fct, flag); + + builder.create(loc, std::nullopt); + }); } // Calculate abspower of bufferMem and store result to a specific line in the @@ -1198,11 +3544,20 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, }); Value multiplied = mulfOp.getResult(0); - Value bufferMem = + Value bufferMem_raw = builder.create(loc, mTp, multiplied); - // Compute 'dap.rfft400' operation, result stores in `bufferMem`. - builder.create(loc, bufferMem); + MemRefType type0 = MemRefType::get({400}, f64Ty); + MemRefType type1 = MemRefType::get(ShapedType::kDynamic, f64Ty); + + Value bufferMem_rfft = + builder.create(loc, type1, bufferMem_raw); + + // Compute 'dap.rfft' operation, result stores in `bufferMem`. + builder.create(loc, bufferMem_rfft); + + Value bufferMem = + builder.create(loc, type0, bufferMem_rfft); // Store the result in a single line specified by `iv`. absPower(builder, loc, bufferMem, spectrogram, iv, c0, c1, c2); @@ -1278,14 +3633,14 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, } namespace { -class DAPRFFT400Lowering : public OpRewritePattern { +class DAPRFFTLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - explicit DAPRFFT400Lowering(MLIRContext *context) + explicit DAPRFFTLowering(MLIRContext *context) : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(dap::RFFT400Op op, + LogicalResult matchAndRewrite(dap::RFFTOp op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto ctx = op->getContext(); @@ -1297,114 +3652,35 @@ class DAPRFFT400Lowering : public OpRewritePattern { Value c3 = rewriter.create(loc, 3); Value c4 = rewriter.create(loc, 4); Value c5 = rewriter.create(loc, 5); + Value c9 = rewriter.create(loc, 9); + Value c24 = rewriter.create(loc, 24); + Value c25 = rewriter.create(loc, 25); + Value c50 = rewriter.create(loc, 50); + + Value inputFeatures = rewriter.create( + loc, bufferMem, /*restrict=*/true, /*writable=*/true); + Value inputFeaturesSize = + rewriter.create(loc, inputFeatures, c0); FloatType f64Ty = rewriter.getF64Type(); + Value f0 = rewriter.create(loc, APFloat(double(0.0)), f64Ty); - int64_t inputLength = 400; - - // Generate ch MemRef - RankedTensorType tensorTy = RankedTensorType::get({inputLength}, f64Ty); - MemRefType m25Ty = MemRefType::get({inputLength}, f64Ty); - Value chTensor = rewriter.create(loc, tensorTy, f0); - Value ch = rewriter.create(loc, m25Ty, chTensor); - - // Generate wa MemRefs - std::vector tw0Vec{ - 0.999877, 0.015707, 0.999507, 0.031411, 0.998890, 0.047106, - 0.998027, 0.062791, 0.996917, 0.078459, 0.995562, 0.094108, - 0.993961, 0.109734, 0.992115, 0.125333, 0.990024, 0.140901, - 0.987688, 0.156434, 0.985109, 0.171929, 0.982287, 0.187381, - 0.979223, 0.202787, 0.975917, 0.218143, 0.972370, 0.233445, - 0.968583, 0.248690, 0.964557, 0.263873, 0.960294, 0.278991, - 0.955793, 0.294040, 0.951057, 0.309017, 0.946085, 0.323917, - 0.940881, 0.338738, 0.935444, 0.353475, 0.929776, 0.368125, - 0.923880, 0.382683, 0.917755, 0.397148, 0.911403, 0.411514, - 0.904827, 0.425779, 0.898028, 0.439939, 0.891007, 0.453990, - 0.883766, 0.467930, 0.876307, 0.481754, 0.868632, 0.495459, - 0.860742, 0.509041, 0.852640, 0.522499, 0.844328, 0.535827, - 0.835807, 0.549023, 0.827081, 0.562083, 0.818150, 0.575005, - 0.809017, 0.587785, 0.799685, 0.600420, 0.790155, 0.612907, - 0.780430, 0.625243, 0.770513, 0.637424, 0.760406, 0.649448, - 0.750111, 0.661312, 0.739631, 0.673013, 0.728969, 0.684547, - 0.718126, 0.695913, 0.000000, 0.999507, 0.031411, 0.998027, - 0.062791, 0.995562, 0.094108, 0.992115, 0.125333, 0.987688, - 0.156434, 0.982287, 0.187381, 0.975917, 0.218143, 0.968583, - 0.248690, 0.960294, 0.278991, 0.951057, 0.309017, 0.940881, - 0.338738, 0.929776, 0.368125, 0.917755, 0.397148, 0.904827, - 0.425779, 0.891007, 0.453990, 0.876307, 0.481754, 0.860742, - 0.509041, 0.844328, 0.535827, 0.827081, 0.562083, 0.809017, - 0.587785, 0.790155, 0.612907, 0.770513, 0.637424, 0.750111, - 0.661312, 0.728969, 0.684547, 0.707107, 0.707107, 0.684547, - 0.728969, 0.661312, 0.750111, 0.637424, 0.770513, 0.612907, - 0.790155, 0.587785, 0.809017, 0.562083, 0.827081, 0.535827, - 0.844328, 0.509041, 0.860742, 0.481754, 0.876307, 0.453990, - 0.891007, 0.425779, 0.904827, 0.397148, 0.917755, 0.368125, - 0.929776, 0.338738, 0.940881, 0.309017, 0.951057, 0.278991, - 0.960294, 0.248690, 0.968583, 0.218143, 0.975917, 0.187381, - 0.982287, 0.156434, 0.987688, 0.125333, 0.992115, 0.094108, - 0.995562, 0.062791, 0.998027, 0.031411, 0.999507, 0.000000, - 0.998890, 0.047106, 0.995562, 0.094108, 0.990024, 0.140901, - 0.982287, 0.187381, 0.972370, 0.233445, 0.960294, 0.278991, - 0.946085, 0.323917, 0.929776, 0.368125, 0.911403, 0.411514, - 0.891007, 0.453990, 0.868632, 0.495459, 0.844328, 0.535827, - 0.818150, 0.575005, 0.790155, 0.612907, 0.760406, 0.649448, - 0.728969, 0.684547, 0.695913, 0.718126, 0.661312, 0.750111, - 0.625243, 0.780430, 0.587785, 0.809017, 0.549023, 0.835807, - 0.509041, 0.860742, 0.467930, 0.883766, 0.425779, 0.904827, - 0.382683, 0.923880, 0.338738, 0.940881, 0.294040, 0.955793, - 0.248690, 0.968583, 0.202787, 0.979223, 0.156434, 0.987688, - 0.109734, 0.993961, 0.062791, 0.998027, 0.015707, 0.999877, - -0.031411, 0.999507, -0.078459, 0.996917, -0.125333, 0.992115, - -0.171929, 0.985109, -0.218143, 0.975917, -0.263873, 0.964557, - -0.309017, 0.951057, -0.353475, 0.935444, -0.397148, 0.917755, - -0.439939, 0.898028, -0.481754, 0.876307, -0.522499, 0.852640, - -0.562083, 0.827081, -0.600420, 0.799685, -0.637424, 0.770513, - -0.673013, 0.739631, 0.000000}; - Value wa0Tensor = rewriter.create( - loc, DenseFPElementsAttr::get(RankedTensorType::get({297}, f64Ty), - ArrayRef(tw0Vec))); - Value wa0 = rewriter.create( - loc, MemRefType::get({297}, f64Ty), wa0Tensor); - - std::vector tw1Vec{ - 0.998027, 0.062791, 0.992115, 0.125333, 0.982287, 0.187381, - 0.968583, 0.248690, 0.951057, 0.309017, 0.929776, 0.368125, - 0.904827, 0.425779, 0.876307, 0.481754, 0.844328, 0.535827, - 0.809017, 0.587785, 0.770513, 0.637424, 0.728969, 0.684547, - 0.992115, 0.125333, 0.968583, 0.248690, 0.929776, 0.368125, - 0.876307, 0.481754, 0.809017, 0.587785, 0.728969, 0.684547, - 0.637424, 0.770513, 0.535827, 0.844328, 0.425779, 0.904827, - 0.309017, 0.951057, 0.187381, 0.982287, 0.062791, 0.998027, - 0.982287, 0.187381, 0.929776, 0.368125, 0.844328, 0.535827, - 0.728969, 0.684547, 0.587785, 0.809017, 0.425779, 0.904827, - 0.248690, 0.968583, 0.062791, 0.998027, -0.125333, 0.992115, - -0.309017, 0.951057, -0.481754, 0.876307, -0.637424, 0.770513}; - Value wa1Tensor = rewriter.create( - loc, DenseFPElementsAttr::get(RankedTensorType::get({72}, f64Ty), - ArrayRef(tw1Vec))); - Value wa1 = rewriter.create( - loc, MemRefType::get({72}, f64Ty), wa1Tensor); - - std::vector tw2Vec{0.968583, 0.248690, 0.876307, 0.481754, - 0.876307, 0.481754, 0.535827, 0.844328, - 0.728969, 0.684547, 0.062791, 0.998027, - 0.535827, 0.844328, -0.425779, 0.904827}; - Value wa2Tensor = rewriter.create( - loc, DenseFPElementsAttr::get(RankedTensorType::get({16}, f64Ty), - ArrayRef(tw2Vec))); - Value wa2 = rewriter.create( - loc, MemRefType::get({16}, f64Ty), wa2Tensor); - - Value c16 = rewriter.create(loc, 16); - Value c25 = rewriter.create(loc, 25); - Value c80 = rewriter.create(loc, 80); - Value c100 = rewriter.create(loc, 100); + Value f1 = + rewriter.create(loc, APFloat(double(1.0)), f64Ty); + + std::vector plan = make_rfftp_plan(rewriter, loc, inputFeaturesSize); + + Value Rfftp_fctdata_fct = plan[0]; + Value Rfftp_fctdata_tw = plan[1]; + Value Rfftp_fctdata_tws = plan[2]; + Value Rfftp_plan_length = plan[3]; + Value Rfftp_plan_nfct = plan[4]; + Value Rfftp_plan_mem = plan[5]; - radf5(rewriter, loc, bufferMem, ch, wa2, c1, c80, c0, c1, c2, c3, c4); - radf5(rewriter, loc, ch, bufferMem, wa2, c5, c16, c0, c1, c2, c3, c4); - radf4(rewriter, loc, bufferMem, ch, wa1, c25, c4, c0, c1, c2, c3); - radf4(rewriter, loc, ch, bufferMem, wa0, c100, c1, c0, c1, c2, c3); + rfftp_forward(rewriter, loc, Rfftp_fctdata_fct, Rfftp_fctdata_tw, + Rfftp_fctdata_tws, Rfftp_plan_length, Rfftp_plan_nfct, + Rfftp_plan_mem, bufferMem, f1); rewriter.eraseOp(op); return success(); @@ -1568,7 +3844,7 @@ class DAPWhisperPreprocessLowering void populateExtendDAPConversionPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); // TODO : extract operators } @@ -1599,6 +3875,7 @@ class ExtendDAPPass registry.insert(); registry.insert(); registry.insert(); + registry.insert(); // Buddy Compiler designed dialect registry.insert(); } @@ -1620,6 +3897,7 @@ void ExtendDAPPass::runOnOperation() { target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); // Add legal operations. target.addLegalOp();