diff --git a/.gitignore b/.gitignore index 485cccfcf9..1ffba60cbc 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,6 @@ # Clangd cache .cache + +# Clangd configurations +.clangd \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 83b7981421..cd2379468a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,13 +42,15 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR OR BUDDY_MLIR_OUT_OF_TREE_ message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") - set(LLVM_MLIR_BINARY_DIR ${MLIR_DIR}/../../../bin) - set(LLVM_MLIR_LIBRARY_DIR ${MLIR_DIR}/../../../lib) - set(LLVM_PROJECT_BUILD_DIR ${MLIR_DIR}/../../../) - if(NOT DEFINED LLVM_PROJECT_SOURCE_DIR) - get_filename_component(LLVM_PROJECT_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/llvm/ ABSOLUTE) + # LLVM_MAIN_SRC_DIR is a private variable for the LLVM in-tree build. + # To provide compatibility for unifying the one-step and two-step build, + # we set LLVM_MAIN_SRC_DIR ourselves here. + # This could benefit users who want to specify a custom LLVM source directory, + # but also not interfere with normal users who just want to use the buddy-mlir provided LLVM sources. + if(NOT DEFINED LLVM_MAIN_SRC_DIR) + get_filename_component(LLVM_MAIN_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/llvm/llvm ABSOLUTE) endif() - set(LLVM_MLIR_SOURCE_DIR ${LLVM_PROJECT_SOURCE_DIR}/mlir) + set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") @@ -66,16 +68,9 @@ else() #------------------------------------------------------------------------------- # MLIR/LLVM Configuration #------------------------------------------------------------------------------- - - # Allow using out-of-tree llvm directory - set(LLVM_PROJECT_SOURCE_DIR ${LLVM_MAIN_SRC_DIR}/..) - message(STATUS "Using LLVM Project ${LLVM_PROJECT_SOURCE_DIR}") - set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) set(MLIR_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include) set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) - set(LLVM_MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}/bin) - set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") endif() #------------------------------------------------------------------------------- @@ -98,12 +93,22 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${BUDDY_LIBRARY_DIR}) set(BUDDY_EXAMPLES OFF CACHE BOOL "Build examples") set(BUDDY_ENABLE_OPENCV OFF CACHE BOOL "Enable OpenCV support.") -if(BUDDY_ENABLE_OPENCV) - add_definitions(-DBUDDY_ENABLE_OPENCV) - find_package(JPEG REQUIRED) + if(BUDDY_ENABLE_OPENCV) + add_definitions(-DBUDDY_ENABLE_OPENCV) + find_package(JPEG REQUIRED) + find_package(PNG REQUIRED) + find_package(OpenCV REQUIRED CONFIG) + include_directories(${OpenCV_INCLUDE_DIRS}) + endif() + +if(BUDDY_MLIR_ENABLE_DIP_LIB) + add_definitions(-DBUDDY_MLIR_ENABLE_DIP_LIB) + find_package(PNG REQUIRED) +endif() + +if(BUDDY_ENABLE_PNG) + add_definitions(-DBUDDY_ENABLE_PNG) find_package(PNG REQUIRED) - find_package(OpenCV REQUIRED CONFIG) - include_directories(${OpenCV_INCLUDE_DIRS}) endif() # Generate libraries into `lib` of build directory. @@ -220,6 +225,8 @@ if(BUDDY_MLIR_ENABLE_PYTHON_PACKAGES) # Create empty __init__.py files to make these directories Python packages file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/__init__.py "") file(WRITE ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy/compiler/__init__.py "") + + install(DIRECTORY ${BUDDY_MLIR_PYTHON_PACKAGES_DIR}/buddy DESTINATION python_packages) endif() #------------------------------------------------------------------------------- diff --git a/README.md b/README.md index cb9a5f1c24..2e44658b02 100644 --- a/README.md +++ b/README.md @@ -96,13 +96,37 @@ $ cmake -G Ninja .. \ -DCMAKE_BUILD_TYPE=RELEASE \ -DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \ -DPython3_EXECUTABLE=$(which python3) +$ ninja +$ ninja check-buddy +$ export BUDDY_MLIR_BUILD_DIR=$PWD +$ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build +$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} +``` + +To configure the build environment for using image processing libraries, follow these steps: + +``` +$ cmake -G Ninja .. \ + -DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \ + -DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DBUDDY_MLIR_ENABLE_DIP_LIB=ON \ + -DBUDDY_ENABLE_PNG=ON +$ ninja +$ ninja check-buddy ``` -If you want to add domain-specific framework support, please add the following cmake options: +To build buddy-mlir with custom LLVM sources: -| Framework | Enable Option | Other Options | -| -------------- | ------------- | ------------- | -| OpenCV | `-DBUDDY_ENABLE_OPENCV=ON` | Add `-DOpenCV_DIR=` or install OpenCV release version on your local device. | +``` +$ cmake -G Ninja .. \ + -DMLIR_DIR=PATH/TO/LLVM/lib/cmake/mlir \ + -DLLVM_DIR=PATH/TO/LLVM/lib/cmake/llvm \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_BUILD_TYPE=RELEASE \ + -DLLVM_MAIN_SRC_DIR=PATH/TO/LLVM_SOURCE +```

One-step building strategy

@@ -134,7 +158,7 @@ This repository have nix flake support. You can follow the [nix installation ins nix develop . ``` -This will setup a bash shell with `clang`, `clangd`, `cmake`, `ninja`, and other necessary dependencies to build buddy-mlir from source. +This will setup a bash shell with `clang`, `ccls`, `cmake`, `ninja`, and other necessary dependencies to build buddy-mlir from source. - If you want to use the buddy-mlir bintools diff --git a/backend/include/llvm/IR/CMakeLists.txt b/backend/include/llvm/IR/CMakeLists.txt index b3447eae61..2de6b999b3 100644 --- a/backend/include/llvm/IR/CMakeLists.txt +++ b/backend/include/llvm/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -include_directories(${LLVM_PROJECT_SOURCE_DIR}/llvm/include/llvm/IR/) +include_directories(${LLVM_MAIN_SRC_DIR}/include/llvm/IR/) set(LLVM_TARGET_DEFINITIONS IntrinsicsBuddyExt.td) tablegen(LLVM IntrinsicImpl.inc -gen-intrinsic-impl) diff --git a/backend/llvm/lib/Analysis/CMakeLists.txt b/backend/llvm/lib/Analysis/CMakeLists.txt index 2a3a65971b..117f75d89b 100644 --- a/backend/llvm/lib/Analysis/CMakeLists.txt +++ b/backend/llvm/lib/Analysis/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Analysis_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Analysis) +set(LLVM_Analysis_DIR ${LLVM_MAIN_SRC_DIR}/lib/Analysis) add_llvm_component_library(LLVMBuddyAnalysis diff --git a/backend/llvm/lib/AsmParser/CMakeLists.txt b/backend/llvm/lib/AsmParser/CMakeLists.txt index b5411d1007..d687d1d3bc 100644 --- a/backend/llvm/lib/AsmParser/CMakeLists.txt +++ b/backend/llvm/lib/AsmParser/CMakeLists.txt @@ -1,6 +1,6 @@ # AsmParser -set(LLVM_AsmParser_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/AsmParser) +set(LLVM_AsmParser_DIR ${LLVM_MAIN_SRC_DIR}/lib/AsmParser) add_llvm_component_library(LLVMBuddyAsmParser ${LLVM_AsmParser_DIR}/LLLexer.cpp diff --git a/backend/llvm/lib/Bitcode/Reader/CMakeLists.txt b/backend/llvm/lib/Bitcode/Reader/CMakeLists.txt index cf92a543fd..7ea9048011 100644 --- a/backend/llvm/lib/Bitcode/Reader/CMakeLists.txt +++ b/backend/llvm/lib/Bitcode/Reader/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Reader_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Bitcode/Reader) +set(LLVM_Reader_DIR ${LLVM_MAIN_SRC_DIR}/lib/Bitcode/Reader) add_llvm_component_library(LLVMBuddyBitReader ${LLVM_Reader_DIR}/BitcodeAnalyzer.cpp diff --git a/backend/llvm/lib/Bitcode/Writer/CMakeLists.txt b/backend/llvm/lib/Bitcode/Writer/CMakeLists.txt index f19595cead..a8b7f0c274 100644 --- a/backend/llvm/lib/Bitcode/Writer/CMakeLists.txt +++ b/backend/llvm/lib/Bitcode/Writer/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Writer_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Bitcode/Writer) +set(LLVM_Writer_DIR ${LLVM_MAIN_SRC_DIR}/lib/Bitcode/Writer) add_llvm_component_library(LLVMBuddyBitWriter diff --git a/backend/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt b/backend/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt index fe3273dd5d..b942f4f734 100644 --- a/backend/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt +++ b/backend/llvm/lib/CodeGen/AsmPrinter/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_AsmPrinter_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/CodeGen/AsmPrinter) +set(LLVM_AsmPrinter_DIR ${LLVM_MAIN_SRC_DIR}/lib/CodeGen/AsmPrinter) add_llvm_component_library(LLVMBuddyAsmPrinter ${LLVM_AsmPrinter_DIR}/AccelTable.cpp diff --git a/backend/llvm/lib/CodeGen/CMakeLists.txt b/backend/llvm/lib/CodeGen/CMakeLists.txt index 1794b38fa4..7eb38876db 100644 --- a/backend/llvm/lib/CodeGen/CMakeLists.txt +++ b/backend/llvm/lib/CodeGen/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_CodeGen_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/CodeGen) +set(LLVM_CodeGen_DIR ${LLVM_MAIN_SRC_DIR}/lib/CodeGen) add_llvm_component_library(LLVMBuddyCodeGen ${LLVM_CodeGen_DIR}/AggressiveAntiDepBreaker.cpp diff --git a/backend/llvm/lib/CodeGen/MIRParser/CMakeLists.txt b/backend/llvm/lib/CodeGen/MIRParser/CMakeLists.txt index 6275b1ece0..1ab94ee930 100644 --- a/backend/llvm/lib/CodeGen/MIRParser/CMakeLists.txt +++ b/backend/llvm/lib/CodeGen/MIRParser/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_MIRParser_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/CodeGen/MIRParser) +set(LLVM_MIRParser_DIR ${LLVM_MAIN_SRC_DIR}/lib/CodeGen/MIRParser) add_llvm_component_library(LLVMBuddyMIRParser ${LLVM_MIRParser_DIR}/MILexer.cpp diff --git a/backend/llvm/lib/CodeGen/SelectionDAG/CMakeLists.txt b/backend/llvm/lib/CodeGen/SelectionDAG/CMakeLists.txt index 4bb3cde980..3b467a4eda 100644 --- a/backend/llvm/lib/CodeGen/SelectionDAG/CMakeLists.txt +++ b/backend/llvm/lib/CodeGen/SelectionDAG/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_SelectionDAG_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/CodeGen/SelectionDAG) +set(LLVM_SelectionDAG_DIR ${LLVM_MAIN_SRC_DIR}/lib/CodeGen/SelectionDAG) add_llvm_component_library(LLVMBuddySelectionDAG ${LLVM_SelectionDAG_DIR}/DAGCombiner.cpp diff --git a/backend/llvm/lib/IR/CMakeLists.txt b/backend/llvm/lib/IR/CMakeLists.txt index e6895a1f80..0d56184730 100644 --- a/backend/llvm/lib/IR/CMakeLists.txt +++ b/backend/llvm/lib/IR/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_IR_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/IR) +set(LLVM_IR_DIR ${LLVM_MAIN_SRC_DIR}/lib/IR) add_llvm_component_library(LLVMBuddyCore ${LLVM_IR_DIR}/AbstractCallSite.cpp diff --git a/backend/llvm/lib/IRReader/CMakeLists.txt b/backend/llvm/lib/IRReader/CMakeLists.txt index 9b315dec3b..72e95722a8 100644 --- a/backend/llvm/lib/IRReader/CMakeLists.txt +++ b/backend/llvm/lib/IRReader/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_IRReader_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/IRReader) +set(LLVM_IRReader_DIR ${LLVM_MAIN_SRC_DIR}/lib/IRReader) add_llvm_component_library(LLVMBuddyIRReader ${LLVM_IRReader_DIR}/IRReader.cpp diff --git a/backend/llvm/lib/Object/CMakeLists.txt b/backend/llvm/lib/Object/CMakeLists.txt index 8695d55ba9..a8425e97c0 100644 --- a/backend/llvm/lib/Object/CMakeLists.txt +++ b/backend/llvm/lib/Object/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Object_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Object) +set(LLVM_Object_DIR ${LLVM_MAIN_SRC_DIR}/lib/Object) add_llvm_component_library(LLVMBuddyObject ${LLVM_Object_DIR}/Archive.cpp diff --git a/backend/llvm/lib/ProfileData/CMakeLists.txt b/backend/llvm/lib/ProfileData/CMakeLists.txt index 9ae05a36fe..742ecf662a 100644 --- a/backend/llvm/lib/ProfileData/CMakeLists.txt +++ b/backend/llvm/lib/ProfileData/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_ProfileData_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/ProfileData) +set(LLVM_ProfileData_DIR ${LLVM_MAIN_SRC_DIR}/lib/ProfileData) add_llvm_component_library(LLVMBuddyProfileData ${LLVM_ProfileData_DIR}/GCOV.cpp diff --git a/backend/llvm/lib/Remarks/CMakeLists.txt b/backend/llvm/lib/Remarks/CMakeLists.txt index 4ed8775770..5c1c81b7d8 100644 --- a/backend/llvm/lib/Remarks/CMakeLists.txt +++ b/backend/llvm/lib/Remarks/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Remarks_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Remarks) +set(LLVM_Remarks_DIR ${LLVM_MAIN_SRC_DIR}/lib/Remarks) add_llvm_component_library(LLVMBuddyRemarks ${LLVM_Remarks_DIR}/BitstreamRemarkParser.cpp diff --git a/backend/llvm/lib/Target/CMakeLists.txt b/backend/llvm/lib/Target/CMakeLists.txt index c6298c3837..1dd5cd34f3 100644 --- a/backend/llvm/lib/Target/CMakeLists.txt +++ b/backend/llvm/lib/Target/CMakeLists.txt @@ -2,7 +2,7 @@ list(APPEND LLVM_COMMON_DEPENDS buddy_intrinsics_gen) list(APPEND LLVM_TABLEGEN_FLAGS -I ${LLVM_MAIN_SRC_DIR}/lib/Target) -set(LLVM_Target_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Target) +set(LLVM_Target_DIR ${LLVM_MAIN_SRC_DIR}/lib/Target) add_llvm_component_library(LLVMBuddyTarget ${LLVM_Target_DIR}/Target.cpp diff --git a/backend/llvm/lib/Target/RISCV/CMakeLists.txt b/backend/llvm/lib/Target/RISCV/CMakeLists.txt index 4a66f65292..6bfee7c2f9 100644 --- a/backend/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/backend/llvm/lib/Target/RISCV/CMakeLists.txt @@ -21,7 +21,7 @@ macro(buddy_add_llvm_target target_name) set( CURRENT_LLVM_TARGET LLVM${target_name} ) endmacro(buddy_add_llvm_target) -set(LLVM_TARGET_RISCV_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Target/RISCV) +set(LLVM_TARGET_RISCV_DIR ${LLVM_MAIN_SRC_DIR}/lib/Target/RISCV) # ------------------------------------------------------------------------------ # Configure RISC-V Buddy Extension. diff --git a/backend/llvm/lib/Transforms/IPO/CMakeLists.txt b/backend/llvm/lib/Transforms/IPO/CMakeLists.txt index 74ff798637..08392abf87 100644 --- a/backend/llvm/lib/Transforms/IPO/CMakeLists.txt +++ b/backend/llvm/lib/Transforms/IPO/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_IPO_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Transforms/IPO) +set(LLVM_IPO_DIR ${LLVM_MAIN_SRC_DIR}/lib/Transforms/IPO) add_llvm_component_library(LLVMBuddyIPO ${LLVM_IPO_DIR}/AlwaysInliner.cpp diff --git a/backend/llvm/lib/Transforms/Scalar/CMakeLists.txt b/backend/llvm/lib/Transforms/Scalar/CMakeLists.txt index c3c412b9a9..6bbcf432e8 100644 --- a/backend/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/backend/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Scalar_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Transforms/Scalar) +set(LLVM_Scalar_DIR ${LLVM_MAIN_SRC_DIR}/lib/Transforms/Scalar) add_llvm_component_library(LLVMBuddyScalarOpts ${LLVM_Scalar_DIR}/ADCE.cpp diff --git a/backend/llvm/lib/Transforms/Utils/CMakeLists.txt b/backend/llvm/lib/Transforms/Utils/CMakeLists.txt index 989a672edd..e3313e07b2 100644 --- a/backend/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/backend/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Utils_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Transforms/Utils) +set(LLVM_Utils_DIR ${LLVM_MAIN_SRC_DIR}/lib/Transforms/Utils) add_llvm_component_library(LLVMBuddyTransformUtils diff --git a/backend/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/backend/llvm/lib/Transforms/Vectorize/CMakeLists.txt index e9cece2c46..669aae5850 100644 --- a/backend/llvm/lib/Transforms/Vectorize/CMakeLists.txt +++ b/backend/llvm/lib/Transforms/Vectorize/CMakeLists.txt @@ -1,4 +1,4 @@ -set(LLVM_Vectorize_DIR ${LLVM_PROJECT_SOURCE_DIR}/llvm/lib/Transforms/Vectorize) +set(LLVM_Vectorize_DIR ${LLVM_MAIN_SRC_DIR}/lib/Transforms/Vectorize) add_llvm_component_library(LLVMBuddyVectorize ${LLVM_Vectorize_DIR}/LoadStoreVectorizer.cpp diff --git a/docs/PythonEnvironment.md b/docs/PythonEnvironment.md new file mode 100644 index 0000000000..77f431e85c --- /dev/null +++ b/docs/PythonEnvironment.md @@ -0,0 +1,10 @@ +# Python Virtual Environment Setup Guide for Buddy-mlir + +We recommend you to use anaconda3 to create python virtual environment. You should install python packages as buddy-mlir/requirements. + +```bash +$ conda create -n python=3.11 +$ conda activate +$ cd buddy-mlir +$ pip install -r requirements.txt +``` \ No newline at end of file diff --git a/docs/RVVEnviroment.md b/docs/RVVEnvironment.md similarity index 100% rename from docs/RVVEnviroment.md rename to docs/RVVEnvironment.md diff --git a/examples/BuddyBert/CMakeLists.txt b/examples/BuddyBert/CMakeLists.txt index 93dc7c2daa..95c98dfa96 100644 --- a/examples/BuddyBert/CMakeLists.txt +++ b/examples/BuddyBert/CMakeLists.txt @@ -7,13 +7,13 @@ add_custom_command( add_custom_command( OUTPUT forward.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyBert/forward.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyBert/forward.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-opt + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt -pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyBert/forward.o + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyBert/forward.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyBert/forward.mlir COMMENT "Building forward.o" VERBATIM) @@ -22,11 +22,11 @@ add_custom_command( OUTPUT subgraph0.o COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_EXAMPLES_DIR}/BuddyBert/subgraph0.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, func-bufferize-dynamic-offset, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize))" | - ${LLVM_MLIR_BINARY_DIR}/mlir-opt + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt -pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyBert/subgraph0.o + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyBert/subgraph0.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyBert/subgraph0.mlir COMMENT "Building subgraph0.o" VERBATIM) @@ -36,7 +36,7 @@ add_library(BERT STATIC forward.o subgraph0.o) SET_TARGET_PROPERTIES(BERT PROPERTIES LINKER_LANGUAGE C) add_executable(buddy-bert-run bert-main.cpp) -target_link_directories(buddy-bert-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR}) +target_link_directories(buddy-bert-run PRIVATE ${LLVM_LIBRARY_DIR}) set(BUDDY_BERT_LIBS BERT mlir_c_runner_utils) target_link_libraries(buddy-bert-run ${BUDDY_BERT_LIBS}) diff --git a/examples/BuddyConvolution/.gitignore b/examples/BuddyConvolution/.gitignore index 0194ea7a68..df9389428a 100644 --- a/examples/BuddyConvolution/.gitignore +++ b/examples/BuddyConvolution/.gitignore @@ -1,3 +1,4 @@ log.mlir log.ll log.s +a.out diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir new file mode 100644 index 0000000000..76d5e4d932 --- /dev/null +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir @@ -0,0 +1,137 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-vector-to-scf \ +// RUN: -lower-affine \ +// RUN: -arith-bufferize \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O3 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// Using `8` as the vector size. +#map = affine_map<(d0) -> (d0 floordiv 8)> +#map0 = affine_map<(d0, d1, d2, d3) -> (d2)> +#map1 = affine_map<(d0, d1) -> (d0 + d1)> +#map2 = affine_map<(d0, d1) -> (d0 + d1 * 8)> +#map3 = affine_map<(d0) -> (d0 * 8)> + +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func private @rtclock() -> f64 + + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + %f0 = arith.constant 0. : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %n = memref.dim %arg0, %c0 : memref + %h_i = memref.dim %arg0, %c1 : memref + %w_i = memref.dim %arg0, %c2 : memref + %c = memref.dim %arg0, %c3 : memref + %f = memref.dim %arg1, %c0 : memref + %h_k = memref.dim %arg1, %c1 : memref + %w_k = memref.dim %arg1, %c2 : memref + %h_o = memref.dim %arg2, %c1 : memref + %w_o = memref.dim %arg2, %c2 : memref + + // Output is NHoWoF + affine.for %idx_n = %c0 to %n { + affine.for %idx_f = %c0 to %f { + affine.for %idx_c = %c0 to %c { + affine.for %idx_h_o = %c0 to %h_o { + affine.for %idx_h_k = %c0 to %h_k { + affine.for %idx_w_k = %c0 to %w_k { + affine.for %idx_w_o = %c0 to #map(%w_o) { + %kernel_ele = memref.load %arg1[%idx_f, %idx_h_k, %idx_w_k, %idx_c] : memref + %kernel_vec = vector.broadcast %kernel_ele : f32 to vector<8xf32> + %in_iter_h = affine.apply #map1 (%idx_h_k, %idx_h_o) + %in_iter_w = affine.apply #map2 (%idx_w_k, %idx_w_o) + %out_iter_w = affine.apply #map3 (%idx_w_o) + %input_vec = vector.transfer_read %arg0[%idx_n, %in_iter_h, %in_iter_w, %idx_c], %f0 + { permutation_map = #map0 } : memref, vector<8xf32> + %output_vec = vector.transfer_read %arg2[%idx_n, %idx_h_o, %out_iter_w, %idx_f], %f0 + { permutation_map = #map0 } : memref, vector<8xf32> + %res_vec = vector.fma %kernel_vec, %input_vec, %output_vec : vector<8xf32> + vector.transfer_write %res_vec, %arg2[%idx_n, %idx_h_o, %out_iter_w, %idx_f] + { permutation_map = #map0 } : vector<8xf32>, memref + } + } + } + } + } + } + } + + return + } + + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + scf.for %idx3 = %c0 to %arg3 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2, %idx3] : memref + } + } + } + } + return %0 : memref + } + + func.func @main() { + %f0 = arith.constant 0.000000e+00 : f32 + %f2 = arith.constant 2.000000e+00 : f32 + %f3 = arith.constant 3.000000e+00 : f32 + + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c24 = arith.constant 24 : index + %c28 = arith.constant 28 : index + + // %v0 = call @alloc_f32(%c1, %c12, %c12, %c6, %f2) : (index, index, index, index, f32) -> memref + // %v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref + // %v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref + + %v0 = call @alloc_f32(%c1, %c28, %c28, %c1, %f2) : (index, index, index, index, f32) -> memref + %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref + %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref + + %t_start = call @rtclock() : () -> f64 + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + %t_end = call @rtclock() : () -> f64 + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [150{{(, 150)*}}], + %print_v2 = memref.cast %v2 : memref to memref<*xf32> + call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () + + %time = arith.subf %t_end, %t_start : f64 + vector.print %time : f64 + + memref.dealloc %v0 : memref + memref.dealloc %v1 : memref + memref.dealloc %v2 : memref + + return + } +} diff --git a/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir new file mode 100644 index 0000000000..90759355e9 --- /dev/null +++ b/examples/BuddyConvolution/conv2d-nhwc-fhwc.mlir @@ -0,0 +1,88 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops \ +// RUN: -lower-affine \ +// RUN: -arith-bufferize \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func private @rtclock() -> f64 + + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nhwc_fhwc ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) + return + } + + func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + scf.for %idx3 = %c0 to %arg3 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2, %idx3] : memref + } + } + } + } + return %0 : memref + } + + func.func @main() { + %f0 = arith.constant 0.000000e+00 : f32 + %f2 = arith.constant 2.000000e+00 : f32 + %f3 = arith.constant 3.000000e+00 : f32 + + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c24 = arith.constant 24 : index + %c28 = arith.constant 28 : index + + // %v0 = call @alloc_f32(%c1, %c12, %c12, %c6, %f2) : (index, index, index, index, f32) -> memref + // %v1 = call @alloc_f32(%c16, %c5, %c5, %c6, %f3) : (index, index, index, index, f32) -> memref + // %v2 = call @alloc_f32(%c1, %c8, %c8, %c16, %f0) : (index, index, index, index, f32) -> memref + + %v0 = call @alloc_f32(%c1, %c28, %c28, %c1, %f2) : (index, index, index, index, f32) -> memref + %v1 = call @alloc_f32(%c6, %c5, %c5, %c1, %f3) : (index, index, index, index, f32) -> memref + %v2 = call @alloc_f32(%c1, %c24, %c24, %c6, %f0) : (index, index, index, index, f32) -> memref + + %t_start = call @rtclock() : () -> f64 + call @conv_2d_nhwc_fhwc(%v0, %v1, %v2) : (memref, memref, memref) -> () + %t_end = call @rtclock() : () -> f64 + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref + // CHECK: [ + // CHECK: [ + // CHECK: [ + // CHECK: [150{{(, 150)*}}], + %print_v2 = memref.cast %v2 : memref to memref<*xf32> + call @printMemrefF32(%print_v2) : (memref<*xf32>) -> () + + %time = arith.subf %t_end, %t_start : f64 + vector.print %time : f64 + + memref.dealloc %v0 : memref + memref.dealloc %v1 : memref + memref.dealloc %v2 : memref + return + } +} diff --git a/examples/BuddyConvolution/makefile b/examples/BuddyConvolution/makefile index 063832fa0f..1962643766 100644 --- a/examples/BuddyConvolution/makefile +++ b/examples/BuddyConvolution/makefile @@ -1,10 +1,12 @@ #!/bin/bash BUDDY_OPT := ../../build/bin/buddy-opt MLIR_OPT := ../../llvm/build/bin/mlir-opt +CLANG := ../../llvm/build/bin/clang MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner LLC := ../../llvm/build/bin/llc -OPT_FLAG := -O0 +OPT_FLAG := -O3 +MLIR_LIB := ../../llvm/build/lib/ ifeq ($(shell uname),Linux) MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.so @@ -61,3 +63,65 @@ conv2d-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +conv2d-nhwc-fhwc-run: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +conv2d-nhwc-fhwc-aot: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} -mlir-to-llvmir -o log.ll + ${CLANG} log.ll ${OPT_FLAG} \ + -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out + +conv2d-nhwc-fhwc-opt-run: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-opt.mlir \ + -convert-vector-to-scf \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} -O3 -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +conv2d-nhwc-fhwc-opt-aot: + @${BUDDY_OPT} ./conv2d-nhwc-fhwc-opt.mlir \ + -convert-vector-to-scf \ + -lower-affine \ + -arith-bufferize \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} -mlir-to-llvmir -o log.ll + ${CLANG} log.ll -O3 \ + -L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + @LD_LIBRARY_PATH=${MLIR_LIB} ./a.out diff --git a/examples/BuddyGPU/.gitignore b/examples/BuddyGPU/.gitignore new file mode 100644 index 0000000000..d82aeb33bb --- /dev/null +++ b/examples/BuddyGPU/.gitignore @@ -0,0 +1,4 @@ +log.mlir +log.ll +log.s +matmul-cubin.mlir diff --git a/examples/BuddyGPU/README.md b/examples/BuddyGPU/README.md new file mode 100644 index 0000000000..7c4081e401 --- /dev/null +++ b/examples/BuddyGPU/README.md @@ -0,0 +1,40 @@ +# Buddy GPU Example +This example demonstrates how to use the Buddy GPU to run a simple single-kernel program. + +## Matmul +The example program is a simple matrix multiplication kernel. The linalg definition is in the `matmul.mlir` file. +A transform sequence is in `transform.mlir` to optimize this kernel and prepare it for execution on the GPU. +The `matmul-cubin.mlir` provides a lowered file, in case the pipeline is not working. + +Run the following command to compile and run the program: +``` + make buddy-gpu-matmul + python run-module-gpu.py --source matmul.mlir --target matmul-cubin.mlir --llvm_dir ../../llvm +``` + +The result should be: +``` +[[502.9141 499.7761 511.35623 ... 500.9083 505.25574 511.03818] + [499.57034 494.8066 506.427 ... 492.7868 497.22513 509.95612] + [511.2017 516.017 513.631 ... 515.5991 515.6389 521.8318 ] + ... + [496.2721 496.3155 506.08054 ... 502.36798 505.94202 516.3577 ] + [512.06866 505.80127 518.81934 ... 510.64966 510.10333 531.85364] + [501.23514 500.17123 505.71808 ... 496.4447 500.5735 514.4204 ]] +[[503.26013 500.11093 511.70193 ... 501.24622 505.60373 511.38376] + [499.89877 495.13043 506.762 ... 493.1151 497.5555 510.29483] + [511.54883 516.35547 513.9717 ... 515.944 515.9865 522.1828 ] + ... + [496.59937 496.63785 506.41483 ... 502.70337 506.27927 516.6994 ] + [512.4154 506.1411 519.17175 ... 510.9929 510.45322 532.2152 ] + [501.57388 500.5093 506.06213 ... 496.7807 500.91638 514.77124]] +MLIR equal to NumPy? True +``` + +As the tensorcore doesn't support fp32 computation, the operands are converted to tf32, hence the result is not exactly the same as the PyTorch result. + +### Profiling +You need to install nsight compute first. +``` +ncu -o profile-result --set full python run-module-gpu.py --source matmul.mlir --target matmul-cubin.mlir --llvm_dir ../../llvm +``` \ No newline at end of file diff --git a/examples/BuddyGPU/makefile b/examples/BuddyGPU/makefile new file mode 100644 index 0000000000..5dbd9c25cd --- /dev/null +++ b/examples/BuddyGPU/makefile @@ -0,0 +1,22 @@ +#!/bin/bash +BUDDY_OPT := ../../build/bin/buddy-opt +MLIR_OPT := ../../llvm/build/bin/mlir-opt +MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate +MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner +LLC := ../../llvm/build/bin/llc + +buddy-gpu-matmul-lower: + @${BUDDY_OPT} matmul.mlir \ + -transform-preload-library="transform-library-paths=transform.mlir" \ + -transform-interpreter="entry-point=codegen" \ + -o log.mlir + +buddy-gpu-matmul: + @${BUDDY_OPT} matmul.mlir -transform-preload-library="transform-library-paths=transform.mlir" -transform-interpreter="entry-point=codegen" | \ + ${BUDDY_OPT} --pass-pipeline='builtin.module(func.func(nvgpu-optimize-shared-memory))' | \ + ${BUDDY_OPT} -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -linalg-bufferize -convert-linalg-to-affine-loops -affine-loop-fusion -affine-parallelize -lower-affine -canonicalize -func-bufferize -arith-bufferize -tensor-bufferize -buffer-deallocation -finalizing-bufferize -canonicalize | \ + ${BUDDY_OPT} -gpu-launch-sink-index-computations -canonicalize -legalize-shmem-outlining -canonicalize | \ + ${BUDDY_OPT} -convert-memcpy-to-gpu -gpu-async-region -canonicalize | \ + ${BUDDY_OPT} -convert-scf-to-cf -memref-expand -finalize-memref-to-llvm -convert-arith-to-llvm --convert-vector-to-llvm -convert-gpu-to-nvvm='has-redux=1' | \ + ${BUDDY_OPT} -llvm-request-c-wrappers -canonicalize -cse -sccp | \ + ${MLIR_OPT} --test-lower-to-nvvm="cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=fatbin" -o matmul-cubin.mlir diff --git a/examples/BuddyGPU/matmul.mlir b/examples/BuddyGPU/matmul.mlir new file mode 100644 index 0000000000..2f0fa226c1 --- /dev/null +++ b/examples/BuddyGPU/matmul.mlir @@ -0,0 +1,12 @@ +!unit = f32 +!lhs = tensor<5376x2048x!unit> +!rhs = tensor<2048x5376x!unit> +!res = tensor<5376x5376x!unit> + +func.func @matmul(%arg0: !lhs, %arg1: !rhs) -> !res { + %cst = arith.constant 0.000000e+00 : !unit + %0 = tensor.empty() : !res + %1 = linalg.fill ins(%cst : !unit) outs(%0 : !res) -> !res + %2 = linalg.matmul ins(%arg0, %arg1: !lhs, !rhs) outs(%1: !res) -> !res + func.return %2 : !res +} diff --git a/examples/BuddyGPU/run-module-gpu.py b/examples/BuddyGPU/run-module-gpu.py new file mode 100644 index 0000000000..a7b210b379 --- /dev/null +++ b/examples/BuddyGPU/run-module-gpu.py @@ -0,0 +1,145 @@ +# ===- run-module-gpu.py --------------------------------------------------===// +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ===----------------------------------------------------------------------===// +# +# This file is a script to test whether the specified MLIR module on the GPU +# calculates the same result as NumPy. +# +# ===----------------------------------------------------------------------===// + +import mlir.dialects.func as func +from mlir.passmanager import * +from mlir.execution_engine import * +from mlir import runtime as rt +from mlir.ir import * +import numpy as np +import ctypes +import argparse as ap + + +def to_numpy(element_type: str) -> np.dtype: + match element_type: + case "f16": + return np.float16 + case "f32": + return np.float32 + case "f64": + return np.float64 + case "i8": + return np.int8 + case "i16": + return np.int16 + case "i32": + return np.int32 + case "i64": + return np.int64 + case "bf16": + return np.dtype("bfloat16") + case _: + raise ValueError(f"Unsupported type: {element_type}") + + +def new_ranked_memref_descriptor(nparray: np.ndarray): + if nparray.dtype == "bfloat16": + ctp = rt.F16 + else: + ctp = rt.as_ctype(nparray.dtype) + + if nparray.ndim == 0: + x = rt.make_zero_d_memref_descriptor(ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + return x + + x = rt.make_nd_memref_descriptor(nparray.ndim, ctp)() + nbytes = nparray.nbytes + buffer = ctypes.create_string_buffer(nbytes) + ctypes.memmove(buffer, nparray.ctypes.data, nbytes) + x.allocated = ctypes.cast(buffer, ctypes.c_void_p).value + x.aligned = ctypes.cast(buffer, ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t( + *[x // nparray.itemsize for x in nparray.strides] + ) + return x + + +def get_memref_descriptors(args: list[Type]): + memref_ptrs = [] + for arg in args: + elem_type = to_numpy(str(arg.element_type)) + np_arg = np.random.rand(*arg.shape).astype(elem_type) + memref_ptrs.append( + ctypes.pointer(ctypes.pointer(new_ranked_memref_descriptor(np_arg))) + ) + return memref_ptrs + + +def test(source, target, llvm_dir): + with Context() as ctx: + file = open(source, "r") + module: Module = Module.parse(file.read()) + funcOp: func.FuncOp = ( + module.operation.regions[0].blocks[0].operations[0] + ) + funcName = str(funcOp.name).replace('"', "") + assert isinstance(funcOp, func.FuncOp) + args_type: list[Type] = [arg.type for arg in funcOp.arguments] + res_type = funcOp.type.results + + file = open(target, "r") + # newModule = lower_to_llvm_cpu(module) + newModule = Module.parse(file.read()) + memref_ptrs = get_memref_descriptors(res_type + args_type) + + engine = ExecutionEngine( + newModule, + shared_libs=[ + llvm_dir + "/build/lib/libomp.so", + llvm_dir + "/build/lib/libmlir_c_runner_utils.so", + llvm_dir + "/build/lib/libmlir_async_runtime.so", + llvm_dir + "/build/lib/libmlir_runner_utils.so", + llvm_dir + "/build/lib/libmlir_cuda_runtime.so", + ], + opt_level=3, + ) + engine.invoke(funcName, *memref_ptrs) + out = rt.ranked_memref_to_numpy(memref_ptrs[0][0]) + if str(res_type[0].element_type) == "bf16": + print("Running on BF16 mode, skipping numpy comparison.") + else: + print(out) + input1 = rt.ranked_memref_to_numpy(memref_ptrs[1][0]) + input2 = rt.ranked_memref_to_numpy(memref_ptrs[2][0]) + numpy_out = np.matmul(input1, input2) + print(numpy_out) + print( + f"MLIR equal to NumPy? {np.allclose(out, numpy_out,rtol=1e-03, atol=1e-03)}" + ) + + +if __name__ == "__main__": + parser = ap.ArgumentParser() + parser.add_argument("--source", type=str, required=True) + parser.add_argument("--target", type=str, required=True) + parser.add_argument("--llvm_dir", type=str, required=True) + args = parser.parse_args() + test(args.source, args.target, args.llvm_dir) diff --git a/examples/BuddyGPU/transform.mlir b/examples/BuddyGPU/transform.mlir new file mode 100644 index 0000000000..e2a02a9a97 --- /dev/null +++ b/examples/BuddyGPU/transform.mlir @@ -0,0 +1,311 @@ +module attributes { transform.with_named_sequence } { + transform.named_sequence @codegen(%arg0: !transform.any_op) { + // Match the target operations and assign them to SSA values. + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %fill = transform.structured.match ops{["linalg.fill"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + + // Perform tiling for the grid. + // For the matrix multiplication of 5376x2048 and 2048x5376, the compilation + // strategy sets the tile size for grid-based partitioning to 128x256. + // This means that each [128, 2048] @ [2048, 256] matmul tile is computed within a GPU block, + // while multiple such blocks are computed in parallel across the grid. + // `tile_sizes` specify the dimensions of the tiled matmul result. + // `%tiled_op` is the tiled matmul operation within the `scf.forall` loop. + // `%forall_op` is the `scf.forall` loop that maintains tile information. + %tiled_op, %forall_op = transform.structured.tile_using_forall %matmul + tile_sizes [128, 256] (mapping = [#gpu.block, #gpu.block]) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Perform canonicalization. + %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %1 : !transform.any_op + %all_loops = transform.structured.match interface{LoopLikeInterface} + in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops : !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.tiling_canonicalization + } : !transform.any_op + + // Fuse the fill operation into the scf.all op. + %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %fill into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Further tile the tiled matmul + // Tile the third dimension in matmul. + // [128, 2048] @ [2048, 256] matmul is further tiled into [128, 16] @ [16, 256] matmul. + %tiled_linalg_op, %loops = transform.structured.tile_using_for %tiled_op [0, 0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Create pad op and prepare for mapping to GPU. + // Nothing has changed in the operation. + %padded, %pad, %copy = transform.structured.pad %tiled_linalg_op {copy_back_op = "none", pack_paddings = [1, 1, 1], pad_to_multiple_of = [1, 1, 1], padding_dimensions = [0, 1, 2], padding_values = [0.000000e+00 : f32, 0.000000e+00 : f32, 0.000000e+00 : f32]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + // Rewrite tensor.pad into linalg.copy. + %3 = transform.get_producer_of_operand %padded[0] : (!transform.any_op) -> !transform.any_op + %4 = transform.get_producer_of_operand %padded[1] : (!transform.any_op) -> !transform.any_op + %5 = transform.get_producer_of_operand %padded[2] : (!transform.any_op) -> !transform.any_op + %6 = transform.structured.rewrite_in_destination_passing_style %3 : (!transform.any_op) -> !transform.any_op + %7 = transform.structured.rewrite_in_destination_passing_style %4 : (!transform.any_op) -> !transform.any_op + %8 = transform.structured.rewrite_in_destination_passing_style %5 : (!transform.any_op) -> !transform.any_op + + // Tile the linalg.copy op and map it to GPU thread level, + // such that the tiled matrix are copied to GPU shared memory. + // num_threads is different from tile_sizes used above, + // as it specifies the number of tile instead of the size of the tile. + // The first transform tile the [128, 16] into [4, 4], + // and the second transform tile the [16, 256] into [2, 16]. + %tiled_op_0, %forall_op_1 = transform.structured.tile_using_forall %6 num_threads [32, 4](mapping = [#gpu.thread, #gpu.thread]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %tiled_op_2, %forall_op_3 = transform.structured.tile_using_forall %7 num_threads [8, 16](mapping = [#gpu.thread, #gpu.thread]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile the linalg.matmul op and map it to GPU warp level. + %tiled_op_4, %forall_op_5 = transform.structured.tile_using_forall %padded num_threads [2, 2](mapping = [#gpu.warp, #gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // Tile the linalg.fill op and map it to GPU warp level. + %tiled_op_6, %forall_op_7 = transform.structured.tile_using_forall %fused_op num_threads [2, 2](mapping = [#gpu.warp, #gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Perform canonicalization. + %9 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %9 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %9 : !transform.any_op + %all_loops_2 = transform.structured.match interface{LoopLikeInterface} + in %9 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_2 : !transform.any_op + transform.apply_patterns to %9 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Perform vectorization. + // Vectorize the linalg.copy, linalg.fill, and linalg.matmul operations. + %10 = transform.structured.vectorize_children_and_apply_patterns %9 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %10 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %10 : !transform.any_op + %all_loops_3 = transform.structured.match interface{LoopLikeInterface} + in %10 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_3 : !transform.any_op + transform.apply_patterns to %10 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Match bufferization.alloc_tensors inside the forall op + %scf_forall = transform.structured.match ops{["scf.forall"]} attributes{mapping = [#gpu.block, #gpu.block]} in %arg0 : (!transform.any_op) -> !transform.any_op + %alloc_tensor_ops = transform.structured.match ops{["bufferization.alloc_tensor"]} in %scf_forall : (!transform.any_op) -> !transform.any_op + + // Bufferize the alloc_tensor ops to memref.alloc ops. + // The memory_space attribute for GPU Dialect 0 means global memory, 3 means workgroup memory address, 5 means private memory address. + // According to https://discourse.llvm.org/t/rfc-memref-memory-shape-as-attribute/2229 + %buffer, %new_ops = transform.structured.bufferize_to_allocation %alloc_tensor_ops {memory_space = 3 } : !transform.any_op + + // Eliminate empty tensors and erase unnecessary inputs. + transform.structured.eliminate_empty_tensors %arg0 : !transform.any_op + %func_eras = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_eras { + transform.apply_patterns.linalg.erase_unnecessary_inputs + } : !transform.any_op + + // Bufferize the remaining operations in one time. + %11 = transform.bufferization.one_shot_bufferize %arg0 { bufferize_function_boundaries = true, function_boundary_type_conversion = 1 : i32} : (!transform.any_op) -> !transform.any_op + + // Erase dead alloc and stores. + %12 = transform.structured.match ops{["func.func"]} in %11 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %12 : (!transform.any_op) -> () + + // Generate GPU launch. + %13 = transform.structured.match ops{["func.func"]} in %11 : (!transform.any_op) -> !transform.any_op + %gpu_launch = transform.gpu.map_forall_to_blocks %13 { generate_gpu_launch } : (!transform.any_op) -> !transform.any_op + + // Rewrite bufferized scf.forall ops to distributed gpu.thread_id attribute. + %mapped = transform.gpu.map_nested_forall_to_threads %gpu_launch block_dims = [64, 2, 1] warp_size = 32 : (!transform.any_op) -> !transform.any_op + + %15 = transform.structured.match ops{["func.func"]} in %11 : (!transform.any_op) -> !transform.any_op + + // Removes unnecessary GPU barriers from the function. + // %15 = transform.buddy.eliminate_gpu_barriers %14 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %15 : !transform.any_op + %all_loops_4 = transform.structured.match interface{LoopLikeInterface} + in %15 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_4 : !transform.any_op + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Identify static memory allocations within the given region, + // and move them to a higher level (hoisting). + transform.buddy.hoist_static_alloc %15 : (!transform.any_op) -> () + + // Collects patterns for folding memref aliasing ops (memref.subview) into consumer load/store ops (affine.load, memref.load, nvgpu.ldmatrix, vector.load, vector.transfer_read, affine.store, memref.store, etc.) and other ops (e.g., memref.subview). + transform.apply_patterns to %15 { + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + // Collects patterns for extracting address computations from operations with memory accesses such that these memory accesses use only a base pointer. + transform.apply_patterns to %15 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op + // Perform canonicalization. + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %15 : !transform.any_op + %all_loops_5 = transform.structured.match interface{LoopLikeInterface} + in %15 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_5 : !transform.any_op + transform.apply_patterns to %15 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Adds patterns that unroll vectors to a native tile size for GPUs with mma operations + transform.apply_patterns to %15 { + transform.apply_patterns.buddy.unroll_vectors_gpu_mma_sync + } : !transform.any_op + + // Insert a gpu.barrier after a given scf.for loop + %16 = transform.structured.match ops{["scf.for"]} in %15 : (!transform.any_op) -> !transform.op<"scf.for"> + // transform.buddy.synchronize_loop %16 : (!transform.op<"scf.for">) -> () + + + transform.apply_patterns to %15 { + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + transform.apply_cse to %15 : !transform.any_op + + // Hoist vector.transfer_read / vector.transfer_write pairs out of immediately enclosing scf::ForOp iteratively + // Warning: Deprecated + %17 = transform.structured.hoist_redundant_vector_transfers %15 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + %all_loops_6 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_6 : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // This converts slices of operations containing vector.contract op into + // mma operations, targetting warp level tensorcore operations. + transform.buddy.vector.vector_to_mma_conversion %17 {use_mma_sync} : (!transform.any_op) -> () + + // %18 = transform.buddy.eliminate_gpu_barriers %17 : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + %all_loops_7 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_7 : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + %19 = transform.structured.match ops{["gpu.launch"]} in %17 : (!transform.any_op) -> !transform.any_op + %fwfa = transform.structured.match ops{["memref.alloc"]} in %19 : (!transform.any_op) -> !transform.op<"memref.alloc"> + + // Do multi-buffering/array expansion to remove dependencies on the temporary allocation between consecutive loop iterations. + transform.memref.multibuffer %fwfa {factor = 3 : i64, skip_analysis} : (!transform.op<"memref.alloc">) -> !transform.any_op + + transform.apply_patterns to %17 { + transform.apply_patterns.vector.transfer_to_scf full_unroll = true + } : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + %all_loops_8 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_8 : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.vector.lower_masked_transfers + } : !transform.any_op + + // Convert sync copies to shared memory to async. + // transform.buddy.create_async_groups %17 {use_mma_sync} : (!transform.any_op) -> () + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + %all_loops_9 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_9 : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + + + %20 = transform.structured.match ops{["nvgpu.mma.sync"]} in %17 : (!transform.any_op) -> !transform.any_op + %21 = transform.get_parent_op %20 {deduplicate, op_name = "scf.for"} : (!transform.any_op) -> !transform.any_op + // This applies software pipelining to a given scf.for loop. + // The pipelining strategy will look for a copy to shared memory and pipeline it to overlap it with the rest of the loop. + // %22 = transform.buddy.pipeline_shared_memory_copies %21 {depth = 3 : i64, use_mma_sync, peel_epilogue} : (!transform.any_op) -> !transform.any_op + + // Perform canonicalization. + transform.apply_patterns to %17 { + transform.apply_patterns.vector.lower_masks + } : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.vector.materialize_masks + } : !transform.any_op + transform.apply_patterns to %17 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + + %all_loops_10 = transform.structured.match interface{LoopLikeInterface} + in %17 + : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %all_loops_10 : !transform.any_op + transform.apply_cse to %17 : !transform.any_op + + transform.yield + } +} // module diff --git a/examples/BuddyGen/.gitignore b/examples/BuddyGen/.gitignore new file mode 100644 index 0000000000..df9389428a --- /dev/null +++ b/examples/BuddyGen/.gitignore @@ -0,0 +1,4 @@ +log.mlir +log.ll +log.s +a.out diff --git a/examples/BuddyGen/GenMemRef.cpp b/examples/BuddyGen/GenMemRef.cpp new file mode 100644 index 0000000000..8ca2526b79 --- /dev/null +++ b/examples/BuddyGen/GenMemRef.cpp @@ -0,0 +1,43 @@ +//===- GenMemRef.cpp ------------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// + +// $ export LLVM_DIR=$PWD/../../llvm/ +// $ export LLVM_BUILD_DIR=$LLVM_DIR/build +// $ c++ GenMemRef.cpp \ + -I $LLVM_DIR/llvm/include/ -I $LLVM_BUILD_DIR/include/ \ + -I $LLVM_DIR/mlir/include/ -I $LLVM_BUILD_DIR/tools/mlir/include/ \ + -L$LLVM_BUILD_DIR/lib -lMLIRIR -lMLIRParser -lMLIRSupport -lLLVMCore \ + -lLLVMSupport -lncurses -ltinfo -lstdc++ -lLLVMDemangle \ + -o a.out +// $ ./a.out + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" + +int main() { + mlir::MLIRContext context; + mlir::OpBuilder builder(&context); + mlir::Type eleType = builder.getF64Type(); + // Target memref type: + // `memref>` + mlir::MemRefType memrefType = mlir::MemRefType::get( + {mlir::ShapedType::kDynamic}, eleType, + mlir::StridedLayoutAttr::get( + &context, /*offset=*/mlir::ShapedType::kDynamic, /*strides=*/{1})); + memrefType.dump(); + return 0; +} diff --git a/examples/BuddyLeNet/CMakeLists.txt b/examples/BuddyLeNet/CMakeLists.txt index 9698f617bc..b765218c68 100644 --- a/examples/BuddyLeNet/CMakeLists.txt +++ b/examples/BuddyLeNet/CMakeLists.txt @@ -6,25 +6,26 @@ add_custom_command( add_custom_command( OUTPUT forward.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-opt + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt -pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/forward.o + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/forward.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir COMMENT "Building forward.o" VERBATIM) add_custom_command( OUTPUT subgraph0.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | ${BUDDY_BINARY_DIR}/buddy-opt -eliminate-empty-tensors - -convert-tensor-to-linalg + -convert-tensor-to-linalg -linalg-bufferize + -batchmatmul-optimize -convert-linalg-to-affine-loops -lower-affine -func-bufferize-dynamic-offset @@ -42,9 +43,9 @@ add_custom_command( -convert-arith-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/subgraph0.o + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/subgraph0.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir COMMENT "Building subgraph0.o" VERBATIM) @@ -54,7 +55,8 @@ add_library(LENET STATIC subgraph0.o forward.o) SET_TARGET_PROPERTIES(LENET PROPERTIES LINKER_LANGUAGE C) add_executable(buddy-lenet-run buddy-lenet-main.cpp) -target_link_directories(buddy-lenet-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR}) +target_link_directories(buddy-lenet-run PRIVATE ${LLVM_LIBRARY_DIR}) + +set(BUDDY_LENET_LIBS LENET mlir_c_runner_utils ${PNG_LIBRARIES}) -set(BUDDY_LENET_LIBS LENET mlir_c_runner_utils ${OpenCV_LIBS}) target_link_libraries(buddy-lenet-run ${BUDDY_LENET_LIBS}) diff --git a/examples/BuddyLeNet/README.md b/examples/BuddyLeNet/README.md index 5988edbe7b..b9b0c44a5f 100644 --- a/examples/BuddyLeNet/README.md +++ b/examples/BuddyLeNet/README.md @@ -25,8 +25,7 @@ $ cmake -G Ninja .. \ -DCMAKE_BUILD_TYPE=RELEASE \ -DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \ -DPython3_EXECUTABLE=$(which python3) \ - -DBUDDY_ENABLE_OPENCV=ON \ - -DOpenCV_DIR= + -DBUDDY_MLIR_ENABLE_DIP_LIB=ON $ ninja $ ninja check-buddy ``` diff --git a/examples/BuddyLeNet/buddy-lenet-main.cpp b/examples/BuddyLeNet/buddy-lenet-main.cpp index 4e2dc2efe0..2fc8b0fbe3 100644 --- a/examples/BuddyLeNet/buddy-lenet-main.cpp +++ b/examples/BuddyLeNet/buddy-lenet-main.cpp @@ -15,41 +15,24 @@ //===----------------------------------------------------------------------===// #include -#include +#include #include +#include #include #include #include #include -#include #include #include #include constexpr size_t ParamsSize = 44426; -const std::string ImgName = "3.png"; +const std::string ImgName = "1-28*28.png"; /// Declare LeNet forward function. extern "C" void _mlir_ciface_forward(MemRef *output, MemRef *arg0, - Img *input); - -/// Function for preprocessing the image to match model input requirements. -const cv::Mat imagePreprocessing() { - // Get the directory of the LeNet example and construct the image path. - std::string lenetDir = getenv("LENET_EXAMPLE_PATH"); - std::string imgPath = lenetDir + "/images/" + ImgName; - // Read the image in grayscale mode. - cv::Mat inputImage = cv::imread(imgPath, cv::IMREAD_GRAYSCALE); - assert(!inputImage.empty() && "Could not read the image."); - cv::Mat resizedImage; - int imageWidth = 28; - int imageHeight = 28; - // Resize the image to 28x28 pixels. - cv::resize(inputImage, resizedImage, cv::Size(imageWidth, imageHeight), - cv::INTER_LINEAR); - return resizedImage; -} + dip::Image *input); /// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } @@ -112,19 +95,16 @@ int main() { const std::string title = "LeNet Inference Powered by Buddy Compiler"; std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; - // Preprocess the image to match the input requirements of the model. - cv::Mat image = imagePreprocessing(); - - // Define the sizes of the input and output tensors. - intptr_t sizesInput[4] = {1, 1, 28, 28}; + // Define the sizes of the output tensors. intptr_t sizesOutput[2] = {1, 10}; // Create input and output containers for the image and model output. - Img input(image, sizesInput, true); + std::string lenetDir = getenv("LENET_EXAMPLE_PATH"); + std::string imgPath = lenetDir + "/images/" + ImgName; + dip::Image input(imgPath, dip::DIP_GRAYSCALE, true /* norm */); MemRef output(sizesOutput); // Load model parameters from the specified file. - std::string lenetDir = getenv("LENET_EXAMPLE_PATH"); std::string paramsDir = lenetDir + "/arg0.data"; MemRef paramsContainer({ParamsSize}); loadParameters(paramsDir, paramsContainer); diff --git a/examples/BuddyLeNet/fake-lenet.mlir b/examples/BuddyLeNet/fake-lenet.mlir index 48d91a7fd3..d7d80a533a 100644 --- a/examples/BuddyLeNet/fake-lenet.mlir +++ b/examples/BuddyLeNet/fake-lenet.mlir @@ -1,5 +1,6 @@ module { func.func private @printMemrefF32(%ptr : tensor<*xf32>) + func.func private @rtclock() -> f64 func.func @forward(%arg0: tensor<44426xf32>, %arg1: tensor<1x1x28x28xf32>) -> tensor<1x10xf32> { %extracted_slice = tensor.extract_slice %arg0[0] [150] [1] : tensor<44426xf32> to tensor<150xf32> @@ -81,10 +82,16 @@ module { %fake_params = arith.constant dense<1.0> : tensor<44426xf32> %fake_input = arith.constant dense<2.0> : tensor<1x1x28x28xf32> + %t_start = call @rtclock() : () -> f64 %fake_output = call @forward(%fake_params, %fake_input) : (tensor<44426xf32>, tensor<1x1x28x28xf32>) -> tensor<1x10xf32> + %t_end = call @rtclock() : () -> f64 %tensor_unranked = tensor.cast %fake_output : tensor<1x10xf32> to tensor<*xf32> call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () + + %time = arith.subf %t_end, %t_start : f64 + vector.print %time : f64 + return } } diff --git a/examples/BuddyLeNet/images/0-28*28.png b/examples/BuddyLeNet/images/0-28*28.png new file mode 100644 index 0000000000..a7a3b2a327 Binary files /dev/null and b/examples/BuddyLeNet/images/0-28*28.png differ diff --git a/examples/BuddyLeNet/images/1-28*28.png b/examples/BuddyLeNet/images/1-28*28.png new file mode 100644 index 0000000000..0f25e8b026 Binary files /dev/null and b/examples/BuddyLeNet/images/1-28*28.png differ diff --git a/examples/BuddyLeNet/images/8-16bit-565-28*28.bmp b/examples/BuddyLeNet/images/8-16bit-565-28*28.bmp new file mode 100644 index 0000000000..d4a43393d3 Binary files /dev/null and b/examples/BuddyLeNet/images/8-16bit-565-28*28.bmp differ diff --git a/examples/BuddyLeNet/images/8-24bit-28*28.bmp b/examples/BuddyLeNet/images/8-24bit-28*28.bmp new file mode 100644 index 0000000000..6591e87be8 Binary files /dev/null and b/examples/BuddyLeNet/images/8-24bit-28*28.bmp differ diff --git a/examples/BuddyLeNet/images/8.bmp b/examples/BuddyLeNet/images/8.bmp new file mode 100644 index 0000000000..7a9e02a295 Binary files /dev/null and b/examples/BuddyLeNet/images/8.bmp differ diff --git a/examples/BuddyLeNet/makefile b/examples/BuddyLeNet/makefile index 6f06642728..fe87b6da1a 100644 --- a/examples/BuddyLeNet/makefile +++ b/examples/BuddyLeNet/makefile @@ -1,30 +1,33 @@ #!/bin/bash -BUDDY_OPT := ../../build/bin/buddy-opt -MLIR_OPT := ../../llvm/build/bin/mlir-opt -MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate -MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner -LLC := ../../llvm/build/bin/llc -OPT_FLAG := -O0 +BUDDY_BUILD_DIR := ../../build/ +LLVM_BUILD_DIR := ../../llvm/build/ +BUDDY_OPT := ${BUDDY_BUILD_DIR}/bin/buddy-opt +MLIR_OPT := ${LLVM_BUILD_DIR}/bin/mlir-opt +MLIR_TRANSLATE := ${LLVM_BUILD_DIR}/bin/mlir-translate +MLIR_CPU_RUNNER := ${LLVM_BUILD_DIR}/bin/mlir-cpu-runner +LLC := ${LLVM_BUILD_DIR}/bin/llc +OPT_FLAG := -O3 ifeq ($(shell uname),Linux) -MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.so -MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.so -MLIR_ASYNC_RUNTIME := ../../llvm/build/lib/libmlir_async_runtime.so +MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.so +MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.so +MLIR_ASYNC_RUNTIME := ${LLVM_BUILD_DIR}/lib/libmlir_async_runtime.so MTRIPLE := x86_64-unknown-linux-gnu else ifeq ($(shell uname),Darwin) -MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.dylib -MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.dylib -MLIR_ASYNC_RUNTIME := ./../llvm/build/lib/libmlir_async_runtime.dylib +MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.dylib +MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.dylib +MLIR_ASYNC_RUNTIME := ${LLVM_BUILD_DIR}/lib/libmlir_async_runtime.dylib MTRIPLE := x86_64-apple-darwin endif buddy-lenet-lower: - @${MLIR_OPT} ./fake-lenet.mlir \ + @${BUDDY_OPT} ./fake-lenet.mlir \ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | \ - ${MLIR_OPT} \ + ${BUDDY_OPT} \ -eliminate-empty-tensors \ -convert-tensor-to-linalg \ -linalg-bufferize \ + -batchmatmul-optimize \ -convert-linalg-to-affine-loops \ -lower-affine \ -func-bufferize \ @@ -38,16 +41,15 @@ buddy-lenet-lower: -convert-arith-to-llvm \ -finalize-memref-to-llvm \ -convert-scf-to-cf \ - -llvm-request-c-wrappers \ -convert-arith-to-llvm \ -convert-func-to-llvm \ -reconcile-unrealized-casts \ -o ./log.mlir buddy-lenet-translate: - @${MLIR_OPT} ./fake-lenet.mlir \ + @${BUDDY_OPT} ./fake-lenet.mlir \ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | \ - ${MLIR_OPT} \ + ${BUDDY_OPT} \ -eliminate-empty-tensors \ -convert-tensor-to-linalg \ -linalg-bufferize \ @@ -64,7 +66,6 @@ buddy-lenet-translate: -convert-arith-to-llvm \ -finalize-memref-to-llvm \ -convert-scf-to-cf \ - -llvm-request-c-wrappers \ -convert-arith-to-llvm \ -convert-func-to-llvm \ -reconcile-unrealized-casts | \ @@ -72,9 +73,9 @@ buddy-lenet-translate: buddy-lenet-run: - @${MLIR_OPT} ./fake-lenet.mlir \ + @${BUDDY_OPT} ./fake-lenet.mlir \ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | \ - ${MLIR_OPT} \ + ${BUDDY_OPT} \ -eliminate-empty-tensors \ -convert-tensor-to-linalg \ -linalg-bufferize \ @@ -91,7 +92,33 @@ buddy-lenet-run: -convert-arith-to-llvm \ -finalize-memref-to-llvm \ -convert-scf-to-cf \ - -llvm-request-c-wrappers \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +buddy-lenet-opt-run: + @${BUDDY_OPT} ./fake-lenet.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | \ + ${BUDDY_OPT} \ + -eliminate-empty-tensors \ + -convert-tensor-to-linalg \ + -linalg-bufferize \ + -batchmatmul-optimize \ + -convert-linalg-to-affine-loops \ + -lower-affine \ + -func-bufferize \ + -arith-bufferize \ + -tensor-bufferize \ + -buffer-deallocation \ + -finalizing-bufferize \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ -convert-arith-to-llvm \ -convert-func-to-llvm \ -reconcile-unrealized-casts | \ diff --git a/examples/BuddyLlama/CMakeLists.txt b/examples/BuddyLlama/CMakeLists.txt index 97aa736cb7..a6bfc2f742 100644 --- a/examples/BuddyLlama/CMakeLists.txt +++ b/examples/BuddyLlama/CMakeLists.txt @@ -6,14 +6,14 @@ add_custom_command( add_custom_command( OUTPUT forward.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/forward.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/forward.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | ${BUDDY_BINARY_DIR}/buddy-opt -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize - -matmul-paralell-vectorization-optimize + -matmul-parallel-vectorization-optimize -batchmatmul-optimize -convert-linalg-to-affine-loops -affine-loop-fusion @@ -40,9 +40,9 @@ add_custom_command( -convert-math-to-libm -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLlama/forward.o DEPENDS buddy-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/forward.mlir COMMENT "Building forward.o " @@ -50,14 +50,14 @@ add_custom_command( add_custom_command( OUTPUT subgraph.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | ${BUDDY_BINARY_DIR}/buddy-opt -arith-expand -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize - -matmul-paralell-vectorization-optimize + -matmul-parallel-vectorization-optimize -batchmatmul-optimize -convert-linalg-to-affine-loops -affine-loop-fusion @@ -85,9 +85,9 @@ add_custom_command( -convert-math-to-libm -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLlama/subgraph.o DEPENDS buddy-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir COMMENT "Building subgraph.o " @@ -107,7 +107,7 @@ SET_TARGET_PROPERTIES( LINKER_LANGUAGE C) add_executable(buddy-llama-run llama-main.cpp) -target_link_directories(buddy-llama-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR}) +target_link_directories(buddy-llama-run PRIVATE ${LLVM_LIBRARY_DIR}) set(BUDDY_LLAMA_LIBS LLAMA diff --git a/examples/BuddyLlama/import-llama2.py b/examples/BuddyLlama/import-llama2.py index fbd12e5bff..2903d6bd81 100644 --- a/examples/BuddyLlama/import-llama2.py +++ b/examples/BuddyLlama/import-llama2.py @@ -1,11 +1,3 @@ -import os -import torch -import torch._dynamo as dynamo -from transformers import LlamaForCausalLM, LlamaTokenizer -from torch._inductor.decomposition import decompositions as inductor_decomp -import numpy - -from buddy.compiler.frontend import DynamoCompiler # ===- import-llama2.py -------------------------------------------------------- # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,6 +17,15 @@ # This is the test of llama2 model. # # ===--------------------------------------------------------------------------- + +import os +import torch +import torch._dynamo as dynamo +from transformers import LlamaForCausalLM, LlamaTokenizer +from torch._inductor.decomposition import decompositions as inductor_decomp +import numpy + +from buddy.compiler.frontend import DynamoCompiler from buddy.compiler.ops import tosa from buddy.compiler.graph import GraphDriver from buddy.compiler.graph.transform import simply_fuse diff --git a/examples/BuddyMatmul/.gitignore b/examples/BuddyMatmul/.gitignore new file mode 100644 index 0000000000..80a243fa81 --- /dev/null +++ b/examples/BuddyMatmul/.gitignore @@ -0,0 +1 @@ +log.* diff --git a/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir new file mode 100644 index 0000000000..58c9142398 --- /dev/null +++ b/examples/BuddyMatmul/linalg-batchmatmul-f32.mlir @@ -0,0 +1,82 @@ +// RUN: buddy-opt %s \ +// RUN: -batchmatmul-optimize \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -expand-strided-metadata \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @printMemrefF32(memref<*xf32>) + +func.func @batch_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.batch_matmul + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return +} + +func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + scf.for %idx2 = %c0 to %arg2 step %c1 { + memref.store %arg4, %0[%idx0, %idx1, %idx2] : memref + } + } + } + return %0 : memref +} + +func.func @main(){ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c576 = arith.constant 576 : index + %c1024 = arith.constant 1024 : index + %c1000 = arith.constant 1000 : index + %f0 = arith.constant 0.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + + %m0 = call @alloc_f32(%c1, %c1, %c576, %f2) : (index, index, index, f32) -> memref + %m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref + %m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref + + call @batch_matmul(%m0, %m1, %m2) : (memref, memref, memref) -> () + + %printed_m2 = memref.cast %m2 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data = + // CHECK-NEXT: [ + // CHECK: [ + // CHECK: [3456{{(, 3456)*}}] + call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> () + + %m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref + %m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref + %m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref + + call @batch_matmul(%m3, %m4, %m5) : (memref, memref, memref) -> () + + %printed_m5 = memref.cast %m5 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data = + // CHECK-NEXT: [ + // CHECK: [ + // CHECK: [6144{{(, 6144)*}}] + call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () + + return +} diff --git a/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir b/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir new file mode 100644 index 0000000000..26a4458c53 --- /dev/null +++ b/examples/BuddyMatmul/linalg-transposematmulb-f32.mlir @@ -0,0 +1,75 @@ +// RUN: buddy-opt %s \ +// RUN: -matmul-transpose-b-vectorization \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -expand-strided-metadata \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @printMemrefF32(memref<*xf32>) + +func.func @test(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c: memref) + return + } + +func.func @alloc_f32(%arg0: index, %arg1: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1) : memref + scf.for %idx0 = %c0 to %arg0 step %c1 { + scf.for %idx1 = %c0 to %arg1 step %c1 { + memref.store %arg4, %0[%idx0, %idx1] : memref + } + } + return %0 : memref +} + +func.func @main(){ + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %c3 = arith.constant 3 : index + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + + %m0 = call @alloc_f32(%c32,%c1024, %f1) : (index, index, f32) -> memref + %m1 = call @alloc_f32(%c32,%c1024, %f1) : (index, index, f32) -> memref + %m2 = call @alloc_f32(%c32,%c32, %f0) : (index, index, f32) -> memref + + call @test(%m0, %m1, %m2) : (memref, memref, memref) -> () + + %printed_m2 = memref.cast %m2 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [32, 32] strides = [32, 1] data = + // CHECK-NEXT: [ + // CHECK: [1024{{(, 1024)*}}] + call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> () + + %m3 = call @alloc_f32(%c3,%c3, %f1) : (index, index, f32) -> memref + %m4 = call @alloc_f32(%c3,%c3, %f1) : (index, index, f32) -> memref + %m5 = call @alloc_f32(%c3,%c3, %f0) : (index, index, f32) -> memref + + call @test(%m3, %m4, %m5) : (memref, memref, memref) -> () + + %printed_m5 = memref.cast %m5 : memref to memref<*xf32> + + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] data = + // CHECK-NEXT: [ + // CHECK: [3{{(, 3)*}}] + call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> () + + return +} diff --git a/examples/BuddyMatmul/makefile b/examples/BuddyMatmul/makefile new file mode 100644 index 0000000000..0940d608da --- /dev/null +++ b/examples/BuddyMatmul/makefile @@ -0,0 +1,55 @@ +#!/bin/bash +BUDDY_BUILD_DIR := ../../build/ +LLVM_BUILD_DIR := ../../llvm/build/ +BUDDY_OPT := ${BUDDY_BUILD_DIR}/bin/buddy-opt +MLIR_OPT := ${LLVM_BUILD_DIR}/bin/mlir-opt +MLIR_TRANSLATE := ${LLVM_BUILD_DIR}/bin/mlir-translate +MLIR_CPU_RUNNER := ${LLVM_BUILD_DIR}/bin/mlir-cpu-runner +LLC := ${LLVM_BUILD_DIR}/bin/llc +OPT_FLAG := -O0 + +ifeq ($(shell uname),Linux) +MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.so +MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.so +MTRIPLE := x86_64-unknown-linux-gnu +else ifeq ($(shell uname),Darwin) +MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.dylib +MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.dylib +MTRIPLE := x86_64-apple-darwin +endif + +linalg-batchmatmul-f32-run: + @${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \ + -batchmatmul-optimize \ + -convert-linalg-to-affine-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -expand-strided-metadata \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-matmul-transpose-b-f32-run: + @${BUDDY_OPT} ./linalg-transposematmulb-f32.mlir\ + -matmul-transpose-b-vectorization \ + -convert-linalg-to-affine-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -expand-strided-metadata \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/BuddyMobileNetV3/CMakeLists.txt b/examples/BuddyMobileNetV3/CMakeLists.txt index e55cc61711..ef60c7e931 100644 --- a/examples/BuddyMobileNetV3/CMakeLists.txt +++ b/examples/BuddyMobileNetV3/CMakeLists.txt @@ -1,6 +1,5 @@ add_custom_command( OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/arg0.data - ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/arg1.data ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/subgraph0.mlir COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/buddy-mobilenetv3-import.py @@ -10,21 +9,21 @@ add_custom_command( add_custom_command( OUTPUT forward.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/forward.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/forward.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), \ empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, \ func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-opt + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt -pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), \ eliminate-empty-tensors, func.func(llvm-request-c-wrappers), \ convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, \ convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, \ convert-func-to-llvm, reconcile-unrealized-casts)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 -o ${BUDDY_BINARY_DIR}/../examples/BuddyMobileNetV3/forward.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/forward.mlir COMMENT "Building forward.o" @@ -55,9 +54,9 @@ add_custom_command( -expand-strided-metadata -finalize-memref-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 -o ${BUDDY_BINARY_DIR}/../examples/BuddyMobileNetV3/subgraph0.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/subgraph0.mlir buddy-opt @@ -69,7 +68,7 @@ add_library(MOBILENETV3 STATIC subgraph0.o forward.o) SET_TARGET_PROPERTIES(MOBILENETV3 PROPERTIES LINKER_LANGUAGE C) add_executable(buddy-mobilenetv3-run buddy-mobilenetv3-main.cpp) -target_link_directories(buddy-mobilenetv3-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR}) +target_link_directories(buddy-mobilenetv3-run PRIVATE ${LLVM_LIBRARY_DIR}) -set(BUDDY_MOBILENETV3_LIBS MOBILENETV3 mlir_c_runner_utils ${OpenCV_LIBS}) +set(BUDDY_MOBILENETV3_LIBS MOBILENETV3 mlir_c_runner_utils BuddyLibDIP ${PNG_LIBRARIES}) target_link_libraries(buddy-mobilenetv3-run ${BUDDY_MOBILENETV3_LIBS}) diff --git a/examples/BuddyMobileNetV3/Labels.txt b/examples/BuddyMobileNetV3/Labels.txt index fe811239d8..8bdc20a086 100644 --- a/examples/BuddyMobileNetV3/Labels.txt +++ b/examples/BuddyMobileNetV3/Labels.txt @@ -1,4 +1,3 @@ -background tench goldfish great white shark @@ -133,7 +132,7 @@ flamingo little blue heron American egret bittern -crane +crane bird limpkin European gallinule American coot @@ -638,7 +637,7 @@ magnetic compass mailbag mailbox maillot -maillot +maillot tank suit manhole cover maraca marimba diff --git a/examples/BuddyMobileNetV3/README.md b/examples/BuddyMobileNetV3/README.md index 1146addb69..a55cd74304 100644 --- a/examples/BuddyMobileNetV3/README.md +++ b/examples/BuddyMobileNetV3/README.md @@ -16,8 +16,8 @@ $ cmake -G Ninja .. \ -DCMAKE_BUILD_TYPE=RELEASE \ -DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \ -DPython3_EXECUTABLE=$(which python3) \ - -DBUDDY_ENABLE_OPENCV=ON \ - -DOpenCV_DIR= + -DBUDDY_MLIR_ENABLE_DIP_LIB=ON \ + -DBUDDY_ENABLE_PNG=ON $ ninja $ ninja check-buddy ``` diff --git a/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py b/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py index 2403800bf9..704b8fc2e3 100644 --- a/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py +++ b/examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py @@ -38,9 +38,17 @@ "The environment variable 'MOBILENETV3_MODEL_PATH' is not set or is invalid." ) -model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1, pretrained=True) +model = models.mobilenet_v3_small( + weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1, pretrained=True +) model = model.eval() +# Remove the num_batches_tracked attribute. +for layer in model.modules(): + if isinstance(layer, torch.nn.BatchNorm2d): + if hasattr(layer, "num_batches_tracked"): + del layer.num_batches_tracked + # Initialize Dynamo Compiler with specific configurations as an importer. dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, @@ -68,11 +76,10 @@ float32_param = np.concatenate( - [param.detach().numpy().reshape([-1]) for param in params if param.dtype == torch.float32] + [ + param.detach().numpy().reshape([-1]) + for param in params + if param.dtype == torch.float32 + ] ) float32_param.tofile(Path(current_path) / "arg0.data") - -int64_param = np.concatenate( - [param.detach().numpy().reshape([-1]) for param in params if param.dtype == torch.int64] -) -int64_param.tofile(Path(current_path) / "arg1.data") diff --git a/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp b/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp index 68d9d15411..90defb895e 100644 --- a/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp +++ b/examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp @@ -1,4 +1,4 @@ -//===- MobileNetBenchmark.cpp ---------------------------------------------===// +//===- buddy-mobilenetv3-main.cpp -----------------------------------------===// // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,13 +15,14 @@ //===----------------------------------------------------------------------===// #include -#include +#include +#include #include +#include #include #include #include #include -#include #include #include #include @@ -31,61 +32,44 @@ const std::string ImgName = "dog.png"; // Declare the mobilenet C interface. extern "C" void _mlir_ciface_forward(MemRef *output, - MemRef *arg0, - MemRef *arg1, - Img *input); - -const cv::Mat imagePreprocessing() { - // Get the directory of the LeNet example and construct the image path. - std::string mobilenetDir = getenv("MOBILENETV3_EXAMPLE_PATH"); - std::string imgPath = mobilenetDir + "/images/" + ImgName; - // Read the image in grayscale mode. - cv::Mat inputImage = cv::imread(imgPath, cv::IMREAD_GRAYSCALE); - assert(!inputImage.empty() && "Could not read the image."); - cv::Mat resizedImage; - int imageWidth = 224; - int imageHeight = 224; - // Resize the image to 224x224 pixels. - cv::resize(inputImage, resizedImage, cv::Size(imageWidth, imageHeight), - cv::INTER_LINEAR); - return resizedImage; -} + MemRef *arg0, + MemRef *input); /// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } -void loadParameters(const std::string &floatParamPath, - const std::string &int64ParamPath, - MemRef &floatParam, - MemRef &int64Param) { - std::ifstream floatParamFile(floatParamPath, std::ios::in | std::ios::binary); - if (!floatParamFile.is_open()) { - std::string errMsg = "Failed to open float param file: " + - std::filesystem::canonical(floatParamPath).string(); - throw std::runtime_error(errMsg); - } - floatParamFile.read(reinterpret_cast(floatParam.getData()), - floatParam.getSize() * sizeof(float)); - if (floatParamFile.fail()) { - throw std::runtime_error("Failed to read float param file"); - } - floatParamFile.close(); - - - std::ifstream int64ParamFile(int64ParamPath, std::ios::in | std::ios::binary); - if (!int64ParamFile.is_open()) { - std::string errMsg = "Failed to open int64 param file: " + - std::filesystem::canonical(int64ParamPath).string(); - throw std::runtime_error(errMsg); +/// Load parameters into data container. +void loadParameters(const std::string ¶mFilePath, + MemRef ¶ms) { + const auto loadStart = std::chrono::high_resolution_clock::now(); + // Open the parameter file in binary mode. + std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary); + if (!paramFile.is_open()) { + throw std::runtime_error("[Error] Failed to open params file!"); } - int64ParamFile.read(reinterpret_cast(int64Param.getData()), - int64Param.getSize() * sizeof(long long)); - if (int64ParamFile.fail()) { - throw std::runtime_error("Failed to read int64 param file"); + printLogLabel(); + std::cout << "Loading params..." << std::endl; + printLogLabel(); + // Print the canonical path of the parameter file. + std::cout << "Params file: " << std::filesystem::canonical(paramFilePath) + << std::endl; + // Read the parameter data into the provided memory reference. + paramFile.read(reinterpret_cast(params.getData()), + sizeof(float) * (params.getSize())); + if (paramFile.fail()) { + throw std::runtime_error("Error occurred while reading params file!"); } - int64ParamFile.close(); + paramFile.close(); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Params load time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; } + // Softmax function. void softmax(float *input, size_t size) { size_t i; @@ -110,8 +94,7 @@ void softmax(float *input, size_t size) { std::string getLabel(int idx) { std::string mobilenetDir = getenv("MOBILENETV3_EXAMPLE_PATH"); - std::ifstream in( - mobilenetDir + "Labels.txt"); + std::ifstream in(mobilenetDir + "Labels.txt"); assert(in.is_open() && "Could not read the label file."); std::string label; for (int i = 0; i < idx; ++i) @@ -126,27 +109,26 @@ int main() { const std::string title = "MobileNetV3 Inference Powered by Buddy Compiler"; std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; - // Preprocess the image to match the input requirements of the model. - cv::Mat image = imagePreprocessing(); - // Define the sizes of the input and output tensors. - intptr_t sizesInput[4] = {1, 3, 224, 224}; intptr_t sizesOutput[2] = {1, 1000}; // Create input and output containers for the image and model output. - Img input(image, sizesInput, true); + std::string mobilenetDir = getenv("MOBILENETV3_EXAMPLE_PATH"); + std::string imgPath = mobilenetDir + "/images/" + ImgName; + dip::Image input(imgPath, dip::DIP_RGB, true /* norm */); + MemRef inputResize = dip::Resize4D_NCHW( + &input, dip::INTERPOLATION_TYPE::BILINEAR_INTERPOLATION, + {1, 3, 224, 224} /*{image_cols, image_rows}*/); + MemRef output(sizesOutput); // Load model parameters from the specified file. - std::string mobilenetDir = getenv("MOBILENETV3_EXAMPLE_PATH"); std::string paramsDir = mobilenetDir + "/arg0.data"; - std::string intDir = mobilenetDir + "/arg1.data"; - MemRef paramsContainerf32({ParamsSize}); - MemRef ParamsContainerInt64({34}); - loadParameters(paramsDir, intDir, paramsContainerf32, ParamsContainerInt64); + MemRef paramsContainer({ParamsSize}); + loadParameters(paramsDir, paramsContainer); // Call the forward function of the model. - _mlir_ciface_forward(&output, ¶msContainerf32, &ParamsContainerInt64, &input); - + _mlir_ciface_forward(&output, ¶msContainer, &inputResize); + auto out = output.getData(); softmax(out, 1000); // Find the classification and print the result. diff --git a/examples/BuddyMobileNetV3/images/curtain-224*224.png b/examples/BuddyMobileNetV3/images/curtain-224*224.png new file mode 100644 index 0000000000..2fa9c06449 Binary files /dev/null and b/examples/BuddyMobileNetV3/images/curtain-224*224.png differ diff --git a/examples/BuddyMobileNetV3/images/curtain.png b/examples/BuddyMobileNetV3/images/curtain.png index 1ae383d359..67a54dbdde 100644 Binary files a/examples/BuddyMobileNetV3/images/curtain.png and b/examples/BuddyMobileNetV3/images/curtain.png differ diff --git a/examples/BuddyMobileNetV3/images/dog-224*224.png b/examples/BuddyMobileNetV3/images/dog-224*224.png new file mode 100644 index 0000000000..4c6649714c Binary files /dev/null and b/examples/BuddyMobileNetV3/images/dog-224*224.png differ diff --git a/examples/BuddyMobileNetV3/images/dog-32bit_224*224.bmp b/examples/BuddyMobileNetV3/images/dog-32bit_224*224.bmp new file mode 100644 index 0000000000..201f030d7c Binary files /dev/null and b/examples/BuddyMobileNetV3/images/dog-32bit_224*224.bmp differ diff --git a/examples/BuddyMobileNetV3/images/dog.bmp b/examples/BuddyMobileNetV3/images/dog.bmp new file mode 100644 index 0000000000..12f0e0dd11 Binary files /dev/null and b/examples/BuddyMobileNetV3/images/dog.bmp differ diff --git a/examples/BuddyMobileNetV3/images/dog.png b/examples/BuddyMobileNetV3/images/dog.png index 12f0e0dd11..4c6000a1fa 100644 Binary files a/examples/BuddyMobileNetV3/images/dog.png and b/examples/BuddyMobileNetV3/images/dog.png differ diff --git a/examples/BuddyMobileNetV3/images/ice-cream-224*224.png b/examples/BuddyMobileNetV3/images/ice-cream-224*224.png new file mode 100644 index 0000000000..1cd06efd4e Binary files /dev/null and b/examples/BuddyMobileNetV3/images/ice-cream-224*224.png differ diff --git a/examples/BuddyMobileNetV3/images/ice-cream-24bit-224*224.bmp b/examples/BuddyMobileNetV3/images/ice-cream-24bit-224*224.bmp new file mode 100644 index 0000000000..75ad4012e0 Binary files /dev/null and b/examples/BuddyMobileNetV3/images/ice-cream-24bit-224*224.bmp differ diff --git a/examples/BuddyMobileNetV3/images/ice-cream.png b/examples/BuddyMobileNetV3/images/ice-cream.png index 209d8999d6..9bb408cea7 100644 Binary files a/examples/BuddyMobileNetV3/images/ice-cream.png and b/examples/BuddyMobileNetV3/images/ice-cream.png differ diff --git a/examples/BuddyMobileNetV3/images/kite.png b/examples/BuddyMobileNetV3/images/kite.png index 23ffe9613d..51912cddc6 100644 Binary files a/examples/BuddyMobileNetV3/images/kite.png and b/examples/BuddyMobileNetV3/images/kite.png differ diff --git a/examples/BuddyMobileNetV3/images/traffic-light-24bit-224*224.bmp b/examples/BuddyMobileNetV3/images/traffic-light-24bit-224*224.bmp new file mode 100644 index 0000000000..948a1ea796 Binary files /dev/null and b/examples/BuddyMobileNetV3/images/traffic-light-24bit-224*224.bmp differ diff --git a/examples/BuddyMobileNetV3/images/traffic-light-32bit-224*224.bmp b/examples/BuddyMobileNetV3/images/traffic-light-32bit-224*224.bmp new file mode 100644 index 0000000000..c415c8dc32 Binary files /dev/null and b/examples/BuddyMobileNetV3/images/traffic-light-32bit-224*224.bmp differ diff --git a/examples/BuddyMobileNetV3/images/traffic-light.png b/examples/BuddyMobileNetV3/images/traffic-light.png index fa1a1e3f61..3fa00918da 100644 Binary files a/examples/BuddyMobileNetV3/images/traffic-light.png and b/examples/BuddyMobileNetV3/images/traffic-light.png differ diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index 7e93591409..443907d352 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -164,3 +164,69 @@ next-attention-fusion-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-sigmoid-run: + @${MLIR_OPT} ./next-sigmoid.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -func-bufferize \ + -arith-bufferize \ + -tensor-bufferize \ + -buffer-deallocation \ + -finalizing-bufferize \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-rope-run: + @${MLIR_OPT} ./next-rope.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -func-bufferize \ + -arith-bufferize \ + -tensor-bufferize \ + -buffer-deallocation \ + -finalizing-bufferize \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/BuddyNext/next-rope.mlir b/examples/BuddyNext/next-rope.mlir new file mode 100644 index 0000000000..091b2c220f --- /dev/null +++ b/examples/BuddyNext/next-rope.mlir @@ -0,0 +1,157 @@ +// RUN: buddy-opt %s \ +// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \ +// RUN: | buddy-opt \ +// RUN: -arith-expand \ +// RUN: -eliminate-empty-tensors \ +// RUN: -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -func-bufferize \ +// RUN: -arith-bufferize \ +// RUN: -tensor-bufferize \ +// RUN: -buffer-deallocation \ +// RUN: -finalizing-bufferize \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @rtclock() -> f64 + +#map = affine_map<(d0, d1, d2) -> (d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map3 = affine_map<(d0, d1) -> (d0, d1)> +#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map5 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map6 = affine_map<(d0, d1, d2) -> (d0, 0, d1, d2)> +#map7 = affine_map<(d0, d1) -> (0, d0, d1)> + +func.func @kenerl(%arg0 : tensor<1x40x4096xf32>, %arg1 : tensor<1x40x4096xf32>, %arg2 : tensor<1x40x4096xf32>, %arg3 : tensor<1x1x2048x128xf32>, %arg4 : tensor<1x1x2048x128xf32>, %arg5 : tensor<1x40xi64>) { + %t_start = call @rtclock() : () -> f64 + + %57 = tosa.reshape %arg0 {new_shape = array} : (tensor<1x40x4096xf32>) -> tensor<1x40x32x128xf32> + %58 = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + %59 = tosa.transpose %57, %58 : (tensor<1x40x32x128xf32>, tensor<4xi32>) -> tensor<1x32x40x128xf32> + + %60 = tosa.reshape %arg1 {new_shape = array} : (tensor<1x40x4096xf32>) -> tensor<1x40x32x128xf32> + %61 = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + %62 = tosa.transpose %60, %61 : (tensor<1x40x32x128xf32>, tensor<4xi32>) -> tensor<1x32x40x128xf32> + + %63 = tosa.reshape %arg2 {new_shape = array} : (tensor<1x40x4096xf32>) -> tensor<1x40x32x128xf32> + %64 = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + %65 = tosa.transpose %63, %64 : (tensor<1x40x32x128xf32>, tensor<4xi32>) -> tensor<1x32x40x128xf32> + + %extracted_slice_9 = tensor.extract_slice %arg3[0, 0, 0, 0] [1, 1, 2048, 128] [1, 1, 1, 1] : tensor<1x1x2048x128xf32> to tensor<1x1x2048x128xf32> + %extracted_slice_10 = tensor.extract_slice %extracted_slice_9[0, 0, 0, 0] [1, 1, 2048, 128] [1, 1, 1, 1] : tensor<1x1x2048x128xf32> to tensor<1x1x2048x128xf32> + %extracted_slice_11 = tensor.extract_slice %extracted_slice_10[0, 0, 0, 0] [1, 1, 40, 128] [1, 1, 1, 1] : tensor<1x1x2048x128xf32> to tensor<1x1x40x128xf32> + %extracted_slice_12 = tensor.extract_slice %arg4[0, 0, 0, 0] [1, 1, 2048, 128] [1, 1, 1, 1] : tensor<1x1x2048x128xf32> to tensor<1x1x2048x128xf32> + %extracted_slice_13 = tensor.extract_slice %extracted_slice_12[0, 0, 0, 0] [1, 1, 2048, 128] [1, 1, 1, 1] : tensor<1x1x2048x128xf32> to tensor<1x1x2048x128xf32> + %extracted_slice_14 = tensor.extract_slice %extracted_slice_13[0, 0, 0, 0] [1, 1, 40, 128] [1, 1, 1, 1] : tensor<1x1x2048x128xf32> to tensor<1x1x40x128xf32> + %66 = tensor.empty() : tensor<1x40x128xf32> + %67 = linalg.generic {indexing_maps = [#map6, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_11 : tensor<1x1x40x128xf32>) outs(%66 : tensor<1x40x128xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x40x128xf32> + %68 = tensor.empty() : tensor<40x128xf32> + %69 = linalg.generic {indexing_maps = [#map7, #map3], iterator_types = ["parallel", "parallel"]} ins(%67 : tensor<1x40x128xf32>) outs(%68 : tensor<40x128xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<40x128xf32> + %70 = tensor.empty() : tensor<1x40x128xf32> + %71 = linalg.generic {indexing_maps = [#map6, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_14 : tensor<1x1x40x128xf32>) outs(%70 : tensor<1x40x128xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x40x128xf32> + %72 = tensor.empty() : tensor<40x128xf32> + %73 = linalg.generic {indexing_maps = [#map7, #map3], iterator_types = ["parallel", "parallel"]} ins(%71 : tensor<1x40x128xf32>) outs(%72 : tensor<40x128xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<40x128xf32> + // precompute_theta_pos_frequencies function, which is used to calculating special values ​​of RoPE according to: https://hyper.ai/wiki/29220 + %74 = tensor.empty() : tensor<1x40x128xf32> + %75 = linalg.generic {indexing_maps = [#map2, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg5 : tensor<1x40xi64>) outs(%74 : tensor<1x40x128xf32>) { + ^bb0(%in: i64, %out: f32): + %4175 = arith.index_cast %in : i64 to index + %4176 = linalg.index 2 : index + %extracted = tensor.extract %69[%4175, %4176] : tensor<40x128xf32> + linalg.yield %extracted : f32 + } -> tensor<1x40x128xf32> + %76 = tosa.reshape %75 {new_shape = array} : (tensor<1x40x128xf32>) -> tensor<1x1x40x128xf32> + %77 = tensor.empty() : tensor<1x40x128xf32> + %78 = linalg.generic {indexing_maps = [#map2, #map5], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg5 : tensor<1x40xi64>) outs(%77 : tensor<1x40x128xf32>) { + ^bb0(%in: i64, %out: f32): + %4175 = arith.index_cast %in : i64 to index + %4176 = linalg.index 2 : index + %extracted = tensor.extract %73[%4175, %4176] : tensor<40x128xf32> + linalg.yield %extracted : f32 + } -> tensor<1x40x128xf32> + %79 = tosa.reshape %78 {new_shape = array} : (tensor<1x40x128xf32>) -> tensor<1x1x40x128xf32> + %80 = tosa.mul %59, %76 {shift = 0 : i8} : (tensor<1x32x40x128xf32>, tensor<1x1x40x128xf32>) -> tensor<1x32x40x128xf32> + %extracted_slice_15 = tensor.extract_slice %59[0, 0, 0, 0] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x128xf32> to tensor<1x32x40x64xf32> + %extracted_slice_16 = tensor.extract_slice %59[0, 0, 0, 64] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x128xf32> to tensor<1x32x40x64xf32> + %81 = tosa.negate %extracted_slice_16 : (tensor<1x32x40x64xf32>) -> tensor<1x32x40x64xf32> + %82 = tensor.empty() : tensor<1x32x40x128xf32> + %inserted_slice = tensor.insert_slice %81 into %82[0, 0, 0, 0] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x64xf32> into tensor<1x32x40x128xf32> + %inserted_slice_17 = tensor.insert_slice %extracted_slice_15 into %inserted_slice[0, 0, 0, 64] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x64xf32> into tensor<1x32x40x128xf32> + %83 = tosa.mul %inserted_slice_17, %79 {shift = 0 : i8} : (tensor<1x32x40x128xf32>, tensor<1x1x40x128xf32>) -> tensor<1x32x40x128xf32> + %84 = tosa.add %80, %83 : (tensor<1x32x40x128xf32>, tensor<1x32x40x128xf32>) -> tensor<1x32x40x128xf32> + %85 = tosa.mul %62, %76 {shift = 0 : i8} : (tensor<1x32x40x128xf32>, tensor<1x1x40x128xf32>) -> tensor<1x32x40x128xf32> + %extracted_slice_18 = tensor.extract_slice %62[0, 0, 0, 0] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x128xf32> to tensor<1x32x40x64xf32> + %extracted_slice_19 = tensor.extract_slice %62[0, 0, 0, 64] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x128xf32> to tensor<1x32x40x64xf32> + %86 = tosa.negate %extracted_slice_19 : (tensor<1x32x40x64xf32>) -> tensor<1x32x40x64xf32> + %87 = tensor.empty() : tensor<1x32x40x128xf32> + %inserted_slice_20 = tensor.insert_slice %86 into %87[0, 0, 0, 0] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x64xf32> into tensor<1x32x40x128xf32> + %inserted_slice_21 = tensor.insert_slice %extracted_slice_18 into %inserted_slice_20[0, 0, 0, 64] [1, 32, 40, 64] [1, 1, 1, 1] : tensor<1x32x40x64xf32> into tensor<1x32x40x128xf32> + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %tensor_unranked = tensor.cast %inserted_slice_21 : tensor<1x32x40x128xf32> to tensor<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 32, 40, 128] strides = [163840, 5120, 128, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [ + // CHECK-SAME: [-3{{(, [-]?3)*}}], + + // Print results. + call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () + // Print timings. + vector.print %time : f64 + + return +} + +func.func @main() { + + %c2 = arith.constant dense<2.0> : tensor<1x40x4096xf32> + %c3 = arith.constant dense<3.0> : tensor<1x40x4096xf32> + %c4 = arith.constant dense<4.0> : tensor<1x40x4096xf32> + %c5 = arith.constant dense<5.0> : tensor<1x1x2048x128xf32> + %c6 = arith.constant dense<6.0> : tensor<1x1x2048x128xf32> + %c7 = arith.constant dense<7> : tensor<1x40xi64> + + call @kenerl(%c2, %c3, %c4, %c5, %c6, %c7) : (tensor<1x40x4096xf32>, tensor<1x40x4096xf32>, tensor<1x40x4096xf32>, tensor<1x1x2048x128xf32>, tensor<1x1x2048x128xf32>, tensor<1x40xi64>) -> () + + return +} +func.func private @printMemrefF32(%ptr : tensor<*xf32>) diff --git a/examples/BuddyNext/next-sigmoid.mlir b/examples/BuddyNext/next-sigmoid.mlir new file mode 100644 index 0000000000..f49f2d7943 --- /dev/null +++ b/examples/BuddyNext/next-sigmoid.mlir @@ -0,0 +1,70 @@ +// RUN: buddy-opt %s \ +// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \ +// RUN: | buddy-opt \ +// RUN: -arith-expand \ +// RUN: -eliminate-empty-tensors \ +// RUN: -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -func-bufferize \ +// RUN: -arith-bufferize \ +// RUN: -tensor-bufferize \ +// RUN: -buffer-deallocation \ +// RUN: -finalizing-bufferize \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @rtclock() -> f64 + +func.func @kenerl(%arg0 : tensor<1x40x11008xf32>) { + %t_start = call @rtclock() : () -> f64 + + %sigmoid = tosa.sigmoid %arg0 : (tensor<1x40x11008xf32>) -> tensor<1x40x11008xf32> + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %tensor_unranked = tensor.cast %sigmoid : tensor<1x40x11008xf32> to tensor<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 40, 11008] strides = [440320, 11008, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [0.952574{{(, 0.952574)*}}], + + // Print results. + call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () + // Print timings. + vector.print %time : f64 + + return +} + +func.func @main() { + + %c3 = arith.constant dense<3.0> : tensor<1x40x11008xf32> + + call @kenerl(%c3) : (tensor<1x40x11008xf32>) -> () + + return +} +func.func private @printMemrefF32(%ptr : tensor<*xf32>) diff --git a/examples/BuddyPython/module_gen.py b/examples/BuddyPython/module_gen.py index e2c722cebf..1f657d2609 100644 --- a/examples/BuddyPython/module_gen.py +++ b/examples/BuddyPython/module_gen.py @@ -43,12 +43,11 @@ def foo(x, y): aot_autograd_decomposition=inductor_decomp, ) -# Pass the function and input data to the dynamo compiler's importer, the -# importer will first build a graph. Then, lower the graph to top-level IR. +# Pass the function and input data to the dynamo compiler's importer, the +# importer will first build a graph. Then, lower the graph to top-level IR. # (tosa, linalg, etc.). Finally, accepts the generated module and weight parameters. -graphs = dynamo_compiler.importer(foo, *(float32_in1, float32_in2)) +graphs = dynamo_compiler.importer(foo, float32_in1, float32_in2) graph = graphs[0] -graph.lower_to_top_level_ir(do_params_pack=True) +graph.lower_to_top_level_ir() print(graph._imported_module) -print(dynamo_compiler.imported_params[graph]) diff --git a/examples/BuddyWhisper/CMakeLists.txt b/examples/BuddyWhisper/CMakeLists.txt index 16518ffb62..756d6db081 100644 --- a/examples/BuddyWhisper/CMakeLists.txt +++ b/examples/BuddyWhisper/CMakeLists.txt @@ -6,22 +6,22 @@ add_custom_command( set(PATTERN_ARG "test-generalize-pad-tensor") add_custom_command( OUTPUT forward.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyWhisper/forward.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyWhisper/forward.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" | ${BUDDY_BINARY_DIR}/buddy-opt - -pass-pipeline "builtin.module( func.func(buffer-deallocation-simplification, convert-linalg-to-loops),matmul-paralell-vectorization-optimize, batchmatmul-optimize, eliminate-empty-tensors,func-bufferize-dynamic-offset, func.func(llvm-request-c-wrappers),convert-scf-to-openmp, convert-openmp-to-llvm, convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyWhisper/forward.o + -pass-pipeline "builtin.module( func.func(buffer-deallocation-simplification, convert-linalg-to-loops),matmul-parallel-vectorization-optimize, batchmatmul-optimize, eliminate-empty-tensors,func-bufferize-dynamic-offset, func.func(llvm-request-c-wrappers),convert-scf-to-openmp, convert-openmp-to-llvm, convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" | + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyWhisper/forward.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyWhisper/forward.mlir COMMENT "Building forward.o" VERBATIM) add_custom_command( OUTPUT subgraph0.o - COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyWhisper/subgraph0.mlir + COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyWhisper/subgraph0.mlir -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" | - ${LLVM_MLIR_BINARY_DIR}/mlir-opt + ${LLVM_TOOLS_BINARY_DIR}/mlir-opt -test-linalg-transform-patterns=${PATTERN_ARG} | ${BUDDY_BINARY_DIR}/buddy-opt -arith-expand @@ -29,7 +29,7 @@ add_custom_command( -convert-elementwise-to-linalg -empty-tensor-to-alloc-tensor -one-shot-bufferize - -matmul-paralell-vectorization-optimize + -matmul-parallel-vectorization-optimize -batchmatmul-optimize -convert-linalg-to-affine-loops -affine-loop-fusion @@ -55,9 +55,9 @@ add_custom_command( -convert-math-to-libm -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llvm-as | - ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 -o ${BUDDY_BINARY_DIR}/../examples/BuddyWhisper/subgraph0.o + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llvm-as | + ${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3 -o ${BUDDY_BINARY_DIR}/../examples/BuddyWhisper/subgraph0.o DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyWhisper/subgraph0.mlir COMMENT "Building subgraph0.o " VERBATIM) @@ -75,11 +75,16 @@ SET_TARGET_PROPERTIES( PROPERTIES LINKER_LANGUAGE C) -add_executable(buddy-whisper-run whisper-main.cpp) -target_link_directories(buddy-whisper-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR}) +set(BUDDY_WHISPER_FILES + whisper-main.cpp +) + +add_executable(buddy-whisper-run ${BUDDY_WHISPER_FILES}) +target_link_directories(buddy-whisper-run PRIVATE ${LLVM_LIBRARY_DIR}) set(BUDDY_WHISPER_LIBS WHISPER + BuddyLibDAP mlir_c_runner_utils omp ) diff --git a/examples/BuddyWhisper/README.md b/examples/BuddyWhisper/README.md index f26a1d845a..680fb34ce9 100644 --- a/examples/BuddyWhisper/README.md +++ b/examples/BuddyWhisper/README.md @@ -1,7 +1,7 @@ # Buddy Compiler WHISPER Example ## Introduction -This example shows how to use Buddy Compiler to compile a WHISPER model to MLIR code then run it. The [model](openai/whisper-base) is a pre-trained model for automatic speech recognition (ASR) and speech translation. +This example shows how to use Buddy Compiler to compile a WHISPER model to MLIR code then run it. The [model](https://huggingface.co/openai/whisper-base) is a pre-trained model for automatic speech recognition (ASR) and speech translation (ST). ## How to run @@ -63,16 +63,15 @@ $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build $ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH} ``` -3. Set model and dataset environment variable. +3. Set model environment variable. ```bash $ export WHISPER_MODEL_PATH=/path-to-whisper-model/ -$ export AUDIO_DATASET_PATH=/path-to-audio-dataset/ // For example: $ export WHISPER_MODEL_PATH=/home/xxx/whisper-base -$ export AUDIO_DATASET_PATH=/home/xxx/librispeech_asr_dummy ``` +Alternatively, you can leave the path blank, and import-whisper.py will automatically download the model for you. 4. Build and run the WHISPER example @@ -83,4 +82,4 @@ $ cd bin $ ./buddy-whisper-run ``` -4. Enjoy it! +5. Enjoy it! diff --git a/examples/BuddyWhisper/audio.wav b/examples/BuddyWhisper/audio.wav new file mode 100644 index 0000000000..069c2329ef Binary files /dev/null and b/examples/BuddyWhisper/audio.wav differ diff --git a/examples/BuddyWhisper/import-whisper.py b/examples/BuddyWhisper/import-whisper.py index 7b5d3681fe..449646a676 100644 --- a/examples/BuddyWhisper/import-whisper.py +++ b/examples/BuddyWhisper/import-whisper.py @@ -14,7 +14,7 @@ # # ===--------------------------------------------------------------------------- # -# This is the example of whisper model. +# This is an example for whisper model. # # ===--------------------------------------------------------------------------- @@ -22,8 +22,7 @@ import torch import torch._dynamo as dynamo from torch._inductor.decomposition import decompositions as inductor_decomp -from transformers import WhisperProcessor, WhisperForConditionalGeneration -from datasets import load_dataset +from transformers import WhisperForConditionalGeneration import numpy from buddy.compiler.frontend import DynamoCompiler @@ -34,27 +33,20 @@ # Retrieve the Whisper model path from environment variables. model_path = os.environ.get("WHISPER_MODEL_PATH") if model_path is None: - raise EnvironmentError( - "The environment variable 'WHISPER_MODEL_PATH' is not set or is invalid." - ) + model_path = "openai/whisper-base" -# Initialize the tokenizer and model from the specified model path. -processor = WhisperProcessor.from_pretrained(model_path) +# Initialize the model from the specified model path. model = WhisperForConditionalGeneration.from_pretrained(model_path) model.config.use_cache = False -dataset_path = os.environ.get("AUDIO_DATASET_PATH") -ds = load_dataset(dataset_path, "clean", split="validation") -sample = ds[1]["audio"] -input_features = processor( - sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt" -).input_features - -decoder_input_ids = torch.tensor([[50258] * 448], dtype=torch.long) +# Generate placeholder for inputs. +input_features = torch.zeros(size=(1, 80, 3000), dtype=torch.float32) +decoder_input_ids = torch.zeros(size=(1, 448), dtype=torch.long) inputs = { "input_features": input_features, "decoder_input_ids": decoder_input_ids, } + # Initialize Dynamo Compiler with specific configurations as an importer. dynamo_compiler = DynamoCompiler( primary_registry=tosa.ops_registry, diff --git a/examples/BuddyWhisper/input_features.data b/examples/BuddyWhisper/input_features.data deleted file mode 100644 index c85c98ddf3..0000000000 Binary files a/examples/BuddyWhisper/input_features.data and /dev/null differ diff --git a/examples/BuddyWhisper/whisper-main.cpp b/examples/BuddyWhisper/whisper-main.cpp index 2ba9138544..7d69ea3074 100644 --- a/examples/BuddyWhisper/whisper-main.cpp +++ b/examples/BuddyWhisper/whisper-main.cpp @@ -13,21 +13,29 @@ // limitations under the License. // //===----------------------------------------------------------------------===// +// +// This file implements an example for Whisper Model Inference. +// +// ------------------------------------------------------------------------===// #include +#include #include #include #include #include +#include #include #include #include + +using namespace std; using namespace buddy; +using namespace dap; constexpr size_t ParamsSize = 99148800; constexpr size_t MaxVocabSize = 51865; constexpr size_t MaxTokenLength = 448; -constexpr size_t HiddenSize = 512; /// Declare Whisper forward function. extern "C" void _mlir_ciface_forward(MemRef *, MemRef *, @@ -37,14 +45,6 @@ extern "C" void _mlir_ciface_forward(MemRef *, MemRef *, // Helper Functions // ----------------------------------------------------------------------------- -/// Capture input message. -void getUserInput(std::string &inputStr) { - std::cout << "\nPlease send a message:" << std::endl; - std::cout << ">>> "; - getline(std::cin, inputStr); - std::cout << std::endl; -} - /// Print [Log] label in bold blue format. void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } @@ -83,30 +83,18 @@ void loadParameters(const std::string ¶mFilePath, << std::endl; } -void loadAudio(const std::string ¶mFilePath, MemRef ¶ms) { - const auto loadStart = std::chrono::high_resolution_clock::now(); - std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary); - if (!paramFile.is_open()) { - throw std::runtime_error("[Error] Failed to open input_features file!"); - } - printLogLabel(); - std::cout << "Loading input_features..." << std::endl; +/// Conduct audio data preprocess. +void runPreprocess(dap::Audio &rawAudioContainer, + MemRef &audioFeatures) { printLogLabel(); - std::cout << "input_features file: " - << std::filesystem::canonical(paramFilePath) << std::endl; - - paramFile.read(reinterpret_cast(params.getData()), - sizeof(float) * (params.getSize())); - - if (paramFile.fail()) { - throw std::runtime_error("Error occurred while reading params file!"); - } - paramFile.close(); + std::cout << "Preprocessing audio..." << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::whisperPreprocess(&rawAudioContainer, &audioFeatures); const auto loadEnd = std::chrono::high_resolution_clock::now(); const std::chrono::duration loadTime = loadEnd - loadStart; printLogLabel(); - std::cout << "input_features load time: " << (double)(loadTime.count()) / 1000 + std::cout << "Audio preprocess time: " << (double)(loadTime.count()) / 1000 << "s\n" << std::endl; } @@ -129,14 +117,13 @@ int main() { /// Define directories of vacabulary and parameter file. const std::string vocabDir = "../../examples/BuddyWhisper/vocab.txt"; const std::string paramsDir = "../../examples/BuddyWhisper/arg0.data"; - const std::string input_featuresDir = - "../../examples/BuddyWhisper/input_features.data"; /// Initialize data containers // - Result container // - Output container. // - Parameters container. Text outputContainer; + Audio rawAudioContainer("../../examples/BuddyWhisper/audio.wav"); MemRef audioInput({1, 80, 3000}); MemRef resultContainer[2] = { MemRef({1, 1500, 512}, false, 0), @@ -148,16 +135,17 @@ int main() { /// Fill data into containers // - Output: register vocabulary. // - Parameters: load parameters from the `arg0` file into the container. + // - Input: compute audioInput. outputContainer.loadVocab(vocabDir); loadParameters(paramsDir, paramsContainer); - loadAudio(input_featuresDir, audioInput); + runPreprocess(rawAudioContainer, audioInput); /// Run Whisper Inference // - Perform the forward function. // - Find and append the generated token. // - Continue iterating until the terminal condition is met. - for (int i = 0; i < MaxTokenLength - 1; i++) { + for (size_t i = 0; i < MaxTokenLength - 1; i++) { const auto inferenceStart = std::chrono::high_resolution_clock::now(); // Execute the forward pass of the model. _mlir_ciface_forward(resultContainer, ¶msContainer, &audioInput, diff --git a/examples/ConvOpt/CMakeLists.txt b/examples/ConvOpt/CMakeLists.txt index 83aa26b686..e01f2b46c6 100644 --- a/examples/ConvOpt/CMakeLists.txt +++ b/examples/ConvOpt/CMakeLists.txt @@ -16,14 +16,14 @@ message(STATUS "Spliting size: ${SPLITING_SIZE}") add_custom_command(OUTPUT conv2d.o COMMAND ${CMAKE_BINARY_DIR}/bin/buddy-opt ${BUDDY_EXAMPLES_DIR}/ConvOpt/conv2d.mlir -conv-vectorization="strip-mining=${SPLITING_SIZE}" -lower-affine -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llc -mtriple=${BUDDY_TARGET_TRIPLE} -mattr=${BUDDY_OPT_ATTR} --filetype=obj -o ${BUDDY_BINARY_DIR}/../examples/ConvOpt/conv2d.o + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate --mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llc -mtriple=${BUDDY_TARGET_TRIPLE} -mattr=${BUDDY_OPT_ATTR} --filetype=obj -o ${BUDDY_BINARY_DIR}/../examples/ConvOpt/conv2d.o DEPENDS buddy-opt) # add_custom_command(OUTPUT conv2d.o -# COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/ConvOpt/conv2d.mlir -convert-linalg-to-loops -convert-scf-to-cf -convert-linalg-to-llvm -lower-affine -convert-scf-to-cf --finalize-memref-to-llvm -convert-func-to-llvm='emit-c-wrappers=1' -reconcile-unrealized-casts | -# ${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir | -# ${LLVM_MLIR_BINARY_DIR}/llc -mtriple=${BUDDY_OPT_TRIPLE} -mattr=${BUDDY_OPT_ATTR} --filetype=obj -o ${BUDDY_BINARY_DIR}/../examples/ConvOpt/conv2d.o +# COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/ConvOpt/conv2d.mlir -convert-linalg-to-loops -convert-scf-to-cf -convert-linalg-to-llvm -lower-affine -convert-scf-to-cf --finalize-memref-to-llvm -convert-func-to-llvm='emit-c-wrappers=1' -reconcile-unrealized-casts | +# ${LLVM_TOOLS_BINARY_DIR}/mlir-translate --mlir-to-llvmir | +# ${LLVM_TOOLS_BINARY_DIR}/llc -mtriple=${BUDDY_OPT_TRIPLE} -mattr=${BUDDY_OPT_ATTR} --filetype=obj -o ${BUDDY_BINARY_DIR}/../examples/ConvOpt/conv2d.o # DEPENDS buddy-opt) add_library(Conv2D STATIC conv2d.o) diff --git a/examples/DAPDialect/CMakeLists.txt b/examples/DAPDialect/CMakeLists.txt index b147d56047..96b921ee3a 100644 --- a/examples/DAPDialect/CMakeLists.txt +++ b/examples/DAPDialect/CMakeLists.txt @@ -20,6 +20,7 @@ add_executable(buddy-fir FIRLowpass.cpp) add_dependencies(buddy-fir buddy-opt) target_link_libraries(buddy-fir BuddyLibDAP + mlir_c_runner_utils ) #------------------------------------------------------------------------------- @@ -30,6 +31,7 @@ add_executable(buddy-biquad biquad.cpp) add_dependencies(buddy-biquad buddy-opt) target_link_libraries(buddy-biquad BuddyLibDAP + mlir_c_runner_utils ) #------------------------------------------------------------------------------- @@ -40,10 +42,30 @@ add_executable(buddy-iir-scalar IIRLowpass.cpp) add_dependencies(buddy-iir-scalar buddy-opt) target_link_libraries(buddy-iir-scalar BuddyLibDAP + mlir_c_runner_utils ) add_executable(buddy-iir-vectorization IIRVectorization.cpp) add_dependencies(buddy-iir-vectorization buddy-opt) target_link_libraries(buddy-iir-vectorization - BuddyLibDAPVectorization + BuddyLibDAP + mlir_c_runner_utils +) + +#------------------------------------------------------------------------------- +# Buddy DAP Dialect WhisperPreprocess Operation +#------------------------------------------------------------------------------- + +add_executable(buddy-whisper-preprocess WhisperPreprocess.cpp) +add_dependencies(buddy-whisper-preprocess buddy-opt) +target_link_libraries(buddy-whisper-preprocess + BuddyLibDAP + mlir_c_runner_utils +) + +add_executable(buddy-rfft RFFT.cpp) +add_dependencies(buddy-rfft buddy-opt) +target_link_libraries(buddy-rfft + BuddyLibDAP + mlir_c_runner_utils ) diff --git a/examples/DAPDialect/FIRLowpass.cpp b/examples/DAPDialect/FIRLowpass.cpp index cfce56091d..3a8217730a 100644 --- a/examples/DAPDialect/FIRLowpass.cpp +++ b/examples/DAPDialect/FIRLowpass.cpp @@ -14,45 +14,76 @@ // //===----------------------------------------------------------------------===// // -// This file implements an end to end example for fir filter in buddy-mlir. It -// generates coefficients for a filter and apply it on a piece of mono audio, -// then saves the audio. -// This file will be linked with the object file generated by mlir to generate -// the executable file. +// An end-to-end example of an FIR (Finite Impulse Response) operation in +// buddy-mlir. // //===----------------------------------------------------------------------===// #include +#include #include using namespace dap; using namespace std; -int main(int argc, char *argv[]) { - string fileName = "../../tests/Interface/core/NASA_Mars.wav"; - ; - string saveFileName = "FIR_NASA_Mars.wav"; - if (argc >= 2) { - fileName = argv[1]; - } - if (argc == 3) { - saveFileName = argv[2]; - } - cout << "Usage: FIRLowpass [loadPath] [savePath]" << endl; - cout << "Current specified path: \n"; - cout << "Load: " << fileName << endl; - cout << "Save: " << saveFileName << endl; +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +int main() { + // Print the title of this example. + const std::string title = "FIR Operation Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + // Generate the kernel for a FIR filter operation. + // Params: + // Input kernel: Stores generated kernel data. + // Type: Specifies the window type from the WINDOW_TYPE enum class. + // Length: The length of the filter. + // Cutoff: The lowpass cutoff frequency. + // Argument: Filter-specific arguments, with size limited by the + // WINDOW_TYPE. intptr_t kernelSize = 100; MemRef kernel(&kernelSize); - dap::firLowpass(kernel, dap::WINDOW_TYPE::BLACKMANHARRIS7, - kernelSize, 0.3, nullptr); - auto aud = dap::Audio(fileName); - aud.getAudioFile().printSummary(); - dap::Audio output; - output.fetchMetadata(aud.getAudioFile()); - output.getAudioFile().setAudioBuffer(nullptr); - dap::fir(&aud.getMemRef(), &kernel, &output.getMemRef()); - cout << "Saving file:" << endl; - cout << (output.save(saveFileName) ? "OK" : "ERROR") << endl; + dap::firLowpass(/*input=*/kernel, + /*type=*/dap::WINDOW_TYPE::BLACKMANHARRIS7, + /*len=*/kernelSize, /*cutoff=*/0.3, + /*args=*/nullptr); + + // Initialize data containers. + // Params: + // Input container: Stores the raw audio data. + // Returns: + // Output memory reference: Provides a MemRef for saving the output. + Audio inputContainer("../../tests/Interface/core/TestAudio.wav"); + intptr_t samplesNum = static_cast(inputContainer.getSamplesNum()); + MemRef outputMemRef(&samplesNum); + + // Apply the FIR filter operation to the audio data. + printLogLabel(); + std::cout << "Running FIR operation..." << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::fir(&inputContainer, &kernel, &outputMemRef); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Audio processing time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; + + // Convert a MemRef object to an Audio object and set the metadata. + Audio outputContainer(std::move(outputMemRef)); + outputContainer.setBitDepth(inputContainer.getBitDepth()); + outputContainer.setSamplesNum(inputContainer.getSamplesNum()); + outputContainer.setChannelsNum(inputContainer.getChannelsNum()); + outputContainer.setSampleRate(inputContainer.getSampleRate()); + + // Save the processed data to an audio file. + std::string saveFileName = "FIRTestAudio.wav"; + outputContainer.saveToFile(saveFileName, "wave"); + printLogLabel(); + std::cout << "Processed audio data saved in: " << saveFileName << "\n" + << std::endl; + return 0; } diff --git a/examples/DAPDialect/IIRLowpass.cpp b/examples/DAPDialect/IIRLowpass.cpp index 1b69ec08b0..ec5de06c95 100644 --- a/examples/DAPDialect/IIRLowpass.cpp +++ b/examples/DAPDialect/IIRLowpass.cpp @@ -14,52 +14,81 @@ // //===----------------------------------------------------------------------===// // -// This file implements an end to end example for iir filter in buddy-mlir. It -// generates coefficients for a filter and apply it on a piece of mono audio, -// then saves the audio. -// This file will be linked with the object file generated by mlir to generate -// the executable file. +// An end-to-end example of the scalar version IIR (Infinite Impulse Response) +// operation in buddy-mlir. // //===----------------------------------------------------------------------===// #include +#include #include using namespace dap; using namespace std; +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + int main(int argc, char *argv[]) { - string fileName = "../../tests/Interface/core/NASA_Mars.wav"; - string saveFileName = "IIR_LOWPASS_NASA_Mars.wav"; - if (argc >= 2) { - fileName = argv[1]; - } - if (argc == 3) { - saveFileName = argv[2]; - } - cout << "Usage: IIRLowpass [loadPath] [savePath]" << endl; - cout << "Current specified path: \n"; - cout << "Load: " << fileName << endl; - cout << "Save: " << saveFileName << endl; - // Order of butterworth filter + // Print the title of this example. + const std::string title = + "Scalar Version IIR Operation Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + // Allocate kernel MemRef for an IIR filter operation. + // Params: + // Order: The order of the butterworth filter. + // Parameter size: Each SOS matrix has 6 parameters. int order = 8; - // Each SOS matrix has 6 paramters. intptr_t kernelSize[2] = {int(order / 2), 6}; MemRef kernel(kernelSize); - // cutoff frequency = 1000, fs = 48000. - dap::iirLowpass(kernel, dap::butterworth(order), 1000, - 48000); - auto aud = dap::Audio(fileName); - aud.getAudioFile().printSummary(); - dap::Audio output; - output.fetchMetadata(aud.getAudioFile()); - output.getAudioFile().setAudioBuffer(nullptr); + // Generate the kernel for an IIR filter operation. + // Params: + // Input kernel: Stores generated kernel data. + // Lowpass filter: Supports butterworth filter upto order 12 for now. + // Lowpass frequency: The lowpass cutoff frequency. + // Sampling frequency: The rate at which the input data is sampled. + dap::iirLowpass(/*kernel=*/kernel, + /*filter=*/dap::butterworth(order), + /*frequency=*/1000, + /*fs=*/48000); + + // Initialize data containers. + // Params: + // Input container: Stores the raw audio data. + // Returns: + // Output memory reference: Provides a MemRef for saving the output. + Audio inputContainer("../../tests/Interface/core/TestAudio.wav"); + intptr_t samplesNum = static_cast(inputContainer.getSamplesNum()); + MemRef outputMemRef(&samplesNum); + + // Apply scalar version IIR operation to the audio data. + printLogLabel(); + std::cout << "Running scalar version IIR operation..." << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::IIR(&inputContainer, &kernel, &outputMemRef); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Audio processing time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; - dap::IIR(&aud.getMemRef(), &kernel, &output.getMemRef()); + // Convert a MemRef object to an Audio object and set the metadata. + Audio outputContainer(std::move(outputMemRef)); + outputContainer.setBitDepth(inputContainer.getBitDepth()); + outputContainer.setSamplesNum(inputContainer.getSamplesNum()); + outputContainer.setChannelsNum(inputContainer.getChannelsNum()); + outputContainer.setSampleRate(inputContainer.getSampleRate()); - cout << "Saving file:" << endl; - cout << (output.save(saveFileName) ? "OK" : "ERROR") << endl; + // Save the processed data to an audio file. + std::string saveFileName = "ScalarVersionIIRTestAudio.wav"; + outputContainer.saveToFile(saveFileName, "wave"); + printLogLabel(); + std::cout << "Processed audio data saved in: " << saveFileName << "\n" + << std::endl; return 0; } diff --git a/examples/DAPDialect/IIRVectorization.cpp b/examples/DAPDialect/IIRVectorization.cpp index c7d0c19553..e766c85889 100644 --- a/examples/DAPDialect/IIRVectorization.cpp +++ b/examples/DAPDialect/IIRVectorization.cpp @@ -14,53 +14,82 @@ // //===----------------------------------------------------------------------===// // -// This file implements an end to end example for iir filter in buddy-mlir. It -// generates coefficients for a filter and apply it on a piece of mono audio, -// then saves the audio. -// This file will be linked with the object file which use dap vectorization -// pass to generate the executable file. +// An end-to-end example of the vectorized IIR (Infinite Impulse Response) +// operation in buddy-mlir. // //===----------------------------------------------------------------------===// #include +#include #include using namespace dap; using namespace std; -int main(int argc, char *argv[]) { - string fileName = "../../tests/Interface/core/NASA_Mars.wav"; - string saveFileName = "IIR_VECTORIZATION_PASS_NASA_Mars.wav"; - if (argc >= 2) { - fileName = argv[1]; - } - if (argc == 3) { - saveFileName = argv[2]; - } - cout << "Usage: IIRVectorizationPass [loadPath] [savePath]" << endl; - cout << "Current specified path: \n"; - cout << "Load: " << fileName << endl; - cout << "Save: " << saveFileName << endl; - // Order for butterworth filter. +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +int main() { + // Print the title of this example. + const std::string title = + "Vectorized IIR Operation Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + // Allocate kernel MemRef for an IIR filter operation. + // Params: + // Order: The order of the butterworth filter. + // Parameter size: Each SOS matrix has 6 parameters. int order = 8; - // Each SOS matrix has 6 paramters. intptr_t kernelSize[2] = {int(order / 2), 6}; MemRef kernel(kernelSize); - // cutoff frequency = 1000, fs = 48000. - dap::iirLowpass(kernel, dap::butterworth(order), 1000, - 48000); - auto aud = dap::Audio(fileName); - aud.getAudioFile().printSummary(); - dap::Audio output; - output.fetchMetadata(aud.getAudioFile()); - output.getAudioFile().setAudioBuffer(nullptr); + // Generate the kernel for an IIR filter operation. + // Params: + // Input kernel: Stores generated kernel data. + // Lowpass filter: Supports butterworth filter upto order 12 for now. + // Lowpass frequency: The lowpass cutoff frequency. + // Sampling frequency: The rate at which the input data is sampled. + dap::iirLowpass(/*kernel=*/kernel, + /*filter=*/dap::butterworth(order), + /*frequency=*/1000, + /*fs=*/48000); + + // Initialize data containers. + // Params: + // Input container: Stores the raw audio data. + // Returns: + // Output memory reference: Provides a MemRef for saving the output. + Audio inputContainer("../../tests/Interface/core/TestAudio.wav"); + intptr_t samplesNum = static_cast(inputContainer.getSamplesNum()); + MemRef outputMemRef(&samplesNum); - dap::IIR(&aud.getMemRef(), &kernel, &output.getMemRef(), + // Apply vectorized IIR operation to the audio data. + printLogLabel(); + std::cout << "Running vectorized IIR operation..." << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::IIR(&inputContainer, &kernel, &outputMemRef, /*isVectorization=*/true); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Audio processing time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; + + // Convert a MemRef object to an Audio object and set the metadata. + Audio outputContainer(std::move(outputMemRef)); + outputContainer.setBitDepth(inputContainer.getBitDepth()); + outputContainer.setSamplesNum(inputContainer.getSamplesNum()); + outputContainer.setChannelsNum(inputContainer.getChannelsNum()); + outputContainer.setSampleRate(inputContainer.getSampleRate()); - cout << "Saving file:" << endl; - cout << (output.save(saveFileName) ? "OK" : "ERROR") << endl; + // Save the processed data to an audio file. + std::string saveFileName = "VectorizedIIRTestAudio.wav"; + outputContainer.saveToFile(saveFileName, "wave"); + printLogLabel(); + std::cout << "Processed audio data saved in: " << saveFileName << "\n" + << std::endl; return 0; } diff --git a/examples/DAPDialect/RFFT.cpp b/examples/DAPDialect/RFFT.cpp new file mode 100644 index 0000000000..993fec95e1 --- /dev/null +++ b/examples/DAPDialect/RFFT.cpp @@ -0,0 +1,75 @@ +//===- RFFT.cpp - Example of DAP RFFT Operation ---------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// An example of the RFFT function from Whisper Preprocessor operation. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#define testLength 840 + +using namespace dap; +using namespace std; + +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +// Write preprocessing results to a text file. +void printResult(MemRef &outputMemRef) { + ofstream fout("whisperPreprocessResultRFFT.txt"); + // Print title. + fout << "-----------------------------------------" << std::endl; + fout << "[ Buddy RFFT Result ]" << std::endl; + fout << "-----------------------------------------" << std::endl; + // Print reuslt data. + for (int i = 0; i < testLength; ++i) { + fout << outputMemRef[i] << std::endl; + } + fout.close(); +} + +int main() { + // Print the title of this example. + const std::string title = "RFFT Operation Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + double *inputAlign = new double[testLength]; + for (int i = 0; i < testLength; ++i) { + inputAlign[i] = static_cast(i); + } + intptr_t inputSizes[1] = {testLength}; + MemRef inputMemRef(inputAlign, inputSizes); + + printLogLabel(); + std::cout << "Running RFFT operation" << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::RFFT(&inputMemRef); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "RFFT time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; + + printResult(inputMemRef); + + return 0; +} diff --git a/examples/DAPDialect/WhisperPreprocess.cpp b/examples/DAPDialect/WhisperPreprocess.cpp new file mode 100644 index 0000000000..db69ac836e --- /dev/null +++ b/examples/DAPDialect/WhisperPreprocess.cpp @@ -0,0 +1,77 @@ +//===- WhisperPreprocessor.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// An example of the Whisper Preprocessor operation. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +using namespace dap; +using namespace std; + +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +// Write preprocessing results to a text file. +void printResult(MemRef &outputMemRef) { + ofstream fout("whisperPreprocessResult.txt"); + // Print title. + fout << "-----------------------------------------" << std::endl; + fout << "[ Whisper Preprocess Result ]" << std::endl; + fout << "-----------------------------------------" << std::endl; + // Print reuslt data. + for (int i = 0; i < 240000; ++i) { + fout << outputMemRef[i] << std::endl; + } + fout.close(); +} + +int main() { + // Print the title of this example. + const std::string title = "Whisper Preprocess Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + // Initialize data containers. + // Params: + // Input container: Stores raw audio data. + // Returns: + // Output memory reference: Features formatted as memref<1x80x3000xf32>. + Audio inputContainer("../../examples/BuddyWhisper/audio.wav"); + float *outputAlign = new float[240000]; + intptr_t outputSizes[3] = {1, 80, 3000}; + MemRef outputMemRef(outputAlign, outputSizes); + + // Compute audio features from raw audio data. + printLogLabel(); + std::cout << "Preprocessing audio..." << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::whisperPreprocess(&inputContainer, &outputMemRef); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Audio preprocess time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; + + // printResult(outputMemRef); + + return 0; +} diff --git a/examples/DAPDialect/biquad.cpp b/examples/DAPDialect/biquad.cpp index 14a78084a0..e606c2d0e8 100644 --- a/examples/DAPDialect/biquad.cpp +++ b/examples/DAPDialect/biquad.cpp @@ -14,45 +14,70 @@ // //===----------------------------------------------------------------------===// // -// This file implements an end to end example for biquad filter in buddy-mlir. -// It generates coefficients for a filter and apply it on a piece of mono audio, -// then saves the audio. -// This file will be linked with the object file generated by mlir to generate -// the executable file. +// An end-to-end example of a biquad operation in buddy-mlir. // //===----------------------------------------------------------------------===// #include +#include #include using namespace dap; using namespace std; -int main(int argc, char *argv[]) { - string fileName = "../../tests/Interface/core/NASA_Mars.wav"; - string saveFileName = "BIQUAD_NASA_Mars.wav"; - if (argc >= 2) { - fileName = argv[1]; - } - if (argc == 3) { - saveFileName = argv[2]; - } - cout << "Usage: BiquadLowpass [loadPath] [savePath]" << endl; - cout << "Current specified path: \n"; - cout << "Load: " << fileName << endl; - cout << "Save: " << saveFileName << endl; +// Print [Log] label in bold blue format. +void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; } + +int main() { + // Print the title of this example. + const std::string title = "Biquad Operation Powered by Buddy Compiler"; + std::cout << "\033[33;1m" << title << "\033[0m" << std::endl; + + // Generate the kernel for a biquad filter operation. + // Params: + // Input kernel: Stores generated kernel data. + // Frequency: Normalized frequency (frequency_Hz / samplerate_Hz). + // Quality factor: Defines the filter's bandwidth relative to its + // center frequency. intptr_t kernelSize = 6; MemRef kernel(&kernelSize); - dap::biquadLowpass(kernel, 0.3, -1.0); - auto aud = dap::Audio(fileName); - aud.getAudioFile().printSummary(); - dap::Audio output; - output.fetchMetadata(aud.getAudioFile()); - output.getAudioFile().setAudioBuffer(nullptr); + dap::biquadLowpass(kernel, /*frequency=*/0.3, /*Q=*/-1.0); + + // Initialize data containers. + // Params: + // Input container: Stores the raw audio data. + // Returns: + // Output memory reference: Provides a MemRef for saving the output. + Audio inputContainer("../../tests/Interface/core/TestAudio.wav"); + intptr_t samplesNum = static_cast(inputContainer.getSamplesNum()); + MemRef outputMemRef(&samplesNum); + + // Apply the biquad filter operation to the audio data. + printLogLabel(); + std::cout << "Running biquad operation..." << std::endl; + const auto loadStart = std::chrono::high_resolution_clock::now(); + dap::biquad(&inputContainer, &kernel, &outputMemRef); + const auto loadEnd = std::chrono::high_resolution_clock::now(); + const std::chrono::duration loadTime = + loadEnd - loadStart; + printLogLabel(); + std::cout << "Audio processing time: " << (double)(loadTime.count()) / 1000 + << "s\n" + << std::endl; + + // Convert a MemRef object to an Audio object and set the metadata. + Audio outputContainer(std::move(outputMemRef)); + outputContainer.setBitDepth(inputContainer.getBitDepth()); + outputContainer.setSamplesNum(inputContainer.getSamplesNum()); + outputContainer.setChannelsNum(inputContainer.getChannelsNum()); + outputContainer.setSampleRate(inputContainer.getSampleRate()); - dap::biquad(&aud.getMemRef(), &kernel, &output.getMemRef()); + // Save the processed data to an audio file. + std::string saveFileName = "BiquadTestAudio.wav"; + outputContainer.saveToFile(saveFileName, "wave"); + printLogLabel(); + std::cout << "Processed audio data saved in: " << saveFileName << "\n" + << std::endl; - cout << "Saving file:" << endl; - cout << (output.save(saveFileName) ? "OK" : "ERROR") << endl; return 0; } diff --git a/examples/DIPDialect/CMakeLists.txt b/examples/DIPDialect/CMakeLists.txt index 27abb889f2..2f897ad633 100644 --- a/examples/DIPDialect/CMakeLists.txt +++ b/examples/DIPDialect/CMakeLists.txt @@ -1,4 +1,4 @@ -set(DIP_LIBS ${JPEG_LIBRARY} ${PNG_LIBRARY} BuddyLibDIP) +set(DIP_LIBS ${JPEG_LIBRARY} ${PNG_LIBRARIES} BuddyLibDIP) if(BUDDY_ENABLE_OPENCV) find_package(OpenCV REQUIRED CONFIG) @@ -25,3 +25,9 @@ target_link_libraries(rotation2D ${DIP_LIBS}) add_executable(resize2D resize2D.cpp) target_link_libraries(resize2D ${DIP_LIBS}) + +add_executable(resize4D_nhwc resize4D_nhwc.cpp) +target_link_libraries(resize4D_nhwc ${DIP_LIBS}) + +add_executable(resize4D_nchw resize4D_nchw.cpp) +target_link_libraries(resize4D_nchw ${DIP_LIBS}) diff --git a/examples/DIPDialect/resize4D_nchw.cpp b/examples/DIPDialect/resize4D_nchw.cpp new file mode 100644 index 0000000000..95d77cc27d --- /dev/null +++ b/examples/DIPDialect/resize4D_nchw.cpp @@ -0,0 +1,58 @@ +//====- resize4D.cpp - Example of buddy-opt tool =============================// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements a 4D resize example with dip.resize_4d operation. +// The dip.resize_4d operation will be compiled into an object file with the +// buddy-opt tool. +// This file will be linked with the object file to generate the executable +// file. +// +//===----------------------------------------------------------------------===// +#include "buddy/DIP/imgcodecs/loadsave.h" +#include +#include +#include +#include +#include +#include + +using namespace std; + +void testImplementation(int argc, char *argv[]) { + // Read as colar image. + dip::Image inputBatch(argv[1], dip::DIP_RGB, true); + + // Note : Both values in output image dimensions and scaling ratios must be + // positive numbers. + MemRef output = dip::Resize4D_NCHW( + &inputBatch, dip::INTERPOLATION_TYPE::BILINEAR_INTERPOLATION, + {1, 3, 224, 224} /*{image_cols, image_rows}*/); + + // Define Img with the output of Resize4D. + intptr_t outSizes[3] = {output.getSizes()[2], output.getSizes()[3], + output.getSizes()[1]}; + + Img outputImageResize4D(output.getData(), outSizes); + + // dip::imwrite(argv[2], outputImageResize4D); + + return; +} + +int main(int argc, char *argv[]) { + testImplementation(argc, argv); + return 0; +} diff --git a/examples/DIPDialect/resize4D_nhwc.cpp b/examples/DIPDialect/resize4D_nhwc.cpp new file mode 100644 index 0000000000..affb8a8a09 --- /dev/null +++ b/examples/DIPDialect/resize4D_nhwc.cpp @@ -0,0 +1,61 @@ +//====- resize4D.cpp - Example of buddy-opt tool =============================// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements a 4D resize example with dip.resize_4d operation. +// The dip.resize_4d operation will be compiled into an object file with the +// buddy-opt tool. +// This file will be linked with the object file to generate the executable +// file. +// +//===----------------------------------------------------------------------===// +#include "buddy/DIP/imgcodecs/loadsave.h" +#include +#include +#include +#include +#include + +using namespace std; + +void testImplementation(int argc, char *argv[]) { + // Read as colar image. + Img input = dip::imread(argv[1], dip::IMGRD_COLOR); + + intptr_t sizes[4] = {1, input.getSizes()[0], input.getSizes()[1], + input.getSizes()[2]}; + Img inputBatch(input.getData(), sizes); + + // Note : Both values in output image dimensions and scaling ratios must be + // positive numbers. + MemRef output = dip::Resize4D_NHWC( + &inputBatch, dip::INTERPOLATION_TYPE::BILINEAR_INTERPOLATION, + {1, 224, 224, 3} /*{image_cols, image_rows}*/); + + // Define Img with the output of Resize4D. + intptr_t outSizes[3] = {output.getSizes()[1], output.getSizes()[2], + output.getSizes()[3]}; + + Img outputImageResize4D(output.getData(), outSizes); + + dip::imwrite(argv[2], outputImageResize4D); + + return; +} + +int main(int argc, char *argv[]) { + testImplementation(argc, argv); + return 0; +} diff --git a/examples/MLIRCF/.gitignore b/examples/MLIRCF/.gitignore new file mode 100644 index 0000000000..790429d34e --- /dev/null +++ b/examples/MLIRCF/.gitignore @@ -0,0 +1,3 @@ +log* +core +a.out diff --git a/examples/MLIRCF/cf-iteration-exit.mlir b/examples/MLIRCF/cf-iteration-exit.mlir new file mode 100644 index 0000000000..89281c9e34 --- /dev/null +++ b/examples/MLIRCF/cf-iteration-exit.mlir @@ -0,0 +1,47 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// The example is equivalent to the following code. +// int main() { +// int val = 0; +// for (int i = 1; i < 5; i++) { +// val += 5; +// if (i == 3) { +// std::cout << val << std::endl; +// return 0; +// } +// } +// return 0; +// } + +module { + func.func @main() { + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c1 = arith.constant 1 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + %cst_5 = arith.constant 5.000000e+00 : f32 + cf.br ^bb1(%c0, %cst_0 : index, f32) + ^bb1(%0: index, %1: f32): + %2 = arith.cmpi slt, %0, %c5 : index + cf.cond_br %2, ^bb2, ^bb4(%1: f32) + ^bb2: + %3 = arith.addf %1, %cst_5 : f32 + %4 = arith.addi %0, %c1 : index + cf.br ^bb3 (%4, %3 : index, f32) + ^bb3(%iter_idx: index, %iter_var: f32): + %eq = arith.cmpi eq, %iter_idx, %c3 : index + cf.cond_br %eq, ^bb4(%iter_var: f32), ^bb1(%iter_idx, %iter_var: index, f32) + ^bb4(%ret_var: f32): + // CHECK: 15 + vector.print %ret_var : f32 + return + } +} diff --git a/examples/MLIRCF/makefile b/examples/MLIRCF/makefile new file mode 100644 index 0000000000..5837ebf442 --- /dev/null +++ b/examples/MLIRCF/makefile @@ -0,0 +1,44 @@ +#!/bin/bash +BUDDY_OPT := ../../build/bin/buddy-opt +MLIR_OPT := ../../llvm/build/bin/mlir-opt +MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate +MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner +LLC := ../../llvm/build/bin/llc +OPT_FLAG := -O0 +CLANG := ../../llvm/build//bin/clang +MLIR_LIB := ../../llvm/build/lib/ +BUDDY_LIB := ../../build/midend/lib/ + +ifeq ($(shell uname),Linux) +MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.so +MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.so +MLIR_ASYNC_RUNTIME := ../../llvm/build/lib/libmlir_async_runtime.so +MTRIPLE := x86_64-unknown-linux-gnu +else ifeq ($(shell uname),Darwin) +MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.dylib +MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.dylib +MLIR_ASYNC_RUNTIME := ./../llvm/build/lib/libmlir_async_runtime.dylib +MTRIPLE := x86_64-apple-darwin +endif + +cf-iteration-exit-lower: + @${MLIR_OPT} ./cf-iteration-exit.mlir \ + -convert-vector-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts \ + -o ./log.mlir + +cf-iteration-exit-translate: + @${MLIR_OPT} ./cf-iteration-exit.mlir \ + -convert-vector-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +cf-iteration-exit-run: + @${MLIR_OPT} ./cf-iteration-exit.mlir \ + -convert-vector-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir new file mode 100644 index 0000000000..04dea80df6 --- /dev/null +++ b/examples/MLIRLinalg/linalg-batch-matmul-dync.mlir @@ -0,0 +1,65 @@ +// RUN: buddy-opt %s -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ +// RUN: -convert-linalg-to-loops -expand-strided-metadata -lower-affine \ +// RUN: -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm \ +// RUN: -convert-arith-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + + // Definition for the batch matrix multiplication function + func.func @buddy_batchmatmul_f32(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul + ins(%A, %B: memref, memref) + outs(%C: memref) + return + } + + func.func @main(){ + // Set up dims. + %cBatch = arith.constant 2:index + %cM = arith.constant 2 : index + %cN = arith.constant 3 : index + %cK = arith.constant 4 : index + + // Set Init Value. + %cf1 = arith.constant 1.0 : f32 + %cf2 = arith.constant 2.0 : f32 + %c0 = arith.constant 0.0 : f32 + + %A = memref.alloc(%cBatch,%cM, %cK) : memref + %B = memref.alloc(%cBatch,%cK, %cN) : memref + %C = memref.alloc(%cBatch,%cM, %cN) : memref + + linalg.fill + ins(%cf1 : f32) + outs(%A:memref) + + linalg.fill + ins(%cf2 : f32) + outs(%B:memref) + + linalg.fill + ins(%c0 : f32) + outs(%C:memref) + + call @buddy_batchmatmul_f32(%A, %B, %C) : (memref, memref, memref) -> () + + %print_C = memref.cast %C : memref to memref<*xf32> + call @printMemrefF32(%print_C) : (memref<*xf32>) -> () + + memref.dealloc %C : memref + memref.dealloc %B : memref + memref.dealloc %A : memref + return + } +} + +// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [2, 2, 3] strides = [6, 3, 1] data = +// CHECK{LITERAL}: [[[8, 8, 8], +// CHECK{LITERAL}: [8, 8, 8]], +// CHECK{LITERAL}: [[8, 8, 8], +// CHECK{LITERAL}: [8, 8, 8]]] diff --git a/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir new file mode 100644 index 0000000000..ea81007153 --- /dev/null +++ b/examples/MLIRLinalg/linalg-conv2d_nhwc_fhwc.mlir @@ -0,0 +1,82 @@ +// RUN: buddy-opt %s \ +// RUN: -conv-nhwc-fhwc-optimize -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + func.func @alloc_2d_filled_f32(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) -> memref { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = memref.alloc(%arg0, %arg1, %arg2, %arg3) : memref + scf.for %arg5 = %c0 to %arg0 step %c1 { + scf.for %arg6 = %c0 to %arg1 step %c1 { + scf.for %arg7 = %c0 to %arg2 step %c1 { + scf.for %arg8 = %c0 to %arg3 step %c1 { + %iarg8=arith.index_cast %arg8 : index to i32 + %loopf= arith.sitofp %iarg8 : i32 to f32 + memref.store %loopf, %0[%arg5, %arg6, %arg7, %arg8] : memref + } + } + } + } + return %0 : memref + } + func.func @conv_2d_nhwc_fhwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_2d_nhwc_fhwc ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) + return + } + func.func @main() { + // Intput(image, filter) and output value. + %cst = arith.constant 0.500000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + + %current_image_n = arith.constant 1 : index + %current_image_c = arith.constant 2 : index + %current_image_h = arith.constant 4 : index + %current_image_w = arith.constant 4 : index + + %current_filter_f = arith.constant 2 : index + %current_filter_c = arith.constant 2 : index + %current_filter_h = arith.constant 2 : index + %current_filter_w = arith.constant 2 : index + + %current_output_n = arith.constant 1 : index + %current_output_c = arith.constant 2 : index + %current_output_h = arith.constant 3 : index + %current_output_w = arith.constant 3 : index + + // Image. + %image = call @alloc_2d_filled_f32(%current_image_n,%current_image_h, %current_image_w, %current_image_c, %cst) : (index, index, index, index, f32) -> memref + // Filter. + %filter = call @alloc_2d_filled_f32(%current_filter_f, %current_filter_h, %current_filter_w,%current_filter_c, %cst) : (index, index, index, index, f32) -> memref + // Output. + %output = call @alloc_2d_filled_f32(%current_output_n, %current_output_h, %current_output_w,%current_output_c, %cst_0) : (index, index, index, index, f32) -> memref + + call @conv_2d_nhwc_fhwc(%image, %filter, %output) : (memref, memref, memref) -> () + + %3 = memref.cast %output : memref to memref<*xf32> + call @printMemrefF32(%3) : (memref<*xf32>) -> () + + + memref.dealloc %output : memref + memref.dealloc %image : memref + memref.dealloc %filter : memref + return + } +} + +// CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 3, 3, 2] strides = [18, 6, 2, 1] data = +// CHECK{LITERAL}: [[[[4, 5], +// CHECK{LITERAL}: [4, 5], +// CHECK{LITERAL}: [4, 5]], +// CHECK{LITERAL}: [[4, 5], +// CHECK{LITERAL}: [4, 5], +// CHECK{LITERAL}: [4, 5]], +// CHECK{LITERAL}: [[4, 5], +// CHECK{LITERAL}: [4, 5], +// CHECK{LITERAL}: [4, 5]]]] diff --git a/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir new file mode 100644 index 0000000000..905df48bd8 --- /dev/null +++ b/examples/MLIRLinalg/linalg-depthwise_conv_2d_nhwc_hwc.mlir @@ -0,0 +1,82 @@ +// RUN: buddy-opt %s \ +// RUN: -depthwise-conv-nhwc-hwc-optimize -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module { + func.func private @printMemrefF32(memref<*xf32>) + + func.func @depthwise_conv_2d_nhwc_hwc(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.depthwise_conv_2d_nhwc_hwc + {dilations = dense<[1,1]> : tensor<2xi64>, strides = dense<[1,1]> : tensor<2xi64>} + ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) + return + } + + func.func @main() { + // Constants for input image, filter, and output sizes. + %cst = arith.constant 0.500000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %cf1 = arith.constant 1.0 : f32 + + %image_n = arith.constant 1 : index + %image_h = arith.constant 4 : index + %image_w = arith.constant 4 : index + %image_c = arith.constant 2 : index + + %filter_h = arith.constant 1 : index + %filter_w = arith.constant 2 : index + %filter_c = arith.constant 2 : index + + %output_n = arith.constant 1 : index + %output_h = arith.constant 3 : index + %output_w = arith.constant 3 : index + %output_c = arith.constant 2 : index + + %image = memref.alloc(%image_n,%image_h,%image_w,%image_c) : memref + %filter = memref.alloc(%filter_h,%filter_w,%filter_c) : memref + %output = memref.alloc(%output_n,%output_h,%output_w,%output_c) : memref + + // Allocate and fill image, filter, and output. + linalg.fill + ins(%cf1 : f32) + outs(%image:memref) + + linalg.fill + ins(%cf1 : f32) + outs(%filter:memref) + linalg.fill + ins(%cf1 : f32) + outs(%output:memref) + + // Call depthwise convolution. + call @depthwise_conv_2d_nhwc_hwc(%image, %filter, %output) : (memref, memref, memref) -> () + + %output_cast = memref.cast %output : memref to memref<*xf32> + + // Print the output. + call @printMemrefF32(%output_cast) : (memref<*xf32>) -> () + + // Deallocate memory. + memref.dealloc %output : memref + memref.dealloc %image : memref + memref.dealloc %filter : memref + return + } +} + +// CHECK: Unranked Memref base@ = {{.*}} rank = 4 offset = 0 sizes = [1, 3, 3, 2] strides = [18, 6, 2, 1] data = +// CHECK{LITERAL}: [[[[3, 3], +// CHECK{LITERAL}: [3, 3], +// CHECK{LITERAL}: [3, 3]], +// CHECK{LITERAL}: [[3, 3], +// CHECK{LITERAL}: [3, 3], +// CHECK{LITERAL}: [3, 3]], +// CHECK{LITERAL}: [[3, 3], +// CHECK{LITERAL}: [3, 3], +// CHECK{LITERAL}: [3, 3]]]] diff --git a/examples/MLIRLinalg/linalg-matmul-opt-f32.mlir b/examples/MLIRLinalg/linalg-matmul-opt-f32.mlir index 5111b57dbe..53148b0d0a 100644 --- a/examples/MLIRLinalg/linalg-matmul-opt-f32.mlir +++ b/examples/MLIRLinalg/linalg-matmul-opt-f32.mlir @@ -1,4 +1,4 @@ -// RUN: buddy-opt -matmul-paralell-vectorization-optimize -verify-diagnostics -expand-strided-metadata -lower-affine \ +// RUN: buddy-opt -matmul-parallel-vectorization-optimize -verify-diagnostics -expand-strided-metadata -lower-affine \ // RUN: -convert-linalg-to-loops -convert-vector-to-scf -convert-scf-to-cf -convert-vector-to-llvm -finalize-memref-to-llvm \ // RUN: -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts %s \ // RUN: | mlir-cpu-runner -O0 -e buddy_matmul_f32 -entry-point-result=void \ diff --git a/examples/MLIRLinalg/linalg-matmul-opt-i8.mlir b/examples/MLIRLinalg/linalg-matmul-opt-i8.mlir index 9a7b72e5e0..26aa92cbe5 100644 --- a/examples/MLIRLinalg/linalg-matmul-opt-i8.mlir +++ b/examples/MLIRLinalg/linalg-matmul-opt-i8.mlir @@ -1,4 +1,4 @@ -// RUN: buddy-opt -matmul-paralell-vectorization-optimize -verify-diagnostics -expand-strided-metadata \ +// RUN: buddy-opt -matmul-parallel-vectorization-optimize -verify-diagnostics -expand-strided-metadata \ // RUN: -lower-affine -convert-vector-to-llvm -finalize-memref-to-llvm -convert-scf-to-cf \ // RUN: -convert-linalg-to-loops -convert-scf-to-cf -llvm-request-c-wrappers -convert-func-to-llvm \ // RUN: -reconcile-unrealized-casts %s \ diff --git a/examples/MLIRLinalg/makefile b/examples/MLIRLinalg/makefile index f214fa7f67..d9a37926f4 100644 --- a/examples/MLIRLinalg/makefile +++ b/examples/MLIRLinalg/makefile @@ -60,6 +60,46 @@ linalg-conv2d-tiling-run: -convert-func-to-llvm -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} +linalg-conv2d_nhwc_fhwc-optimize-lower: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ + -conv-nhwc-fhwc-optimize="vec-size=16" \ + -o ./log.mlir + +linalg-conv2d_nhwc_fhwc-optimize-run: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ + -conv-nhwc-fhwc-optimize="vec-size=16" \ + -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + + +linalg-conv2d_nhwc_fhwc-tile-optimize-lower: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir \ + -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ + -o ./log.mlir + +linalg-conv2d_nhwc_fhwc-tile-optimize-run: + @${BUDDY_OPT} linalg-conv2d_nhwc_fhwc.mlir ${MLIR_OPT_OPTIONS} \ + -conv-nhwc-fhwc-tile-optimize="vec-size=16 tiling-height=2 tiling-width=3" \ + -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-depthwise_conv_2d_nhwc_hwc-optimize-lower: + @${BUDDY_OPT} linalg-depthwise_conv_2d_nhwc_hwc.mlir \ + -depthwise-conv-nhwc-hwc-optimize="vec-size=16" \ + -o ./log.mlir + +linalg-depthwise_conv_2d_nhwc_hwc-optimize-run: + @${BUDDY_OPT} linalg-depthwise_conv_2d_nhwc_hwc.mlir \ + -depthwise-conv-nhwc-hwc-optimize="vec-size=16" \ + -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ + -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ + -convert-func-to-llvm -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + linalg-generic-lower: @${MLIR_OPT} ./linalg-generic.mlir \ -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ @@ -177,6 +217,46 @@ linalg-batch-matmul-optimize-lower: -batchmatmul-optimize="vector-size=64" \ -o ./log.mlir +linalg-batch-matmul-tile-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ + -o ./log.mlir + +linalg-batch-matmul-tile-optimize-run: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-tile-optimize="vec-size=64 kernel-m=4 kernel-n=2" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +linalg-batch-matmul-scf-optimize-lower: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-scf-optimize="vector-size=64" \ + -o ./log.mlir + +linalg-batch-matmul-scf-optimize-run: + @${BUDDY_OPT} linalg-batch-matmul-dync.mlir ${MLIR_OPT_OPTIONS} \ + -batchmatmul-scf-optimize="vector-size=64" \ + -convert-linalg-to-loops \ + -expand-strided-metadata \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + linalg-batch-matmul-optimize-translate: @${BUDDY_OPT} linalg-batch-matmul-f32.mlir ${MLIR_OPT_OPTIONS} \ -batchmatmul-optimize="vector-size=64" \ @@ -248,7 +328,7 @@ linalg-batch-matmul-i8-optimize-translate: linalg-matmul-parallized-vectorized-optmize-run: @${BUDDY_OPT} linalg-matmul-opt-f32.mlir ${MLIR_OPT_OPTIONS} \ - -matmul-paralell-vectorization-optimize="vector-size=128" \ + -matmul-parallel-vectorization-optimize="vector-size=128" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ @@ -263,12 +343,12 @@ linalg-matmul-parallized-vectorized-optmize-run: linalg-matmul-parallized-vectorized-optmize-lower: @${BUDDY_OPT} linalg-matmul-opt-f32.mlir ${MLIR_OPT_OPTIONS} \ - -matmul-paralell-vectorization-optimize="vector-size=128" \ + -matmul-parallel-vectorization-optimize="vector-size=128" \ -o ./log.mlir linalg-matmul-parallized-vectorized-optmize-translate: @${BUDDY_OPT} linalg-matmul-opt-f32.mlir ${MLIR_OPT_OPTIONS} \ - -matmul-paralell-vectorization-optimize="vector-size=128" \ + -matmul-parallel-vectorization-optimize="vector-size=128" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ @@ -282,7 +362,7 @@ linalg-matmul-parallized-vectorized-optmize-translate: linalg-matmul-i8-parallized-vectorized-optmize-run: @${BUDDY_OPT} linalg-matmul-opt-i8.mlir ${MLIR_OPT_OPTIONS} \ - -matmul-paralell-vectorization-optimize="vector-size=128" \ + -matmul-parallel-vectorization-optimize="vector-size=128" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ @@ -297,12 +377,12 @@ linalg-matmul-i8-parallized-vectorized-optmize-run: linalg-matmul-i8-parallized-vectorized-optmize-lower: @${BUDDY_OPT} linalg-matmul-opt-i8.mlir ${MLIR_OPT_OPTIONS} \ - -matmul-paralell-vectorization-optimize="vector-size=128" \ + -matmul-parallel-vectorization-optimize="vector-size=128" \ -o ./log.mlir linalg-matmul-i8-parallized-vectorized-optmize-translate: @${BUDDY_OPT} linalg-matmul-opt-i8.mlir ${MLIR_OPT_OPTIONS} \ - -matmul-paralell-vectorization-optimize="vector-size=128" \ + -matmul-parallel-vectorization-optimize="vector-size=128" \ -convert-linalg-to-loops \ -expand-strided-metadata \ -lower-affine \ diff --git a/examples/MLIRVector/makefile b/examples/MLIRVector/makefile index 681335c7fd..ccc9e9af24 100644 --- a/examples/MLIRVector/makefile +++ b/examples/MLIRVector/makefile @@ -43,17 +43,20 @@ vector-load-run: vector-broadcast-lower: @${MLIR_OPT} ./vector-broadcast.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts -o ./log.mlir vector-broadcast-translate: @${MLIR_OPT} ./vector-broadcast.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll vector-broadcast-asm-x86: @${MLIR_OPT} ./vector-broadcast.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir | \ @@ -62,6 +65,7 @@ vector-broadcast-asm-x86: vector-broadcast-asm-rv: @${MLIR_OPT} ./vector-broadcast.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir | \ @@ -72,6 +76,7 @@ vector-broadcast-asm-rv: run-targets += vector-broadcast-run vector-broadcast-run: @${MLIR_OPT} ./vector-broadcast.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -79,17 +84,20 @@ vector-broadcast-run: vector-fma-lower: @${MLIR_OPT} ./vector-fma.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts -o ./log.mlir vector-fma-translate: @${MLIR_OPT} ./vector-fma.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll vector-fma-asm-x86: @${MLIR_OPT} ./vector-fma.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir | \ @@ -98,6 +106,7 @@ vector-fma-asm-x86: vector-fma-asm-rv: @${MLIR_OPT} ./vector-fma.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir | \ @@ -108,6 +117,7 @@ vector-fma-asm-rv: run-targets += vector-fma-run vector-fma-run: @${MLIR_OPT} ./vector-fma.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -115,17 +125,20 @@ vector-fma-run: vector-long-lower: @${MLIR_OPT} ./vector-long.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts -o ./log.mlir vector-long-translate: @${MLIR_OPT} ./vector-long.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll vector-long-asm-x86: @${MLIR_OPT} ./vector-long.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir | \ @@ -134,6 +147,7 @@ vector-long-asm-x86: vector-long-asm-rv: @${MLIR_OPT} ./vector-long.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_TRANSLATE} --mlir-to-llvmir | \ @@ -144,6 +158,7 @@ vector-long-asm-rv: run-targets += vector-long-run vector-long-run: @${MLIR_OPT} ./vector-long.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -187,6 +202,7 @@ vector-shape-cast-translate: run-targets += vector-shape-cast-run vector-shape-cast-run: @${MLIR_OPT} ./vector-shape-cast.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -209,6 +225,7 @@ vector-type-cast-translate: run-targets += vector-type-cast-run vector-type-cast-run: @${MLIR_OPT} ./vector-type-cast.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -253,6 +270,7 @@ vector-shuffle-translate: run-targets += vector-shuffle-run vector-shuffle-run: @${MLIR_OPT} ./vector-shuffle.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -275,6 +293,7 @@ vector-splat-translate: run-targets += vector-splat-run vector-splat-run: @${MLIR_OPT} ./vector-splat.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -297,6 +316,7 @@ vector-insert-translate: run-targets += vector-insert-run vector-insert-run: @${MLIR_OPT} ./vector-insert.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -319,6 +339,7 @@ vector-reduction-translate: run-targets += vector-reduction-run vector-reduction-run: @${MLIR_OPT} ./vector-reduction.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -341,6 +362,7 @@ vector-outerproduct-translate: run-targets += vector-outerproduct-run vector-outerproduct-run: @${MLIR_OPT} ./vector-outerproduct.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -363,6 +385,7 @@ vector-create-mask-translate: run-targets += vector-create-mask-run vector-create-mask-run: @${MLIR_OPT} ./vector-create-mask.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -384,6 +407,7 @@ vector-extract-translate: run-targets += vector-extract-run vector-extract-run: @${MLIR_OPT} ./vector-extract.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -405,6 +429,7 @@ vector-maskedload-translate: run-targets += vector-maskedload-run vector-maskedload-run: @${MLIR_OPT} ./vector-maskedload.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -427,6 +452,7 @@ vector-maskedstore-translate: run-targets += vector-maskedstore-run vector-maskedstore-run: @${MLIR_OPT} ./vector-maskedstore.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -449,6 +475,7 @@ vector-extract-strided-slice-translate: run-targets += vector-extract-strided-slice-run vector-extract-strided-slice-run: @${MLIR_OPT} ./vector-extract-strided-slice.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -470,6 +497,7 @@ vector-constant-mask-translate: run-targets += vector-constant-mask-run vector-constant-mask-run: @${MLIR_OPT} ./vector-constant-mask.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -491,6 +519,7 @@ vector-expandload-translate: run-targets += vector-expandload-run vector-expandload-run: @${MLIR_OPT} ./vector-expandload.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -512,6 +541,7 @@ vector-compressstore-translate: run-targets += vector-compressstore-run vector-compressstore-run: @${MLIR_OPT} ./vector-compressstore.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -533,6 +563,7 @@ vector-insert-strided-slice-translate: run-targets += vector-insert-strided-slice-run vector-insert-strided-slice-run: @${MLIR_OPT} ./vector-insert-strided-slice.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -554,6 +585,7 @@ vector-scatter-translate: run-targets += vector-scatter-run vector-scatter-run: @${MLIR_OPT} ./vector-scatter.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -576,6 +608,7 @@ vector-gather-translate: run-targets += vector-gather-run vector-gather-run: @${MLIR_OPT} ./vector-gather.mlir \ + -convert-vector-to-scf -convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ -split-input-file -verify-diagnostics \ --reconcile-unrealized-casts | \ @@ -598,7 +631,7 @@ vector-transfer-read-translate: run-targets += vector-transfer-read-run vector-transfer-read-run: @${MLIR_OPT} ./vector-transfer-read.mlir \ - --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ + --convert-vector-to-scf --lower-affine --convert-scf-to-cf \ --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ @@ -669,3 +702,27 @@ vector-store-run: --reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +vector-iteration-lower: + @${MLIR_OPT} ./vector-iteration.mlir \ + --lower-affine \ + -convert-vector-to-scf -convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts -o ./log.mlir + +vector-iteration-translate: + @${MLIR_OPT} ./vector-iteration.mlir \ + --lower-affine \ + -convert-vector-to-scf -convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o log.ll + +vector-iteration-run: + @${MLIR_OPT} ./vector-iteration.mlir \ + --lower-affine \ + -convert-vector-to-scf -convert-scf-to-cf \ + --convert-vector-to-llvm --finalize-memref-to-llvm --convert-func-to-llvm \ + --reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=i32 \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} diff --git a/examples/MLIRVector/vector-iteration.mlir b/examples/MLIRVector/vector-iteration.mlir new file mode 100644 index 0000000000..7d63f22896 --- /dev/null +++ b/examples/MLIRVector/vector-iteration.mlir @@ -0,0 +1,128 @@ +// RUN: buddy-opt %s \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +memref.global "private" @gv : memref<4x4xf32> = dense<[[0. , 1. , 2. , 3. ], + [10., 11., 12., 13.], + [20., 21., 22., 23.], + [30., 31., 32., 33.]]> + +memref.global "private" @gv_pat_1 : memref<10xf32> = dense<[0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9.]> +memref.global "private" @gv_pat_2 : memref<10xf32> = dense<[0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9.]> + +func.func private @printMemrefF32(memref<*xf32>) + +func.func @main() -> i32 { + %mem = memref.get_global @gv : memref<4x4xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %sum_0 = arith.constant dense<0.000000e+00> : vector<4xf32> + %sum = affine.for %i = 0 to 3 iter_args(%sum_iter = %sum_0) -> (vector<4xf32>) { + %load_vec1 = vector.load %mem[%c0, %c0] : memref<4x4xf32>, vector<4xf32> + %load_vec2 = vector.load %mem[%i, %c0] : memref<4x4xf32>, vector<4xf32> + %sum_next = vector.fma %load_vec1, %load_vec2, %sum_iter : vector<4xf32> + affine.yield %sum_next : vector<4xf32> + } + // CHECK: ( 0, 33, 72, 117 ) + vector.print %sum : vector<4xf32> + + // --------------------------------------------------------------------------- + // Iteration Pattern 1 + // Main Vector Loop + Scalar Remainder + Fixed Vector Type + // --------------------------------------------------------------------------- + + // 1. Get the total length of the workload. + %mem_pat_1 = memref.get_global @gv_pat_1 : memref<10xf32> + %print_mem_pat_1 = memref.cast %mem_pat_1 : memref<10xf32> to memref<*xf32> + %vl_total_pat_1 = memref.dim %mem_pat_1, %c0 : memref<10xf32> + + // 2. Set the iteration step (vector size). + %vl_step_pat_1 = arith.constant 4 : index + + // 3. Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length is divisible + // by the vector size. + %vl_upbound_pat_1_ = arith.subi %vl_total_pat_1, %vl_step_pat_1 : index + %vl_upbound_pat_1 = arith.addi %vl_upbound_pat_1_, %c1 : index + + // 4. Perform the vectorization body. + %iter_idx_pat_1 = scf.for %i = %c0 to %vl_upbound_pat_1 step %vl_step_pat_1 + iter_args(%iter_init = %c0) -> (index) { + %load_vec1 = vector.load %mem_pat_1[%i] : memref<10xf32>, vector<4xf32> + %load_vec2 = vector.load %mem_pat_1[%i] : memref<10xf32>, vector<4xf32> + %res = arith.addf %load_vec1, %load_vec2 : vector<4xf32> + vector.store %res, %mem_pat_1[%i] : memref<10xf32>, vector<4xf32> + %i_next = arith.addi %i, %vl_step_pat_1 : index + scf.yield %i_next : index + } + // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9] + call @printMemrefF32(%print_mem_pat_1) : (memref<*xf32>) -> () + + // 5. Process the remainder of the elements with scalar operations. + scf.for %i = %iter_idx_pat_1 to %vl_total_pat_1 step %c1 { + %ele1 = memref.load %mem_pat_1[%i] : memref<10xf32> + %ele2 = memref.load %mem_pat_1[%i] : memref<10xf32> + %res = arith.addf %ele1, %ele2 : f32 + memref.store %res, %mem_pat_1[%i] : memref<10xf32> + } + // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + call @printMemrefF32(%print_mem_pat_1) : (memref<*xf32>) -> () + + // --------------------------------------------------------------------------- + // Iteration Pattern 2 + // Main Vector Loop + Masked Vector Remainder + Fixed Vector Type + // --------------------------------------------------------------------------- + + // 1. Get the total length of the workload. + %mem_pat_2 = memref.get_global @gv_pat_2 : memref<10xf32> + %print_mem_pat_2 = memref.cast %mem_pat_2 : memref<10xf32> to memref<*xf32> + %vl_total_pat_2 = memref.dim %mem_pat_2, %c0 : memref<10xf32> + + // 2. Set the iteration step (vector size). + %vl_step_pat_2 = arith.constant 4 : index + + // 3. Calculate the upper bound for vectorized processing + // - Subtract `vl_step` is to avoid overflow at the vectorization tail. + // - Add 1 to ensure the final loop runs when the workload length is divisible + // by the vector size. + %vl_upbound_pat_2_ = arith.subi %vl_total_pat_2, %vl_step_pat_2 : index + %vl_upbound_pat_2 = arith.addi %vl_upbound_pat_2_, %c1 : index + + // 4. Perform the vectorization body. + %iter_idx_pat_2 = scf.for %i = %c0 to %vl_upbound_pat_2 step %vl_step_pat_2 + iter_args(%iter_init = %c0) -> (index) { + %load_vec1 = vector.load %mem_pat_2[%i] : memref<10xf32>, vector<4xf32> + %load_vec2 = vector.load %mem_pat_2[%i] : memref<10xf32>, vector<4xf32> + %res = arith.addf %load_vec1, %load_vec2 : vector<4xf32> + vector.store %res, %mem_pat_2[%i] : memref<10xf32>, vector<4xf32> + %i_next = arith.addi %i, %vl_step_pat_1 : index + scf.yield %i_next : index + } + // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9] + call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> () + + // 5. Compute the tail size and create mask and pass-through vector for the + // remaining elements. + %tail_size_pat_2 = arith.subi %vl_total_pat_2, %iter_idx_pat_2 : index + %mask_pat_2 = vector.create_mask %tail_size_pat_2 : vector<4xi1> + %pass_thr_vec = arith.constant dense<0.> : vector<4xf32> + + // 6. Process the remaining elements using masked vector operations. + %ele1 = vector.maskedload %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> + %ele2 = vector.maskedload %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %pass_thr_vec : memref<10xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> + %res = arith.addf %ele1, %ele2 : vector<4xf32> + vector.maskedstore %mem_pat_2[%iter_idx_pat_2], %mask_pat_2, %res : memref<10xf32>, vector<4xi1>, vector<4xf32> + // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + call @printMemrefF32(%print_mem_pat_2) : (memref<*xf32>) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/RVVExperiment/makefile b/examples/RVVExperiment/makefile index 5a8a28f38d..6cadb07cdc 100644 --- a/examples/RVVExperiment/makefile +++ b/examples/RVVExperiment/makefile @@ -110,19 +110,16 @@ rvv-insert-extract-intrinsics-asm: -mattr=+m,+d,+v -riscv-v-vector-bits-min=256 \ --filetype=asm -o log.s -# TODO: Fix me. rvv-c-setvl-translate: @${LOCAL_CLANG} -march=rv64gcv --target=riscv64-unknown-linux-gnu \ --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ ./rvv-c-setvl.c -fPIC -S -emit-llvm -o log.ll -# TODO: Fix me. rvv-c-setvl-asm: @${LOCAL_CLANG} -march=rv64gcv --target=riscv64-unknown-linux-gnu \ --sysroot=${RISCV_GNU_TOOLCHAIN_SYSROOT} --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ ./rvv-c-setvl.c -fPIC -S -o log.s -# TODO: Fix me. run-targets += rvv-c-setvl-run rvv-c-setvl-run: @${LOCAL_CLANG} -march=rv64gcv --target=riscv64-unknown-linux-gnu \ diff --git a/examples/RVVExperiment/rvv-c-setvl.c b/examples/RVVExperiment/rvv-c-setvl.c index c8d1ccfbb1..4a8489d55d 100644 --- a/examples/RVVExperiment/rvv-c-setvl.c +++ b/examples/RVVExperiment/rvv-c-setvl.c @@ -3,7 +3,7 @@ int main() { int avl = 70; - int vl = vsetvl_e32m2(avl); + int vl = __riscv_vsetvl_e32m2(avl); printf("vl: %d\n", vl); return 0; diff --git a/examples/VectorExpDialect/makefile b/examples/VectorExpDialect/makefile index ab85a8a2cc..fc88556419 100644 --- a/examples/VectorExpDialect/makefile +++ b/examples/VectorExpDialect/makefile @@ -319,3 +319,24 @@ vector-exp-dynamic-vector-run: -L${CROSS_MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ -o a.out @LD_LIBRARY_PATH=${CROSS_MLIR_LIB} ${QEMU} -L ${RISCV_GNU_TOOLCHAIN_SYSROOT} -cpu max a.out + +vector-exp-iteration-aot: + @${BUDDY_OPT} ./vector-exp-iteration.mlir \ + -lower-vector-exp \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-vector-to-llvm \ + -convert-index-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -finalize-memref-to-llvm \ + -reconcile-unrealized-casts | \ + ${BUDDY_TRANSLATE} -buddy-to-llvmir -o log.ll + ${LOCAL_CLANG} -O3 log.ll \ + -march=rv64gcv --target=riscv64-unknown-linux-gnu -fPIC \ + --sysroot=${RISCV_GNU_TOOLCHAIN}/sysroot \ + --gcc-toolchain=${RISCV_GNU_TOOLCHAIN} \ + -L${CROSS_MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils \ + -o a.out + diff --git a/examples/VectorExpDialect/vector-exp-iteration.mlir b/examples/VectorExpDialect/vector-exp-iteration.mlir new file mode 100644 index 0000000000..bc879d0103 --- /dev/null +++ b/examples/VectorExpDialect/vector-exp-iteration.mlir @@ -0,0 +1,57 @@ +memref.global "private" @gv : memref<10xf32> = dense<[0. , 1. , 2. , 3. , 4. , 5. , 6. , 7. , 8. , 9.]> + +func.func private @printMemrefF32(memref<*xf32>) + +func.func @main() -> i32 { + %c0 = arith.constant 0 : index + + // --------------------------------------------------------------------------- + // Iteration Pattern for RVV Dynamic Vector Length + // --------------------------------------------------------------------------- + + // 1. Get the total length of the workload. + %mem = memref.get_global @gv : memref<10xf32> + %print_mem = memref.cast %mem : memref<10xf32> to memref<*xf32> + %vl_total = memref.dim %mem, %c0 : memref<10xf32> + + // 2. Set the scale factor, iteration step, and mask. + %vs = vector.vscale + %factor = arith.constant 2 : index + %vl_step = arith.muli %vs, %factor : index + %mask = arith.constant dense<1> : vector<[2]xi1> + %vl_total_i32 = index.casts %vl_total : index to i32 + %vl_step_i32 = index.casts %vl_step : index to i32 + + // 3. Perform the vectorization. + %iter_vl = scf.for %i = %c0 to %vl_total step %vl_step + iter_args(%iter_vl_i32 = %vl_total_i32) -> (i32) { + + %load_vec1 = vector_exp.predication %mask, %iter_vl_i32 : vector<[2]xi1>, i32 { + %ele = vector.load %mem[%i] : memref<10xf32>, vector<[2]xf32> + vector.yield %ele : vector<[2]xf32> + } : vector<[2]xf32> + + %load_vec2 = vector_exp.predication %mask, %iter_vl_i32 : vector<[2]xi1>, i32 { + %ele = vector.load %mem[%i] : memref<10xf32>, vector<[2]xf32> + vector.yield %ele : vector<[2]xf32> + } : vector<[2]xf32> + + %res = "llvm.intr.vp.fadd" (%load_vec1, %load_vec2, %mask, %iter_vl_i32) : + (vector<[2]xf32>, vector<[2]xf32>, vector<[2]xi1>, i32) -> vector<[2]xf32> + + vector_exp.predication %mask, %iter_vl_i32 : vector<[2]xi1>, i32 { + vector.store %res, %mem[%i] : memref<10xf32>, vector<[2]xf32> + vector.yield + } : () -> () + + // Update dynamic vector length. + %new_vl = arith.subi %iter_vl_i32, %vl_step_i32 : i32 + scf.yield %new_vl : i32 + } + + // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, 8, 9] + call @printMemrefF32(%print_mem) : (memref<*xf32>) -> () + + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/examples/lit.cfg.py b/examples/lit.cfg.py index a1527a03a3..c1c4c05bd6 100644 --- a/examples/lit.cfg.py +++ b/examples/lit.cfg.py @@ -43,6 +43,7 @@ 'BuddyBert', 'BuddyMobileNetV3', 'BuddyResNet18', + 'BuddyGPU', 'ConvOpt', 'DAPDialect', 'DIPDialect', diff --git a/flake.lock b/flake.lock index 7bdd046777..bd79922394 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1694529238, - "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", "owner": "numtide", "repo": "flake-utils", - "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1699099776, - "narHash": "sha256-X09iKJ27mGsGambGfkKzqvw5esP1L/Rf8H3u3fCqIiU=", + "lastModified": 1722813957, + "narHash": "sha256-IAoYyYnED7P8zrBFMnmp7ydaJfwTnwcnqxUElC1I26Y=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "85f1ba3e51676fa8cc604a3d863d729026a6b8eb", + "rev": "cb9a96f23c491c081b38eab96d22fa958043c9fa", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 8f94e2aec0..c3af6d9d59 100644 --- a/flake.nix +++ b/flake.nix @@ -9,36 +9,17 @@ outputs = { self, nixpkgs, flake-utils }@inputs: let overlay = import ./nix/overlay.nix; - pkgsForSys = system: import nixpkgs { overlays = [ overlay ]; inherit system; }; in flake-utils.lib.eachDefaultSystem (system: let - pkgs = pkgsForSys system; - mkLLVMShell = pkgs.mkShell.override { stdenv = pkgs.llvmPkgs.stdenv; }; + pkgs = import nixpkgs { overlays = [ overlay ]; inherit system; }; in { # Help other use packages in this flake legacyPackages = pkgs; - devShells.default = mkLLVMShell { - buildInputs = with pkgs; [ - # buddy-mlir build tools - cmake - ninja - python3 - llvmPkgs.bintools # For ld.lld - - # buddy-mlir libraries - libjpeg - libpng - zlib-ng - ]; - - postHook = '' - export PATH="${pkgs.clang-tools}/bin:$PATH" - ''; - }; + devShells.default = pkgs.buddy-mlir.devShell; formatter = pkgs.nixpkgs-fmt; }) // diff --git a/frontend/Interfaces/buddy/Core/Container.h b/frontend/Interfaces/buddy/Core/Container.h index db8b66c179..6e3ff18d53 100644 --- a/frontend/Interfaces/buddy/Core/Container.h +++ b/frontend/Interfaces/buddy/Core/Container.h @@ -132,7 +132,7 @@ MemRef::MemRef(intptr_t sizes[N], T init) : MemRef(sizes) { template MemRef::MemRef(intptr_t sizes[N], bool needMalloc, intptr_t offset) - : offset(offset), aligned(nullptr), allocated(nullptr) { + : allocated(nullptr), aligned(nullptr), offset(offset) { for (size_t i = 0; i < N; i++) { this->sizes[i] = sizes[i]; } @@ -152,7 +152,7 @@ MemRef::MemRef(std::vector sizes, T init) : MemRef(sizes) { template MemRef::MemRef(std::vector sizes, bool needMalloc, intptr_t offset) - : offset(offset), aligned(nullptr), allocated(nullptr) { + : allocated(nullptr), aligned(nullptr), offset(offset) { if (sizes.size() != N) { throw std::runtime_error("Invalid number of dimensions."); } diff --git a/frontend/Interfaces/buddy/DAP/AudioContainer.h b/frontend/Interfaces/buddy/DAP/AudioContainer.h index 9bc9245742..7c3901e733 100644 --- a/frontend/Interfaces/buddy/DAP/AudioContainer.h +++ b/frontend/Interfaces/buddy/DAP/AudioContainer.h @@ -14,6 +14,13 @@ // //===----------------------------------------------------------------------===// // +// The audio decoding process in this file references the `AudioFile` library, +// which is hereby acknowledged. +// For the license of the `AudioFile` library, +// please see: https://github.com/adamstark/AudioFile/blob/master/LICENSE +// +//===----------------------------------------------------------------------===// +// // Audio container descriptor. // //===----------------------------------------------------------------------===// @@ -21,79 +28,592 @@ #ifndef FRONTEND_INTERFACES_BUDDY_DAP_AUDIOCONTAINER #define FRONTEND_INTERFACES_BUDDY_DAP_AUDIOCONTAINER -#include "AudioFile.h" #include "buddy/Core/Container.h" +#include +#include +#include +#include +#include namespace dap { - -// Audio container. -// - T represents the type of the elements. -// - N represents the number of audio channels (Normally would be 1 or 2). -// If N is smaller than channels from the file, only previous N channels will be -// manipulated. -template class Audio { +template class Audio : public MemRef { public: - Audio() : audioFile(), data(nullptr) {} - explicit Audio(std::string filename) : audioFile(filename), data(nullptr) {} - void fetchMetadata(const AudioFile &aud); - bool save(std::string filename); - AudioFile &getAudioFile() { - moveToAudioFile(); - return audioFile; - } - MemRef &getMemRef() { - moveToMemRef(); - return *data; - } - -protected: - void moveToMemRef(); - void moveToAudioFile(); - AudioFile audioFile; - MemRef *data; + // Constructor to initialize the Audio MemRef object with a file name. + Audio(std::string filename); + // Constructor to convert MemRef object to Audio MemRef object. Member + // variables are initialized with default values. + Audio(MemRef &&memref) noexcept; + + // Retrieve the name of the audio format. + std::string getFormatName() const { + switch (this->audioFormat) { + case AudioFormat::WAV: + return "WAV"; + default: + return "Unsupported format"; + } + } + // Returns the number of bits per sample. + int getBitDepth() const { return static_cast(this->bitsPerSample); } + // Returns the number of samples per channel. + size_t getSamplesNum() const { return this->numSamples; } + // Returns the number of audio channels. + int getChannelsNum() const { return static_cast(this->numChannels); } + // Returns the sampling rate in samples per second. + int getSampleRate() const { return static_cast(this->sampleRate); } + + // Sets the number of bits per sample. + void setBitDepth(int bitDepth) { + this->bitsPerSample = static_cast(bitDepth); + } + // Sets the number of samples per channel. + void setSamplesNum(size_t samplesNum) { this->numSamples = samplesNum; } + // Sets the number of audio channels. + void setChannelsNum(int channelsNum) { + this->numChannels = static_cast(channelsNum); + } + // Sets the sampling rate in samples per second. + void setSampleRate(int sampleRate) { + this->sampleRate = static_cast(sampleRate); + } + + // Create an Audio File with file name and format. + bool saveToFile(std::string filename, std::string format); + +private: + // Sample bit depth. + uint16_t bitsPerSample; + // Number of samples per channel. + size_t numSamples; + // Number of audio channels. + uint16_t numChannels; + // Samples per second (Hz). + uint32_t sampleRate; + // Enum to represent supported audio formats. + enum class AudioFormat { + ERROR, // Represents an error or unsupported format. + WAV, // WAV format. + } audioFormat; + // Enum to represent byte order of data. + enum class Endianness { LittleEndian, BigEndian }; + + // Decoders for multiple audio file formats. + // Decode a WAV file into MemRef format. + bool decodeWaveFile(const std::vector &fileData); + + // Encoders for multiple audio file formats. + // Encode a MemRef into WAV format. + bool EncodeWaveFile(std::vector &fileData); + + // Helper functions for decoding and data manipulation + // Find the index of a specified chunk in the audio file. + size_t getIndexOfChunk(const std::vector &fileData, + const std::string &chunkHeaderID, size_t startIndex, + Endianness endianness = Endianness::LittleEndian); + // Convert four bytes to a 32-bit integer according to byte order of data. + int32_t fourBytesToI32(const std::vector &fileData, + size_t startIndex, + Endianness endianness = Endianness::LittleEndian); + // Convert two bytes to a 16-bit integer according to byte order of data. + int16_t twoBytesToI16(const std::vector &fileData, size_t startIndex, + Endianness endianness = Endianness::LittleEndian); + // Normalize 8-bit unsigned integer sample to a range of -1.0 to 1.0. + T oneByteToSample(uint8_t data) { + return static_cast(data - 128) / static_cast(128.); + } + // Normalize 16-bit signed integer sample to a range of -1.0 to 1.0. + T twoBytesToSample(int16_t data) { + return static_cast(data) / static_cast(32768.); + } + + // Helper functions for encoding and data manipulation. + // Converts each character in the string to a byte. + void stringToBytes(std::vector &fileData, const std::string &str) { + for (size_t i = 0; i < str.size(); i++) + fileData.push_back(static_cast(str[i])); + } + // Converts a 32-bit integer to four bytes according to byte order of data. + void i32ToFourBytes(std::vector &fileData, int32_t num, + Endianness endianness = Endianness::LittleEndian); + // Converts a 16-bit integer to two bytes according to byte order of data. + void i16ToTwoBytes(std::vector &fileData, int16_t num, + Endianness endianness = Endianness::LittleEndian); + // Converts an audio sample to a 8-bit PCM format (one byte). + uint8_t sampleToOneByte(T sample); + // Converts an audio sample to a 16-bit PCM format (two bytes). + int16_t sampleToI16(T sample); }; -template bool Audio::save(std::string filename) { - if (!this->audioFile.samples) { - auto temp = this->data->release(); - if constexpr (std::is_same_v) { - for (int i = 0; i < audioFile.numSamples; i++) { - if (temp[i] != temp[i]) { // To handle NaN values - temp[i] = 0.9999999; - } else { // Clamp the values between -1.0 to 1.0 - temp[i] = std::clamp(temp[i], float(-1.0), float(0.9999999)); - } +// Audio Container Constructor. +// Constructs an audio container object from the audio file path. +template Audio::Audio(std::string filePath) { + // --------------------------------------------------------------------------- + // 1. Read the audio file into a std::vector. + // --------------------------------------------------------------------------- + // Open the file in binary mode and position the file pointer at the end of + // the file. + std::ifstream file(filePath, std::ios::binary | std::ios::ate); + // Check if the file was successfully opened. + if (!file) { + throw std::runtime_error("Error: Unable to open file at " + filePath); + } + // Get the size of the file. + size_t dataLength = file.tellg(); + // Move file pointer to the beginning of the file. + file.seekg(0, std::ios::beg); + // Create a vector to store the data. + std::vector fileData(dataLength); + // Read the data. + if (!file.read(reinterpret_cast(fileData.data()), dataLength)) { + throw std::runtime_error("Error: Unable to read data from " + filePath); + } + // --------------------------------------------------------------------------- + // 2. Determine the audio format and decode the audio data into MemRef. + // --------------------------------------------------------------------------- + std::string header(fileData.begin(), fileData.begin() + 4); + // Check the file header to determine the format. + if (header == "RIFF") { + this->audioFormat = AudioFormat::WAV; + bool success = decodeWaveFile(fileData); + if (!success) { + this->audioFormat = AudioFormat::ERROR; + throw std::runtime_error("Failed to decode WAV file from " + filePath); + }; + } else { + this->audioFormat = AudioFormat::ERROR; + throw std::runtime_error("Unsupported audio format detected in file " + + filePath); + } +} + +// Constructs an audio container object from a MemRef object. Initializes +// metadata with default values. +template +Audio::Audio(MemRef &&memref) noexcept + : MemRef(std::move(memref)), bitsPerSample(0), numSamples(0), + numChannels(0), sampleRate(0) {} + +// Create Audio File. +// Save Audio MemRef to the specified file path using the desired format. +template +bool Audio::saveToFile(std::string filePath, std::string format) { + // --------------------------------------------------------------------------- + // 1. Determine the audio format and encode the MemRef into file data. + // --------------------------------------------------------------------------- + // Convert the string to lowercase before comparison, ensuring that case + // variations are handled without repeating conditions. + std::transform(format.begin(), format.end(), format.begin(), ::tolower); + // Vector for storing bytes in a specific audio format. + std::vector fileData; + // Select encoder. + if (format == "wav" || format == "wave") { + bool success = EncodeWaveFile(fileData); + if (!success) { + std::cerr << "Failed to encode WAVE file." << std::endl; + return false; + } + } else { + std::cerr << "Unsupported: The encoding method for " << format + << " format is not yet supported." << std::endl; + return false; + } + // --------------------------------------------------------------------------- + // 2. Write std::vector into audio file. + // --------------------------------------------------------------------------- + std::ofstream outputFile(filePath, std::ios::binary); + + if (outputFile.is_open()) { + for (size_t i = 0; i < fileData.size(); i++) { + char value = static_cast(fileData[i]); + outputFile.write(&value, sizeof(char)); + } + + outputFile.close(); + + return true; + } + + return false; +} + +// WAV Audio File Decoder +template +bool Audio::decodeWaveFile(const std::vector &fileData) { + // This container class only cares about the data and key information in the + // audio file, so only the format and data chunk are decoded here. + // Find the starting indices of critical chunks within the WAV file. + size_t indexOfFormatChunk = getIndexOfChunk(fileData, "fmt ", 12); + size_t indexOfDataChunk = getIndexOfChunk(fileData, "data", 12); + + // Decode the 'format' chunk to obtain format specifications. + // Format sub-chunk: + // sub-chunk ID: char[4] | 4 bytes | "fmt " + // sub-chunk size: uint32_t | 4 bytes + // audio format: uint16_t | 2 bytes | 1 for PCM + // number of channels: uint16_t | 2 bytes + // sample rate: uint32_t | 4 bytes + // byte rate: uint32_t | 4 bytes + // block align: uint16_t | 2 bytes + // bits per sample: uint16_t | 2 bytes + std::string formatChunkID(fileData.begin() + indexOfFormatChunk, + fileData.begin() + indexOfFormatChunk + 4); + // uint32_t fmtChunkSize = fourBytesToI32(fileData, indexOfFormatChunk + 4); + // uint16_t audioFormat = twoBytesToI16(fileData, indexOfFormatChunk + 8); + this->numChannels = twoBytesToI16(fileData, indexOfFormatChunk + 10); + this->sampleRate = fourBytesToI32(fileData, indexOfFormatChunk + 12); + // byteRate = sampleRate * numChannels * bitsPerSample / 8 + // uint32_t byteRate = fourBytesToI32(fileData, indexOfFormatChunk + 16); + // blockAlign = numChannels * bitsPerSample / 8 + uint16_t blockAlign = twoBytesToI16(fileData, indexOfFormatChunk + 20); + this->bitsPerSample = twoBytesToI16(fileData, indexOfFormatChunk + 22); + uint16_t numBytesPerSample = static_cast(this->bitsPerSample) / 8; + + // Decode `data` chunk. + // Data sub-chunk: + // sub-chunk ID: char[4] | 4 bytes | "data" + // sub-chunk size: uint32_t | 4 bytes + // data | remains + std::string dataChunkID(fileData.begin() + indexOfDataChunk, + fileData.begin() + indexOfDataChunk + 4); + int32_t dataChunkSize = fourBytesToI32(fileData, indexOfDataChunk + 4); + this->numSamples = dataChunkSize / blockAlign; + // size_t numSamplesPerChannels = this->numSamples / this->numChannels; + size_t samplesStartIndex = indexOfDataChunk + 8; + + // Audio MemRef layout defaults to 1 dimension. + // Sample values from multiple channels are stored together. + if (N == 1) { + this->sizes[0] = this->numSamples; + } else if (N == this->numChannels) { + // TODO: add conversion from 1 dimension to multi-dimension + std::cerr << "Unsupported: The MemRef layout of multi-dimensional channels " + "is not yet supported." + << std::endl; + return false; + } else { + std::cerr << "Error: dimension mismatch (audio file channel: " + << this->numChannels << " MemRef layout channel: " << N << ")" + << std::endl; + return false; + } + + // Allocate memory for MemRef. + this->setStrides(); + size_t size = this->product(this->sizes); + this->allocated = (T *)malloc(sizeof(T) * size); + this->aligned = this->allocated; + + // Sample data type: 8 bit + if (this->bitsPerSample == 8) { + size_t memrefIndex = 0; + for (size_t i = 0; i < this->numSamples; i++) { + for (size_t channel = 0; channel < this->numChannels; channel++) { + size_t sampleIndex = + samplesStartIndex + (blockAlign * i) + channel * numBytesPerSample; + this->aligned[memrefIndex] = oneByteToSample(fileData[sampleIndex]); + memrefIndex++; + } + } + } + // Sample data type: 16 bit + else if (this->bitsPerSample == 16) { + size_t memrefIndex = 0; + for (size_t i = 0; i < this->numSamples; i++) { + for (size_t channel = 0; channel < this->numChannels; channel++) { + size_t sampleIndex = + samplesStartIndex + (blockAlign * i) + channel * numBytesPerSample; + int16_t dataTwoBytes = twoBytesToI16(fileData, sampleIndex); + this->aligned[memrefIndex] = twoBytesToSample(dataTwoBytes); + memrefIndex++; + } + } + } + // Other data types are not currently supported. + else { + std::cerr << "Unsupported audio data type." << std::endl; + return false; + } + + return true; +} + +// WAV Audio File Encoder +template +bool Audio::EncodeWaveFile(std::vector &fileData) { + // Encode the 'header' chunk. + // RIFF chunk descriptor + // chunk ID: char[4] | 4 bytes | "RIFF" + // chunk size: uint32_t | 4bytes + // format: char[4] | 4 bytes | "WAVE" + stringToBytes(fileData, "RIFF"); + int16_t audioFormat = this->bitsPerSample == 32 ? 0 : 1; + // Size for 'format' sub-chunk, doesn't include metadata length. + int32_t formatChunkSize = audioFormat == 1 ? 16 : 18; + // Size for 'data' sub-chunk, doesn't include metadata length. + int32_t dataChunkSize = + this->numSamples * this->numChannels * this->bitsPerSample / 8; + // The file size in bytes include header chunk size(4, not counting RIFF and + // WAVE), the format chunk size(formatChunkSize and 8 bytes for metadata), the + // data chunk size(dataChunkSize and 8 bytes for metadata). + int32_t fileSizeInBytes = 4 + formatChunkSize + 8 + dataChunkSize + 8; + i32ToFourBytes(fileData, fileSizeInBytes); + stringToBytes(fileData, "WAVE"); + + // Encode the 'format' chunk. + // Format sub-chunk: + // sub-chunk ID: char[4] | 4 bytes | "fmt " + // sub-chunk size: uint32_t | 4 bytes + // audio format: uint16_t | 2 bytes | 1 for PCM + // number of channels: uint16_t | 2 bytes + // sample rate: uint32_t | 4 bytes + // byte rate: uint32_t | 4 bytes + // block align: uint16_t | 2 bytes + // bits per sample: uint16_t | 2 bytes + stringToBytes(fileData, "fmt "); + i32ToFourBytes(fileData, formatChunkSize); + i16ToTwoBytes(fileData, audioFormat); + i16ToTwoBytes(fileData, static_cast(this->numChannels)); + i32ToFourBytes(fileData, static_cast(this->sampleRate)); + int16_t numBytesPerBlock = + static_cast(dataChunkSize / this->numSamples); + int32_t numBytesPerSecond = + static_cast(this->sampleRate * numBytesPerBlock); + i32ToFourBytes(fileData, numBytesPerSecond); + i16ToTwoBytes(fileData, numBytesPerBlock); + i16ToTwoBytes(fileData, static_cast(this->bitsPerSample)); + + // Encode the 'data' chunk. + // Data sub-chunk: + // sub-chunk ID: char[4] | 4 bytes | "data" + // sub-chunk size: uint32_t | 4 bytes + // data | remains + stringToBytes(fileData, "data"); + i32ToFourBytes(fileData, dataChunkSize); + + // Sample data length: 8 bit + if (this->bitsPerSample == 8) { + size_t memrefIndex = 0; + for (size_t i = 0; i < this->numSamples; i++) { + for (size_t channel = 0; channel < this->numChannels; channel++) { + uint8_t byte = sampleToOneByte(this->aligned[memrefIndex]); + fileData.push_back(byte); + memrefIndex++; + } + } + } + // Sample data length: 16 bit + else if (this->bitsPerSample == 16) { + size_t memrefIndex = 0; + for (size_t i = 0; i < this->numSamples; i++) { + for (size_t channel = 0; channel < this->numChannels; channel++) { + int16_t sampleAsInt = sampleToI16(this->aligned[memrefIndex]); + i16ToTwoBytes(fileData, sampleAsInt); + memrefIndex++; } } - this->audioFile.samples.reset(temp); } - return this->audioFile.save(filename); + // Other data length are not yet supported. + else { + std::cerr << "Unsupported audio data length: " << this->bitsPerSample + << " bit" << std::endl; + return false; + } + + return true; +} + +// Locates the start index of a specific chunk in a WAV file data buffer. +// Params: +// fileData: Vector containing the raw binary data of the WAV file. +// chunkHeaderID: The 4-byte identifier for the chunk (e.g., "fmt ", "data"). +// startIndex: Index to start the search from within the fileData. +// endianness: Byte order used to interpret multi-byte values in the chunk +// size. +// Returns: +// The index of the start of the chunk or 0 if not found. +template +size_t Audio::getIndexOfChunk(const std::vector &fileData, + const std::string &chunkHeaderID, + size_t startIndex, Endianness endianness) { + constexpr int dataLen = 4; + if (chunkHeaderID.size() != dataLen) { + assert(false && "Chunk header ID must be exactly 4 characters long"); + return -1; + } + size_t i = startIndex; + while (i < fileData.size() - dataLen) { + // Check if the current bytes match the chunk header ID + if (memcmp(&fileData[i], chunkHeaderID.data(), dataLen) == 0) { + return i; + } + // Skip to the next chunk: advance by the size of the current chunk + // Move index to the size part of the chunk + i += dataLen; + // Prevent reading beyond vector size + if (i + dataLen > fileData.size()) + break; + // Get the size of the chunk. + auto chunkSize = fourBytesToI32(fileData, i, endianness); + if (chunkSize < 0) { + assert(false && "Invalid chunk size encountered"); + return -1; + } + // Move to the next chunk header + i += (dataLen + chunkSize); + } + // Return 0 if the chunk is not found + return 0; +} + +// Converts four bytes from the file data array to a 32-bit integer based on +// endianness. Params: +// fileData: Vector containing the raw binary data. +// startIndex: Index in fileData where the 4-byte sequence starts. +// endianness: Specifies the byte order (LittleEndian or BigEndian). +// Returns: +// The 32-bit integer converted from the byte sequence. +template +int32_t Audio::fourBytesToI32(const std::vector &fileData, + size_t startIndex, Endianness endianness) { + // Ensure the index is within the bounds to prevent out-of-range access. + if (startIndex + 3 >= fileData.size()) { + throw std::out_of_range("Index out of range for fourBytesToI32"); + } + // Use uint32_t to prevent sign extension and maintain accurate binary + // representation during bit operations. + uint32_t result; + if (endianness == Endianness::LittleEndian) { + result = (static_cast(fileData[startIndex + 3]) << 24) | + (static_cast(fileData[startIndex + 2]) << 16) | + (static_cast(fileData[startIndex + 1]) << 8) | + static_cast(fileData[startIndex]); + } else { + result = (static_cast(fileData[startIndex]) << 24) | + (static_cast(fileData[startIndex + 1]) << 16) | + (static_cast(fileData[startIndex + 2]) << 8) | + static_cast(fileData[startIndex + 3]); + } + // Convert the unsigned result to signed int32_t to match the function's + // return type. + return static_cast(result); +} + +// Converts two bytes from the file data array to a 16-bit integer based on +// endianness. Params: +// fileData: Vector containing the raw binary data. +// startIndex: Index in fileData where the 2-byte sequence starts. +// endianness: Specifies the byte order (LittleEndian or BigEndian). +// Returns: +// The 16-bit integer converted from the byte sequence. +template +int16_t Audio::twoBytesToI16(const std::vector &fileData, + size_t startIndex, Endianness endianness) { + // Ensure the index is within the bounds to prevent out-of-range access. + if (startIndex + 1 >= fileData.size()) { + throw std::out_of_range("Index out of range for twoBytesToI16"); + } + // Use uint16_t to prevent sign extension and maintain accurate binary + // representation during bit operations. + uint16_t result; + if (endianness == Endianness::LittleEndian) { + result = (static_cast(fileData[startIndex + 1]) << 8) | + static_cast(fileData[startIndex]); + } else { + result = (static_cast(fileData[startIndex]) << 8) | + static_cast(fileData[startIndex + 1]); + } + // Convert the unsigned result to signed int16_t to match the function's + // return type. + return static_cast(result); } +// Converts a 32-bit integer to four bytes based on endianness. +// Params: +// fileData: Vector containing the raw binary data. +// num: A 32-bit integer prepared for convertion. +// endianness: Specifies the byte order (LittleEndian or BigEndian). template -void Audio::fetchMetadata(const AudioFile &aud) { - this->audioFile.setBitDepth(aud.getBitDepth()); - this->audioFile.setSampleRate(aud.getSampleRate()); - this->audioFile.numSamples = aud.numSamples; - this->audioFile.numChannels = aud.numChannels; - this->audioFile.setAudioBuffer(nullptr); +void Audio::i32ToFourBytes(std::vector &fileData, int32_t num, + Endianness endianness) { + // Use uint8_t to prevent sign extension and maintain accurate binary + // representation during bit operations. + uint8_t bytes[4]; + if (endianness == Endianness::LittleEndian) { + bytes[3] = static_cast(num >> 24) & 0xFF; + bytes[2] = static_cast(num >> 16) & 0xFF; + bytes[1] = static_cast(num >> 8) & 0xFF; + bytes[0] = static_cast(num) & 0xFF; + } else { + bytes[0] = static_cast(num >> 24) & 0xFF; + bytes[1] = static_cast(num >> 16) & 0xFF; + bytes[2] = static_cast(num >> 8) & 0xFF; + bytes[3] = static_cast(num) & 0xFF; + } + // Append the converted bytes to the fileData vector. + for (size_t i = 0; i < 4; i++) + fileData.push_back(bytes[i]); } -template void Audio::moveToMemRef() { - if (data) - delete data; - intptr_t sizes[N]; - for (size_t i = 0; i < N; ++i) { - sizes[i] = audioFile.numSamples; - } - data = new MemRef(audioFile.samples, sizes); + +// Converts a 16-bit integer to two bytes based on endianness. +// Params: +// fileData: Vector containing the raw binary data. +// num: A 16-bit integer prepared for convertion. +// endianness: Specifies the byte order (LittleEndian or BigEndian). +template +void Audio::i16ToTwoBytes(std::vector &fileData, int16_t num, + Endianness endianness) { + // Use uint8_t to prevent sign extension and maintain accurate binary + // representation during bit operations. + uint8_t bytes[2]; + if (endianness == Endianness::LittleEndian) { + bytes[1] = static_cast(num >> 8) & 0xFF; + bytes[0] = static_cast(num) & 0xFF; + } else { + bytes[0] = static_cast(num >> 8) & 0xFF; + bytes[1] = static_cast(num) & 0xFF; + } + // Append the converted bytes to the fileData vector. + fileData.push_back(bytes[0]); + fileData.push_back(bytes[1]); } -template void Audio::moveToAudioFile() { - if (data) { - auto temp = data->release(); - audioFile.setAudioBuffer(temp); + +// Converts an audio sample to a 8-bit PCM format (one byte). +// Params: +// sample: A floating-point value representing the audio sample. +// Returns: +// An 8-bit unsigned integer representing the sample as one byte. +template uint8_t Audio::sampleToOneByte(T sample) { + if (std::isnan(sample)) { + // Handle corner case for NaN (Not a Number). Reset NaN to 1. + sample = static_cast(1.); + } else { + // Restricts sample value in range [-1.0, 1.0]. + sample = std::min(sample, static_cast(1.)); + sample = std::max(sample, static_cast(-1.)); } + // Converts a normalized floating-point audio sample to the [0, 255] range. + sample = (sample + static_cast(1.)) / static_cast(2.); + return static_cast(sample * 255.); } +// Converts an audio sample to a 16-bit PCM format (two bytes). +// Params: +// sample: A floating-point value representing the audio sample. +// Returns: +// A 16-bit signed integer representing the sample as two bytes. +template int16_t Audio::sampleToI16(T sample) { + if (std::isnan(sample)) { + // Handle corner case for NaN (Not a Number). Reset NaN to 1. + sample = static_cast(1.); + } else { + // Restricts sample value in range [-1.0, 1.0]. + sample = std::min(sample, static_cast(1.)); + sample = std::max(sample, static_cast(-1.)); + } + // Converts a normalized floating-point audio sample to the [-32767, 32767] + // range. + return static_cast(sample * 32767.); +} } // namespace dap #endif // FRONTEND_INTERFACES_BUDDY_DAP_AUDIOCONTAINER diff --git a/frontend/Interfaces/buddy/DAP/DAP.h b/frontend/Interfaces/buddy/DAP/DAP.h index 5f86565ccf..48fd2afbf3 100644 --- a/frontend/Interfaces/buddy/DAP/DAP.h +++ b/frontend/Interfaces/buddy/DAP/DAP.h @@ -21,10 +21,10 @@ #ifndef FRONTEND_INTERFACES_BUDDY_DAP_DAP #define FRONTEND_INTERFACES_BUDDY_DAP_DAP -#include "AudioFile.h" #include "buddy/DAP/AudioContainer.h" #include "buddy/DAP/DSP/Biquad.h" #include "buddy/DAP/DSP/FIR.h" #include "buddy/DAP/DSP/IIR.h" +#include "buddy/DAP/DSP/WhisperPreprocess.h" #endif // FRONTEND_INTERFACES_BUDDY_DAP_DAP diff --git a/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h new file mode 100644 index 0000000000..d0d1d8fb63 --- /dev/null +++ b/frontend/Interfaces/buddy/DAP/DSP/WhisperPreprocess.h @@ -0,0 +1,61 @@ +//===- WhisperPreprocess.h ------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// Header file for whisper preprocess operation in DAP dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef FRONTEND_INTERFACES_BUDDY_DAP_DSP_WHISPERPREPROCESS +#define FRONTEND_INTERFACES_BUDDY_DAP_DSP_WHISPERPREPROCESS + +#include "buddy/Core/Container.h" +#include "buddy/DAP/AudioContainer.h" +#include "buddy/DAP/DSP/IIRDesign.h" + +namespace dap { +namespace detail { +// Declare the whisper preprocess C interface. +extern "C" { +// The original MLIR function: +// ```mlir +// func.func @buddy_whisperPreprocess(%in : memref) -> +// memref<1x80x3000xf32> +// ``` +// +// After applying the '-llvm-request-c-wrappers' pass: +// The result of the function (memref<1x80x3000xf32>) is modified to be the +// first operand. +void _mlir_ciface_buddy_whisperPreprocess(MemRef *outputFeatures, + MemRef *inputRawSpeech); + +void _mlir_ciface_buddy_RFFT(MemRef *inputRawSpeech); + +} +} // namespace detail + +// Function for Whisper preprocess +void whisperPreprocess(MemRef *inputRawSpeech, + MemRef *outputFeatures) { + detail::_mlir_ciface_buddy_whisperPreprocess(outputFeatures, inputRawSpeech); +} + + +void RFFT(MemRef *inputRawSpeech) { + detail::_mlir_ciface_buddy_RFFT(inputRawSpeech); +} +} // namespace dap + +#endif // FRONTEND_INTERFACES_BUDDY_DAP_DSP_WHISPERPREPROCESS diff --git a/frontend/Interfaces/buddy/DIP/DIP.h b/frontend/Interfaces/buddy/DIP/DIP.h index bf61c701b8..8598b61fc5 100644 --- a/frontend/Interfaces/buddy/DIP/DIP.h +++ b/frontend/Interfaces/buddy/DIP/DIP.h @@ -23,6 +23,8 @@ #include "buddy/Core/Container.h" #include "buddy/DIP/ImageContainer.h" +#include "buddy/DIP/ImgContainer.h" +#include #include namespace dip { // Availale types of boundary extrapolation techniques provided in DIP dialect. @@ -69,10 +71,28 @@ void _mlir_ciface_resize_2d_nearest_neighbour_interpolation( Img *input, float horizontalScalingFactor, float verticalScalingFactor, MemRef *output); +// Declare the Resize4D C interface. +void _mlir_ciface_resize_4d_nhwc_nearest_neighbour_interpolation( + Img *input, float horizontalScalingFactor, + float verticalScalingFactor, MemRef *output); + +void _mlir_ciface_resize_4d_nchw_nearest_neighbour_interpolation( + dip::Image *input, float horizontalScalingFactor, + float verticalScalingFactor, MemRef *output); + void _mlir_ciface_resize_2d_bilinear_interpolation( Img *input, float horizontalScalingFactor, float verticalScalingFactor, MemRef *output); +// Declare the Resize4D C interface. +void _mlir_ciface_resize_4d_nhwc_bilinear_interpolation( + Img *input, float horizontalScalingFactor, + float verticalScalingFactor, MemRef *output); + +void _mlir_ciface_resize_4d_nchw_bilinear_interpolation( + dip::Image *input, float horizontalScalingFactor, + float verticalScalingFactor, MemRef *output); + // Declare the Morphology 2D C interface. void _mlir_ciface_erosion_2d_constant_padding( Img input, MemRef *kernel, MemRef *output, @@ -201,6 +221,49 @@ inline MemRef Resize2D_Impl(Img *input, return output; } + +// Helper function for applying 4D resize operation on images. +inline MemRef Resize4D_NHWC_Impl(Img *input, + INTERPOLATION_TYPE type, + std::vector scalingRatios, + intptr_t outputSize[4]) { + MemRef output(outputSize); + + if (type == INTERPOLATION_TYPE::NEAREST_NEIGHBOUR_INTERPOLATION) { + detail::_mlir_ciface_resize_4d_nhwc_nearest_neighbour_interpolation( + input, scalingRatios[0], scalingRatios[1], &output); + } else if (type == INTERPOLATION_TYPE::BILINEAR_INTERPOLATION) { + detail::_mlir_ciface_resize_4d_nhwc_bilinear_interpolation( + input, scalingRatios[0], scalingRatios[1], &output); + } else { + throw std::invalid_argument( + "Please chose a supported type of interpolation " + "(Nearest neighbour interpolation or Bilinear interpolation)\n"); + } + + return output; +} + +inline MemRef Resize4D_NCHW_Impl(dip::Image *input, + INTERPOLATION_TYPE type, + std::vector scalingRatios, + intptr_t outputSize[4]) { + MemRef output(outputSize); + + if (type == INTERPOLATION_TYPE::NEAREST_NEIGHBOUR_INTERPOLATION) { + detail::_mlir_ciface_resize_4d_nchw_nearest_neighbour_interpolation( + input, scalingRatios[0], scalingRatios[1], &output); + } else if (type == INTERPOLATION_TYPE::BILINEAR_INTERPOLATION) { + detail::_mlir_ciface_resize_4d_nchw_bilinear_interpolation( + input, scalingRatios[0], scalingRatios[1], &output); + } else { + throw std::invalid_argument( + "Please chose a supported type of interpolation " + "(Nearest neighbour interpolation or Bilinear interpolation)\n"); + } + + return output; +} } // namespace detail // User interface for 2D Correlation. @@ -325,25 +388,44 @@ inline MemRef Rotate2D(Img *input, float angle, // User interface for 2D Resize. inline MemRef Resize2D(Img *input, INTERPOLATION_TYPE type, - std::vector scalingRatios) { - if (scalingRatios[0] <= 0 || scalingRatios[1] <= 0) { - throw std::invalid_argument( - "Please enter positive values of scaling ratios.\n" - "Note : scaling ratio = " - "output_image_dimension / input_image_dimension\n"); + std::vector size) { + if (size.size() != 2) { + throw std::invalid_argument("Dimension of an image should be 2\n"); } - std::reverse(scalingRatios.begin(), scalingRatios.end()); - - intptr_t outputSize[2] = {static_cast(std::round( - input->getSizes()[0] * scalingRatios[0])), - static_cast(std::round( - input->getSizes()[1] * scalingRatios[1]))}; + intptr_t outputSize[2] = {size[0], size[1]}; + return detail::Resize2D_Impl(input, type, + {(float)input->getSizes()[0] / (float)size[0], + (float)input->getSizes()[1] / (float)size[1]}, + outputSize); +} - scalingRatios[0] = 1 / scalingRatios[0]; - scalingRatios[1] = 1 / scalingRatios[1]; +// User interface for 4D Resize. +inline MemRef Resize4D_NHWC(Img *input, + INTERPOLATION_TYPE type, + std::vector size) { + if (size.size() != 4) { + throw std::invalid_argument("Dimension of an image should be 4\n"); + } + intptr_t outputSize[4] = {size[0], size[1], size[2], size[3]}; + return detail::Resize4D_NHWC_Impl( + input, type, + {(float)input->getSizes()[1] / (float)size[1], + (float)input->getSizes()[2] / (float)size[2]}, + outputSize); +} - return detail::Resize2D_Impl( - input, type, {scalingRatios[1], scalingRatios[0]}, outputSize); +inline MemRef Resize4D_NCHW(dip::Image *input, + INTERPOLATION_TYPE type, + std::vector size) { + if (size.size() != 4) { + throw std::invalid_argument("Dimension of an image should be 4\n"); + } + intptr_t outputSize[4] = {size[0], size[1], size[2], size[3]}; + return detail::Resize4D_NCHW_Impl( + input, type, + {(float)input->getSizes()[2] / (float)size[2], + (float)input->getSizes()[3] / (float)size[3]}, + outputSize); } // User interface for 2D Resize. diff --git a/frontend/Interfaces/buddy/DIP/ImageContainer.h b/frontend/Interfaces/buddy/DIP/ImageContainer.h index a613ceb351..4470e4a443 100644 --- a/frontend/Interfaces/buddy/DIP/ImageContainer.h +++ b/frontend/Interfaces/buddy/DIP/ImageContainer.h @@ -141,6 +141,7 @@ Img::Img(T *data, intptr_t sizes[N]) : MemRef(data, sizes) {} #ifdef BUDDY_ENABLE_OPENCV // Image Constructor from OpenCV Mat. + template Img::Img(cv::Mat image, intptr_t sizes[N], bool norm) : MemRef() { if (image.channels() == 1) { @@ -189,14 +190,16 @@ Img::Img(cv::Mat image, intptr_t sizes[N], bool norm) : MemRef() { this->allocated = new T[size]; this->aligned = this->allocated; size_t k = 0; + //NCHW Layout for (int batch = 0; batch < this->sizes[0]; batch++) { for (int channel = 0; channel < this->sizes[1]; channel++) { + T *chandata = image.ptr(batch, channel); for (int row = 0; row < this->sizes[2]; row++) { for (int col = 0; col < this->sizes[3]; col++) { if (norm) { - this->aligned[k] = (T)image.at(row, col) / 255; + this->aligned[k] = chandata[row * this->sizes[3] + col] / 255; } else { - this->aligned[k] = (T)image.at(row, col); + this->aligned[k] = chandata[row * this->sizes[3] + col]; } k++; } @@ -205,6 +208,7 @@ Img::Img(cv::Mat image, intptr_t sizes[N], bool norm) : MemRef() { } } } + #endif template int Img::channels() { diff --git a/frontend/Interfaces/buddy/DIP/ImgContainer.h b/frontend/Interfaces/buddy/DIP/ImgContainer.h new file mode 100644 index 0000000000..2525641bff --- /dev/null +++ b/frontend/Interfaces/buddy/DIP/ImgContainer.h @@ -0,0 +1,621 @@ +//===- ImgContainer.h -----------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// Image container descriptor (without OpenCV dependency). +// +//===----------------------------------------------------------------------===// + +#ifndef FRONTEND_INTERFACES_BUDDY_DIP_IMGCONTAINER +#define FRONTEND_INTERFACES_BUDDY_DIP_IMGCONTAINER + +#include "buddy/Core/Container.h" +#include +#include +#include +#include +#ifdef BUDDY_ENABLE_PNG +#include +#endif + +namespace dip { +enum ImageModes { + DIP_GRAYSCALE = 0, + DIP_RGB = 1, +}; + +inline bool ifBigEndian() { + int num = 1; + char *ptr = (char *)# + return (*ptr == 0); +} + +template class Image : public MemRef { +public: + // Constructor initializes the image by loading from a file. + // Params: + // filename: Specifies the path to the image file. + // mode: Specifies the image mode (e.g., DIP_GRAYSCALE, DIP_RGB). + // norm: Indicates whether to normalize pixel values (default is false). + Image(std::string filename, ImageModes mode, bool norm = false); + + // Retrieves the name of the current image format as a string. + std::string getFormatName() const { + switch (this->imageFormat) { + case ImageFormat::BMP: + return "BMP"; + case ImageFormat::PNG: + return "PNG"; + default: + return "Unsupported format"; + } + } + // Returns the width of the image in pixels. + size_t getWidth() const { return this->width; } + // Returns the height of the image in pixels. + size_t getHeight() const { return this->height; } + // Returns the bit depth of the image. + int getBitDepth() const { return this->bitDepth; } + +private: + // Enum to represent supported image formats. + enum class ImageFormat { + ERROR, // Represents an error or unsupported format. + BMP, // BMP file format. + PNG, // PNG file format. + } imageFormat; + // Mode of the image (e.g., DIP_GRAYSCALE, DIP_RGB). + ImageModes imageMode; + // Width of the image in pixels. + size_t width; + // Height of the image in pixels. + size_t height; + // Bit depth of the image. + int bitDepth; + // Normalization flag. + bool isNorm; + // Determines the image format from raw file data. + void determineFormat(const std::vector &fileData); + // Decodes a BMP image from raw file data. + bool decodeBMP(const std::vector &fileData); + // Decodes a PNG image from raw file data. +#ifdef BUDDY_ENABLE_PNG + bool decodePNG(const std::vector &fileData); +#endif +}; + +// Image Container Constructor +// Constructs an image container object from the image file path. +template +Image::Image(std::string filePath, ImageModes mode, bool norm) + : imageMode(mode), isNorm(norm) { + // --------------------------------------------------------------------------- + // 1. Read the image file into a std::vector. + // --------------------------------------------------------------------------- + // Open the file in binary mode and position the file pointer at the end of + // the file. + std::ifstream file(filePath, std::ios::binary | std::ios::ate); + // Check if the file was successfully opened. + if (!file) { + throw std::runtime_error("Error: Unable to open file at " + filePath); + } + // Get the size of the file. + size_t dataLength = file.tellg(); + // Move file pointer to the beginning of the file. + file.seekg(0, std::ios::beg); + // Create a vector to store the data. + std::vector fileData(dataLength); + // Read the data. + if (!file.read(reinterpret_cast(fileData.data()), dataLength)) { + throw std::runtime_error("Error: Unable to read data from " + filePath); + } + file.close(); + + // --------------------------------------------------------------------------- + // 2. Determine the image format and decode the image data into MemRef. + // --------------------------------------------------------------------------- + // Determine the image format from the raw file data. + determineFormat(fileData); + if (this->imageFormat == ImageFormat::BMP) { + bool success = decodeBMP(fileData); + if (!success) { + this->imageFormat = ImageFormat::ERROR; + throw std::runtime_error("Failed to decode BMP file from " + filePath); + }; + } +#ifdef BUDDY_ENABLE_PNG + else if (this->imageFormat == ImageFormat::PNG) { + bool success = decodePNG(fileData); + if (!success) { + this->imageFormat = ImageFormat::ERROR; + throw std::runtime_error("Failed to decode PNG file from " + filePath); + }; + } +#endif + else { + throw std::runtime_error("Unsupported image file format."); + } +} + +// Determines the image format by inspecting the header of the file data. +template +void Image::determineFormat(const std::vector &fileData) { + std::array pngHeader = {0x89, 0x50, 0x4E, 0x47, + 0x0D, 0x0A, 0x1A, 0x0A}; + if (fileData.size() > 2 && fileData[0] == 'B' && fileData[1] == 'M') { + this->imageFormat = ImageFormat::BMP; + } else if (fileData.size() > 7 && + std::memcmp(fileData.data(), pngHeader.data(), 8) == 0) { + this->imageFormat = ImageFormat::PNG; + } else { + this->imageFormat = ImageFormat::ERROR; + } +} + +// BMP Image File Decoder +template +bool Image::decodeBMP(const std::vector &fileData) { + // Check if the provided data is large enough to contain a minimal BMP header + // (54 bytes). + if (fileData.size() < 54) { + throw std::runtime_error("Invalid BMP File: too small to contain header"); + } + + // Extract image information from BMP header + this->width = *reinterpret_cast(&fileData[18]); + this->height = *reinterpret_cast(&fileData[22]); + this->bitDepth = *reinterpret_cast(&fileData[28]); + uint32_t compression = *reinterpret_cast(&fileData[30]); + size_t pixelDataOffset = *reinterpret_cast(&fileData[10]); + + // Currently, only the BI_RGB (value 0) or BI_BITFIELDS (value 3) compression + // method is supported. + if (compression != 0 && compression != 3) { + std::cerr << "Unsupported BMP file compression method." << std::endl; + return false; + } + + // Currently, only the NCHW format with 4 dimensions is supported. + if (N == 4) { + if (this->imageMode == ImageModes::DIP_GRAYSCALE) { + // TODO: Add batch setting. + this->sizes[0] = 1; + this->sizes[1] = 1; + this->sizes[2] = this->height; + this->sizes[3] = this->width; + this->setStrides(); + size_t size = this->product(this->sizes); + this->allocated = (T *)malloc(sizeof(T) * size); + this->aligned = this->allocated; + // Fullfill data to memref container. + size_t memrefIndex = 0; + if (this->bitDepth == 32) { + // BMP file is upside-down storage. + for (size_t i = this->height; i > 0; i--) { + for (size_t j = 0; j < this->width; j++) { + // Locate the current pixel. + size_t pixelIndex = + pixelDataOffset + (((i - 1) * this->width) + j) * 4; + // Extract the blue, green, and red value from the current pixel. + int bluePixel = + *reinterpret_cast(&fileData[pixelIndex]); + int greenPixel = + *reinterpret_cast(&fileData[pixelIndex + 1]); + int redPixel = + *reinterpret_cast(&fileData[pixelIndex + 2]); + // Calculate the gray scale value. + int grayScaleValue = static_cast( + 0.299 * redPixel + 0.587 * greenPixel + 0.114 * bluePixel); + // Store the gray scale value into memref container. + this->aligned[memrefIndex] = + this->isNorm ? static_cast(grayScaleValue) / 255 + : static_cast(grayScaleValue); + memrefIndex++; + } + } + } else if (this->bitDepth == 24) { + // BMP file is upside-down storage. + for (size_t i = this->height; i > 0; i--) { + for (size_t j = 0; j < this->width; j++) { + // Locate the current pixel. + size_t pixelIndex = + pixelDataOffset + (((i - 1) * this->width) + j) * 3; + // Extract the blue, green, and red value from the current pixel. + int bluePixel = + *reinterpret_cast(&fileData[pixelIndex]); + int greenPixel = + *reinterpret_cast(&fileData[pixelIndex + 1]); + int redPixel = + *reinterpret_cast(&fileData[pixelIndex + 2]); + // Calculate the gray scale value. + int grayScaleValue = static_cast( + 0.299 * redPixel + 0.587 * greenPixel + 0.114 * bluePixel); + // Store the gray scale value into memref container. + this->aligned[memrefIndex] = + this->isNorm ? static_cast(grayScaleValue) / 255 + : static_cast(grayScaleValue); + memrefIndex++; + } + } + } else if (this->bitDepth == 16) { + // BMP file is upside-down storage. + for (size_t i = this->height; i > 0; i--) { + for (size_t j = 0; j < this->width; j++) { + // Locate the current pixel. + size_t pixelIndex = + pixelDataOffset + (((i - 1) * this->width) + j) * 2; + // Extract the 16-bit pixel value + uint16_t pixelValue = + *reinterpret_cast(&fileData[pixelIndex]); + + int redPixel, greenPixel, bluePixel; + if (compression == 3) { + // Extract individual color components (assuming RGB565 format) + redPixel = (pixelValue >> 11) & 0x1F; + greenPixel = (pixelValue >> 5) & 0x3F; + bluePixel = pixelValue & 0x1F; + + // Expand to 8-bit per channel + redPixel = (redPixel << 3) | (redPixel >> 2); + greenPixel = (greenPixel << 2) | (greenPixel >> 4); + bluePixel = (bluePixel << 3) | (bluePixel >> 2); + } else { + // Extract individual color components for 5-5-5 format + redPixel = (pixelValue >> 10) & 0x1F; + greenPixel = (pixelValue >> 5) & 0x1F; + bluePixel = pixelValue & 0x1F; + + // Expand to 8-bit per channel + redPixel = (redPixel << 3) | (redPixel >> 2); + greenPixel = (greenPixel << 3) | (greenPixel >> 2); + bluePixel = (bluePixel << 3) | (bluePixel >> 2); + } + // Calculate the gray scale value. + int grayScaleValue = static_cast( + 0.299 * redPixel + 0.587 * greenPixel + 0.114 * bluePixel); + // Store the gray scale value into memref container. + this->aligned[memrefIndex] = + this->isNorm ? static_cast(grayScaleValue) / 255 + : static_cast(grayScaleValue); + memrefIndex++; + } + } + } else { + std::cerr << "Unsupported: " << this->bitDepth << "bit depth." + << std::endl; + return false; + } + } else if (this->imageMode == ImageModes::DIP_RGB) { + // TODO: Add batch setting. + this->sizes[0] = 1; + this->sizes[1] = 3; + this->sizes[2] = this->height; + this->sizes[3] = this->width; + this->setStrides(); + size_t size = this->product(this->sizes); + this->allocated = (T *)malloc(sizeof(T) * size); + this->aligned = this->allocated; + // Fullfill data to memref container. + size_t memrefIndex = 0; + size_t colorStride = this->height * this->width; + + if (this->bitDepth == 32) { + // BMP file is upside-down storage. + for (size_t i = height; i > 0; i--) { + for (size_t j = 0; j < width; j++) { + // Locate the current pixel. + size_t pixelIndex = pixelDataOffset + (((i - 1) * width) + j) * 4; + // Extract the blue, green, and red value from the current pixel. + int bluePixel = + *reinterpret_cast(&fileData[pixelIndex]); + int greenPixel = + *reinterpret_cast(&fileData[pixelIndex + 1]); + int redPixel = + *reinterpret_cast(&fileData[pixelIndex + 2]); + // Store the values into memref container as RGB order. (BGR -> RGB) + this->aligned[memrefIndex] = this->isNorm + ? static_cast(redPixel) / 255 + : static_cast(redPixel); + this->aligned[memrefIndex + colorStride] = + this->isNorm ? static_cast(greenPixel) / 255 + : static_cast(greenPixel); + this->aligned[memrefIndex + 2 * colorStride] = + this->isNorm ? static_cast(bluePixel) / 255 + : static_cast(bluePixel); + memrefIndex++; + } + } + } else if (this->bitDepth == 24) { + // BMP file is upside-down storage. + for (size_t i = height; i > 0; i--) { + for (size_t j = 0; j < width; j++) { + // Locate the current pixel. + size_t pixelIndex = pixelDataOffset + (((i - 1) * width) + j) * 3; + // Extract the blue, green, and red value from the current pixel. + int bluePixel = + *reinterpret_cast(&fileData[pixelIndex]); + int greenPixel = + *reinterpret_cast(&fileData[pixelIndex + 1]); + int redPixel = + *reinterpret_cast(&fileData[pixelIndex + 2]); + // Store the values into memref container as RGB order. (BGR -> RGB) + this->aligned[memrefIndex] = this->isNorm + ? static_cast(redPixel) / 255 + : static_cast(redPixel); + this->aligned[memrefIndex + colorStride] = + this->isNorm ? static_cast(greenPixel) / 255 + : static_cast(greenPixel); + this->aligned[memrefIndex + 2 * colorStride] = + this->isNorm ? static_cast(bluePixel) / 255 + : static_cast(bluePixel); + memrefIndex++; + } + } + } else if (this->bitDepth == 16) { + // BMP file is upside-down storage. + for (size_t i = height; i > 0; i--) { + for (size_t j = 0; j < width; j++) { + // Locate the current pixel. + size_t pixelIndex = pixelDataOffset + (((i - 1) * width) + j) * 2; + // Extract the 16-bit pixel value + uint16_t pixelValue = + *reinterpret_cast(&fileData[pixelIndex]); + + int redPixel, greenPixel, bluePixel; + if (compression == 3) { + // Extract individual color components (assuming RGB565 format) + redPixel = (pixelValue >> 11) & 0x1F; + greenPixel = (pixelValue >> 5) & 0x3F; + bluePixel = pixelValue & 0x1F; + + // Expand to 8-bit per channel + redPixel = (redPixel << 3) | (redPixel >> 2); + greenPixel = (greenPixel << 2) | (greenPixel >> 4); + bluePixel = (bluePixel << 3) | (bluePixel >> 2); + } else { + // Extract individual color components for 5-5-5 format + redPixel = (pixelValue >> 10) & 0x1F; + greenPixel = (pixelValue >> 5) & 0x1F; + bluePixel = pixelValue & 0x1F; + + // Expand to 8-bit per channel + redPixel = (redPixel << 3) | (redPixel >> 2); + greenPixel = (greenPixel << 3) | (greenPixel >> 2); + bluePixel = (bluePixel << 3) | (bluePixel >> 2); + } + + // Store the values into memref container as RGB order. (BGR -> RGB) + this->aligned[memrefIndex] = this->isNorm + ? static_cast(redPixel) / 255 + : static_cast(redPixel); + this->aligned[memrefIndex + colorStride] = + this->isNorm ? static_cast(greenPixel) / 255 + : static_cast(greenPixel); + this->aligned[memrefIndex + 2 * colorStride] = + this->isNorm ? static_cast(bluePixel) / 255 + : static_cast(bluePixel); + memrefIndex++; + } + } + } else { + std::cerr << "Unsupported: " << this->bitDepth << "bit depth." + << std::endl; + return false; + } + } + } else { + std::cerr << "Unsupported: " << N << " dimension layout." << std::endl; + return false; + } + return true; +} + +// PNG Image File Decoder +#ifdef BUDDY_ENABLE_PNG +template +bool Image::decodePNG(const std::vector &fileData) { + // Check if the provided data is large enough to contain a minimal PNG header + // (33 bytes). + if (fileData.size() < 33) { + throw std::runtime_error("Invalid PNG File: too small to contain header"); + } + + // Extract image information from PNG header. Convert Big-Endian to + // Little-Endian. + this->width = (fileData[16] << 24) | (fileData[17] << 16) | + (fileData[18] << 8) | fileData[19]; + this->height = (fileData[20] << 24) | (fileData[21] << 16) | + (fileData[22] << 8) | fileData[23]; + this->bitDepth = *reinterpret_cast(&fileData[24]); + int colorType = *reinterpret_cast(&fileData[25]); + uint8_t interlace = *reinterpret_cast(&fileData[28]); + + // Currently, only the NCHW format with 4 dimensions is supported. + if (N == 4) { + // use libpng to read png image. Initialize libpng parameters + png_structp png_ptr = + png_create_read_struct(PNG_LIBPNG_VER_STRING, 0, 0, 0); + if (!png_ptr) { + std::cerr << "png_ptr creation failed" << std::endl; + return false; + } + + png_infop info_ptr = png_create_info_struct(png_ptr); + if (!info_ptr) { + std::cerr << "png_infop creation failed" << std::endl; + return false; + } + + // Set jump point for error handling + if (setjmp(png_jmpbuf(png_ptr))) { + std::cerr << "error during PNG reading" << std::endl; + // close PNG reading and free memory + png_destroy_read_struct(&png_ptr, &info_ptr, NULL); + return false; + } + + // copy filedata. Read image data from memory. + std::vector dataCopy = fileData; + png_set_read_fn( + png_ptr, &dataCopy, + [](png_structp png_ptr, png_bytep data, png_size_t length) { + std::vector *fileData = + static_cast *>(png_get_io_ptr(png_ptr)); + if (fileData->size() < length) { + png_error(png_ptr, "Read error from memory"); + } + std::copy(fileData->begin(), fileData->begin() + length, data); + fileData->erase(fileData->begin(), fileData->begin() + length); + }); + + png_read_info(png_ptr, info_ptr); + + // Convert big or little Endian and convert 16 bits to 8 bits + if (this->bitDepth == 16) + png_set_strip_16(png_ptr); + else if (!ifBigEndian()) + png_set_swap(png_ptr); + + // Remove alpha channel + if (colorType & PNG_COLOR_MASK_ALPHA) + png_set_strip_alpha(png_ptr); + + // Convert palette to rgb + if (colorType == PNG_COLOR_TYPE_PALETTE) + png_set_palette_to_rgb(png_ptr); + + // Convert low depth grayscale to 8-bit grayscale + if ((colorType & PNG_COLOR_MASK_COLOR) == 0 && this->bitDepth < 8) +#if (PNG_LIBPNG_VER_MAJOR * 10000 + PNG_LIBPNG_VER_MINOR * 100 + \ + PNG_LIBPNG_VER_RELEASE >= \ + 10209) || \ + (PNG_LIBPNG_VER_MAJOR == 1 && PNG_LIBPNG_VER_MINOR == 0 && \ + PNG_LIBPNG_VER_RELEASE >= 18) + png_set_expand_gray_1_2_4_to_8(png_ptr); +#else + png_set_gray_1_2_4_to_8(png_ptr); +#endif + + // Processing interleaved PNG images + if (interlace) + png_set_interlace_handling(png_ptr); + + if (this->imageMode == ImageModes::DIP_GRAYSCALE) { + // TODO: Add batch setting. + this->sizes[0] = 1; + this->sizes[1] = 1; + this->sizes[2] = this->height; + this->sizes[3] = this->width; + this->setStrides(); + size_t size = this->product(this->sizes); + this->allocated = (T *)malloc(sizeof(T) * size); + this->aligned = this->allocated; + + // RGB->Gray + if ((colorType & PNG_COLOR_MASK_COLOR) || + (colorType == PNG_COLOR_TYPE_PALETTE)) + png_set_rgb_to_gray(png_ptr, 1, 0.299, 0.587); + + // Update reading setting + png_read_update_info(png_ptr, info_ptr); + + // Allocate memory for libpng to read images + std::vector imgData(this->height * this->width); + std::vector row_pointers(this->height); + for (size_t y = 0; y < this->height; ++y) { + row_pointers[y] = imgData.data() + y * this->width; + } + + // Reading image + png_read_image(png_ptr, row_pointers.data()); + + // Fullfill data to memref container. + for (size_t i = 0; i < this->height; i++) + for (size_t j = 0; j < this->width; j++) { + size_t memrefIndex = i * this->width + j; + this->aligned[memrefIndex] = + this->isNorm ? static_cast(imgData[memrefIndex]) / 255 + : static_cast(imgData[memrefIndex]); + ; + } + } else if (this->imageMode == ImageModes::DIP_RGB) { + // TODO: Add batch setting. + this->sizes[0] = 1; + this->sizes[1] = 3; + this->sizes[2] = this->height; + this->sizes[3] = this->width; + this->setStrides(); + size_t size = this->product(this->sizes); + this->allocated = (T *)malloc(sizeof(T) * size); + this->aligned = this->allocated; + size_t colorStride = this->height * this->width; + + // Gray->RGB + if (colorType & PNG_COLOR_TYPE_GRAY) + png_set_gray_to_rgb(png_ptr); + + // Update reading setting + png_read_update_info(png_ptr, info_ptr); + + // Allocate memory for libpng to read images + std::vector imgData(this->height * this->width * 3); + std::vector row_pointers(this->height); + for (size_t y = 0; y < this->height; ++y) { + row_pointers[y] = imgData.data() + y * this->width * 3; + } + + // Reading image + png_read_image(png_ptr, row_pointers.data()); + + // Separate pixel data by channel + size_t memrefIndex = 0; + for (size_t i = 0; i < this->height; i++) + for (size_t j = 0; j < this->width; j++) { + // Locate the current pixel. + size_t pixelIndex = ((i * width) + j) * 3; + // Extract the red, green, and blue value from the current pixel. + int redPixel = + *reinterpret_cast(&imgData[pixelIndex]); + int greenPixel = + *reinterpret_cast(&imgData[pixelIndex + 1]); + int bluePixel = + *reinterpret_cast(&imgData[pixelIndex + 2]); + // Store the values into memref container as RGB order. + this->aligned[memrefIndex] = this->isNorm + ? static_cast(redPixel) / 255 + : static_cast(redPixel); + this->aligned[memrefIndex + colorStride] = + this->isNorm ? static_cast(greenPixel) / 255 + : static_cast(greenPixel); + this->aligned[memrefIndex + 2 * colorStride] = + this->isNorm ? static_cast(bluePixel) / 255 + : static_cast(bluePixel); + memrefIndex++; + } + } + + // close PNG reading and free memory + png_destroy_read_struct(&png_ptr, &info_ptr, NULL); + } else { + std::cerr << "Unsupported: " << N << " dimension layout." << std::endl; + return false; + } + return true; +} +#endif + +} // namespace dip + +#endif // FRONTEND_INTERFACES_BUDDY_DIP_IMGCONTAINER diff --git a/frontend/Interfaces/lib/CMakeLists.txt b/frontend/Interfaces/lib/CMakeLists.txt index 9f6f61b298..6a98a18b93 100644 --- a/frontend/Interfaces/lib/CMakeLists.txt +++ b/frontend/Interfaces/lib/CMakeLists.txt @@ -21,13 +21,13 @@ add_custom_command(OUTPUT DIP.o -finalize-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llc + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate --mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llc -mtriple=${BUDDY_TARGET_TRIPLE} -mattr=${BUDDY_OPT_ATTR} --filetype=obj -o ${CMAKE_CURRENT_BINARY_DIR}/DIP.o - DEPENDS buddy-opt + DEPENDS mlir-translate llc buddy-opt ) add_library(BuddyLibDIP STATIC DIP.o) @@ -50,23 +50,42 @@ add_custom_command(OUTPUT DAP.o -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llc + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate --mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llc -mtriple=${BUDDY_TARGET_TRIPLE} -mattr=${BUDDY_OPT_ATTR} --filetype=obj -o ${CMAKE_CURRENT_BINARY_DIR}/DAP.o - DEPENDS buddy-opt + DEPENDS mlir-translate llc buddy-opt ) -add_library(BuddyLibDAP STATIC DAP.o) - -SET_TARGET_PROPERTIES(BuddyLibDAP PROPERTIES - LINKER_LANGUAGE CXX - ARCHIVE_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_DIRECTORY} +add_custom_command(OUTPUT DAP-extend.o + COMMAND ${CMAKE_BINARY_DIR}/bin/buddy-opt ${CMAKE_CURRENT_SOURCE_DIR}/DAP-extend.mlir + -extend-dap + -one-shot-bufferize + -convert-linalg-to-loops + -convert-scf-to-cf + -expand-strided-metadata + -lower-affine + -convert-vector-to-llvm + -memref-expand + -arith-expand + -convert-arith-to-llvm + -finalize-memref-to-llvm + -convert-math-to-llvm + -llvm-request-c-wrappers + -convert-func-to-llvm + -reconcile-unrealized-casts | + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llc + -mtriple=${BUDDY_TARGET_TRIPLE} + -mattr=${BUDDY_OPT_ATTR} + -filetype=obj -relocation-model=pic + -o ${CMAKE_CURRENT_BINARY_DIR}/DAP-extend.o + DEPENDS mlir-translate llc buddy-opt ) - add_custom_command(OUTPUT DAPVectorization.o +add_custom_command(OUTPUT DAPVectorization.o COMMAND cat ${CMAKE_CURRENT_SOURCE_DIR}/DAP.mlir | sed 's/buddy_fir/buddy_fir_vectorization/' | sed 's/buddy_iir/buddy_iir_vectorization/' | @@ -83,18 +102,22 @@ SET_TARGET_PROPERTIES(BuddyLibDAP PROPERTIES -llvm-request-c-wrappers -convert-func-to-llvm -reconcile-unrealized-casts | - ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir | - ${LLVM_MLIR_BINARY_DIR}/llc + ${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir | + ${LLVM_TOOLS_BINARY_DIR}/llc -mtriple=${BUDDY_TARGET_TRIPLE} -mattr=${BUDDY_OPT_ATTR} -filetype=obj -o ${CMAKE_CURRENT_BINARY_DIR}/DAPVectorization.o - DEPENDS buddy-opt + DEPENDS mlir-translate llc buddy-opt ) -add_library(BuddyLibDAPVectorization STATIC DAPVectorization.o) +add_library(BuddyLibDAP STATIC + DAP.o + DAP-extend.o + DAPVectorization.o + ) -SET_TARGET_PROPERTIES(BuddyLibDAPVectorization PROPERTIES +SET_TARGET_PROPERTIES(BuddyLibDAP PROPERTIES LINKER_LANGUAGE CXX ARCHIVE_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_DIRECTORY} ) diff --git a/frontend/Interfaces/lib/DAP-extend.mlir b/frontend/Interfaces/lib/DAP-extend.mlir new file mode 100644 index 0000000000..2c9b7a5a3b --- /dev/null +++ b/frontend/Interfaces/lib/DAP-extend.mlir @@ -0,0 +1,8 @@ +func.func @buddy_whisperPreprocess(%in : memref) -> memref<1x80x3000xf32> { + %out = dap.whisper_preprocess %in : memref to memref<1x80x3000xf32> + return %out : memref<1x80x3000xf32> +} +func.func @buddy_RFFT(%in : memref) -> () { + dap.rfft %in : memref + return +} diff --git a/frontend/Interfaces/lib/DIP.mlir b/frontend/Interfaces/lib/DIP.mlir index 406cb21b7c..3153d1ebe8 100644 --- a/frontend/Interfaces/lib/DIP.mlir +++ b/frontend/Interfaces/lib/DIP.mlir @@ -54,6 +54,30 @@ func.func @resize_2d_bilinear_interpolation(%inputImage : memref, %hori return } +func.func @resize_4d_nhwc_nearest_neighbour_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + return +} + +func.func @resize_4d_nhwc_bilinear_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.resize_4d_nhwc BILINEAR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + return +} + +func.func @resize_4d_nchw_nearest_neighbour_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + return +} + +func.func @resize_4d_nchw_bilinear_interpolation(%inputImage : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %outputImage : memref) attributes{llvm.emit_c_interface} +{ + dip.resize_4d_nchw BILINEAR_INTERPOLATION %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + return +} + func.func @erosion_2d_constant_padding(%inputImage : memref, %kernel : memref, %outputImage : memref, %copymemref : memref, %centerX : index, %centerY : index, %iterations : index, %constantValue: f32) attributes{llvm.emit_c_interface} { dip.erosion_2d %inputImage, %kernel, %outputImage, %copymemref, %centerX, %centerY, %iterations, %constantValue: memref, memref, memref, memref, index, index, index, f32 diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py index 8b1b954fdc..865614da5a 100644 --- a/frontend/Python/frontend.py +++ b/frontend/Python/frontend.py @@ -161,10 +161,14 @@ def __init__( "reciprocal.default": ReciprocalOp, "clamp_min.default": ClampMinOp, "clamp_max.default": ClampMaxOp, - "ge.Scalar": GreaterEqualOp, - "gt.Tensor": GreaterThanOp, + "randint.low": RandIntLowOp, "cos.default": CosOp, "sin.default": SinOp, + "argmax.default": ArgMaxOp, + "split.Tensor":SplitOp, + "max.default":MaxOp, + "gt.Scalar":GtOp, + "ge.Scalar": GeOp, } @property @@ -284,6 +288,7 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): fake_params, self._ops_registry, self._func_name, + self._verbose ) for gm_node in _gm.graph.nodes: node_users = [] @@ -323,7 +328,8 @@ def _compiler(_gm: torch.fx.GraphModule, _inputs: List[torch.Tensor]): else: tensor_meta = gm_node.meta.get("tensor_meta") val = gm_node.meta.get("val") - num_returns = len(gm_node.target._schema.returns) + # num_returns = len(gm_node.target._schema.returns) + num_returns = len(val) if isinstance(val, list) else len(gm_node.target._schema.returns) if num_returns == 1: node_dtype = self._torch_dtype_translate( str(tensor_meta.dtype) diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index eb78c0ff33..ce35693efd 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -105,6 +105,7 @@ def __init__( fake_params: List[TensorMeta], ops_registry: dict, func_name: str, + verbose=False ) -> None: """ Initializes the Graph. @@ -125,6 +126,7 @@ def __init__( self._fake_params = fake_params self.device = "cpu" self._imported_module = None + self._verbose = verbose self._ops_registry = ops_registry self._func_name = func_name self._ctx = ir.Context() @@ -237,6 +239,7 @@ def lower_to_top_level_ir(self): self._inputs, self._func_name, self._ops_registry, + verbose=self._verbose ) self._imported_module = fx_importer.import_graph() outputs = fx_importer.get_output_nodes() @@ -347,6 +350,7 @@ def __init__( func_name: str, ops_registry: dict, do_param_pack: bool = False, + verbose=False ): """ Initializes the buddy Graph importer. @@ -364,6 +368,7 @@ def __init__( self._func_name = func_name self._params = params self._inputs = inputs + self._verbose = verbose self._do_param_pack = do_param_pack self._param_packs = [] self._num_input_visited = 0 @@ -451,9 +456,11 @@ def import_graph(self) -> ir.Module: @func.FuncOp.from_py_func(*arguments, name=self._func_name) def generated_func(*args): args_list = list(args) + func_op = self._module.body.operations[0] for node in self._body: if node in extern_func: continue + old_ops = [op for op in func_op.body.blocks[0].operations] if isinstance(node, OutputOp): output_node_args = node.args returns = [ @@ -471,7 +478,20 @@ def generated_func(*args): ] else: self._import_op(node) - + new_ops = [op for op in func_op.body.blocks[0].operations] + if self._verbose: + print('='*20 + "Graph Node" + "="*20) + print("Node: " + node.name) + print("Type: " + str(node._op_type)) + print("Arguments: " + str(node.args)) + print("Parents: " + str(node._parents)) + print("Children: " + str(node._children)) + print('-'*20 + "MLIR OPS" + '-'*20) + for op in new_ops: + if op not in old_ops: + print(op) + print("") + return self._symbol_table.get(("output", 0)) return self._module diff --git a/frontend/Python/graph/graph_driver.py b/frontend/Python/graph/graph_driver.py index 50a8869d5a..dd37aa12bd 100644 --- a/frontend/Python/graph/graph_driver.py +++ b/frontend/Python/graph/graph_driver.py @@ -142,7 +142,7 @@ def build_subgraph_by_group(self): # Create subgraph and add it to the dictionary subgraph = Graph( - subgraph_input, [], self._graph._ops_registry, subgraph_name + subgraph_input, [], self._graph._ops_registry, subgraph_name, verbose=self._graph._verbose ) subgraph.body = subgraph_body for op in subgraph_body: @@ -172,6 +172,7 @@ def construct_main_graph(self, do_param_pack=False): self._graph._fake_params, self._graph._ops_registry, self._graph._func_name, + self._graph._verbose ) # Adding FuncOp nodes for each subgraph diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py index a9cd18520d..9c2618e3c6 100644 --- a/frontend/Python/graph/operation.py +++ b/frontend/Python/graph/operation.py @@ -494,25 +494,49 @@ def __init__(self) -> None: self._op_type = OpType.ElementwiseType -class GreaterEqualOp(Op): +class RandIntLowOp(Op): def __init__(self) -> None: super().__init__() - self._op_type = OpType.BroadcastType + self._op_type = OpType.PlaceholderType -class GreaterThanOp(Op): +class CosOp(Op): def __init__(self) -> None: super().__init__() - self._op_type = OpType.BroadcastType + self._op_type = OpType.ElementwiseType -class CosOp(Op): +class SinOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType -class SinOp(Op): +class ArgMaxOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ReduceType + + +class SplitOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ReshapeType + + +class MaxOp(Op): def __init__(self) -> None: super().__init__() self._op_type = OpType.ElementwiseType + + +class GtOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType + + +class GeOp(Op): + def __init__(self) -> None: + super().__init__() + self._op_type = OpType.ElementwiseType \ No newline at end of file diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index 61f6a5b54a..ac7d34c99c 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -19,14 +19,14 @@ # ===--------------------------------------------------------------------------- from .. import Graph -from ..operation import PlaceholderOp, OpType +from ..operation import * from .. import DeviceType # TODO: classify op type for op fusion # OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType] # OP_TYPE_UNFUSABLE = [OpType.Unfusable, OpType.ConcatType] # OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = [] -# ANCHOR_OP_TYPE = [] +# ANCHOR_OP_TYPE = [] def simply_fuse(graph: Graph): """ @@ -47,5 +47,3 @@ def simply_fuse(graph: Graph): graph.op_groups = {} graph.op_groups["subgraph0"] = new_op_group graph.group_map_device = {"subgraph0": device} - - diff --git a/frontend/Python/ops/linalg.py b/frontend/Python/ops/linalg.py index 3cdae6c0a8..ec67b39a19 100644 --- a/frontend/Python/ops/linalg.py +++ b/frontend/Python/ops/linalg.py @@ -22,7 +22,7 @@ import mlir.ir as ir from mlir.dialects import tosa, linalg, arith, tensor, math -import copy +import copy, array, sys import numpy import functools @@ -1853,8 +1853,127 @@ def scalar_tensor_op(node: ScalarTensorOp, symbol_table): return op +def split_op(node: SplitOp, symbol_table): + """ + Split the input tensor into smaller tensors along the specified dimension. + + Args: + node (SplitOp): The split operation node with metadata. + symbol_table: Mapping of variable names to tensor references. + + Returns: + List[Tensor]: List of split tensors. + """ + # Get the input tensor and parameters + input_tensor = symbol_table.get((str(node.args[0]), 0), node.args[0]) + split_size = node.args[1] # Size of each split tensor + input_shape = input_tensor.type.shape + dim = node.args[2] # Dimension to split along + if dim < 0: + dim += len(input_shape) + + split_count = (input_shape[dim] + split_size - 1) // split_size # Round up + tensor_rank = len(input_shape) + default_sizes = list(input_shape) + default_strides = [1] * tensor_rank + splits = [] + + for i in range(split_count): + # Calculate the offset along the specified dimension + offsets = [0] * tensor_rank + offsets[dim] = i * split_size + offsets_attr = ir._denseI64ArrayAttr(offsets, None) + + # Set the size along the split dimension; the last slice may be smaller than split_size + sizes = list(default_sizes) + sizes[dim] = min(split_size, input_shape[dim] - i * split_size) + sizes_attr = ir._denseI64ArrayAttr(sizes, None) + + # The stride for each dimension is set to 1 by default + strides = list(default_strides) + strides_attr = ir._denseI64ArrayAttr(strides, None) + + output_shape = list(node.tensor_meta["shape"][i]) + dtype = node.tensor_meta["dtype"][i] + mlir_dtype = mlir_element_type_get(dtype) + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + + slice_op = tensor.ExtractSliceOp( + tensor_type, + input_tensor, + [], + [], + [], + offsets_attr, + sizes_attr, + strides_attr, + ) + splits.append(slice_op.result) + + return splits + + +def max_op(node: MaxOp, symbol_table): + """ + Computes the maximum value from the input tensor and returns it as a tensor. + + Args: + node: The operation node containing input tensor information. + symbol_table: A table mapping identifiers to tensor values. + + Returns: + A tensor containing the maximum value extracted from the input tensor. + """ + input1 = symbol_table.get((str(node.args[0]), 0), node.args[0]) + dtype = node.tensor_meta["dtype"] + mlir_dtype = mlir_element_type_get(dtype) + output_shape = node.tensor_meta["shape"] + tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype) + input_shape = ir.RankedTensorType(input1.type).shape + + total_size = 1 + for x in input_shape: + total_size *= x + reshape_op = tosa.ReshapeOp( + input1, memoryview(array.array("i", [total_size])) + ) + + argmax_result = ir.RankedTensorType.get([], ir.IntegerType.get_signless(64)) + argmax_op = tosa.ArgMaxOp(argmax_result, reshape_op.result, 0) + index_value = tensor.ExtractOp(argmax_op, []) + index = arith.IndexCastOp(ir.IndexType.get(), index_value) + max_value = tensor.ExtractOp(reshape_op, index) + output = tensor.FromElementsOp(tensor_type, max_value) + + return output + + +def gt_op(node: GtOp, symbol_table): + """ + Compares an input tensor with a scalar value to determine element-wise greater than. + + Parameters: + - node: The operation node containing arguments and metadata. + - symbol_table: A mapping of tensor names to their corresponding MLIR objects. + + Returns: + - cmp_op: A comparison operation result indicating where the input tensor's elements are greater than the scalar. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0), node.args[0]) + input_dtype = ir.RankedTensorType(input_tensor.type).element_type + input_shape = ir.RankedTensorType(input_tensor.type).shape + tensor_type = ir.RankedTensorType.get(input_shape, input_dtype) + scalar = arith.ConstantOp(input_dtype, node.args[1]) + rhs = tensor.SplatOp(tensor_type, scalar) + if str(input_dtype).find("i") != -1: + cmp_op = arith.CmpIOp(4, input_tensor, rhs) + else: + cmp_op = arith.CmpFOp(2, input_tensor, rhs) + + return cmp_op + def ge_op( - node: GreaterEqualOp, + node: GeOp, symbol_table: Dict[Tuple[str, int], ir.Operation], ): """ @@ -1926,99 +2045,6 @@ def ge_op( return op -def gt_op( - node: GreaterThanOp, - symbol_table: Dict[Tuple[str, int], ir.Operation], -): - """ - Import the tensor greater than operation. - From buddy GreaterThanOp to MLIR arith `constant` operation. - - Note: This op, campare two input nodes, and output bool tensor to represent - compare result. - Args: - node: Containing information from the input graph node. - symbol_table: A dictionary mapping symbols to their corresponding - operations. - - Returns: - op: The operation return the linalg.generic op. - """ - input1 = symbol_table.get((str(node.args[0]), 0)) - input2 = symbol_table.get((str(node.args[1]), 0)) - output_shape = list(node.tensor_meta["shape"]) - dtype = node.tensor_meta["dtype"] - value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 4) - shp1 = list(ir.RankedTensorType(ir.Value(input1).type).shape) - shp2 = list(ir.RankedTensorType(ir.Value(input2).type).shape) - dtype = mlir_element_type_get(dtype) - tensor_type = ir.RankedTensorType.get(output_shape, dtype) - output = tensor.EmptyOp(output_shape, dtype) - if len(shp1) < len(shp2): - if int(shp1[-1]) > 1 and shp2[-1] == 1: - generic_map = ir.AffineMap.get_permutation( - [i for i in range(len(shp2) + 1)] - ) - op = linalg.GenericOp( - [tensor_type], - [input1, input2], - [output], - ir.ArrayAttr.get( - [ - ir.AffineMapAttr.get( - generic_map.get_submap( - [ - i - for i in range( - len(shp2) - len(shp1), len(shp2) - ) - ] - ) - ), - ir.AffineMapAttr.get( - generic_map.get_submap( - [i for i in range(0, len(shp2) - 1)] - + [len(shp2)] - ) - ), - ir.AffineMapAttr.get( - generic_map.get_submap( - [i for i in range(0, len(shp2))] - ) - ), - ] - ), - ir.ArrayAttr.get( - [ir.Attribute.parse("#linalg.iterator_type")] - * len(shp2) - + [ir.Attribute.parse("#linalg.iterator_type")] - ), - ) - block = ir.Block.create_at_start( - op.region, - [ - ir.RankedTensorType(input2.type).element_type, - ir.RankedTensorType(input2.type).element_type, - dtype, - ], - ) - if ( - str(ir.RankedTensorType(input2.type).element_type).find("i") - != -1 - ): - cmpop = arith.CmpIOp( - value, block.arguments[0], block.arguments[1] - ) - else: - cmpop = arith.CmpFOp( - value, block.arguments[0], block.arguments[1] - ) - block.append(cmpop) - block.append(linalg.YieldOp([cmpop.result])) - - return op - - ops_registry = { "MatmulOp": matmul_op, "ArangeOp": arange_op, @@ -2051,6 +2077,8 @@ def gt_op( "AddOp": add_op, "WhereOp": where_op, "ScalarTensorOp": scalar_tensor_op, - "GreaterEqualOp": ge_op, - "GreaterThanOp": gt_op, + "SplitOp": split_op, + "MaxOp": max_op, + "GtOp": gt_op, + "GeOp":ge_op, } diff --git a/frontend/Python/ops/math.py b/frontend/Python/ops/math.py index cc2ab2634b..6ce2e868d5 100644 --- a/frontend/Python/ops/math.py +++ b/frontend/Python/ops/math.py @@ -32,6 +32,16 @@ def sqrt_op(node, symbol_table): op = math.SqrtOp(input_tensor) return op +def cos_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + op = math.CosOp(input_tensor) + return op + +def sin_op(node, symbol_table): + input_tensor = symbol_table.get((str(node.args[0]), 0)) + op = math.SinOp(input_tensor) + return op + def cos_op(node, symbol_table): input_tensor = symbol_table.get((str(node.args[0]), 0)) diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py index 9a3c3db4fe..962672d26c 100644 --- a/frontend/Python/ops/tosa.py +++ b/frontend/Python/ops/tosa.py @@ -60,6 +60,8 @@ MeanOp, ClampMinOp, ClampMaxOp, + RandIntLowOp, + ArgMaxOp, ) from .utils import * @@ -514,6 +516,7 @@ def convert_element_type_op(node: ConvertElementTypeOp, symbol_table): TensorDType.Float64: ir.F64Type.get(), TensorDType.Float32: ir.F32Type.get(), TensorDType.Float16: ir.F16Type.get(), + TensorDType.Int64: ir.IntegerType.get_signless(64), TensorDType.Int32: ir.IntegerType.get_signless(32), TensorDType.Bool: ir.IntegerType.get_signless(1), } @@ -802,7 +805,10 @@ def expand_op(node: ExpandOp, symbol_table) -> ir.Operation: result_element_type = ir.RankedTensorType( to_expand_tensor.type ).element_type - if result_element_type == ir.IntegerType.get_signless(1): + if result_element_type in ( + ir.IntegerType.get_signless(1), + ir.IntegerType.get_signless(64), + ): element = ir.IntegerAttr.get(result_element_type, 0) elif result_element_type == ir.F32Type.get(): element = ir.FloatAttr.get(result_element_type, 0.0) @@ -1434,6 +1440,52 @@ def clamp_max_op(node: ClampMaxOp, symbol_table): return op +def randint_low_op(node: RandIntLowOp, symbol_table): + """ + Generates a tensor of random integers within a specified range. + + Parameters: + - node (RandIntLowOp): Node containing the range and shape. + - symbol_table (dict): Maps identifiers to values. + + Returns: + - tosa.ConstOp: Tensor with random integers. + """ + min_value = symbol_table.get((str(node.args[0]), 0), node.args[0]) + max_value = symbol_table.get((str(node.args[1]), 0), node.args[1]) + shape = symbol_table.get((str(node.args[2]), 0), node.args[2]) + output = ir.DenseElementsAttr.get( + numpy.random.randint(min_value, max_value, size=shape) + ) + op = tosa.ConstOp(output) + return op + + +def argmax_op(node: ArgMaxOp, symbol_table): + """ + Compute the indices of the maximum values along the specified axis. + + Args: + node (ArgMaxOp): The ArgMax operation node with metadata. + symbol_table: Mapping of variable names to tensor references. + + Returns: + op: The constructed ArgMax operation. + """ + input_tensor = symbol_table.get((str(node.args[0]), 0), node.args[0]) + axis = symbol_table.get((str(node.args[1]), 0), node.args[1]) + input_shape = list(ir.RankedTensorType(input_tensor.type).shape) + + if axis < 0: + axis += len(input_shape) + + result_shape = input_shape[:axis] + input_shape[axis + 1 :] + result_type = ir.IntegerType.get_signless(64) + result = ir.RankedTensorType.get(result_shape, result_type) + op = tosa.ArgMaxOp(result, input_tensor, axis) + return op + + ops_registry = { "AddOp": add_op, "MulOp": mul_op, @@ -1468,4 +1520,6 @@ def clamp_max_op(node: ClampMaxOp, symbol_table): "MeanOp": mean_op, "ClampMinOp": clamp_min_op, "ClampMaxOp": clamp_max_op, + "RandIntLowOp": randint_low_op, + "ArgMaxOp": argmax_op, } diff --git a/midend/include/Dialect/CMakeLists.txt b/midend/include/Dialect/CMakeLists.txt index 8ab8f29f58..afedee5d69 100644 --- a/midend/include/Dialect/CMakeLists.txt +++ b/midend/include/Dialect/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(RVV) add_subdirectory(VectorExp) add_subdirectory(Gemmini) add_subdirectory(Sche) +add_subdirectory(GPU) diff --git a/midend/include/Dialect/DAP/DAPOps.td b/midend/include/Dialect/DAP/DAPOps.td index 9e7d894b90..d14ca5cfcd 100644 --- a/midend/include/Dialect/DAP/DAPOps.td +++ b/midend/include/Dialect/DAP/DAPOps.td @@ -50,8 +50,7 @@ def DAP_FirOp : DAP_Op<"fir"> { }]; } -def DAP_BiquadOp : DAP_Op<"biquad"> -{ +def DAP_BiquadOp : DAP_Op<"biquad"> { let summary = [{Biquad filter, a infinite impulse response (IIR) filter. ```mlir @@ -94,4 +93,49 @@ def DAP_IirOp : DAP_Op<"iir"> { }]; } +def DAP_RFFTOp : DAP_Op<"rfft"> { + let summary = "RFFT operation."; + let description = [{ + The RFFT algorithm is designed to handle real-valued input signals. Real + signals exhibit conjugate symmetry in the frequency domain, meaning that + the positive and negative frequency components are complex conjugates of + each other. This symmetry property allows the RFFT algorithm to compute + only half of the frequency spectrum, reducing computational costs. + + Example: + + ```mlir + dap.rfft %data : memref + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$memref); + let assemblyFormat = [{ + $memref attr-dict `:` type($memref) + }]; +} + +def DAP_WhisperPreprocessOp : DAP_Op<"whisper_preprocess"> { + let summary = "preprocessor for Whisper model"; + let description = [{ + Preprocessor for Whisper model, do features extraction for input audio. + Input MemRef stores the raw speech data, Output MemRef contains computed + features with shape memref<1x80x3000xf32>. + + Example: + + ```mlir + %output = dap.whisper_preprocess %input : memref to memref<1x80x3000xf32> + ``` + }]; + + let arguments = (ins Arg:$memrefI); + let results = (outs Res:$memrefO); + let assemblyFormat = [{ + $memrefI attr-dict `:` type($memrefI) `to` type($memrefO) + }]; +} + #endif // DAP_DAPOPS_TD diff --git a/midend/include/Dialect/DIP/DIPOps.td b/midend/include/Dialect/DIP/DIPOps.td index aa32a60b8c..179e66359d 100644 --- a/midend/include/Dialect/DIP/DIPOps.td +++ b/midend/include/Dialect/DIP/DIPOps.td @@ -210,6 +210,92 @@ def DIP_Resize2DOp : DIP_Op<"resize_2d"> }]; } +def DIP_Resize4D_NHWCOp : DIP_Op<"resize_4d_nhwc"> +{ + let summary = [{ + This operation intends to provide a utility for resizing images using the DIP dialect. + Image resizing has many applications such as data augmentation, dimension adjustment in ML + models, etc. and can thus be used in native MLIR pipelines catering to above mentioned + use-cases. + + As of now, two different mechanisms for pixel interpolation are provided namely nearest + neighbour interpolation and bilinear interpolation. The user can specify the desired type of + interpolation via an attribute provided as argument to the operation. The operation also + expects scaling ratios (Input image dimension / Output image dimension) for both dimensions + of input and output images as arguments. + + The operation is flexible for its use with images of different sizes without necessarily + lowering it every time for each new image (Refer to the example provided in examples + directory for the DIP dialect). + + The processed image format is (batch, height, weight, channel). + + Syntax : + + ```mlir + dip.resize_4d_nhwc INTERPOLATION_TYPE %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + ``` + + where ```INTERPOLATION_TYPE``` can be ```NEAREST_NEIGHBOUR_INTERPOLATION``` or + ```BILINEAR_INTERPOLATION```. + }]; + + let arguments = (ins Arg:$memrefI, + F32 : $horizontal_scaling_factor, + F32 : $vertical_scaling_factor, + Arg:$memrefO, + DIP_InterpolationAttr:$interpolation_type); + + let assemblyFormat = [{ + $interpolation_type $memrefI `,` $horizontal_scaling_factor `,` $vertical_scaling_factor `,` $memrefO attr-dict `:` type($memrefI) `,` type($horizontal_scaling_factor) `,` type($vertical_scaling_factor) `,` type($memrefO) + }]; +} + +def DIP_Resize4D_NCHWOp : DIP_Op<"resize_4d_nchw"> +{ + let summary = [{ + This operation intends to provide a utility for resizing images using the DIP dialect. + Image resizing has many applications such as data augmentation, dimension adjustment in ML + models, etc. and can thus be used in native MLIR pipelines catering to above mentioned + use-cases. + + As of now, two different mechanisms for pixel interpolation are provided namely nearest + neighbour interpolation and bilinear interpolation. The user can specify the desired type of + interpolation via an attribute provided as argument to the operation. The operation also + expects scaling ratios (Input image dimension / Output image dimension) for both dimensions + of input and output images as arguments. + + The operation is flexible for its use with images of different sizes without necessarily + lowering it every time for each new image (Refer to the example provided in examples + directory for the DIP dialect). + + The processed image format is (batch, channel, height, weight). + + Syntax : + + ```mlir + dip.resize_4d_nchw INTERPOLATION_TYPE %inputImage, %horizontal_scaling_factor, %vertical_scaling_factor, %outputImage : memref, f32, f32, memref + ``` + + where ```INTERPOLATION_TYPE``` can be ```NEAREST_NEIGHBOUR_INTERPOLATION``` or + ```BILINEAR_INTERPOLATION```. + }]; + + let arguments = (ins Arg:$memrefI, + F32 : $horizontal_scaling_factor, + F32 : $vertical_scaling_factor, + Arg:$memrefO, + DIP_InterpolationAttr:$interpolation_type); + + let assemblyFormat = [{ + $interpolation_type $memrefI `,` $horizontal_scaling_factor `,` $vertical_scaling_factor `,` $memrefO attr-dict `:` type($memrefI) `,` type($horizontal_scaling_factor) `,` type($vertical_scaling_factor) `,` type($memrefO) + }]; +} + def DIP_Erosion2DOp : DIP_Op<"erosion_2d"> { let summary = [{This operation aims to provide utility to perform Erosion on a 2d single channel image.}]; diff --git a/midend/include/Dialect/GPU/CMakeLists.txt b/midend/include/Dialect/GPU/CMakeLists.txt new file mode 100644 index 0000000000..7278959827 --- /dev/null +++ b/midend/include/Dialect/GPU/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS TransformOps.td) +mlir_tablegen(TransformOps.h.inc -gen-op-decls) +mlir_tablegen(TransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(TransformOpsIncGen) diff --git a/midend/include/Dialect/GPU/TransformOps.h b/midend/include/Dialect/GPU/TransformOps.h new file mode 100644 index 0000000000..d69c467f56 --- /dev/null +++ b/midend/include/Dialect/GPU/TransformOps.h @@ -0,0 +1,74 @@ +//===- TransformOps.h -----------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file defines GPU transform ops for code generation. +// +//===----------------------------------------------------------------------===// + +#ifndef TRANSFORM_OPS_H +#define TRANSFORM_OPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" + +namespace mlir { +class DialectRegistry; + +namespace func { +class FuncOp; +} + +namespace scf { +class ForallOp; +class IfOp; +class ForOp; +} // namespace scf + +namespace vector { +class VectorDialect; +class WarpExecuteOnLane0Op; +} // namespace vector + +} // namespace mlir + +namespace mlir { +namespace buddy { +void registerBuddyGPUTransformOps(mlir::DialectRegistry ®istry); + +namespace gpu { + +class TransformExtensions + : public mlir::transform::TransformDialectExtension< + TransformExtensions> { +public: + TransformExtensions(); +}; +} // namespace gpu +} // namespace buddy +} // namespace mlir + +#define GET_OP_CLASSES +#include "GPU/TransformOps.h.inc" + +#endif // TRANSFORM_OPS_H diff --git a/midend/include/Dialect/GPU/TransformOps.td b/midend/include/Dialect/GPU/TransformOps.td new file mode 100644 index 0000000000..8eb7fac01d --- /dev/null +++ b/midend/include/Dialect/GPU/TransformOps.td @@ -0,0 +1,127 @@ +//===- TransformOps.td ----------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file defines the transform operations of the gpu dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRANSFORM_OPS_TD +#define TRANSFORM_OPS_TD + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" + +// From IREE Common Extension OPs +def HoistStaticAllocOp : Op, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Hoist static allocations"; + let description = [{ + Find static allocations and hoist them to the top level. + + #### Return modes + This transform applies static alloc hoisting the whole region of the operand. + + It does not consume the target handle and always return success. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "mlir::buddy::gpu"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::FuncOp funcOp, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def ApplyUnrollVectorsGpuMmaSyncPatternsOp : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Populate patterns that unroll vectors. TODO: better documentation. + }]; + + let cppNamespace = "mlir::buddy::gpu"; + let assemblyFormat = "attr-dict"; +} + +def VectorToMMAConversionOp : Op, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + This converts slices of operations containing vector.contract op into + mma operations, targetting warp level tensorcore operations. If the vector + operations are bigger than the native mma size it will first split up those + vector operations. + + Exactly one of use_wmma or use_mma_sync must be specified. + + #### Return modes + + This transform consumes the target handle and produces a result handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$use_mma_sync, + UnitAttr:$use_wmma); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + let cppNamespace = "mlir::buddy::gpu"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +#endif // TRANSFORM_OPS_TD diff --git a/midend/include/Utils/DIPUtils.h b/midend/include/Utils/DIPUtils.h index 7d17872d58..a8b77e8f23 100644 --- a/midend/include/Utils/DIPUtils.h +++ b/midend/include/Utils/DIPUtils.h @@ -103,6 +103,15 @@ void fillPixels(OpBuilder &builder, Location loc, Value resXVec, Value resYVec, Value outputColLastElemF32, Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32); +// Fill appropriate pixel 4D data in its corresponding rotated co-ordinate of +// output image. +void fillPixelsNearestNeighbour4D( + OpBuilder &builder, Location loc, Value ivs0, Value ivs1, Value resXVec, + Value resYVec, Value xVec, Value yVec, Value input, Value output, Value c0, + Value strideVal, Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, + Value dataCondition); + // Calculate tan(angle / 2) where angle is a function parameter. Value customTanVal(OpBuilder &builder, Location loc, Value angleVal); @@ -150,6 +159,15 @@ void fillPixelsBilinearInterpolate( Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, Value c1F32); +// Fills pixels in 4D of bilinear interpolation fashion. +void fillPixelsBilinearInterpolate4D( + OpBuilder &builder, Location loc, Value ivs0, Value ivs1, Value resXVec, + Value resYVec, Value xVec_L, Value yVec_L, Value xVec_H, Value yVec_H, + Value input, Value output, Value c0, Value strideVal, Value xVecWeight, + Value yVecWeight, Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, + Value c1F32, Value dataCondition); + // Helper function for resizing an image using nearest neighbour interpolation // mechanism. void NearestNeighbourInterpolationResizing( @@ -161,6 +179,17 @@ void NearestNeighbourInterpolationResizing( Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, int64_t stride, Value c0, Value c0F32); +// Helper function for resizing 4D an image using nearest neighbour +// interpolation mechanism. +void NearestNeighbourInterpolationResizing4D( + OpBuilder &builder, Location loc, MLIRContext *ctx, + SmallVector lowerBounds, SmallVector upperBounds, + SmallVector steps, Value strideVal, Value input, Value output, + Value horizontalScalingFactorVec, Value verticalScalingFactorVec, + Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, + int64_t stride, Value c0, Value c0F32, Value dataCondition); + // Helper function for resizing an image using bilinear interpolation mechanism. void BilinearInterpolationResizing( OpBuilder &builder, Location loc, MLIRContext *ctx, @@ -171,6 +200,17 @@ void BilinearInterpolationResizing( Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, int64_t stride, Value c0, Value c0F32, Value c1F32); +// Helper function for resizing 4D an image using bilinear interpolation +// mechanism. +void BilinearInterpolationResizing4D( + OpBuilder &builder, Location loc, MLIRContext *ctx, + SmallVector lowerBounds, SmallVector upperBounds, + SmallVector steps, Value strideVal, Value input, Value output, + Value horizontalScalingFactorVec, Value verticalScalingFactorVec, + Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, + int64_t stride, Value c0, Value c0F32, Value c1F32, Value dataCondition); + // Util function for morphological transformations ; compares two vectors and // returns a mask Value createCompVecMorph(OpBuilder &builder, Location loc, VectorType type, @@ -191,9 +231,10 @@ void calcAndStorewTailProcessingMorph( Value zeroPadding, Value inputCol, VectorType vectorMaskTy, Type elemTy, Value kernelValue, Value zeroPaddingElem, DIP_OP op); -// Utility function for traversing an image with support for boundary extrapolation, -// variable anchor point positioning, and tail processing. It is used to compose more -// complicated operations on top of it, like 2D Correlation and morphological operations. +// Utility function for traversing an image with support for boundary +// extrapolation, variable anchor point positioning, and tail processing. It is +// used to compose more complicated operations on top of it, like 2D Correlation +// and morphological operations. void traverseImagewBoundaryExtrapolation( OpBuilder &rewriter, Location loc, MLIRContext *ctx, Value input, Value kernel, Value output, Value centerX, Value centerY, diff --git a/midend/include/Utils/GPUUtils.h b/midend/include/Utils/GPUUtils.h new file mode 100644 index 0000000000..88605fe1d3 --- /dev/null +++ b/midend/include/Utils/GPUUtils.h @@ -0,0 +1,104 @@ +//===- GPUUtils.h ---------------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file implements GPU dialect specific utility functions for the buddy +// compiler ecosystem. +// +//===----------------------------------------------------------------------===// + +#ifndef INCLUDE_UTILS_GPUUTILS_H +#define INCLUDE_UTILS_GPUUTILS_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/TargetParser/Triple.h" + +namespace mlir{ +namespace buddy::gpu{ +static constexpr int32_t kNumGPUDims = 3; +static constexpr int32_t kWarpSize = 32; + +/// Pick an unrolling order that will allow tensorcore operation to reuse LHS +/// register. This is needed to get good performance on sm_80 target. +std::optional> +gpuMmaUnrollOrder(vector::ContractionOp contract); + +/// Helper function to return native size for MMA.SYNC-based operations. +std::optional> getMmaNativeVectorSize(Operation *op); + +/// Return true if the given memref has workgroup memory space. +bool hasSharedMemoryAddressSpace(MemRefType memrefType); + +/// Packs vector of lower precision into a single 32-bit width element. +/// (i.e <2xf16> -> i32 and <4xi8> -> i32) +Value packVectorToSupportedWidth(Location loc, OpBuilder &builder, Value input); + +/// Unpack single scalar element into a target vector type. +/// (i.e i32 -> vector<4xi8> or f32 -> vector<2xf16>) +Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput, + VectorType targetVecType); + +/// Creates an allocation in the entry block of the function if the size is +/// statically bounded. For a static allocation, it returns an allocation +/// of the same size but in the entry basic block. For dynamic (still bounded) +/// allocations creates an allocation, and inserts a subview to match the +/// dynamic shape of the allocation. Returns std::nullopt if the method +/// couldnt creat an allocation in the entry block. +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + Location loc, MemRefType allocaType, + ValueRange dynamicSizes, + std::optional alignment); + +/// Hoists `allocaOp` to the entry block of the function if the size is +/// statically bounded. For a static allocation, it returns an allocation +/// of the same size but in the entry basic block. For dynamic (still bounded) +/// allocations creates an allocation, and inserts a subview to match the +/// dynamic shape of the allocation. The method returns a value, but +/// does not replace the uses of the `allocaOp`. +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + AllocLikeOpType allocaOp); + +/// Traverse funcOp and try to hoist every AllocaOp to the entry block of the +/// function if the size is statically bounded. +template +void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp); + +} // namespace buddy::gpu +} // namespace mlir + +#endif // INCLUDE_UTILS_GPUUTILS_H diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index bd3c7f1509..c3c2fa2ddd 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(LowerBud) add_subdirectory(LowerDIP) add_subdirectory(LowerRVV) add_subdirectory(LowerDAP) +add_subdirectory(ExtendDAP) add_subdirectory(DAPVectorization) add_subdirectory(MatMulOptimization) add_subdirectory(TransposeOptimization) @@ -13,3 +14,5 @@ add_subdirectory(LowerLinalgToGemmini) add_subdirectory(SchedulingOnDevices) add_subdirectory(LowerSche) add_subdirectory(FuncBufferize) +add_subdirectory(DepthwiseConvOptimization) +add_subdirectory(MLIRGPU) diff --git a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt index fc88a92ef6..9f77079d38 100644 --- a/midend/lib/Conversion/ConvOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/ConvOptimization/CMakeLists.txt @@ -1,3 +1,5 @@ add_mlir_library(ConvOptimization ConvOptimize.cpp + ConvNhwcFhwcOptimize.cpp + ConvNhwcFhwcTileOptimize.cpp ) diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp new file mode 100644 index 0000000000..e4bc67e361 --- /dev/null +++ b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcOptimize.cpp @@ -0,0 +1,276 @@ +//====- ConvNhwcFhwcOptimize.cpp----------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Conv2DNhwcFhwcOp optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class ConvNhwcFhwcOptimizePattern : public ConversionPattern { +public: + explicit ConvNhwcFhwcOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC + Value IC = rewriter.create(loc, input, 3); // IC + Value FH = rewriter.create(loc, filter, 1); // FH + Value FW = rewriter.create(loc, filter, 2); // FW + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW,OC + auto outputForAllOp = rewriter.create( + loc, SmallVector({N, OH, OW, OC}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + Value ivOH = loopIndices[1]; // Index for the second dimension OH + Value ivOW = loopIndices[2]; // Index for the third dimension OW + Value ivOC = loopIndices[3]; // Index for the third dimension OC + + Value addRes = nestedBuilder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + // IC + auto forOp = nestedBuilder.create( + nestedLoc, c0, IC, vecSizeValue, ValueRange{addRes}, + [&](OpBuilder &builder, Location loc, Value ivIC, + ValueRange iargs) { + Value tVec; + if (isa(elemTy)) { + tVec = builder.create(loc, vecTy, + zeroElementType); + } else { + tVec = builder.create(loc, vecTy, + zeroElementType); + } + + Value remainLen = builder.create( + loc, + AffineMap::get(2, 1, {-d0 + s0, d1}, builder.getContext()), + ValueRange{ivIC, vecSizeValue, IC}); + Value remainMask = builder.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{remainLen}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = ivFW; + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, ivIC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{ivOC, rowFilter, columnFilter, + ivIC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + auto reduceVecOp = builder.create( + loc, vector::CombiningKind::ADD, forOp.getResult(0)); + auto maskedOp = + cast(mlir::vector::maskOperation( + builder, reduceVecOp, remainMask)); + Value reduceVec = maskedOp->getResult(0); + Value addNext; + if (isa(elemTy)) { + addNext = + builder.create(loc, iargs[0], reduceVec); + } else { + addNext = + builder.create(loc, iargs[0], reduceVec); + } + builder.create(loc, ValueRange{addNext}); + }); + + nestedBuilder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConvNhwcFhwcOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class ConvNhwcFhwcOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvNhwcFhwcOptimizePass) + StringRef getArgument() const final { return "conv-nhwc-fhwc-optimize"; } + StringRef getDescription() const final { + return "Conv2d NHWC FHWC optimize."; + } + ConvNhwcFhwcOptimizePass() = default; + ConvNhwcFhwcOptimizePass(const ConvNhwcFhwcOptimizePass &) {} + explicit ConvNhwcFhwcOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void ConvNhwcFhwcOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConvNhwcFhwcOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcTileOptimize.cpp b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcTileOptimize.cpp new file mode 100644 index 0000000000..41e1c066ee --- /dev/null +++ b/midend/lib/Conversion/ConvOptimization/ConvNhwcFhwcTileOptimize.cpp @@ -0,0 +1,342 @@ +//====- ConvNhwcFhwcOptimizeTile.cpp------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Conv2DNhwcFhwcOp tile optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class ConvNhwcFhwcTileOptimizePattern : public ConversionPattern { +public: + explicit ConvNhwcFhwcTileOptimizePattern(MLIRContext *context, + int64_t vecSizeParam, + int64_t tilingOHParam, + int64_t tilingOWParam, + int64_t tilingOCParam) + : ConversionPattern(linalg::Conv2DNhwcFhwcOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + tilingOH = tilingOHParam; + tilingOW = tilingOWParam; + tilingOC = tilingOCParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC + Value IC = rewriter.create(loc, input, 3); // IC + Value FH = rewriter.create(loc, filter, 1); // FH + Value FW = rewriter.create(loc, filter, 2); // FW + + auto tilingUpperBound = + AffineMap::get(2, 1, {d0 + d1, s0}, rewriter.getContext()); + + Value stepOH = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOH)), OH); + Value stepOW = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOW)), OW); + Value stepOC = rewriter.create( + loc, AffineMap::get(1, 0, d0.ceilDiv(tilingOC)), OC); + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW,OC + rewriter.create( + loc, SmallVector{c0, c0, c0, c0}, + SmallVector({N, OH, OW, OC}), + SmallVector({c1, stepOH, stepOW, stepOC}), + ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + + Value ubOH = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[1], stepOH, + OH}); // ub for the second dimension OH + Value ubOW = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[2], stepOW, + OW}); // ub for the second dimension OW + Value ubOC = nestedBuilder.create( + loc, tilingUpperBound, + ValueRange{loopIndices[3], stepOC, + OC}); // ub for the second dimension OC + + rewriter.create( + loc, + SmallVector{loopIndices[1], loopIndices[2], + loopIndices[3]}, + SmallVector({ubOH, ubOW, ubOC}), + SmallVector({c1, c1, c1}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivOH = loopIndices[0]; // Index for the first dimension OH + Value ivOW = loopIndices[1]; // Index for the first dimension OW + Value ivOC = loopIndices[2]; // Index for the first dimension OC + + Value addRes = nestedBuilder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + // IC + auto forOp = nestedBuilder.create( + nestedLoc, c0, IC, vecSizeValue, ValueRange{addRes}, + [&](OpBuilder &builder, Location loc, Value ivIC, + ValueRange iargs) { + Value tVec; + if (isa(elemTy)) { + tVec = builder.create( + loc, vecTy, zeroElementType); + } else { + tVec = builder.create(loc, vecTy, + zeroElementType); + } + + Value remainLen = builder.create( + loc, + AffineMap::get(2, 1, {-d0 + s0, d1}, + builder.getContext()), + ValueRange{ivIC, vecSizeValue, IC}); + Value remainMask = builder.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{remainLen}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, + Value ivFW, ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get(2, 0, + d0 * strWidth + + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, + ivIC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{ivOC, rowFilter, columnFilter, + ivIC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = + builder.create(loc, iVec, + fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create( + loc, ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + auto reduceVecOp = builder.create( + loc, vector::CombiningKind::ADD, forOp.getResult(0)); + auto maskedOp = + cast(mlir::vector::maskOperation( + builder, reduceVecOp, remainMask)); + Value reduceVec = maskedOp->getResult(0); + Value addNext; + if (isa(elemTy)) { + addNext = builder.create(loc, iargs[0], + reduceVec); + } else { + addNext = builder.create(loc, iargs[0], + reduceVec); + } + builder.create(loc, ValueRange{addNext}); + }); + + nestedBuilder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + nestedBuilder.create(nestedLoc); + }); + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; + int64_t tilingOH; + int64_t tilingOW; + int64_t tilingOC; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ConvNhwcFhwcTileOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class ConvNhwcFhwcTileOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvNhwcFhwcTileOptimizePass) + StringRef getArgument() const final { return "conv-nhwc-fhwc-tile-optimize"; } + StringRef getDescription() const final { + return "Conv2d NHWC FHWC optimize with Tile."; + } + ConvNhwcFhwcTileOptimizePass() = default; + ConvNhwcFhwcTileOptimizePass(const ConvNhwcFhwcTileOptimizePass &) {} + explicit ConvNhwcFhwcTileOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; + Option tilingOH{*this, "tiling-height", + llvm::cl::desc("number of the output height tiles."), + llvm::cl::init(1)}; + Option tilingOW{*this, "tiling-width", + llvm::cl::desc("number of the output width tiles."), + llvm::cl::init(1)}; + Option tilingOC{ + *this, "tiling-channel", + llvm::cl::desc("number of the output channel tiles."), llvm::cl::init(1)}; +}; +} // end anonymous namespace. + +void ConvNhwcFhwcTileOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize, tilingOH, + tilingOW, tilingOC); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerConvNhwcFhwcTileOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp b/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp index 55c876dd63..918a1388d6 100644 --- a/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp +++ b/midend/lib/Conversion/ConvVectorization/GEMMPointwiseConv2DNhwcHwcf.cpp @@ -122,8 +122,7 @@ class GEMMPointwiseConvPattern : public ConversionPattern { namespace { class PointwiseConvToGemmPass - : public PassWrapper> { + : public PassWrapper> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PointwiseConvToGemmPass) StringRef getArgument() const final { return "pointwise-conv-to-gemm"; } @@ -144,14 +143,20 @@ class PointwiseConvToGemmPass void PointwiseConvToGemmPass::runOnOperation() { MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); ConversionTarget target(*context); - target.addLegalDialect(); + target + .addLegalDialect(); target.addLegalOp(); target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); } namespace mlir { diff --git a/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt b/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt new file mode 100644 index 0000000000..8493e2a60a --- /dev/null +++ b/midend/lib/Conversion/DepthwiseConvOptimization/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_library(DepthwiseConvOptimization + DepthwiseConvNhwcHwc.cpp + ) diff --git a/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp b/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp new file mode 100644 index 0000000000..04bf76f769 --- /dev/null +++ b/midend/lib/Conversion/DepthwiseConvOptimization/DepthwiseConvNhwcHwc.cpp @@ -0,0 +1,331 @@ +//====- DepthwiseConvNhwcHwc.cpp +//--------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the DepthwiseConvNhwcHwc optimize. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class DepthwiseConv2DNhwcHwcOptimizePattern : public ConversionPattern { +public: + explicit DepthwiseConv2DNhwcHwcOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::DepthwiseConv2DNhwcHwcOp::getOperationName(), + 1, context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto convOp = dyn_cast_or_null(op); + auto loc = op->getLoc(); + + // Some constant we need. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const Value vecSizeValue = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + + Value input = op->getOperand(0); + Value filter = op->getOperand(1); + Value output = op->getOperand(2); + + int strHeight, strWidth, dilHeight, dilWidth; + + // Strides. + if (!convOp.getStrides()) { + strHeight = 1; + strWidth = 1; + } else { + strHeight = convOp.getStrides().getValues()[0]; + strWidth = convOp.getStrides().getValues() + [convOp.getStrides().getValues().size() - 1]; + } + + // Dilations. + if (!convOp.getDilations()) { + dilHeight = 1; + dilWidth = 1; + } else { + dilHeight = convOp.getDilations().getValues()[0]; + dilWidth = convOp.getDilations().getValues() + [convOp.getDilations().getValues().size() - 1]; + } + + ShapedType inputTy = input.getType().cast(); + Type elemTy = inputTy.getElementType(); + VectorType vecTy = VectorType::get(vecSize, elemTy); + + const Value zeroElementType = + rewriter.create(loc, rewriter.getZeroAttr(elemTy)); + + Value zeroElementTypeVec; + if (isa(elemTy)) { + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + } else { + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + } + // Dims + Value N = rewriter.create(loc, output, 0); // N + Value OH = rewriter.create(loc, output, 1); // OH + Value OW = rewriter.create(loc, output, 2); // OW + Value OC = rewriter.create(loc, output, 3); // OC/FC/IC + + Value applyOC = rewriter.create( + loc, AffineMap::get(1, 0, d0.floorDiv(vecSize) * vecSize), OC); + Value tailLength = rewriter.create( + loc, AffineMap::get(1, 0, d0 % vecSize), ValueRange{OC}); + Value maskVector = rewriter.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{tailLength}); + + Value FH = rewriter.create(loc, filter, 0); // FH + Value FW = rewriter.create(loc, filter, 1); // FW + + // clang format off + // Step 1: Create outer most loops. + // Create the scf::ForallOp operation For N,OH,OW + auto outputForAllOp = rewriter.create( + loc, SmallVector({N, OH, OW}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange loopIndices) { + Value ivN = loopIndices[0]; // Index for the first dimension N + Value ivOH = loopIndices[1]; // Index for the second dimension OH + Value ivOW = loopIndices[2]; // Index for the third dimension OW + // OC + nestedBuilder.create( + nestedLoc, c0, applyOC, vecSizeValue, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value ivOC, + ValueRange iargs) { + Value tVec = builder.create( + loc, vecTy, output, ValueRange{ivN, ivOH, ivOW, ivOC}); + + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, ivOC}); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{rowFilter, columnFilter, ivOC}); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + builder.create( + loc, forOp.getResult(0), output, + ValueRange{ivN, ivOH, ivOW, ivOC}); + + builder.create(loc, ValueRange{std::nullopt}); + }); + + // applyOC + Value condition = nestedBuilder.create( + loc, arith::CmpIPredicate::sgt, tailLength, c0); + nestedBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value tVec = builder.create( + loc, vecTy, output, ValueRange{ivN, ivOH, ivOW, applyOC}, + maskVector, zeroElementTypeVec); + // FH + auto forOp = builder.create( + loc, c0, FH, c1, ValueRange{tVec}, + [&](OpBuilder &builder, Location loc, Value ivFH, + ValueRange iargs) { + Value rowInput = builder.create( + loc, + AffineMap::get(2, 0, d0 * strHeight + d1 * dilHeight), + ValueRange{ivOH, ivFH}); + Value rowFilter = ivFH; + // FW + auto forOp = builder.create( + loc, c0, FW, c1, ValueRange{iargs[0]}, + [&](OpBuilder &builder, Location loc, Value ivFW, + ValueRange iargs) { + Value columnInput = + builder.create( + loc, + AffineMap::get( + 2, 0, d0 * strWidth + d1 * dilWidth), + ValueRange{ivOW, ivFW}); + Value columnFilter = + builder.create( + loc, AffineMap::get(1, 0, d0), ivFW); + Value iVec = builder.create( + loc, vecTy, input, + ValueRange{ivN, rowInput, columnInput, applyOC}, + maskVector, zeroElementTypeVec); + Value fVec = builder.create( + loc, vecTy, filter, + ValueRange{rowFilter, columnFilter, applyOC}, + maskVector, zeroElementTypeVec); + Value tVecNext; + if (isa(elemTy)) { + Value mulVec = builder.create( + loc, iVec, fVec); + tVecNext = builder.create( + loc, mulVec, iargs[0]); + } else { + tVecNext = builder.create( + loc, vecTy, iVec, fVec, iargs[0]); + } + + builder.create(loc, + ValueRange{tVecNext}); + }); + builder.create( + loc, ValueRange{forOp.getResult(0)}); + }); + builder.create( + loc, output, ValueRange{ivN, ivOH, ivOW, applyOC}, + maskVector, forOp.getResult(0)); + builder.create(loc, ValueRange{std::nullopt}); + }); + + nestedBuilder.create(nestedLoc); + }); + // clang format on + + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// DepthwiseConv2DNhwcHwcOptimizePass +//===----------------------------------------------------------------------===// + +namespace { +class DepthwiseConv2DNhwcHwcOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + DepthwiseConv2DNhwcHwcOptimizePass) + StringRef getArgument() const final { + return "depthwise-conv-nhwc-hwc-optimize"; + } + StringRef getDescription() const final { + return "Depthwise Conv2d NHWC HWC optimize."; + } + DepthwiseConv2DNhwcHwcOptimizePass() = default; + DepthwiseConv2DNhwcHwcOptimizePass( + const DepthwiseConv2DNhwcHwcOptimizePass &) {} + explicit DepthwiseConv2DNhwcHwcOptimizePass(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", llvm::cl::desc("Vector size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void DepthwiseConv2DNhwcHwcOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerDepthwiseConv2DNhwcHwcOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/ExtendDAP/CMakeLists.txt b/midend/lib/Conversion/ExtendDAP/CMakeLists.txt new file mode 100644 index 0000000000..5ecaa64c99 --- /dev/null +++ b/midend/lib/Conversion/ExtendDAP/CMakeLists.txt @@ -0,0 +1,3 @@ +add_mlir_library(ExtendDAPPass + ExtendDAPPass.cpp + ) diff --git a/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp new file mode 100644 index 0000000000..32fc42fcf7 --- /dev/null +++ b/midend/lib/Conversion/ExtendDAP/ExtendDAPPass.cpp @@ -0,0 +1,3915 @@ +//====- ExtendDAPPass.cpp - Extend DAP Dialect Lowering Pass -------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file defines Extend DAP dialect lowering pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "DAP/DAPDialect.h" +#include "DAP/DAPOps.h" +#include + +using namespace mlir; +using namespace buddy; +using namespace vector; +using namespace mlir::arith; +using namespace mlir::linalg; +using namespace mlir::bufferization; +using namespace mlir::scf; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// +Value initMelFilter(PatternRewriter &rewriter, Location loc, Value c0, Value c1, + Value f0) { + FloatType f64Ty = rewriter.getF64Type(); + std::vector data{ + 0.024862593984176087, 0.0019908218880980706, 0.022871772096078023, + 0.003981643776196141, 0.020880950207979945, 0.005972465664294215, + 0.018890128319881874, 0.007963287552392284, 0.016899306431783803, + 0.00995410944049036, 0.014908484543685726, 0.011944931328588433, + 0.012917662655587655, 0.013935753216686492, 0.0109268407674896, + 0.015926575104784558, 0.008936018879391525, 0.017917396992882653, + 0.006945196991293433, 0.019908218880980738, 0.004954375103195362, + 0.021899040769078785, 0.0029635532150973053, 0.02388986265717686, + 0.000972731326999214, 0.025880684545274913, 0.025835325311175383, + 0.0010180905610987637, 0.023844503423077285, 0.003008912449196894, + 0.0218536815349792, 0.004999734337294947, 0.019862859646881146, + 0.00699055622539301, 0.017872037758783092, 0.008981378113491093, + 0.015881215870684983, 0.010972200001589204, 0.013890393982586886, + 0.012963021889687279, 0.011899572094488815, 0.01495384377778533, + 0.009908750206390747, 0.016944665665883398, 0.007917928318292721, + 0.01893548755398142, 0.00592710643019463, 0.020874010059259842, + 0.0040404255283634505, 0.02211421726709443, 0.003318606124028175, + 0.021736724240202378, 0.0036109675065762467, 0.020497701500567712, + 0.0047621938624405336, 0.018486659689787407, 0.00659261778657699, + 0.015856038061722075, 0.00896277117173217, 0.01273876809538182, + 0.011751329556399702, 0.009250369503549623, 0.014853145171184273, + 0.005490841259941731, 0.018177473406255133, 0.0015463665611462304, + 0.0028155461301291296, 0.016329520204672005, 0.0074201889869586046, + 0.011181051149095055, 0.012018863908450889, 0.006065350444974551, + 0.016561277418378373, 0.0010297985194884273, 0.004360878822945124, + 0.012770536755428743, 0.009707189146398206, 0.00698640273562069, + 0.014854299940667337, 0.0014180475372245899, 0.004391219533926463, + 0.011486922519974862, 0.010089744452471235, 0.005411105098683222, + 0.00040022287019040627, 0.01473556574143922, 0.006518189694660876, + 0.008278412940789993, 0.012277561276489102, 0.0021878083895293813, + 0.003967812818254116, 0.01018448003033658, 0.009981875463793392, + 0.0038694348523115605, 0.002286485205918727, 0.011274894758755125, + 0.00846622203459276, 0.00482029386344711, 0.0013397691078130697, + 0.01167825163503901, 0.0076086824024487005, 0.005156961757386885, + 0.0010091040034666294, 0.011507895445653202, 0.007301822420859842, + 0.00498210435635522, 0.00119016554305424, 0.010863499380371759, + 0.007451189902152603, 0.004385921781217367, 0.001791381421388157, + 0.009832492179588035, 0.007973956151288167, 0.003447454713364045, + 0.00273258990979288, 0.008491348038548032, 0.008797688393881514, + 0.00223576473247454, 0.003943828025486854, 0.00690675179681992, + 0.00985924113208423, 0.0008110002442122107, 0.005364237465034378, + 0.005136650837642653, 0.0008692337979845255, 0.009462301431073093, + 0.0069410774768914035, 0.0032312041128928558, 0.0027783994364693502, + 0.0072370497293947405, 0.008628834806348044, 0.0012336377872858904, + 0.004773912819246566, 0.004943322320664174, 0.0009189908321450893, + 0.008653006854042458, 0.006818502698028209, 0.0026164354511310182, + 0.0032485836741847724, 0.006051854749492107, 0.008880466967901061, + 0.0002863667968266877, 0.005574480003847867, 0.0034677978994882394, + 0.0022684930397946727, 0.006649229002149792, 0.007871464954585431, + 0.0009245911230515482, 0.004809896965650232, 0.003870811942679267, + 0.0017483289767150309, 0.006817032762306986, 0.007283343633866355, + 0.00117032890964782, 0.00444812418116144, 0.0038987290660258203, + 0.001612904728456526, 0.006627129222403821, 0.007047320122030166, + 0.001089299240400779, 0.004421714754438077, 0.003615982700014719, + 0.0017961093868459883, 0.006142666159628658, 0.007102936106246769, + 0.000739228059530935, 0.00467144758787073, 0.003079108186777991, + 0.002239959069494691, 0.005418988314025048, 0.007397185776342812, + 0.0001706867135296284, 0.005145462616431531, 0.0023375742945811327, + 0.0028937394565202498, 0.004504461875632637, 0.0006420162966089689, + 0.006671349456684142, 0.005798479593256033, 0.0014345343541053727, + 0.003713231338714718, 0.0034412191129693185, 0.0016279830841734021, + 0.005447903871833263, 0.006591092774765041, 0.0004075043790786325, + 0.004660011565279688, 0.0022658304669981606, 0.002728930355794336, + 0.0041241565549176885, 0.0007978491463089834, 0.0059824826428372165, + 0.005700822451952891, 0.001009910787399597, 0.003912510374718906, + 0.0027308466911522586, 0.0021241982974849216, 0.00445178259490492, + 0.00033588622025093644, 0.006172718498657583, 0.005150905140435021, + 0.001294369078080886, 0.003494806955712255, 0.002888072359801849, + 0.0018387087709894891, 0.004481775641522813, 0.0001826105862667237, + 0.006075478923243776, 0.004886660120436632, 0.0013131388257841057, + 0.0033530009608540383, 0.0027890160764299436, 0.001819341801271444, + 0.004264893327075782, 0.0002856826416888495, 0.00574077057772162, + 0.00485997904767977, 0.0011121684505728251, 0.003439706723569149, + 0.002478930810892349, 0.002019434399458528, 0.003845693171211873, + 0.0005991620753479068, 0.0052124555315313965, 0.005028639927429189, + 0.0007317548848252071, 0.0037133714976539506, 0.001997469461541363, + 0.002398103067878713, 0.0032631840382575193, 0.0010828346381034754, + 0.004528898614973675, 0.005355687096081724, 0.00020713973879943928, + 0.004137659389003323, 0.001379277219478372, 0.002919631681924923, + 0.0025514147001573046, 0.0017016039748465222, 0.0037235521808362372, + 0.0004835762677681217, 0.00489568966151517, 0.004680894973891053, + 0.0006545265135735772, 0.003552918766430338, 0.001740005261220558, + 0.002424942558969623, 0.0028254840088675383, 0.001296966351508908, + 0.003910962756514519, 0.0001689901440481931, 0.0049964415041615, + 0.004270978712160627, 0.0008546266928695816, 0.003226396296503344, + 0.001859853580513147, 0.00218181388084606, 0.002865080468156712, + 0.0011372314651887762, 0.003870307355800277, 9.264904953149247e-05, + 0.004875534243443842, 0.00408313784805493, 0.0008483414885398006, + 0.0031157837355450563, 0.0017792497148243077, 0.0021484296230351824, + 0.002710157941108815, 0.0011810755105253086, 0.003641066167393322, + 0.00021372139801543514, 0.00457197439367783, 0.004079728686464056, + 0.0006716204322456959, 0.0031838932167458007, 0.0015337045412328826, + 0.0022880577470275453, 0.0023957886502200695, 0.0013922222773092897, + 0.0032578727592072563, 0.0004963868075910343, 0.004119956868194443, + 0.004227725348621998, 0.00035597961441710105, 0.003398120989238948, + 0.0011543279262957885, 0.0025685166298558978, 0.0019526762381744762, + 0.0017389122704728477, 0.0027510245500531635, 0.0009093079110897976, + 0.0035493728619318513, 7.970355170674751e-05, 0.004347721173810539, + 0.0037299628548962886, 0.0006682946412839993, 0.0029616929924361443, + 0.001407619285405272, 0.0021934231299760003, 0.002146943929526545, + 0.001425153267515856, 0.0028862685736478176, 0.000656883405055712, + 0.0036255932177690904, 0.004154576575021074, 9.92650919068254e-05, + 0.00344310661360058, 0.0007839298194100073, 0.0027316366521800855, + 0.0014685945469131891, 0.002020166690759591, 0.002153259274416371, + 0.0013086967293390965, 0.002837924001919553, 0.0005972267679186017, + 0.0035225887294227346, 0.00399151753267442, 0.00010181095051366054, + 0.0033326481291776314, 0.0007358568907029797, 0.002673778725680842, + 0.0013699028308922986, 0.002014909322184054, 0.0020039487710816176, + 0.0013560399186872648, 0.002637994711270937, 0.0006971705151904762, + 0.003272040651460256, 3.8301111693687365e-05, 0.0039060865916495753, + 0.0033682563798377797, 0.000553036427908215, 0.0027580986579326967, + 0.0011402059404032124, 0.0021479409360276127, 0.0017273754528982098, + 0.0015377832141225299, 0.002314544965393207, 0.0009276254922174464, + 0.002901714477888205, 0.00031746777031236304, 0.003488883990383202, + 0.0035233401613614045, 0.0002608386658672823, 0.0029582927579996227, + 0.0008045974293106004, 0.0023932453546378412, 0.0013483561927539183, + 0.0018281979512760594, 0.0018921149561972363, 0.0012631505479142777, + 0.0024358737196405545, 0.0006981031445524962, 0.0029796324830838727, + 0.00013305574119071463, 0.003523391246527191, 0.0032513804428682767, + 0.0003849812001174835, 0.0027281082517878774, 0.0008885386678154236, + 0.0022048360607074776, 0.0013920961355133636, 0.0016815638696270783, + 0.0018956536032113034, 0.001158291678546679, 0.002399211070909244, + 0.0006350194874662795, 0.0029027685386071836, 0.0001117472963858801, + 0.003406326006305124, 0.0031327631853624, 0.0003667416797849643, + 0.002648177672131103, 0.0008330700230168817, 0.0021635921588998063, + 0.001299398366248799, 0.00167900664566851, 0.0017657267094807168, + 0.0011944211324372133, 0.002232055052712634, 0.0007098356192059166, + 0.0026983833959445514, 0.00022525010597461998, 0.003164711739176469, + 0.0031413131931234614, 0.002692554165534394, 0.002243795137945327, + 0.0017950361103562596, 0.001346277082767192, 0.0008975180551781247, + 0.0004487590275890572}; + Value melFilterData = rewriter.create( + loc, DenseFPElementsAttr::get(RankedTensorType::get(391, f64Ty), + ArrayRef(data))); + + IndexType idxTy = rewriter.getIndexType(); + std::vector D1Index{ + 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, + 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, + 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, + 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, + 31, 31, 32, 32, 33, 33, 34, 34, 35, 35, 36, 36, 37, 37, 38, + 38, 39, 39, 40, 40, 41, 41, 42, 42, 43, 43, 44, 44, 45, 45, + 46, 46, 47, 47, 48, 48, 49, 49, 50, 50, 51, 51, 52, 52, 53, + 53, 54, 54, 55, 55, 56, 56, 57, 57, 58, 58, 59, 59, 60, 60, + 61, 61, 62, 62, 63, 63, 64, 64, 65, 65, 66, 66, 67, 67, 68, + 68, 69, 69, 70, 70, 71, 71, 72, 72, 73, 73, 74, 74, 75, 75, + 76, 76, 77, 77, 78, 78, 79, 79, 80, 80, 81, 81, 82, 82, 83, + 83, 84, 84, 85, 85, 86, 86, 87, 87, 88, 88, 89, 89, 90, 90, + 91, 91, 92, 92, 93, 93, 94, 94, 95, 95, 96, 96, 97, 97, 98, + 98, 99, 99, 100, 100, 101, 101, 102, 102, 103, 103, 104, 104, 105, 105, + 106, 106, 107, 107, 108, 108, 109, 109, 110, 110, 111, 111, 112, 112, 113, + 113, 114, 114, 115, 115, 116, 116, 117, 117, 118, 118, 119, 119, 120, 120, + 121, 121, 122, 122, 123, 123, 124, 124, 125, 125, 126, 126, 127, 127, 128, + 128, 129, 129, 130, 130, 131, 131, 132, 132, 133, 133, 134, 134, 135, 135, + 136, 136, 137, 137, 138, 138, 139, 139, 140, 140, 141, 141, 142, 142, 143, + 143, 144, 144, 145, 145, 146, 146, 147, 147, 148, 148, 149, 149, 150, 150, + 151, 151, 152, 152, 153, 153, 154, 154, 155, 155, 156, 156, 157, 157, 158, + 158, 159, 159, 160, 160, 161, 161, 162, 162, 163, 163, 164, 164, 165, 165, + 166, 166, 167, 167, 168, 168, 169, 169, 170, 170, 171, 171, 172, 172, 173, + 173, 174, 174, 175, 175, 176, 176, 177, 177, 178, 178, 179, 179, 180, 180, + 181, 181, 182, 182, 183, 183, 184, 184, 185, 185, 186, 186, 187, 187, 188, + 188, 189, 189, 190, 190, 191, 191, 192, 192, 193, 194, 195, 196, 197, 198, + 199}; + Value dim1Index = rewriter.create( + loc, DenseElementsAttr::get(RankedTensorType::get(391, idxTy), + ArrayRef(D1Index))); + + std::vector D2Index{ + 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, + 9, 10, 10, 11, 11, 12, 12, 13, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, + 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, + 28, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34, 33, 34, 34, 35, 35, 36, + 36, 37, 36, 37, 37, 38, 38, 39, 38, 39, 39, 40, 39, 40, 40, 41, 41, 42, + 41, 42, 42, 43, 42, 43, 43, 44, 43, 44, 44, 45, 44, 45, 45, 46, 45, 46, + 46, 47, 46, 47, 47, 48, 47, 48, 48, 49, 48, 49, 49, 50, 49, 50, 49, 50, + 50, 51, 50, 51, 51, 52, 51, 52, 51, 52, 52, 53, 52, 53, 53, 54, 53, 54, + 53, 54, 54, 55, 54, 55, 54, 55, 55, 56, 55, 56, 55, 56, 56, 57, 56, 57, + 56, 57, 57, 58, 57, 58, 57, 58, 58, 59, 58, 59, 58, 59, 58, 59, 59, 60, + 59, 60, 59, 60, 60, 61, 60, 61, 60, 61, 60, 61, 61, 62, 61, 62, 61, 62, + 61, 62, 62, 63, 62, 63, 62, 63, 62, 63, 63, 64, 63, 64, 63, 64, 63, 64, + 64, 65, 64, 65, 64, 65, 64, 65, 65, 66, 65, 66, 65, 66, 65, 66, 66, 67, + 66, 67, 66, 67, 66, 67, 66, 67, 67, 68, 67, 68, 67, 68, 67, 68, 67, 68, + 68, 69, 68, 69, 68, 69, 68, 69, 68, 69, 69, 70, 69, 70, 69, 70, 69, 70, + 69, 70, 70, 71, 70, 71, 70, 71, 70, 71, 70, 71, 71, 72, 71, 72, 71, 72, + 71, 72, 71, 72, 71, 72, 72, 73, 72, 73, 72, 73, 72, 73, 72, 73, 73, 74, + 73, 74, 73, 74, 73, 74, 73, 74, 73, 74, 74, 75, 74, 75, 74, 75, 74, 75, + 74, 75, 74, 75, 74, 75, 75, 76, 75, 76, 75, 76, 75, 76, 75, 76, 75, 76, + 76, 77, 76, 77, 76, 77, 76, 77, 76, 77, 76, 77, 76, 77, 77, 78, 77, 78, + 77, 78, 77, 78, 77, 78, 77, 78, 77, 78, 78, 79, 78, 79, 78, 79, 78, 79, + 78, 79, 78, 79, 78, 79, 79, 79, 79, 79, 79, 79, 79}; + Value dim2Index = rewriter.create( + loc, DenseElementsAttr::get(RankedTensorType::get(391, idxTy), + ArrayRef(D2Index))); + + RankedTensorType melFilterType = RankedTensorType::get({201, 80}, f64Ty); + Value melFilter = rewriter.create(loc, melFilterType, f0); + auto mTp = + MemRefType::get(melFilterType.getShape(), melFilterType.getElementType()); + Value melFilterMemRef = + rewriter.create(loc, mTp, melFilter); + + // TODO : remove tomemref & totensor, and use insert to replace store. !! + Value c391 = rewriter.create(loc, 391); + Value number, d1, d2; + // rewriter.create(loc, c0, c391, c1, std::nullopt, + // [&](OpBuilder &builder, Location loc, Value iv, ValueRange iargs) { + // number = builder.create(loc, melFilterData, iv); + // d1 = builder.create(loc, dim1Index, iv); + // d2 = builder.create(loc, dim2Index, iv); + // builder.create(loc, number, melFilterMemRef, + // ValueRange{d1, d2}); builder.create(loc, std::nullopt); + // }); + auto loopOp = rewriter.create(loc, c0, c391, c1); + rewriter.setInsertionPointToStart(loopOp.getBody()); + + Value iv = loopOp.getInductionVar(); + number = rewriter.create(loc, melFilterData, iv); + d1 = rewriter.create(loc, dim1Index, iv); + d2 = rewriter.create(loc, dim2Index, iv); + rewriter.create(loc, number, melFilterMemRef, + ValueRange{d1, d2}); + + rewriter.setInsertionPointAfter(loopOp); + + Value newMelFilter = rewriter.create( + loc, melFilterMemRef, /*restrict=*/true, /*writable=*/false); + + return newMelFilter; +} + +Value getHanningWindow400(PatternRewriter &rewriter, Location loc) { + FloatType f64Ty = rewriter.getF64Type(); + std::vector hanningWindow400{0.0, + 6.168375916970614e-05, + 0.0002467198171342, + 0.0005550625190150482, + 0.0009866357858642205, + 0.001541333133436018, + 0.002219017698460002, + 0.003019522272410202, + 0.0039426493427611176, + 0.0049881711417212315, + 0.00615582970243117, + 0.007445336922613066, + 0.00885637463565564, + 0.01038859468911707, + 0.012041619030626338, + 0.013815039801161721, + 0.015708419435684517, + 0.017721290771101017, + 0.019853157161528523, + 0.02210349260083494, + 0.024471741852423234, + 0.02695732058622735, + 0.029559615522887273, + 0.03227798458506631, + 0.035111757055874326, + 0.03806023374435674, + 0.04112268715800954, + 0.044298361682277465, + 0.04758647376699032, + 0.05098621211969223, + 0.054496737905816106, + 0.05811718495565327, + 0.06184665997806821, + 0.06568424278090434, + 0.06962898649802812, + 0.07367991782295402, + 0.07783603724899257, + 0.08209631931586497, + 0.08645971286271914, + 0.09092514128748835, + 0.09549150281252633, + 0.10015767075645471, + 0.1049224938121548, + 0.10978479633083521, + 0.11474337861210543, + 0.11979701719998453, + 0.1249444651847702, + 0.1301844525106951, + 0.13551568628929433, + 0.14093685111840565, + 0.14644660940672627, + 0.15204360170384285, + 0.15772644703565564, + 0.1634937432451133, + 0.16934406733817414, + 0.17527597583490823, + 0.18128800512565513, + 0.1873786718321474, + 0.1935464731735117, + 0.19978988733705805, + 0.2061073738537635, + 0.21249737397836072, + 0.21895831107393465, + 0.22548859100093405, + 0.23208660251050156, + 0.2387507176420256, + 0.24547929212481434, + 0.2522706657837962, + 0.2591231629491423, + 0.2660350928697134, + 0.2730047501302266, + 0.2800304150720424, + 0.28711035421746367, + 0.2942428206974456, + 0.30142605468260963, + 0.30865828381745525, + 0.31593772365766115, + 0.3232625781103715, + 0.3306310398773543, + 0.3380412909009253, + 0.34549150281252644, + 0.3529798373838481, + 0.3605044469803854, + 0.36806347501731357, + 0.3756550564175726, + 0.38327731807204724, + 0.39092837930172886, + 0.3986063523217438, + 0.4063093427071377, + 0.41403544986029517, + 0.4217827674798846, + 0.4295493840312088, + 0.4373333832178479, + 0.44513284445447737, + 0.45294584334074284, + 0.4607704521360776, + 0.4686047402353433, + 0.4764467746451787, + 0.48429462046093585, + 0.49214634134408974, + 0.5, + 0.5078536586559104, + 0.5157053795390641, + 0.5235532253548213, + 0.5313952597646567, + 0.5392295478639225, + 0.5470541566592572, + 0.5548671555455227, + 0.5626666167821522, + 0.5704506159687914, + 0.5782172325201155, + 0.5859645501397047, + 0.5936906572928624, + 0.6013936476782563, + 0.6090716206982714, + 0.6167226819279528, + 0.6243449435824273, + 0.6319365249826864, + 0.6394955530196147, + 0.647020162616152, + 0.6545084971874737, + 0.6619587090990747, + 0.6693689601226458, + 0.6767374218896286, + 0.6840622763423391, + 0.6913417161825449, + 0.6985739453173903, + 0.7057571793025544, + 0.7128896457825363, + 0.7199695849279575, + 0.7269952498697734, + 0.7339649071302867, + 0.7408768370508576, + 0.7477293342162038, + 0.7545207078751857, + 0.7612492823579744, + 0.7679133974894983, + 0.7745114089990659, + 0.7810416889260654, + 0.7875026260216393, + 0.7938926261462367, + 0.8002101126629421, + 0.8064535268264883, + 0.8126213281678527, + 0.8187119948743449, + 0.8247240241650918, + 0.8306559326618259, + 0.8365062567548867, + 0.8422735529643444, + 0.8479563982961571, + 0.8535533905932737, + 0.8590631488815944, + 0.8644843137107058, + 0.8698155474893048, + 0.8750555348152298, + 0.8802029828000155, + 0.8852566213878946, + 0.8902152036691648, + 0.8950775061878451, + 0.8998423292435453, + 0.9045084971874737, + 0.9090748587125117, + 0.9135402871372809, + 0.9179036806841352, + 0.9221639627510075, + 0.9263200821770461, + 0.9303710135019718, + 0.9343157572190957, + 0.9381533400219317, + 0.9418828150443468, + 0.9455032620941839, + 0.9490137878803078, + 0.9524135262330098, + 0.9557016383177226, + 0.9588773128419905, + 0.9619397662556434, + 0.9648882429441257, + 0.9677220154149337, + 0.9704403844771128, + 0.9730426794137726, + 0.9755282581475768, + 0.977896507399165, + 0.9801468428384715, + 0.982278709228899, + 0.9842915805643155, + 0.9861849601988383, + 0.9879583809693737, + 0.9896114053108829, + 0.9911436253643444, + 0.9925546630773869, + 0.9938441702975689, + 0.9950118288582788, + 0.996057350657239, + 0.9969804777275899, + 0.99778098230154, + 0.998458666866564, + 0.9990133642141358, + 0.9994449374809851, + 0.9997532801828658, + 0.9999383162408303, + 1.0, + 0.9999383162408303, + 0.9997532801828658, + 0.9994449374809851, + 0.9990133642141358, + 0.998458666866564, + 0.99778098230154, + 0.9969804777275899, + 0.996057350657239, + 0.9950118288582788, + 0.9938441702975689, + 0.9925546630773869, + 0.9911436253643444, + 0.9896114053108829, + 0.9879583809693737, + 0.9861849601988383, + 0.9842915805643155, + 0.982278709228899, + 0.9801468428384715, + 0.977896507399165, + 0.9755282581475768, + 0.9730426794137726, + 0.9704403844771128, + 0.9677220154149337, + 0.9648882429441257, + 0.9619397662556434, + 0.9588773128419905, + 0.9557016383177226, + 0.9524135262330098, + 0.9490137878803078, + 0.9455032620941839, + 0.9418828150443468, + 0.9381533400219317, + 0.9343157572190957, + 0.9303710135019718, + 0.9263200821770461, + 0.9221639627510075, + 0.9179036806841352, + 0.9135402871372809, + 0.9090748587125117, + 0.9045084971874737, + 0.8998423292435453, + 0.8950775061878451, + 0.8902152036691648, + 0.8852566213878946, + 0.8802029828000155, + 0.8750555348152298, + 0.8698155474893048, + 0.8644843137107058, + 0.8590631488815944, + 0.8535533905932737, + 0.8479563982961571, + 0.8422735529643444, + 0.8365062567548867, + 0.8306559326618259, + 0.8247240241650918, + 0.8187119948743449, + 0.8126213281678527, + 0.8064535268264883, + 0.8002101126629421, + 0.7938926261462367, + 0.7875026260216393, + 0.7810416889260654, + 0.7745114089990659, + 0.7679133974894983, + 0.7612492823579744, + 0.7545207078751857, + 0.7477293342162038, + 0.7408768370508576, + 0.7339649071302867, + 0.7269952498697734, + 0.7199695849279575, + 0.7128896457825363, + 0.7057571793025544, + 0.6985739453173903, + 0.6913417161825449, + 0.6840622763423391, + 0.6767374218896286, + 0.6693689601226458, + 0.6619587090990747, + 0.6545084971874737, + 0.647020162616152, + 0.6394955530196147, + 0.6319365249826864, + 0.6243449435824273, + 0.6167226819279528, + 0.6090716206982714, + 0.6013936476782563, + 0.5936906572928624, + 0.5859645501397047, + 0.5782172325201155, + 0.5704506159687914, + 0.5626666167821522, + 0.5548671555455227, + 0.5470541566592572, + 0.5392295478639225, + 0.5313952597646567, + 0.5235532253548213, + 0.5157053795390641, + 0.5078536586559104, + 0.5, + 0.49214634134408974, + 0.48429462046093585, + 0.4764467746451787, + 0.4686047402353433, + 0.4607704521360776, + 0.45294584334074284, + 0.44513284445447737, + 0.4373333832178479, + 0.4295493840312088, + 0.4217827674798846, + 0.41403544986029517, + 0.4063093427071377, + 0.3986063523217438, + 0.39092837930172886, + 0.38327731807204724, + 0.3756550564175726, + 0.36806347501731357, + 0.3605044469803854, + 0.3529798373838481, + 0.34549150281252644, + 0.3380412909009253, + 0.3306310398773543, + 0.3232625781103715, + 0.31593772365766115, + 0.30865828381745525, + 0.30142605468260963, + 0.2942428206974456, + 0.28711035421746367, + 0.2800304150720424, + 0.2730047501302266, + 0.2660350928697134, + 0.2591231629491423, + 0.2522706657837962, + 0.24547929212481434, + 0.2387507176420256, + 0.23208660251050156, + 0.22548859100093405, + 0.21895831107393465, + 0.21249737397836072, + 0.2061073738537635, + 0.19978988733705805, + 0.1935464731735117, + 0.1873786718321474, + 0.18128800512565513, + 0.17527597583490823, + 0.16934406733817414, + 0.1634937432451133, + 0.15772644703565564, + 0.15204360170384285, + 0.14644660940672627, + 0.14093685111840565, + 0.13551568628929433, + 0.1301844525106951, + 0.1249444651847702, + 0.11979701719998453, + 0.11474337861210543, + 0.10978479633083521, + 0.1049224938121548, + 0.10015767075645471, + 0.09549150281252633, + 0.09092514128748835, + 0.08645971286271914, + 0.08209631931586497, + 0.07783603724899257, + 0.07367991782295402, + 0.06962898649802812, + 0.06568424278090434, + 0.06184665997806821, + 0.05811718495565327, + 0.054496737905816106, + 0.05098621211969223, + 0.04758647376699032, + 0.044298361682277465, + 0.04112268715800954, + 0.03806023374435674, + 0.035111757055874326, + 0.03227798458506631, + 0.029559615522887273, + 0.02695732058622735, + 0.024471741852423234, + 0.02210349260083494, + 0.019853157161528523, + 0.017721290771101017, + 0.015708419435684517, + 0.013815039801161721, + 0.012041619030626338, + 0.01038859468911707, + 0.00885637463565564, + 0.007445336922613066, + 0.00615582970243117, + 0.0049881711417212315, + 0.0039426493427611176, + 0.003019522272410202, + 0.002219017698460002, + 0.001541333133436018, + 0.0009866357858642205, + 0.0005550625190150482, + 0.0002467198171342, + 6.168375916970614e-05}; + Value window = rewriter.create( + loc, DenseFPElementsAttr::get(RankedTensorType::get(400, f64Ty), + ArrayRef(hanningWindow400))); + return window; +} + +// Implement numpy reflect padding, low for left padding length, high for right +// padding length +Value padReflect(PatternRewriter &rewriter, Location loc, Value c0, Value c1, + Value input, int64_t low, int64_t high) { + Value lowPadLen = rewriter.create(loc, low); + Value highPadLen = rewriter.create(loc, high); + SmallVector lowValues; + SmallVector highValues; + lowValues.push_back(lowPadLen); + highValues.push_back(c0); + + FloatType f64Ty = rewriter.getF64Type(); + IndexType idxTy = rewriter.getIndexType(); + // Pad left part(low) for input tensor + int64_t inputSize = + llvm::cast(input.getType()).getShape()[0]; + int64_t lowPaddedSize = inputSize + low; + auto padOp1 = rewriter.create( + loc, RankedTensorType::get(lowPaddedSize, f64Ty), input, lowValues, + highValues); + + Region *padOpRegion1 = &padOp1.getRegion(); + int64_t sourceRank1 = llvm::cast(input.getType()).getRank(); + SmallVector blockArgTypes1(sourceRank1, idxTy); + SmallVector blockArgLocs1(sourceRank1, loc); + + // Create Block for padOp1 and insert operations + OpBuilder::InsertPoint ip1(rewriter.saveInsertionPoint()); + rewriter.createBlock(padOpRegion1, padOpRegion1->end(), blockArgTypes1, + blockArgLocs1); + Value iv1 = padOp1.getRegion().front().getArgument(0); + Value idx1 = rewriter.create(loc, lowPadLen, iv1); + Value elem1 = rewriter.create(loc, input, idx1); + rewriter.create(loc, elem1); + rewriter.restoreInsertionPoint(ip1); + lowValues.clear(); + highValues.clear(); + + Value lowPaddedInput = padOp1.getResult(); + + // Pad right part(high) for lowPaddedInput tensor + lowValues.push_back(c0); + highValues.push_back(highPadLen); + int64_t highPaddedSize = lowPaddedSize + high; + Value lowPaddedInputDim = + rewriter.create(loc, lowPaddedInput, c0); + Value symIndex = rewriter.create(loc, lowPaddedInputDim, c1); + auto padOp2 = rewriter.create( + loc, RankedTensorType::get(highPaddedSize, f64Ty), lowPaddedInput, + lowValues, highValues); + Region *padOpRegion2 = &padOp2.getRegion(); + int64_t sourceRank2 = + llvm::cast(lowPaddedInput.getType()).getRank(); + SmallVector blockArgTypes2(sourceRank2, idxTy); + SmallVector blockArgLocs2(sourceRank2, loc); + + OpBuilder::InsertPoint ip2(rewriter.saveInsertionPoint()); + rewriter.createBlock(padOpRegion2, padOpRegion2->end(), blockArgTypes2, + blockArgLocs2); + Value iv2 = padOp2.getRegion().front().getArgument(0); + Value sub = rewriter.create(loc, iv2, symIndex); + Value idx2 = rewriter.create(loc, symIndex, sub); + Value elem2 = rewriter.create(loc, lowPaddedInput, idx2); + rewriter.create(loc, elem2); + rewriter.restoreInsertionPoint(ip2); + lowValues.clear(); + highValues.clear(); + + return padOp2.getResult(); +} + +// function to print a memref for debug +void printMemref(OpBuilder &rewriter, Location loc, Value input, int l) { + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value length = rewriter.create(loc, l); + rewriter.create(loc, "Print Start:\n"); + + rewriter.create( + loc, c0, length, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value x = b.create(loc, input, i); + b.create(loc, x); + + b.create(loc, std::nullopt); + }); + + rewriter.create(loc, "\n"); +} + +// WA CC CH PM MULPM C1 C1w C2 CH2 CH2w CH_radfg CCw CSARR AR AI IANG are helper +// functions for RFFTP +inline Value WA(OpBuilder &builder, Location loc, Value wa, Value x, Value i, + Value ido, Value c1) { + Value idom1 = builder.create(loc, ido, c1); + Value tmp1 = builder.create(loc, x, idom1); + Value index = builder.create(loc, tmp1, i); + return builder.create(loc, wa, index); +} + +inline Value CC(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value l1) { + Value tmp1 = builder.create(loc, l1, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + return builder.create(loc, cc, index); +} + +inline void CH(OpBuilder &builder, Location loc, Value ch, Value a, Value b, + Value c, Value ido, Value cdim, Value toWrite) { + Value tmp1 = builder.create(loc, cdim, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + builder.create(loc, toWrite, ch, index); + return; +} + +inline std::vector PM(OpBuilder &builder, Location loc, Value c, + Value d) { + return {builder.create(loc, c, d), + builder.create(loc, c, d)}; +} + +inline std::vector MULPM(OpBuilder &builder, Location loc, Value c, + Value d, Value e, Value f) { + Value tmp1 = builder.create(loc, c, e); + Value tmp2 = builder.create(loc, d, f); + Value tmp3 = builder.create(loc, c, f); + Value tmp4 = builder.create(loc, d, e); + return {builder.create(loc, tmp1, tmp2), + builder.create(loc, tmp3, tmp4)}; +} + +inline Value C1(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value l1) { + Value tmp1 = builder.create(loc, l1, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + return builder.create(loc, cc, index); +} + +inline void C1w(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value l1, Value toWrite) { + Value tmp1 = builder.create(loc, l1, c); + Value tmp2 = builder.create(loc, tmp1, b); + Value tmp3 = builder.create(loc, tmp2, ido); + Value index = builder.create(loc, tmp3, a); + builder.create(loc, toWrite, cc, index); + return; +} + +inline Value C2(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value idl1) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + return builder.create(loc, cc, index); +} + +inline Value CH2(OpBuilder &builder, Location loc, Value ch, Value a, Value b, + Value idl1) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + return builder.create(loc, ch, index); +} + +inline void CH2w(OpBuilder &builder, Location loc, Value ch, Value a, Value b, + Value idl1, Value toWrite) { + Value tmp1 = builder.create(loc, idl1, b); + Value index = builder.create(loc, tmp1, a); + builder.create(loc, toWrite, ch, index); + return; +} + +inline Value CH_radfg(OpBuilder &builder, Location loc, Value ch, Value a, + Value b, Value c, Value ido, Value l1) { + Value tmp = builder.create(loc, l1, c); + Value tmp1 = builder.create(loc, b, tmp); + Value tmp2 = builder.create(loc, tmp1, ido); + Value index = builder.create(loc, tmp2, a); + return builder.create(loc, ch, index); +} + +inline void CCw(OpBuilder &builder, Location loc, Value cc, Value a, Value b, + Value c, Value ido, Value cdim, Value toWrite) { + Value tmp = builder.create(loc, cdim, c); + Value tmp1 = builder.create(loc, b, tmp); + Value tmp2 = builder.create(loc, tmp1, ido); + Value index = builder.create(loc, tmp2, a); + builder.create(loc, toWrite, cc, index); + return; +} + +inline Value CSARR(OpBuilder &builder, Location loc, Value csarr, Value index) { + + return builder.create(loc, csarr, index); +} + +inline Value AR(OpBuilder &builder, Location loc, Value csarr, Value iang) { + Value c2 = builder.create(loc, 2); + Value index = builder.create(loc, iang, c2); + return CSARR(builder, loc, csarr, index); +} + +inline Value AI(OpBuilder &builder, Location loc, Value csarr, Value iang) { + Value c1 = builder.create(loc, 1); + Value c2 = builder.create(loc, 2); + Value tmp = builder.create(loc, iang, c2); + Value index = builder.create(loc, tmp, c1); + return CSARR(builder, loc, csarr, index); +} + +inline Value IANG(OpBuilder &builder, Location loc, Value iang, Value l, + Value ip) { + + Value iang_new = builder.create(loc, iang, l); + + Value condition = builder.create( + loc, arith::CmpIPredicate::sge, iang_new, ip); + + auto result = builder.create( + loc, condition, + [&](OpBuilder &b, Location loc) { + Value res = b.create(loc, iang_new, ip); + b.create(loc, ValueRange{res}); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{iang_new}); + }); + + return result.getResult(0); +} + +void radfgExtend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value csarr, Value ido, Value ip, Value l1) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value cdim = opBuilder.create(loc, ip, c0); + Value tmp0 = opBuilder.create(loc, ip, c1); + Value ipph = opBuilder.create(loc, tmp0, c2); + Value idom1 = opBuilder.create(loc, ido, c1); + Value idom2 = opBuilder.create(loc, ido, c2); + Value idl1 = opBuilder.create(loc, ido, l1); + + opBuilder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value ik, ValueRange ik_args) { + Value c2ik0 = C2(builder, loc, cc, ik, c0, idl1); + CH2w(builder, loc, ch, ik, c0, idl1, c2ik0); + builder.create(loc, std::nullopt); + }); + + opBuilder.create( + loc, c1, ipph, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value j, ValueRange j_args) { + builder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value ik, ValueRange ik_args) { + Value c2ikj = C2(b, loc, cc, ik, j, idl1); + Value ch2ik0 = CH2(b, loc, ch, ik, c0, idl1); + Value ch2ik0_updated = + b.create(loc, ch2ik0, c2ikj); + + CH2w(b, loc, ch, ik, c0, idl1, ch2ik0_updated); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + builder.create( + loc, c0, ido, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value chik0 = CH_radfg(b, loc, ch, i, k, c0, ido, l1); + + CCw(b, loc, cc, i, c0, k, ido, cdim, chik0); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + + Value j_start_0 = opBuilder.create(loc, 1); + Value jc_start_0 = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{j_start_0, jc_start_0}, + [&](OpBuilder &builder, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + + Value tmp = builder.create(loc, j, c2); + Value j2 = builder.create(loc, tmp, c1); + Value j2p1 = builder.create(loc, j2, c1); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k, ValueRange k_args) { + Value ch0kj = CH_radfg(b, loc, ch, c0, k, j, ido, l1); + CCw(b, loc, cc, idom1, j2, k, ido, cdim, ch0kj); + + Value ch0kjc = CH_radfg(b, loc, ch, c0, k, jc, ido, l1); + CCw(b, loc, cc, c0, j2p1, k, ido, cdim, ch0kjc); + + b.create(loc, std::nullopt); + }); + + Value j_next = builder.create(loc, j, c1); + Value jc_next = builder.create(loc, jc, c1); + builder.create(loc, std::vector{j_next, jc_next}); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::ne, ido, l1); + + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + Value j_start_1 = opBuilder.create(loc, 1); + Value jc_start_1 = opBuilder.create(loc, ip, c1); + + builder.create( + loc, c1, ipph, c1, ValueRange{j_start_1, jc_start_1}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + + Value tmp = b.create(loc, j, c2); + Value j2 = b.create(loc, tmp, c1); + Value j2p1 = b.create(loc, j2, c1); + + b.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value k, ValueRange k_args) { + Value i_start_0 = b2.create(loc, 1); + Value ic_start_0 = b2.create(loc, ido, c3); + + b2.create( + loc, c1, idom1, c2, ValueRange{i_start_0, ic_start_0}, + [&](OpBuilder &b3, Location loc, Value i_loop, + ValueRange i_loop_args) { + Value i = i_loop_args[0]; + Value ic = i_loop_args[1]; + + Value ip1 = b3.create(loc, i, c1); + Value icp1 = b3.create(loc, ic, c1); + + Value chikj = CH_radfg(b3, loc, ch, i, k, j, ido, l1); + Value chikjc = + CH_radfg(b3, loc, ch, i, k, jc, ido, l1); + Value tmp2 = + b3.create(loc, chikj, chikjc); + Value tmp3 = + b3.create(loc, chikj, chikjc); + CCw(b3, loc, cc, i, j2p1, k, ido, cdim, tmp2); + CCw(b3, loc, cc, ic, j2, k, ido, cdim, tmp3); + + Value chip1kj = + CH_radfg(b3, loc, ch, ip1, k, j, ido, l1); + Value chip1kjc = + CH_radfg(b3, loc, ch, ip1, k, jc, ido, l1); + Value tmp4 = + b3.create(loc, chip1kj, chip1kjc); + Value tmp5 = + b3.create(loc, chip1kjc, chip1kj); + CCw(b3, loc, cc, ip1, j2p1, k, ido, cdim, tmp4); + CCw(b3, loc, cc, icp1, j2, k, ido, cdim, tmp5); + + Value i_next = b3.create(loc, i, c2); + Value ic_next = b3.create(loc, ic, c2); + b3.create( + loc, std::vector{i_next, ic_next}); + }); + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c1); + Value jc_next = b.create(loc, jc, c1); + b.create(loc, std::vector{j_next, jc_next}); + }); + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle general radix FFT computation. +void radfg(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value csarr, Value ido, Value ip, Value l1) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value ipm1 = opBuilder.create(loc, ip, c1); + Value ipm2 = opBuilder.create(loc, ip, c2); + + Value cdim = opBuilder.create(loc, ip, c0); + Value tmp = opBuilder.create(loc, ip, c1); + Value ipph = opBuilder.create(loc, tmp, c2); + + Value idl1 = opBuilder.create(loc, ido, l1); + Value idom1 = opBuilder.create(loc, ido, c1); + Value idom2 = opBuilder.create(loc, ido, c2); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, l1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value jc_start = builder.create(loc, ip, c1); + + builder.create( + loc, c1, ipph, c1, ValueRange{jc_start}, + [&](OpBuilder &b, Location loc, Value j, ValueRange j_args) { + Value jc = j_args[0]; + + Value jm1 = b.create(loc, j, c1); + Value jcm1 = b.create(loc, jc, c1); + + Value is = b.create(loc, jm1, idom1); + Value is2 = b.create(loc, jcm1, idom1); + + b.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value k, ValueRange k_args) { + Value idij_start = b2.create(loc, is, c0); + Value idij2_start = b2.create(loc, is2, c0); + + b2.create( + loc, c1, idom1, c2, ValueRange{idij_start, idij2_start}, + [&](OpBuilder &b3, Location loc, Value i, + ValueRange i_args) { + Value idij = i_args[0]; + Value idij2 = i_args[1]; + + Value ip1 = b3.create(loc, i, c1); + Value idijp1 = + b3.create(loc, idij, c1); + Value idij2p1 = + b3.create(loc, idij2, c1); + + Value t1 = C1(b3, loc, cc, i, k, j, ido, l1); + Value t2 = C1(b3, loc, cc, ip1, k, j, ido, l1); + Value t3 = C1(b3, loc, cc, i, k, jc, ido, l1); + Value t4 = C1(b3, loc, cc, ip1, k, jc, ido, l1); + + Value waidij = + b3.create(loc, wa, idij); + Value waidijp1 = + b3.create(loc, wa, idijp1); + Value waidij2 = + b3.create(loc, wa, idij2); + Value waidij2p1 = + b3.create(loc, wa, idij2p1); + + Value tmp1_x1 = + b3.create(loc, waidij, t1); + Value tmp2_x1 = + b3.create(loc, waidijp1, t2); + Value x1 = + b3.create(loc, tmp1_x1, tmp2_x1); + + Value tmp1_x2 = + b3.create(loc, waidij, t2); + Value tmp2_x2 = + b3.create(loc, waidijp1, t1); + Value x2 = + b3.create(loc, tmp1_x2, tmp2_x2); + + Value tmp1_x3 = + b3.create(loc, waidij2, t3); + Value tmp2_x3 = + b3.create(loc, waidij2p1, t4); + Value x3 = + b3.create(loc, tmp1_x3, tmp2_x3); + + Value tmp1_x4 = + b3.create(loc, waidij2, t4); + Value tmp2_x4 = + b3.create(loc, waidij2p1, t3); + Value x4 = + b3.create(loc, tmp1_x4, tmp2_x4); + + Value tmp3 = b3.create(loc, x1, x3); + Value tmp4 = b3.create(loc, x2, x4); + Value tmp5 = b3.create(loc, x2, x4); + Value tmp6 = b3.create(loc, x3, x1); + + C1w(b3, loc, cc, i, k, j, ido, l1, tmp3); + C1w(b3, loc, cc, i, k, jc, ido, l1, tmp4); + C1w(b3, loc, cc, ip1, k, j, ido, l1, tmp5); + C1w(b3, loc, cc, ip1, k, jc, ido, l1, tmp6); + + Value idij_next = + b3.create(loc, idij, c2); + Value idij2_next = + b3.create(loc, idij2, c2); + + b3.create( + loc, std::vector{idij_next, idij2_next}); + }); + b2.create(loc, std::nullopt); + } + + ); + + Value jc_next = b.create(loc, jc, c1); + b.create(loc, jc_next); + }); + + builder.create(loc, std::nullopt); + }); + + Value jc_a_start = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{jc_a_start}, + [&](OpBuilder &builder, Location loc, Value j_a, ValueRange j_a_args) { + Value jc_a = j_a_args[0]; + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k_a, ValueRange k_a_args) { + Value t1_a = C1(b, loc, cc, c0, k_a, j_a, ido, l1); + Value t2_a = C1(b, loc, cc, c0, k_a, jc_a, ido, l1); + + Value tmp_a = b.create(loc, t1_a, t2_a); + Value tmp1_a = b.create(loc, t2_a, t1_a); + + C1w(b, loc, cc, c0, k_a, j_a, ido, l1, tmp_a); + C1w(b, loc, cc, c0, k_a, jc_a, ido, l1, tmp1_a); + b.create(loc, std::nullopt); + }); + + Value jc_a_next = builder.create(loc, jc_a, c1); + builder.create(loc, jc_a_next); + }); + + Value lc_b_start = opBuilder.create(loc, ip, c1); + + opBuilder.create( + loc, c1, ipph, c1, ValueRange{lc_b_start}, + [&](OpBuilder &builder, Location loc, Value l_b, ValueRange l_b_args) { + Value lc_b = l_b_args[0]; + + builder.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value ik_b, ValueRange ik_b_args) { + Value m2l = b.create(loc, l_b, c2); + Value m4l = b.create(loc, l_b, c4); + Value m2lp1 = b.create(loc, m2l, c1); + Value m4lp1 = b.create(loc, m4l, c1); + + Value csarr2l = CSARR(b, loc, csarr, m2l); + Value csarr4l = CSARR(b, loc, csarr, m4l); + Value csarr2lp1 = CSARR(b, loc, csarr, m2lp1); + Value csarr4lp1 = CSARR(b, loc, csarr, m4lp1); + + Value c2ik0 = C2(b, loc, cc, ik_b, c0, idl1); + Value c2ik1 = C2(b, loc, cc, ik_b, c1, idl1); + Value c2ik2 = C2(b, loc, cc, ik_b, c2, idl1); + + Value c2ikipm1 = C2(b, loc, cc, ik_b, ipm1, idl1); + Value c2ikipm2 = C2(b, loc, cc, ik_b, ipm2, idl1); + + Value tmp_b = b.create(loc, csarr2l, c2ik1); + Value tmp1_b = b.create(loc, csarr4l, c2ik2); + Value tmp2_b = b.create(loc, tmp_b, tmp1_b); + Value tmp3_b = b.create(loc, c2ik0, tmp2_b); + + CH2w(b, loc, ch, ik_b, l_b, idl1, tmp3_b); + + Value tmp4_b = b.create(loc, csarr2lp1, c2ikipm1); + Value tmp5_b = b.create(loc, csarr4lp1, c2ikipm2); + Value tmp6_b = b.create(loc, tmp4_b, tmp5_b); + + CH2w(b, loc, ch, ik_b, lc_b, idl1, tmp6_b); + b.create(loc, std::nullopt); + }); + + Value iang_start_c = builder.create(loc, c2, l_b); + Value j_start_c = builder.create(loc, 3); + Value jc_start_c = builder.create(loc, ip, c3); + Value ipphm1 = builder.create(loc, ipph, c1); + Value ipphm3 = builder.create(loc, ipph, c3); + + auto loop1 = builder.create( + loc, j_start_c, ipphm3, c4, + ValueRange{j_start_c, jc_start_c, iang_start_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_c = IANG(b, loc, iang, l_b, ip); + Value ar1 = AR(b, loc, csarr, iang_1_c); + Value ai1 = AI(b, loc, csarr, iang_1_c); + + Value iang_2_c = IANG(b, loc, iang_1_c, l_b, ip); + Value ar2 = AR(b, loc, csarr, iang_2_c); + Value ai2 = AI(b, loc, csarr, iang_2_c); + + Value iang_3_c = IANG(b, loc, iang_2_c, l_b, ip); + Value ar3 = AR(b, loc, csarr, iang_3_c); + Value ai3 = AI(b, loc, csarr, iang_3_c); + + Value iang_4_c = IANG(b, loc, iang_3_c, l_b, ip); + Value ar4 = AR(b, loc, csarr, iang_4_c); + Value ai4 = AI(b, loc, csarr, iang_4_c); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_c, + ValueRange ik_c_args) { + Value jp1 = b2.create(loc, j, c1); + Value jp2 = b2.create(loc, j, c2); + Value jp3 = b2.create(loc, j, c3); + Value jm1 = b2.create(loc, j, c1); + Value jm2 = b2.create(loc, j, c2); + Value jm3 = b2.create(loc, j, c3); + + Value c2ikj = C2(b2, loc, cc, ik_c, j, idl1); + Value c2ikjp1 = C2(b2, loc, cc, ik_c, jp1, idl1); + Value c2ikjp2 = C2(b2, loc, cc, ik_c, jp2, idl1); + Value c2ikjp3 = C2(b2, loc, cc, ik_c, jp3, idl1); + + Value tmp_c = b2.create(loc, ar1, c2ikj); + Value tmp1_c = b2.create(loc, ar2, c2ikjp1); + Value tmp2_c = b2.create(loc, ar3, c2ikjp2); + Value tmp3_c = b2.create(loc, ar4, c2ikjp3); + + Value tmp4_c = b2.create(loc, tmp_c, tmp1_c); + Value tmp5_c = + b2.create(loc, tmp4_c, tmp2_c); + Value tmp6_c = + b2.create(loc, tmp5_c, tmp3_c); + + Value ch2ikl = CH2(b2, loc, ch, ik_c, l_b, idl1); + Value tmp7_c = + b2.create(loc, tmp6_c, ch2ikl); + CH2w(b2, loc, ch, ik_c, l_b, idl1, tmp7_c); + + Value jcm1 = b2.create(loc, jc, c1); + Value jcm2 = b2.create(loc, jc, c2); + Value jcm3 = b2.create(loc, jc, c3); + + Value c2ikjc = C2(b2, loc, cc, ik_c, jc, idl1); + Value c2ikjcm1 = C2(b2, loc, cc, ik_c, jcm1, idl1); + Value c2ikjcm2 = C2(b2, loc, cc, ik_c, jcm2, idl1); + Value c2ikjcm3 = C2(b2, loc, cc, ik_c, jcm3, idl1); + + Value tmp_ai1 = b2.create(loc, ai1, c2ikjc); + Value tmp_ai2 = + b2.create(loc, ai2, c2ikjcm1); + Value tmp_ai3 = + b2.create(loc, ai3, c2ikjcm2); + Value tmp_ai4 = + b2.create(loc, ai4, c2ikjcm3); + + Value tmp_ai5 = + b2.create(loc, tmp_ai1, tmp_ai2); + Value tmp_ai6 = + b2.create(loc, tmp_ai5, tmp_ai3); + Value tmp_ai7 = + b2.create(loc, tmp_ai6, tmp_ai4); + + Value ch2iklc = CH2(b2, loc, ch, ik_c, lc_b, idl1); + Value tmp_ai8 = + b2.create(loc, tmp_ai7, ch2iklc); + CH2w(b2, loc, ch, ik_c, lc_b, idl1, tmp_ai8); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c4); + Value jc_next = b.create(loc, jc, c4); + builder.create( + loc, std::vector{j_next, jc_next, iang_4_c}); + }); + + Value j_1_c = loop1.getResults()[0]; + Value jc_1_c = loop1.getResults()[1]; + Value iang1_c = loop1.getResults()[2]; + + auto loop2 = builder.create( + loc, j_1_c, ipphm1, c2, ValueRange{j_1_c, jc_1_c, iang1_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_d = IANG(b, loc, iang, l_b, ip); + Value ar1 = AR(b, loc, csarr, iang_1_d); + Value ai1 = AI(b, loc, csarr, iang_1_d); + + Value iang_2_d = IANG(b, loc, iang_1_d, l_b, ip); + Value ar2 = AR(b, loc, csarr, iang_2_d); + Value ai2 = AI(b, loc, csarr, iang_2_d); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_d, + ValueRange ik_d_args) { + Value jp1 = b2.create(loc, j, c1); + Value jm1 = b2.create(loc, j, c1); + + Value c2ikj = C2(b2, loc, cc, ik_d, j, idl1); + Value c2ikjp1 = C2(b2, loc, cc, ik_d, jp1, idl1); + + Value tmp_c = b2.create(loc, ar1, c2ikj); + Value tmp1_c = b2.create(loc, ar2, c2ikjp1); + Value tmp2_c = b2.create(loc, tmp_c, tmp1_c); + + Value ch2ikl = CH2(b2, loc, ch, ik_d, l_b, idl1); + Value tmp3_c = + b2.create(loc, tmp2_c, ch2ikl); + CH2w(b2, loc, ch, ik_d, l_b, idl1, tmp3_c); + + Value jcm1 = b2.create(loc, jc, c1); + Value c2ikjc = C2(b2, loc, cc, ik_d, jc, idl1); + Value c2ikjcm1 = C2(b2, loc, cc, ik_d, jcm1, idl1); + + Value tmp_ai1 = b2.create(loc, ai1, c2ikjc); + Value tmp_ai2 = + b2.create(loc, ai2, c2ikjcm1); + Value tmp_ai3 = + b2.create(loc, tmp_ai1, tmp_ai2); + + Value ch2iklc = CH2(b2, loc, ch, ik_d, lc_b, idl1); + Value tmp_ai4 = + b2.create(loc, tmp_ai3, ch2iklc); + CH2w(b2, loc, ch, ik_d, lc_b, idl1, tmp_ai4); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c2); + Value jc_next = b.create(loc, jc, c2); + builder.create( + loc, std::vector{j_next, jc_next, iang_2_d}); + }); + + Value j_2_c = loop2.getResults()[0]; + Value jc_2_c = loop2.getResults()[1]; + Value iang2_c = loop2.getResults()[2]; + + auto loop3 = builder.create( + loc, j_2_c, ipph, c1, ValueRange{j_2_c, jc_2_c, iang2_c}, + [&](OpBuilder &b, Location loc, Value j_loop, + ValueRange j_loop_args) { + Value j = j_loop_args[0]; + Value jc = j_loop_args[1]; + Value iang = j_loop_args[2]; + + Value iang_1_e = IANG(b, loc, iang, l_b, ip); + Value ar = AR(b, loc, csarr, iang_1_e); + Value ai = AI(b, loc, csarr, iang_1_e); + + b.create( + loc, c0, idl1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value ik_e, + ValueRange ik_e_args) { + Value c2ikj = C2(b2, loc, cc, ik_e, j, idl1); + Value tmp_c = b2.create(loc, ar, c2ikj); + Value ch2ikl = CH2(b2, loc, ch, ik_e, l_b, idl1); + Value tmp2_c = b2.create(loc, tmp_c, ch2ikl); + CH2w(b2, loc, ch, ik_e, l_b, idl1, tmp2_c); + + Value c2ikjc = C2(b2, loc, cc, ik_e, jc, idl1); + Value tmp_ai = b2.create(loc, ai, c2ikjc); + Value ch2iklc = CH2(b2, loc, ch, ik_e, lc_b, idl1); + Value tmp2_ai = + b2.create(loc, tmp_ai, ch2iklc); + CH2w(b2, loc, ch, ik_e, lc_b, idl1, tmp2_ai); + + b2.create(loc, std::nullopt); + }); + + Value j_next = b.create(loc, j, c2); + Value jc_next = b.create(loc, jc, c2); + builder.create( + loc, std::vector{j_next, jc_next, iang_1_e}); + }); + + Value lc_b_next = builder.create(loc, lc_b, c1); + builder.create(loc, lc_b_next); + }); + + radfgExtend(opBuilder, loc, cc, ch, wa, csarr, ido, ip, l1); +} + +void radf2Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim) { + FloatType f64Ty = opBuilder.getF64Type(); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c20 = opBuilder.create(loc, 20); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector tr2_ti2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + std::vector ccim1k0_tr2 = PM(b, loc, ccim1k0, tr2_ti2[0]); + std::vector ti2_ccik0 = PM(b, loc, tr2_ti2[1], ccik0); + + CH(b, loc, ch, im1, c0, k, ido, cdim, ccim1k0_tr2[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, ccim1k0_tr2[1]); + + CH(b, loc, ch, i, c0, k, ido, cdim, ti2_ccik0[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, ti2_ccik0[1]); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// Handle radix-2 FFT computation +void radf2(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 2); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c20 = opBuilder.create(loc, 20); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iv_args) { + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + std::vector cc0k0_cc0k1 = PM(builder, loc, cc0k0, cc0k1); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, cc0k0_cc0k1[0]); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, cc0k0_cc0k1[1]); + builder.create(loc, std::nullopt); + }); + + Value flag = opBuilder.create(loc, ido, c2); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::eq, flag, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value k, ValueRange k_args) { + Value ccidom1k1 = CC(b, loc, cc, idom1, k, c1, ido, l1); + Value tmp = b.create(loc, ccidom1k1); + CH(b, loc, ch, c0, c1, k, ido, cdim, tmp); + Value ccidom1k0 = CC(b, loc, cc, idom1, k, c0, ido, l1); + CH(b, loc, ch, idom1, c0, k, ido, cdim, ccidom1k0); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + radf2Extend(builder, loc, cc, ch, wa, ido, l1, cdim); + builder.create(loc, std::nullopt); + }); +} + +void radf3Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value taur = + opBuilder.create(loc, APFloat(double(-0.5)), f64Ty); + Value taui = opBuilder.create( + loc, APFloat(double(0.86602540378443864676)), f64Ty); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector dr2_di2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); + Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); + Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); + Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); + std::vector dr3_di3 = + MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + + Value cr2 = b.create(loc, dr2_di2[0], dr3_di3[0]); + Value ci2 = b.create(loc, dr2_di2[1], dr3_di3[1]); + + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + Value tmp5 = b.create(loc, ccim1k0, cr2); + CH(builder, loc, ch, im1, c0, k, ido, cdim, tmp5); + + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + Value tmp6 = b.create(loc, ccik0, ci2); + CH(builder, loc, ch, i, c0, k, ido, cdim, tmp6); + + Value tmp7 = b.create(loc, taur, cr2); + Value tr2 = b.create(loc, ccim1k0, tmp7); + + Value tmp8 = b.create(loc, taur, ci2); + Value ti2 = b.create(loc, ccik0, tmp8); + + Value tmp9 = b.create(loc, dr2_di2[1], dr3_di3[1]); + Value tr3 = b.create(loc, taui, tmp9); + + Value tmp10 = + b.create(loc, dr3_di3[0], dr2_di2[0]); + Value ti3 = b.create(loc, taui, tmp10); + + std::vector tr2_tr3 = PM(b, loc, tr2, tr3); + std::vector ti3_ti2 = PM(b, loc, ti3, ti2); + + CH(builder, loc, ch, im1, c2, k, ido, cdim, tr2_tr3[0]); + CH(builder, loc, ch, icm1, c1, k, ido, cdim, tr2_tr3[1]); + + CH(builder, loc, ch, i, c2, k, ido, cdim, ti3_ti2[0]); + CH(builder, loc, ch, ic, c1, k, ido, cdim, ti3_ti2[1]); + + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// Handle radix-3 FFT computation +void radf3(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1) { + + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 3); + Value taur = + opBuilder.create(loc, APFloat(double(-0.5)), f64Ty); + Value taui = opBuilder.create( + loc, APFloat(double(0.86602540378443864676)), f64Ty); + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iv_args) { + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); + Value cr2 = builder.create(loc, cc0k1, cc0k2); + + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value tmp0 = builder.create(loc, cc0k0, cr2); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp0); + + Value tmp1 = builder.create(loc, cc0k2, cc0k1); + Value tmp2 = builder.create(loc, tmp1, taui); + CH(builder, loc, ch, c0, c2, iv, ido, cdim, tmp2); + + Value tmp3 = builder.create(loc, taur, cr2); + Value tmp4 = builder.create(loc, tmp3, cc0k0); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tmp4); + + builder.create(loc, std::nullopt); + }); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, ido, c1); + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + radf3Extend(builder, loc, cc, ch, wa, ido, l1, cdim); + builder.create(loc, std::nullopt); + }); +} + +void radf4Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim, Value c0, Value c1, + Value c2, Value c3) { + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector cr2_ci2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); + Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); + Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); + Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); + std::vector cr3_ci3 = + MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + + Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); + Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); + Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); + Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); + std::vector cr4_ci4 = + MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + + std::vector tr1_tr4 = PM(b, loc, cr4_ci4[0], cr2_ci2[0]); + std::vector ti1_ti4 = PM(b, loc, cr2_ci2[1], cr4_ci4[1]); + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + std::vector tr2_tr3 = PM(b, loc, ccim1k0, cr3_ci3[0]); + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + std::vector ti2_ti3 = PM(b, loc, ccik0, cr3_ci3[1]); + + std::vector chtmp0 = PM(b, loc, tr2_tr3[0], tr1_tr4[0]); + CH(b, loc, ch, im1, c0, k, ido, cdim, chtmp0[0]); + CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp0[1]); + + std::vector chtmp1 = PM(b, loc, ti1_ti4[0], ti2_ti3[0]); + CH(b, loc, ch, i, c0, k, ido, cdim, chtmp1[0]); + CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp1[1]); + + std::vector chtmp2 = PM(b, loc, tr2_tr3[1], ti1_ti4[1]); + CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp2[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp2[1]); + + std::vector chtmp3 = PM(b, loc, tr1_tr4[1], ti2_ti3[1]); + CH(b, loc, ch, i, c2, k, ido, cdim, chtmp3[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp3[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle radix-4 FFT computation +void radf4(OpBuilder &opBuilder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1, Value c0, Value c1, Value c2, Value c3) { + FloatType f64Ty = opBuilder.getF64Type(); + Value cdim = opBuilder.create(loc, 4); + Value hsqt2 = opBuilder.create( + loc, APFloat(double(0.70710678118654752440)), f64Ty); + Value idom1 = opBuilder.create(loc, ido, c1); + + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iargs) { + Value cc0k3 = CC(builder, loc, cc, c0, iv, c3, ido, l1); + Value cc0k1 = CC(builder, loc, cc, c0, iv, c1, ido, l1); + std::vector tr1_tmp0 = PM(builder, loc, cc0k3, cc0k1); + CH(builder, loc, ch, c0, c2, iv, ido, cdim, tr1_tmp0[1]); + + Value cc0k0 = CC(builder, loc, cc, c0, iv, c0, ido, l1); + Value cc0k2 = CC(builder, loc, cc, c0, iv, c2, ido, l1); + std::vector tr2_tmp1 = PM(builder, loc, cc0k0, cc0k2); + CH(builder, loc, ch, idom1, c1, iv, ido, cdim, tr2_tmp1[1]); + + std::vector tmp2_tmp3 = + PM(builder, loc, tr2_tmp1[0], tr1_tmp0[0]); + CH(builder, loc, ch, c0, c0, iv, ido, cdim, tmp2_tmp3[0]); + CH(builder, loc, ch, idom1, c3, iv, ido, cdim, tmp2_tmp3[1]); + + builder.create(loc, std::nullopt); + }); + + Value reminder = opBuilder.create(loc, ido, c2); + Value condition0 = opBuilder.create( + loc, arith::CmpIPredicate::eq, reminder, c0); + opBuilder.create( + loc, condition0, [&](OpBuilder &builder, Location loc) { + Value negHsqt2 = builder.create( + loc, APFloat(double(-0.70710678118654752440)), f64Ty); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { + Value ccidom1k1 = CC(b, loc, cc, idom1, iv, c1, ido, l1); + Value ccidom1k3 = CC(b, loc, cc, idom1, iv, c3, ido, l1); + Value tmp0 = b.create(loc, ccidom1k1, ccidom1k3); + Value ti1 = b.create(loc, negHsqt2, tmp0); + + Value tmp1 = b.create(loc, ccidom1k1, ccidom1k3); + Value tr1 = b.create(loc, hsqt2, tmp1); + + Value ccidom1k0 = CC(b, loc, cc, idom1, iv, c0, ido, l1); + std::vector tmp2_tmp3 = PM(b, loc, ccidom1k0, tr1); + CH(b, loc, ch, idom1, c0, iv, ido, cdim, tmp2_tmp3[0]); + CH(b, loc, ch, idom1, c2, iv, ido, cdim, tmp2_tmp3[1]); + + Value ccidom1k2 = CC(b, loc, cc, idom1, iv, c2, ido, l1); + std::vector tmp4_tmp5 = PM(b, loc, ti1, ccidom1k2); + CH(b, loc, ch, c0, c3, iv, ido, cdim, tmp4_tmp5[0]); + CH(b, loc, ch, c0, c1, iv, ido, cdim, tmp4_tmp5[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + Value condition1 = + opBuilder.create(loc, arith::CmpIPredicate::sgt, ido, c2); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + radf4Extend(builder, loc, cc, ch, wa, ido, l1, cdim, c0, c1, c2, c3); + builder.create(loc, std::nullopt); + }); + + return; +} + +void radf5Extend(OpBuilder &opBuilder, Location loc, Value cc, Value ch, + Value wa, Value ido, Value l1, Value cdim, Value tr11, + Value tr12, Value ti11, Value ti12, Value c0, Value c1, + Value c2, Value c3, Value c4) { + opBuilder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value k, ValueRange kargs) { + builder.create( + loc, c2, ido, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange iargs) { + Value ic = b.create(loc, ido, i); + Value icm1 = b.create(loc, ic, c1); + Value im1 = b.create(loc, i, c1); + Value im2 = b.create(loc, i, c2); + + Value wa0im2 = WA(b, loc, wa, c0, im2, ido, c1); + Value wa0im1 = WA(b, loc, wa, c0, im1, ido, c1); + Value ccim1k1 = CC(b, loc, cc, im1, k, c1, ido, l1); + Value ccik1 = CC(b, loc, cc, i, k, c1, ido, l1); + std::vector dr2_di2 = + MULPM(b, loc, wa0im2, wa0im1, ccim1k1, ccik1); + + Value wa1im2 = WA(b, loc, wa, c1, im2, ido, c1); + Value wa1im1 = WA(b, loc, wa, c1, im1, ido, c1); + Value ccim1k2 = CC(b, loc, cc, im1, k, c2, ido, l1); + Value ccik2 = CC(b, loc, cc, i, k, c2, ido, l1); + std::vector dr3_di3 = + MULPM(b, loc, wa1im2, wa1im1, ccim1k2, ccik2); + + Value wa2im2 = WA(b, loc, wa, c2, im2, ido, c1); + Value wa2im1 = WA(b, loc, wa, c2, im1, ido, c1); + Value ccim1k3 = CC(b, loc, cc, im1, k, c3, ido, l1); + Value ccik3 = CC(b, loc, cc, i, k, c3, ido, l1); + std::vector dr4_di4 = + MULPM(b, loc, wa2im2, wa2im1, ccim1k3, ccik3); + + Value wa3im2 = WA(b, loc, wa, c3, im2, ido, c1); + Value wa3im1 = WA(b, loc, wa, c3, im1, ido, c1); + Value ccim1k4 = CC(b, loc, cc, im1, k, c4, ido, l1); + Value ccik4 = CC(b, loc, cc, i, k, c4, ido, l1); + std::vector dr5_di5 = + MULPM(b, loc, wa3im2, wa3im1, ccim1k4, ccik4); + + std::vector cr2_ci5 = PM(b, loc, dr5_di5[0], dr2_di2[0]); + std::vector ci2_cr5 = PM(b, loc, dr2_di2[1], dr5_di5[1]); + std::vector cr3_ci4 = PM(b, loc, dr4_di4[0], dr3_di3[0]); + std::vector ci3_cr4 = PM(b, loc, dr3_di3[1], dr4_di4[1]); + + Value ccim1k0 = CC(b, loc, cc, im1, k, c0, ido, l1); + Value tmpch0 = b.create(loc, ccim1k0, cr2_ci5[0]); + Value chim10k = b.create(loc, tmpch0, cr3_ci4[0]); + CH(b, loc, ch, im1, c0, k, ido, cdim, chim10k); + + Value ccik0 = CC(b, loc, cc, i, k, c0, ido, l1); + Value tmpch1 = b.create(loc, ccik0, ci2_cr5[0]); + Value chi0k = b.create(loc, tmpch1, ci3_cr4[0]); + CH(b, loc, ch, i, c0, k, ido, cdim, chi0k); + + Value tmp0 = b.create(loc, tr11, cr2_ci5[0]); + Value tmp1 = b.create(loc, ccim1k0, tmp0); + Value tmp2 = b.create(loc, tr12, cr3_ci4[0]); + Value tr2 = b.create(loc, tmp1, tmp2); + + Value tmp3 = b.create(loc, tr11, ci2_cr5[0]); + Value tmp4 = b.create(loc, ccik0, tmp3); + Value tmp5 = b.create(loc, tr12, ci3_cr4[0]); + Value ti2 = b.create(loc, tmp4, tmp5); + + Value tmp6 = b.create(loc, tr12, cr2_ci5[0]); + Value tmp7 = b.create(loc, ccim1k0, tmp6); + Value tmp8 = b.create(loc, tr11, cr3_ci4[0]); + Value tr3 = b.create(loc, tmp7, tmp8); + + Value tmp9 = b.create(loc, tr12, ci2_cr5[0]); + Value tmp10 = b.create(loc, ccik0, tmp9); + Value tmp11 = b.create(loc, tr11, ci3_cr4[0]); + Value ti3 = b.create(loc, tmp10, tmp11); + + std::vector tr5_tr4 = + MULPM(b, loc, ci2_cr5[1], ci3_cr4[1], ti11, ti12); + std::vector ti5_ti4 = + MULPM(b, loc, cr2_ci5[1], cr3_ci4[1], ti11, ti12); + + std::vector chtmp0 = PM(b, loc, tr2, tr5_tr4[0]); + CH(b, loc, ch, im1, c2, k, ido, cdim, chtmp0[0]); + CH(b, loc, ch, icm1, c1, k, ido, cdim, chtmp0[1]); + + std::vector chtmp1 = PM(b, loc, ti5_ti4[0], ti2); + CH(b, loc, ch, i, c2, k, ido, cdim, chtmp1[0]); + CH(b, loc, ch, ic, c1, k, ido, cdim, chtmp1[1]); + + std::vector chtmp2 = PM(b, loc, tr3, tr5_tr4[1]); + CH(b, loc, ch, im1, c4, k, ido, cdim, chtmp2[0]); + CH(b, loc, ch, icm1, c3, k, ido, cdim, chtmp2[1]); + + std::vector chtmp3 = PM(b, loc, ti5_ti4[1], ti3); + CH(b, loc, ch, i, c4, k, ido, cdim, chtmp3[0]); + CH(b, loc, ch, ic, c3, k, ido, cdim, chtmp3[1]); + + b.create(loc, std::nullopt); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +// Handle radix-5 FFT computation +void radf5(OpBuilder &builder, Location loc, Value cc, Value ch, Value wa, + Value ido, Value l1, Value c0, Value c1, Value c2, Value c3, + Value c4) { + + FloatType f64Ty = builder.getF64Type(); + Value cdim = builder.create(loc, 5); + Value tr11 = builder.create( + loc, APFloat(double(0.3090169943749474241)), f64Ty); + Value tr12 = builder.create( + loc, APFloat(double(-0.8090169943749474241)), f64Ty); + Value ti11 = builder.create( + loc, APFloat(double(0.95105651629515357212)), f64Ty); + Value ti12 = builder.create( + loc, APFloat(double(0.58778525229247312917)), f64Ty); + Value idom1 = builder.create(loc, ido, c1); + + builder.create( + loc, c0, l1, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { + Value cc0k4 = CC(b, loc, cc, c0, iv, c4, ido, l1); + Value cc0k1 = CC(b, loc, cc, c0, iv, c1, ido, l1); + std::vector cr2_ci5 = PM(b, loc, cc0k4, cc0k1); + + Value cc0k3 = CC(b, loc, cc, c0, iv, c3, ido, l1); + Value cc0k2 = CC(b, loc, cc, c0, iv, c2, ido, l1); + std::vector cr3_ci4 = PM(b, loc, cc0k3, cc0k2); + + Value cc0k0 = CC(b, loc, cc, c0, iv, c0, ido, l1); + Value tmpch0 = b.create(loc, cc0k0, cr2_ci5[0]); + Value ch0 = b.create(loc, tmpch0, cr3_ci4[0]); + CH(b, loc, ch, c0, c0, iv, ido, cdim, ch0); + + Value tmpch1 = b.create(loc, tr11, cr2_ci5[0]); + Value tmpch2 = b.create(loc, tr12, cr3_ci4[0]); + Value tmpch3 = b.create(loc, cc0k0, tmpch1); + Value ch1 = b.create(loc, tmpch2, tmpch3); + CH(b, loc, ch, idom1, c1, iv, ido, cdim, ch1); + + Value tmpch4 = b.create(loc, ti11, cr2_ci5[1]); + Value tmpch5 = b.create(loc, ti12, cr3_ci4[1]); + Value ch2 = b.create(loc, tmpch4, tmpch5); + CH(b, loc, ch, c0, c2, iv, ido, cdim, ch2); + + Value tmpch6 = b.create(loc, tr12, cr2_ci5[0]); + Value tmpch7 = b.create(loc, tr11, cr3_ci4[0]); + Value tmpch8 = b.create(loc, tmpch6, tmpch7); + Value ch3 = b.create(loc, cc0k0, tmpch8); + CH(b, loc, ch, idom1, c3, iv, ido, cdim, ch3); + + Value tmpch9 = b.create(loc, ti12, cr2_ci5[1]); + Value tmpch10 = b.create(loc, ti11, cr3_ci4[1]); + Value ch4 = b.create(loc, tmpch9, tmpch10); + CH(b, loc, ch, c0, c4, iv, ido, cdim, ch4); + + b.create(loc, std::nullopt); + }); + + Value condition = + builder.create(loc, arith::CmpIPredicate::ne, ido, c1); + builder.create(loc, condition, [&](OpBuilder &b, Location loc) { + radf5Extend(b, loc, cc, ch, wa, ido, l1, cdim, tr11, tr12, ti11, ti12, c0, + c1, c2, c3, c4); + b.create(loc, std::nullopt); + }); + + return; +} + +// function to implement ++ operation +void index_increment(OpBuilder &opBuilder, Location loc, Value target) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value a = opBuilder.create(loc, target, c0); + Value b = opBuilder.create(loc, a, c1); + opBuilder.create(loc, b, target, c0); +} + +// switch 2 element in an array +void index_SWAP(OpBuilder &opBuilder, Location loc, Value array, Value target1, + Value target2) { + Value a = opBuilder.create(loc, array, target1); + Value b = opBuilder.create(loc, array, target2); + + opBuilder.create(loc, a, array, target2); + opBuilder.create(loc, b, array, target1); +} + +// factorize the input length ans store factors in Rfftp_fctdata_fct +Value rfftp_factorize(OpBuilder &opBuilder, Location loc, + Value Rfftp_fctdata_fct, Value Rfftp_fctdata_tw, + Value Rfftp_fctdata_tws, Value Rfftp_plan_length, + Value Rfftp_plan_nfct, Value Rfftp_plan_mem) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c_neg1 = opBuilder.create(loc, -1); + Value NFCT = opBuilder.create(loc, 25); + + FloatType f64Ty = opBuilder.getF64Type(); + IndexType indexTy = opBuilder.getIndexType(); + + Value length = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + Value length_1 = opBuilder.create(loc, Rfftp_plan_length, c0); + opBuilder.create(loc, length_1, length, c0); + + Value nfct = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + + opBuilder.create(loc, c0, nfct, c0); + + auto loop = opBuilder.create( + loc, TypeRange{indexTy}, ValueRange{length_1}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value length_while = args[0]; + + Value length_mod_4 = + builder.create(loc, length_while, c4); + Value condition = builder.create( + loc, arith::CmpIPredicate::eq, length_mod_4, c0); + builder.create(loc, condition, + ValueRange{length_while}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value length_while = args[0]; + + Value currnet_nfct = builder.create(loc, nfct, c0); + builder.create(loc, c4, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(builder, loc, nfct); + Value length_next = + builder.create(loc, length_while, c2); + builder.create(loc, length_next, length, c0); + + builder.create(loc, std::vector{length_next}); + }); + + Value length_if = opBuilder.create(loc, length, c0); + Value length_if_mod_2 = opBuilder.create(loc, length_if, c2); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, length_if_mod_2, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value length_next = builder.create(loc, length_if, c1); + builder.create(loc, length_next, length, c0); + + Value currnet_nfct = builder.create(loc, nfct, c0); + builder.create(loc, c2, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(builder, loc, nfct); + + Value currnet_nfct_1 = builder.create(loc, nfct, c0); + Value nfctm1 = builder.create(loc, currnet_nfct_1, c1); + index_SWAP(builder, loc, Rfftp_fctdata_fct, nfctm1, c0); + + builder.create(loc, std::nullopt); + }); + + TypeRange type1 = TypeRange{f64Ty}; + TypeRange type2 = TypeRange{indexTy}; + + Value maxl = + opBuilder.create(loc, MemRefType::get(1, indexTy)); + Value current_length2 = opBuilder.create(loc, length, c0); + Value current_length2_i32 = opBuilder.create( + loc, opBuilder.getI32Type(), current_length2); + Value length_f64 = opBuilder.create( + loc, opBuilder.getF64Type(), current_length2_i32); + Value sqrt_length = opBuilder.create(loc, length_f64); + Value maxl_index = opBuilder.create( + loc, opBuilder.getI32Type(), sqrt_length); + Value maxl_index_index = opBuilder.create( + loc, opBuilder.getIndexType(), maxl_index); + Value maxl_final = opBuilder.create(loc, maxl_index_index, c1); + opBuilder.create(loc, maxl_final, maxl, c0); + + opBuilder.create( + loc, TypeRange{indexTy}, ValueRange{c3}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value divisor = args[0]; + Value length_while = builder.create(loc, length, c0); + Value current_maxl = builder.create(loc, maxl, c0); + + Value condition1 = builder.create( + loc, arith::CmpIPredicate::sgt, length_while, c1); + Value condition2 = builder.create( + loc, arith::CmpIPredicate::slt, divisor, current_maxl); + Value and_cond = + builder.create(loc, condition1, condition2); + builder.create(loc, and_cond, ValueRange{divisor}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value divisor = args[0]; + + Value length_while = builder.create(loc, length, c0); + Value length_mod_divisor = + builder.create(loc, length_while, divisor); + Value condition1 = builder.create( + loc, arith::CmpIPredicate::eq, length_mod_divisor, c0); + builder.create( + loc, condition1, [&](OpBuilder &b, Location loc) { + b.create( + loc, TypeRange{indexTy}, ValueRange{c1}, + [&](OpBuilder &b2, Location loc, ValueRange args) { + Value x = args[0]; + + Value length_while_1 = + b2.create(loc, length, c0); + Value length_mod_divisor_1 = + b2.create(loc, length_while_1, divisor); + + Value condition2 = + b2.create(loc, arith::CmpIPredicate::eq, + length_mod_divisor_1, c0); + b2.create(loc, condition2, ValueRange{x}); + }, + [&](OpBuilder &b2, Location loc, ValueRange args) { + Value x = args[0]; + + Value currnet_nfct = + b2.create(loc, nfct, c0); + b2.create(loc, divisor, Rfftp_fctdata_fct, + currnet_nfct); + index_increment(b2, loc, nfct); + + Value length_while_1 = + b2.create(loc, length, c0); + Value length_new = + b2.create(loc, length_while_1, divisor); + b2.create(loc, length_new, length, c0); + + b2.create(loc, std::vector{x}); + }); + + Value current_length2_1 = + b.create(loc, length, c0); + Value currnet_length2_i32_1 = b.create( + loc, opBuilder.getI32Type(), current_length2_1); + Value length_f64_1 = b.create( + loc, opBuilder.getF64Type(), currnet_length2_i32_1); + Value sqrt_length_1 = b.create(loc, length_f64_1); + Value maxl_index_1 = + b.create(loc, b.getI32Type(), sqrt_length_1); + Value maxl_index_index_1 = b.create( + loc, opBuilder.getIndexType(), maxl_index_1); + Value maxl_final_1 = + b.create(loc, maxl_index_index_1, c1); + b.create(loc, maxl_final_1, maxl, c0); + + b.create(loc, std::nullopt); + }); + + Value divisor_next = builder.create(loc, divisor, c2); + builder.create(loc, std::vector{divisor_next}); + }); + + Value current_length1 = opBuilder.create(loc, length, c0); + Value condition1 = opBuilder.create( + loc, arith::CmpIPredicate::sgt, current_length1, c1); + opBuilder.create( + loc, condition1, [&](OpBuilder &builder, Location loc) { + Value current_nfct = builder.create(loc, nfct, c0); + builder.create(loc, current_length1, Rfftp_fctdata_fct, + current_nfct); + index_increment(builder, loc, nfct); + builder.create(loc, std::nullopt); + }); + + Value current_nfct1 = opBuilder.create(loc, nfct, c0); + opBuilder.create(loc, current_nfct1, Rfftp_plan_nfct, c0); + + return c0; +} + +Value index_to_f64(OpBuilder &opBuilder, Location loc, Value n) { + TypeRange type = TypeRange{opBuilder.getF64Type()}; + Value n_i32 = + opBuilder.create(loc, opBuilder.getI32Type(), n); + Value n_f64 = + opBuilder.create(loc, opBuilder.getF64Type(), n_i32); + return n_f64; +} + +Value f64_to_index(OpBuilder &opBuilder, Location loc, Value n_f64) { + TypeRange type = TypeRange{opBuilder.getI32Type()}; + Value n_i32 = + opBuilder.create(loc, opBuilder.getI32Type(), n_f64); + Value n_index = opBuilder.create( + loc, opBuilder.getIndexType(), n_i32); + return n_index; +} + +void my_sincosm1pi(OpBuilder &opBuilder, Location loc, Value a, Value res, + Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{c2}, SmallVector{c1}); + + Value s = opBuilder.create(loc, a, a); + + Value r1 = opBuilder.create( + loc, APFloat(double(-1.0369917389758117e-4)), f64Ty); + Value r2 = opBuilder.create( + loc, APFloat(double(1.9294935641298806e-3)), f64Ty); + Value r3 = opBuilder.create( + loc, APFloat(double(-2.5806887942825395e-2)), f64Ty); + Value r4 = opBuilder.create( + loc, APFloat(double(2.3533063028328211e-1)), f64Ty); + Value r5 = opBuilder.create( + loc, APFloat(double(-1.3352627688538006e+0)), f64Ty); + Value r6 = opBuilder.create( + loc, APFloat(double(4.0587121264167623e+0)), f64Ty); + Value r7 = opBuilder.create( + loc, APFloat(double(-4.9348022005446790e+0)), f64Ty); + + Value fma1 = opBuilder.create(loc, r1, s, r2); + Value fma2 = opBuilder.create(loc, fma1, s, r3); + Value fma3 = opBuilder.create(loc, fma2, s, r4); + Value fma4 = opBuilder.create(loc, fma3, s, r5); + Value fma5 = opBuilder.create(loc, fma4, s, r6); + Value fma6 = opBuilder.create(loc, fma5, s, r7); + + Value c = opBuilder.create(loc, fma6, s); + + Value r8 = opBuilder.create( + loc, APFloat(double(4.6151442520157035e-4)), f64Ty); + Value r9 = opBuilder.create( + loc, APFloat(double(-7.3700183130883555e-3)), f64Ty); + Value r10 = opBuilder.create( + loc, APFloat(double(8.2145868949323936e-2)), f64Ty); + Value r11 = opBuilder.create( + loc, APFloat(double(-5.9926452893214921e-1)), f64Ty); + Value r12 = opBuilder.create( + loc, APFloat(double(2.5501640398732688e+0)), f64Ty); + Value r13 = opBuilder.create( + loc, APFloat(double(-5.1677127800499516e+0)), f64Ty); + + Value fma7 = opBuilder.create(loc, r8, s, r9); + Value fma8 = opBuilder.create(loc, fma7, s, r10); + Value fma9 = opBuilder.create(loc, fma8, s, r11); + Value fma10 = opBuilder.create(loc, fma9, s, r12); + Value fma11 = opBuilder.create(loc, fma10, s, r13); + + Value s_new = opBuilder.create(loc, s, a); + Value r = opBuilder.create(loc, fma11, s_new); + + Value pi = opBuilder.create( + loc, APFloat(double(3.1415926535897931e+0)), f64Ty); + Value s_final = opBuilder.create(loc, a, pi, r); + + opBuilder.create(loc, c, res_raw, c0); + opBuilder.create(loc, s_final, res_raw, c1); + + return; +} + +void calc_first_octant_extend2(OpBuilder &opBuilder, Location loc, Value den, + Value res, Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, bias); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value f2 = + opBuilder.create(loc, APFloat(double(2.0)), f64Ty); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + + Value n_f64 = index_to_f64(opBuilder, loc, n); + Value l1_f64 = opBuilder.create(loc, n_f64); + Value l1 = f64_to_index(opBuilder, loc, l1_f64); + + opBuilder.create( + loc, c1, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iargs) { + Value i_f64 = index_to_f64(builder, loc, i); + Value den_f64 = index_to_f64(builder, loc, den); + Value arg = builder.create(loc, i_f64, den_f64); + Value arg_scaled = builder.create(loc, arg, f2); + + Value im2 = builder.create(loc, i, c2); + Value im2_bias = builder.create(loc, im2, bias); + + my_sincosm1pi(builder, loc, arg_scaled, res, im2_bias); + builder.create(loc, std::nullopt); + }); + + Value start_start = opBuilder.create(loc, l1, c0); + + opBuilder.create( + loc, start_start, n, l1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value start_loop, + ValueRange start_loop_args) { + Value start_f64 = index_to_f64(builder, loc, start_loop); + Value den_f64 = index_to_f64(builder, loc, den); + Value arg = builder.create(loc, start_f64, den_f64); + Value arg_scaled = builder.create(loc, arg, f2); + + Value cs = + builder.create(loc, MemRefType::get(2, f64Ty)); + my_sincosm1pi(builder, loc, arg_scaled, cs, c0); + + Value cs0 = builder.create(loc, cs, c0); + Value cs1 = builder.create(loc, cs, c1); + + Value cs0_plus_1 = builder.create(loc, cs0, f1); + + Value start_2 = builder.create(loc, start_loop, c2); + builder.create(loc, cs0_plus_1, res_raw, start_2); + Value start_2_plus_1 = builder.create(loc, start_2, c1); + builder.create(loc, cs1, res_raw, start_2_plus_1); + + Value n_minus_start = builder.create(loc, n, start_loop); + Value end_1 = builder.create(loc, l1, c0); + Value sum = builder.create(loc, start_loop, end_1); + Value condition = builder.create( + loc, arith::CmpIPredicate::sgt, sum, n); + Value end = builder.create(loc, condition, + n_minus_start, end_1); + + builder.create( + loc, c1, end, c1, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value i_2 = b.create(loc, i, c2); + Value csx0 = b.create(loc, res_raw, i_2); + Value i_2_plus_1 = b.create(loc, i_2, c1); + Value csx1 = b.create(loc, res_raw, i_2_plus_1); + + Value tmp1 = b.create(loc, cs0, csx0); + Value tmp2 = b.create(loc, cs1, csx1); + Value tmp3 = b.create(loc, tmp1, tmp2); + Value tmp4 = b.create(loc, tmp3, cs0); + Value tmp5 = b.create(loc, tmp4, csx0); + Value res_real = b.create(loc, tmp5, f1); + + Value tmp6 = b.create(loc, cs0, csx1); + Value tmp7 = b.create(loc, cs1, csx0); + Value tmp8 = b.create(loc, tmp6, tmp7); + Value tmp9 = b.create(loc, tmp8, cs1); + Value res_imag = b.create(loc, tmp9, csx1); + + Value start_plus_i = b.create(loc, start_loop, i); + Value start_plus_i_2 = + b.create(loc, start_plus_i, c2); + Value start_plus_i_2_plus_1 = + b.create(loc, start_plus_i_2, c1); + b.create(loc, res_real, res_raw, start_plus_i_2); + b.create(loc, res_imag, res_raw, + start_plus_i_2_plus_1); + b.create(loc, std::nullopt); + }); + + builder.create(loc, cs); + builder.create(loc, std::nullopt); + }); + + opBuilder.create( + loc, c1, l1, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange i_args) { + Value i_2 = builder.create(loc, i, c2); + Value val = builder.create(loc, res_raw, i_2); + Value val_plus_1 = builder.create(loc, val, f1); + builder.create(loc, val_plus_1, res_raw, i_2); + builder.create(loc, std::nullopt); + }); + + return; +} + +void calc_first_octant_extend1(OpBuilder &opBuilder, Location loc, Value den, + Value res, Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, bias); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value res_raw = opBuilder.create( + loc, resultType, res, SmallVector{bias}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + + opBuilder.create(loc, f1, res_raw, c0); + opBuilder.create(loc, f0, res_raw, c1); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + calc_first_octant_extend2(builder, loc, den, res, bias); + builder.create(loc, std::nullopt); + }); +} + +void calc_first_octant(OpBuilder &opBuilder, Location loc, Value den, Value res, + Value bias) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value den_plus_4 = opBuilder.create(loc, den, c4); + Value n = opBuilder.create(loc, den_plus_4, c3); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + calc_first_octant_extend1(builder, loc, den, res, bias); + builder.create(loc, std::nullopt); + }); +} + +void calc_first_quadrant(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, n); + + FloatType f64Ty = opBuilder.getF64Type(); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value p_raw = opBuilder.create( + loc, resultType, res, SmallVector{n}, + SmallVector{remaining_size}, SmallVector{c1}); + + Value n_times_2 = opBuilder.create(loc, n, c1); + calc_first_octant(opBuilder, loc, n_times_2, res, n); + + Value n_plus_2 = opBuilder.create(loc, n, c2); + Value ndone = opBuilder.create(loc, n_plus_2, c2); + Value ndonem1 = opBuilder.create(loc, ndone, c1); + Value ndone2 = opBuilder.create(loc, ndone, c2); + Value idx2_start = opBuilder.create(loc, ndone2, c2); + + Value i_start = opBuilder.create(loc, 0); + Value idx1_start = opBuilder.create(loc, 0); + + auto loop = opBuilder.create( + loc, i_start, ndonem1, c2, ValueRange{i_start, idx1_start, idx2_start}, + [&](OpBuilder &builder, Location loc, Value i_loop, + ValueRange i_loop_args) { + Value i_loop1 = i_loop_args[0]; + Value idx1 = i_loop_args[1]; + Value idx2 = i_loop_args[2]; + + Value p_2i = builder.create(loc, i_loop1, c2); + Value p_val = builder.create(loc, p_raw, p_2i); + builder.create(loc, p_val, res, idx1); + + Value p_2i_plus_1 = builder.create(loc, p_2i, c1); + Value p_val_1 = builder.create(loc, p_raw, p_2i_plus_1); + Value idx1_plus_1 = builder.create(loc, idx1, c1); + builder.create(loc, p_val_1, res, idx1_plus_1); + + Value p_2i_plus_3 = builder.create(loc, p_2i, c3); + Value p_val_3 = builder.create(loc, p_raw, p_2i_plus_3); + builder.create(loc, p_val_3, res, idx2); + + Value p_2i_plus_2 = builder.create(loc, p_2i, c2); + Value p_val_2 = builder.create(loc, p_raw, p_2i_plus_2); + Value idx2_plus_1 = builder.create(loc, idx2, c1); + builder.create(loc, p_val_2, res, idx2_plus_1); + + Value i_loop1_next = builder.create(loc, i_loop1, c2); + Value idx1_next = builder.create(loc, idx1, c2); + Value idx2_next = builder.create(loc, idx2, c2); + builder.create( + loc, std::vector{i_loop1_next, idx1_next, idx2_next}); + }); + + Value i_v = loop.getResults()[0]; + Value idx1_v = loop.getResults()[1]; + Value idx2_v = loop.getResults()[2]; + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::ne, i_v, ndone); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value p_2i = builder.create(loc, i_v, c2); + Value p_val = builder.create(loc, p_raw, p_2i); + builder.create(loc, p_val, res, idx1_v); + + Value p_2i_plus_1 = builder.create(loc, p_2i, c1); + Value p_val_1 = builder.create(loc, p_raw, p_2i_plus_1); + Value idx1_plus_1 = builder.create(loc, idx1_v, c1); + builder.create(loc, p_val_1, res, idx1_plus_1); + builder.create(loc, std::nullopt); + }); + + return; +} + +void calc_first_half(OpBuilder &opBuilder, Location loc, Value n, Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + IndexType indexTy = opBuilder.getIndexType(); + FloatType f64Ty = opBuilder.getF64Type(); + + Value f0 = + opBuilder.create(loc, APFloat(double(0.0)), f64Ty); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + + Value n_plus_1 = opBuilder.create(loc, n, c1); + Value ndone = opBuilder.create(loc, n_plus_1, c1); + + Value size = opBuilder.create(loc, res, c0); + Value remaining_size = opBuilder.create(loc, size, n); + Value remaining_size_p1 = + opBuilder.create(loc, remaining_size, c1); + + Value nm1 = opBuilder.create(loc, n, c1); + + FailureOr computelayout = + StridedLayoutAttr::get(opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value p_raw = opBuilder.create( + loc, resultType, res, SmallVector{nm1}, + SmallVector{remaining_size_p1}, + SmallVector{c1}); + + Value n_times_4 = opBuilder.create(loc, n, c2); + calc_first_octant(opBuilder, loc, n_times_4, res, nm1); + + Value i4_start = opBuilder.create(loc, 0); + Value i_start = opBuilder.create(loc, 0); + Value in = opBuilder.create(loc, n, c0); + + auto loop = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{i4_start, i_start}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_minus_i4 = builder.create(loc, in, i4); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4, in_minus_i4); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value i4_2 = builder.create(loc, i4, c2); + Value i_2 = builder.create(loc, i, c2); + Value i4_2_p1 = builder.create(loc, i4_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_i4_2 = builder.create(loc, p_raw, i4_2); + Value p_i4_2_p1 = builder.create(loc, p_raw, i4_2_p1); + + builder.create(loc, p_i4_2, res, i_2); + builder.create(loc, p_i4_2_p1, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_0 = loop.getResults()[0]; + Value final_i_0 = loop.getResults()[1]; + + auto loop1 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_0, final_i_0}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value i4_minus_in = builder.create(loc, i4, in); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4_minus_in, c0); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value xm = builder.create(loc, in, i4); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + builder.create(loc, p_xm_2_p1, res, i_2); + builder.create(loc, p_xm_2, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_1 = loop1.getResults()[0]; + Value final_i_1 = loop1.getResults()[1]; + + auto loop2 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_1, final_i_1}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_3 = builder.create(loc, in, c3); + Value in_3_m_i4 = builder.create(loc, in_3, i4); + Value condition = builder.create( + loc, arith::CmpIPredicate::sle, i4, in_3_m_i4); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value xm = builder.create(loc, i4, in); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + Value m_p_xm_2_p1 = builder.create(loc, f0, p_xm_2_p1); + + builder.create(loc, m_p_xm_2_p1, res, i_2); + builder.create(loc, p_xm_2, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + builder.create(loc, std::vector{i4_next, i_next}); + }); + + Value final_i4_2 = loop2.getResults()[0]; + Value final_i_2 = loop2.getResults()[1]; + + auto loop3 = opBuilder.create( + loc, TypeRange{indexTy, indexTy}, ValueRange{final_i4_2, final_i_2}, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value condition = builder.create( + loc, arith::CmpIPredicate::slt, i, ndone); + builder.create(loc, condition, ValueRange{i4, i}); + }, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value i4 = args[0]; + Value i = args[1]; + + Value in_2 = builder.create(loc, in, c2); + + Value xm = builder.create(loc, in_2, i4); + Value xm_2 = builder.create(loc, xm, c2); + Value i_2 = builder.create(loc, i, c2); + Value xm_2_p1 = builder.create(loc, xm_2, c1); + Value i_2_p1 = builder.create(loc, i_2, c1); + + Value p_xm_2_p1 = builder.create(loc, p_raw, xm_2_p1); + Value p_xm_2 = builder.create(loc, p_raw, xm_2); + + Value m_p_xm_2 = builder.create(loc, f0, p_xm_2); + + builder.create(loc, m_p_xm_2, res, i_2); + builder.create(loc, p_xm_2_p1, res, i_2_p1); + + Value i4_next = builder.create(loc, i4, c4); + Value i_next = builder.create(loc, i, c1); + + builder.create(loc, std::vector{i4_next, i_next}); + }); + + return; +} + +void fill_first_quadrant(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c8 = opBuilder.create(loc, 8); + + FloatType f64Ty = opBuilder.getF64Type(); + + Value hsqt2 = opBuilder.create( + loc, APFloat(double(0.707106781186547524400844362104849)), f64Ty); + + Value quart = opBuilder.create(loc, n, c2); + Value n_mod_8 = opBuilder.create(loc, n, c8); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_8, c0); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value quart_plus_1 = builder.create(loc, quart, c1); + builder.create(loc, hsqt2, res, quart); + builder.create(loc, hsqt2, res, quart_plus_1); + builder.create(loc, std::nullopt); + }); + + Value two_quart = opBuilder.create(loc, quart, c2); + Value two_quart_minus_2 = opBuilder.create(loc, two_quart, c2); + + opBuilder.create( + loc, c2, quart, c2, ValueRange{two_quart_minus_2}, + [&](OpBuilder &builder, Location loc, Value i, ValueRange i_args) { + Value j = i_args[0]; + + Value i_plus_1 = builder.create(loc, i, c1); + Value j_plus_1 = builder.create(loc, j, c1); + + Value val_i = builder.create(loc, res, i); + Value val_i_plus_1 = builder.create(loc, res, i_plus_1); + + builder.create(loc, val_i_plus_1, res, j); + builder.create(loc, val_i, res, j_plus_1); + + Value j_next = builder.create(loc, j, c2); + builder.create(loc, j_next); + }); + + return; +} + +void fill_first_half(OpBuilder &opBuilder, Location loc, Value n, Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + FloatType f64Ty = opBuilder.getF64Type(); + Value c_1 = + opBuilder.create(loc, APFloat(double(-1.0)), f64Ty); + + Value half = opBuilder.create(loc, n, c1); + Value n_mod_4 = opBuilder.create(loc, n, c4); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_4, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, c0, half, c2, std::nullopt, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value i_plus_1 = b.create(loc, i, c1); + Value i_plus_half = b.create(loc, i, half); + Value i_plus_half_plus_1 = + b.create(loc, i_plus_half, c1); + + Value val_i = b.create(loc, res, i); + Value val_i_plus_1 = b.create(loc, res, i_plus_1); + + Value neg_val_i_plus_1 = + b.create(loc, val_i_plus_1, c_1); + b.create(loc, neg_val_i_plus_1, res, + i_plus_half); + b.create(loc, val_i, res, i_plus_half_plus_1); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value two_half_minus_2 = builder.create(loc, half, c1); + Value two_half_minus_2_mul_2 = + builder.create(loc, two_half_minus_2, c2); + + builder.create( + loc, c2, half, c2, ValueRange{two_half_minus_2_mul_2}, + [&](OpBuilder &b, Location loc, Value i, ValueRange i_args) { + Value j = i_args[0]; + Value i_plus_1 = builder.create(loc, i, c1); + Value j_plus_1 = builder.create(loc, j, c1); + Value val_i = b.create(loc, res, i); + Value val_i_plus_1 = b.create(loc, res, i_plus_1); + Value neg_val_i = b.create(loc, val_i, c_1); + b.create(loc, neg_val_i, res, j); + b.create(loc, val_i_plus_1, res, j_plus_1); + + Value j_next = builder.create(loc, j, c2); + b.create(loc, j_next); + }); + + builder.create(loc, std::nullopt); + }); + + return; +} + +void sincos_2pibyn_half(OpBuilder &opBuilder, Location loc, Value n, + Value res) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value n_mod_4 = opBuilder.create(loc, n, c4); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, n_mod_4, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + calc_first_octant(builder, loc, n, res, c0); + + fill_first_quadrant(builder, loc, n, res); + fill_first_half(builder, loc, n, res); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value n_mod_2 = builder.create(loc, n, c2); + Value condition1 = builder.create( + loc, arith::CmpIPredicate::eq, n_mod_2, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &b, Location loc) { + calc_first_quadrant(b, loc, n, res); + fill_first_half(b, loc, n, res); + b.create(loc, std::nullopt); + }, + [&](OpBuilder &b, Location loc) { + calc_first_half(b, loc, n, res); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// calcuate the twiddle factors for the input length +Value rfftp_comp_twiddle(OpBuilder &opBuilder, Location loc, Value length, + Value Rfftp_fctdata_fct, Value Rfftp_fctdata_tw, + Value Rfftp_fctdata_tws, Value Rfftp_plan_length, + Value Rfftp_plan_nfct, Value Rfftp_plan_mem) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c50 = opBuilder.create(loc, 50); + + Value length_2 = opBuilder.create(loc, length, c2); + FloatType f64Ty = opBuilder.getF64Type(); + + Value twid = opBuilder.create( + loc, MemRefType::get(ShapedType::kDynamic, f64Ty), + /*dynamicOperands=*/length_2); + + Value plan_nfct = opBuilder.create(loc, Rfftp_plan_nfct, c0); + + sincos_2pibyn_half(opBuilder, loc, length, twid); + + Value l1_start = opBuilder.create(loc, 1); + + opBuilder.create( + loc, c0, plan_nfct, c1, ValueRange{l1_start}, + [&](OpBuilder &builder, Location loc, Value k, ValueRange k_args) { + Value l1 = k_args[0]; + + Value ip = builder.create(loc, Rfftp_fctdata_fct, k); + + Value l1_m_ip = builder.create(loc, l1, ip); + Value ido = builder.create(loc, length, l1_m_ip); + Value plan_nfct_m1 = builder.create(loc, plan_nfct, c1); + + Value condition1 = builder.create( + loc, arith::CmpIPredicate::slt, k, plan_nfct_m1); + + builder.create( + loc, condition1, [&](OpBuilder &b, Location loc) { + Value ido_m1 = b.create(loc, ido, c1); + Value ido_m1_d2 = b.create(loc, ido_m1, c2); + Value ido_m1_d2_p1 = b.create(loc, ido_m1_d2, c1); + + b.create( + loc, c1, ip, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value j, ValueRange j_args) { + b2.create( + loc, c1, ido_m1_d2_p1, c1, std::nullopt, + [&](OpBuilder &b3, Location loc, Value i, + ValueRange i_args) { + Value j2 = b3.create(loc, j, c2); + Value j2_l1 = b3.create(loc, j2, l1); + Value j2_l1_i = + b3.create(loc, j2_l1, i); + Value j2_l1_i_p1 = + b3.create(loc, j2_l1_i, c1); + + Value j_m1 = b3.create(loc, j, c1); + Value ido_m1_j_m1 = + b3.create(loc, ido_m1, j_m1); + + Value i2 = b3.create(loc, i, c2); + Value i2_m1 = b3.create(loc, i2, c1); + Value i2_m2 = b3.create(loc, i2, c2); + + Value tw_a = + b3.create(loc, ido_m1_j_m1, i2_m2); + Value tw_b = + b3.create(loc, ido_m1_j_m1, i2_m1); + + Value twid_a = + b3.create(loc, twid, j2_l1_i); + Value twid_b = + b3.create(loc, twid, j2_l1_i_p1); + + Value fct_k = b3.create( + loc, Rfftp_fctdata_tw, k); + + b3.create(loc, twid_a, fct_k, tw_a); + b3.create(loc, twid_b, fct_k, tw_b); + + b3.create(loc, std::nullopt); + }); + b2.create(loc, std::nullopt); + }); + + b.create(loc, std::nullopt); + }); + + Value condition2 = builder.create( + loc, arith::CmpIPredicate::sgt, ip, c5); + + builder.create( + loc, condition2, [&](OpBuilder &b, Location loc) { + Value fct_k = b.create(loc, Rfftp_fctdata_tws, k); + Value c_f0 = + b.create(loc, APFloat(double(0.0)), f64Ty); + Value c_f1 = + b.create(loc, APFloat(double(1.0)), f64Ty); + + b.create(loc, c_f1, fct_k, c0); + b.create(loc, c_f0, fct_k, c1); + + Value ip_div_2 = b.create(loc, ip, c1); + Value ip_div_2_p1 = b.create(loc, ip_div_2, c1); + + b.create( + loc, c1, ip_div_2_p1, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value i, ValueRange i_args) { + Value i2 = b2.create(loc, i, c2); + Value i2_p1 = b2.create(loc, i2, c1); + Value ip_m_i = b2.create(loc, ip, i); + Value ip_m_i_2 = b2.create(loc, ip_m_i, c2); + Value ip_m_i_2_p1 = + b2.create(loc, ip_m_i_2, c1); + + Value length_div_ip = + b2.create(loc, length, ip); + Value i2_length_div_ip = + b2.create(loc, i2, length_div_ip); + Value i2_length_div_ip_p1 = + b2.create(loc, i2_length_div_ip, c1); + + Value twid_a = + b2.create(loc, twid, i2_length_div_ip); + Value twid_b = b2.create( + loc, twid, i2_length_div_ip_p1); + Value twid_c = b2.create(loc, c_f0, twid_a); + Value twid_d = b2.create(loc, c_f0, twid_b); + + b2.create(loc, twid_a, fct_k, i2); + b2.create(loc, twid_b, fct_k, i2_p1); + b2.create(loc, twid_c, fct_k, ip_m_i_2); + b2.create(loc, twid_d, fct_k, ip_m_i_2_p1); + b2.create(loc, std::nullopt); + }); + + b.create(loc, std::nullopt); + }); + + Value l1_next = builder.create(loc, l1, ip); + builder.create(loc, l1_next); + }); + + opBuilder.create(loc, twid); + + return c0; +} + +// calculate the twiddle factors and generates the computation order of +// butterfly operators +std::vector make_rfftp_plan(OpBuilder &opBuilder, Location loc, + Value length) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + int64_t NFCT_num = 25; + Value NFCT = opBuilder.create(loc, NFCT_num); + + FloatType f64Ty = opBuilder.getF64Type(); + IndexType indexTy = opBuilder.getIndexType(); + + Value length_2 = opBuilder.create(loc, length, c2); + + MemRefType type = MemRefType::get(NFCT_num, indexTy); + // MemRefType type1 = MemRefType::get(length_num2, f64Ty); + MemRefType type1 = MemRefType::get(ShapedType::kDynamic, f64Ty); + MemRefType type2 = MemRefType::get(NFCT_num, type1); + MemRefType type3 = MemRefType::get(1, indexTy); + MemRefType type4 = MemRefType::get(1, f64Ty); + + Value Rfftp_fctdata_fct = opBuilder.create(loc, type); + Value Rfftp_fctdata_tw = opBuilder.create(loc, type2); + Value Rfftp_fctdata_tws = opBuilder.create(loc, type2); + Value Rfftp_plan_length = opBuilder.create(loc, type3); + Value Rfftp_plan_nfct = opBuilder.create(loc, type3); + Value Rfftp_plan_mem = opBuilder.create(loc, type4); + + opBuilder.create(loc, length, Rfftp_plan_length, c0); + opBuilder.create(loc, c0, Rfftp_plan_nfct, c0); + + opBuilder.create( + loc, c0, NFCT, c1, std::nullopt, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iargs) { + builder.create(loc, c0, Rfftp_fctdata_fct, i); + + Value tw_i = builder.create( + loc, type1, /*dynamicOperands=*/length_2); + builder.create(loc, tw_i, Rfftp_fctdata_tw, i); + Value tws_i = builder.create( + loc, type1, /*dynamicOperands=*/length_2); + builder.create(loc, tws_i, Rfftp_fctdata_tws, i); + + builder.create(loc, std::nullopt); + }); + + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::ne, length, c1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value xxx = builder.create(loc, 1); + rfftp_factorize(builder, loc, Rfftp_fctdata_fct, Rfftp_fctdata_tw, + Rfftp_fctdata_tws, Rfftp_plan_length, Rfftp_plan_nfct, + Rfftp_plan_mem); + rfftp_comp_twiddle(builder, loc, length, Rfftp_fctdata_fct, + Rfftp_fctdata_tw, Rfftp_fctdata_tws, + Rfftp_plan_length, Rfftp_plan_nfct, Rfftp_plan_mem); + builder.create(loc, std::nullopt); + }); + + return {Rfftp_fctdata_fct, Rfftp_fctdata_tw, Rfftp_fctdata_tws, + Rfftp_plan_length, Rfftp_plan_nfct, Rfftp_plan_mem}; +} + +void memref_SWAP(OpBuilder &opBuilder, Location loc, Value p, Value p1) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + + Value length = opBuilder.create(loc, p, c0); + + opBuilder.create( + loc, c0, length, c1, std::nullopt, + [&](OpBuilder builder, Location loc, Value i, ValueRange i_args) { + Value val_p = builder.create(loc, p, i); + Value val_p1 = builder.create(loc, p1, i); + + builder.create(loc, val_p, p1, i); + builder.create(loc, val_p1, p, i); + builder.create(loc, std::nullopt); + }); +} + +void flag_SWAP(OpBuilder &opBuilder, Location loc, Value flag) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + + Value val = opBuilder.create(loc, flag, c0); + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::eq, val, c0); + + Value x = opBuilder.create(loc, condition, c1, c0); + + opBuilder.create(loc, x, flag, c0); +} + +void copy_and_norm(OpBuilder &opBuilder, Location loc, Value c, Value p1, + Value n, Value fct, Value flag) { + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + FloatType f64Ty = opBuilder.getF64Type(); + Value f1 = + opBuilder.create(loc, APFloat(double(1.0)), f64Ty); + + Value flag_val = opBuilder.create(loc, flag, c0); + Value condition = opBuilder.create( + loc, arith::CmpIPredicate::eq, flag_val, c0); + + opBuilder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + Value condition1 = builder.create( + loc, arith::CmpFPredicate::ONE, fct, f1); + builder.create( + loc, condition1, + [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder b2, Location loc, Value i, ValueRange i_args) { + Value p1_i = b2.create(loc, p1, i); + Value v = b2.create(loc, fct, p1_i); + b2.create(loc, v, c, i); + b2.create(loc, std::nullopt); + }); + b.create(loc, std::nullopt); + }, + [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder b2, Location loc, Value i, ValueRange i_args) { + Value val = b2.create(loc, p1, i); + b2.create(loc, val, c, i); + b2.create(loc, std::nullopt); + }); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }, + [&](OpBuilder &builder, Location loc) { + Value condition2 = builder.create( + loc, arith::CmpFPredicate::ONE, fct, f1); + builder.create( + loc, condition2, [&](OpBuilder &b, Location loc) { + b.create( + loc, c0, n, c1, std::nullopt, + [&](OpBuilder &b2, Location loc, Value i, ValueRange i_args) { + Value c_i = b2.create(loc, c, i); + Value newC = b2.create(loc, fct, c_i); + b2.create(loc, newC, c, i); + b2.create(loc, std::nullopt); + }); + b.create(loc, std::nullopt); + }); + builder.create(loc, std::nullopt); + }); +} + +// FFT forward function for real number +void rfftp_forward(OpBuilder &opBuilder, Location loc, Value Rfftp_fctdata_fct, + Value Rfftp_fctdata_tw, Value Rfftp_fctdata_tws, + Value Rfftp_plan_length, Value Rfftp_plan_nfct, + Value Rfftp_plan_mem, Value c, Value fct) { + + Value c0 = opBuilder.create(loc, 0); + Value c1 = opBuilder.create(loc, 1); + Value c2 = opBuilder.create(loc, 2); + Value c3 = opBuilder.create(loc, 3); + Value c4 = opBuilder.create(loc, 4); + Value c5 = opBuilder.create(loc, 5); + Value c20 = opBuilder.create(loc, 20); + FloatType f64Ty = opBuilder.getF64Type(); + + Value n = opBuilder.create(loc, Rfftp_plan_length, c0); + + Value condition = + opBuilder.create(loc, arith::CmpIPredicate::ne, n, c1); + + opBuilder.create( + loc, condition, [&](OpBuilder &builder, Location loc) { + Value flag = builder.create( + loc, MemRefType::get(1, builder.getIndexType())); + builder.create(loc, c1, flag, c0); + Value l1_raw = builder.create(loc, n, c0); + Value nf = builder.create(loc, Rfftp_plan_nfct, c0); + + MemRefType cType = dyn_cast(c.getType()); + Value dimSize = builder.create(loc, c, 0); + Value ch = builder.create(loc, cType, + /*dynamicOperands=*/dimSize); + + // Value ch = builder.create( + // loc, MemRefType::get(cType.getShape(), f64Ty)); + + FailureOr computelayout = StridedLayoutAttr::get( + opBuilder.getContext(), + /*offset=*/ShapedType::kDynamic, /*strides=*/{1}); + MemRefType resultType = + MemRefType::get(ShapedType::kDynamic, f64Ty, *computelayout); + + // memref> + + Value p1_raw = builder.create( + loc, resultType, c, SmallVector{c0}, + SmallVector{n}, SmallVector{c1}); + + Value p2_raw = builder.create( + loc, resultType, ch, SmallVector{c0}, + SmallVector{n}, SmallVector{c1}); + + builder.create( + loc, c0, nf, c1, ValueRange{l1_raw}, + [&](OpBuilder b, Location loc, Value k1, ValueRange k1_args) { + Value l1_old = k1_args[0]; + + Value nf_m_k1 = b.create(loc, nf, k1); + Value k = b.create(loc, nf_m_k1, c1); + Value ip = b.create(loc, Rfftp_fctdata_fct, k); + Value ido = b.create(loc, n, l1_old); + Value l1 = b.create(loc, l1_old, ip); + + Value tw = b.create(loc, Rfftp_fctdata_tw, k); + + Value condition1 = b.create( + loc, arith::CmpIPredicate::eq, ip, c4); + + b.create( + loc, condition1, + [&](OpBuilder &b2, Location loc) { + radf4(b2, loc, p1_raw, p2_raw, tw, ido, l1, c0, c1, c2, c3); + b2.create(loc, std::nullopt); + }, + [&](OpBuilder &b2, Location loc) { + Value condition2 = b2.create( + loc, arith::CmpIPredicate::eq, ip, c2); + b2.create( + loc, condition2, + [&](OpBuilder &b3, Location loc) { + radf2(b3, loc, p1_raw, p2_raw, tw, ido, l1); + b3.create(loc, std::nullopt); + }, + [&](OpBuilder &b3, Location loc) { + Value condition3 = b3.create( + loc, arith::CmpIPredicate::eq, ip, c3); + b3.create( + loc, condition3, + [&](OpBuilder &b4, Location loc) { + radf3(b4, loc, p1_raw, p2_raw, tw, ido, l1); + b4.create(loc, std::nullopt); + }, + [&](OpBuilder &b4, Location loc) { + Value condition4 = b4.create( + loc, arith::CmpIPredicate::eq, ip, c5); + b4.create( + loc, condition4, + [&](OpBuilder &b5, Location loc) { + radf5(b5, loc, p1_raw, p2_raw, tw, ido, + l1, c0, c1, c2, c3, c4); + b5.create(loc, + std::nullopt); + }, + [&](OpBuilder &b5, Location loc) { + Value tws = b5.create( + loc, Rfftp_fctdata_tws, k); + radfg(b5, loc, p1_raw, p2_raw, tw, tws, + ido, ip, l1); + memref_SWAP(b5, loc, p1_raw, p2_raw); + flag_SWAP(b5, loc, flag); + b5.create(loc, + std::nullopt); + }); + b4.create(loc, std::nullopt); + }); + b3.create(loc, std::nullopt); + } + + ); + b2.create(loc, std::nullopt); + }); + + memref_SWAP(b, loc, p1_raw, p2_raw); + flag_SWAP(b, loc, flag); + + b.create(loc, l1); + }); + + copy_and_norm(builder, loc, c, p1_raw, n, fct, flag); + + builder.create(loc, std::nullopt); + }); +} + +// Calculate abspower of bufferMem and store result to a specific line in the +// resultMem +void absPower(OpBuilder &builder, Location loc, Value bufferMem, + Value resultMem, Value idx, Value c0, Value c1, Value c2) { + Value c200 = builder.create(loc, 200); + Value c398 = builder.create(loc, 398); + Value c399 = builder.create(loc, 399); + Value power = builder.create(loc, 2); + + Value firstNum = builder.create(loc, bufferMem, c0); + Value firstPow = builder.create(loc, firstNum, power); + builder.create(loc, firstPow, resultMem, + ValueRange{idx, c0}); + + Value lastNum = builder.create(loc, bufferMem, c399); + Value lastPow = builder.create(loc, lastNum, power); + builder.create(loc, lastPow, resultMem, + ValueRange{idx, c200}); + + builder.create( + loc, c1, c398, c2, ValueRange{c1}, + [&](OpBuilder &b, Location loc, Value iv, ValueRange iargs) { + Value j = b.create(loc, iv, c1); + Value num1 = b.create(loc, bufferMem, iv); + Value num2 = b.create(loc, bufferMem, j); + Value pow1 = b.create(loc, num1, power); + Value pow2 = b.create(loc, num2, power); + Value add = b.create(loc, pow1, pow2); + b.create(loc, add, resultMem, + ValueRange{idx, iargs[0]}); + + Value indexNext = b.create(loc, iargs[0], c1); + + b.create(loc, indexNext); + }); + + return; +} + +// Compute Log Mel Spectrogram +Value spectrogram(PatternRewriter &rewriter, Location loc, Value f0, Value c0, + Value c1, Value c2, Value c3, Value c4, Value c5, Value input, + Value window, Value melFilters) { + FloatType f64Ty = rewriter.getF64Type(); + + Value numFrames = rewriter.create(loc, 3001); + Value hopLength = rewriter.create(loc, 160); + Value c400 = rewriter.create(loc, 400); + + MemRefType spectrogramTy = MemRefType::get({3001, 201}, f64Ty); + Value spectrogram = rewriter.create(loc, spectrogramTy); + + RankedTensorType tensorTy0 = RankedTensorType::get({400}, f64Ty); + MemRefType mTp = MemRefType::get({400}, f64Ty); + + // #mulf_trait for 'linalg.generic' operation. + AffineMap mulFIdMap = + AffineMap::getMultiDimIdentityMap(1, rewriter.getContext()); + SmallVector mulFIndexingMaps = {mulFIdMap, mulFIdMap, mulFIdMap}; + SmallVector mulFIteratorTypes = { + utils::IteratorType::parallel}; + + rewriter.create( + loc, c0, numFrames, c1, ValueRange{c0}, + [&](OpBuilder &builder, Location loc, Value iv, ValueRange iargs) { + auto extractSliceOp = rewriter.create( + loc, input, iargs[0], c400, c1); + Value buffer400 = extractSliceOp.getResult(); + Value buffer = + rewriter.create(loc, tensorTy0, buffer400); + + // 'linalg.generic' operation use #mulf_trait. + auto mulfOp = rewriter.create( + loc, /*resultTensorTypes=*/tensorTy0, + /*inputs=*/ValueRange{buffer, window}, + /*outputs=*/ValueRange{buffer}, mulFIndexingMaps, mulFIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value elem = b.create(loc, args[0], args[1]); + b.create(loc, elem); + }); + Value multiplied = mulfOp.getResult(0); + + Value bufferMem_raw = + builder.create(loc, mTp, multiplied); + + MemRefType type0 = MemRefType::get({400}, f64Ty); + MemRefType type1 = MemRefType::get(ShapedType::kDynamic, f64Ty); + + Value bufferMem_rfft = + builder.create(loc, type1, bufferMem_raw); + + // Compute 'dap.rfft' operation, result stores in `bufferMem`. + builder.create(loc, bufferMem_rfft); + + Value bufferMem = + builder.create(loc, type0, bufferMem_rfft); + + // Store the result in a single line specified by `iv`. + absPower(builder, loc, bufferMem, spectrogram, iv, c0, c1, c2); + + Value timestepNext = + builder.create(loc, iargs[0], hopLength); + + builder.create(loc, timestepNext); + }); + + // TODO: check alloc and dealloc + // MemRefType melFiltersTransposeTy = MemRefType::get({80, 201}, f64Ty); + // Value alloc0 = rewriter.create(loc, + // melFiltersTransposeTy); Value init0 = + // rewriter.create(loc, alloc0); + Value init0 = + rewriter.create(loc, ArrayRef{80, 201}, f64Ty); + auto transposeOp0 = rewriter.create( + loc, /*input=*/melFilters, + /*init=*/init0, + /*permutation=*/ArrayRef{1, 0}); + Value melFiltersT = transposeOp0.getResult()[0]; + + Value gram = rewriter.create( + loc, spectrogram, /*restrict=*/true, /*writable=*/false); + Value init1 = rewriter.create( + loc, ArrayRef{201, 3001}, f64Ty); + auto transposeOp1 = rewriter.create( + loc, /*input=*/gram, + /*init=*/init1, + /*permutation=*/ArrayRef{1, 0}); + Value spectrogramT = transposeOp1.getResult()[0]; + + rewriter.create(loc, spectrogram); + + Value init2 = + rewriter.create(loc, ArrayRef{80, 3001}, f64Ty); + auto matmulOp = rewriter.create( + loc, /*inputs=*/ValueRange{melFiltersT, spectrogramT}, + /*outputs=*/ValueRange{init2}); + Value matMulResult = matmulOp.getResultTensors()[0]; + + // Initialize a tensor with constant `1e-10`. + RankedTensorType tensorTy1 = RankedTensorType::get({80, 3001}, f64Ty); + Value cMelFloor = rewriter.create( + loc, APFloat(double(0.0000000001)), f64Ty); + Value melFloor = rewriter.create(loc, tensorTy1, cMelFloor); + + auto linalgMaxOp = rewriter.create( + loc, /*input=*/ValueRange{melFloor, matMulResult}, + /*outputs=*/ValueRange{melFloor}); + Value spectrogramMax = linalgMaxOp.getResultTensors()[0]; + + // #log10_trait for 'linalg.generic' operation. + AffineMap log10IdMap = + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + SmallVector log10IndexingMaps = {log10IdMap, log10IdMap}; + SmallVector log10IteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel}; + + // 'linalg.generic' operation use #log10_trait. + auto log10Op = rewriter.create( + loc, /*resultTensorTypes=*/tensorTy1, + /*inputs=*/ValueRange{spectrogramMax}, + /*outputs=*/ValueRange{spectrogramMax}, log10IndexingMaps, + log10IteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value elem = b.create(loc, args[0]); + b.create(loc, elem); + }); + Value spectrogramLog10 = log10Op.getResult(0); + + return spectrogramLog10; +} + +namespace { +class DAPRFFTLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit DAPRFFTLowering(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(dap::RFFTOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + Value bufferMem = op->getOperand(0); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + Value c4 = rewriter.create(loc, 4); + Value c5 = rewriter.create(loc, 5); + Value c9 = rewriter.create(loc, 9); + Value c24 = rewriter.create(loc, 24); + Value c25 = rewriter.create(loc, 25); + Value c50 = rewriter.create(loc, 50); + + Value inputFeatures = rewriter.create( + loc, bufferMem, /*restrict=*/true, /*writable=*/true); + Value inputFeaturesSize = + rewriter.create(loc, inputFeatures, c0); + + FloatType f64Ty = rewriter.getF64Type(); + + Value f0 = + rewriter.create(loc, APFloat(double(0.0)), f64Ty); + Value f1 = + rewriter.create(loc, APFloat(double(1.0)), f64Ty); + + std::vector plan = make_rfftp_plan(rewriter, loc, inputFeaturesSize); + + Value Rfftp_fctdata_fct = plan[0]; + Value Rfftp_fctdata_tw = plan[1]; + Value Rfftp_fctdata_tws = plan[2]; + Value Rfftp_plan_length = plan[3]; + Value Rfftp_plan_nfct = plan[4]; + Value Rfftp_plan_mem = plan[5]; + + rfftp_forward(rewriter, loc, Rfftp_fctdata_fct, Rfftp_fctdata_tw, + Rfftp_fctdata_tws, Rfftp_plan_length, Rfftp_plan_nfct, + Rfftp_plan_mem, bufferMem, f1); + + rewriter.eraseOp(op); + return success(); + } +}; + +class DAPWhisperPreprocessLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit DAPWhisperPreprocessLowering(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(dap::WhisperPreprocessOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + Value input = op->getOperand(0); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + Value c4 = rewriter.create(loc, 4); + Value c5 = rewriter.create(loc, 5); + Value c80 = rewriter.create(loc, 80); + Value c3000 = rewriter.create(loc, 3000); + Value c480000 = rewriter.create(loc, 480000); + + FloatType f32 = FloatType::getF32(ctx); + FloatType f64 = FloatType::getF64(ctx); + + Value inputFeatures = rewriter.create( + loc, input, /*restrict=*/true, /*writable=*/false); + Value inputFeaturesSize = + rewriter.create(loc, inputFeatures, c0); + Value padConstantHigh = + rewriter.create(loc, c480000, inputFeaturesSize); + + // Pad inputFeatures to MaxLength = 480000 + SmallVector paddedShape; + paddedShape.push_back(480000); + + SmallVector lowValues; + SmallVector highValues; + lowValues.push_back(c0); + highValues.push_back(padConstantHigh); + + Value f0 = + rewriter.create(loc, APFloat(double(0.0)), f64); + auto padConstantOp = rewriter.create( + loc, RankedTensorType::get(paddedShape, f64), inputFeatures, lowValues, + highValues, f0); + Value paddedInput = padConstantOp.getResult(); + + // Generate melFilter with 391 numbers + Value melFilter = initMelFilter(rewriter, loc, c0, c1, f0); + + // Generate hanning window with length 400 + Value window = getHanningWindow400(rewriter, loc); + + // Reflect pad for paddedInput, both left and right part pad with length 200 + Value finalPaddedInput = + padReflect(rewriter, loc, c0, c1, paddedInput, 200, 200); + Value logSpec = spectrogram(rewriter, loc, f0, c0, c1, c2, c3, c4, c5, + finalPaddedInput, window, melFilter); + + auto extractSliceOp = rewriter.create( + loc, /*source=*/logSpec, + /*offsets=*/ValueRange{c0, c0}, + /*sizes=*/ValueRange{c80, c3000}, + /*strides=*/ValueRange{c1, c1}); + Value logSpecCut = extractSliceOp.getResult(); + + Value maxInit = + rewriter.create(loc, APFloat(double(-10.0)), f64); + auto forOp0 = rewriter.create( + loc, c0, c80, c1, maxInit, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iargs0) { + auto forOp1 = builder.create( + loc, c0, c3000, c1, iargs0[0], + [&](OpBuilder &b, Location loc, Value j, ValueRange iargs1) { + Value elem = b.create(loc, logSpecCut, + ValueRange{i, j}); + Value larger = + b.create(loc, elem, iargs1[0]); + b.create(loc, larger); + }); + + Value maxNext = forOp1.getResults()[0]; + builder.create(loc, maxNext); + }); + Value maxNum = forOp0.getResults()[0]; + + Value f8 = rewriter.create(loc, APFloat(double(8.0)), f64); + Value maxNumMinus8 = rewriter.create(loc, maxNum, f8); + Value logSpecFloor = rewriter.create( + loc, RankedTensorType::get({80, 3000}, f64), maxNumMinus8); + + auto linalgMaxOp = rewriter.create( + loc, /*input=*/ValueRange{logSpecCut, logSpecFloor}, + /*outputs=*/ValueRange{logSpecFloor}); + Value logSpecMax = linalgMaxOp.getResultTensors()[0]; + + Value f0F32 = + rewriter.create(loc, APFloat(float(0.0)), f32); + Value f4 = rewriter.create(loc, APFloat(double(4.0)), f64); + RankedTensorType resultTy = RankedTensorType::get({80, 3000}, f32); + Value InputFeaturesF32 = + rewriter.create(loc, resultTy, f0F32); + + // #tail_processing_trait for 'linalg.generic' operation. + AffineMap IdMap = + AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + SmallVector IndexingMaps = {IdMap, IdMap}; + SmallVector IteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel}; + + // 'linalg.generic' operation use #tail_processing_trait. + auto tailProcessOp = rewriter.create( + loc, /*resultTensorTypes=*/resultTy, + /*inputs=*/ValueRange{logSpecMax}, + /*outputs=*/ValueRange{InputFeaturesF32}, IndexingMaps, IteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value add4 = b.create(loc, args[0], f4); + Value div4 = b.create(loc, add4, f4); + Value elem = b.create(loc, f32, div4); + b.create(loc, elem); + }); + Value result = tailProcessOp.getResult(0); + + // Compute reassociation indices [[0, 1], 2] + SmallVector> reassociationIndices( + resultTy.getRank()); + int64_t index = 0; + for (index = 0; index <= 1; index++) { + reassociationIndices[0].push_back(index); + } + reassociationIndices[1].push_back(index); + + RankedTensorType expandTy = RankedTensorType::get({1, 80, 3000}, f32); + + Value resultExpand = rewriter.create( + loc, /*resultType=*/expandTy, /*src=*/result, + /*reassociation=*/reassociationIndices); + + auto resultMemTp = + MemRefType::get(expandTy.getShape(), expandTy.getElementType()); + Value resultMemRef = rewriter.create( + loc, resultMemTp, resultExpand); + + // Replace 'dap.whisper_preprocess' operation with the generated result. The + // replaced op is erased. + rewriter.replaceOp(op, resultMemRef); + return success(); + } +}; + +} // end anonymous namespace + +void populateExtendDAPConversionPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + // TODO : extract operators +} + +//===----------------------------------------------------------------------===// +// ExtendDAPPass +//===----------------------------------------------------------------------===// + +namespace { +class ExtendDAPPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ExtendDAPPass) + ExtendDAPPass() = default; + ExtendDAPPass(const ExtendDAPPass &) {} + + StringRef getArgument() const final { return "extend-dap"; } + StringRef getDescription() const final { return "Extend DAP Dialect."; } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + // Buddy Compiler designed dialect + registry.insert(); + } +}; +} // end anonymous namespace. + +void ExtendDAPPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + // Add legal dialects. + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + // Add legal operations. + target.addLegalOp(); + + RewritePatternSet patterns(context); + populateExtendDAPConversionPatterns(patterns); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerExtendDAPPass() { PassRegistration(); } +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp index 66b6d69ad2..6118ecc1c3 100644 --- a/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp +++ b/midend/lib/Conversion/LowerDIP/LowerDIPPass.cpp @@ -360,6 +360,286 @@ class DIPResize2DOpLowering : public OpRewritePattern { int64_t stride; }; +class DIPResize4D_NHWCOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit DIPResize4D_NHWCOpLowering(MLIRContext *context, int64_t strideParam) + : OpRewritePattern(context) { + stride = strideParam; + } + + LogicalResult matchAndRewrite(dip::Resize4D_NHWCOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Register operand values. + Value input = op->getOperand(0); + Value horizontalScalingFactor = op->getOperand(1); + Value verticalScalingFactor = op->getOperand(2); + Value output = op->getOperand(3); + auto interpolationAttr = op.getInterpolationType(); + Value strideVal = rewriter.create(loc, stride); + + auto inElemTy = input.getType().cast().getElementType(); + dip::DIP_ERROR error = + dip::checkDIPCommonTypes(op, {input, output}); + + if (error == dip::DIP_ERROR::INCONSISTENT_TYPES) { + return op->emitOpError() + << "input, and output must have the same element type"; + } else if (error == dip::DIP_ERROR::UNSUPPORTED_TYPE) { + return op->emitOpError() << "supports only f32, f64 and integer types. " + << inElemTy << "is passed"; + } + + // true: NHWC, false: NCHW + Value dataCondition = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(true)); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + + Value c0F32 = indexToF32(rewriter, loc, c0); + + // Value inputBatch = rewriter.create(loc, input, c0); + Value inputRow = rewriter.create(loc, input, c1); + Value inputCol = rewriter.create(loc, input, c2); + Value inputColor = rewriter.create(loc, input, c3); + + Value outputBatch = rewriter.create(loc, output, c0); + Value outputRow = rewriter.create(loc, output, c1); + Value outputCol = rewriter.create(loc, output, c2); + Value outputColor = rewriter.create(loc, output, c3); + + // Determine lower bound for second call of resize function (this is done + // for efficient tail processing). + Value outputColStrideRatio = + rewriter.create(loc, outputCol, strideVal); + Value outputColMultiple = + rewriter.create(loc, strideVal, outputColStrideRatio); + + SmallVector lowerBounds1{c0, c0, c0, c0}; + SmallVector upperBounds1{outputBatch, outputColor, outputRow, + outputColMultiple}; + + SmallVector steps{1, 1, 1, stride}; + Value strideTailVal = + rewriter.create(loc, outputCol, outputColMultiple); + + SmallVector lowerBounds2{c0, c0, c0, outputColMultiple}; + SmallVector upperBounds2{outputBatch, outputColor, outputRow, + outputCol}; + + FloatType f32 = FloatType::getF32(ctx); + VectorType vectorTy32 = VectorType::get({stride}, f32); + + Value horizontalScalingFactorVec = rewriter.create( + loc, vectorTy32, horizontalScalingFactor); + Value verticalScalingFactorVec = rewriter.create( + loc, vectorTy32, verticalScalingFactor); + + // Obtain extreme allocatable value(s) in input and output for bounding + // purpose. + Value inputRowLastElem = rewriter.create(loc, inputRow, c1); + Value inputRowLastElemF32 = indexToF32(rewriter, loc, inputRowLastElem); + + Value inputColLastElem = rewriter.create(loc, inputCol, c1); + Value inputColLastElemF32 = indexToF32(rewriter, loc, inputColLastElem); + + Value outputRowLastElem = + rewriter.create(loc, outputRow, c1); + Value outputRowLastElemF32 = indexToF32(rewriter, loc, outputRowLastElem); + + Value outputColLastElem = + rewriter.create(loc, outputCol, c1); + Value outputColLastElemF32 = indexToF32(rewriter, loc, outputColLastElem); + + if (interpolationAttr == + dip::InterpolationType::NearestNeighbourInterpolation) { + dip::NearestNeighbourInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); + + dip::NearestNeighbourInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); + } else if (interpolationAttr == + dip::InterpolationType::BilinearInterpolation) { + Value c1F32 = indexToF32(rewriter, loc, c1); + + dip::BilinearInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); + + dip::BilinearInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); + } + + // Remove the original resize operation. + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t stride; +}; + +class DIPResize4D_NCHWOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + explicit DIPResize4D_NCHWOpLowering(MLIRContext *context, int64_t strideParam) + : OpRewritePattern(context) { + stride = strideParam; + } + + LogicalResult matchAndRewrite(dip::Resize4D_NCHWOp op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + + // Register operand values. + Value input = op->getOperand(0); + Value horizontalScalingFactor = op->getOperand(1); + Value verticalScalingFactor = op->getOperand(2); + Value output = op->getOperand(3); + auto interpolationAttr = op.getInterpolationType(); + Value strideVal = rewriter.create(loc, stride); + + auto inElemTy = input.getType().cast().getElementType(); + dip::DIP_ERROR error = + dip::checkDIPCommonTypes(op, {input, output}); + + if (error == dip::DIP_ERROR::INCONSISTENT_TYPES) { + return op->emitOpError() + << "input, and output must have the same element type"; + } else if (error == dip::DIP_ERROR::UNSUPPORTED_TYPE) { + return op->emitOpError() << "supports only f32, f64 and integer types. " + << inElemTy << "is passed"; + } + + // true: NHWC, false: NCHW + Value dataCondition = rewriter.create( + loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + Value c2 = rewriter.create(loc, 2); + Value c3 = rewriter.create(loc, 3); + + Value c0F32 = indexToF32(rewriter, loc, c0); + + // Value inputBatch = rewriter.create(loc, input, c0); + Value inputColor = rewriter.create(loc, input, c1); + Value inputRow = rewriter.create(loc, input, c2); + Value inputCol = rewriter.create(loc, input, c3); + + Value outputBatch = rewriter.create(loc, output, c0); + Value outputColor = rewriter.create(loc, output, c1); + Value outputRow = rewriter.create(loc, output, c2); + Value outputCol = rewriter.create(loc, output, c3); + + // Determine lower bound for second call of resize function (this is done + // for efficient tail processing). + Value outputColStrideRatio = + rewriter.create(loc, outputCol, strideVal); + Value outputColMultiple = + rewriter.create(loc, strideVal, outputColStrideRatio); + + SmallVector lowerBounds1{c0, c0, c0, c0}; + SmallVector upperBounds1{outputBatch, outputColor, outputRow, + outputColMultiple}; + + SmallVector steps{1, 1, 1, stride}; + Value strideTailVal = + rewriter.create(loc, outputCol, outputColMultiple); + + SmallVector lowerBounds2{c0, c0, c0, outputColMultiple}; + SmallVector upperBounds2{outputBatch, outputColor, outputRow, + outputCol}; + + FloatType f32 = FloatType::getF32(ctx); + VectorType vectorTy32 = VectorType::get({stride}, f32); + + Value horizontalScalingFactorVec = rewriter.create( + loc, vectorTy32, horizontalScalingFactor); + Value verticalScalingFactorVec = rewriter.create( + loc, vectorTy32, verticalScalingFactor); + + // Obtain extreme allocatable value(s) in input and output for bounding + // purpose. + Value inputRowLastElem = rewriter.create(loc, inputRow, c1); + Value inputRowLastElemF32 = indexToF32(rewriter, loc, inputRowLastElem); + + Value inputColLastElem = rewriter.create(loc, inputCol, c1); + Value inputColLastElemF32 = indexToF32(rewriter, loc, inputColLastElem); + + Value outputRowLastElem = + rewriter.create(loc, outputRow, c1); + Value outputRowLastElemF32 = indexToF32(rewriter, loc, outputRowLastElem); + + Value outputColLastElem = + rewriter.create(loc, outputCol, c1); + Value outputColLastElemF32 = indexToF32(rewriter, loc, outputColLastElem); + + if (interpolationAttr == + dip::InterpolationType::NearestNeighbourInterpolation) { + dip::NearestNeighbourInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); + + dip::NearestNeighbourInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, dataCondition); + } else if (interpolationAttr == + dip::InterpolationType::BilinearInterpolation) { + Value c1F32 = indexToF32(rewriter, loc, c1); + + dip::BilinearInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds1, upperBounds1, steps, strideVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); + + dip::BilinearInterpolationResizing4D( + rewriter, loc, ctx, lowerBounds2, upperBounds2, steps, strideTailVal, + input, output, horizontalScalingFactorVec, verticalScalingFactorVec, + outputRowLastElemF32, outputColLastElemF32, inputRowLastElemF32, + inputColLastElemF32, vectorTy32, stride, c0, c0F32, c1F32, + dataCondition); + } + + // Remove the original resize operation. + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t stride; +}; + class DIPErosion2DOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1308,6 +1588,8 @@ void populateLowerDIPConversionPatterns(RewritePatternSet &patterns, patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); + patterns.add(patterns.getContext(), stride); + patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); patterns.add(patterns.getContext(), stride); diff --git a/midend/lib/Conversion/MLIRGPU/CMakeLists.txt b/midend/lib/Conversion/MLIRGPU/CMakeLists.txt new file mode 100644 index 0000000000..be7148357f --- /dev/null +++ b/midend/lib/Conversion/MLIRGPU/CMakeLists.txt @@ -0,0 +1,28 @@ +add_mlir_library(MLIRGPUPasses + ConvertMemcpyToGPU.cpp + LegalizeShmemOutlining.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRBufferizationDialect + MLIRControlFlowInterfaces + MLIRFuncDialect + MLIRFunctionInterfaces + MLIRInferTypeOpInterface + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTensorDialect + MLIRSCFDialect + MLIRSideEffectInterfaces + MLIRSubsetOpInterface + MLIRTransforms + MLIRViewLikeInterface + MLIRSupport + BuddyUtils + MLIRBufferizationTransforms + MLIRGPUDialect +) diff --git a/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp b/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp new file mode 100644 index 0000000000..dd50feccf8 --- /dev/null +++ b/midend/lib/Conversion/MLIRGPU/ConvertMemcpyToGPU.cpp @@ -0,0 +1,263 @@ +//===- ConvertMemcpyToGPU.cpp ---------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the pass that converts memcpy to gpu operations. +// +//===---------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// ConvertMemcpyToGPUPass +//===----------------------------------------------------------------------===// + +namespace { + +class ConvertMemcpyToGPUPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertMemcpyToGPUPass) + StringRef getArgument() const final { return "convert-memcpy-to-gpu"; } + StringRef getDescription() const final { + return "Convert memref opertaions to gpu operations."; + } + ConvertMemcpyToGPUPass() = default; + ConvertMemcpyToGPUPass(const ConvertMemcpyToGPUPass &) {} + + Option processArgs{ + *this, "process-args", + llvm::cl::desc("Whether the pass processes the input args."), + llvm::cl::init(true)}; + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; + +void ConvertMemcpyToGPUPass::runOnOperation() { + auto funcOp = getOperation(); + + // Make sure the gpu function is already outlined. + funcOp->walk([&](Operation *nestedOp) { + if (auto gpuLaunchOp = dyn_cast(nestedOp)) { + nestedOp->emitOpError("The gpu function should be outlined."); + } + return WalkResult::advance(); + }); + + std::set unDeallocatedOperations; + OpBuilder builder(funcOp->getContext()); + // Copy all function arguments to gpu, needs deallocation + if (processArgs) { + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + unsigned numArgs = funcOp.getNumArguments(); + for (unsigned i = 0; i < numArgs; ++i) { + BlockArgument arg = funcOp.getArgument(i); + // Create a gpu.alloc op, then copy memory to it + // TODO: Move this out of operation, make the copy process async + auto memrefType = dyn_cast(arg.getType()); + auto gpuAllocOp = builder.create( + builder.getUnknownLoc(), TypeRange({memrefType}), ValueRange({})); + unDeallocatedOperations.insert(&gpuAllocOp); + auto gpuMemcpyOp = builder.create( + gpuAllocOp.getLoc(), TypeRange(), ValueRange(), + gpuAllocOp.getResult(0), arg); + // Replace all users with GPU memory + auto users = arg.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + // Don't replace memcpy's operand + if (isa(user)) + continue; + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == arg) { + user->setOperand(j, gpuAllocOp.getResult(0)); + } + } + } + } + } + + funcOp->walk([&](Operation *nestedOp) { + // Replace all allocations with GPU.alloc + if (auto allocOp = dyn_cast(nestedOp)) { + // Rewrite this allocOp to gpu.alloc, change for all users + builder.setInsertionPointAfter(allocOp); + auto result = allocOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto memorySpace = memrefType.getMemorySpace(); + + // Filter operations. + if (memorySpace) { + if (auto intMemorySpace = llvm::dyn_cast(memorySpace)) { + if (intMemorySpace.getInt() != 0) { + return WalkResult::advance(); + } + } else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (gpuMemorySpace.getValue() != gpu::AddressSpace::Global) { + return WalkResult::advance(); + } + } else + return WalkResult::advance(); + } + + auto gpuAllocOp = builder.create( + allocOp->getLoc(), TypeRange({memrefType}), ValueRange({})); + auto users = result.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + for (size_t j = 0; j < user->getNumOperands(); j++) { + // Only the return value will not have dealloc op + if (auto deallocOp = dyn_cast(user)) { + builder.setInsertionPointAfter(deallocOp); + auto gpuDeallocOp = builder.create( + deallocOp->getLoc(), TypeRange(), ValueRange(), + gpuAllocOp.getResult(0)); + deallocOp->erase(); + } else if (user->getOperand(j) == result) { + user->setOperand(j, gpuAllocOp.getResult(0)); + } + } + } + allocOp->erase(); + } + // Replace all memory.copy operations with gpu.memcpy + else if (auto copyOp = dyn_cast(nestedOp)) { + auto src = copyOp.getOperand(0); + auto dst = copyOp.getOperand(1); + // Notice: GPU.memcpy has a different src dst order + builder.setInsertionPointAfter(copyOp); + auto gpuMemcpyOp = builder.create( + copyOp->getLoc(), TypeRange(), ValueRange(), dst, src); + { + auto users = src.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == src) { + user->setOperand(j, gpuMemcpyOp.getOperand(1)); + } + } + } + } + { + auto users = dst.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == src) { + user->setOperand(j, gpuMemcpyOp.getOperand(0)); + } + } + } + } + copyOp->erase(); + } + // Allocate space on GPU and copy global memrefs to GPU, needs deallocation + else if (auto getGlobalOp = dyn_cast(nestedOp)) { + builder.setInsertionPointAfter(getGlobalOp); + auto result = getGlobalOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto gpuAllocOp = builder.create( + getGlobalOp->getLoc(), TypeRange({memrefType}), ValueRange({})); + unDeallocatedOperations.insert(&gpuAllocOp); + auto src = result; + auto dst = gpuAllocOp->getResult(0); + auto gpuMemcpyOp = builder.create( + gpuAllocOp->getLoc(), TypeRange(), ValueRange(), dst, src); + { + auto users = src.getUsers(); + std::vector usersVec(users.begin(), users.end()); + for (auto user : usersVec) { + if (isa(user)) + continue; + // TODO: replace with src.replaceAllUsesExcept() + for (size_t j = 0; j < user->getNumOperands(); j++) { + if (user->getOperand(j) == src) { + user->setOperand(j, dst); + } + } + } + } + } + // Copy data back to CPU, deallocate GPU, then return + else if (auto returnOp = dyn_cast(nestedOp)) { + builder.setInsertionPoint(returnOp); + + for (auto *gpuAllocOp : unDeallocatedOperations) { + auto gpuDeallocOp = builder.create( + builder.getUnknownLoc(), TypeRange(), ValueRange(), + gpuAllocOp->getResult(0)); + } + builder.setInsertionPoint(returnOp); + for (unsigned i = 0; i < returnOp.getNumOperands(); ++i) { + auto val = returnOp->getOperand(i); + auto memRefType = dyn_cast(val.getType()); + auto allocOp = builder.create(builder.getUnknownLoc(), + memRefType); + auto gpuMemcpyOp = builder.create( + allocOp.getLoc(), TypeRange(), ValueRange(), allocOp->getResult(0), + val); + auto gpuDeallocOp = builder.create( + gpuMemcpyOp->getLoc(), TypeRange(), ValueRange(), val); + returnOp->setOperand(i, allocOp->getResult(0)); + } + } + return WalkResult::advance(); + }); +} +} // end anonymous namespace. + +namespace mlir { +namespace buddy { +void registerConvertMemcpyToGPUPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MLIRGPU/LegalizeShmemOutlining.cpp b/midend/lib/Conversion/MLIRGPU/LegalizeShmemOutlining.cpp new file mode 100644 index 0000000000..79638d4603 --- /dev/null +++ b/midend/lib/Conversion/MLIRGPU/LegalizeShmemOutlining.cpp @@ -0,0 +1,433 @@ +//===- LegalizeShmemOutlining.cpp -----------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the pass that legalizes shared memory operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +using namespace mlir; +using namespace vector; + +//===---------------------------------------------------------------------===// +// From mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +//===---------------------------------------------------------------------===// + +namespace mlir { +#define GEN_PASS_DEF_GPULAUNCHSINKINDEXCOMPUTATIONS +#define GEN_PASS_DEF_GPUKERNELOUTLINING +#include "mlir/Dialect/GPU/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +template +static void createForAllDimensions(OpBuilder &builder, Location loc, + SmallVectorImpl &values) { + for (auto dim : {gpu::Dimension::x, gpu::Dimension::y, gpu::Dimension::z}) + values.push_back(builder.create(loc, builder.getIndexType(), dim)); +} + +/// Adds operations generating block/thread ids and grid/block dimensions at the +/// beginning of the `launchFuncOpBody` region. Add mapping from argument in +/// entry block of `launchOpBody`, to the corresponding result value of the +/// added operations. +static void injectGpuIndexOperations(Location loc, Region &launchFuncOpBody, + Region &launchOpBody, IRMapping &map) { + OpBuilder builder(loc->getContext()); + Block &firstBlock = launchOpBody.front(); + builder.setInsertionPointToStart(&launchFuncOpBody.front()); + SmallVector indexOps; + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + // Replace the leading 12 function args with the respective thread/block index + // operations. Iterate backwards since args are erased and indices change. + for (const auto &indexOp : enumerate(indexOps)) + map.map(firstBlock.getArgument(indexOp.index()), indexOp.value()); +} + +/// Return the provided KernelDim3 as an array of i32 constants if possible. +static DenseI32ArrayAttr maybeConstantDimsAttr(gpu::KernelDim3 dims) { + SmallVector constants; + MLIRContext *ctx = dims.x.getContext(); + for (Value v : {dims.x, dims.y, dims.z}) { + APInt constValue; + if (!matchPattern(v, m_ConstantInt(&constValue))) + return nullptr; + // In the event someone called for a too-large block or grid dimension, + // don't set bounds as it is likely to cause more confusing behavior. + if (constValue.ugt(std::numeric_limits::max())) + return nullptr; + constants.push_back( + constValue.getLimitedValue(std::numeric_limits::max())); + } + return DenseI32ArrayAttr::get(ctx, constants); +} + +/// Outline the `gpu.launch` operation body into a kernel function. Replace +/// `gpu.terminator` operations by `gpu.return` in the generated function. +/// Set block and grid size bounds if known. +static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp, + StringRef kernelFnName, + SetVector &operands) { + Location loc = launchOp.getLoc(); + // Create a builder with no insertion point, insertion will happen separately + // due to symbol table manipulation. + OpBuilder builder(launchOp.getContext()); + Region &launchOpBody = launchOp.getBody(); + + // Identify uses from values defined outside of the scope of the launch + // operation. + getUsedValuesDefinedAbove(launchOpBody, operands); + + // Create the gpu.func operation. + SmallVector kernelOperandTypes; + kernelOperandTypes.reserve(operands.size()); + for (Value operand : operands) { + kernelOperandTypes.push_back(operand.getType()); + } + FunctionType type = + FunctionType::get(launchOp.getContext(), kernelOperandTypes, {}); + auto outlinedFunc = builder.create( + loc, kernelFnName, type, + TypeRange(ValueRange(launchOp.getWorkgroupAttributions())), + TypeRange(ValueRange(launchOp.getPrivateAttributions()))); + outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + + // If we can infer bounds on the grid and/or block sizes from the arguments + // to the launch op, propagate them to the generated kernel. This is safe + // because multiple launches with the same body are not deduplicated. + if (auto blockBounds = + maybeConstantDimsAttr(launchOp.getBlockSizeOperandValues())) + outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName(), + blockBounds); + if (auto gridBounds = + maybeConstantDimsAttr(launchOp.getGridSizeOperandValues())) + outlinedFunc->setAttr(gpu::GPUFuncOp::getKnownGridSizeAttrName(), + gridBounds); + + IRMapping map; + + // Map the arguments corresponding to the launch parameters like blockIdx, + // threadIdx, etc. + Region &outlinedFuncBody = outlinedFunc.getBody(); + injectGpuIndexOperations(loc, outlinedFuncBody, launchOpBody, map); + + // Map memory attributions from the LaunOp op to the GPUFuncOp attributions. + for (const auto &[launchArg, funcArg] : + llvm::zip(launchOp.getWorkgroupAttributions(), + outlinedFunc.getWorkgroupAttributions())) + map.map(launchArg, funcArg); + for (const auto &[launchArg, funcArg] : + llvm::zip(launchOp.getPrivateAttributions(), + outlinedFunc.getPrivateAttributions())) + map.map(launchArg, funcArg); + + // Map arguments from gpu.launch region to the arguments of the gpu.func + // operation. + Block &entryBlock = outlinedFuncBody.front(); + for (const auto &operand : enumerate(operands)) + map.map(operand.value(), entryBlock.getArgument(operand.index())); + + // Clone the region of the gpu.launch operation into the gpu.func operation. + // TODO: If cloneInto can be modified such that if a mapping for + // a block exists, that block will be used to clone operations into (at the + // end of the block), instead of creating a new block, this would be much + // cleaner. + launchOpBody.cloneInto(&outlinedFuncBody, map); + + // Branch from entry of the gpu.func operation to the block that is cloned + // from the entry block of the gpu.launch operation. + Block &launchOpEntry = launchOpBody.front(); + Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry); + builder.setInsertionPointToEnd(&entryBlock); + builder.create(loc, clonedLaunchOpEntry); + + outlinedFunc.walk([](gpu::TerminatorOp op) { + OpBuilder replacer(op); + replacer.create(op.getLoc()); + op.erase(); + }); + return outlinedFunc; +} + +/// Replace `gpu.launch` operations with an `gpu.launch_func` operation +/// launching `kernelFunc`. The kernel func contains the body of the +/// `gpu.launch` with constant region arguments inlined. +static void convertToLaunchFuncOp(gpu::LaunchOp launchOp, + gpu::GPUFuncOp kernelFunc, + ValueRange operands) { + OpBuilder builder(launchOp); + // The launch op has an optional dynamic shared memory size. If it doesn't + // exist, we use zero. + Value asyncToken = launchOp.getAsyncToken(); + auto launchFunc = builder.create( + launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), + launchOp.getBlockSizeOperandValues(), + launchOp.getDynamicSharedMemorySize(), operands, + asyncToken ? asyncToken.getType() : nullptr, + launchOp.getAsyncDependencies()); + launchOp.replaceAllUsesWith(launchFunc); + launchOp.erase(); +} + +/// Pass that moves the kernel of each LaunchOp into its separate nested module. +/// +/// This pass moves the kernel code of each LaunchOp into a function created +/// inside a nested module. It also creates an external function of the same +/// name in the parent module. +/// +/// The gpu.modules are intended to be compiled to a cubin blob independently in +/// a separate pass. The external functions can then be annotated with the +/// symbol of the cubin accessor function. + +namespace { +class LegalizeShmemOutliningPass + : public PassWrapper> { +public: + std::vector shmemAllocations; + std::map shmemGlobalPairs; + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeShmemOutliningPass) + StringRef getArgument() const final { return "legalize-shmem-outlining"; } + StringRef getDescription() const final { + return "Convert shared memory outlining to global memref declaration."; + } + + void runOnOperation() override { + SymbolTable symbolTable(getOperation()); + + bool modified = false; + for (auto func : getOperation().getOps()) { + // Insert just after the function. + Block::iterator insertPt(func->getNextNode()); + + // Collects all allocations for shared memory outside the kernel. + // The collection must happen before the kernel outlining. + // It moves back all shared allocations back into their GPU body + // Allowing the functions to create kernels without shared memory + // as parameters. + func.walk([&](memref::AllocOp allocOp) { + auto result = allocOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + return WalkResult::advance(); + else { + if (auto intMemorySpace = llvm::dyn_cast(memorySpace)) { + if (intMemorySpace.getInt() != 3) { + return WalkResult::advance(); + } + } else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (gpuMemorySpace.getValue() != gpu::AddressSpace::Workgroup) { + return WalkResult::advance(); + } + } else + return WalkResult::advance(); + } + auto users = allocOp->getUsers(); + for (auto user : users) { + if (isa(user)) { + user->erase(); + continue; + } + // Locates the gpu kernel wrapper + auto launchOp = user->getParentOfType(); + OpBuilder builder(launchOp); + builder.setInsertionPointToStart( + &launchOp.getBody().getBlocks().front()); + auto newAllocOp = + builder.create(launchOp.getLoc(), memrefType); + allocOp->replaceAllUsesWith(newAllocOp); + allocOp->erase(); + break; + } + return WalkResult::advance(); + }); + + auto funcWalkResult = func.walk([&](gpu::LaunchOp op) { + SetVector operands; + std::string kernelFnName = + Twine(op->getParentOfType().getName(), "_kernel") + .str(); + + gpu::GPUFuncOp outlinedFunc = + outlineKernelFuncImpl(op, kernelFnName, operands); + + // Create nested module and insert outlinedFunc. The module will + // originally get the same name as the function, but may be renamed on + // insertion into the parent module. + auto kernelModule = createKernelModule(outlinedFunc, symbolTable); + symbolTable.insert(kernelModule, insertPt); + + size_t counter = 0; + // Walk the funcop and replace all shmem allocations with global memref + outlinedFunc->walk([&](memref::AllocOp allocOp) { + auto result = allocOp->getResult(0); + auto memrefType = dyn_cast(result.getType()); + auto memorySpace = memrefType.getMemorySpace(); + if (!memorySpace) + allocOp->emitOpError() + << "Found non-shared memory inside a kernel function"; + else { + if (auto intMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (intMemorySpace.getInt() != 3) { + return WalkResult::advance(); + } + } else if (auto gpuMemorySpace = + llvm::dyn_cast(memorySpace)) { + if (gpuMemorySpace.getValue() != gpu::AddressSpace::Workgroup) { + return WalkResult::advance(); + } + } else + return WalkResult::advance(); + } + + OpBuilder builder(outlinedFunc); + + auto name = Twine("shmem_", std::to_string(counter++)).str(); + + auto globalOp = builder.create( + kernelModule->getLoc(), + /*sym_name=*/name, + /*sym_visibility=*/builder.getStringAttr("private"), + /*type=*/memrefType, + /*initial_value=*/ElementsAttr(), + /*constant=*/false, + /*alignment=*/builder.getI64IntegerAttr(64)); + // symbolTable.insert(globalOp); + builder.setInsertionPointAfter(allocOp); + Value getGlobalOp = builder.create( + allocOp->getLoc(), globalOp.getType(), name); + allocOp.replaceAllUsesWith(getGlobalOp); + allocOp->erase(); + return WalkResult::advance(); + }); + + // Potentially changes signature, pulling in constants. + convertToLaunchFuncOp(op, outlinedFunc, operands.getArrayRef()); + modified = true; + return WalkResult::advance(); + }); + if (funcWalkResult.wasInterrupted()) + return signalPassFailure(); + } + + // If any new module was inserted in this module, annotate this module as + // a container module. + if (modified) + getOperation()->setAttr(gpu::GPUDialect::getContainerModuleAttrName(), + UnitAttr::get(&getContext())); + } + +private: + /// Returns a gpu.module containing kernelFunc and all callees (recursive). + gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc, + const SymbolTable &parentSymbolTable) { + // TODO: This code cannot use an OpBuilder because it must be inserted into + // a SymbolTable by the caller. SymbolTable needs to be refactored to + // prevent manual building of Ops with symbols in code using SymbolTables + // and then this needs to use the OpBuilder. + auto *context = getOperation().getContext(); + OpBuilder builder(context); + auto kernelModule = builder.create(kernelFunc.getLoc(), + kernelFunc.getName()); + + SymbolTable symbolTable(kernelModule); + symbolTable.insert(kernelFunc); + + SmallVector symbolDefWorklist = {kernelFunc}; + while (!symbolDefWorklist.empty()) { + if (std::optional symbolUses = + SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + StringRef symbolName = + cast(symbolUse.getSymbolRef()).getValue(); + if (symbolTable.lookup(symbolName)) + continue; + + Operation *symbolDefClone = + parentSymbolTable.lookup(symbolName)->clone(); + symbolDefWorklist.push_back(symbolDefClone); + symbolTable.insert(symbolDefClone); + } + } + } + + return kernelModule; + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// LegalizeShmemOutliningPass +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace buddy { +void registerLegalizeShmemOutliningPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp index 757ac8ae91..6cedaa1655 100644 --- a/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulOptimize.cpp @@ -155,7 +155,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern { IntegerSet::get( 1, 1, {d0 * -affineVectorSize + s0 - affineVectorSize}, {false}), - ValueRange{loopVarBatchIdx, bCol}, true); + ValueRange{loopVarColOfB, bCol}, true); // Branch handling full vector operations. OpBuilder trueBranchBuilder = branchingOp.getThenBodyBuilder(); diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp new file mode 100644 index 0000000000..a3d079be22 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulSCFOptimize.cpp @@ -0,0 +1,281 @@ +//===- BatchMatMulOptimize.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the batchmatmul scf vectorization optimization. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; +using namespace affine; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class BatchMatMuSCFOptimizePattern : public ConversionPattern { +private: + int64_t vecSize; + +public: + explicit BatchMatMuSCFOptimizePattern(MLIRContext *context, + int64_t vecSizeParam) + : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Retrieve input tensors A, B, and C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Acquire the element type of input tensors. + Type elementType = A.getType().cast().getElementType(); + + // Define constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value cVecSize = + rewriter.create(loc, rewriter.getIndexAttr(vecSize)); + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr d2 = rewriter.getAffineDimExpr(2); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); + + const Value zeroElementType = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, 0); + Value aRow = rewriter.create(loc, A, 1); + Value bCol = rewriter.create(loc, B, 2); + Value bRow = rewriter.create(loc, B, 1); + + VectorType vecTy = VectorType::get({vecSize}, elementType); + Value zeroElementTypeVec; + if (isa(elementType)) + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + else + zeroElementTypeVec = + rewriter.create(loc, vecTy, zeroElementType); + // Calculate the length of the tail, which might not fit in a + // vector. + Value tailLength = rewriter.create( + loc, AffineMap::get(1, 0, d0 % vecSize), ValueRange{bCol}); + + // Generate a mask vector based on the tail length. + Value maskVector = rewriter.create( + loc, VectorType::get({vecSize}, rewriter.getI1Type()), + ValueRange{tailLength}); + + Value ApplyBCol = rewriter.create( + loc, AffineMap::get(1, 0, d0.floorDiv(vecSize) * vecSize), bCol); + + rewriter.create( + loc, SmallVector({c0}), + SmallVector({batch}), + SmallVector({c1}), ValueRange{}, + std::nullopt, // No mapping specified in this example + [&](OpBuilder &builder, Location loc, ValueRange loopIndices) { + Value loopVarBatchIdx = loopIndices[0]; + builder.create( + loc, c0, aRow, c1, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value loopVarRowOfA, + ValueRange iargs) { + builder.create( + loc, c0, bRow, c1, ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, Value loopVarRowOfB, + ValueRange iargs) { + Value aElement = builder.create( + loc, A, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarRowOfB}); + Value aVec = builder.create( + loc, vecTy, aElement); + builder.create( + loc, c0, ApplyBCol, cVecSize, + ValueRange{std::nullopt}, + [&](OpBuilder &builder, Location loc, + Value loopVarColOfB, ValueRange iargs) { + Value bVec = builder.create( + loc, vecTy, B, + ValueRange{loopVarBatchIdx, loopVarRowOfB, + loopVarColOfB}); + + Value cVec = builder.create( + loc, vecTy, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarColOfB}); + Value computedVec; + + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, cVec); + } else { + computedVec = builder.create( + loc, aVec, bVec, cVec); + } + builder.create( + loc, computedVec, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + loopVarColOfB}); + builder.create( + loc, ValueRange{std::nullopt}); + }); + Value condition = builder.create( + loc, arith::CmpIPredicate::sgt, tailLength, c0); + builder.create( + loc, condition, + [&](OpBuilder &builder, Location loc) { + Value bVec = builder.create( + loc, vecTy, B, + ValueRange{loopVarBatchIdx, loopVarRowOfB, + ApplyBCol}, + maskVector, zeroElementTypeVec); + + Value cVec = builder.create( + loc, vecTy, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + ApplyBCol}, + maskVector, zeroElementTypeVec); + + Value computedVec; + + if (isa(elementType)) { + Value mulVec = builder.create( + loc, aVec, bVec); + computedVec = builder.create( + loc, mulVec, cVec); + } else { + computedVec = builder.create( + loc, aVec, bVec, cVec); + } + + builder.create( + loc, C, + ValueRange{loopVarBatchIdx, loopVarRowOfA, + ApplyBCol}, + maskVector, computedVec); + builder.create(loc); + }); + builder.create(loc, + ValueRange{std::nullopt}); + }); + builder.create(loc, ValueRange{std::nullopt}); + }); + + builder.create(loc); + }); + + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BatchMatMuSCFOptimize +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class BatchMatMuSCFOptimize + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BatchMatMuSCFOptimize) + StringRef getArgument() const final { return "batchmatmul-scf-optimize"; } + StringRef getDescription() const final { + return "BatchMatMul SCF Optimization."; + } + BatchMatMuSCFOptimize() = default; + BatchMatMuSCFOptimize(const BatchMatMuSCFOptimize &) {} + explicit BatchMatMuSCFOptimize(int64_t vecSizeParam) { + vecSize = vecSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vector-size", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(16)}; +}; +} // end anonymous namespace. + +void BatchMatMuSCFOptimize::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} +// add to buddy-opt.cpp +namespace mlir { +namespace buddy { +void registerBatchMatMuSCFOptimize() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp b/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp new file mode 100644 index 0000000000..91d10c6456 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/BatchMatMulTileOptimize.cpp @@ -0,0 +1,353 @@ +//===- BatchMatMulOptimize.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the batchmatmul tile optimization. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/ArrayRef.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; +using namespace affine; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class BatchMatMulTileOptimizePattern : public ConversionPattern { +private: + int64_t vecSize, kernelM, kernelN; + +public: + explicit BatchMatMulTileOptimizePattern(MLIRContext *context, + int64_t vecSizeParam, + int64_t kernelMParam, + int64_t kernelNParam) + : ConversionPattern(linalg::BatchMatmulOp::getOperationName(), 1, + context) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // Retrieve input tensors A, B, and C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Acquire the element type of input tensors. + Type elementType = A.getType().cast().getElementType(); + ShapedType ATy = A.getType().cast(); + + // Define constants. + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + + const AffineExpr d0 = rewriter.getAffineDimExpr(0); + const AffineExpr d1 = rewriter.getAffineDimExpr(1); + const AffineExpr d2 = rewriter.getAffineDimExpr(2); + const AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + const AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + const AffineExpr s2 = rewriter.getAffineSymbolExpr(2); + + const AffineExpr zeroAffine = rewriter.getAffineConstantExpr(0); + + // Get dimensions of input tensors. + Value batch = rewriter.create(loc, A, 0); + Value M = rewriter.create(loc, A, 1); // aRow + Value K = rewriter.create(loc, B, 1); // bRow + Value N = rewriter.create(loc, B, 2); // bCol + + SmallVector reducedValues = llvm::to_vector<4>( + llvm::map_range(ArrayRef{}, + [](const LoopReduction &red) { return red.value; })); + + // Configs + int64_t kNLen = vecSize * kernelN; + + // Create the primary parallel batch level loop. + AffineParallelOp parallelBatchLoop = + rewriter.create( + loc, ValueRange(reducedValues).getTypes(), ValueRange{batch}, + ArrayRef{ + rewriter.getNamedAttr("lowerBoundsGroups", + rewriter.getI32TensorAttr({1})), + rewriter.getNamedAttr("upperBoundsGroups", + rewriter.getI32TensorAttr({1})), + rewriter.getNamedAttr( + "lowerBoundsMap", + AffineMapAttr::get(AffineMap::get(0, 0, {zeroAffine}, + rewriter.getContext()))), + rewriter.getNamedAttr("upperBoundsMap", + AffineMapAttr::get(AffineMap::get( + 1, 0, {d0}, rewriter.getContext()))), + rewriter.getNamedAttr("reductions", rewriter.getArrayAttr({})), + rewriter.getNamedAttr("steps", rewriter.getI64ArrayAttr({1}))}); + + // Create the loop body for the parallel loop. + Block *loopBody = new Block(); + rewriter.setInsertionPointToStart(loopBody); + loopBody->addArgument(rewriter.getIndexType(), loc); + Value loopVarBatchIdx = loopBody->getArguments()[0]; + + // Prefetching data from tensor 'A' for better cache utilization. + rewriter.create( + loc, A, AffineMap::get(3, 0, {d0, d1, d2}, rewriter.getContext()), + ArrayRef{loopVarBatchIdx, M, K}, false, 3, true); + + // build loop body + affine::buildAffineLoopNest( + rewriter, loc, {c0}, {N}, kNLen, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + auto ivJ = ivRange.front(); + affine::buildAffineLoopNest( + builder, loc, {c0}, {M}, kernelM, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivI = ivRange.front(); + SmallVector cptrs; + + const VectorType vTy = + VectorType::get(vecSize, ATy.getElementType()); + + for (int i = 0; i < kernelM; i++) { + Value fixedIV = builder.create( + loc, + AffineMap::get(1, 1, {d0 + i, s0 - 1}, + builder.getContext()), + SmallVector{ivI, M}); + MemRefType resTy = MemRefType::get( + ATy.getShape(), ATy.getElementType(), + AffineMap::get(3, 3, d1 * s2 + d0 * s1 + s0 + d2)); + auto cptr = builder.create( + loc, resTy, C, + SmallVector{loopVarBatchIdx, fixedIV, c0}, + SmallVector{c1, c1, N}, + SmallVector{c1, c1, c1}); + cptrs.push_back(cptr); + } + affine::buildAffineLoopNest( + builder, loc, {c0}, {K}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value ivK = ivRange.front(); + SmallVector bs; + + for (int j = 0; j < kernelN; j++) { + Value fixedJV = ivJ; + if (j != 0) { + fixedJV = builder.create( + loc, AffineMap::get(1, 0, d0 + j * vecSize), ivJ); + } + bs.push_back(builder.create( + loc, vTy, B, + ValueRange{loopVarBatchIdx, ivK, fixedJV})); + } + + for (int i = 0; i < kernelM; ++i) { + Value fixedIV = ivI; + if (i != 0) { + fixedIV = builder.create( + loc, + AffineMap::get(1, 0, {d0 + i}, + builder.getContext()), + SmallVector{ivI}); + } + affine::AffineIfOp mBranchingOp = + builder.create( + loc, + IntegerSet::get(1, 1, {-d0 + s0 - 1}, {false}), + ValueRange{fixedIV, M}, false); + OpBuilder mTrueBranchBuilder = + mBranchingOp.getThenBodyBuilder(); + Value ksubAElement = + mTrueBranchBuilder.create( + loc, A, + ValueRange{loopVarBatchIdx, fixedIV, ivK}); + + for (int j = 0; j < kernelN; j++) { + Value fixedJV = ivJ; + if (j != 0) { + fixedJV = + mTrueBranchBuilder + .create( + loc, + AffineMap::get(1, 0, d0 + j * vecSize), + ivJ); + } + Value vecC = mTrueBranchBuilder.create( + loc, vTy, cptrs[i], ValueRange{c0, c0, fixedJV}); + if (isa(elementType)) { + Value vecA = + mTrueBranchBuilder.create( + loc, vTy, ksubAElement); + Value vecMul = + mTrueBranchBuilder.create( + loc, vTy, vecA, bs[j]); + vecC = mTrueBranchBuilder.create( + loc, vTy, vecMul, vecC); + } else { + Value vecA = + mTrueBranchBuilder.create( + loc, vTy, ksubAElement); + vecC = mTrueBranchBuilder.create( + loc, vTy, vecA, bs[j], vecC); + } + // store vecC + Value tailLength = + mTrueBranchBuilder.create( + loc, AffineMap::get(2, 0, -d0 + d1), + ValueRange{fixedJV, N}); + affine::AffineIfOp nBranchingOp = + mTrueBranchBuilder.create( + loc, + IntegerSet::get(1, 0, {-vecSize + d0}, + {false}), + ValueRange{tailLength}, true); + // Calculate the length of the tail, which might not + // fit in a vector. + OpBuilder nTrueBranchBuilder = + nBranchingOp.getThenBodyBuilder(); + nTrueBranchBuilder.create( + loc, vecC, cptrs[i], ValueRange{c0, c0, fixedJV}); + OpBuilder nFalseBranchBuilder = + nBranchingOp.getElseBodyBuilder(); + // Generate a mask vector based on the tail length. + Value maskVector = + nFalseBranchBuilder.create( + loc, + VectorType::get({vecSize}, + rewriter.getI1Type()), + ValueRange{tailLength}); + nFalseBranchBuilder.create( + loc, cptrs[i], ValueRange{c0, c0, fixedJV}, + maskVector, vecC); + } + } + }); + }); + }); + + rewriter.create(loc); + + // Finalize the loop and erase the original operation. + parallelBatchLoop.getRegion().push_back(loopBody); + rewriter.setInsertionPointAfter(parallelBatchLoop); + + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// BatchMatMulTileOptimizePass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering linalg pooling operations to mixture of +/// Affine + Vector operations. +namespace { +class BatchMatMulTileOptimizePass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BatchMatMulTileOptimizePass) + StringRef getArgument() const final { return "batchmatmul-tile-optimize"; } + StringRef getDescription() const final { + return "BatchMatMul Tile Optimization."; + } + BatchMatMulTileOptimizePass() = default; + BatchMatMulTileOptimizePass(const BatchMatMulTileOptimizePass &) {} + explicit BatchMatMulTileOptimizePass(int64_t vecSizeParam, + int64_t kernelMParam, + int64_t kernelNParam) { + vecSize = vecSizeParam; + kernelM = kernelMParam; + kernelN = kernelNParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vecSize{*this, "vec-size", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(16)}; + + Option kernelM{*this, "kernel-m", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(4)}; + + Option kernelN{*this, "kernel-n", + llvm::cl::desc("Strip mining size."), + llvm::cl::init(2)}; +}; +} // end anonymous namespace. + +void BatchMatMulTileOptimizePass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context, vecSize, kernelM, + kernelN); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} +// add to buddy-opt.cpp +namespace mlir { +namespace buddy { +void registerBatchMatMulTileOptimizePass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt index 8e726863eb..7ec2cf4ac4 100644 --- a/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt +++ b/midend/lib/Conversion/MatMulOptimization/CMakeLists.txt @@ -1,8 +1,14 @@ add_mlir_library(MatMulOptimization - BatchMatMulOptimize.cpp MatMulOptimize.cpp MatMulVectorization.cpp MatMulParallelVectorization.cpp + BatchMatMulOptimize.cpp + BatchMatMulTileOptimize.cpp + BatchMatMulSCFOptimize.cpp + MatMulTransposeBVec.cpp + BatchMatMulOptimize.cpp + BatchMatMulTileOptimize.cpp + BatchMatMulSCFOptimize.cpp LINK_LIBS PUBLIC BuddyUtils ) @@ -14,3 +20,7 @@ add_mlir_library(BatchMatMulOptimization add_mlir_library(MatMulParallelVectorization MatMulParallelVectorization.cpp ) + +add_mlir_library(MatMulTransposeBVec + MatMulTransposeBVec.cpp +) diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp index d10c80e3ac..23d0ef4e7b 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulParallelVectorization.cpp @@ -14,7 +14,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements the matmul-paralell-vectorization optimization. +// This file implements the matmul-parallel-vectorization optimization. // //===----------------------------------------------------------------------===// @@ -318,7 +318,7 @@ class MatMulParallelVectorizationPass public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulParallelVectorizationPass) StringRef getArgument() const final { - return "matmul-paralell-vectorization-optimize"; + return "matmul-parallel-vectorization-optimize"; } StringRef getDescription() const final { return "MatMulParallelVectorization Optimization."; diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp new file mode 100644 index 0000000000..4500119d76 --- /dev/null +++ b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp @@ -0,0 +1,214 @@ +//===- MatMulTransposeBVec.cpp --------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Matmul_TransposeB vectorization. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Utils/Utils.h" + +using namespace mlir; +using namespace vector; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { +class MatMulTransposeBVecPattern : public ConversionPattern{ +public: + explicit MatMulTransposeBVecPattern(MLIRContext *context,int64_t vecSizeparam) + : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(),1,context){ + vecSize = vecSizeparam; + } + + LogicalResult + matchAndRewrite(Operation *op,ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override{ + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Get shape of input and output. + ShapedType ATy = A.getType().cast(); + Type eleTy = ATy.getElementType(); + + // the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + + VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); + VectorType vectorMaskTy = VectorType::get({vecSize}, i1); + + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value step = rewriter.create(loc, vecSize); + + const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); + Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); + + const Value aRow = rewriter.create(loc, A, c0); + const Value bRow = rewriter.create(loc, B, c0); + const Value bCol = rewriter.create(loc, B, c1); + + AffineExpr d0; + bindDims(ctx, d0); + AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); + SmallVector lowerBounds(2, c0); + SmallVector uperBounds{aRow, bRow}; + SmallVector steps(2, 1); + // clang-format off + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create loop based on vector size. + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{bCol}, vecTailMap, 1, std::nullopt, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + AffineExpr a,b,c; + bindDims(ctx, a,b,c); + AffineMap AVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); + // Check tail. + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + AffineMap BVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); + + // Calculate the tail. + Value bColCur = builder.create(loc, iv, step); + Value tailLen = builder.create(loc, bCol, bColCur); + Value tailFlag = rewriter.create( + loc, arith::CmpIPredicate::sge, tailLen, step); + // If the current column does not reach the tail. + builder.create(loc, tailFlag, + [&](OpBuilder &builder, Location loc) { + Value aVec = builder.create( + loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); + Value bVec = builder.create( + loc, vectorTy, B, BVectorMap, ValueRange{ivs[1], ivs[1], iv}); + Value resvec = builder.create(loc,aVec,bVec); + Value res1 = builder.create( + loc,mlir::vector::CombiningKind::ADD,resvec); + Value res2 = builder.create( + loc, C, ValueRange{ivs[0], ivs[1]}); + Value sum = builder.create(loc, res1, res2); + builder.create(loc, sum, + C, ValueRange{ivs[0], ivs[1]}); + builder.create(loc); + }, + // The else branch + [&](OpBuilder &builder, Location loc) { + Value aVec = builder.create( + loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); + // Create mask according to the tail. + Value maskVec = builder.create( + loc, vectorMaskTy, tailLen); + Value ColIdxTail = builder.create(loc, iv, step); + + Value aVecTail = builder.create( + loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, + maskVec, passthruVec); + + Value bVecTail = builder.create( + loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, + maskVec, passthruVec); + + Value resvec = builder.create(loc,aVecTail,bVecTail); + Value res1 = builder.create( + loc,mlir::vector::CombiningKind::ADD,resvec); + Value res2 = builder.create( + loc, C, ValueRange{ivs[0], ivs[1]}); + Value sum = builder.create(loc, res1, res2); + builder.create(loc, sum, C, ValueRange{ivs[0], ivs[1]}); + builder.create(loc); + }); + builder.create(loc); + }); + }); + // clang-format on + rewriter.eraseOp(op); + return success(); + } +private: + int64_t vecSize; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// MatMulVectorizationPass +//===----------------------------------------------------------------------===// + +namespace{ + class MatMulTransposeBVecPass + :public PassWrapper>{ +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) + StringRef getArgument() const final{ return "matmul-transpose-b-vectorization"; } + StringRef getDescription() const final { return "vectorize linalg MatmulTransposeBOp"; } + MatMulTransposeBVecPass() = default; + MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override{ + registry.insert(); + } + Option vecSize{*this,"vec-size", + llvm::cl::desc("The size of vectorization"), + llvm::cl::init(32)}; + +}; +} + +void MatMulTransposeBVecPass::runOnOperation(){ + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(context); + patterns.add(context,vecSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} + +namespace mlir { +namespace buddy { +void registerMatMulTransposeBVecPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/Dialect/CMakeLists.txt b/midend/lib/Dialect/CMakeLists.txt index 8ab8f29f58..afedee5d69 100644 --- a/midend/lib/Dialect/CMakeLists.txt +++ b/midend/lib/Dialect/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(RVV) add_subdirectory(VectorExp) add_subdirectory(Gemmini) add_subdirectory(Sche) +add_subdirectory(GPU) diff --git a/midend/lib/Dialect/GPU/CMakeLists.txt b/midend/lib/Dialect/GPU/CMakeLists.txt new file mode 100644 index 0000000000..b575a44e27 --- /dev/null +++ b/midend/lib/Dialect/GPU/CMakeLists.txt @@ -0,0 +1,42 @@ +add_mlir_library(BuddyGPUTransformOPs + TransformOps.cpp + + DEPENDS + TransformOpsIncGen + + LINK_LIBS PUBLIC + LLVMSupport + BuddyGPUUtils + MLIRAffineDialect + MLIRArithDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRBytecodeWriter + MLIRFuncDialect + MLIRFunctionInterfaces + MLIRGPUDialect + MLIRGPUTransformOps + MLIRNVGPUDialect + MLIRIndexDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMemRefDialect + MLIRNVGPUDialect + MLIRNVGPUTransforms + MLIRParser + MLIRPDLDialect + MLIRPass + MLIRSCFDialect + MLIRSideEffectInterfaces + MLIRTensorTransformOps + MLIRTransformDialect + MLIRTransformDialectUtils + MLIRTransforms + MLIRVectorDialect + MLIRVectorToGPU + MLIRVectorTransforms + MLIRViewLikeInterface + MLIRGPUPasses + ) diff --git a/midend/lib/Dialect/GPU/TransformOps.cpp b/midend/lib/Dialect/GPU/TransformOps.cpp new file mode 100644 index 0000000000..3e689fc931 --- /dev/null +++ b/midend/lib/Dialect/GPU/TransformOps.cpp @@ -0,0 +1,211 @@ +//===- TransformOps.cpp ---------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file implements transform ops for GPU targets. +// +//===----------------------------------------------------------------------===// + +#include "GPU/TransformOps.h" + +#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include +#include + +#include "Utils/GPUUtils.h" + +using namespace mlir; +using namespace mlir::buddy; + +using llvm::dbgs; + +#define DEBUG_TYPE "transform-llvmgpu-extensions" +#define DEBUG_TYPE_ALIAS "transform-llvmgpu-extensions-alias" +#define DEBUG_VECTOR_TO_MMA "transform-llvmgpu-extensions-vector-to-mma" + +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(dbgs() << '[' << DEBUG_TYPE << "] " << X) +#define DBGS_ALIAS() (dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") +#define DBGS_VECTOR_TO_MMA() (dbgs() << '[' << DEBUG_VECTOR_TO_MMA << "] ") + +buddy::gpu::TransformExtensions::TransformExtensions() { + // CreateAsyncGroupsOp depends on the following two dialects. + declareGeneratedDialect(); + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "GPU/TransformOps.cpp.inc" + >(); +} + +void buddy::registerBuddyGPUTransformOps(DialectRegistry ®istry) { + registry.addExtensions(); +} + +//===----------------------------------------------------------------------===// +// HoistStaticAllocOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure buddy::gpu::HoistStaticAllocOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, func::FuncOp target, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + hoistStaticallyBoundAllocationsInFunc(rewriter, target); + return DiagnosedSilenceableFailure::success(); +} + +void buddy::gpu::HoistStaticAllocOp::getEffects( + SmallVectorImpl &effects) { + mlir::transform::onlyReadsHandle(getTarget(), effects); + mlir::transform::modifiesPayload(effects); +} + +//===---------------------------------------------------------------------===// +// ApplyUnrollVectorsGpuMmaSyncPatternsOp +//===---------------------------------------------------------------------===// + +static std::optional> +getGPUTensorCoreNativeMmaSyncVectorSize(Operation *op) { + return buddy::gpu::getMmaNativeVectorSize(op); +} + +void buddy::gpu::ApplyUnrollVectorsGpuMmaSyncPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + auto unrollOrder = [](Operation *op) -> std::optional> { + auto contract = dyn_cast(op); + if (!contract) + return std::nullopt; + return gpuMmaUnrollOrder(contract); + }; + vector::populateVectorUnrollPatterns( + patterns, vector::UnrollVectorOptions() + .setNativeShapeFn(getGPUTensorCoreNativeMmaSyncVectorSize) + .setUnrollTraversalOrderFn(unrollOrder)); +} + +//===---------------------------------------------------------------------===// +// VectorToMMAConversionOp +//===---------------------------------------------------------------------===// + +void buddy::gpu::VectorToMMAConversionOp::getEffects( + SmallVectorImpl &effects) { + mlir::transform::onlyReadsHandle(getTarget(), effects); + mlir::transform::modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +buddy::gpu::VectorToMMAConversionOp::applyToOne( + mlir::transform::TransformRewriter &rewriter, Operation *target, + mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformState &state) { + if (!target->hasTrait()) { + // target->emitOpError( + // "applies only to isolated-from-above targets because it " + // "needs to apply " + // "patterns greedily"); + // return emitDefaultDefiniteFailure(target); + } + + auto funcOp = dyn_cast(target); + if (!funcOp) { + target->emitOpError("Must apply to a func op"); + return emitDefaultDefiniteFailure(target); + } + + if (!(getUseMmaSync() ^ getUseWmma())) { + target->emitOpError( + "Exactly one of use_mma_sync or use_wmma must be specified"); + return emitDefaultDefiniteFailure(target); + } + + MLIRContext *ctx = target->getContext(); + mlir::transform::ErrorCheckingTrackingListener listener(state, *this); + GreedyRewriteConfig config; + config.listener = &listener; + + // Unrolling to native vector size must have previously occurred. + // TODO: Add pattern to propagate the extract through the scf.for + // ops. Convert slice of contract operations to mma_sync/wmma ops. + RewritePatternSet patterns(ctx); + mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + populatePrepareVectorToMMAPatterns(patterns, getUseMmaSync()); + if (failed( + applyPatternsAndFoldGreedily(target, std::move(patterns), config))) { + target->emitOpError("vector to mma preparation patterns failed to apply"); + return emitDefaultDefiniteFailure(target); + } + + auto diag = DiagnosedSilenceableFailure::success(); + if (getUseWmma()) { + if (failed(convertVectorToMMAOps(rewriter, target))) + return mlir::emitDefiniteFailure( + target, "vector to wmma patterns failed to apply"); + return listener.checkAndResetError(); + } + + if (failed(convertVectorToNVVMCompatibleMMASync(rewriter, funcOp))) + return mlir::emitDefiniteFailure(target, + "vector to mma patterns failed to apply"); + + // Using TF32 for Float. + RewritePatternSet f32ToTF32patterns(funcOp.getContext()); + nvgpu::populateMmaSyncF32ToTF32Patterns(f32ToTF32patterns, + nvgpu::MmaSyncF32Lowering::TF32); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(f32ToTF32patterns), + config))) + return mlir::emitDefiniteFailure( + target, "vector to mma F32ToTF32 patterns failed to apply"); + + return listener.checkAndResetError(); +} + +#define GET_OP_CLASSES +#include "GPU/TransformOps.cpp.inc" diff --git a/midend/lib/Utils/CMakeLists.txt b/midend/lib/Utils/CMakeLists.txt index ff9aa6e380..9cedf4f6a7 100644 --- a/midend/lib/Utils/CMakeLists.txt +++ b/midend/lib/Utils/CMakeLists.txt @@ -2,12 +2,13 @@ add_mlir_library(BuddyUtils Utils.cpp DIPUtils.cpp DAPUtils.cpp + GPUUtils.cpp AffineTransformUtils.cpp ) add_mlir_library(BuddyDIPUtils DIPUtils.cpp - + LINK_LIBS PUBLIC BuddyUtils ) @@ -18,3 +19,32 @@ add_mlir_library(BuddyDAPUtils LINK_LIBS PUBLIC BuddyUtils ) + +add_mlir_library(BuddyGPUUtils + GPUUtils.cpp + + LINK_LIBS PUBLIC + LLVMSupport + LLVMTargetParser + MLIRAffineDialect + MLIRAffineUtils + MLIRAnalysis + MLIRArithDialect + MLIRArithUtils + MLIRFuncDialect + MLIRGPUDialect + MLIRIR + MLIRLinalgDialect + MLIRLinalgTransforms + MLIRLinalgUtils + MLIRMemRefDialect + MLIRSCFDialect + MLIRSideEffectInterfaces + MLIRSupport + MLIRTensorDialect + MLIRTilingInterface + MLIRTransformUtils + MLIRVectorDialect + MLIRViewLikeInterface + MLIRGPUPasses +) diff --git a/midend/lib/Utils/DIPUtils.cpp b/midend/lib/Utils/DIPUtils.cpp index 0e71af0f1f..da41b65cd6 100644 --- a/midend/lib/Utils/DIPUtils.cpp +++ b/midend/lib/Utils/DIPUtils.cpp @@ -53,6 +53,15 @@ checkDIPCommonTypes(dip::Rotate2DOp, template DIP_ERROR checkDIPCommonTypes(dip::Resize2DOp, const std::vector &args); + +template DIP_ERROR +checkDIPCommonTypes(dip::Resize4D_NHWCOp, + const std::vector &args); + +template DIP_ERROR +checkDIPCommonTypes(dip::Resize4D_NCHWOp, + const std::vector &args); + template DIP_ERROR checkDIPCommonTypes(dip::Erosion2DOp, const std::vector &args); @@ -107,7 +116,9 @@ DIP_ERROR checkDIPCommonTypes(DIPOP op, const std::vector &args) { return DIP_ERROR::UNSUPPORTED_TYPE; } } else if (op->getName().stripDialect() == "rotate_2d" || - op->getName().stripDialect() == "resize_2d") { + op->getName().stripDialect() == "resize_2d" || + op->getName().stripDialect() == "resize_4d_nhwc" || + op->getName().stripDialect() == "resize_4d_nchw") { auto inElemTy = getElementType(0); auto outElemTy = getElementType(1); @@ -381,6 +392,61 @@ void fillPixels(OpBuilder &builder, Location loc, Value resXVec, Value resYVec, }); } +// Fill appropriate pixel data in its corresponding co-ordinate of the output +// image. +void fillPixelsNearestNeighbour4D( + OpBuilder &builder, Location loc, Value ivs0, Value ivs1, Value resXVec, + Value resYVec, Value xVec, Value yVec, Value input, Value output, Value c0, + Value strideVal, Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, + Value dataCondition) { + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{strideVal}, + builder.getDimIdentityMap(), /*step*/ 1, std::nullopt, + [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArg) { + std::vector origIndices = + extractIndices(builder, loc, xVec, yVec, ivs[0], + inputColLastElemF32, inputRowLastElemF32, c0F32); + std::vector resIndices = + extractIndices(builder, loc, resXVec, resYVec, ivs[0], + outputColLastElemF32, outputRowLastElemF32, c0F32); + + auto ifop = builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + Value pixelVal = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, origIndices[1], origIndices[0], ivs1}); + builder.create(loc, pixelVal); + }, + [&](OpBuilder &builder, Location loc) { + Value pixelVal = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, origIndices[1], origIndices[0]}); + builder.create(loc, pixelVal); + }); + Value pixelVal = ifop.getResult(0); + + builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, resIndices[1], resIndices[0], ivs1}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, ivs1, resIndices[1], resIndices[0]}); + builder.create(loc); + }); + + builder.create(loc); + }); +} + // Calculate tan(angle / 2) where angle is a function parameter. Value customTanVal(OpBuilder &builder, Location loc, Value angleVal) { Value c2F32 = builder.create(loc, (llvm::APFloat)2.0f, @@ -517,7 +583,8 @@ void affineTransformController(OpBuilder &builder, Location loc, Value xMm3 = builder.create(loc, dynamicTypeI32, outputColMultiple); - // RSV_BITS = reserved bits, how many bits should be reserved for fraction part + // RSV_BITS = reserved bits, how many bits should be reserved for fraction + // part // TODO: make reserved bits configurable const int RSV_BITS = 5; Value c_rsv = builder.create( @@ -735,6 +802,133 @@ void fillPixelsBilinearInterpolate( }); } +// Fills pixels in bilinear interpolation fashion. +void fillPixelsBilinearInterpolate4D( + OpBuilder &builder, Location loc, Value ivs0, Value ivs1, Value resXVec, + Value resYVec, Value xVec_L, Value yVec_L, Value xVec_H, Value yVec_H, + Value input, Value output, Value c0, Value strideVal, Value xVecWeight, + Value yVecWeight, Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, Value c0F32, + Value c1F32, Value dataCondition) { + builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), ValueRange{strideVal}, + builder.getDimIdentityMap(), /*step*/ 1, std::nullopt, + [&](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArg) { + std::vector resIndices = + extractIndices(builder, loc, resXVec, resYVec, ivs[0], + outputColLastElemF32, outputRowLastElemF32, c0F32); + + std::vector inputIndices_L = + extractIndices(builder, loc, xVec_L, yVec_L, ivs[0], + inputColLastElemF32, inputRowLastElemF32, c0F32); + std::vector inputIndices_H = + extractIndices(builder, loc, xVec_H, yVec_H, ivs[0], + inputColLastElemF32, inputRowLastElemF32, c0F32); + + std::vector indexWeights; + Value xPos_temp = + builder.create(loc, xVecWeight, ivs[0]); + Value yPos_temp = + builder.create(loc, yVecWeight, ivs[0]); + + indexWeights.push_back( + valBound(builder, loc, xPos_temp, inputColLastElemF32, c0F32)); + indexWeights.push_back( + valBound(builder, loc, yPos_temp, inputRowLastElemF32, c0F32)); + + std::vector indexWeights_UnitComplements = { + builder.create(loc, c1F32, indexWeights[0]), + builder.create(loc, c1F32, indexWeights[1])}; + + auto ifop = builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + Value pixelVal_a = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_L[1], inputIndices_L[0], ivs1}); + Value pixelVal_b = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_H[1], inputIndices_L[0], ivs1}); + Value pixelVal_c = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_L[1], inputIndices_H[0], ivs1}); + Value pixelVal_d = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, inputIndices_H[1], inputIndices_H[0], ivs1}); + builder.create( + loc, + ValueRange{pixelVal_a, pixelVal_b, pixelVal_c, pixelVal_d}); + }, + [&](OpBuilder &builder, Location loc) { + Value pixelVal_a = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_L[1], inputIndices_L[0]}); + Value pixelVal_b = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_H[1], inputIndices_L[0]}); + Value pixelVal_c = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_L[1], inputIndices_H[0]}); + Value pixelVal_d = builder.create( + loc, builder.getF32Type(), input, + ValueRange{ivs0, ivs1, inputIndices_H[1], inputIndices_H[0]}); + builder.create( + loc, + ValueRange{pixelVal_a, pixelVal_b, pixelVal_c, pixelVal_d}); + }); + Value pixelVal_a = ifop.getResult(0); + Value pixelVal_b = ifop.getResult(1); + Value pixelVal_c = ifop.getResult(2); + Value pixelVal_d = ifop.getResult(3); + + Value weightVal1 = + builder.create(loc, indexWeights_UnitComplements[0], + indexWeights_UnitComplements[1]); + Value weightVal2 = builder.create( + loc, indexWeights[0], indexWeights_UnitComplements[1]); + Value weightVal3 = builder.create( + loc, indexWeights[1], indexWeights_UnitComplements[0]); + Value weightVal4 = builder.create(loc, indexWeights[0], + indexWeights[1]); + + Value interm1 = + builder.create(loc, pixelVal_a, weightVal1); + Value interm2 = + builder.create(loc, pixelVal_b, weightVal2); + Value interm3 = + builder.create(loc, pixelVal_c, weightVal3); + Value interm4 = + builder.create(loc, pixelVal_d, weightVal4); + + Value pixel_interm1 = + builder.create(loc, interm1, interm2); + Value pixel_interm2 = + builder.create(loc, interm3, interm4); + Value pixelVal = + builder.create(loc, pixel_interm1, pixel_interm2); + + // Value pixelVal = roundOff(builder, loc, pixel_interm3); + + builder.create( + loc, dataCondition, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, resIndices[1], resIndices[0], ivs1}); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, pixelVal, output, + ValueRange{ivs0, ivs1, resIndices[1], resIndices[0]}); + builder.create(loc); + }); + + builder.create(loc); + }); +} + // Helper function for resizing an image using nearest neighbour interpolation // mechanism. void NearestNeighbourInterpolationResizing( @@ -767,6 +961,39 @@ void NearestNeighbourInterpolationResizing( }); } +// Helper function for resizing 4D an image using nearest neighbour +// interpolation mechanism. +void NearestNeighbourInterpolationResizing4D( + OpBuilder &builder, Location loc, MLIRContext *ctx, + SmallVector lowerBounds, SmallVector upperBounds, + SmallVector steps, Value strideVal, Value input, Value output, + Value horizontalScalingFactorVec, Value verticalScalingFactorVec, + Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, + int64_t stride, Value c0, Value c0F32, Value dataCondition) { + affine::buildAffineLoopNest( + builder, loc, lowerBounds, upperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + Value ivs2F32 = indexToF32(builder, loc, ivs[2]); + Value yVec = builder.create(loc, vectorTy32, ivs2F32); + Value xVec = iotaVec(builder, loc, ctx, ivs[3], strideVal, vectorTy32, + c0, stride); + + Value resXVecInterm = + builder.create(loc, xVec, verticalScalingFactorVec); + Value resYVecInterm = builder.create( + loc, yVec, horizontalScalingFactorVec); + + Value resXVec = roundOff(builder, loc, resXVecInterm); + Value resYVec = roundOff(builder, loc, resYVecInterm); + + fillPixelsNearestNeighbour4D( + builder, loc, ivs[0], ivs[1], xVec, yVec, resXVec, resYVec, input, + output, c0, strideVal, outputRowLastElemF32, outputColLastElemF32, + inputRowLastElemF32, inputColLastElemF32, c0F32, dataCondition); + }); +} + // Helper function for resizing an image using bilinear interpolation mechanism. void BilinearInterpolationResizing( OpBuilder &builder, Location loc, MLIRContext *ctx, @@ -808,6 +1035,49 @@ void BilinearInterpolationResizing( }); } +// Helper function for resizing 4D an image using bilinear interpolation +// mechanism. +void BilinearInterpolationResizing4D( + OpBuilder &builder, Location loc, MLIRContext *ctx, + SmallVector lowerBounds, SmallVector upperBounds, + SmallVector steps, Value strideVal, Value input, Value output, + Value horizontalScalingFactorVec, Value verticalScalingFactorVec, + Value outputRowLastElemF32, Value outputColLastElemF32, + Value inputRowLastElemF32, Value inputColLastElemF32, VectorType vectorTy32, + int64_t stride, Value c0, Value c0F32, Value c1F32, Value dataCondition) { + affine::buildAffineLoopNest( + builder, loc, lowerBounds, upperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + Value ivs0F32 = indexToF32(builder, loc, ivs[2]); + Value yVec = builder.create(loc, vectorTy32, ivs0F32); + Value xVec = iotaVec(builder, loc, ctx, ivs[3], strideVal, vectorTy32, + c0, stride); + + Value xVecInterm = + builder.create(loc, xVec, verticalScalingFactorVec); + Value yVecInterm = builder.create( + loc, yVec, horizontalScalingFactorVec); + + Value xVecInterm_L = builder.create(loc, xVecInterm); + Value xVecInterm_H = builder.create(loc, xVecInterm); + + Value yVecInterm_L = builder.create(loc, yVecInterm); + Value yVecInterm_H = builder.create(loc, yVecInterm); + + Value xVecWeight = + builder.create(loc, xVecInterm, xVecInterm_L); + Value yVecWeight = + builder.create(loc, yVecInterm, yVecInterm_L); + + fillPixelsBilinearInterpolate4D( + builder, loc, ivs[0], ivs[1], xVec, yVec, xVecInterm_L, + yVecInterm_L, xVecInterm_H, yVecInterm_H, input, output, c0, + strideVal, xVecWeight, yVecWeight, outputRowLastElemF32, + outputColLastElemF32, inputRowLastElemF32, inputColLastElemF32, + c0F32, c1F32, dataCondition); + }); +} + // Function to test whether a value is equivalent to zero or not. Value zeroCond(OpBuilder &builder, Location loc, Type elemType, Value value, Value zeroElem) { diff --git a/midend/lib/Utils/GPUUtils.cpp b/midend/lib/Utils/GPUUtils.cpp new file mode 100644 index 0000000000..82058c8813 --- /dev/null +++ b/midend/lib/Utils/GPUUtils.cpp @@ -0,0 +1,536 @@ +//====- GPUUtils.cpp ------------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// The process in this file references the IREE project, +// which is hereby acknowledged. +// For the license of the IREE project +// please see: https://github.com/iree-org/iree/blob/main/LICENSE +// +//===----------------------------------------------------------------------===// +// +// This file implements GPU dialect specific utility functions for the buddy +// compiler ecosystem. +// +//===----------------------------------------------------------------------===// + +#ifndef UTILS_GPUUTILS_DEF +#define UTILS_GPUUTILS_DEF + +#include "mlir/Analysis/Liveness.h" +#include "mlir/Analysis/Presburger/IntegerRelation.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Transforms/TopologicalSortUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" + +#include "Utils/GPUUtils.h" + +#include + +#define DEBUG_TYPE "buddy-codegen-gpu-utils" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; + +static constexpr unsigned kShuffleBitWidth = 32; + +namespace mlir::buddy { +namespace gpu { + +/// Pick an unrolling order that will allow tensorcore operation to reuse LHS +/// register. This is needed to get good performance on sm_80 target. +std::optional> +gpuMmaUnrollOrder(vector::ContractionOp contract) { + SmallVector order; + // First make reduction the outer dimensions. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isReductionIterator(iter)) { + order.push_back(index); + } + } + + llvm::SmallDenseSet dims; + for (AffineExpr expr : contract.getIndexingMapsArray()[0].getResults()) { + dims.insert(expr.cast().getPosition()); + } + // Then parallel dimensions that are part of Lhs as we want to re-use Lhs. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && dims.count(index)) { + order.push_back(index); + } + } + // Then the remaining parallel loops. + for (auto [index, iter] : llvm::enumerate(contract.getIteratorTypes())) { + if (vector::isParallelIterator(iter) && !dims.count(index)) { + order.push_back(index); + } + } + return order; +} + +//===----------------------------------------------------------------------===// +// Reduction utils +//===----------------------------------------------------------------------===// + +/// Packs scalar element to it's vector equivalent. +/// (i.e f16 -> vector<1xf16> and f32 -> vector<1xf32>) +static Value promoteElementToVector(Location loc, OpBuilder &builder, + Value input) { + VectorType vectorTypeBroadcast = VectorType::get({1}, input.getType()); + Value vectorInput = + builder.create(loc, vectorTypeBroadcast, input); + return vectorInput; +} + +Value packVectorToSupportedWidth(Location loc, OpBuilder &builder, + Value input) { + LLVM_DEBUG({ + auto vecType = input.getType().cast(); + Type elementType = vecType.getElementType(); + assert(vecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() == + kShuffleBitWidth && + "vecSize * vecBitWidth needs to packable into 32-bitwidth."); + assert(elementType.isIntOrFloat() && + "Only int and float packing is supported."); + }); + VectorType packed32Type = VectorType::get({1}, builder.getI32Type()); + Value packedInputVec = + builder.create(loc, packed32Type, input); + Value packedInput = builder.create(loc, packedInputVec, 0); + return packedInput; +} + +Value unpackToVector(Location loc, OpBuilder &builder, Value packedInput, + VectorType targetVecType) { + LLVM_DEBUG({ + Type packedType = packedInput.getType(); + assert(packedType.isIntOrFloat() && "Only ints and floats are unpackable."); + Type elementType = targetVecType.getElementType(); + assert(targetVecType.getDimSize(0) * elementType.getIntOrFloatBitWidth() == + packedType.getIntOrFloatBitWidth() && + "packed width needs to be unpackable to vecSize * vecBitWidth."); + }); + Value packedVector = promoteElementToVector(loc, builder, packedInput); + Value unpackedVector = + builder.create(loc, targetVecType, packedVector); + return unpackedVector; +} + +//===----------------------------------------------------------------------===// +// getMmaNativeVectorSize +//===----------------------------------------------------------------------===// +/// Returns vector::ContractionOp operand's index where the result is used. +static std::optional +getVectorContractOpOperandId(vector::ContractionOp contractOp, + OpResult result) { + if (contractOp.getLhs() == result) + return 0; + if (contractOp.getRhs() == result) + return 1; + if (contractOp.getAcc() == result) + return 2; + return std::nullopt; +} + +/// Returns vector::ContractionOp operand's index where the +/// vector::TransferReadOp is consumed either consumed directly or via +/// vector::ExtractStridedSliceOp. +static std::optional +getVectorContractOpOperandIdForVectorReadOp(Operation *op) { + vector::ContractionOp contractOp; + + // Check if the vector::TransferReadOp is consumed directly by + // vector::ContractionOp. + if (op->use_empty()) + return std::nullopt; + Operation *firstLevelUser = *((op->getUsers()).begin()); + if (!firstLevelUser) + return std::nullopt; + if (auto contractOp = dyn_cast(firstLevelUser)) + return getVectorContractOpOperandId(contractOp, op->getResult(0)); + + // Check if the vector::TransferReadOp is consumed indirectly by + // vector::ContractionOp. Only check until the second level of use-def chain. + if (firstLevelUser->use_empty()) + return std::nullopt; + Operation *secondLevelUser = *((firstLevelUser->getUsers()).begin()); + if (!secondLevelUser) + return std::nullopt; + if (auto contractOp = dyn_cast(secondLevelUser)) + return getVectorContractOpOperandId(contractOp, + firstLevelUser->getResult(0)); + return std::nullopt; +} + +/// Helper function to return native size for MMA.SYNC-based operations. +std::optional> getMmaNativeVectorSize(Operation *op) { + // Shape of native Tensor Core GPU mma.sync operations. + int64_t mmaShapeM = 16; + int64_t mmaShapeN = 8; + int64_t mmaShapeK; + + // Shape the mma.sync warp-level operation. + if (auto contract = dyn_cast(op)) { + Type sourceType = contract.getLhsType().getElementType(); + + // Set mmaShapeK based on sourceType. + if (sourceType.isInteger(4)) + mmaShapeK = 64; + else if (sourceType.isInteger(8)) + mmaShapeK = 32; + else if (sourceType.isF16() || sourceType.isBF16()) + mmaShapeK = 16; + else if (sourceType.isF32()) + mmaShapeK = 8; + else { + LDBG("unsupported shape for vector.contract: "); + return std::nullopt; + } + + // Initialize/set the starting dims of the ranked shape, such as batch, + // to 1. + SmallVector mmaShape(contract.getIteratorTypes().size() - 3, 1); + mmaShape.append({mmaShapeM, mmaShapeN, mmaShapeK}); + LLVM_DEBUG({ + llvm::interleaveComma(mmaShape, DBGS() << "shape for vector.contract: "); + llvm::dbgs() << "\n"; + }); + return mmaShape; + } + + // Shape of warp-level vector write operation. + if (auto writeOp = dyn_cast(op)) { + if (writeOp.getVectorType().getRank() < 2) + return std::nullopt; + SmallVector outputShape(writeOp.getVectorType().getRank() - 2, 1); + outputShape.append({mmaShapeM, mmaShapeN}); + LLVM_DEBUG({ + llvm::interleaveComma(outputShape, + DBGS() << "shape for vector.xfer_write: "); + llvm::dbgs() << "\n"; + }); + return outputShape; + } + + // Shape of warp-level vector read (load) operation. + if (auto readOp = dyn_cast(op)) { + auto resultVectorType = + llvm::cast(readOp.getVector().getType()); + Type resultElementType = resultVectorType.getElementType(); + + std::optional operandId = + getVectorContractOpOperandIdForVectorReadOp(op); + if (!operandId) { + LLVM_DEBUG({ + DBGS() << "Failed to get operandId for vector::xfer_read: " << *op + << "\n"; + }); + return std::nullopt; + } + + // Loading F16 values from Shared Memory to Registers. + if (resultElementType.isF16() || resultElementType.isBF16()) { + // For matrixC. + if (*operandId == 2) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeN}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + + // For matrixA and matrixB. + if (*operandId == 0 || *operandId == 1) { + // MmaSyncOp input operands: matrixA and matrixB. + // LDSMx1, x2, x4: + // - LDSMx1 loads a 1 tile of 8x8. + // - LDSMx2 loads a 2 tiles of 8x8. + // - LDSMx4 loads a 4 tiles of 8x8. (in use) + // here uses the largest tiled load, i.e., LDSMx4. + + // MmaSyncOp source operand: matrixC. + // matrixC is also read/written in tiled block of 16x16. In the pass + // OptimizeVectorTransfer, matrixC reads are moved above the mainloop + // and writes are moved below the mainloop. Thus, mma.sync read/write + // accumulator inplace. + SmallVector readShape; + readShape.append({16, 16}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + } + + // Loading F32 values from Shared Memory to Registers. + if (resultElementType.isF32()) { + // Set mmaShapeK for F32 datatype mma.sync.f32.tf32.m16n8k8. + mmaShapeK = 8; + + // For matrixC. + if (*operandId == 2) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeN}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + // For matrixA. + if (*operandId == 0) { + SmallVector readShape; + readShape.append({mmaShapeM, mmaShapeK}); + LLVM_DEBUG({ + llvm::interleaveComma(readShape, + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return readShape; + } + // For matrixB. + if (*operandId == 1) { + // Do not use ldmatrix for matrixB. + // Transfer read ops may need different shapes based on how they are + // being used. For simplicity just match the shape used by the extract + // strided op. + VectorType sliceType; + for (Operation *users : op->getUsers()) { + auto extract = dyn_cast(users); + if (!extract) + return std::nullopt; + auto vecType = llvm::cast(extract.getResult().getType()); + if (sliceType && sliceType != vecType) + return std::nullopt; + sliceType = vecType; + } + LLVM_DEBUG({ + llvm::interleaveComma(sliceType.getShape(), + DBGS() << "shape for vector.xfer_read: "); + llvm::dbgs() << "\n"; + }); + return llvm::to_vector(sliceType.getShape()); + } + } + } + LDBG("unsupported shape for " << op->getName().getStringRef()); + return std::nullopt; +} + +bool hasSharedMemoryAddressSpace(MemRefType memrefType) { + auto addrSpace = llvm::dyn_cast_if_present( + memrefType.getMemorySpace()); + return addrSpace && addrSpace.getValue() == + mlir::gpu::GPUDialect::getWorkgroupAddressSpace(); +} + +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + Location loc, MemRefType allocLikeType, + ValueRange dynamicSizes, + std::optional alignment) { + IntegerAttr alignmentAttr = + alignment ? builder.getI64IntegerAttr(alignment.value()) : nullptr; + // For static case just create a new allocation in the entry block of the same + // size. No need to insert a subview. + if (dynamicSizes.empty()) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&funcOp.getBody().front()); + Value allocation = + builder.create(loc, allocLikeType, alignmentAttr); + if (std::is_same::value) { + builder.setInsertionPoint(funcOp.getBody().front().getTerminator()); + builder.create(loc, allocation); + } + return allocation; + } + + /// For the dynamic but bounded case, insert an allocation of the shape of the + /// bounds, and a subview of the required size to be used as a replacement. + SmallVector staticShape; + SmallVector subviewSizes; + staticShape.reserve(allocLikeType.getRank()); + subviewSizes.reserve(allocLikeType.getRank()); + + int index = 0; + for (auto dimSize : allocLikeType.getShape()) { + if (!ShapedType::isDynamic(dimSize)) { + staticShape.push_back(dimSize); + subviewSizes.push_back(builder.getIndexAttr(dimSize)); + continue; + } + Value dynamicSize = dynamicSizes[index++]; + auto ub = ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, dynamicSize, /*dim=*/std::nullopt, + /*stopCondition=*/nullptr, /*closedUB=*/true); + if (failed(ub)) { + return std::nullopt; + } + staticShape.push_back(ub.value()); + subviewSizes.push_back(dynamicSize); + } + SmallVector offsets(allocLikeType.getRank(), + builder.getIndexAttr(0)); + SmallVector strides(allocLikeType.getRank(), + builder.getIndexAttr(1)); + + Value allocation; + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&funcOp.getBody().front()); + auto allocationType = + MemRefType::get(staticShape, allocLikeType.getElementType()); + allocation = + builder.create(loc, allocationType, alignmentAttr); + } + + Value subviewOp = builder.create(loc, allocation, offsets, + subviewSizes, strides); + + if (std::is_same::value) { + builder.setInsertionPoint(funcOp.getBody().front().getTerminator()); + builder.create(loc, allocation); + } + return subviewOp; +} + +template +std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, OpBuilder &builder, + AllocLikeOpType allocLikeOp) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(allocLikeOp); + return hoistOneStaticallyBoundAllocation( + funcOp, builder, allocLikeOp.getLoc(), allocLikeOp.getType(), + allocLikeOp.getDynamicSizes(), allocLikeOp.getAlignment()); +} + +/// Some uses of a AllocLike can be replaced with a `memref.subview` +/// easily. Other uses (like a use in a `scf.yield` or `func.return`) are +/// non-trivial because of compatibility between types of different SSA values. +static bool isUseReplaceableWithSubview(OpOperand &use) { + Operation *user = use.getOwner(); + return isa(user); +} + +template +void hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp) { + SmallVector allocLikeOps; + + // Collect all allocLikes that are hoistable. + funcOp.walk([&](AllocLikeOpType allocLikeOp) { + if (allocLikeOp->getBlock() == &funcOp.getBody().front()) + return; + if (allocLikeOp.getDynamicSizes().empty()) { + allocLikeOps.push_back(allocLikeOp); + return; + } + if (llvm::all_of(allocLikeOp->getUses(), [](OpOperand &use) { + return isUseReplaceableWithSubview(use); + })) { + allocLikeOps.push_back(allocLikeOp); + return; + } + }); + + // Hoist the allocLikes and replace all uses. + for (auto allocLikeOp : allocLikeOps) { + // Record potential memref::DeallocOps to clean up after hoisting occurs. + SmallVector deallocOps; + for (Operation *user : allocLikeOp->getUsers()) { + auto dealloc = dyn_cast(user); + if (dealloc) + deallocOps.push_back(dealloc); + } + + LLVM_DEBUG({ + llvm::dbgs() << "Alloca Op : "; + allocLikeOp->dump(); + int numUses = std::distance(allocLikeOp.getResult().use_begin(), + allocLikeOp.getResult().use_end()); + llvm::dbgs() << " num Uses : " << numUses; + }); + std::optional replacement = + hoistOneStaticallyBoundAllocation(funcOp, rewriter, allocLikeOp); + if (!replacement) + continue; + LLVM_DEBUG({ + llvm::dbgs() << "Replacement : "; + replacement->dump(); + }); + Value replacementVal = replacement.value(); + rewriter.replaceOp(allocLikeOp, replacementVal); + + for (memref::DeallocOp deallocOp : deallocOps) + rewriter.eraseOp(deallocOp); + } +} + +/// Explicit instantiations for `hoistStaticallyBoundAllocationsInFunc` and +/// dependent functions. +template std::optional +hoistOneStaticallyBoundAllocation( + func::FuncOp funcOp, OpBuilder &builder, Location loc, + MemRefType allocLikeType, ValueRange dynamicSizes, + std::optional alignment); +template std::optional +hoistOneStaticallyBoundAllocation( + func::FuncOp funcOp, OpBuilder &builder, Location loc, + MemRefType allocLikeType, ValueRange dynamicSizes, + std::optional alignment); +template std::optional +hoistOneStaticallyBoundAllocation(func::FuncOp funcOp, + OpBuilder &builder, + memref::AllocOp allocLikeOp); +template std::optional +hoistOneStaticallyBoundAllocation( + func::FuncOp funcOp, OpBuilder &builder, memref::AllocaOp allocLikeOp); +template void +hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp); +template void +hoistStaticallyBoundAllocationsInFunc(RewriterBase &rewriter, + func::FuncOp funcOp); + +} // namespace gpu +} // namespace mlir::buddy +#endif // UTILS_GPUUTILS_DEF diff --git a/nix/buddy-llvm.nix b/nix/buddy-llvm.nix new file mode 100644 index 0000000000..af5bc1c867 --- /dev/null +++ b/nix/buddy-llvm.nix @@ -0,0 +1,76 @@ +{ stdenv +, cmake +, ninja +, python3 +, fetchFromGitHub +}: + +let + pythonEnv = python3.withPackages (ps: [ + ps.numpy + ps.pybind11 + ps.pyyaml + ps.ml-dtypes + ]); +in +stdenv.mkDerivation rec { + name = "llvm-for-buddy-mlir"; + version = "6c59f0e1b0fb56c909ad7c9aad4bde37dc006ae0"; + src = fetchFromGitHub { + owner = "llvm"; + repo = "llvm-project"; + rev = version; + hash = "sha256-bMJJ2q1hSh7m0ewclHOmIe7lOHv110rz/P7D3pw8Uiw="; + }; + + requiredSystemFeatures = [ "big-parallel" ]; + + propagatedBuildInputs = [ + pythonEnv + ]; + + nativeBuildInputs = [ + cmake + ninja + ]; + + cmakeDir = "../llvm"; + cmakeFlags = [ + "-DLLVM_ENABLE_PROJECTS=mlir" + "-DLLVM_TARGETS_TO_BUILD=host;RISCV" + "-DLLVM_ENABLE_ASSERTIONS=ON" + "-DCMAKE_BUILD_TYPE=Release" + # required for MLIR python binding + "-DMLIR_ENABLE_BINDINGS_PYTHON=ON" + # required for not, FileCheck... + "-DLLVM_INSTALL_UTILS=ON" + ]; + + outputs = [ "out" "lib" "dev" ]; + + postInstall = '' + # buddy-mlir have custom RVV backend that required LLVM backend, + # and those LLVM backend headers require this config.h header file. + # However for LLVM, this config.h is meant to be used on build phase only, + # so it will not be installed for cmake install. + # We have to do some hack + cp -v "include/llvm/Config/config.h" "$dev/include/llvm/Config/config.h" + + # move llvm-config to $dev to resolve a circular dependency + moveToOutput "bin/llvm-config*" "$dev" + + # move all lib files to $lib except lib/cmake + moveToOutput "lib" "$lib" + moveToOutput "lib/cmake" "$dev" + + # patch configuration files so each path points to the new $lib or $dev paths + substituteInPlace "$dev/lib/cmake/llvm/LLVMConfig.cmake" \ + --replace 'set(LLVM_BINARY_DIR "''${LLVM_INSTALL_PREFIX}")' 'set(LLVM_BINARY_DIR "'"$lib"'")' + substituteInPlace \ + "$dev/lib/cmake/llvm/LLVMExports-release.cmake" \ + "$dev/lib/cmake/mlir/MLIRTargets-release.cmake" \ + --replace "\''${_IMPORT_PREFIX}/lib/lib" "$lib/lib/lib" \ + --replace "\''${_IMPORT_PREFIX}/lib/objects-Release" "$lib/lib/objects-Release" \ + --replace "$out/bin/llvm-config" "$dev/bin/llvm-config" # patch path for llvm-config + ''; +} diff --git a/nix/buddy-mlir.nix b/nix/buddy-mlir.nix index b59d82275f..db10c6281c 100644 --- a/nix/buddy-mlir.nix +++ b/nix/buddy-mlir.nix @@ -1,51 +1,68 @@ -{ cmake, ninja, python3, llvmPackages_16, fetchFromGitHub, libjpeg, libpng, zlib-ng }: +{ lib +, stdenv +, buddy-llvm +, cmake +, ninja +, llvmPkgs +, libjpeg +, libpng +, zlib-ng +, ccls +}: let - # Using git submodule to obtain the llvm source is really slow. - # So here I use tarball to save time from git index. - llvmSrc = fetchFromGitHub { - owner = "llvm"; - repo = "llvm-project"; - rev = "6c59f0e1b0fb56c909ad7c9aad4bde37dc006ae0"; - sha256 = "sha256-bMJJ2q1hSh7m0ewclHOmIe7lOHv110rz/P7D3pw8Uiw"; - }; -in -# Use clang instead of gcc to build -llvmPackages_16.stdenv.mkDerivation { - pname = "buddy-mlir"; - version = "unstable-2023-11-07+rev=38bfd56"; - - srcs = [ - llvmSrc - ../. - ]; - sourceRoot = "llvm-project"; - unpackPhase = '' - sourceArray=($srcs) - cp -r ''${sourceArray[0]} llvm-project - cp -r ''${sourceArray[1]} buddy-mlir + self = stdenv.mkDerivation { + pname = "buddy-mlir"; + version = "unstable-2024-07-18"; - # Directories copied from nix store are read only - chmod -R u+w llvm-project buddy-mlir - ''; + src = with lib.fileset; toSource { + root = ./..; + fileset = unions [ + ./../backend + ./../cmake + ./../examples + ./../frontend + ./../midend + ./../tests + ./../tools + ./../thirdparty + ./../CMakeLists.txt + ./../flake.lock + ./../flake.nix + ]; + }; - # Tablegen in latest commit have bug. See llvm-projects issue #68166 - prePatch = "pushd $NIX_BUILD_TOP/llvm-project"; - patches = [ ./tblgen.patch ]; - postPatch = "popd"; + nativeBuildInputs = [ + cmake + ninja + llvmPkgs.bintools + ]; - nativeBuildInputs = [ cmake ninja python3 llvmPackages_16.bintools libjpeg libpng zlib-ng ]; + buildInputs = [ + buddy-llvm + ]; - cmakeDir = "../llvm"; - cmakeFlags = [ - "-DCMAKE_BUILD_TYPE=Release" - "-DLLVM_ENABLE_PROJECTS=mlir" - "-DLLVM_TARGETS_TO_BUILD=host;RISCV" - "-DLLVM_ENABLE_ASSERTIONS=ON" - "-DLLVM_USE_LINKER=lld" + cmakeFlags = [ + "-DMLIR_DIR=${buddy-llvm.dev}/lib/cmake/mlir" + "-DLLVM_DIR=${buddy-llvm.dev}/lib/cmake/llvm" + "-DLLVM_MAIN_SRC_DIR=${buddy-llvm.src}/llvm" + "-DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON" + "-DCMAKE_BUILD_TYPE=Release" + ]; - "-DLLVM_EXTERNAL_PROJECTS=buddy-mlir" - "-DLLVM_EXTERNAL_BUDDY_MLIR_SOURCE_DIR=../../buddy-mlir" - ]; + passthru = { + llvm = buddy-llvm; + devShell = self.overrideAttrs (old: { + nativeBuildInputs = old.nativeBuildInputs ++ [ + libjpeg + libpng + zlib-ng + ccls + ]; + }); + }; - checkTarget = "check-mlir check-buddy"; -} + # No need to do check, and it also takes too much time to finish. + doCheck = false; + }; +in +self diff --git a/nix/overlay.nix b/nix/overlay.nix index 19c97fc33c..767f23bdd4 100644 --- a/nix/overlay.nix +++ b/nix/overlay.nix @@ -1,6 +1,8 @@ final: prev: { # Add an alias here can help future migration - llvmPkgs = final.llvmPackages_16; - buddy-mlir = final.callPackage ./buddy-mlir.nix { }; + llvmPkgs = final.llvmPackages_17; + # Use clang instead of gcc to compile, to avoid gcc13 miscompile issue. + buddy-llvm = final.callPackage ./buddy-llvm.nix { stdenv = final.llvmPkgs.stdenv; }; + buddy-mlir = final.callPackage ./buddy-mlir.nix { stdenv = final.llvmPkgs.stdenv; }; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 024ce2cc38..2cffa98469 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,6 +21,13 @@ if(BUDDY_ENABLE_OPENCV) list(APPEND BUDDY_TEST_DEPENDS buddy-image-container-test) endif() +if(BUDDY_MLIR_ENABLE_DIP_LIB) + list(APPEND BUDDY_TEST_DEPENDS buddy-new-image-container-test-bmp) + if(BUDDY_ENABLE_PNG) + list(APPEND BUDDY_TEST_DEPENDS buddy-new-image-container-test-png) + endif() +endif() + add_lit_testsuite(check-tests "Running the buddy regression tests..." ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${BUDDY_TEST_DEPENDS} diff --git a/tests/Conversion/convert-memcpy-to-gpu.mlir b/tests/Conversion/convert-memcpy-to-gpu.mlir new file mode 100644 index 0000000000..63edfd8d02 --- /dev/null +++ b/tests/Conversion/convert-memcpy-to-gpu.mlir @@ -0,0 +1,23 @@ +// RUN: buddy-opt -convert-memcpy-to-gpu -canonicalize %s | FileCheck %s + +// CHECK: %memref = gpu.alloc () : memref<32x32xf32> +// CHECK: %memref_0 = gpu.alloc () : memref<32x32xf32> +// CHECK: gpu.dealloc %memref : memref<32x32xf32> +// CHECK: %alloc = memref.alloc() : memref<32x32xf32> +// CHECK: gpu.memcpy %alloc, %memref_0 : memref<32x32xf32>, memref<32x32xf32> +// CHECK: gpu.dealloc %memref_0 : memref<32x32xf32> +module attributes {gpu.container_module} { + func.func @matmul(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> { + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + gpu.launch_func @matmul_kernel::@matmul_kernel blocks in (%c1, %c1, %c1) threads in (%c64, %c2, %c1) + return %alloc : memref<32x32xf32> + } + gpu.module @matmul_kernel { + gpu.func @matmul_kernel() kernel attributes {gpu.known_block_size = array, gpu.known_grid_size = array} { + gpu.return + } + } +} diff --git a/tests/Conversion/legalize-shmem-outlining.mlir b/tests/Conversion/legalize-shmem-outlining.mlir new file mode 100644 index 0000000000..f80c9b761a --- /dev/null +++ b/tests/Conversion/legalize-shmem-outlining.mlir @@ -0,0 +1,26 @@ +// RUN: buddy-opt -legalize-shmem-outlining -canonicalize %s | FileCheck %s + +// CHECK: module attributes {gpu.container_module} +// CHECK: gpu.launch_func @matmul_kernel::@matmul_kernel blocks in (%c1, %c1, %c1) threads in (%c64, %c2, %c1) +// CHECK: return %alloc : memref<32x32xf32> +// CHECK: gpu.module @matmul_kernel { +// CHECK-NEXT: gpu.func @matmul_kernel() kernel attributes {gpu.known_block_size = array, gpu.known_grid_size = array} { +// CHECK-NEXT: gpu.return +// CHECK-NEXT: } +// CHECK-NEXT: } +module { + func.func @matmul(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) -> memref<32x32xf32> { + %alloc = memref.alloc() : memref<16x32xf32, 3> + %alloc_2 = memref.alloc() : memref<32x16xf32, 3> + %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c2 = arith.constant 2 : index + gpu.launch blocks(%arg2, %arg3, %arg4) in (%arg8 = %c1, %arg9 = %c1, %arg10 = %c1) threads(%arg5, %arg6, %arg7) in (%arg11 = %c64, %arg12 = %c2, %arg13 = %c1) { + gpu.terminator + } + memref.dealloc %alloc_2 : memref<32x16xf32, 3> + memref.dealloc %alloc : memref<16x32xf32, 3> + return %alloc_3 : memref<32x32xf32> + } +} diff --git a/tests/Dialect/BuddyGPU/hoist-static-alloc.mlir b/tests/Dialect/BuddyGPU/hoist-static-alloc.mlir new file mode 100644 index 0000000000..0578b2e001 --- /dev/null +++ b/tests/Dialect/BuddyGPU/hoist-static-alloc.mlir @@ -0,0 +1,92 @@ +// RUN: buddy-opt --split-input-file --transform-interpreter %s | FileCheck %s + +func.func @non_entry_bb_allocs() { + cf.br ^bb1 + ^bb1() : + %0 = memref.alloc() : memref<16xi32> + memref.dealloc %0 : memref<16xi32> + return +} +// CHECK-LABEL: func @non_entry_bb_allocs() +// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32> +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16xi32> +// CHECK-NEXT: cf.br ^bb1 +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: return + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.op<"func.func"> + transform.buddy.hoist_static_alloc %func : (!transform.op<"func.func">) -> () + transform.yield + } // @__transform_main +} // module + +// ----- + +#map = affine_map<(d0) -> (d0, 16)> +func.func @nested_op_alloc_subview_use_static(%arg0 : index, %o0 : index, %o1 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : i32 + scf.for %iv = %c0 to %arg0 step %c1 { + %0 = affine.min #map(%iv) + %1 = memref.alloc() : memref<16x16xi32> + %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref<16x16xi32> to memref> + memref.dealloc %1 : memref<16x16xi32> + scf.yield + } + return +} +// CHECK-LABEL: func @nested_op_alloc_subview_use_static( +// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32> +// CHECK: scf.for +// CHECK: %[[SIZE:.+]] = affine.min +// CHECK: memref.subview %[[ALLOC]] +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16x16xi32> + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.op<"func.func"> + transform.buddy.hoist_static_alloc %func : (!transform.op<"func.func">) -> () + transform.yield + } // @__transform_main +} // module + +// ----- + +#map = affine_map<(d0) -> (d0, 16)> +func.func @nested_op_alloc_subview_use_dynamic(%arg0 : index, %o0 : index, %o1 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : i32 + scf.for %iv = %c0 to %arg0 step %c1 { + %0 = affine.min #map(%iv) + %1 = memref.alloc(%0, %0) : memref + %2 = memref.subview %1[%o0, %o1][%c1, %0][1, 1] : memref to memref> + memref.dealloc %1 : memref + scf.yield + } + return +} + +// CHECK-LABEL: func @nested_op_alloc_subview_use_dynamic( +// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<16x16xi32> +// CHECK: scf.for +// CHECK: %[[SIZE:.+]] = affine.min +// CHECK: %subview = memref.subview %[[ALLOC]][0, 0] [%[[SIZE]], %[[SIZE]]] [1, 1] +// CHECK: %subview_0 = memref.subview %subview[%arg1, %arg2] [%c1, %0] [1, 1] +// CHECK-NEXT: } +// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<16x16xi32> + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.op<"func.func"> + transform.buddy.hoist_static_alloc %func : (!transform.op<"func.func">) -> () + transform.yield + } // @__transform_main +} // module diff --git a/tests/Dialect/BuddyGPU/transform-dialect-vector-to-nvgpu-mma.mlir b/tests/Dialect/BuddyGPU/transform-dialect-vector-to-nvgpu-mma.mlir new file mode 100644 index 0000000000..3aee301cea --- /dev/null +++ b/tests/Dialect/BuddyGPU/transform-dialect-vector-to-nvgpu-mma.mlir @@ -0,0 +1,97 @@ +// RUN: buddy-opt --split-input-file --transform-interpreter %s | FileCheck %s + + +#matmat_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] + +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func.func @wmma(%a: memref<16x16xf32>, %b: memref<16x16xf32>, %c: memref<16x16xf32>) { + %c0 = arith.constant 0: index + %cst = arith.constant 0.0: f32 + %va = vector.transfer_read %a[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vb = vector.transfer_read %b[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vc = vector.transfer_read %c[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + + // CHECK-NOT: vector.contract + // CHECK: gpu.subgroup_mma_compute + %vres = vector.contract #matmat_trait %va, %vb, %vc + : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> + vector.transfer_write %vres, %c[%c0, %c0]: vector<16x16xf32>, memref<16x16xf32> + return +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main( + %module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.buddy.unroll_vectors_gpu_mma_sync + } : !transform.any_op + transform.buddy.vector.vector_to_mma_conversion %func { use_wmma } : (!transform.any_op) -> () + + // Apply canonicalization post-hoc to trigger DCE and pass the test + // (i.e. all vector.contract are dead). + // TODO: consider having the vector_to_mma_conversion do the DCE automatically. + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + + transform.yield + } +} + +// ----- + +#matmat_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +func.func @mma_sync(%a: memref<16x16xf32>, %b: memref<16x16xf32>, %c: memref<16x16xf32>) { + %c0 = arith.constant 0: index + %cst = arith.constant 0.0: f32 + %va = vector.transfer_read %a[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vb = vector.transfer_read %b[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + %vc = vector.transfer_read %c[%c0, %c0], %cst: memref<16x16xf32>, vector<16x16xf32> + + // CHECK-NOT: vector.contract + // CHECK: nvgpu.mma.sync{{.*}} tf32Enabled} + %vres = vector.contract #matmat_trait %va, %vb, %vc + : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> + vector.transfer_write %vres, %c[%c0, %c0]: vector<16x16xf32>, memref<16x16xf32> + return +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main( + %module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.buddy.unroll_vectors_gpu_mma_sync + } : !transform.any_op + transform.buddy.vector.vector_to_mma_conversion %func { use_mma_sync } : (!transform.any_op) -> () + + // Apply canonicalization post-hoc to trigger DCE and pass the test + // (i.e. all vector.contract are dead). + // TODO: consider having the vector_to_mma_conversion do the DCE automatically. + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + + transform.yield + } +} diff --git a/tests/Dialect/DIP/resize4D_nchw_lowering.mlir b/tests/Dialect/DIP/resize4D_nchw_lowering.mlir new file mode 100644 index 0000000000..92f6cf3728 --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nchw_lowering.mlir @@ -0,0 +1,7 @@ +// RUN: buddy-opt --lower-dip %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nchw_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: memref.store %57, %arg3[%arg4, %arg5, %56, %54] : memref + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Dialect/DIP/resize4D_nchw_roundtrip.mlir b/tests/Dialect/DIP/resize4D_nchw_roundtrip.mlir new file mode 100644 index 0000000000..3850a88dc5 --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nchw_roundtrip.mlir @@ -0,0 +1,25 @@ +// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nchw_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nchw_NEAREST_NEIGHBOUR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nchw_BILINEAR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nchw_BILINEAR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nchw BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nchw BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Dialect/DIP/resize4D_nhwc_lowering.mlir b/tests/Dialect/DIP/resize4D_nhwc_lowering.mlir new file mode 100644 index 0000000000..79291b9340 --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nhwc_lowering.mlir @@ -0,0 +1,7 @@ +// RUN: buddy-opt --lower-dip %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nhwc_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: memref.store %57, %arg3[%arg4, %arg5, %56, %54] : memref + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Dialect/DIP/resize4D_nhwc_roundtrip.mlir b/tests/Dialect/DIP/resize4D_nhwc_roundtrip.mlir new file mode 100644 index 0000000000..46bea2fdbd --- /dev/null +++ b/tests/Dialect/DIP/resize4D_nhwc_roundtrip.mlir @@ -0,0 +1,25 @@ +// RUN: buddy-opt -verify-diagnostics %s | buddy-opt | FileCheck %s + +func.func @buddy_resize4d_nhwc_NEAREST_NEIGHBOUR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nhwc_NEAREST_NEIGHBOUR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc NEAREST_NEIGHBOUR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nhwc_BILINEAR_INTERPOLATION_f32(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} + +func.func @buddy_resize4d_nhwc_BILINEAR_INTERPOLATION_f64(%input : memref, %horizontal_scaling_factor : f32, %vertical_scaling_factor : f32, %output : memref) -> () { + // CHECK: dip.resize_4d_nhwc BILINEAR_INTERPOLATION{{.*}} : memref, f32, f32, memref + dip.resize_4d_nhwc BILINEAR_INTERPOLATION %input, %horizontal_scaling_factor, %vertical_scaling_factor, %output : memref, f32, f32, memref + return +} diff --git a/tests/Interface/core/AudioContainerTest.cpp b/tests/Interface/core/AudioContainerTest.cpp index a31c7800f7..684584c3a4 100644 --- a/tests/Interface/core/AudioContainerTest.cpp +++ b/tests/Interface/core/AudioContainerTest.cpp @@ -20,22 +20,72 @@ // RUN: buddy-audio-container-test 2>&1 | FileCheck %s +#include "AudioFile.h" #include #include using namespace std; int main() { - dap::Audio aud("../../../../tests/Interface/core/NASA_Mars.wav"); - auto &audioFile = aud.getAudioFile(); + // --------------------------------------------------------------------------- + // 1. Print Decoded Reuslts using Buddy Audio Container + // --------------------------------------------------------------------------- + + // Read and decode audio file with Buddy Audio Container. + dap::Audio aud("../../../../tests/Interface/core/TestAudio.wav"); + + // CHECK: WAV + fprintf(stderr, "%s\n", aud.getFormatName().c_str()); + // CHECK: 16 + fprintf(stderr, "%d\n", aud.getBitDepth()); + // CHECK: 77040 + fprintf(stderr, "%lu\n", aud.getSamplesNum()); + // CHECK: 1 + fprintf(stderr, "%d\n", aud.getChannelsNum()); + // CHECK: 16000 + fprintf(stderr, "%d\n", aud.getSampleRate()); + // CHECK: -0.000153 + fprintf(stderr, "%f\n", aud.getData()[3]); + // CHECK: -0.000275 + fprintf(stderr, "%f\n", aud.getData()[4]); + + // --------------------------------------------------------------------------- + // 2. Compare Encoded results using Buddy Audio Container and AudioFile.h + // --------------------------------------------------------------------------- + + // Encode the audio data and save it to a file using the Buddy Audio Container + string filePath = "./buddyEncodeResult.wav"; + aud.saveToFile(filePath, "WAVE"); + + // Print metadata and sample values using the Buddy Audio Container. + dap::Audio audContainer(filePath); + // CHECK: 16 + fprintf(stderr, "%d\n", audContainer.getBitDepth()); + // CHECK: 77040 + fprintf(stderr, "%lu\n", audContainer.getSamplesNum()); + // CHECK: 1 + fprintf(stderr, "%d\n", audContainer.getChannelsNum()); + // CHECK: 16000 + fprintf(stderr, "%d\n", audContainer.getSampleRate()); + // CHECK: -0.000122 + fprintf(stderr, "%f\n", audContainer.getData()[3]); + // CHECK: -0.000244 + fprintf(stderr, "%f\n", audContainer.getData()[4]); + + // Print metadata and sample values using the third-party (AudioFile.h). + AudioFile audFile(filePath); + // CHECK: 16 + fprintf(stderr, "%d\n", audFile.getBitDepth()); + // CHECK: 77040 + fprintf(stderr, "%d\n", audFile.getNumSamplesPerChannel()); // CHECK: 1 - fprintf(stderr, "%u\n", audioFile.getNumChannels()); - // CHECK: 24 - fprintf(stderr, "%u\n", audioFile.getBitDepth()); - // CHECK: 2000000 - fprintf(stderr, "%u\n", audioFile.getNumSamplesPerChannel()); - // CHECK: 100000 - fprintf(stderr, "%u\n", audioFile.getSampleRate()); + fprintf(stderr, "%d\n", audFile.getNumChannels()); + // CHECK: 16000 + fprintf(stderr, "%d\n", audFile.getSampleRate()); + // CHECK: -0.000122 + fprintf(stderr, "%f\n", audFile.getSample(0, 3)); + // CHECK: -0.000244 + fprintf(stderr, "%f\n", audFile.getSample(0, 4)); return 0; } diff --git a/tests/Interface/core/CMakeLists.txt b/tests/Interface/core/CMakeLists.txt index c82cb5a283..b84ae71aef 100644 --- a/tests/Interface/core/CMakeLists.txt +++ b/tests/Interface/core/CMakeLists.txt @@ -17,10 +17,27 @@ if(BUDDY_MLIR_ENABLE_DIP_LIB OR BUDDY_ENABLE_OPENCV) ) endif() +if(BUDDY_MLIR_ENABLE_DIP_LIB) + set(NEW_DIP_LIBS "") + if(BUDDY_ENABLE_PNG) + list(APPEND NEW_DIP_LIBS ${PNG_LIBRARIES}) + _add_test_executable(buddy-new-image-container-test-png + NewImageContainerTestPng.cpp + LINK_LIBS + ${NEW_DIP_LIBS} + ) + endif() + _add_test_executable(buddy-new-image-container-test-bmp + NewImageContainerTestBmp.cpp + LINK_LIBS + ${NEW_DIP_LIBS} + ) +endif() + _add_test_executable(buddy-audio-container-test AudioContainerTest.cpp ) _add_test_executable(buddy-text-container-test TextContainerTest.cpp -) \ No newline at end of file +) diff --git a/tests/Interface/core/ContainerTest.cpp b/tests/Interface/core/ContainerTest.cpp index 3d80b3375d..c58ae1249b 100644 --- a/tests/Interface/core/ContainerTest.cpp +++ b/tests/Interface/core/ContainerTest.cpp @@ -66,7 +66,7 @@ int main() { MemRef testArrayNoMallocConstructor(arrayShape, false, 0); // CHECK: {{(nil)|0x0}} fprintf(stderr, "%p\n", testArrayNoMallocConstructor.getData()); - + //===--------------------------------------------------------------------===// // Test array constructor. //===--------------------------------------------------------------------===// @@ -77,6 +77,16 @@ int main() { // CHECK: 5.0 fprintf(stderr, "%f\n", testArrayConstructor[5]); + //===--------------------------------------------------------------------===// + // Test 2-D array constructor. + //===--------------------------------------------------------------------===// + float data2D[2][3] = {{0.0, 1.0, 2.0}, {0.0, 1.0, 2.0}}; + MemRef test2DArrayConstructor(&data2D[0][0], sizes); + // CHECK: 0.0 + fprintf(stderr, "%f\n", test2DArrayConstructor.getData()[0]); + // CHECK: 1.0 + fprintf(stderr, "%f\n", test2DArrayConstructor[4]); + //===--------------------------------------------------------------------===// // Test copy constructor and copy assignment operator. //===--------------------------------------------------------------------===// @@ -118,7 +128,6 @@ int main() { //===--------------------------------------------------------------------===// // Test overloading bracket operator. //===--------------------------------------------------------------------===// - float data1[6] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}; MemRef testBracketOperator1(data1, sizes); // CHECK: 0.0 diff --git a/tests/Interface/core/ImageContainerTest.cpp b/tests/Interface/core/ImageContainerTest.cpp index 442f79ca6c..f84bc4237f 100644 --- a/tests/Interface/core/ImageContainerTest.cpp +++ b/tests/Interface/core/ImageContainerTest.cpp @@ -24,6 +24,54 @@ #include #include +bool compare_flt(float a, float b) { return (std::abs(a - b) < FLT_EPSILON); } + +template +bool testImgcvnorm(cv::Mat testImgcv, Img testImg, bool norm = false, + intptr_t sizes[N] = nullptr) { + int cvn = testImgcv.dims; + if (cvn != N) + return false; + for (size_t i = 0; i < N; ++i) { + if (testImgcv.size[i] != testImg.getSizes()[i]) + return false; + } + T *data = testImg.getData(); + if (N == 2) { + size_t k = 0; + for (int i = 0; i < testImg.getSizes()[0]; ++i) { + for (int j = 0; j < testImg.getSizes()[1]; ++j) { + if (norm ? !compare_flt(data[k], (T)testImgcv.at(i, j)) + : !compare_flt(data[k], (T)testImgcv.at(i, j))) + return false; + + ++k; + } + } + return true; + } else if (N == 4) { + if (sizes == nullptr) { + return false; + } + size_t k = 0; + // NCHW layout + for (size_t batch = 0; batch < sizes[0]; ++batch) { + for (size_t channel = 0; channel < sizes[1]; ++channel) { + T *chandata = testImgcv.ptr(batch, channel); + for (size_t row = 0; row < sizes[2]; ++row) { + for (size_t col = 0; col < sizes[3]; ++col) { + if (!compare_flt(data[k], chandata[row * sizes[3] + col])) + return false; + + ++k; + } + } + } + } + return true; + } +} + int main() { // The original test image is a gray scale image, and the pixel values are as // follows: @@ -33,7 +81,7 @@ int main() { // 195.0, 210.0, 225.0, 240.0 // The test running directory is in /tests/Interface/core, so the // `imread` function uses the following relative path. - + //===--------------------------------------------------------------------===// // Test bmp format image. //===--------------------------------------------------------------------===// @@ -75,12 +123,10 @@ int main() { fprintf(stderr, "%ld\n", testCopyConstructor2.getSize()); // CHECK: 60.0 fprintf(stderr, "%f\n", testCopyConstructor2[3]); - Img testCopyConstructor3 = - Img(grayimage_bmp); + Img testCopyConstructor3 = Img(grayimage_bmp); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor3[0]); - Img *testCopyConstructor4 = - new Img(grayimage_bmp); + Img *testCopyConstructor4 = new Img(grayimage_bmp); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor4->getData()[0]); delete testCopyConstructor4; @@ -132,7 +178,50 @@ int main() { const Img testBracketOperator2(grayimage_bmp); // CHECK: 240.0 fprintf(stderr, "%f\n", testBracketOperator2[15]); + //===--------------------------------------------------------------------===// + // Test Opencv Image without norm + //===--------------------------------------------------------------------===// + cv::Mat testImgcvbmp = + cv::imread("../../../../tests/Interface/core/TestGrayImage.bmp", + cv::IMREAD_GRAYSCALE); + Img testImgbmp(testImgcvbmp); + bool testbmp = testImgcvnorm(testImgcvbmp, testImgbmp); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmp); + //===--------------------------------------------------------------------===// + // Test Opencv Image with norm + //===--------------------------------------------------------------------===// + Img testImgbmpnorm(testImgcvbmp, nullptr, true); + cv::Mat checkimgbmp(testImgcvbmp.rows, testImgcvbmp.cols, CV_32FC1); + testImgcvbmp.convertTo(checkimgbmp, CV_32FC1, 1.f / 255); + bool testbmp1 = testImgcvnorm(checkimgbmp, testImgbmpnorm, true); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmp1); + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) without norm (NCHW) + //===--------------------------------------------------------------------===// + std::vector testbmpvec = {testImgcvbmp, testImgcvbmp}; + cv::Mat testcvbmpblob = cv::dnn::blobFromImages( + testbmpvec, 1.0, cv::Size(testImgcvbmp.rows, testImgcvbmp.cols)); + intptr_t sizesbmp[4] = {testcvbmpblob.size[0], testcvbmpblob.size[1], + testcvbmpblob.size[2], testcvbmpblob.size[3]}; + Img testImgbmpblob(testcvbmpblob, sizesbmp, false); + bool testbmpN4 = + testImgcvnorm(testcvbmpblob, testImgbmpblob, false, sizesbmp); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmpN4); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) with norm (NCHW) + //===--------------------------------------------------------------------===// + cv::Mat testcvbmpblob2 = cv::dnn::blobFromImages( + testbmpvec, 1.0f / 255.0, cv::Size(testImgcvbmp.rows, testImgcvbmp.cols)); + Img testImgbmpblobnorm(testcvbmpblob, sizesbmp, true); + bool testbmpN4norm = testImgcvnorm( + testcvbmpblob2, testImgbmpblobnorm, true, sizesbmp); + // CHECK: 1 + fprintf(stderr, "%d \n", testbmpN4norm); //===--------------------------------------------------------------------===// // Test jpeg format image. @@ -175,12 +264,10 @@ int main() { fprintf(stderr, "%ld\n", testCopyConstructor6.getSize()); // CHECK: 60.0 fprintf(stderr, "%f\n", testCopyConstructor6[3]); - Img testCopyConstructor7 = - Img(grayimage_jpg); + Img testCopyConstructor7 = Img(grayimage_jpg); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor7[0]); - Img *testCopyConstructor8 = - new Img(grayimage_jpg); + Img *testCopyConstructor8 = new Img(grayimage_jpg); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor8->getData()[0]); delete testCopyConstructor8; @@ -233,6 +320,51 @@ int main() { // CHECK: 240.0 fprintf(stderr, "%f\n", testBracketOperator4[15]); + //===--------------------------------------------------------------------===// + // Test Opencv Image without norm + //===--------------------------------------------------------------------===// + cv::Mat testImgcvjpg = + cv::imread("../../../../tests/Interface/core/TestGrayImage.jpg", + cv::IMREAD_GRAYSCALE); + Img testImgjpg(testImgcvjpg); + bool testjpg = testImgcvnorm(testImgcvjpg, testImgjpg); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpg); + + //===--------------------------------------------------------------------===// + // Test Opencv Image with norm + //===--------------------------------------------------------------------===// + Img testImgjpgnorm(testImgcvjpg, nullptr, true); + cv::Mat checkimgjpg(testImgcvjpg.rows, testImgcvjpg.cols, CV_32FC1); + testImgcvjpg.convertTo(checkimgjpg, CV_32FC1, 1.f / 255); + bool testjpg1 = testImgcvnorm(checkimgjpg, testImgjpgnorm, true); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpg1); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) without norm (NCHW) + //===--------------------------------------------------------------------===// + std::vector testjpgvec = {testImgcvjpg, testImgcvjpg}; + cv::Mat testcvjpgblob = cv::dnn::blobFromImages( + testjpgvec, 1.0, cv::Size(testImgcvjpg.rows, testImgcvjpg.cols)); + intptr_t sizesjpg[4] = {testcvjpgblob.size[0], testcvjpgblob.size[1], + testcvjpgblob.size[2], testcvjpgblob.size[3]}; + Img testImgjpgblob(testcvjpgblob, sizesjpg, false); + bool testjpgN4 = + testImgcvnorm(testcvjpgblob, testImgjpgblob, false, sizesjpg); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpgN4); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) with norm (NCHW) + //===--------------------------------------------------------------------===// + cv::Mat testcvjpgblob2 = cv::dnn::blobFromImages( + testjpgvec, 1.0f / 255.0, cv::Size(testImgcvjpg.rows, testImgcvjpg.cols)); + Img testImgjpgblobnorm(testcvjpgblob, sizesjpg, true); + bool testjpgN4norm = testImgcvnorm( + testcvjpgblob2, testImgjpgblobnorm, true, sizesjpg); + // CHECK: 1 + fprintf(stderr, "%d \n", testjpgN4norm); //===--------------------------------------------------------------------===// // Test png format image. @@ -275,12 +407,10 @@ int main() { fprintf(stderr, "%ld\n", testCopyConstructor10.getSize()); // CHECK: 60.0 fprintf(stderr, "%f\n", testCopyConstructor10[3]); - Img testCopyConstructor11 = - Img(grayimage_png); + Img testCopyConstructor11 = Img(grayimage_png); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor11[0]); - Img *testCopyConstructor12 = - new Img(grayimage_png); + Img *testCopyConstructor12 = new Img(grayimage_png); // CHECK: 15.0 fprintf(stderr, "%f\n", testCopyConstructor12->getData()[0]); delete testCopyConstructor12; @@ -332,6 +462,52 @@ int main() { const Img testBracketOperator6(grayimage_png); // CHECK: 240.0 fprintf(stderr, "%f\n", testBracketOperator6[15]); - + + //===--------------------------------------------------------------------===// + // Test Opencv Image without norm + //===--------------------------------------------------------------------===// + cv::Mat testImgcvpng = + cv::imread("../../../../tests/Interface/core/TestGrayImage.png", + cv::IMREAD_GRAYSCALE); + Img testImgpng(testImgcvpng); + bool testpng = testImgcvnorm(testImgcvpng, testImgpng); + /// CHECK: 1 + fprintf(stderr, "%d \n", testpng); + + //===--------------------------------------------------------------------===// + // Test Opencv Image with norm + //===--------------------------------------------------------------------===// + Img testImgpngnorm(testImgcvpng, nullptr, true); + cv::Mat checkimgpng(testImgcvpng.rows, testImgcvpng.cols, CV_32FC1); + testImgcvpng.convertTo(checkimgpng, CV_32FC1, 1.f / 255); + bool testpng1 = testImgcvnorm(checkimgpng, testImgpngnorm, true); + // CHECK: 1 + fprintf(stderr, "%d \n", testpng1); + + ///===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) without norm (NCHW) + //===--------------------------------------------------------------------===// + std::vector testpngvec = {testImgcvpng, testImgcvpng}; + cv::Mat testcvpngblob = cv::dnn::blobFromImages( + testpngvec, 1.0, cv::Size(testImgcvpng.rows, testImgcvpng.cols)); + intptr_t sizespng[4] = {testcvpngblob.size[0], testcvpngblob.size[1], + testcvpngblob.size[2], testcvpngblob.size[3]}; + Img testImgpngblob(testcvpngblob, sizespng, false); + bool testpngN4 = + testImgcvnorm(testcvpngblob, testImgpngblob, false, sizespng); + // CHECK: 1 + fprintf(stderr, "%d \n", testpngN4); + + //===--------------------------------------------------------------------===// + // Test Opencv blob Image (batched images) with norm (NCHW) + //===--------------------------------------------------------------------===// + cv::Mat testcvpngblob2 = cv::dnn::blobFromImages( + testpngvec, 1.0f / 255.0, cv::Size(testImgcvpng.rows, testImgcvpng.cols)); + Img testImgpngblobnorm(testcvpngblob, sizespng, true); + bool testpngN4norm = testImgcvnorm( + testcvpngblob2, testImgpngblobnorm, true, sizespng); + // CHECK: 1 + fprintf(stderr, "%d \n", testpngN4norm); + return 0; -} \ No newline at end of file +} diff --git a/tests/Interface/core/NewImageContainerTestBmp.cpp b/tests/Interface/core/NewImageContainerTestBmp.cpp new file mode 100644 index 0000000000..13f1a9c7cf --- /dev/null +++ b/tests/Interface/core/NewImageContainerTestBmp.cpp @@ -0,0 +1,171 @@ +//===- NewImageContainerTestBmp.cpp ---------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the image container test file. +// +//===----------------------------------------------------------------------===// + +// RUN: buddy-new-image-container-test-bmp 2>&1 | FileCheck %s + +#include + +int main() { + //===--------------------------------------------------------------------===// + // Test new image container - bmp format image. + //===--------------------------------------------------------------------===// + // Default Gray Scale + dip::Image bmp32bitGrayDefault( + "../../../../tests/Interface/core/TestImage-gray.bmp", + dip::DIP_GRAYSCALE); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp32bitGrayDefault.getFormatName().c_str()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp32bitGrayDefault.getWidth()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp32bitGrayDefault.getHeight()); + // CHECK: 32 + fprintf(stderr, "%d\n", bmp32bitGrayDefault.getBitDepth()); + // CHECK: 7 + fprintf(stderr, "%f\n", bmp32bitGrayDefault.getData()[0]); + // Gray Scale + Normalization + dip::Image bmp32bitGrayNorm( + "../../../../tests/Interface/core/TestImage-gray.bmp", dip::DIP_GRAYSCALE, + true /* norm */); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp32bitGrayNorm.getFormatName().c_str()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp32bitGrayNorm.getWidth()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp32bitGrayNorm.getHeight()); + // CHECK: 32 + fprintf(stderr, "%d\n", bmp32bitGrayNorm.getBitDepth()); + // CHECK: 0.027451 + fprintf(stderr, "%f\n", bmp32bitGrayNorm.getData()[0]); + + // BMP 24bit Default Gray Scale + dip::Image bmp24bitGrayDefault( + "../../../../tests/Interface/core/TestImage-gray-24bit.bmp", + dip::DIP_GRAYSCALE); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp24bitGrayDefault.getFormatName().c_str()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp24bitGrayDefault.getWidth()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp24bitGrayDefault.getHeight()); + // CHECK: 24 + fprintf(stderr, "%d\n", bmp24bitGrayDefault.getBitDepth()); + // CHECK: 7 + fprintf(stderr, "%f\n", bmp24bitGrayDefault.getData()[0]); + // BMP 24bit Gray Scale + Normalization + dip::Image bmp24bitGrayNorm( + "../../../../tests/Interface/core/TestImage-gray-24bit.bmp", + dip::DIP_GRAYSCALE, true /* norm */); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp24bitGrayNorm.getFormatName().c_str()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp24bitGrayNorm.getWidth()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp24bitGrayNorm.getHeight()); + // CHECK: 24 + fprintf(stderr, "%d\n", bmp24bitGrayNorm.getBitDepth()); + // CHECK: 0.027451 + fprintf(stderr, "%f\n", bmp24bitGrayNorm.getData()[0]); + + // BMP 16bit Default Gray Scale + dip::Image bmp16bitGrayDefault( + "../../../../tests/Interface/core/TestImage-gray-16bit-rgb565.bmp", + dip::DIP_GRAYSCALE); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp16bitGrayDefault.getFormatName().c_str()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp16bitGrayDefault.getWidth()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp16bitGrayDefault.getHeight()); + // CHECK: 16 + fprintf(stderr, "%d\n", bmp16bitGrayDefault.getBitDepth()); + // CHECK: 2 + fprintf(stderr, "%f\n", bmp16bitGrayDefault.getData()[0]); + // BMP 16bit Gray Scale + Normalization + dip::Image bmp16bitGrayNorm( + "../../../../tests/Interface/core/TestImage-gray-16bit-rgb565.bmp", + dip::DIP_GRAYSCALE, true /* norm */); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp16bitGrayNorm.getFormatName().c_str()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp16bitGrayNorm.getWidth()); + // CHECK: 28 + fprintf(stderr, "%ld\n", bmp16bitGrayNorm.getHeight()); + // CHECK: 16 + fprintf(stderr, "%d\n", bmp16bitGrayNorm.getBitDepth()); + // CHECK: 0.007843 + fprintf(stderr, "%f\n", bmp16bitGrayNorm.getData()[0]); + + dip::Image bmp32bitRGBDefault( + "../../../../tests/Interface/core/TestImage-RGB-32bit.bmp", dip::DIP_RGB); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp32bitRGBDefault.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp32bitRGBDefault.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp32bitRGBDefault.getHeight()); + // CHECK: 32 + fprintf(stderr, "%d\n", bmp32bitRGBDefault.getBitDepth()); + // CHECK: 116 + fprintf(stderr, "%f\n", bmp32bitRGBDefault.getData()[0]); + + dip::Image bmp32bitRGBNorm( + "../../../../tests/Interface/core/TestImage-RGB-32bit.bmp", dip::DIP_RGB, + true); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp32bitRGBNorm.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp32bitRGBNorm.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp32bitRGBNorm.getHeight()); + // CHECK: 32 + fprintf(stderr, "%d\n", bmp32bitRGBNorm.getBitDepth()); + // CHECK: 0.45490 + fprintf(stderr, "%f\n", bmp32bitRGBNorm.getData()[0]); + + dip::Image bmp24bitRGBDefault( + "../../../../tests/Interface/core/TestImage-RGB-24bit.bmp", dip::DIP_RGB); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp24bitRGBDefault.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp24bitRGBDefault.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp24bitRGBDefault.getHeight()); + // CHECK: 24 + fprintf(stderr, "%d\n", bmp24bitRGBDefault.getBitDepth()); + // CHECK: 116 + fprintf(stderr, "%f\n", bmp24bitRGBDefault.getData()[0]); + + dip::Image bmp24bitRGBNorm( + "../../../../tests/Interface/core/TestImage-RGB-24bit.bmp", dip::DIP_RGB, + true); + // CHECK: BMP + fprintf(stderr, "%s\n", bmp24bitRGBNorm.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp24bitRGBNorm.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", bmp24bitRGBNorm.getHeight()); + // CHECK: 24 + fprintf(stderr, "%d\n", bmp24bitRGBNorm.getBitDepth()); + // CHECK: 0.45490 + fprintf(stderr, "%f\n", bmp24bitRGBNorm.getData()[0]); + + return 0; +} diff --git a/tests/Interface/core/NewImageContainerTestPng.cpp b/tests/Interface/core/NewImageContainerTestPng.cpp new file mode 100644 index 0000000000..0f1dea37c3 --- /dev/null +++ b/tests/Interface/core/NewImageContainerTestPng.cpp @@ -0,0 +1,82 @@ +//===- NewImageContainerTestPng.cpp ---------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This is the image container test file. +// +//===----------------------------------------------------------------------===// + +// RUN: buddy-new-image-container-test-png 2>&1 | FileCheck %s + +#include + +int main() { + // Default Gray Scale + dip::Image pngGrayDefault( + "../../../../tests/Interface/core/TestGrayImage.png", dip::DIP_GRAYSCALE); + // CHECK: PNG + fprintf(stderr, "%s\n", pngGrayDefault.getFormatName().c_str()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayDefault.getWidth()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayDefault.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngGrayDefault.getBitDepth()); + // CHECK: 15 + fprintf(stderr, "%f\n", pngGrayDefault.getData()[0]); + // Gray Scale + Normalization + dip::Image pngGrayNorm( + "../../../../tests/Interface/core/TestGrayImage.png", dip::DIP_GRAYSCALE, + true /* norm */); + // CHECK: PNG + fprintf(stderr, "%s\n", pngGrayNorm.getFormatName().c_str()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayNorm.getWidth()); + // CHECK: 4 + fprintf(stderr, "%ld\n", pngGrayNorm.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngGrayNorm.getBitDepth()); + // CHECK: 0.058824 + fprintf(stderr, "%f\n", pngGrayNorm.getData()[0]); + + dip::Image pngRGBDefault( + "../../../../tests/Interface/core/TestImage-RGB.png", dip::DIP_RGB); + // CHECK: PNG + fprintf(stderr, "%s\n", pngRGBDefault.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBDefault.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBDefault.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngRGBDefault.getBitDepth()); + // CHECK: 144 + fprintf(stderr, "%f\n", pngRGBDefault.getData()[0]); + + dip::Image pngRGBNorm( + "../../../../tests/Interface/core/TestImage-RGB.png", dip::DIP_RGB, + true /* norm */); + // CHECK: PNG + fprintf(stderr, "%s\n", pngRGBNorm.getFormatName().c_str()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBNorm.getWidth()); + // CHECK: 224 + fprintf(stderr, "%ld\n", pngRGBNorm.getHeight()); + // CHECK: 8 + fprintf(stderr, "%d\n", pngRGBNorm.getBitDepth()); + // CHECK: 0.5647 + fprintf(stderr, "%f\n", pngRGBNorm.getData()[0]); + + return 0; +} diff --git a/tests/Interface/core/TestAudio.wav b/tests/Interface/core/TestAudio.wav new file mode 100644 index 0000000000..069c2329ef Binary files /dev/null and b/tests/Interface/core/TestAudio.wav differ diff --git a/tests/Interface/core/TestImage-RGB-24bit.bmp b/tests/Interface/core/TestImage-RGB-24bit.bmp new file mode 100644 index 0000000000..948a1ea796 Binary files /dev/null and b/tests/Interface/core/TestImage-RGB-24bit.bmp differ diff --git a/tests/Interface/core/TestImage-RGB-32bit.bmp b/tests/Interface/core/TestImage-RGB-32bit.bmp new file mode 100644 index 0000000000..c415c8dc32 Binary files /dev/null and b/tests/Interface/core/TestImage-RGB-32bit.bmp differ diff --git a/tests/Interface/core/TestImage-RGB.png b/tests/Interface/core/TestImage-RGB.png new file mode 100644 index 0000000000..e5f89a6ee4 Binary files /dev/null and b/tests/Interface/core/TestImage-RGB.png differ diff --git a/tests/Interface/core/TestImage-gray-16bit-rgb565.bmp b/tests/Interface/core/TestImage-gray-16bit-rgb565.bmp new file mode 100644 index 0000000000..d4a43393d3 Binary files /dev/null and b/tests/Interface/core/TestImage-gray-16bit-rgb565.bmp differ diff --git a/tests/Interface/core/TestImage-gray-24bit.bmp b/tests/Interface/core/TestImage-gray-24bit.bmp new file mode 100644 index 0000000000..6591e87be8 Binary files /dev/null and b/tests/Interface/core/TestImage-gray-24bit.bmp differ diff --git a/tests/Interface/core/TestImage-gray.bmp b/tests/Interface/core/TestImage-gray.bmp new file mode 100644 index 0000000000..7a9e02a295 Binary files /dev/null and b/tests/Interface/core/TestImage-gray.bmp differ diff --git a/tests/Interface/core/lit.local.cfg b/tests/Interface/core/lit.local.cfg index 630d46bf16..f5c2722550 100644 --- a/tests/Interface/core/lit.local.cfg +++ b/tests/Interface/core/lit.local.cfg @@ -1,2 +1,9 @@ if config.buddy_enable_opencv != 'ON': config.excludes.add('ImageContainerTest.cpp') + +if config.buddy_mlir_enable_dip_lib != 'ON': + config.excludes.add('NewImageContainerTestBmp.cpp') + config.excludes.add('NewImageContainerTestPng.cpp') + +if config.buddy_enable_png != 'ON': + config.excludes.add('NewImageContainerTestPng.cpp') diff --git a/tests/Python/test_verbose_mode.py b/tests/Python/test_verbose_mode.py index 82279ac2ed..e96c177652 100644 --- a/tests/Python/test_verbose_mode.py +++ b/tests/Python/test_verbose_mode.py @@ -23,6 +23,7 @@ def foo(x, y): # Test the dynamo compiler verbose mode. dynamo_compiler_verbose_on = DynamoCompiler(verbose=True) graphs = dynamo_compiler_verbose_on.importer(foo, *(float32_in1, float32_in2)) +graphs[0].lower_to_top_level_ir() # Test output in the verbose mode. # CHECK: placeholder @@ -31,9 +32,36 @@ def foo(x, y): # CHECK: call_function # CHECK: output +# CHECK: ====================Graph Node==================== +# CHECK: Node: mul +# CHECK: Type: OpType.BroadcastType +# CHECK: Arguments: ['arg0_1', 'arg1_1'] +# CHECK: Parents: ['arg0_1', 'arg1_1'] +# CHECK: Children: ['add'] +# CHECK: --------------------MLIR OPS-------------------- +# CHECK: %{{.*}} = "tosa.mul" + +# CHECK: ====================Graph Node==================== +# CHECK: Node: add +# CHECK: Type: OpType.BroadcastType +# CHECK: Arguments: ['mul', 'arg0_1'] +# CHECK: Parents: ['mul', 'arg0_1'] +# CHECK: Children: ['output'] +# CHECK: --------------------MLIR OPS-------------------- +# CHECK: %{{.*}} = "tosa.add" + +# CHECK: ====================Graph Node==================== +# CHECK: Node: output +# CHECK: Type: OpType.GetItemType +# CHECK: Arguments: ['add'] +# CHECK: Parents: [] +# CHECK: Children: [] +# CHECK: --------------------MLIR OPS-------------------- + # Test the dynamo compiler verbose mode off. dynamo_compiler_verbose_off = DynamoCompiler(verbose=False) graphs = dynamo_compiler_verbose_off.importer(foo, *(float32_in1, float32_in2)) +graphs[0].lower_to_top_level_ir() # Ensure no output is printed when the verbose mode is off. # CHECK-NOT: . diff --git a/tests/lit.cfg.py b/tests/lit.cfg.py index 4cf5e245f7..2982e2851f 100644 --- a/tests/lit.cfg.py +++ b/tests/lit.cfg.py @@ -107,4 +107,9 @@ if config.buddy_enable_opencv == "ON": tools.append("buddy-image-container-test") +if config.buddy_mlir_enable_dip_lib == "ON": + tools.append("buddy-new-image-container-test-bmp") + if config.buddy_enable_png == "ON": + tools.append("buddy-new-image-container-test-png") + llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tests/lit.site.cfg.py.in b/tests/lit.site.cfg.py.in index 6a5e5f37e3..0011f056f7 100644 --- a/tests/lit.site.cfg.py.in +++ b/tests/lit.site.cfg.py.in @@ -28,12 +28,14 @@ config.enable_libcxx = "@LLVM_ENABLE_LIBCXX@" config.host_ldflags = '@HOST_LDFLAGS@' config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' -config.llvm_build_dir = "@LLVM_PROJECT_BUILD_DIR@" +config.llvm_build_dir = "@LLVM_BINARY_DIR@" config.host_arch = "@HOST_ARCH@" config.buddy_src_root = "@CMAKE_SOURCE_DIR@" config.buddy_obj_root = "@CMAKE_BINARY_DIR@" config.buddy_tools_dir = "@BUDDY_BINARY_DIR@" config.buddy_enable_opencv = "@BUDDY_ENABLE_OPENCV@" +config.buddy_enable_png = "@BUDDY_ENABLE_PNG@" +config.buddy_mlir_enable_dip_lib = "@BUDDY_MLIR_ENABLE_DIP_LIB@" config.buddy_mlir_enable_python_packages = "@BUDDY_MLIR_ENABLE_PYTHON_PACKAGES@" config.buddy_python_packages_dir = "@BUDDY_MLIR_PYTHON_PACKAGES_DIR@" config.mlir_runner_utils_dir = "@LLVM_LIBS_DIR@" diff --git a/tools/buddy-llc/CMakeLists.txt b/tools/buddy-llc/CMakeLists.txt index 06b2a9c32f..cac8ac20a0 100644 --- a/tools/buddy-llc/CMakeLists.txt +++ b/tools/buddy-llc/CMakeLists.txt @@ -29,7 +29,7 @@ set(LLVM_LINK_COMPONENTS ) add_llvm_tool(buddy-llc - ${LLVM_PROJECT_SOURCE_DIR}/llvm/tools/llc/llc.cpp + ${LLVM_MAIN_SRC_DIR}/tools/llc/llc.cpp DEPENDS buddy_intrinsics_gen diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 69a897cef2..0abb857fad 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -20,6 +20,7 @@ target_link_libraries(buddy-opt LowerDIPPass BuddyDAP LowerDAPPass + ExtendDAPPass DAPVectorization BuddyRVV LowerRVVPass @@ -28,6 +29,7 @@ target_link_libraries(buddy-opt MatMulParallelVectorization TransposeOptimization ConvOptimization + DepthwiseConvOptimization VectorExp LowerVectorExpPass BuddyGemmini @@ -36,4 +38,17 @@ target_link_libraries(buddy-opt SchedulingOnDevices LowerSche FuncBufferizeDynamicOffset + MLIRGPUPasses + BuddyGPUTransformOPs + MLIRTestTransforms + MLIRTestTransformDialect + MLIRTransforms + MLIRTransformUtils + MatMulTransposeBVec + MLIRGPUPasses + BuddyGPUTransformOPs + MLIRTestTransforms + MLIRTestTransformDialect + MLIRTransforms + MLIRTransformUtils ) diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 5a1c286344..08e172f8bc 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -40,36 +40,49 @@ #include "DAP/DAPOps.h" #include "DIP/DIPDialect.h" #include "DIP/DIPOps.h" -#include "RVV/RVVDialect.h" -#include "VectorExp/VectorExpDialect.h" -#include "VectorExp/VectorExpOps.h" +#include "GPU/TransformOps.h" #include "Gemmini/GemminiDialect.h" #include "Gemmini/GemminiOps.h" +#include "RVV/RVVDialect.h" #include "Sche/ScheDialect.h" #include "Sche/ScheOps.h" +#include "VectorExp/VectorExpDialect.h" +#include "VectorExp/VectorExpOps.h" namespace mlir { namespace buddy { void registerConvVectorizationPass(); void registerPointwiseConvToGemmPass(); +void registerPointwiseConvToGemmForNhwcFhwcPass(); void registerPoolingVectorizationPass(); void registerLowerBudPass(); void registerLowerDIPPass(); +void registerBatchMatMulOptimizePass(); +void registerBatchMatMulTileOptimizePass(); +void registerBatchMatMuSCFOptimize(); void registerLowerDAPPass(); +void registerExtendDAPPass(); void registerDAPVectorizePass(); void registerLowerRVVPass(); -void registerBatchMatMulOptimizePass(); void registerMatMulOptimizePass(); void registerMatMulVectorizationPass(); void registerMatMulParallelVectorizationPass(); void registerTransposeOptimizationPass(); void registerConvOptimizePass(); +void registerConvNhwcFhwcOptimizePass(); +void registerConvNhwcFhwcTileOptimizePass(); +void registerDepthwiseConv2DNhwcHwcOptimizePass(); void registerLowerVectorExpPass(); void registerLowerGemminiPass(); void registerLowerLinalgToGemminiPass(); void registerDeviceSchedulePass(); void registerLowerSchePass(); void registerFuncBufferizeDynamicOffsetPass(); +void registerConvertMemcpyToGPUPass(); +void registerLegalizeShmemOutliningPass(); +void registerMatMulTransposeBVecPass(); +void registerConvertMemcpyToGPUPass(); +void registerLegalizeShmemOutliningPass(); } // namespace buddy } // namespace mlir @@ -84,6 +97,7 @@ int main(int argc, char **argv) { mlir::buddy::registerLowerBudPass(); mlir::buddy::registerLowerDIPPass(); mlir::buddy::registerLowerDAPPass(); + mlir::buddy::registerExtendDAPPass(); // Register Vectorization of DAP Dialect. mlir::buddy::registerDAPVectorizePass(); mlir::buddy::registerLowerRVVPass(); @@ -93,14 +107,24 @@ int main(int argc, char **argv) { // Register Several Optimize Pass. mlir::buddy::registerMatMulOptimizePass(); + mlir::buddy::registerBatchMatMulOptimizePass(); + mlir::buddy::registerBatchMatMulTileOptimizePass(); + mlir::buddy::registerBatchMatMuSCFOptimize(); mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerMatMulParallelVectorizationPass(); - mlir::buddy::registerBatchMatMulOptimizePass(); mlir::buddy::registerTransposeOptimizationPass(); mlir::buddy::registerConvOptimizePass(); + mlir::buddy::registerConvNhwcFhwcOptimizePass(); + mlir::buddy::registerConvNhwcFhwcTileOptimizePass(); + mlir::buddy::registerDepthwiseConv2DNhwcHwcOptimizePass(); mlir::buddy::registerDeviceSchedulePass(); mlir::buddy::registerLowerSchePass(); mlir::buddy::registerFuncBufferizeDynamicOffsetPass(); + mlir::buddy::registerMatMulTransposeBVecPass(); + + // Register gpu passes + mlir::buddy::registerConvertMemcpyToGPUPass(); + mlir::buddy::registerLegalizeShmemOutliningPass(); mlir::DialectRegistry registry; // Register all MLIR core dialects. @@ -117,6 +141,8 @@ int main(int argc, char **argv) { buddy::sche::ScheDialect>(); // clang-format on + mlir::buddy::registerBuddyGPUTransformOps(registry); + return mlir::failed( mlir::MlirOptMain(argc, argv, "buddy-mlir optimizer driver", registry)); }