Skip to content

Commit

Permalink
Merge branch 'buddy-compiler:main' into LeNet-GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
WuXintong123 authored Oct 27, 2024
2 parents 82b92f8 + 2b2a8df commit 0c473dc
Show file tree
Hide file tree
Showing 45 changed files with 5,449 additions and 504 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ if(BUDDY_MLIR_ENABLE_DIP_LIB)
find_package(PNG REQUIRED)
endif()

if(BUDDY_ENABLE_PNG)
add_definitions(-DBUDDY_ENABLE_PNG)
find_package(PNG REQUIRED)
endif()

# Generate libraries into `lib` of build directory.
set(LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)

Expand Down
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build
$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH}
```

To configure the build environment for using image processing libraries, follow these steps:

```
$ cmake -G Ninja .. \
-DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \
-DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DCMAKE_BUILD_TYPE=RELEASE \
-DBUDDY_MLIR_ENABLE_DIP_LIB=ON \
-DBUDDY_ENABLE_PNG=ON
$ ninja
$ ninja check-buddy
```

To build buddy-mlir with custom LLVM sources:

```
Expand Down
75 changes: 75 additions & 0 deletions examples/BuddyMatmul/linalg-transposematmulb-f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// RUN: buddy-opt %s \
// RUN: -matmul-transpose-b-vectorization \
// RUN: -convert-linalg-to-affine-loops \
// RUN: -lower-affine \
// RUN: -convert-vector-to-scf \
// RUN: -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-math-to-llvm \
// RUN: -convert-math-to-libm \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-func-to-llvm \
// RUN: -expand-strided-metadata \
// RUN: -finalize-memref-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

func.func private @printMemrefF32(memref<*xf32>)

func.func @test(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>) {
linalg.matmul_transpose_b
ins(%a, %b: memref<?x?xf32>, memref<?x?xf32>)
outs(%c: memref<?x?xf32>)
return
}

func.func @alloc_f32(%arg0: index, %arg1: index, %arg4: f32) -> memref<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
scf.for %idx0 = %c0 to %arg0 step %c1 {
scf.for %idx1 = %c0 to %arg1 step %c1 {
memref.store %arg4, %0[%idx0, %idx1] : memref<?x?xf32>
}
}
return %0 : memref<?x?xf32>
}

func.func @main(){
%c32 = arith.constant 32 : index
%c1024 = arith.constant 1024 : index
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%f1 = arith.constant 1.0 : f32

%m0 = call @alloc_f32(%c32,%c1024, %f1) : (index, index, f32) -> memref<?x?xf32>
%m1 = call @alloc_f32(%c32,%c1024, %f1) : (index, index, f32) -> memref<?x?xf32>
%m2 = call @alloc_f32(%c32,%c32, %f0) : (index, index, f32) -> memref<?x?xf32>

call @test(%m0, %m1, %m2) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()

%printed_m2 = memref.cast %m2 : memref<?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [32, 32] strides = [32, 1] data =
// CHECK-NEXT: [
// CHECK: [1024{{(, 1024)*}}]
call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> ()

%m3 = call @alloc_f32(%c3,%c3, %f1) : (index, index, f32) -> memref<?x?xf32>
%m4 = call @alloc_f32(%c3,%c3, %f1) : (index, index, f32) -> memref<?x?xf32>
%m5 = call @alloc_f32(%c3,%c3, %f0) : (index, index, f32) -> memref<?x?xf32>

call @test(%m3, %m4, %m5) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()

%printed_m5 = memref.cast %m5 : memref<?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] data =
// CHECK-NEXT: [
// CHECK: [3{{(, 3)*}}]
call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> ()

