Skip to content

Commit

Permalink
[DAP] Add document for RFFT development and OSPP proposal guide. (bud…
Browse files Browse the repository at this point in the history
  • Loading branch information
taiqzheng authored May 15, 2024
1 parent faec06b commit 61de7a5
Showing 1 changed file with 334 additions and 0 deletions.
334 changes: 334 additions & 0 deletions docs/RFFTAlgorithmDev.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
# Real-Valued Fast Fourier Transform (RFFT) Algorithm Development

In this document, RFFT will represent the Real-Valued Fast Fourier Transform.
In many situations, such as the preprocessing stage for the openai/whisper model, the input data for the DFT consists of purely real numbers. In such cases, the RFFT runs faster than the FFT algorithm.

## Introduction to RFFT
The RFFT algorithm is designed to handle real-valued input signals. Real signals exhibit conjugate symmetry in the frequency domain, meaning that the positive and negative frequency components are complex conjugates of each other. This symmetry property allows the RFFT algorithm to compute only half of the frequency spectrum, reducing computational costs.

## Motivation
The Buddy compiler aims to develop an end-to-end Whisper model with the MLIR project. The Whisper model uses RFFT for generating input features with a fixed length of 400, and this process is repeated 3001 times for each audio. To ensure accuracy for the generated input features, we use the same algorithm as NumPy (which uses the PocketFFT library).

The implementation of PocketFFT's RFFT typically involves two stages: the planning stage and the computation stage. The planning stage calculates the twiddle factors and generates the computation order of butterfly operators. The computation stage follows the schedule order generated by the planning stage and completes the RFFT computation process.

### Planning Stage
- Divide factors: Divide the input length into small factors, which are called codelets or butterfly operators.
- Schedule codelet order: Most FFT libraries use a planner to schedule the computation order for butterfly operators.
- Compute twiddle factors: Compute twiddle factors for different butterfly operators.

### Computation Stage
Follow the scheduled order to achieve RFFT computation. This project already supports butterfly operators 4 and 5, represented as radf4 and radf5, respectively.

To achieve a universal RFFT algorithm, butterfly operators radf2, radf3, and radfg are needed. This part of the work will be done in the 2024 OSPP project.

### Optimization Plan
Currently, only the scalar version is supported; the vectorization pass is still in progress.

## 2024 OSPP
In the 2024 OSPP, three additional operators are needed to achieve universal RFFT functions: which are rad2, rad3, and radfg. Radfg represents for larger prime butterfly factors and is less efficient.

### Pre-Task
1. Algorithm Validation

Before using Rewrite Pattern in MLIR to develop a pass, it is highly recommended to validate the algorithm in handwritten MLIR first. This process can help developers to find the correct pass pipeline and operation syntax (for whatever dialect).

Once buddy-opt is compiled (from the buddy-mlir project), use radf5.mlir and the Makefile documents to compute the a demo RFFT with 25 numbers. You need to change the paths in the Makefile first, then run the "make run-radf5" command to compute the result. Please paste the result in the OSPP Project Proposal (项目申请书).

radf5.mlir:
```
memref.global "private" @ccMem : memref<25xf64> = dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0]>
memref.global "private" @chMem : memref<25xf64> = dense<[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]>
memref.global "private" @twMem : memref<16xf64> = dense<[0.968583, 0.248690, 0.876307, 0.481754, 0.876307, 0.481754, 0.535827, 0.844328, 0.728969, 0.684547, 0.062791, 0.998027, 0.535827, 0.844328, -0.425779, 0.904827]>
memref.global "private" @testMem : memref<5xf64> = dense<[0.0, 0.0, 0.0, 0.0, 0.0]>
// represent factors
memref.global "private" @fctMem : memref<2xindex> = dense<[5, 5]>
func.func private @printMemrefF64(memref<*xf64>)
// #define WA(x,i) wa[(i)+(x)*(ido-1)]
// #define PM(a,b,c,d) { a=c+d; b=c-d; }
// #define MULPM(a,b,c,d,e,f) { a=c*e+d*f; b=c*f-d*e; }
// #define CC(a,b,c) cc[(a)+ido*((b)+l1*(c))]
// #define CH(a,b,c) ch[(a)+ido*((b)+cdim*(c))]
func.func private @WA(%wa : memref<16xf64>, %x : index, %i : index, %ido : index, %i1 : index) -> f64 {
%idom1 = arith.subi %ido, %i1 : index
%tmp1 = arith.muli %x, %idom1 : index
%index = arith.addi %tmp1, %i : index
%result = memref.load %wa[%index] : memref<16xf64>
return %result : f64
}
func.func private @CC(%cc : memref<25xf64>, %a : index, %b : index, %c : index, %ido : index, %l1 : index) -> f64 {
%tmp1 = arith.muli %l1, %c : index
%tmp2 = arith.addi %tmp1, %b : index
%tmp3 = arith.muli %tmp2, %ido : index
%index = arith.addi %tmp3, %a : index
%result = memref.load %cc[%index] : memref<25xf64>
return %result : f64
}
func.func private @CH(%ch : memref<25xf64>, %a : index, %b : index, %c : index, %ido : index, %cdim : index, %toWrite : f64) {
%tmp1 = arith.muli %cdim, %c : index
%tmp2 = arith.addi %tmp1, %b : index
%tmp3 = arith.muli %tmp2, %ido : index
%index = arith.addi %tmp3, %a : index
memref.store %toWrite, %ch[%index] : memref<25xf64>
return
}
func.func private @PM(%c : f64, %d : f64) -> (f64, f64) {
%a = arith.addf %c, %d : f64
%b = arith.subf %c, %d : f64
return %a, %b : f64, f64
}
func.func private @MULPM(%c : f64, %d : f64, %e : f64, %f : f64) -> (f64, f64) {
%tmp1 = arith.mulf %c, %e : f64
%tmp2 = arith.mulf %d, %f : f64
%a = arith.addf %tmp1, %tmp2 : f64
%tmp3 = arith.mulf %c, %f : f64
%tmp4 = arith.mulf %d, %e : f64
%b = arith.subf %tmp3, %tmp4 : f64
return %a, %b : f64, f64
}
func.func @radf5Extend(%cc : memref<25xf64>, %ch : memref<25xf64>, %wa : memref<16xf64>, %ido : index, %l1 : index, %cdim : index) -> () {
%tr11 = arith.constant 0.3090169943749474241 : f64
%tr12 = arith.constant -0.8090169943749474241 : f64
%ti11 = arith.constant 0.95105651629515357212 : f64
%ti12 = arith.constant 0.58778525229247312917 : f64
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : index
%i2 = arith.constant 2 : index
%i3 = arith.constant 3 : index
%i4 = arith.constant 4 : index
scf.for %k = %i0 to %l1 step %i1 {
scf.for %i = %i2 to %ido step %i2 {
%ic = arith.subi %ido, %i : index
%icm1 = arith.subi %ic, %i1 : index
%im1 = arith.subi %i, %i1 : index
%im2 = arith.subi %i, %i2 : index
%wa0im2 = func.call @WA(%wa, %i0, %im2, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%wa0im1 = func.call @WA(%wa, %i0, %im1, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%ccim1k1 = func.call @CC(%cc, %im1, %k, %i1, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%ccik1 = func.call @CC(%cc, %i, %k, %i1, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%dr2_di2:2 = func.call @MULPM(%wa0im2, %wa0im1, %ccim1k1, %ccik1) : (f64, f64, f64, f64) -> (f64, f64)
%wa1im2 = func.call @WA(%wa, %i1, %im2, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%wa1im1 = func.call @WA(%wa, %i1, %im1, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%ccim1k2 = func.call @CC(%cc, %im1, %k, %i2, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%ccik2 = func.call @CC(%cc, %i, %k, %i2, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%dr3_di3:2 = func.call @MULPM(%wa1im2, %wa1im1, %ccim1k2, %ccik2) : (f64, f64, f64, f64) -> (f64, f64)
%wa2im2 = func.call @WA(%wa, %i2, %im2, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%wa2im1 = func.call @WA(%wa, %i2, %im1, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%ccim1k3 = func.call @CC(%cc, %im1, %k, %i3, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%ccik3 = func.call @CC(%cc, %i, %k, %i3, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%dr4_di4:2 = func.call @MULPM(%wa2im2, %wa2im1, %ccim1k3, %ccik3) : (f64, f64, f64, f64) -> (f64, f64)
%wa3im2 = func.call @WA(%wa, %i3, %im2, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%wa3im1 = func.call @WA(%wa, %i3, %im1, %ido, %i1) : (memref<16xf64>, index, index, index, index) -> (f64)
%ccim1k4 = func.call @CC(%cc, %im1, %k, %i4, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%ccik4 = func.call @CC(%cc, %i, %k, %i4, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%dr5_di5:2 = func.call @MULPM(%wa3im2, %wa3im1, %ccim1k4, %ccik4) : (f64, f64, f64, f64) -> (f64, f64)
%cr2_ci5:2 = func.call @PM(%dr5_di5#0, %dr2_di2#0) : (f64, f64) -> (f64, f64)
%ci2_cr5:2 = func.call @PM(%dr2_di2#1, %dr5_di5#1) : (f64, f64) -> (f64, f64)
%cr3_ci4:2 = func.call @PM(%dr4_di4#0, %dr3_di3#0) : (f64, f64) -> (f64, f64)
%ci3_cr4:2 = func.call @PM(%dr3_di3#1, %dr4_di4#1) : (f64, f64) -> (f64, f64)
%ccim1k0 = func.call @CC(%cc, %im1, %k, %i0, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%tmpch0 = arith.addf %ccim1k0, %cr2_ci5#0 : f64
%chim10k = arith.addf %tmpch0, %cr3_ci4#0 : f64
func.call @CH(%ch, %im1, %i0, %k, %ido, %cdim, %chim10k) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%ccik0 = func.call @CC(%cc, %i, %k, %i0, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%tmpch1 = arith.addf %ccik0, %ci2_cr5#0 : f64
%chi0k = arith.addf %tmpch1, %ci3_cr4#0 : f64
func.call @CH(%ch, %i, %i0, %k, %ido, %cdim, %chi0k) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%tmp2 = arith.mulf %tr11, %cr2_ci5#0 : f64
%tmp3 = arith.addf %ccim1k0, %tmp2 : f64
%tmp4 = arith.mulf %tr12, %cr3_ci4#0 : f64
%tr2 = arith.addf %tmp3, %tmp4 : f64
%tmp5 = arith.mulf %tr11, %ci2_cr5#0 : f64
%tmp6 = arith.addf %ccik0, %tmp5 : f64
%tmp7 = arith.mulf %tr12, %ci3_cr4#0 : f64
%ti2 = arith.addf %tmp6, %tmp7 : f64
%tmp8 = arith.mulf %tr12, %cr2_ci5#0 : f64
%tmp9 = arith.addf %ccim1k0, %tmp8 : f64
%tmp10 = arith.mulf %tr11, %cr3_ci4#0 : f64
%tr3 = arith.addf %tmp9, %tmp10 : f64
%tmp11 = arith.mulf %tr12, %ci2_cr5#0 : f64
%tmp12 = arith.addf %ccik0, %tmp11 : f64
%tmp13 = arith.mulf %tr11, %ci3_cr4#0 : f64
%ti3 = arith.addf %tmp12, %tmp13 : f64
%tr5_tr4:2 = func.call @MULPM(%ci2_cr5#1, %ci3_cr4#1, %ti11, %ti12) : (f64, f64, f64, f64) -> (f64, f64)
%ti5_ti4:2 = func.call @MULPM(%cr2_ci5#1, %cr3_ci4#1, %ti11, %ti12) : (f64, f64, f64, f64) -> (f64, f64)
%chtmp1:2 = func.call @PM(%tr2, %tr5_tr4#0) : (f64, f64) -> (f64, f64)
func.call @CH(%ch, %im1, %i2, %k, %ido, %cdim, %chtmp1#0) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
func.call @CH(%ch, %icm1, %i1, %k, %ido, %cdim, %chtmp1#1) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%chtmp2:2 = func.call @PM(%ti5_ti4#0, %ti2) : (f64, f64) -> (f64, f64)
func.call @CH(%ch, %i, %i2, %k, %ido, %cdim, %chtmp2#0) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
func.call @CH(%ch, %ic, %i1, %k, %ido, %cdim, %chtmp2#1) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%chtmp3:2 = func.call @PM(%tr3, %tr5_tr4#1) : (f64, f64) -> (f64, f64)
func.call @CH(%ch, %im1, %i4, %k, %ido, %cdim, %chtmp3#0) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
func.call @CH(%ch, %icm1, %i3, %k, %ido, %cdim, %chtmp3#1) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%chtmp4:2 = func.call @PM(%ti5_ti4#1, %ti3) : (f64, f64) -> (f64, f64)
func.call @CH(%ch, %i, %i4, %k, %ido, %cdim, %chtmp4#0) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
func.call @CH(%ch, %ic, %i3, %k, %ido, %cdim, %chtmp4#1) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
scf.yield
}
scf.yield
}
return
}
func.func @radf5(%cc : memref<25xf64>, %ch : memref<25xf64>, %wa : memref<16xf64>, %ido : index, %l1 : index) {
%cdim = arith.constant 5 : index
%tr11 = arith.constant 0.3090169943749474241 : f64
%tr12 = arith.constant -0.8090169943749474241 : f64
%ti11 = arith.constant 0.95105651629515357212 : f64
%ti12 = arith.constant 0.58778525229247312917 : f64
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : index
%i2 = arith.constant 2 : index
%i3 = arith.constant 3 : index
%i4 = arith.constant 4 : index
%idom1 = arith.subi %ido, %i1 : index
scf.for %iv = %i0 to %l1 step %i1 {
%cc0k4 = func.call @CC(%cc, %i0, %iv, %i4, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%cc0k1 = func.call @CC(%cc, %i0, %iv, %i1, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%cr2_ci5:2 = func.call @PM(%cc0k4, %cc0k1) : (f64, f64) -> (f64, f64)
%cc0k3 = func.call @CC(%cc, %i0, %iv, %i3, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%cc0k2 = func.call @CC(%cc, %i0, %iv, %i2, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%cr3_ci4:2 = func.call @PM(%cc0k3, %cc0k2) : (f64, f64) -> (f64, f64)
%cc0k0 = func.call @CC(%cc, %i0, %iv, %i0, %ido, %l1) : (memref<25xf64>, index, index, index, index, index) -> (f64)
%tmpch0 = arith.addf %cc0k0, %cr2_ci5#0 : f64
%ch0 = arith.addf %tmpch0, %cr3_ci4#0 : f64
func.call @CH(%ch, %i0, %i0, %iv, %ido, %cdim, %ch0) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%tmpch1 = arith.mulf %tr11, %cr2_ci5#0 : f64
%tmpch2 = arith.mulf %tr12, %cr3_ci4#0 : f64
%tmpch3 = arith.addf %cc0k0, %tmpch1 : f64
%ch1 = arith.addf %tmpch2, %tmpch3 : f64
func.call @CH(%ch, %idom1, %i1, %iv, %ido, %cdim, %ch1) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%tmpch4 = arith.mulf %ti11, %cr2_ci5#1 : f64
%tmpch5 = arith.mulf %ti12, %cr3_ci4#1 : f64
%ch2 = arith.addf %tmpch4, %tmpch5 : f64
func.call @CH(%ch, %i0, %i2, %iv, %ido, %cdim, %ch2) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%tmpch6 = arith.mulf %tr12, %cr2_ci5#0 : f64
%tmpch7 = arith.mulf %tr11, %cr3_ci4#0 : f64
%tmpch8 = arith.addf %tmpch6, %tmpch7 : f64
%ch3 = arith.addf %cc0k0, %tmpch8 : f64
func.call @CH(%ch, %idom1, %i3, %iv, %ido, %cdim, %ch3) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
%tmpch9 = arith.mulf %ti12, %cr2_ci5#1 : f64
%tmpch10 = arith.mulf %ti11, %cr3_ci4#1 : f64
%ch4 = arith.subf %tmpch9, %tmpch10 : f64
func.call @CH(%ch, %i0, %i4, %iv, %ido, %cdim, %ch4) : (memref<25xf64>, index, index, index, index, index, f64) -> ()
scf.yield
}
%condition = arith.cmpi ne, %ido, %i1 : index
scf.if %condition {
func.call @radf5Extend(%cc, %ch, %wa, %ido, %l1, %cdim) : (memref<25xf64>, memref<25xf64>, memref<16xf64>, index, index, index) -> ()
scf.yield
}
return
}
func.func @main() {
%cc = memref.get_global @ccMem : memref<25xf64>
%ch = memref.get_global @chMem : memref<25xf64>
%wa = memref.get_global @twMem : memref<16xf64>
%fct = memref.get_global @fctMem : memref<2xindex>
%ido = arith.constant 1 : index
%l1 = arith.constant 5 : index
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : index
func.call @radf5(%cc, %ch, %wa, %ido, %l1) : (memref<25xf64>, memref<25xf64>, memref<16xf64>, index, index) -> ()
func.call @radf5(%ch, %cc, %wa, %l1, %ido) : (memref<25xf64>, memref<25xf64>, memref<16xf64>, index, index) -> ()
%print_out = memref.cast %cc : memref<25xf64> to memref<*xf64>
func.call @printMemrefF64(%print_out) : (memref<*xf64>) -> ()
return
}
```

Makefile:
```
#!/bin/bash
BUDDY_OPT := /PATH/TO/build/bin/buddy-opt
MLIR_OPT := /PATH/TO/llvm/build/bin/mlir-opt
MLIR_CPU_RUNNER := /PATH/TO/llvm/build/bin/mlir-cpu-runner
MLIR_TRANSLATE := /PATH/TO/llvm/build/bin/mlir-translate
LLC := /PATH/TO/llvm/build/bin/llc
OPT_FLAG := -O0
ifeq ($(shell uname),Linux)
MLIR_RUNNER_UTILS := /PATH/TO/llvm/build/lib/libmlir_runner_utils.so
MLIR_C_RUNNER_UTILS := /PATH/TO/llvm/build/lib/libmlir_c_runner_utils.so
MTRIPLE := x86_64-unknown-linux-gnu
BUDDY_OPT_ATTR := avx512f
else ifeq ($(shell uname),Darwin)
MLIR_RUNNER_UTILS := /PATH/TO/llvm/build/lib/libmlir_runner_utils.dylib
MLIR_C_RUNNER_UTILS := /PATH/TO/llvm/build/lib/libmlir_c_runner_utils.dylib
MTRIPLE := x86_64-apple-darwin
endif
run-radf5:
@${BUDDY_OPT} ./radf5.mlir \
-one-shot-bufferize=bufferize-function-boundaries \
-arith-expand \
-convert-math-to-funcs \
-convert-vector-to-scf \
-convert-linalg-to-loops \
-convert-scf-to-cf \
-convert-cf-to-llvm \
-finalize-memref-to-llvm \
-convert-arith-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-vector-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
```

2. Develop and Debug

Develop the first loop of radf3 function in mlir and paste the radf3.mlir and Makefile code in the OSPP Project Proposal (项目申请书).

The C-style code can be find in pocketfft.c file from the [PocketFFT library](https://gitlab.mpcdf.mpg.de/mtr/pocketfft). You can reuse the code in radf5.mlir. Please also design an example like "run-radf5"; input with three numbers is enough.

3. Extra Task

After algorithm validation, there is still some work remaining. This include rewriting the above algorithm into a lowering pass, writing examples, and benchmarks.

Follow MLIR's [toy tutorial](https://mlir.llvm.org/docs/Tutorials/Toy/) or check the code of DAP/DIP dialects(designed in buddy-mlir) to understand the process for developing a dialect operation.

An optional task:
Please describe the work required to add an operation (in a dialect). For example, specify which files are needed and where the examples and benchmarks are located.

0 comments on commit 61de7a5

Please sign in to comment.