From 2b2a8dfb54bb00263e59e7196e99d41b10332cc3 Mon Sep 17 00:00:00 2001 From: BrokenArrow Date: Tue, 22 Oct 2024 11:28:36 +0800 Subject: [PATCH] [Midend] Add RFFT op in Extend DAP Pass (#387) --- examples/DAPDialect/CMakeLists.txt | 7 + examples/DAPDialect/RFFT.cpp | 75 + .../buddy/DAP/DSP/WhisperPreprocess.h | 7 + frontend/Interfaces/lib/DAP-extend.mlir | 4 + midend/include/Dialect/DAP/DAPOps.td | 6 +- .../Conversion/ExtendDAP/ExtendDAPPass.cpp | 2960 +++++++++++++++-- 6 files changed, 2715 insertions(+), 344 deletions(-) create mode 100644 examples/DAPDialect/RFFT.cpp diff --git a/examples/DAPDialect/CMakeLists.txt b/examples/DAPDialect/CMakeLists.txt index dff9b10ffb..96b921ee3a 100644 --- a/examples/DAPDialect/CMakeLists.txt +++ b/examples/DAPDialect/CMakeLists.txt @@ -62,3 +62,10 @@ target_link_libraries(buddy-whisper-preprocess BuddyLibDAP mlir_c_runner_utils ) + +add_executable(buddy-rfft RFFT.cpp) +add_dependencies(buddy-rfft buddy-opt) +target_link_libraries(buddy-rfft + BuddyLibDAP + mlir_c_runner_utils +) diff --git a/examples/DAPDialect/RFFT.cpp b/examples/DAPDialect/RFFT.cpp new file mode 100644 index 0000000000..993fec95e1 --- /dev/null +++ b/examples/DAPDialect/RFFT.cpp @@ -0,0 +1,75 @@ +//===- RFFT.cpp - Example of DAP RFFT Operation ---------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// An example of the RFFT function from Whisper Preprocessor operation. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#define testLength 840 + +using namespace dap; +using namespace std; + +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +// Write preprocessing results to a text file. +void printResult(MemRef &outputMemRef) { + ofstream fout("whisperPreprocessResultRFFT.txt"); + // Print title. + fout << "-----------------------------------------" << std::endl; + fout << "[ Buddy RFFT Result ]" << std::endl; + fout << "-----------------------------------------" << std::endl; + // Print reuslt data. + for (int i = 0; i < testLength; ++i) { + fout << outputMemRef[i] << std::endl; + } + fout.close(); +} + +int main() { + // Print the title of this example. + const std::string title = "RFFT Operation Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + double *inputAlign = new double[testLength]; + for (int i = 0; i < testLength; ++i) { + inputAlign[i] = static_cast(i); + } + intptr_t inputSizes[1] = {testLength}; + MemRef inputMemRef(inputAlign, inputSizes); + + printLogLabel(); + std::cout << "Running RFFT operation" << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::RFFT(&inputMemRef); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "RFFT time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; + + printResult(inputMemRef); + + return 0; +} diff --git a/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h index a6c3ef3b2e..d0d1d8fb63 100644 --- a/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h +++ b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h @@ -40,6 +40,9 @@ extern "C" { // first operand. void _mlir_ciface_buddy_whisperPreprocess(MemRef *outputFeatures, MemRef *inputRawSpeech); + +void _mlir_ciface_buddy_RFFT(MemRef *inputRawSpeech); + } } // namespace detail @@ -49,6 +52,10 @@ void whisperPreprocess(MemRef *inputRawSpeech, detail::_mlir_ciface_buddy_whisperPreprocess(outputFeatures, inputRawSpeech); } + +void RFFT(MemRef *inputRawSpeech) { + detail::_mlir_ciface_buddy_RFFT(inputRawSpeech); +} } // namespace dap #endif // FRONTEND_INTERFACES_BUDDY_DAP_DSP_WHISPERPREPROCESS diff --git a/frontend/Interfaces/lib/DAP-extend.mlir b/frontend/Interfaces/lib/DAP-extend.mlir index c77fe38735..2c9b7a5a3b 100644 --- a/frontend/Interfaces/lib/DAP-extend.mlir +++ b/frontend/Interfaces/lib/DAP-extend.mlir @@ -2,3 +2,7 @@ func.func @buddy_whisperPreprocess(%in : memref) -> memref<1x80x3000xf32> %out = dap.whisper_preprocess %in : memref to memref<1x80x3000xf32> return %out : memref<1x80x3000xf32> } +func.func @buddy_RFFT(%in : memref) -> () { + dap.rfft %in : memref + return +} diff --git a/midend/include/Dialect/DAP/DAPOps.td b/midend/include/Dialect/DAP/DAPOps.td index 70d7a21fe6..d14ca5cfcd 100644 --- a/midend/include/Dialect/DAP/DAPOps.td +++ b/midend/include/Dialect/DAP/DAPOps.td @@ -93,8 +93,8 @@ def DAP_IirOp : DAP_Op<"iir"> { }]; } -def DAP_RFFT400Op : DAP_Op<"rfft400"> { - let summary = "RFFT operation for length 400."; +def DAP_RFFTOp : DAP_Op<"rfft"> { + let summary = "RFFT operation."; let description = [{ The RFFT algorithm is designed to handle real-valued input signals. Real signals exhibit conjugate symmetry in the frequency domain, meaning that @@ -105,7 +105,7 @@ def DAP_RFFT400Op : DAP_Op<"rfft400"> { Example: ```mlir - dap.rfft400 %data : memref<400xf64> + dap.rfft %data : memref ``` }]; diff --git a/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp index 20918fda97..32fc42fcf7 100644 --- a/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp +++ b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp @@ -38,6 +38,7 @@ using namespace vector; using namespace mlir::arith; using namespace mlir::linalg; using namespace mlir::bufferization; +using namespace mlir::scf; //===----------------------------------------------------------------------===// // Rewrite Pattern @@ -756,6 +757,28 @@ Value padReflect(PatternRewriter &rewriter, Location loc, Value c0, Value c1, return padOp2.getResult(); } +// function to print a memref for debug +void printMemref(OpBuilder &rewriter, Location loc, Value input, int l) { + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value length = rewriter.create(loc, l); + rewriter.create(loc, "Print Start:\n"); + + rewriter.create( + loc, c0, length, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value x = b.create(loc, input, i); + b.create(loc, x); + + b.create(loc, std::nullopt); + }); + + rewriter.create(loc, "\n"); +} + +// WA CC CH PM MULPM C1 C1w C2 CH2 CH2w CH_radfg CCw CSARR AR AI IANG are helper +// functions for RFFTP inline Value WA(OpBuilder &builder, Location loc, Value wa, Value x, Value i, Value ido, Value c1) { Value idom1 = builder.create(loc, ido, c1); @@ -799,15 +822,695 @@ inline std::vector MULPM(OpBuilder &builder, Location loc, Value c, builder.create(loc, tmp3, tmp4)}; } -void radf4Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, - Value wa, Value ido, Value l1, Value cdim, Value c0, Value c1, - Value c2, Value c3) { +inline Value C1(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value l1) { + Value tmp1 = builder.create(loc, l1, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + return builder.create(loc, cc, index); +} + +inline void C1w(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value l1, Value toWrite) { + Value tmp1 = builder.create(loc, l1, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + builder.create(loc, toWrite, cc, index); + return; +} + +inline Value C2(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value idl1) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + return builder.create(loc, cc, index); +} + +inline Value CH2(OpBuilder &builder, Location loc, Value ch, Value a, Value b, + Value idl1) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + return builder.create(loc, ch, index); +} + +inline void CH2w(OpBuilder &builder, Location loc, Value ch, Value a, Value b, + Value idl1, Value toWrite) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + builder.create(loc, toWrite, ch, index); + return; +} + +inline Value CH_radfg(OpBuilder &builder, Location loc, Value ch, Value a, + Value b, Value c, Value ido, Value l1) { + Value tmp = builder.create(loc, l1, c); + Value tmp1 = builder.create(loc, b, tmp); + Value tmp2 = builder.create(loc, tmp1, ido); + Value index = builder.create(loc, tmp2, a); + return builder.create(loc, ch, index); +} + +inline void CCw(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value cdim, Value toWrite) { + Value tmp = builder.create(loc, cdim, c); + Value tmp1 = builder.create(loc, b, tmp); + Value tmp2 = builder.create(loc, tmp1, ido); + Value index = builder.create(loc, tmp2, a); + builder.create(loc, toWrite, cc, index); + return; +} + +inline Value CSARR(OpBuilder &builder, Location loc, Value csarr, Value index) { + + return builder.create(loc, csarr, index); +} + +inline Value AR(OpBuilder &builder, Location loc, Value csarr, Value iang) { + Value c2 = builder.create(loc, 2); + Value index = builder.create(loc, iang, c2); + return CSARR(builder, loc, csarr, index); +} + +inline Value AI(OpBuilder &builder, Location loc, Value csarr, Value iang) { + Value c1 = builder.create(loc, 1); + Value c2 = builder.create(loc, 2); + Value tmp = builder.create(loc, iang, c2); + Value index = builder.create(loc, tmp, c1); + return CSARR(builder, loc, csarr, index); +} + +inline Value IANG(OpBuilder &builder, Location loc, Value iang, Value l, + Value ip) { + + Value iang_new = builder.create(loc, iang, l); + + Value condition = builder.create( + loc, arith::CmpIPredicate::sge, iang_new, ip); + + auto result = builder.create( + loc, condition, + [&](OpBuilder &b, Location loc) { + Value res = b.create(loc, iang_new, ip); + b.create(loc, ValueRange{res}); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{iang_new}); + }); + + return result.getResult(0); +} + +void radfgExtend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value csarr, Value ido, Value ip, Value l1) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value cdim = opBuilder.create(loc, ip, c0); + Value tmp0 = opBuilder.create(loc, ip, c1); + Value ipph = opBuilder.create(loc, tmp0, c2); + Value idom1 = opBuilder.create(loc, ido, c1); + Value idom2 = opBuilder.create(loc, ido, c2); + Value idl1 = opBuilder.create(loc, ido, l1); + + opBuilder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value ik, ValueRange ik_args) { + Value c2ik0 = C2(builder, loc, cc, ik, c0, idl1); + CH2w(builder, loc, ch, ik, c0, idl1, c2ik0); + builder.create(loc, std::nullopt); + }); + + opBuilder.create( + loc, c1, ipph, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value j, ValueRange j_args) { + builder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value ik, ValueRange ik_args) { + Value c2ikj = C2(b, loc, cc, ik, j, idl1); + Value ch2ik0 = CH2(b, loc, ch, ik, c0, idl1); + Value ch2ik0_updated = + b.create(loc, ch2ik0, c2ikj); + + CH2w(b, loc, ch, ik, c0, idl1, ch2ik0_updated); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + opBuilder.create( loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + builder.create( + loc, c0, ido, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value chik0 = CH_radfg(b, loc, ch, i, k, c0, ido, l1); + + CCw(b, loc, cc, i, c0, k, ido, cdim, chik0); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + + Value j_start_0 = opBuilder.create(loc, 1); + Value jc_start_0 = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{j_start_0, jc_start_0}, + [&](OpBuilder &builder, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + + Value tmp = builder.create(loc, j, c2); + Value j2 = builder.create(loc, tmp, c1); + Value j2p1 = builder.create(loc, j2, c1); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k, ValueRange k_args) { + Value ch0kj = CH_radfg(b, loc, ch, c0, k, j, ido, l1); + CCw(b, loc, cc, idom1, j2, k, ido, cdim, ch0kj); + + Value ch0kjc = CH_radfg(b, loc, ch, c0, k, jc, ido, l1); + CCw(b, loc, cc, c0, j2p1, k, ido, cdim, ch0kjc); + + b.create(loc, std::nullopt); + }); + + Value j_next = builder.create(loc, j, c1); + Value jc_next = builder.create(loc, jc, c1); + builder.create(loc, std::vector{j_next, jc_next}); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::ne, ido, l1); + + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + Value j_start_1 = opBuilder.create(loc, 1); + Value jc_start_1 = opBuilder.create(loc, ip, c1); + + builder.create( + loc, c1, ipph, c1, ValueRange{j_start_1, jc_start_1}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + + Value tmp = b.create(loc, j, c2); + Value j2 = b.create(loc, tmp, c1); + Value j2p1 = b.create(loc, j2, c1); + + b.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value k, ValueRange k_args) { + Value i_start_0 = b2.create(loc, 1); + Value ic_start_0 = b2.create(loc, ido, c3); + + b2.create( + loc, c1, idom1, c2, ValueRange{i_start_0, ic_start_0}, + [&](OpBuilder &b3, Location loc, Value i_loop, + ValueRange i_loop_args) { + Value i = i_loop_args[0]; + Value ic = i_loop_args[1]; + + Value ip1 = b3.create(loc, i, c1); + Value icp1 = b3.create(loc, ic, c1); + + Value chikj = CH_radfg(b3, loc, ch, i, k, j, ido, l1); + Value chikjc = + CH_radfg(b3, loc, ch, i, k, jc, ido, l1); + Value tmp2 = + b3.create(loc, chikj, chikjc); + Value tmp3 = + b3.create(loc, chikj, chikjc); + CCw(b3, loc, cc, i, j2p1, k, ido, cdim, tmp2); + CCw(b3, loc, cc, ic, j2, k, ido, cdim, tmp3); + + Value chip1kj = + CH_radfg(b3, loc, ch, ip1, k, j, ido, l1); + Value chip1kjc = + CH_radfg(b3, loc, ch, ip1, k, jc, ido, l1); + Value tmp4 = + b3.create(loc, chip1kj, chip1kjc); + Value tmp5 = + b3.create(loc, chip1kjc, chip1kj); + CCw(b3, loc, cc, ip1, j2p1, k, ido, cdim, tmp4); + CCw(b3, loc, cc, icp1, j2, k, ido, cdim, tmp5); + + Value i_next = b3.create(loc, i, c2); + Value ic_next = b3.create(loc, ic, c2); + b3.create( + loc, std::vector{i_next, ic_next}); + }); + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c1); + Value jc_next = b.create(loc, jc, c1); + b.create(loc, std::vector{j_next, jc_next}); + }); + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle general radix FFT computation. +void radfg(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value csarr, Value ido, Value ip, Value l1) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value ipm1 = opBuilder.create(loc, ip, c1); + Value ipm2 = opBuilder.create(loc, ip, c2); + + Value cdim = opBuilder.create(loc, ip, c0); + Value tmp = opBuilder.create(loc, ip, c1); + Value ipph = opBuilder.create(loc, tmp, c2); + + Value idl1 = opBuilder.create(loc, ido, l1); + Value idom1 = opBuilder.create(loc, ido, c1); + Value idom2 = opBuilder.create(loc, ido, c2); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, l1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value jc_start = builder.create(loc, ip, c1); + + builder.create( + loc, c1, ipph, c1, ValueRange{jc_start}, + [&](OpBuilder &b, Location loc, Value j, ValueRange j_args) { + Value jc = j_args[0]; + + Value jm1 = b.create(loc, j, c1); + Value jcm1 = b.create(loc, jc, c1); + + Value is = b.create(loc, jm1, idom1); + Value is2 = b.create(loc, jcm1, idom1); + + b.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value k, ValueRange k_args) { + Value idij_start = b2.create(loc, is, c0); + Value idij2_start = b2.create(loc, is2, c0); + + b2.create( + loc, c1, idom1, c2, ValueRange{idij_start, idij2_start}, + [&](OpBuilder &b3, Location loc, Value i, + ValueRange i_args) { + Value idij = i_args[0]; + Value idij2 = i_args[1]; + + Value ip1 = b3.create(loc, i, c1); + Value idijp1 = + b3.create(loc, idij, c1); + Value idij2p1 = + b3.create(loc, idij2, c1); + + Value t1 = C1(b3, loc, cc, i, k, j, ido, l1); + Value t2 = C1(b3, loc, cc, ip1, k, j, ido, l1); + Value t3 = C1(b3, loc, cc, i, k, jc, ido, l1); + Value t4 = C1(b3, loc, cc, ip1, k, jc, ido, l1); + + Value waidij = + b3.create(loc, wa, idij); + Value waidijp1 = + b3.create(loc, wa, idijp1); + Value waidij2 = + b3.create(loc, wa, idij2); + Value waidij2p1 = + b3.create(loc, wa, idij2p1); + + Value tmp1_x1 = + b3.create(loc, waidij, t1); + Value tmp2_x1 = + b3.create(loc, waidijp1, t2); + Value x1 = + b3.create(loc, tmp1_x1, tmp2_x1); + + Value tmp1_x2 = + b3.create(loc, waidij, t2); + Value tmp2_x2 = + b3.create(loc, waidijp1, t1); + Value x2 = + b3.create(loc, tmp1_x2, tmp2_x2); + + Value tmp1_x3 = + b3.create(loc, waidij2, t3); + Value tmp2_x3 = + b3.create(loc, waidij2p1, t4); + Value x3 = + b3.create(loc, tmp1_x3, tmp2_x3); + + Value tmp1_x4 = + b3.create(loc, waidij2, t4); + Value tmp2_x4 = + b3.create(loc, waidij2p1, t3); + Value x4 = + b3.create(loc, tmp1_x4, tmp2_x4); + + Value tmp3 = b3.create(loc, x1, x3); + Value tmp4 = b3.create(loc, x2, x4); + Value tmp5 = b3.create(loc, x2, x4); + Value tmp6 = b3.create(loc, x3, x1); + + C1w(b3, loc, cc, i, k, j, ido, l1, tmp3); + C1w(b3, loc, cc, i, k, jc, ido, l1, tmp4); + C1w(b3, loc, cc, ip1, k, j, ido, l1, tmp5); + C1w(b3, loc, cc, ip1, k, jc, ido, l1, tmp6); + + Value idij_next = + b3.create(loc, idij, c2); + Value idij2_next = + b3.create(loc, idij2, c2); + + b3.create( + loc, std::vector{idij_next, idij2_next}); + }); + b2.create(loc, std::nullopt); + } + + ); + + Value jc_next = b.create(loc, jc, c1); + b.create(loc, jc_next); + }); + + builder.create(loc, std::nullopt); + }); + + Value jc_a_start = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{jc_a_start}, + [&](OpBuilder &builder, Location loc, Value j_a, ValueRange j_a_args) { + Value jc_a = j_a_args[0]; + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k_a, ValueRange k_a_args) { + Value t1_a = C1(b, loc, cc, c0, k_a, j_a, ido, l1); + Value t2_a = C1(b, loc, cc, c0, k_a, jc_a, ido, l1); + + Value tmp_a = b.create(loc, t1_a, t2_a); + Value tmp1_a = b.create(loc, t2_a, t1_a); + + C1w(b, loc, cc, c0, k_a, j_a, ido, l1, tmp_a); + C1w(b, loc, cc, c0, k_a, jc_a, ido, l1, tmp1_a); + b.create(loc, std::nullopt); + }); + + Value jc_a_next = builder.create(loc, jc_a, c1); + builder.create(loc, jc_a_next); + }); + + Value lc_b_start = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{lc_b_start}, + [&](OpBuilder &builder, Location loc, Value l_b, ValueRange l_b_args) { + Value lc_b = l_b_args[0]; + + builder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value ik_b, ValueRange ik_b_args) { + Value m2l = b.create(loc, l_b, c2); + Value m4l = b.create(loc, l_b, c4); + Value m2lp1 = b.create(loc, m2l, c1); + Value m4lp1 = b.create(loc, m4l, c1); + + Value csarr2l = CSARR(b, loc, csarr, m2l); + Value csarr4l = CSARR(b, loc, csarr, m4l); + Value csarr2lp1 = CSARR(b, loc, csarr, m2lp1); + Value csarr4lp1 = CSARR(b, loc, csarr, m4lp1); + + Value c2ik0 = C2(b, loc, cc, ik_b, c0, idl1); + Value c2ik1 = C2(b, loc, cc, ik_b, c1, idl1); + Value c2ik2 = C2(b, loc, cc, ik_b, c2, idl1); + + Value c2ikipm1 = C2(b, loc, cc, ik_b, ipm1, idl1); + Value c2ikipm2 = C2(b, loc, cc, ik_b, ipm2, idl1); + + Value tmp_b = b.create(loc, csarr2l, c2ik1); + Value tmp1_b = b.create(loc, csarr4l, c2ik2); + Value tmp2_b = b.create(loc, tmp_b, tmp1_b); + Value tmp3_b = b.create(loc, c2ik0, tmp2_b); + + CH2w(b, loc, ch, ik_b, l_b, idl1, tmp3_b); + + Value tmp4_b = b.create(loc, csarr2lp1, c2ikipm1); + Value tmp5_b = b.create(loc, csarr4lp1, c2ikipm2); + Value tmp6_b = b.create(loc, tmp4_b, tmp5_b); + + CH2w(b, loc, ch, ik_b, lc_b, idl1, tmp6_b); + b.create(loc, std::nullopt); + }); + + Value iang_start_c = builder.create(loc, c2, l_b); + Value j_start_c = builder.create(loc, 3); + Value jc_start_c = builder.create(loc, ip, c3); + Value ipphm1 = builder.create(loc, ipph, c1); + Value ipphm3 = builder.create(loc, ipph, c3); + + auto loop1 = builder.create( + loc, j_start_c, ipphm3, c4, + ValueRange{j_start_c, jc_start_c, iang_start_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_c = IANG(b, loc, iang, l_b, ip); + Value ar1 = AR(b, loc, csarr, iang_1_c); + Value ai1 = AI(b, loc, csarr, iang_1_c); + + Value iang_2_c = IANG(b, loc, iang_1_c, l_b, ip); + Value ar2 = AR(b, loc, csarr, iang_2_c); + Value ai2 = AI(b, loc, csarr, iang_2_c); + + Value iang_3_c = IANG(b, loc, iang_2_c, l_b, ip); + Value ar3 = AR(b, loc, csarr, iang_3_c); + Value ai3 = AI(b, loc, csarr, iang_3_c); + + Value iang_4_c = IANG(b, loc, iang_3_c, l_b, ip); + Value ar4 = AR(b, loc, csarr, iang_4_c); + Value ai4 = AI(b, loc, csarr, iang_4_c); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_c, + ValueRange ik_c_args) { + Value jp1 = b2.create(loc, j, c1); + Value jp2 = b2.create(loc, j, c2); + Value jp3 = b2.create(loc, j, c3); + Value jm1 = b2.create(loc, j, c1); + Value jm2 = b2.create(loc, j, c2); + Value jm3 = b2.create(loc, j, c3); + + Value c2ikj = C2(b2, loc, cc, ik_c, j, idl1); + Value c2ikjp1 = C2(b2, loc, cc, ik_c, jp1, idl1); + Value c2ikjp2 = C2(b2, loc, cc, ik_c, jp2, idl1); + Value c2ikjp3 = C2(b2, loc, cc, ik_c, jp3, idl1); + + Value tmp_c = b2.create(loc, ar1, c2ikj); + Value tmp1_c = b2.create(loc, ar2, c2ikjp1); + Value tmp2_c = b2.create(loc, ar3, c2ikjp2); + Value tmp3_c = b2.create(loc, ar4, c2ikjp3); + + Value tmp4_c = b2.create(loc, tmp_c, tmp1_c); + Value tmp5_c = + b2.create(loc, tmp4_c, tmp2_c); + Value tmp6_c = + b2.create(loc, tmp5_c, tmp3_c); + + Value ch2ikl = CH2(b2, loc, ch, ik_c, l_b, idl1); + Value tmp7_c = + b2.create(loc, tmp6_c, ch2ikl); + CH2w(b2, loc, ch, ik_c, l_b, idl1, tmp7_c); + + Value jcm1 = b2.create(loc, jc, c1); + Value jcm2 = b2.create(loc, jc, c2); + Value jcm3 = b2.create(loc, jc, c3); + + Value c2ikjc = C2(b2, loc, cc, ik_c, jc, idl1); + Value c2ikjcm1 = C2(b2, loc, cc, ik_c, jcm1, idl1); + Value c2ikjcm2 = C2(b2, loc, cc, ik_c, jcm2, idl1); + Value c2ikjcm3 = C2(b2, loc, cc, ik_c, jcm3, idl1); + + Value tmp_ai1 = b2.create(loc, ai1, c2ikjc); + Value tmp_ai2 = + b2.create(loc, ai2, c2ikjcm1); + Value tmp_ai3 = + b2.create(loc, ai3, c2ikjcm2); + Value tmp_ai4 = + b2.create(loc, ai4, c2ikjcm3); + + Value tmp_ai5 = + b2.create(loc, tmp_ai1, tmp_ai2); + Value tmp_ai6 = + b2.create(loc, tmp_ai5, tmp_ai3); + Value tmp_ai7 = + b2.create(loc, tmp_ai6, tmp_ai4); + + Value ch2iklc = CH2(b2, loc, ch, ik_c, lc_b, idl1); + Value tmp_ai8 = + b2.create(loc, tmp_ai7, ch2iklc); + CH2w(b2, loc, ch, ik_c, lc_b, idl1, tmp_ai8); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c4); + Value jc_next = b.create(loc, jc, c4); + builder.create( + loc, std::vector{j_next, jc_next, iang_4_c}); + }); + + Value j_1_c = loop1.getResults()[0]; + Value jc_1_c = loop1.getResults()[1]; + Value iang1_c = loop1.getResults()[2]; + + auto loop2 = builder.create( + loc, j_1_c, ipphm1, c2, ValueRange{j_1_c, jc_1_c, iang1_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_d = IANG(b, loc, iang, l_b, ip); + Value ar1 = AR(b, loc, csarr, iang_1_d); + Value ai1 = AI(b, loc, csarr, iang_1_d); + + Value iang_2_d = IANG(b, loc, iang_1_d, l_b, ip); + Value ar2 = AR(b, loc, csarr, iang_2_d); + Value ai2 = AI(b, loc, csarr, iang_2_d); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_d, + ValueRange ik_d_args) { + Value jp1 = b2.create(loc, j, c1); + Value jm1 = b2.create(loc, j, c1); + + Value c2ikj = C2(b2, loc, cc, ik_d, j, idl1); + Value c2ikjp1 = C2(b2, loc, cc, ik_d, jp1, idl1); + + Value tmp_c = b2.create(loc, ar1, c2ikj); + Value tmp1_c = b2.create(loc, ar2, c2ikjp1); + Value tmp2_c = b2.create(loc, tmp_c, tmp1_c); + + Value ch2ikl = CH2(b2, loc, ch, ik_d, l_b, idl1); + Value tmp3_c = + b2.create(loc, tmp2_c, ch2ikl); + CH2w(b2, loc, ch, ik_d, l_b, idl1, tmp3_c); + + Value jcm1 = b2.create(loc, jc, c1); + Value c2ikjc = C2(b2, loc, cc, ik_d, jc, idl1); + Value c2ikjcm1 = C2(b2, loc, cc, ik_d, jcm1, idl1); + + Value tmp_ai1 = b2.create(loc, ai1, c2ikjc); + Value tmp_ai2 = + b2.create(loc, ai2, c2ikjcm1); + Value tmp_ai3 = + b2.create(loc, tmp_ai1, tmp_ai2); + + Value ch2iklc = CH2(b2, loc, ch, ik_d, lc_b, idl1); + Value tmp_ai4 = + b2.create(loc, tmp_ai3, ch2iklc); + CH2w(b2, loc, ch, ik_d, lc_b, idl1, tmp_ai4); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c2); + Value jc_next = b.create(loc, jc, c2); + builder.create( + loc, std::vector{j_next, jc_next, iang_2_d}); + }); + + Value j_2_c = loop2.getResults()[0]; + Value jc_2_c = loop2.getResults()[1]; + Value iang2_c = loop2.getResults()[2]; + + auto loop3 = builder.create( + loc, j_2_c, ipph, c1, ValueRange{j_2_c, jc_2_c, iang2_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_e = IANG(b, loc, iang, l_b, ip); + Value ar = AR(b, loc, csarr, iang_1_e); + Value ai = AI(b, loc, csarr, iang_1_e); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_e, + ValueRange ik_e_args) { + Value c2ikj = C2(b2, loc, cc, ik_e, j, idl1); + Value tmp_c = b2.create(loc, ar, c2ikj); + Value ch2ikl = CH2(b2, loc, ch, ik_e, l_b, idl1); + Value tmp2_c = b2.create(loc, tmp_c, ch2ikl); + CH2w(b2, loc, ch, ik_e, l_b, idl1, tmp2_c); + + Value c2ikjc = C2(b2, loc, cc, ik_e, jc, idl1); + Value tmp_ai = b2.create(loc, ai, c2ikjc); + Value ch2iklc = CH2(b2, loc, ch, ik_e, lc_b, idl1); + Value tmp2_ai = + b2.create(loc, tmp_ai, ch2iklc); + CH2w(b2, loc, ch, ik_e, lc_b, idl1, tmp2_ai); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c2); + Value jc_next = b.create(loc, jc, c2); + builder.create( + loc, std::vector{j_next, jc_next, iang_1_e}); + }); + + Value lc_b_next = builder.create(loc, lc_b, c1); + builder.create(loc, lc_b_next); + }); + + radfgExtend(opBuilder, loc, cc, ch, wa, csarr, ido, ip, l1); +} + +void radf2Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim) { + FloatType f64Ty = opBuilder.getF64Type(); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c20 = opBuilder.create(loc, 20); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { builder.create( loc, c2, ido, c2, std::nullopt, - [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { Value ic = b.create(loc, ido, i); Value icm1 = b.create(loc, ic, c1); Value im1 = b.create(loc, i, c1); @@ -817,303 +1520,1946 @@ void radf4Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); - std::vector cr2_ci2 = + std::vector tr2_ti2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + std::vector ccim1k0_tr2 = PM(b, loc, ccim1k0, tr2_ti2[0]); + std::vector ti2_ccik0 = PM(b, loc, tr2_ti2[1], ccik0); + + CH(b, loc, ch, im1, c0, k, ido, cdim, ccim1k0_tr2[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, ccim1k0_tr2[1]); + + CH(b, loc, ch, i, c0, k, ido, cdim, ti2_ccik0[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, ti2_ccik0[1]); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// Handle radix-2 FFT computation +void radf2(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 2); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c20 = opBuilder.create(loc, 20); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iv_args) { + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + std::vector cc0k0_cc0k1 = PM(builder, loc, cc0k0, cc0k1); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, cc0k0_cc0k1[0]); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, cc0k0_cc0k1[1]); + builder.create(loc, std::nullopt); + }); + + Value flag = opBuilder.create(loc, ido, c2); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::eq, flag, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k, ValueRange k_args) { + Value ccidom1k1 = CC(b, loc, cc, idom1, k, c1, ido, l1); + Value tmp = b.create(loc, ccidom1k1); + CH(b, loc, ch, c0, c1, k, ido, cdim, tmp); + Value ccidom1k0 = CC(b, loc, cc, idom1, k, c0, ido, l1); + CH(b, loc, ch, idom1, c0, k, ido, cdim, ccidom1k0); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + radf2Extend(builder, loc, cc, ch, wa, ido, l1, cdim); + builder.create(loc, std::nullopt); + }); +} + +void radf3Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value taur = + opBuilder.create(loc, APFloat(double(-0.5)), f64Ty); + Value taui = opBuilder.create( + loc, APFloat(double(0.86602540378443864676)), f64Ty); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector dr2_di2 = MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); - std::vector cr3_ci3 = + std::vector dr3_di3 = MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); - Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); - Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); - Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); - Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); - std::vector cr4_ci4 = - MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + Value cr2 = b.create(loc, dr2_di2[0], dr3_di3[0]); + Value ci2 = b.create(loc, dr2_di2[1], dr3_di3[1]); - std::vector tr1_tr4 = PM(b, loc, cr4_ci4[0], cr2_ci2[0]); - std::vector ti1_ti4 = PM(b, loc, cr2_ci2[1], cr4_ci4[1]); Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); - std::vector tr2_tr3 = PM(b, loc, ccim1k0, cr3_ci3[0]); + Value tmp5 = b.create(loc, ccim1k0, cr2); + CH(builder, loc, ch, im1, c0, k, ido, cdim, tmp5); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); - std::vector ti2_ti3 = PM(b, loc, ccik0, cr3_ci3[1]); + Value tmp6 = b.create(loc, ccik0, ci2); + CH(builder, loc, ch, i, c0, k, ido, cdim, tmp6); - std::vector chtmp0 = PM(b, loc, tr2_tr3[0], tr1_tr4[0]); - CH(b, loc, ch, im1, c0, k, ido, cdim, chtmp0[0]); - CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp0[1]); + Value tmp7 = b.create(loc, taur, cr2); + Value tr2 = b.create(loc, ccim1k0, tmp7); - std::vector chtmp1 = PM(b, loc, ti1_ti4[0], ti2_ti3[0]); - CH(b, loc, ch, i, c0, k, ido, cdim, chtmp1[0]); - CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp1[1]); + Value tmp8 = b.create(loc, taur, ci2); + Value ti2 = b.create(loc, ccik0, tmp8); - std::vector chtmp2 = PM(b, loc, tr2_tr3[1], ti1_ti4[1]); - CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp2[0]); - CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp2[1]); + Value tmp9 = b.create(loc, dr2_di2[1], dr3_di3[1]); + Value tr3 = b.create(loc, taui, tmp9); - std::vector chtmp3 = PM(b, loc, tr1_tr4[1], ti2_ti3[1]); - CH(b, loc, ch, i, c2, k, ido, cdim, chtmp3[0]); - CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp3[1]); + Value tmp10 = + b.create(loc, dr3_di3[0], dr2_di2[0]); + Value ti3 = b.create(loc, taui, tmp10); + + std::vector tr2_tr3 = PM(b, loc, tr2, tr3); + std::vector ti3_ti2 = PM(b, loc, ti3, ti2); + CH(builder, loc, ch, im1, c2, k, ido, cdim, tr2_tr3[0]); + CH(builder, loc, ch, icm1, c1, k, ido, cdim, tr2_tr3[1]); + + CH(builder, loc, ch, i, c2, k, ido, cdim, ti3_ti2[0]); + CH(builder, loc, ch, ic, c1, k, ido, cdim, ti3_ti2[1]); + + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// Handle radix-3 FFT computation +void radf3(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 3); + Value taur = + opBuilder.create(loc, APFloat(double(-0.5)), f64Ty); + Value taui = opBuilder.create( + loc, APFloat(double(0.86602540378443864676)), f64Ty); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iv_args) { + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); + Value cr2 = builder.create(loc, cc0k1, cc0k2); + + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value tmp0 = builder.create(loc, cc0k0, cr2); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp0); + + Value tmp1 = builder.create(loc, cc0k2, cc0k1); + Value tmp2 = builder.create(loc, tmp1, taui); + CH(builder, loc, ch, c0, c2, iv, ido, cdim, tmp2); + + Value tmp3 = builder.create(loc, taur, cr2); + Value tmp4 = builder.create(loc, tmp3, cc0k0); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tmp4); + + builder.create(loc, std::nullopt); + }); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, ido, c1); + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + radf3Extend(builder, loc, cc, ch, wa, ido, l1, cdim); + builder.create(loc, std::nullopt); + }); +} + +void radf4Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim, Value c0, Value c1, + Value c2, Value c3) { + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector cr2_ci2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); + Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); + Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); + Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); + std::vector cr3_ci3 = + MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + + Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); + Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); + Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); + Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); + std::vector cr4_ci4 = + MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + + std::vector tr1_tr4 = PM(b, loc, cr4_ci4[0], cr2_ci2[0]); + std::vector ti1_ti4 = PM(b, loc, cr2_ci2[1], cr4_ci4[1]); + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + std::vector tr2_tr3 = PM(b, loc, ccim1k0, cr3_ci3[0]); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + std::vector ti2_ti3 = PM(b, loc, ccik0, cr3_ci3[1]); + + std::vector chtmp0 = PM(b, loc, tr2_tr3[0], tr1_tr4[0]); + CH(b, loc, ch, im1, c0, k, ido, cdim, chtmp0[0]); + CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp0[1]); + + std::vector chtmp1 = PM(b, loc, ti1_ti4[0], ti2_ti3[0]); + CH(b, loc, ch, i, c0, k, ido, cdim, chtmp1[0]); + CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp1[1]); + + std::vector chtmp2 = PM(b, loc, tr2_tr3[1], ti1_ti4[1]); + CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp2[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp2[1]); + + std::vector chtmp3 = PM(b, loc, tr1_tr4[1], ti2_ti3[1]); + CH(b, loc, ch, i, c2, k, ido, cdim, chtmp3[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp3[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle radix-4 FFT computation +void radf4(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1, Value c0, Value c1, Value c2, Value c3) { + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 4); + Value hsqt2 = opBuilder.create( + loc, APFloat(double(0.70710678118654752440)), f64Ty); + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iargs) { + Value cc0k3 = CC(builder, loc, cc, c0, iv, c3, ido, l1); + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + std::vector tr1_tmp0 = PM(builder, loc, cc0k3, cc0k1); + CH(builder, loc, ch, c0, c2, iv, ido, cdim, tr1_tmp0[1]); + + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); + std::vector tr2_tmp1 = PM(builder, loc, cc0k0, cc0k2); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tr2_tmp1[1]); + + std::vector tmp2_tmp3 = + PM(builder, loc, tr2_tmp1[0], tr1_tmp0[0]); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp2_tmp3[0]); + CH(builder, loc, ch, idom1, c3, iv, ido, cdim, tmp2_tmp3[1]); + + builder.create(loc, std::nullopt); + }); + + Value reminder = opBuilder.create(loc, ido, c2); + Value condition0 = opBuilder.create( + loc, arith::CmpIPredicate::eq, reminder, c0); + opBuilder.create( + loc, condition0, [&](OpBuilder &builder, Location loc) { + Value negHsqt2 = builder.create( + loc, APFloat(double(-0.70710678118654752440)), f64Ty); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { + Value ccidom1k1 = CC(b, loc, cc, idom1, iv, c1, ido, l1); + Value ccidom1k3 = CC(b, loc, cc, idom1, iv, c3, ido, l1); + Value tmp0 = b.create(loc, ccidom1k1, ccidom1k3); + Value ti1 = b.create(loc, negHsqt2, tmp0); + + Value tmp1 = b.create(loc, ccidom1k1, ccidom1k3); + Value tr1 = b.create(loc, hsqt2, tmp1); + + Value ccidom1k0 = CC(b, loc, cc, idom1, iv, c0, ido, l1); + std::vector tmp2_tmp3 = PM(b, loc, ccidom1k0, tr1); + CH(b, loc, ch, idom1, c0, iv, ido, cdim, tmp2_tmp3[0]); + CH(b, loc, ch, idom1, c2, iv, ido, cdim, tmp2_tmp3[1]); + + Value ccidom1k2 = CC(b, loc, cc, idom1, iv, c2, ido, l1); + std::vector tmp4_tmp5 = PM(b, loc, ti1, ccidom1k2); + CH(b, loc, ch, c0, c3, iv, ido, cdim, tmp4_tmp5[0]); + CH(b, loc, ch, c0, c1, iv, ido, cdim, tmp4_tmp5[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + radf4Extend(builder, loc, cc, ch, wa, ido, l1, cdim, c0, c1, c2, c3); + builder.create(loc, std::nullopt); + }); + + return; +} + +void radf5Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim, Value tr11, + Value tr12, Value ti11, Value ti12, Value c0, Value c1, + Value c2, Value c3, Value c4) { + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector dr2_di2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); + Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); + Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); + Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); + std::vector dr3_di3 = + MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + + Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); + Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); + Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); + Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); + std::vector dr4_di4 = + MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + + Value wa3im2 = WA(b, loc, wa, c3, im2, ido, c1); + Value wa3im1 = WA(b, loc, wa, c3, im1, ido, c1); + Value ccim1k4 = CC(b, loc, cc, im1, k, c4, ido, l1); + Value ccik4 = CC(b, loc, cc, i, k, c4, ido, l1); + std::vector dr5_di5 = + MULPM(b, loc, wa3im2, wa3im1, ccim1k4, ccik4); + + std::vector cr2_ci5 = PM(b, loc, dr5_di5[0], dr2_di2[0]); + std::vector ci2_cr5 = PM(b, loc, dr2_di2[1], dr5_di5[1]); + std::vector cr3_ci4 = PM(b, loc, dr4_di4[0], dr3_di3[0]); + std::vector ci3_cr4 = PM(b, loc, dr3_di3[1], dr4_di4[1]); + + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + Value tmpch0 = b.create(loc, ccim1k0, cr2_ci5[0]); + Value chim10k = b.create(loc, tmpch0, cr3_ci4[0]); + CH(b, loc, ch, im1, c0, k, ido, cdim, chim10k); + + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + Value tmpch1 = b.create(loc, ccik0, ci2_cr5[0]); + Value chi0k = b.create(loc, tmpch1, ci3_cr4[0]); + CH(b, loc, ch, i, c0, k, ido, cdim, chi0k); + + Value tmp0 = b.create(loc, tr11, cr2_ci5[0]); + Value tmp1 = b.create(loc, ccim1k0, tmp0); + Value tmp2 = b.create(loc, tr12, cr3_ci4[0]); + Value tr2 = b.create(loc, tmp1, tmp2); + + Value tmp3 = b.create(loc, tr11, ci2_cr5[0]); + Value tmp4 = b.create(loc, ccik0, tmp3); + Value tmp5 = b.create(loc, tr12, ci3_cr4[0]); + Value ti2 = b.create(loc, tmp4, tmp5); + + Value tmp6 = b.create(loc, tr12, cr2_ci5[0]); + Value tmp7 = b.create(loc, ccim1k0, tmp6); + Value tmp8 = b.create(loc, tr11, cr3_ci4[0]); + Value tr3 = b.create(loc, tmp7, tmp8); + + Value tmp9 = b.create(loc, tr12, ci2_cr5[0]); + Value tmp10 = b.create(loc, ccik0, tmp9); + Value tmp11 = b.create(loc, tr11, ci3_cr4[0]); + Value ti3 = b.create(loc, tmp10, tmp11); + + std::vector tr5_tr4 = + MULPM(b, loc, ci2_cr5[1], ci3_cr4[1], ti11, ti12); + std::vector ti5_ti4 = + MULPM(b, loc, cr2_ci5[1], cr3_ci4[1], ti11, ti12); + + std::vector chtmp0 = PM(b, loc, tr2, tr5_tr4[0]); + CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp0[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp0[1]); + + std::vector chtmp1 = PM(b, loc, ti5_ti4[0], ti2); + CH(b, loc, ch, i, c2, k, ido, cdim, chtmp1[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp1[1]); + + std::vector chtmp2 = PM(b, loc, tr3, tr5_tr4[1]); + CH(b, loc, ch, im1, c4, k, ido, cdim, chtmp2[0]); + CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp2[1]); + + std::vector chtmp3 = PM(b, loc, ti5_ti4[1], ti3); + CH(b, loc, ch, i, c4, k, ido, cdim, chtmp3[0]); + CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp3[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle radix-5 FFT computation +void radf5(OpBuilder &builder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1, Value c0, Value c1, Value c2, Value c3, + Value c4) { + + FloatType f64Ty = builder.getF64Type(); + Value cdim = builder.create(loc, 5); + Value tr11 = builder.create( + loc, APFloat(double(0.3090169943749474241)), f64Ty); + Value tr12 = builder.create( + loc, APFloat(double(-0.8090169943749474241)), f64Ty); + Value ti11 = builder.create( + loc, APFloat(double(0.95105651629515357212)), f64Ty); + Value ti12 = builder.create( + loc, APFloat(double(0.58778525229247312917)), f64Ty); + Value idom1 = builder.create(loc, ido, c1); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { + Value cc0k4 = CC(b, loc, cc, c0, iv, c4, ido, l1); + Value cc0k1 = CC(b, loc, cc, c0, iv, c1, ido, l1); + std::vector cr2_ci5 = PM(b, loc, cc0k4, cc0k1); + + Value cc0k3 = CC(b, loc, cc, c0, iv, c3, ido, l1); + Value cc0k2 = CC(b, loc, cc, c0, iv, c2, ido, l1); + std::vector cr3_ci4 = PM(b, loc, cc0k3, cc0k2); + + Value cc0k0 = CC(b, loc, cc, c0, iv, c0, ido, l1); + Value tmpch0 = b.create(loc, cc0k0, cr2_ci5[0]); + Value ch0 = b.create(loc, tmpch0, cr3_ci4[0]); + CH(b, loc, ch, c0, c0, iv, ido, cdim, ch0); + + Value tmpch1 = b.create(loc, tr11, cr2_ci5[0]); + Value tmpch2 = b.create(loc, tr12, cr3_ci4[0]); + Value tmpch3 = b.create(loc, cc0k0, tmpch1); + Value ch1 = b.create(loc, tmpch2, tmpch3); + CH(b, loc, ch, idom1, c1, iv, ido, cdim, ch1); + + Value tmpch4 = b.create(loc, ti11, cr2_ci5[1]); + Value tmpch5 = b.create(loc, ti12, cr3_ci4[1]); + Value ch2 = b.create(loc, tmpch4, tmpch5); + CH(b, loc, ch, c0, c2, iv, ido, cdim, ch2); + + Value tmpch6 = b.create(loc, tr12, cr2_ci5[0]); + Value tmpch7 = b.create(loc, tr11, cr3_ci4[0]); + Value tmpch8 = b.create(loc, tmpch6, tmpch7); + Value ch3 = b.create(loc, cc0k0, tmpch8); + CH(b, loc, ch, idom1, c3, iv, ido, cdim, ch3); + + Value tmpch9 = b.create(loc, ti12, cr2_ci5[1]); + Value tmpch10 = b.create(loc, ti11, cr3_ci4[1]); + Value ch4 = b.create(loc, tmpch9, tmpch10); + CH(b, loc, ch, c0, c4, iv, ido, cdim, ch4); + + b.create(loc, std::nullopt); + }); + + Value condition = + builder.create(loc, arith::CmpIPredicate::ne, ido, c1); + builder.create(loc, condition, [&](OpBuilder &b, Location loc) { + radf5Extend(b, loc, cc, ch, wa, ido, l1, cdim, tr11, tr12, ti11, ti12, c0, + c1, c2, c3, c4); + b.create(loc, std::nullopt); + }); + + return; +} + +// function to implement ++ operation +void index_increment(OpBuilder &opBuilder, Location loc, Value target) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value a = opBuilder.create(loc, target, c0); + Value b = opBuilder.create(loc, a, c1); + opBuilder.create(loc, b, target, c0); +} + +// switch 2 element in an array +void index_SWAP(OpBuilder &opBuilder, Location loc, Value array, Value target1, + Value target2) { + Value a = opBuilder.create(loc, array, target1); + Value b = opBuilder.create(loc, array, target2); + + opBuilder.create(loc, a, array, target2); + opBuilder.create(loc, b, array, target1); +} + +// factorize the input length ans store factors in Rfftp_fctdata_fct +Value rfftp_factorize(OpBuilder &opBuilder, Location loc, + Value Rfftp_fctdata_fct, Value Rfftp_fctdata_tw, + Value Rfftp_fctdata_tws, Value Rfftp_plan_length, + Value Rfftp_plan_nfct, Value Rfftp_plan_mem) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c_neg1 = opBuilder.create(loc, -1); + Value NFCT = opBuilder.create(loc, 25); + + FloatType f64Ty = opBuilder.getF64Type(); + IndexType indexTy = opBuilder.getIndexType(); + + Value length = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + Value length_1 = opBuilder.create(loc, Rfftp_plan_length, c0); + opBuilder.create(loc, length_1, length, c0); + + Value nfct = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + + opBuilder.create(loc, c0, nfct, c0); + + auto loop = opBuilder.create( + loc, TypeRange{indexTy}, ValueRange{length_1}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value length_while = args[0]; + + Value length_mod_4 = + builder.create(loc, length_while, c4); + Value condition = builder.create( + loc, arith::CmpIPredicate::eq, length_mod_4, c0); + builder.create(loc, condition, + ValueRange{length_while}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value length_while = args[0]; + + Value currnet_nfct = builder.create(loc, nfct, c0); + builder.create(loc, c4, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(builder, loc, nfct); + Value length_next = + builder.create(loc, length_while, c2); + builder.create(loc, length_next, length, c0); + + builder.create(loc, std::vector{length_next}); + }); + + Value length_if = opBuilder.create(loc, length, c0); + Value length_if_mod_2 = opBuilder.create(loc, length_if, c2); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, length_if_mod_2, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value length_next = builder.create(loc, length_if, c1); + builder.create(loc, length_next, length, c0); + + Value currnet_nfct = builder.create(loc, nfct, c0); + builder.create(loc, c2, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(builder, loc, nfct); + + Value currnet_nfct_1 = builder.create(loc, nfct, c0); + Value nfctm1 = builder.create(loc, currnet_nfct_1, c1); + index_SWAP(builder, loc, Rfftp_fctdata_fct, nfctm1, c0); + + builder.create(loc, std::nullopt); + }); + + TypeRange type1 = TypeRange{f64Ty}; + TypeRange type2 = TypeRange{indexTy}; + + Value maxl = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + Value current_length2 = opBuilder.create(loc, length, c0); + Value current_length2_i32 = opBuilder.create( + loc, opBuilder.getI32Type(), current_length2); + Value length_f64 = opBuilder.create( + loc, opBuilder.getF64Type(), current_length2_i32); + Value sqrt_length = opBuilder.create(loc, length_f64); + Value maxl_index = opBuilder.create( + loc, opBuilder.getI32Type(), sqrt_length); + Value maxl_index_index = opBuilder.create( + loc, opBuilder.getIndexType(), maxl_index); + Value maxl_final = opBuilder.create(loc, maxl_index_index, c1); + opBuilder.create(loc, maxl_final, maxl, c0); + + opBuilder.create( + loc, TypeRange{indexTy}, ValueRange{c3}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value divisor = args[0]; + Value length_while = builder.create(loc, length, c0); + Value current_maxl = builder.create(loc, maxl, c0); + + Value condition1 = builder.create( + loc, arith::CmpIPredicate::sgt, length_while, c1); + Value condition2 = builder.create( + loc, arith::CmpIPredicate::slt, divisor, current_maxl); + Value and_cond = + builder.create(loc, condition1, condition2); + builder.create(loc, and_cond, ValueRange{divisor}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value divisor = args[0]; + + Value length_while = builder.create(loc, length, c0); + Value length_mod_divisor = + builder.create(loc, length_while, divisor); + Value condition1 = builder.create( + loc, arith::CmpIPredicate::eq, length_mod_divisor, c0); + builder.create( + loc, condition1, [&](OpBuilder &b, Location loc) { + b.create( + loc, TypeRange{indexTy}, ValueRange{c1}, + [&](OpBuilder &b2, Location loc, ValueRange args) { + Value x = args[0]; + + Value length_while_1 = + b2.create(loc, length, c0); + Value length_mod_divisor_1 = + b2.create(loc, length_while_1, divisor); + + Value condition2 = + b2.create(loc, arith::CmpIPredicate::eq, + length_mod_divisor_1, c0); + b2.create(loc, condition2, ValueRange{x}); + }, + [&](OpBuilder &b2, Location loc, ValueRange args) { + Value x = args[0]; + + Value currnet_nfct = + b2.create(loc, nfct, c0); + b2.create(loc, divisor, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(b2, loc, nfct); + + Value length_while_1 = + b2.create(loc, length, c0); + Value length_new = + b2.create(loc, length_while_1, divisor); + b2.create(loc, length_new, length, c0); + + b2.create(loc, std::vector{x}); + }); + + Value current_length2_1 = + b.create(loc, length, c0); + Value currnet_length2_i32_1 = b.create( + loc, opBuilder.getI32Type(), current_length2_1); + Value length_f64_1 = b.create( + loc, opBuilder.getF64Type(), currnet_length2_i32_1); + Value sqrt_length_1 = b.create(loc, length_f64_1); + Value maxl_index_1 = + b.create(loc, b.getI32Type(), sqrt_length_1); + Value maxl_index_index_1 = b.create( + loc, opBuilder.getIndexType(), maxl_index_1); + Value maxl_final_1 = + b.create(loc, maxl_index_index_1, c1); + b.create(loc, maxl_final_1, maxl, c0); + + b.create(loc, std::nullopt); + }); + + Value divisor_next = builder.create(loc, divisor, c2); + builder.create(loc, std::vector{divisor_next}); + }); + + Value current_length1 = opBuilder.create(loc, length, c0); + Value condition1 = opBuilder.create( + loc, arith::CmpIPredicate::sgt, current_length1, c1); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + Value current_nfct = builder.create(loc, nfct, c0); + builder.create(loc, current_length1, Rfftp_fctdata_fct, + current_nfct); + index_increment(builder, loc, nfct); + builder.create(loc, std::nullopt); + }); + + Value current_nfct1 = opBuilder.create(loc, nfct, c0); + opBuilder.create(loc, current_nfct1, Rfftp_plan_nfct, c0); + + return c0; +} + +Value index_to_f64(OpBuilder &opBuilder, Location loc, Value n) { + TypeRange type = TypeRange{opBuilder.getF64Type()}; + Value n_i32 = + opBuilder.create(loc, opBuilder.getI32Type(), n); + Value n_f64 = + opBuilder.create(loc, opBuilder.getF64Type(), n_i32); + return n_f64; +} + +Value f64_to_index(OpBuilder &opBuilder, Location loc, Value n_f64) { + TypeRange type = TypeRange{opBuilder.getI32Type()}; + Value n_i32 = + opBuilder.create(loc, opBuilder.getI32Type(), n_f64); + Value n_index = opBuilder.create( + loc, opBuilder.getIndexType(), n_i32); + return n_index; +} + +void my_sincosm1pi(OpBuilder &opBuilder, Location loc, Value a, Value res, + Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{c2}, SmallVector{c1}); + + Value s = opBuilder.create(loc, a, a); + + Value r1 = opBuilder.create( + loc, APFloat(double(-1.0369917389758117e-4)), f64Ty); + Value r2 = opBuilder.create( + loc, APFloat(double(1.9294935641298806e-3)), f64Ty); + Value r3 = opBuilder.create( + loc, APFloat(double(-2.5806887942825395e-2)), f64Ty); + Value r4 = opBuilder.create( + loc, APFloat(double(2.3533063028328211e-1)), f64Ty); + Value r5 = opBuilder.create( + loc, APFloat(double(-1.3352627688538006e+0)), f64Ty); + Value r6 = opBuilder.create( + loc, APFloat(double(4.0587121264167623e+0)), f64Ty); + Value r7 = opBuilder.create( + loc, APFloat(double(-4.9348022005446790e+0)), f64Ty); + + Value fma1 = opBuilder.create(loc, r1, s, r2); + Value fma2 = opBuilder.create(loc, fma1, s, r3); + Value fma3 = opBuilder.create(loc, fma2, s, r4); + Value fma4 = opBuilder.create(loc, fma3, s, r5); + Value fma5 = opBuilder.create(loc, fma4, s, r6); + Value fma6 = opBuilder.create(loc, fma5, s, r7); + + Value c = opBuilder.create(loc, fma6, s); + + Value r8 = opBuilder.create( + loc, APFloat(double(4.6151442520157035e-4)), f64Ty); + Value r9 = opBuilder.create( + loc, APFloat(double(-7.3700183130883555e-3)), f64Ty); + Value r10 = opBuilder.create( + loc, APFloat(double(8.2145868949323936e-2)), f64Ty); + Value r11 = opBuilder.create( + loc, APFloat(double(-5.9926452893214921e-1)), f64Ty); + Value r12 = opBuilder.create( + loc, APFloat(double(2.5501640398732688e+0)), f64Ty); + Value r13 = opBuilder.create( + loc, APFloat(double(-5.1677127800499516e+0)), f64Ty); + + Value fma7 = opBuilder.create(loc, r8, s, r9); + Value fma8 = opBuilder.create(loc, fma7, s, r10); + Value fma9 = opBuilder.create(loc, fma8, s, r11); + Value fma10 = opBuilder.create(loc, fma9, s, r12); + Value fma11 = opBuilder.create(loc, fma10, s, r13); + + Value s_new = opBuilder.create(loc, s, a); + Value r = opBuilder.create(loc, fma11, s_new); + + Value pi = opBuilder.create( + loc, APFloat(double(3.1415926535897931e+0)), f64Ty); + Value s_final = opBuilder.create(loc, a, pi, r); + + opBuilder.create(loc, c, res_raw, c0); + opBuilder.create(loc, s_final, res_raw, c1); + + return; +} + +void calc_first_octant_extend2(OpBuilder &opBuilder, Location loc, Value den, + Value res, Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, bias); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value f2 = + opBuilder.create(loc, APFloat(double(2.0)), f64Ty); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + + Value n_f64 = index_to_f64(opBuilder, loc, n); + Value l1_f64 = opBuilder.create(loc, n_f64); + Value l1 = f64_to_index(opBuilder, loc, l1_f64); + + opBuilder.create( + loc, c1, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iargs) { + Value i_f64 = index_to_f64(builder, loc, i); + Value den_f64 = index_to_f64(builder, loc, den); + Value arg = builder.create(loc, i_f64, den_f64); + Value arg_scaled = builder.create(loc, arg, f2); + + Value im2 = builder.create(loc, i, c2); + Value im2_bias = builder.create(loc, im2, bias); + + my_sincosm1pi(builder, loc, arg_scaled, res, im2_bias); + builder.create(loc, std::nullopt); + }); + + Value start_start = opBuilder.create(loc, l1, c0); + + opBuilder.create( + loc, start_start, n, l1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value start_loop, + ValueRange start_loop_args) { + Value start_f64 = index_to_f64(builder, loc, start_loop); + Value den_f64 = index_to_f64(builder, loc, den); + Value arg = builder.create(loc, start_f64, den_f64); + Value arg_scaled = builder.create(loc, arg, f2); + + Value cs = + builder.create(loc, MemRefType::get(2, f64Ty)); + my_sincosm1pi(builder, loc, arg_scaled, cs, c0); + + Value cs0 = builder.create(loc, cs, c0); + Value cs1 = builder.create(loc, cs, c1); + + Value cs0_plus_1 = builder.create(loc, cs0, f1); + + Value start_2 = builder.create(loc, start_loop, c2); + builder.create(loc, cs0_plus_1, res_raw, start_2); + Value start_2_plus_1 = builder.create(loc, start_2, c1); + builder.create(loc, cs1, res_raw, start_2_plus_1); + + Value n_minus_start = builder.create(loc, n, start_loop); + Value end_1 = builder.create(loc, l1, c0); + Value sum = builder.create(loc, start_loop, end_1); + Value condition = builder.create( + loc, arith::CmpIPredicate::sgt, sum, n); + Value end = builder.create(loc, condition, + n_minus_start, end_1); + + builder.create( + loc, c1, end, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value i_2 = b.create(loc, i, c2); + Value csx0 = b.create(loc, res_raw, i_2); + Value i_2_plus_1 = b.create(loc, i_2, c1); + Value csx1 = b.create(loc, res_raw, i_2_plus_1); + + Value tmp1 = b.create(loc, cs0, csx0); + Value tmp2 = b.create(loc, cs1, csx1); + Value tmp3 = b.create(loc, tmp1, tmp2); + Value tmp4 = b.create(loc, tmp3, cs0); + Value tmp5 = b.create(loc, tmp4, csx0); + Value res_real = b.create(loc, tmp5, f1); + + Value tmp6 = b.create(loc, cs0, csx1); + Value tmp7 = b.create(loc, cs1, csx0); + Value tmp8 = b.create(loc, tmp6, tmp7); + Value tmp9 = b.create(loc, tmp8, cs1); + Value res_imag = b.create(loc, tmp9, csx1); + + Value start_plus_i = b.create(loc, start_loop, i); + Value start_plus_i_2 = + b.create(loc, start_plus_i, c2); + Value start_plus_i_2_plus_1 = + b.create(loc, start_plus_i_2, c1); + b.create(loc, res_real, res_raw, start_plus_i_2); + b.create(loc, res_imag, res_raw, + start_plus_i_2_plus_1); + b.create(loc, std::nullopt); + }); + + builder.create(loc, cs); + builder.create(loc, std::nullopt); + }); + + opBuilder.create( + loc, c1, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange i_args) { + Value i_2 = builder.create(loc, i, c2); + Value val = builder.create(loc, res_raw, i_2); + Value val_plus_1 = builder.create(loc, val, f1); + builder.create(loc, val_plus_1, res_raw, i_2); + builder.create(loc, std::nullopt); + }); + + return; +} + +void calc_first_octant_extend1(OpBuilder &opBuilder, Location loc, Value den, + Value res, Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, bias); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + + opBuilder.create(loc, f1, res_raw, c0); + opBuilder.create(loc, f0, res_raw, c1); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + calc_first_octant_extend2(builder, loc, den, res, bias); + builder.create(loc, std::nullopt); + }); +} + +void calc_first_octant(OpBuilder &opBuilder, Location loc, Value den, Value res, + Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + calc_first_octant_extend1(builder, loc, den, res, bias); + builder.create(loc, std::nullopt); + }); +} + +void calc_first_quadrant(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, n); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value p_raw = opBuilder.create( + loc, resultType, res, SmallVector{n}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value n_times_2 = opBuilder.create(loc, n, c1); + calc_first_octant(opBuilder, loc, n_times_2, res, n); + + Value n_plus_2 = opBuilder.create(loc, n, c2); + Value ndone = opBuilder.create(loc, n_plus_2, c2); + Value ndonem1 = opBuilder.create(loc, ndone, c1); + Value ndone2 = opBuilder.create(loc, ndone, c2); + Value idx2_start = opBuilder.create(loc, ndone2, c2); + + Value i_start = opBuilder.create(loc, 0); + Value idx1_start = opBuilder.create(loc, 0); + + auto loop = opBuilder.create( + loc, i_start, ndonem1, c2, ValueRange{i_start, idx1_start, idx2_start}, + [&](OpBuilder &builder, Location loc, Value i_loop, + ValueRange i_loop_args) { + Value i_loop1 = i_loop_args[0]; + Value idx1 = i_loop_args[1]; + Value idx2 = i_loop_args[2]; + + Value p_2i = builder.create(loc, i_loop1, c2); + Value p_val = builder.create(loc, p_raw, p_2i); + builder.create(loc, p_val, res, idx1); + + Value p_2i_plus_1 = builder.create(loc, p_2i, c1); + Value p_val_1 = builder.create(loc, p_raw, p_2i_plus_1); + Value idx1_plus_1 = builder.create(loc, idx1, c1); + builder.create(loc, p_val_1, res, idx1_plus_1); + + Value p_2i_plus_3 = builder.create(loc, p_2i, c3); + Value p_val_3 = builder.create(loc, p_raw, p_2i_plus_3); + builder.create(loc, p_val_3, res, idx2); + + Value p_2i_plus_2 = builder.create(loc, p_2i, c2); + Value p_val_2 = builder.create(loc, p_raw, p_2i_plus_2); + Value idx2_plus_1 = builder.create(loc, idx2, c1); + builder.create(loc, p_val_2, res, idx2_plus_1); + + Value i_loop1_next = builder.create(loc, i_loop1, c2); + Value idx1_next = builder.create(loc, idx1, c2); + Value idx2_next = builder.create(loc, idx2, c2); + builder.create( + loc, std::vector{i_loop1_next, idx1_next, idx2_next}); + }); + + Value i_v = loop.getResults()[0]; + Value idx1_v = loop.getResults()[1]; + Value idx2_v = loop.getResults()[2]; + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::ne, i_v, ndone); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value p_2i = builder.create(loc, i_v, c2); + Value p_val = builder.create(loc, p_raw, p_2i); + builder.create(loc, p_val, res, idx1_v); + + Value p_2i_plus_1 = builder.create(loc, p_2i, c1); + Value p_val_1 = builder.create(loc, p_raw, p_2i_plus_1); + Value idx1_plus_1 = builder.create(loc, idx1_v, c1); + builder.create(loc, p_val_1, res, idx1_plus_1); + builder.create(loc, std::nullopt); + }); + + return; +} + +void calc_first_half(OpBuilder &opBuilder, Location loc, Value n, Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + IndexType indexTy = opBuilder.getIndexType(); + FloatType f64Ty = opBuilder.getF64Type(); + + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + + Value n_plus_1 = opBuilder.create(loc, n, c1); + Value ndone = opBuilder.create(loc, n_plus_1, c1); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, n); + Value remaining_size_p1 = + opBuilder.create(loc, remaining_size, c1); + + Value nm1 = opBuilder.create(loc, n, c1); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value p_raw = opBuilder.create( + loc, resultType, res, SmallVector{nm1}, + SmallVector{remaining_size_p1}, + SmallVector{c1}); + + Value n_times_4 = opBuilder.create(loc, n, c2); + calc_first_octant(opBuilder, loc, n_times_4, res, nm1); + + Value i4_start = opBuilder.create(loc, 0); + Value i_start = opBuilder.create(loc, 0); + Value in = opBuilder.create(loc, n, c0); + + auto loop = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{i4_start, i_start}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_minus_i4 = builder.create(loc, in, i4); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4, in_minus_i4); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value i4_2 = builder.create(loc, i4, c2); + Value i_2 = builder.create(loc, i, c2); + Value i4_2_p1 = builder.create(loc, i4_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_i4_2 = builder.create(loc, p_raw, i4_2); + Value p_i4_2_p1 = builder.create(loc, p_raw, i4_2_p1); + + builder.create(loc, p_i4_2, res, i_2); + builder.create(loc, p_i4_2_p1, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_0 = loop.getResults()[0]; + Value final_i_0 = loop.getResults()[1]; + + auto loop1 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_0, final_i_0}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value i4_minus_in = builder.create(loc, i4, in); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4_minus_in, c0); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value xm = builder.create(loc, in, i4); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + builder.create(loc, p_xm_2_p1, res, i_2); + builder.create(loc, p_xm_2, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_1 = loop1.getResults()[0]; + Value final_i_1 = loop1.getResults()[1]; + + auto loop2 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_1, final_i_1}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_3 = builder.create(loc, in, c3); + Value in_3_m_i4 = builder.create(loc, in_3, i4); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4, in_3_m_i4); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value xm = builder.create(loc, i4, in); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + Value m_p_xm_2_p1 = builder.create(loc, f0, p_xm_2_p1); + + builder.create(loc, m_p_xm_2_p1, res, i_2); + builder.create(loc, p_xm_2, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_2 = loop2.getResults()[0]; + Value final_i_2 = loop2.getResults()[1]; + + auto loop3 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_2, final_i_2}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value condition = builder.create( + loc, arith::CmpIPredicate::slt, i, ndone); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_2 = builder.create(loc, in, c2); + + Value xm = builder.create(loc, in_2, i4); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + Value m_p_xm_2 = builder.create(loc, f0, p_xm_2); + + builder.create(loc, m_p_xm_2, res, i_2); + builder.create(loc, p_xm_2_p1, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + + builder.create(loc, std::vector{i4_next, i_next}); + }); + + return; +} + +void fill_first_quadrant(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c8 = opBuilder.create(loc, 8); + + FloatType f64Ty = opBuilder.getF64Type(); + + Value hsqt2 = opBuilder.create( + loc, APFloat(double(0.707106781186547524400844362104849)), f64Ty); + + Value quart = opBuilder.create(loc, n, c2); + Value n_mod_8 = opBuilder.create(loc, n, c8); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_8, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value quart_plus_1 = builder.create(loc, quart, c1); + builder.create(loc, hsqt2, res, quart); + builder.create(loc, hsqt2, res, quart_plus_1); + builder.create(loc, std::nullopt); + }); + + Value two_quart = opBuilder.create(loc, quart, c2); + Value two_quart_minus_2 = opBuilder.create(loc, two_quart, c2); + + opBuilder.create( + loc, c2, quart, c2, ValueRange{two_quart_minus_2}, + [&](OpBuilder &builder, Location loc, Value i, ValueRange i_args) { + Value j = i_args[0]; + + Value i_plus_1 = builder.create(loc, i, c1); + Value j_plus_1 = builder.create(loc, j, c1); + + Value val_i = builder.create(loc, res, i); + Value val_i_plus_1 = builder.create(loc, res, i_plus_1); + + builder.create(loc, val_i_plus_1, res, j); + builder.create(loc, val_i, res, j_plus_1); + + Value j_next = builder.create(loc, j, c2); + builder.create(loc, j_next); + }); + + return; +} + +void fill_first_half(OpBuilder &opBuilder, Location loc, Value n, Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + FloatType f64Ty = opBuilder.getF64Type(); + Value c_1 = + opBuilder.create(loc, APFloat(double(-1.0)), f64Ty); + + Value half = opBuilder.create(loc, n, c1); + Value n_mod_4 = opBuilder.create(loc, n, c4); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_4, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, c0, half, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value i_plus_1 = b.create(loc, i, c1); + Value i_plus_half = b.create(loc, i, half); + Value i_plus_half_plus_1 = + b.create(loc, i_plus_half, c1); + + Value val_i = b.create(loc, res, i); + Value val_i_plus_1 = b.create(loc, res, i_plus_1); + + Value neg_val_i_plus_1 = + b.create(loc, val_i_plus_1, c_1); + b.create(loc, neg_val_i_plus_1, res, + i_plus_half); + b.create(loc, val_i, res, i_plus_half_plus_1); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value two_half_minus_2 = builder.create(loc, half, c1); + Value two_half_minus_2_mul_2 = + builder.create(loc, two_half_minus_2, c2); + + builder.create( + loc, c2, half, c2, ValueRange{two_half_minus_2_mul_2}, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value j = i_args[0]; + Value i_plus_1 = builder.create(loc, i, c1); + Value j_plus_1 = builder.create(loc, j, c1); + Value val_i = b.create(loc, res, i); + Value val_i_plus_1 = b.create(loc, res, i_plus_1); + Value neg_val_i = b.create(loc, val_i, c_1); + b.create(loc, neg_val_i, res, j); + b.create(loc, val_i_plus_1, res, j_plus_1); + + Value j_next = builder.create(loc, j, c2); + b.create(loc, j_next); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +void sincos_2pibyn_half(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value n_mod_4 = opBuilder.create(loc, n, c4); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_4, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + calc_first_octant(builder, loc, n, res, c0); + + fill_first_quadrant(builder, loc, n, res); + fill_first_half(builder, loc, n, res); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value n_mod_2 = builder.create(loc, n, c2); + Value condition1 = builder.create( + loc, arith::CmpIPredicate::eq, n_mod_2, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &b, Location loc) { + calc_first_quadrant(b, loc, n, res); + fill_first_half(b, loc, n, res); + b.create(loc, std::nullopt); + }, + [&](OpBuilder &b, Location loc) { + calc_first_half(b, loc, n, res); b.create(loc, std::nullopt); }); - builder.create(loc, std::nullopt); }); - - return; } -void radf4(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, - Value ido, Value l1, Value c0, Value c1, Value c2, Value c3) { +// calcuate the twiddle factors for the input length +Value rfftp_comp_twiddle(OpBuilder &opBuilder, Location loc, Value length, + Value Rfftp_fctdata_fct, Value Rfftp_fctdata_tw, + Value Rfftp_fctdata_tws, Value Rfftp_plan_length, + Value Rfftp_plan_nfct, Value Rfftp_plan_mem) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value length_2 = opBuilder.create(loc, length, c2); FloatType f64Ty = opBuilder.getF64Type(); - Value cdim = opBuilder.create(loc, 4); - Value hsqt2 = opBuilder.create( - loc, APFloat(double(0.70710678118654752440)), f64Ty); - Value idom1 = opBuilder.create(loc, ido, c1); - - opBuilder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value iv, ValueRange iargs) { - Value cc0k3 = CC(builder, loc, cc, c0, iv, c3, ido, l1); - Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); - std::vector tr1_tmp0 = PM(builder, loc, cc0k3, cc0k1); - CH(builder, loc, ch, c0, c2, iv, ido, cdim, tr1_tmp0[1]); - - Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); - Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); - std::vector tr2_tmp1 = PM(builder, loc, cc0k0, cc0k2); - CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tr2_tmp1[1]); - std::vector tmp2_tmp3 = - PM(builder, loc, tr2_tmp1[0], tr1_tmp0[0]); - CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp2_tmp3[0]); - CH(builder, loc, ch, idom1, c3, iv, ido, cdim, tmp2_tmp3[1]); + Value twid = opBuilder.create( + loc, MemRefType::get(ShapedType::kDynamic, f64Ty), + /*dynamicOperands=*/length_2); - builder.create(loc, std::nullopt); - }); + Value plan_nfct = opBuilder.create(loc, Rfftp_plan_nfct, c0); - Value reminder = opBuilder.create(loc, ido, c2); - Value condition0 = opBuilder.create( - loc, arith::CmpIPredicate::eq, reminder, c0); - opBuilder.create( - loc, condition0, [&](OpBuilder &builder, Location loc) { - Value negHsqt2 = builder.create( - loc, APFloat(double(-0.70710678118654752440)), f64Ty); + sincos_2pibyn_half(opBuilder, loc, length, twid); - builder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { - Value ccidom1k1 = CC(b, loc, cc, idom1, iv, c1, ido, l1); - Value ccidom1k3 = CC(b, loc, cc, idom1, iv, c3, ido, l1); - Value tmp0 = b.create(loc, ccidom1k1, ccidom1k3); - Value ti1 = b.create(loc, negHsqt2, tmp0); + Value l1_start = opBuilder.create(loc, 1); - Value tmp1 = b.create(loc, ccidom1k1, ccidom1k3); - Value tr1 = b.create(loc, hsqt2, tmp1); + opBuilder.create( + loc, c0, plan_nfct, c1, ValueRange{l1_start}, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + Value l1 = k_args[0]; + + Value ip = builder.create(loc, Rfftp_fctdata_fct, k); + + Value l1_m_ip = builder.create(loc, l1, ip); + Value ido = builder.create(loc, length, l1_m_ip); + Value plan_nfct_m1 = builder.create(loc, plan_nfct, c1); + + Value condition1 = builder.create( + loc, arith::CmpIPredicate::slt, k, plan_nfct_m1); + + builder.create( + loc, condition1, [&](OpBuilder &b, Location loc) { + Value ido_m1 = b.create(loc, ido, c1); + Value ido_m1_d2 = b.create(loc, ido_m1, c2); + Value ido_m1_d2_p1 = b.create(loc, ido_m1_d2, c1); + + b.create( + loc, c1, ip, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value j, ValueRange j_args) { + b2.create( + loc, c1, ido_m1_d2_p1, c1, std::nullopt, + [&](OpBuilder &b3, Location loc, Value i, + ValueRange i_args) { + Value j2 = b3.create(loc, j, c2); + Value j2_l1 = b3.create(loc, j2, l1); + Value j2_l1_i = + b3.create(loc, j2_l1, i); + Value j2_l1_i_p1 = + b3.create(loc, j2_l1_i, c1); + + Value j_m1 = b3.create(loc, j, c1); + Value ido_m1_j_m1 = + b3.create(loc, ido_m1, j_m1); + + Value i2 = b3.create(loc, i, c2); + Value i2_m1 = b3.create(loc, i2, c1); + Value i2_m2 = b3.create(loc, i2, c2); + + Value tw_a = + b3.create(loc, ido_m1_j_m1, i2_m2); + Value tw_b = + b3.create(loc, ido_m1_j_m1, i2_m1); + + Value twid_a = + b3.create(loc, twid, j2_l1_i); + Value twid_b = + b3.create(loc, twid, j2_l1_i_p1); + + Value fct_k = b3.create( + loc, Rfftp_fctdata_tw, k); + + b3.create(loc, twid_a, fct_k, tw_a); + b3.create(loc, twid_b, fct_k, tw_b); + + b3.create(loc, std::nullopt); + }); + b2.create(loc, std::nullopt); + }); - Value ccidom1k0 = CC(b, loc, cc, idom1, iv, c0, ido, l1); - std::vector tmp2_tmp3 = PM(b, loc, ccidom1k0, tr1); - CH(b, loc, ch, idom1, c0, iv, ido, cdim, tmp2_tmp3[0]); - CH(b, loc, ch, idom1, c2, iv, ido, cdim, tmp2_tmp3[1]); + b.create(loc, std::nullopt); + }); - Value ccidom1k2 = CC(b, loc, cc, idom1, iv, c2, ido, l1); - std::vector tmp4_tmp5 = PM(b, loc, ti1, ccidom1k2); - CH(b, loc, ch, c0, c3, iv, ido, cdim, tmp4_tmp5[0]); - CH(b, loc, ch, c0, c1, iv, ido, cdim, tmp4_tmp5[1]); + Value condition2 = builder.create( + loc, arith::CmpIPredicate::sgt, ip, c5); + + builder.create( + loc, condition2, [&](OpBuilder &b, Location loc) { + Value fct_k = b.create(loc, Rfftp_fctdata_tws, k); + Value c_f0 = + b.create(loc, APFloat(double(0.0)), f64Ty); + Value c_f1 = + b.create(loc, APFloat(double(1.0)), f64Ty); + + b.create(loc, c_f1, fct_k, c0); + b.create(loc, c_f0, fct_k, c1); + + Value ip_div_2 = b.create(loc, ip, c1); + Value ip_div_2_p1 = b.create(loc, ip_div_2, c1); + + b.create( + loc, c1, ip_div_2_p1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value i, ValueRange i_args) { + Value i2 = b2.create(loc, i, c2); + Value i2_p1 = b2.create(loc, i2, c1); + Value ip_m_i = b2.create(loc, ip, i); + Value ip_m_i_2 = b2.create(loc, ip_m_i, c2); + Value ip_m_i_2_p1 = + b2.create(loc, ip_m_i_2, c1); + + Value length_div_ip = + b2.create(loc, length, ip); + Value i2_length_div_ip = + b2.create(loc, i2, length_div_ip); + Value i2_length_div_ip_p1 = + b2.create(loc, i2_length_div_ip, c1); + + Value twid_a = + b2.create(loc, twid, i2_length_div_ip); + Value twid_b = b2.create( + loc, twid, i2_length_div_ip_p1); + Value twid_c = b2.create(loc, c_f0, twid_a); + Value twid_d = b2.create(loc, c_f0, twid_b); + + b2.create(loc, twid_a, fct_k, i2); + b2.create(loc, twid_b, fct_k, i2_p1); + b2.create(loc, twid_c, fct_k, ip_m_i_2); + b2.create(loc, twid_d, fct_k, ip_m_i_2_p1); + b2.create(loc, std::nullopt); + }); b.create(loc, std::nullopt); }); - builder.create(loc, std::nullopt); + Value l1_next = builder.create(loc, l1, ip); + builder.create(loc, l1_next); }); - Value condition1 = - opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); - opBuilder.create( - loc, condition1, [&](OpBuilder &builder, Location loc) { - radf4Extend(builder, loc, cc, ch, wa, ido, l1, cdim, c0, c1, c2, c3); - builder.create(loc, std::nullopt); - }); + opBuilder.create(loc, twid); - return; + return c0; } -void radf5Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, - Value wa, Value ido, Value l1, Value cdim, Value tr11, - Value tr12, Value ti11, Value ti12, Value c0, Value c1, - Value c2, Value c3, Value c4) { - opBuilder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { - builder.create( - loc, c2, ido, c2, std::nullopt, - [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { - Value ic = b.create(loc, ido, i); - Value icm1 = b.create(loc, ic, c1); - Value im1 = b.create(loc, i, c1); - Value im2 = b.create(loc, i, c2); +// calculate the twiddle factors and generates the computation order of +// butterfly operators +std::vector make_rfftp_plan(OpBuilder &opBuilder, Location loc, + Value length) { - Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); - Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); - Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); - Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); - std::vector dr2_di2 = - MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); - Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); - Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); - Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); - Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); - std::vector dr3_di3 = - MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + int64_t NFCT_num = 25; + Value NFCT = opBuilder.create(loc, NFCT_num); - Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); - Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); - Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); - Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); - std::vector dr4_di4 = - MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + FloatType f64Ty = opBuilder.getF64Type(); + IndexType indexTy = opBuilder.getIndexType(); - Value wa3im2 = WA(b, loc, wa, c3, im2, ido, c1); - Value wa3im1 = WA(b, loc, wa, c3, im1, ido, c1); - Value ccim1k4 = CC(b, loc, cc, im1, k, c4, ido, l1); - Value ccik4 = CC(b, loc, cc, i, k, c4, ido, l1); - std::vector dr5_di5 = - MULPM(b, loc, wa3im2, wa3im1, ccim1k4, ccik4); + Value length_2 = opBuilder.create(loc, length, c2); - std::vector cr2_ci5 = PM(b, loc, dr5_di5[0], dr2_di2[0]); - std::vector ci2_cr5 = PM(b, loc, dr2_di2[1], dr5_di5[1]); - std::vector cr3_ci4 = PM(b, loc, dr4_di4[0], dr3_di3[0]); - std::vector ci3_cr4 = PM(b, loc, dr3_di3[1], dr4_di4[1]); + MemRefType type = MemRefType::get(NFCT_num, indexTy); + // MemRefType type1 = MemRefType::get(length_num2, f64Ty); + MemRefType type1 = MemRefType::get(ShapedType::kDynamic, f64Ty); + MemRefType type2 = MemRefType::get(NFCT_num, type1); + MemRefType type3 = MemRefType::get(1, indexTy); + MemRefType type4 = MemRefType::get(1, f64Ty); - Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); - Value tmpch0 = b.create(loc, ccim1k0, cr2_ci5[0]); - Value chim10k = b.create(loc, tmpch0, cr3_ci4[0]); - CH(b, loc, ch, im1, c0, k, ido, cdim, chim10k); + Value Rfftp_fctdata_fct = opBuilder.create(loc, type); + Value Rfftp_fctdata_tw = opBuilder.create(loc, type2); + Value Rfftp_fctdata_tws = opBuilder.create(loc, type2); + Value Rfftp_plan_length = opBuilder.create(loc, type3); + Value Rfftp_plan_nfct = opBuilder.create(loc, type3); + Value Rfftp_plan_mem = opBuilder.create(loc, type4); - Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); - Value tmpch1 = b.create(loc, ccik0, ci2_cr5[0]); - Value chi0k = b.create(loc, tmpch1, ci3_cr4[0]); - CH(b, loc, ch, i, c0, k, ido, cdim, chi0k); + opBuilder.create(loc, length, Rfftp_plan_length, c0); + opBuilder.create(loc, c0, Rfftp_plan_nfct, c0); - Value tmp0 = b.create(loc, tr11, cr2_ci5[0]); - Value tmp1 = b.create(loc, ccim1k0, tmp0); - Value tmp2 = b.create(loc, tr12, cr3_ci4[0]); - Value tr2 = b.create(loc, tmp1, tmp2); + opBuilder.create( + loc, c0, NFCT, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iargs) { + builder.create(loc, c0, Rfftp_fctdata_fct, i); - Value tmp3 = b.create(loc, tr11, ci2_cr5[0]); - Value tmp4 = b.create(loc, ccik0, tmp3); - Value tmp5 = b.create(loc, tr12, ci3_cr4[0]); - Value ti2 = b.create(loc, tmp4, tmp5); + Value tw_i = builder.create( + loc, type1, /*dynamicOperands=*/length_2); + builder.create(loc, tw_i, Rfftp_fctdata_tw, i); + Value tws_i = builder.create( + loc, type1, /*dynamicOperands=*/length_2); + builder.create(loc, tws_i, Rfftp_fctdata_tws, i); - Value tmp6 = b.create(loc, tr12, cr2_ci5[0]); - Value tmp7 = b.create(loc, ccim1k0, tmp6); - Value tmp8 = b.create(loc, tr11, cr3_ci4[0]); - Value tr3 = b.create(loc, tmp7, tmp8); + builder.create(loc, std::nullopt); + }); - Value tmp9 = b.create(loc, tr12, ci2_cr5[0]); - Value tmp10 = b.create(loc, ccik0, tmp9); - Value tmp11 = b.create(loc, tr11, ci3_cr4[0]); - Value ti3 = b.create(loc, tmp10, tmp11); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::ne, length, c1); - std::vector tr5_tr4 = - MULPM(b, loc, ci2_cr5[1], ci3_cr4[1], ti11, ti12); - std::vector ti5_ti4 = - MULPM(b, loc, cr2_ci5[1], cr3_ci4[1], ti11, ti12); + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value xxx = builder.create(loc, 1); + rfftp_factorize(builder, loc, Rfftp_fctdata_fct, Rfftp_fctdata_tw, + Rfftp_fctdata_tws, Rfftp_plan_length, Rfftp_plan_nfct, + Rfftp_plan_mem); + rfftp_comp_twiddle(builder, loc, length, Rfftp_fctdata_fct, + Rfftp_fctdata_tw, Rfftp_fctdata_tws, + Rfftp_plan_length, Rfftp_plan_nfct, Rfftp_plan_mem); + builder.create(loc, std::nullopt); + }); - std::vector chtmp0 = PM(b, loc, tr2, tr5_tr4[0]); - CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp0[0]); - CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp0[1]); + return {Rfftp_fctdata_fct, Rfftp_fctdata_tw, Rfftp_fctdata_tws, + Rfftp_plan_length, Rfftp_plan_nfct, Rfftp_plan_mem}; +} - std::vector chtmp1 = PM(b, loc, ti5_ti4[0], ti2); - CH(b, loc, ch, i, c2, k, ido, cdim, chtmp1[0]); - CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp1[1]); +void memref_SWAP(OpBuilder &opBuilder, Location loc, Value p, Value p1) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); - std::vector chtmp2 = PM(b, loc, tr3, tr5_tr4[1]); - CH(b, loc, ch, im1, c4, k, ido, cdim, chtmp2[0]); - CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp2[1]); + Value length = opBuilder.create(loc, p, c0); - std::vector chtmp3 = PM(b, loc, ti5_ti4[1], ti3); - CH(b, loc, ch, i, c4, k, ido, cdim, chtmp3[0]); - CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp3[1]); + opBuilder.create( + loc, c0, length, c1, std::nullopt, + [&](OpBuilder builder, Location loc, Value i, ValueRange i_args) { + Value val_p = builder.create(loc, p, i); + Value val_p1 = builder.create(loc, p1, i); + + builder.create(loc, val_p, p1, i); + builder.create(loc, val_p1, p, i); + builder.create(loc, std::nullopt); + }); +} +void flag_SWAP(OpBuilder &opBuilder, Location loc, Value flag) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + + Value val = opBuilder.create(loc, flag, c0); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::eq, val, c0); + + Value x = opBuilder.create(loc, condition, c1, c0); + + opBuilder.create(loc, x, flag, c0); +} + +void copy_and_norm(OpBuilder &opBuilder, Location loc, Value c, Value p1, + Value n, Value fct, Value flag) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + FloatType f64Ty = opBuilder.getF64Type(); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + + Value flag_val = opBuilder.create(loc, flag, c0); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, flag_val, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + Value condition1 = builder.create( + loc, arith::CmpFPredicate::ONE, fct, f1); + builder.create( + loc, condition1, + [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder b2, Location loc, Value i, ValueRange i_args) { + Value p1_i = b2.create(loc, p1, i); + Value v = b2.create(loc, fct, p1_i); + b2.create(loc, v, c, i); + b2.create(loc, std::nullopt); + }); + b.create(loc, std::nullopt); + }, + [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder b2, Location loc, Value i, ValueRange i_args) { + Value val = b2.create(loc, p1, i); + b2.create(loc, val, c, i); + b2.create(loc, std::nullopt); + }); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value condition2 = builder.create( + loc, arith::CmpFPredicate::ONE, fct, f1); + builder.create( + loc, condition2, [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value i, ValueRange i_args) { + Value c_i = b2.create(loc, c, i); + Value newC = b2.create(loc, fct, c_i); + b2.create(loc, newC, c, i); + b2.create(loc, std::nullopt); + }); b.create(loc, std::nullopt); }); - builder.create(loc, std::nullopt); }); - - return; } -void radf5(OpBuilder &builder, Location loc, Value cc, Value ch, Value wa, - Value ido, Value l1, Value c0, Value c1, Value c2, Value c3, - Value c4) { - FloatType f64Ty = builder.getF64Type(); - Value cdim = builder.create(loc, 5); - Value tr11 = builder.create( - loc, APFloat(double(0.3090169943749474241)), f64Ty); - Value tr12 = builder.create( - loc, APFloat(double(-0.8090169943749474241)), f64Ty); - Value ti11 = builder.create( - loc, APFloat(double(0.95105651629515357212)), f64Ty); - Value ti12 = builder.create( - loc, APFloat(double(0.58778525229247312917)), f64Ty); - Value idom1 = builder.create(loc, ido, c1); +// FFT forward function for real number +void rfftp_forward(OpBuilder &opBuilder, Location loc, Value Rfftp_fctdata_fct, + Value Rfftp_fctdata_tw, Value Rfftp_fctdata_tws, + Value Rfftp_plan_length, Value Rfftp_plan_nfct, + Value Rfftp_plan_mem, Value c, Value fct) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c20 = opBuilder.create(loc, 20); + FloatType f64Ty = opBuilder.getF64Type(); - builder.create( - loc, c0, l1, c1, std::nullopt, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { - Value cc0k4 = CC(b, loc, cc, c0, iv, c4, ido, l1); - Value cc0k1 = CC(b, loc, cc, c0, iv, c1, ido, l1); - std::vector cr2_ci5 = PM(b, loc, cc0k4, cc0k1); + Value n = opBuilder.create(loc, Rfftp_plan_length, c0); - Value cc0k3 = CC(b, loc, cc, c0, iv, c3, ido, l1); - Value cc0k2 = CC(b, loc, cc, c0, iv, c2, ido, l1); - std::vector cr3_ci4 = PM(b, loc, cc0k3, cc0k2); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c1); - Value cc0k0 = CC(b, loc, cc, c0, iv, c0, ido, l1); - Value tmpch0 = b.create(loc, cc0k0, cr2_ci5[0]); - Value ch0 = b.create(loc, tmpch0, cr3_ci4[0]); - CH(b, loc, ch, c0, c0, iv, ido, cdim, ch0); + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value flag = builder.create( + loc, MemRefType::get(1, builder.getIndexType())); + builder.create(loc, c1, flag, c0); + Value l1_raw = builder.create(loc, n, c0); + Value nf = builder.create(loc, Rfftp_plan_nfct, c0); - Value tmpch1 = b.create(loc, tr11, cr2_ci5[0]); - Value tmpch2 = b.create(loc, tr12, cr3_ci4[0]); - Value tmpch3 = b.create(loc, cc0k0, tmpch1); - Value ch1 = b.create(loc, tmpch2, tmpch3); - CH(b, loc, ch, idom1, c1, iv, ido, cdim, ch1); + MemRefType cType = dyn_cast(c.getType()); + Value dimSize = builder.create(loc, c, 0); + Value ch = builder.create(loc, cType, + /*dynamicOperands=*/dimSize); - Value tmpch4 = b.create(loc, ti11, cr2_ci5[1]); - Value tmpch5 = b.create(loc, ti12, cr3_ci4[1]); - Value ch2 = b.create(loc, tmpch4, tmpch5); - CH(b, loc, ch, c0, c2, iv, ido, cdim, ch2); + // Value ch = builder.create( + // loc, MemRefType::get(cType.getShape(), f64Ty)); - Value tmpch6 = b.create(loc, tr12, cr2_ci5[0]); - Value tmpch7 = b.create(loc, tr11, cr3_ci4[0]); - Value tmpch8 = b.create(loc, tmpch6, tmpch7); - Value ch3 = b.create(loc, cc0k0, tmpch8); - CH(b, loc, ch, idom1, c3, iv, ido, cdim, ch3); + FailureOr computelayout = StridedLayoutAttr::get( + opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); - Value tmpch9 = b.create(loc, ti12, cr2_ci5[1]); - Value tmpch10 = b.create(loc, ti11, cr3_ci4[1]); - Value ch4 = b.create(loc, tmpch9, tmpch10); - CH(b, loc, ch, c0, c4, iv, ido, cdim, ch4); + // memref> - b.create(loc, std::nullopt); - }); + Value p1_raw = builder.create( + loc, resultType, c, SmallVector{c0}, + SmallVector{n}, SmallVector{c1}); - Value condition = - builder.create(loc, arith::CmpIPredicate::ne, ido, c1); - builder.create(loc, condition, [&](OpBuilder &b, Location loc) { - radf5Extend(b, loc, cc, ch, wa, ido, l1, cdim, tr11, tr12, ti11, ti12, c0, - c1, c2, c3, c4); - b.create(loc, std::nullopt); - }); + Value p2_raw = builder.create( + loc, resultType, ch, SmallVector{c0}, + SmallVector{n}, SmallVector{c1}); - return; + builder.create( + loc, c0, nf, c1, ValueRange{l1_raw}, + [&](OpBuilder b, Location loc, Value k1, ValueRange k1_args) { + Value l1_old = k1_args[0]; + + Value nf_m_k1 = b.create(loc, nf, k1); + Value k = b.create(loc, nf_m_k1, c1); + Value ip = b.create(loc, Rfftp_fctdata_fct, k); + Value ido = b.create(loc, n, l1_old); + Value l1 = b.create(loc, l1_old, ip); + + Value tw = b.create(loc, Rfftp_fctdata_tw, k); + + Value condition1 = b.create( + loc, arith::CmpIPredicate::eq, ip, c4); + + b.create( + loc, condition1, + [&](OpBuilder &b2, Location loc) { + radf4(b2, loc, p1_raw, p2_raw, tw, ido, l1, c0, c1, c2, c3); + b2.create(loc, std::nullopt); + }, + [&](OpBuilder &b2, Location loc) { + Value condition2 = b2.create( + loc, arith::CmpIPredicate::eq, ip, c2); + b2.create( + loc, condition2, + [&](OpBuilder &b3, Location loc) { + radf2(b3, loc, p1_raw, p2_raw, tw, ido, l1); + b3.create(loc, std::nullopt); + }, + [&](OpBuilder &b3, Location loc) { + Value condition3 = b3.create( + loc, arith::CmpIPredicate::eq, ip, c3); + b3.create( + loc, condition3, + [&](OpBuilder &b4, Location loc) { + radf3(b4, loc, p1_raw, p2_raw, tw, ido, l1); + b4.create(loc, std::nullopt); + }, + [&](OpBuilder &b4, Location loc) { + Value condition4 = b4.create( + loc, arith::CmpIPredicate::eq, ip, c5); + b4.create( + loc, condition4, + [&](OpBuilder &b5, Location loc) { + radf5(b5, loc, p1_raw, p2_raw, tw, ido, + l1, c0, c1, c2, c3, c4); + b5.create(loc, + std::nullopt); + }, + [&](OpBuilder &b5, Location loc) { + Value tws = b5.create( + loc, Rfftp_fctdata_tws, k); + radfg(b5, loc, p1_raw, p2_raw, tw, tws, + ido, ip, l1); + memref_SWAP(b5, loc, p1_raw, p2_raw); + flag_SWAP(b5, loc, flag); + b5.create(loc, + std::nullopt); + }); + b4.create(loc, std::nullopt); + }); + b3.create(loc, std::nullopt); + } + + ); + b2.create(loc, std::nullopt); + }); + + memref_SWAP(b, loc, p1_raw, p2_raw); + flag_SWAP(b, loc, flag); + + b.create(loc, l1); + }); + + copy_and_norm(builder, loc, c, p1_raw, n, fct, flag); + + builder.create(loc, std::nullopt); + }); } // Calculate abspower of bufferMem and store result to a specific line in the @@ -1198,11 +3544,20 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, }); Value multiplied = mulfOp.getResult(0); - Value bufferMem = + Value bufferMem_raw = builder.create(loc, mTp, multiplied); - // Compute 'dap.rfft400' operation, result stores in `bufferMem`. - builder.create(loc, bufferMem); + MemRefType type0 = MemRefType::get({400}, f64Ty); + MemRefType type1 = MemRefType::get(ShapedType::kDynamic, f64Ty); + + Value bufferMem_rfft = + builder.create(loc, type1, bufferMem_raw); + + // Compute 'dap.rfft' operation, result stores in `bufferMem`. + builder.create(loc, bufferMem_rfft); + + Value bufferMem = + builder.create(loc, type0, bufferMem_rfft); // Store the result in a single line specified by `iv`. absPower(builder, loc, bufferMem, spectrogram, iv, c0, c1, c2); @@ -1278,14 +3633,14 @@ Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, } namespace { -class DAPRFFT400Lowering : public OpRewritePattern { +class DAPRFFTLowering : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - explicit DAPRFFT400Lowering(MLIRContext *context) + explicit DAPRFFTLowering(MLIRContext *context) : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(dap::RFFT400Op op, + LogicalResult matchAndRewrite(dap::RFFTOp op, PatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto ctx = op->getContext(); @@ -1297,114 +3652,35 @@ class DAPRFFT400Lowering : public OpRewritePattern { Value c3 = rewriter.create(loc, 3); Value c4 = rewriter.create(loc, 4); Value c5 = rewriter.create(loc, 5); + Value c9 = rewriter.create(loc, 9); + Value c24 = rewriter.create(loc, 24); + Value c25 = rewriter.create(loc, 25); + Value c50 = rewriter.create(loc, 50); + + Value inputFeatures = rewriter.create( + loc, bufferMem, /*restrict=*/true, /*writable=*/true); + Value inputFeaturesSize = + rewriter.create(loc, inputFeatures, c0); FloatType f64Ty = rewriter.getF64Type(); + Value f0 = rewriter.create(loc, APFloat(double(0.0)), f64Ty); - int64_t inputLength = 400; - - // Generate ch MemRef - RankedTensorType tensorTy = RankedTensorType::get({inputLength}, f64Ty); - MemRefType m25Ty = MemRefType::get({inputLength}, f64Ty); - Value chTensor = rewriter.create(loc, tensorTy, f0); - Value ch = rewriter.create(loc, m25Ty, chTensor); - - // Generate wa MemRefs - std::vector tw0Vec{ - 0.999877, 0.015707, 0.999507, 0.031411, 0.998890, 0.047106, - 0.998027, 0.062791, 0.996917, 0.078459, 0.995562, 0.094108, - 0.993961, 0.109734, 0.992115, 0.125333, 0.990024, 0.140901, - 0.987688, 0.156434, 0.985109, 0.171929, 0.982287, 0.187381, - 0.979223, 0.202787, 0.975917, 0.218143, 0.972370, 0.233445, - 0.968583, 0.248690, 0.964557, 0.263873, 0.960294, 0.278991, - 0.955793, 0.294040, 0.951057, 0.309017, 0.946085, 0.323917, - 0.940881, 0.338738, 0.935444, 0.353475, 0.929776, 0.368125, - 0.923880, 0.382683, 0.917755, 0.397148, 0.911403, 0.411514, - 0.904827, 0.425779, 0.898028, 0.439939, 0.891007, 0.453990, - 0.883766, 0.467930, 0.876307, 0.481754, 0.868632, 0.495459, - 0.860742, 0.509041, 0.852640, 0.522499, 0.844328, 0.535827, - 0.835807, 0.549023, 0.827081, 0.562083, 0.818150, 0.575005, - 0.809017, 0.587785, 0.799685, 0.600420, 0.790155, 0.612907, - 0.780430, 0.625243, 0.770513, 0.637424, 0.760406, 0.649448, - 0.750111, 0.661312, 0.739631, 0.673013, 0.728969, 0.684547, - 0.718126, 0.695913, 0.000000, 0.999507, 0.031411, 0.998027, - 0.062791, 0.995562, 0.094108, 0.992115, 0.125333, 0.987688, - 0.156434, 0.982287, 0.187381, 0.975917, 0.218143, 0.968583, - 0.248690, 0.960294, 0.278991, 0.951057, 0.309017, 0.940881, - 0.338738, 0.929776, 0.368125, 0.917755, 0.397148, 0.904827, - 0.425779, 0.891007, 0.453990, 0.876307, 0.481754, 0.860742, - 0.509041, 0.844328, 0.535827, 0.827081, 0.562083, 0.809017, - 0.587785, 0.790155, 0.612907, 0.770513, 0.637424, 0.750111, - 0.661312, 0.728969, 0.684547, 0.707107, 0.707107, 0.684547, - 0.728969, 0.661312, 0.750111, 0.637424, 0.770513, 0.612907, - 0.790155, 0.587785, 0.809017, 0.562083, 0.827081, 0.535827, - 0.844328, 0.509041, 0.860742, 0.481754, 0.876307, 0.453990, - 0.891007, 0.425779, 0.904827, 0.397148, 0.917755, 0.368125, - 0.929776, 0.338738, 0.940881, 0.309017, 0.951057, 0.278991, - 0.960294, 0.248690, 0.968583, 0.218143, 0.975917, 0.187381, - 0.982287, 0.156434, 0.987688, 0.125333, 0.992115, 0.094108, - 0.995562, 0.062791, 0.998027, 0.031411, 0.999507, 0.000000, - 0.998890, 0.047106, 0.995562, 0.094108, 0.990024, 0.140901, - 0.982287, 0.187381, 0.972370, 0.233445, 0.960294, 0.278991, - 0.946085, 0.323917, 0.929776, 0.368125, 0.911403, 0.411514, - 0.891007, 0.453990, 0.868632, 0.495459, 0.844328, 0.535827, - 0.818150, 0.575005, 0.790155, 0.612907, 0.760406, 0.649448, - 0.728969, 0.684547, 0.695913, 0.718126, 0.661312, 0.750111, - 0.625243, 0.780430, 0.587785, 0.809017, 0.549023, 0.835807, - 0.509041, 0.860742, 0.467930, 0.883766, 0.425779, 0.904827, - 0.382683, 0.923880, 0.338738, 0.940881, 0.294040, 0.955793, - 0.248690, 0.968583, 0.202787, 0.979223, 0.156434, 0.987688, - 0.109734, 0.993961, 0.062791, 0.998027, 0.015707, 0.999877, - -0.031411, 0.999507, -0.078459, 0.996917, -0.125333, 0.992115, - -0.171929, 0.985109, -0.218143, 0.975917, -0.263873, 0.964557, - -0.309017, 0.951057, -0.353475, 0.935444, -0.397148, 0.917755, - -0.439939, 0.898028, -0.481754, 0.876307, -0.522499, 0.852640, - -0.562083, 0.827081, -0.600420, 0.799685, -0.637424, 0.770513, - -0.673013, 0.739631, 0.000000}; - Value wa0Tensor = rewriter.create( - loc, DenseFPElementsAttr::get(RankedTensorType::get({297}, f64Ty), - ArrayRef(tw0Vec))); - Value wa0 = rewriter.create( - loc, MemRefType::get({297}, f64Ty), wa0Tensor); - - std::vector tw1Vec{ - 0.998027, 0.062791, 0.992115, 0.125333, 0.982287, 0.187381, - 0.968583, 0.248690, 0.951057, 0.309017, 0.929776, 0.368125, - 0.904827, 0.425779, 0.876307, 0.481754, 0.844328, 0.535827, - 0.809017, 0.587785, 0.770513, 0.637424, 0.728969, 0.684547, - 0.992115, 0.125333, 0.968583, 0.248690, 0.929776, 0.368125, - 0.876307, 0.481754, 0.809017, 0.587785, 0.728969, 0.684547, - 0.637424, 0.770513, 0.535827, 0.844328, 0.425779, 0.904827, - 0.309017, 0.951057, 0.187381, 0.982287, 0.062791, 0.998027, - 0.982287, 0.187381, 0.929776, 0.368125, 0.844328, 0.535827, - 0.728969, 0.684547, 0.587785, 0.809017, 0.425779, 0.904827, - 0.248690, 0.968583, 0.062791, 0.998027, -0.125333, 0.992115, - -0.309017, 0.951057, -0.481754, 0.876307, -0.637424, 0.770513}; - Value wa1Tensor = rewriter.create( - loc, DenseFPElementsAttr::get(RankedTensorType::get({72}, f64Ty), - ArrayRef(tw1Vec))); - Value wa1 = rewriter.create( - loc, MemRefType::get({72}, f64Ty), wa1Tensor); - - std::vector tw2Vec{0.968583, 0.248690, 0.876307, 0.481754, - 0.876307, 0.481754, 0.535827, 0.844328, - 0.728969, 0.684547, 0.062791, 0.998027, - 0.535827, 0.844328, -0.425779, 0.904827}; - Value wa2Tensor = rewriter.create( - loc, DenseFPElementsAttr::get(RankedTensorType::get({16}, f64Ty), - ArrayRef(tw2Vec))); - Value wa2 = rewriter.create( - loc, MemRefType::get({16}, f64Ty), wa2Tensor); - - Value c16 = rewriter.create(loc, 16); - Value c25 = rewriter.create(loc, 25); - Value c80 = rewriter.create(loc, 80); - Value c100 = rewriter.create(loc, 100); + Value f1 = + rewriter.create(loc, APFloat(double(1.0)), f64Ty); + + std::vector plan = make_rfftp_plan(rewriter, loc, inputFeaturesSize); + + Value Rfftp_fctdata_fct = plan[0]; + Value Rfftp_fctdata_tw = plan[1]; + Value Rfftp_fctdata_tws = plan[2]; + Value Rfftp_plan_length = plan[3]; + Value Rfftp_plan_nfct = plan[4]; + Value Rfftp_plan_mem = plan[5]; - radf5(rewriter, loc, bufferMem, ch, wa2, c1, c80, c0, c1, c2, c3, c4); - radf5(rewriter, loc, ch, bufferMem, wa2, c5, c16, c0, c1, c2, c3, c4); - radf4(rewriter, loc, bufferMem, ch, wa1, c25, c4, c0, c1, c2, c3); - radf4(rewriter, loc, ch, bufferMem, wa0, c100, c1, c0, c1, c2, c3); + rfftp_forward(rewriter, loc, Rfftp_fctdata_fct, Rfftp_fctdata_tw, + Rfftp_fctdata_tws, Rfftp_plan_length, Rfftp_plan_nfct, + Rfftp_plan_mem, bufferMem, f1); rewriter.eraseOp(op); return success(); @@ -1568,7 +3844,7 @@ class DAPWhisperPreprocessLowering void populateExtendDAPConversionPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); // TODO : extract operators } @@ -1599,6 +3875,7 @@ class ExtendDAPPass registry.insert(); registry.insert(); registry.insert(); + registry.insert(); // Buddy Compiler designed dialect registry.insert(); } @@ -1620,6 +3897,7 @@ void ExtendDAPPass::runOnOperation() { target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); // Add legal operations. target.addLegalOp();