Skip to content

Commit

Permalink
Merge branch 'main' into poolingnhwcmax
Browse files Browse the repository at this point in the history
  • Loading branch information
FloatingcloudKnight authored Dec 27, 2024
2 parents b48655f + 76947f0 commit 5131c79
Show file tree
Hide file tree
Showing 75 changed files with 55,715 additions and 412 deletions.
2 changes: 1 addition & 1 deletion examples/BuddyBert/bert-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ int main() {

/// Execute forward inference of the model.
_mlir_ciface_forward(&result, &arg0, &arg1, &pureStrContainer,
&attention_mask, &token_type_ids);
&token_type_ids, &attention_mask);

const auto inferenceEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> inferenceTime =
Expand Down
58 changes: 57 additions & 1 deletion examples/BuddyLeNet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,62 @@ SET_TARGET_PROPERTIES(LENET PROPERTIES LINKER_LANGUAGE C)
add_executable(buddy-lenet-run buddy-lenet-main.cpp)
target_link_directories(buddy-lenet-run PRIVATE ${LLVM_LIBRARY_DIR})

set(BUDDY_LENET_LIBS LENET mlir_c_runner_utils ${PNG_LIBRARIES})
if(NOT DEFINED BUDDY_ENABLE_PNG)
message(FATAL_ERROR "To run LeNet inference, the png library is required. Please define BUDDY_ENABLE_PNG for CMake.")
endif()
set(BUDDY_LENET_LIBS LENET mlir_c_runner_utils mlir_async_runtime mlir_runner_utils ${PNG_LIBRARIES})

target_link_libraries(buddy-lenet-run ${BUDDY_LENET_LIBS})

set(ONE_SHOT_BUFFERIZE_OPTION "bufferize-function-boundaries=1 function-boundary-type-conversion=identity-layout-map")
set(LOWER_TO_NVVM_OPTION "cubin-chip=sm_80 cubin-features=+ptx71 cubin-format=fatbin")
set(CONVERT_MEMCPY_TO_GPU_OPTION "process-args=1")
set(CONVERT_MEMCPY_TO_GPU_OPTION_DISABLE_PROCESS_ARG "process-args=0")

