Skip to content

Commit

Permalink
[frontend] Add ops fusion demo. (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
effrey-liu authored Oct 31, 2024
1 parent 86cd5af commit fdd324c
Show file tree
Hide file tree
Showing 15 changed files with 919 additions and 13 deletions.
83 changes: 83 additions & 0 deletions examples/BuddyFusedLeNet/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
file(GLOB SUBGRAPH_FILES "${CMAKE_CURRENT_SOURCE_DIR}/subgraph*.mlir")

add_custom_command(
OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyFusedLeNet/arg0.data
${BUDDY_EXAMPLES_DIR}/BuddyFusedLeNet/forward.mlir
${SUBGRAPH_FILES}
COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyFusedLeNet/buddy-lenet-import.py
COMMENT "Generating forward.mlir, subgraph*.mlir and parameter files"
)

set(SUBGRAPH_OBJECTS "")
foreach(SUBGRAPH_FILE ${SUBGRAPH_FILES})
get_filename_component(SUBGRAPH_NAME ${SUBGRAPH_FILE} NAME_WE)
set(OBJECT_FILE ${BUDDY_BINARY_DIR}/../examples/BuddyFusedLeNet/${SUBGRAPH_NAME}.o)
list(APPEND SUBGRAPH_OBJECTS ${OBJECT_FILE})

add_custom_command(
OUTPUT ${OBJECT_FILE}
COMMAND ${BUDDY_BINARY_DIR}/buddy-opt ${SUBGRAPH_FILE}
-pass-pipeline
"builtin.module(func.func(tosa-to-linalg-named, tosa-to-arith, tosa-to-linalg, tosa-to-tensor))" |
${BUDDY_BINARY_DIR}/buddy-opt
-convert-elementwise-to-linalg
-func-bufferize-dynamic-offset
-arith-bufferize
-func-bufferize
-tensor-bufferize
-linalg-bufferize
-finalizing-bufferize
-convert-linalg-to-loops
-lower-affine
-convert-scf-to-cf
-llvm-request-c-wrappers
-convert-math-to-llvm
-convert-math-to-libm
-convert-arith-to-llvm
-convert-func-to-llvm
-expand-strided-metadata
-finalize-memref-to-llvm
-reconcile-unrealized-casts |
${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_TOOLS_BINARY_DIR}/llvm-as |
${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3
-o ${OBJECT_FILE}
DEPENDS ${SUBGRAPH_FILE} buddy-opt
COMMENT "Building ${SUBGRAPH_NAME}.o"
VERBATIM
)
endforeach()

add_custom_command(
OUTPUT forward.o
COMMAND ${LLVM_TOOLS_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyFusedLeNet/forward.mlir
-pass-pipeline
"builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), \
empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, \
func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" |
${LLVM_TOOLS_BINARY_DIR}/mlir-opt
-pass-pipeline
"builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), \
eliminate-empty-tensors, func.func(llvm-request-c-wrappers), \
convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, \
convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, \
convert-func-to-llvm, reconcile-unrealized-casts)" |
${LLVM_TOOLS_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_TOOLS_BINARY_DIR}/llvm-as |
${LLVM_TOOLS_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O3
-o ${BUDDY_BINARY_DIR}/../examples/BuddyFusedLeNet/forward.o
DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyFusedLeNet/forward.mlir
COMMENT "Building forward.o"
VERBATIM
)

add_library(FUSED_LENET STATIC ${SUBGRAPH_OBJECTS} forward.o)

SET_TARGET_PROPERTIES(FUSED_LENET PROPERTIES LINKER_LANGUAGE C)

add_executable(buddy-fused-lenet-run buddy-lenet-main.cpp)
target_link_directories(buddy-fused-lenet-run PRIVATE ${LLVM_LIBRARY_DIR})

set(BUDDY_LENET_LIBS FUSED_LENET mlir_c_runner_utils ${PNG_LIBRARIES})

target_link_libraries(buddy-fused-lenet-run ${BUDDY_LENET_LIBS})
47 changes: 47 additions & 0 deletions examples/BuddyFusedLeNet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Buddy Compiler Fuse Ops Example

## Fused LeNet Model Inference

0. Activate your python environment.

1. Build buddy-mlir

```bash
$ mkdir build && cd build
$ cmake -G Ninja .. \
-DMLIR_DIR=$PWD/../llvm/build/lib/cmake/mlir \
-DLLVM_DIR=$PWD/../llvm/build/lib/cmake/llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DCMAKE_BUILD_TYPE=RELEASE \
-DBUDDY_MLIR_ENABLE_PYTHON_PACKAGES=ON \
-DPython3_EXECUTABLE=$(which python3) \
-DBUDDY_MLIR_ENABLE_DIP_LIB=ON \
-DBUDDY_ENABLE_PNG=ON
$ ninja
$ ninja check-buddy
```

2. Set the `PYTHONPATH` environment variable.

Make sure you are in the build directory.

```bash
$ export BUDDY_MLIR_BUILD_DIR=$PWD
$ export LLVM_MLIR_BUILD_DIR=$PWD/../llvm/build
$ export PYTHONPATH=${LLVM_MLIR_BUILD_DIR}/tools/mlir/python_packages/mlir_core:${BUDDY_MLIR_BUILD_DIR}/python_packages:${PYTHONPATH}
```

3. Set the `FUSED_LENET_EXAMPLE_PATH` environment variable.

```bash
$ export FUSED_LENET_EXAMPLE_PATH=${BUDDY_MLIR_BUILD_DIR}/../examples/BuddyFusedLeNet/
```

4. Build and run the Fused LeNet example

```bash
$ cmake -G Ninja .. -DBUDDY_FUSED_LENET_EXAMPLES=ON
$ ninja buddy-fused-lenet-run
$ cd bin
$ ./buddy-fused-lenet-run
```
79 changes: 79 additions & 0 deletions examples/BuddyFusedLeNet/buddy-lenet-import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# ===- buddy-lenet-import.py ---------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# This is the LeNet model AOT importer.
#
# ===---------------------------------------------------------------------------

import os
from pathlib import Path

import numpy as np
import torch
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse, my_fuse_ops_test
from buddy.compiler.ops import tosa
from model import LeNet

# Retrieve the LeNet model path from environment variables.
model_path = os.environ.get("FUSED_LENET_EXAMPLE_PATH")
if model_path is None:
raise EnvironmentError(
"The environment variable 'LENET_MODEL_PATH' is not set or is invalid."
)

model = LeNet()
model = torch.load(model_path + "/lenet-model.pth")
model = model.eval()

# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

data = torch.randn([1, 1, 28, 28])
# Import the model into MLIR module and parameters.
with torch.no_grad():
graphs = dynamo_compiler.importer(model, data)

assert len(graphs) == 1
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]
pattern_list = [my_fuse_ops_test]
graph.fuse_ops(pattern_list)
driver = GraphDriver(graph)
path_prefix = os.path.dirname(os.path.abspath(__file__))

for i in range(len(driver.subgraphs)):
driver.subgraphs[i].lower_to_top_level_ir()
with open(os.path.join(path_prefix, f"subgraph{i}.mlir"), "w") as module_file:
print(driver.subgraphs[i]._imported_module, file=module_file)

with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file:
print(driver.construct_main_graph(True), file=module_file)

params = dynamo_compiler.imported_params[graph]
current_path = os.path.dirname(os.path.abspath(__file__))

float32_param = np.concatenate(
[param.detach().numpy().reshape([-1]) for param in params]
)

float32_param.tofile(Path(current_path) / "arg0.data")
133 changes: 133 additions & 0 deletions examples/BuddyFusedLeNet/buddy-lenet-main.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//===- buddy-lenet-main.cpp -----------------------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//

#include <buddy/Core/Container.h>
#include <buddy/DIP/ImgContainer.h>
#include <chrono>
#include <cmath>
#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <limits>
#include <string>
#include <utility>
#include <vector>

constexpr size_t ParamsSize = 44426;
const std::string ImgName = "1-28*28.png";

/// Declare LeNet forward function.
extern "C" void _mlir_ciface_forward(MemRef<float, 2> *output,
MemRef<float, 1> *arg0,
dip::Image<float, 4> *input);

/// Print [Log] label in bold blue format.
void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; }

/// Load parameters into data container.
void loadParameters(const std::string &paramFilePath,
MemRef<float, 1> &params) {
const auto loadStart = std::chrono::high_resolution_clock::now();
// Open the parameter file in binary mode.
std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary);
if (!paramFile.is_open()) {
throw std::runtime_error("[Error] Failed to open params file!");
}
printLogLabel();
std::cout << "Loading params..." << std::endl;
printLogLabel();
// Print the canonical path of the parameter file.
std::cout << "Params file: " << std::filesystem::canonical(paramFilePath)
<< std::endl;
// Read the parameter data into the provided memory reference.
paramFile.read(reinterpret_cast<char *>(params.getData()),
sizeof(float) * (params.getSize()));
if (paramFile.fail()) {
throw std::runtime_error("Error occurred while reading params file!");
}
paramFile.close();
const auto loadEnd = std::chrono::high_resolution_clock::now();
const std::chrono::duration<double, std::milli> loadTime =
loadEnd - loadStart;
printLogLabel();
std::cout << "Params load time: " << (double)(loadTime.count()) / 1000
<< "s\n"
<< std::endl;
}

/// Softmax function to convert logits to probabilities.
void softmax(float *input, size_t size) {
size_t i;
float max_value = -INFINITY;
double sum = 0.0;
// Find the maximum value in the input array for numerical stability.
for (i = 0; i < size; ++i) {
if (max_value < input[i]) {
max_value = input[i];
}
}
// Calculate the sum of the exponentials of the input elements, normalized by
// the max value.
for (i = 0; i < size; ++i) {
sum += exp(input[i] - max_value);
}
// Normalize the input array with the softmax calculation.
for (i = 0; i < size; ++i) {
input[i] = exp(input[i] - max_value) / sum;
}
}

int main() {
// Print the title of this example.
const std::string title = "Fused LeNet Inference Powered by Buddy Compiler";
std::cout << "\033[33;1m" << title << "\033[0m" << std::endl;

// Define the sizes of the output tensors.
intptr_t sizesOutput[2] = {1, 10};

// Create input and output containers for the image and model output.
std::string lenetDir = getenv("FUSED_LENET_EXAMPLE_PATH");
std::string imgPath = lenetDir + "/images/" + ImgName;
dip::Image<float, 4> input(imgPath, dip::DIP_GRAYSCALE, true /* norm */);
MemRef<float, 2> output(sizesOutput);

// Load model parameters from the specified file.
std::string paramsDir = lenetDir + "/arg0.data";
MemRef<float, 1> paramsContainer({ParamsSize});
loadParameters(paramsDir, paramsContainer);

// Call the forward function of the model.
_mlir_ciface_forward(&output, &paramsContainer, &input);

// Apply softmax to the output logits to get probabilities.
auto out = output.getData();
softmax(out, 10);

// Find the classification and print the result.
float maxVal = 0;
float maxIdx = 0;
for (int i = 0; i < 10; ++i) {
if (out[i] > maxVal) {
maxVal = out[i];
maxIdx = i;
}
}

std::cout << "Classification: " << maxIdx << std::endl;
std::cout << "Probability: " << maxVal << std::endl;

return 0;
}
Binary file added examples/BuddyFusedLeNet/images/1-28*28.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/BuddyFusedLeNet/lenet-model.pth
Binary file not shown.
41 changes: 41 additions & 0 deletions examples/BuddyFusedLeNet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# ===- model.py ----------------------------------------------------------------
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===---------------------------------------------------------------------------
#
# LeNet model definition.
#
# ===---------------------------------------------------------------------------

import torch
import torch.nn as nn

class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
4 changes: 4 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ if (BUDDY_LENET_EXAMPLES)
add_subdirectory(BuddyLeNet)
endif()

if (BUDDY_FUSED_LENET_EXAMPLES)
add_subdirectory(BuddyFusedLeNet)
endif()

if(BUDDY_WHISPER_EXAMPLES)
add_subdirectory(BuddyWhisper)
endif()
Expand Down
Loading

0 comments on commit fdd324c

Please sign in to comment.