return
}
18 changes: 18 additions & 0 deletions examples/BuddyMatmul/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,21 @@ linalg-batchmatmul-f32-run:
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-matmul-transpose-b-f32-run:
@${BUDDY_OPT} ./linalg-transposematmulb-f32.mlir\
-matmul-transpose-b-vectorization \
-convert-linalg-to-affine-loops \
-lower-affine \
-convert-vector-to-scf \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-arith-to-llvm \
-convert-func-to-llvm \
-expand-strided-metadata \
-finalize-memref-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
1 change: 0 additions & 1 deletion examples/BuddyMobileNetV3/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_custom_command(
OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/arg0.data
${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/arg1.data
${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/forward.mlir
${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/subgraph0.mlir
COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyMobileNetV3/buddy-mobilenetv3-import.py
Expand Down
3 changes: 2 additions & 1 deletion examples/BuddyMobileNetV3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ $ cmake -G Ninja .. \
-DCMAKE_BUILD_TYPE=RELEASE \
-DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \
-DPython3_EXECUTABLE=$(which python3) \
-DBUDDY_MLIR_ENABLE_DIP_LIB=ON
-DBUDDY_MLIR_ENABLE_DIP_LIB=ON \
-DBUDDY_ENABLE_PNG=ON
$ ninja
$ ninja check-buddy
```
Expand Down
21 changes: 14 additions & 7 deletions examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@
"The environment variable 'MOBILENETV3_MODEL_PATH' is not set or is invalid."
)

model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1, pretrained=True)
model = models.mobilenet_v3_small(
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1, pretrained=True
)
model = model.eval()

# Remove the num_batches_tracked attribute.
for layer in model.modules():
if isinstance(layer, torch.nn.BatchNorm2d):
if hasattr(layer, "num_batches_tracked"):
del layer.num_batches_tracked

# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
Expand Down Expand Up @@ -68,11 +76,10 @@


float32_param = np.concatenate(
[param.detach().numpy().reshape([-1]) for param in params if param.dtype == torch.float32]
[
param.detach().numpy().reshape([-1])
for param in params
if param.dtype == torch.float32
]
)
float32_param.tofile(Path(current_path) / "arg0.data")

int64_param = np.concatenate(
[param.detach().numpy().reshape([-1]) for param in params if param.dtype == torch.int64]
)
int64_param.tofile(Path(current_path) / "arg1.data")
65 changes: 31 additions & 34 deletions examples/BuddyMobileNetV3/buddy-mobilenetv3-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,43 +33,43 @@ const std::string ImgName = "dog.png";
// Declare the mobilenet C interface.
extern "C" void _mlir_ciface_forward(MemRef<float, 2> *output,
MemRef<float, 1> *arg0,
MemRef<long long, 1> *arg1,
MemRef<float, 4> *input);

/// Print [Log] label in bold blue format.
void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; }

void loadParameters(const std::string &floatParamPath,
const std::string &int64ParamPath,
MemRef<float, 1> &floatParam,
MemRef<long long, 1> &int64Param) {
std::ifstream floatParamFile(floatParamPath, std::ios::in | std::ios::binary);
if (!floatParamFile.is_open()) {
std::string errMsg = "Failed to open float param file: " +
std::filesystem::canonical(floatParamPath).string();
throw std::runtime_error(errMsg);
/// Load parameters into data container.
void loadParameters(const std::string &paramFilePath,
MemRef<float, 1> &params) {
const auto loadStart = std::chrono::high_resolution_clock::now();
// Open the parameter file in binary mode.
std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary);
if (!paramFile.is_open()) {
throw std::runtime_error("[Error] Failed to open params file!");
}
floatParamFile.read(reinterpret_cast<char *>(floatParam.getData()),
floatParam.getSize() * sizeof(float));
if (floatParamFile.fail()) {
throw std::runtime_error("Failed to read float param file");
printLogLabel();
std::cout << "Loading params..." << std::endl;
printLogLabel();
// Print the canonical path of the parameter file.
std::cout << "Params file: " << std::filesystem::canonical(paramFilePath)
<< std::endl;
// Read the parameter data into the provided memory reference.
paramFile.read(reinterpret_cast<char *>(params.getData()),
sizeof(float) * (params.getSize()));
if (paramFile.fail()) {
throw std::runtime_error("Error occurred while reading params file!");
}
floatParamFile.close();

std::ifstream int64ParamFile(int64ParamPath, std::ios::in | std::ios::binary);
if (!int64ParamFile.is_open()) {
std::string errMsg = "Failed to open int64 param file: " +
std::filesystem::canonical(int64ParamPath).string();
throw std::runtime_error(errMsg);
}
int64ParamFile.read(reinterpret_cast<char *>(int64Param.getData()),
int64Param.getSize() * sizeof(long long));
if (int64ParamFile.fail()) {
throw std::runtime_error("Failed to read int64 param file");
}
int64ParamFile.close();
paramFile.close();
const auto loadEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> loadTime =
loadEnd - loadStart;
printLogLabel();
std::cout << "Params load time: " << (double)(loadTime.count()) / 1000
<< "s\n"
<< std::endl;
}


// Softmax function.
void softmax(float *input, size_t size) {
size_t i;
Expand Down Expand Up @@ -124,13 +124,10 @@ int main() {

// Load model parameters from the specified file.
std::string paramsDir = mobilenetDir + "/arg0.data";
std::string intDir = mobilenetDir + "/arg1.data";
MemRef<float, 1> paramsContainerf32({ParamsSize});
MemRef<long long, 1> ParamsContainerInt64({34});
loadParameters(paramsDir, intDir, paramsContainerf32, ParamsContainerInt64);
MemRef<float, 1> paramsContainer({ParamsSize});
loadParameters(paramsDir, paramsContainer);
// Call the forward function of the model.
_mlir_ciface_forward(&output, &paramsContainerf32, &ParamsContainerInt64,
&inputResize);
_mlir_ciface_forward(&output, &paramsContainer, &inputResize);

auto out = output.getData();
softmax(out, 1000);
Expand Down
7 changes: 7 additions & 0 deletions examples/DAPDialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,10 @@ target_link_libraries(buddy-whisper-preprocess
BuddyLibDAP
mlir_c_runner_utils
)

add_executable(buddy-rfft RFFT.cpp)
add_dependencies(buddy-rfft buddy-opt)
target_link_libraries(buddy-rfft
BuddyLibDAP
mlir_c_runner_utils
)
75 changes: 75 additions & 0 deletions examples/DAPDialect/RFFT.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//===- RFFT.cpp - Example of DAP RFFT Operation ---------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//
//
// An example of the RFFT function from Whisper Preprocessor operation.
//
//===----------------------------------------------------------------------===//

#include <buddy/DAP/DAP.h>
#include <chrono>
#include <fstream>
#include <iostream>

#define testLength 840

using namespace dap;
using namespace std;

// Print [Log] label in bold blue format.
void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; }

// Write preprocessing results to a text file.
void printResult(MemRef<double, 1> &outputMemRef) {
ofstream fout("whisperPreprocessResultRFFT.txt");
// Print title.
fout << "-----------------------------------------" << std::endl;
fout << "[ Buddy RFFT Result ]" << std::endl;
fout << "-----------------------------------------" << std::endl;
// Print reuslt data.
for (int i = 0; i < testLength; ++i) {
fout << outputMemRef[i] << std::endl;
}
fout.close();
}

int main() {
// Print the title of this example.
const std::string title = "RFFT Operation Powered by Buddy Compiler";
std::cout << "\033[33;1m" << title << "\033[0m" << std::endl;

double *inputAlign = new double[testLength];
for (int i = 0; i < testLength; ++i) {
inputAlign[i] = static_cast<double>(i);
}
intptr_t inputSizes[1] = {testLength};
MemRef<double, 1> inputMemRef(inputAlign, inputSizes);

printLogLabel();
std::cout << "Running RFFT operation" << std::endl;
const auto loadStart = std::chrono::high_resolution_clock::now();
dap::RFFT(&inputMemRef);
const auto loadEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> loadTime =
loadEnd - loadStart;
printLogLabel();
std::cout << "RFFT time: " << (double)(loadTime.count()) / 1000
<< "s\n"
<< std::endl;

printResult(inputMemRef);

return 0;
}
Loading

0 comments on commit 0c473dc

Please sign in to comment.