add_custom_command(
OUTPUT forward_gpu.o
COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir
-buffer-deallocation
-canonicalize -cse -expand-strided-metadata -convert-memcpy-to-gpu -gpu-async-region |
${LLVM_TOOLS_BINARY_DIR}/mlir-opt -llvm-request-c-wrappers --gpu-to-llvm |
${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_TOOLS_BINARY_DIR}/llvm-as |
${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/forward_gpu.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/forward.mlir
COMMENT "Building forward_gpu.o"
VERBATIM)

add_custom_command(
OUTPUT subgraph0_gpu.o
COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" |
${BUDDY_BINARY_DIR}/buddy-opt
-one-shot-bufferize
-func-bufferize-dynamic-offset
-convert-linalg-to-parallel-loops
-canonicalize
-gpu-map-parallel-loops
-convert-parallel-loops-to-gpu
-gpu-kernel-outlining
-buffer-deallocation
-canonicalize
-cse |
${BUDDY_BINARY_DIR}/buddy-opt -convert-memcpy-to-gpu=${CONVERT_MEMCPY_TO_GPU_OPTION_DISABLE_PROCESS_ARG} -gpu-async-region -canonicalize |
${LLVM_TOOLS_BINARY_DIR}/mlir-opt -llvm-request-c-wrappers --test-lower-to-nvvm=${LOWER_TO_NVVM_OPTION} |
${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_TOOLS_BINARY_DIR}/llvm-as |
${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyLeNet/subgraph0_gpu.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyLeNet/subgraph0.mlir
COMMENT "Building subgraph0_gpu.o"
VERBATIM)

add_library(LENET_GPU STATIC subgraph0_gpu.o forward_gpu.o)

SET_TARGET_PROPERTIES(LENET_GPU PROPERTIES LINKER_LANGUAGE C)

add_executable(buddy-lenet-run-gpu buddy-lenet-main.cpp)
target_link_directories(buddy-lenet-run-gpu PRIVATE ${LLVM_LIBRARY_DIR})

set(BUDDY_LENET_LIBS_GPU LENET_GPU mlir_c_runner_utils mlir_async_runtime mlir_runner_utils mlir_cuda_runtime ${PNG_LIBRARIES})

target_link_libraries(buddy-lenet-run-gpu ${BUDDY_LENET_LIBS_GPU})
53 changes: 48 additions & 5 deletions examples/BuddyLeNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,47 @@ $ python pytorch-lenet-train.py

## LeNet Model Inference

0. Activate your python environment.
### Activate your python environment.

1. Build buddy-mlir
```bash
$ conda activate <your env>
```

### Build LLVM

```bash
$ cd buddy-mlir
$ mkdir llvm/build
$ cd llvm/build

// CPU
$ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \
-DLLVM_TARGETS_TO_BUILD="host;RISCV" \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DOPENMP_ENABLE_LIBOMPTARGET=OFF \
-DCMAKE_BUILD_TYPE=RELEASE \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DPython3_EXECUTABLE=$(which python3)

// GPU
$ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang;openmp" \
-DLLVM_TARGETS_TO_BUILD="host;RISCV;NVPTX" \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DOPENMP_ENABLE_LIBOMPTARGET=OFF \
-DCMAKE_BUILD_TYPE=RELEASE \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DPython3_EXECUTABLE=$(which python3)

$ ninja check-clang check-mlir omp
```

### Build buddy-mlir

```bash
$ cd buddy-mlir
$ mkdir build && cd build
$ cmake -G Ninja .. \
-DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \
Expand All @@ -31,7 +67,7 @@ $ ninja
$ ninja check-buddy
```

2. Set the `PYTHONPATH` environment variable.
### Set the `PYTHONPATH` environment variable.

Make sure you are in the build directory.

Expand All @@ -41,19 +77,26 @@ $ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build
$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH}
```

3. Set the `LENET_EXAMPLE_PATH` environment variable.
### Set the `LENET_EXAMPLE_PATH` environment variable.

```bash
$ export LENET_EXAMPLE_PATH=${BUDDY_MLIR_BUILD_DIR}/../examples/BuddyLeNet/
```

4. Build and run the LeNet example
### Build and run the LeNet example

```bash
$ cmake -G Ninja .. -DBUDDY_LENET_EXAMPLES=ON

// CPU
$ ninja buddy-lenet-run
$ cd bin
$ ./buddy-lenet-run

// GPU
$ ninja buddy-lenet-run-gpu
$ cd bin
$ ./buddy-lenet-run-gpu
```

## Debug the Lowering Pass Pipeline with Fake Parameters.
Expand Down
4 changes: 1 addition & 3 deletions examples/BuddyLeNet/buddy-lenet-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import numpy as np
import torch
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
Expand All @@ -39,13 +38,12 @@
)

model = LeNet()
model = torch.load(model_path + "/lenet-model.pth")
model = torch.load(model_path + "/lenet-model.pth", weights_only=False)
model = model.eval()

# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

data = torch.randn([1, 1, 28, 28])
Expand Down
1 change: 1 addition & 0 deletions examples/BuddyLlama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ add_custom_command(
COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyLlama/subgraph0.mlir
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" |
${BUDDY_BINARY_DIR}/buddy-opt
-convert-elementwise-to-linalg
-arith-expand
-eliminate-empty-tensors
-empty-tensor-to-alloc-tensor
Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/import-llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)

# Initialize the tokenizer and model from the specified model path.
tokenizer = LlamaTokenizer.from_pretrained(model_path)
tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True)
model = LlamaForCausalLM.from_pretrained(model_path, torchscript=True)
model.config.use_cache = False

Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyLlama/llama-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

using namespace buddy;

constexpr size_t ParamsSize = 6755192832;
constexpr size_t ParamsSize = 6738415680;
constexpr size_t MaxVocabSize = 32000;
constexpr size_t MaxTokenLength = 40;
constexpr size_t HiddenSize = 4096;
Expand Down
3 changes: 1 addition & 2 deletions examples/BuddyLlama/llama_annotation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,9 @@ module {
%113 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x32x40x128xf32>}> : () -> tensor<1x32x40x128xf32>
%114 = tosa.add %65, %113 : (tensor<1x32x40x128xf32>, tensor<1x32x40x128xf32>) -> tensor<1x32x40x128xf32>
%115 = tosa.reshape %114 {new_shape = array<i64: 32, 40, 128>} : (tensor<1x32x40x128xf32>) -> tensor<32x40x128xf32>
//
%116 = tosa.matmul %112, %115 : (tensor<32x40x40xf32>, tensor<32x40x128xf32>) -> tensor<32x40x128xf32>
// complete one head Softmax(QK/sqrt(d_k)), collect all heads.
%117 = tosa.reshape %116 {new_shape = array<i64: 1, 32, 40, 128>} : (tensor<32x40x128xf32>) -> tensor<1x32x40x128xf32>
// complete one head Softmax(QK/sqrt(d_k)), collect all heads.
%118 = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
%119 = tosa.transpose %117, %118 : (tensor<1x32x40x128xf32>, tensor<4xi32>) -> tensor<1x40x32x128xf32>
%120 = tosa.identity %119 : (tensor<1x40x32x128xf32>) -> tensor<1x40x32x128xf32>
Expand Down
25 changes: 15 additions & 10 deletions examples/BuddyMatmul/linalg-batchmatmul-f32.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,24 @@
// RUN: | FileCheck %s

func.func private @printMemrefF32(memref<*xf32>)
func.func private @rtclock() -> f64

func.func @batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
%t_start = call @rtclock() : () -> f64

linalg.batch_matmul
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%arg2 : memref<?x?x?xf32>)

%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64

%printed_output = memref.cast %arg2 : memref<?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()

// Print timings.
vector.print %time : f64

return
}

Expand Down Expand Up @@ -54,29 +67,21 @@ func.func @main(){
%m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref<?x?x?xf32>
%m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref<?x?x?xf32>

call @batch_matmul(%m0, %m1, %m2) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%printed_m2 = memref.cast %m2 : memref<?x?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data =
// CHECK-NEXT: [
// CHECK: [
// CHECK: [3456{{(, 3456)*}}]
call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> ()
call @batch_matmul(%m0, %m1, %m2) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref<?x?x?xf32>
%m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref<?x?x?xf32>
%m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref<?x?x?xf32>

call @batch_matmul(%m3, %m4, %m5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%printed_m5 = memref.cast %m5 : memref<?x?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data =
// CHECK-NEXT: [
// CHECK: [
// CHECK: [6144{{(, 6144)*}}]
call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> ()
call @batch_matmul(%m3, %m4, %m5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

return
}
47 changes: 47 additions & 0 deletions examples/BuddyMatmul/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ OPT_FLAG := -O0
ifeq ($(shell uname),Linux)
MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.so
MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.so
LIB_OMP := ${LLVM_BUILD_DIR}/lib/libomp.so
MTRIPLE := x86_64-unknown-linux-gnu
else ifeq ($(shell uname),Darwin)
MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.dylib
Expand All @@ -36,6 +37,52 @@ linalg-batchmatmul-f32-run:
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batchmatmul-f32-omp-lower:
@${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \
-batchmatmul-optimize \
-convert-linalg-to-affine-loops \
-affine-parallelize \
-lower-affine \
-convert-scf-to-openmp \
-convert-vector-to-scf \
-expand-strided-metadata \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-openmp-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts \
-o log.mlir

linalg-batchmatmul-f32-omp-run:
@${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \
-batchmatmul-optimize \
-convert-linalg-to-affine-loops \
-affine-parallelize \
-lower-affine \
-convert-scf-to-openmp \
-convert-vector-to-scf \
-expand-strided-metadata \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-openmp-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} \
-shared-libs=${LIB_OMP}

linalg-matmul-transpose-b-f32-run:
@${BUDDY_OPT} ./linalg-transposematmulb-f32.mlir\
-matmul-transpose-b-vectorization \
Expand Down
1 change: 1 addition & 0 deletions examples/BuddyMobileNetV3/buddy-mobilenetv3-import.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pathlib import Path
import numpy as np
import torch
import torch._inductor.lowering
import torchvision.models as models
from torch._inductor.decomposition import decompositions as inductor_decomp

Expand Down
Loading

0 comments on commit 5131c79

Please sign in to comment.