Skip to content

Commit

Permalink
new layer SelfAttention & RMSNorm finish
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanthaWangdl committed Aug 25, 2024
1 parent ba24f38 commit 46e69c1
Show file tree
Hide file tree
Showing 10 changed files with 778 additions and 0 deletions.
2 changes: 2 additions & 0 deletions benchmarks/DeepLearning/Layers/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
add_subdirectory(FFN)
add_subdirectory(SelfAttention)
add_subdirectory(RMSNorm)
9 changes: 9 additions & 0 deletions benchmarks/DeepLearning/Layers/RMSNorm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__pycache__

# model params file
arg0.data
arg1.data

# model mlir file
forward.mlir
subgraph0.mlir
169 changes: 169 additions & 0 deletions benchmarks/DeepLearning/Layers/RMSNorm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
add_custom_command(
OUTPUT
${CMAKE_CURRENT_SOURCE_DIR}/forward.mlir
${CMAKE_CURRENT_SOURCE_DIR}/subgraph0.mlir
COMMAND
python3 ${CMAKE_CURRENT_SOURCE_DIR}/buddy_rmsnorm_import.py
COMMENT "Generating forward.mlir, subgraph0.mlir"
)

add_custom_command(
OUTPUT forward_scalar.o
COMMAND
cat ${CMAKE_CURRENT_SOURCE_DIR}/forward.mlir |
sed -e {s/@forward/@forward_scalar/}
-e {s/@subgraph0/@subgraph0_scalar/} |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-pass-pipeline
"builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), \
empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, \
func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-pass-pipeline
"builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), \
eliminate-empty-tensors, func.func(llvm-request-c-wrappers), \
convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, \
convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, \
convert-func-to-llvm, reconcile-unrealized-casts)" |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llc -O0
-mtriple=${BUDDY_OPT_TRIPLE} -mattr=${BUDDY_OPT_ATTR} -filetype=obj
-o ${CMAKE_CURRENT_BINARY_DIR}/forward_scalar.o
DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/subgraph0.mlir
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
COMMENT "Building forward_scalar.o"
VERBATIM)

add_custom_command(
OUTPUT subgraph0_scalar.o
COMMAND
cat ${CMAKE_CURRENT_SOURCE_DIR}/subgraph0.mlir |
sed -e {s/@subgraph0/@subgraph0_scalar/} |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-pass-pipeline
"builtin.module(func.func(tosa-to-linalg-named, tosa-to-arith, tosa-to-linalg, tosa-to-tensor))" |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-convert-elementwise-to-linalg
-func-bufferize-dynamic-offset
-arith-bufferize
-func-bufferize
-tensor-bufferize
-linalg-bufferize
-finalizing-bufferize
-convert-linalg-to-affine-loops
-lower-affine
-convert-vector-to-scf
-convert-scf-to-cf
-llvm-request-c-wrappers
-convert-vector-to-llvm
-convert-math-to-llvm
-convert-math-to-libm
-convert-arith-to-llvm
-convert-func-to-llvm
-expand-strided-metadata
-finalize-memref-to-llvm
-reconcile-unrealized-casts |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llc -O0
-mtriple=${BUDDY_OPT_TRIPLE} -mattr=${BUDDY_OPT_ATTR} -filetype=obj
-o ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_scalar.o
DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/subgraph0.mlir
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
COMMENT "Building subgraph0_scalar.o"
VERBATIM)

add_custom_command(
OUTPUT forward_auto_vectorization.o
COMMAND
cat ${CMAKE_CURRENT_SOURCE_DIR}/forward.mlir |
sed -e {s/@forward/@forward_auto_vectorization/}
-e {s/@subgraph0/@subgraph0_auto_vectorization/} |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-pass-pipeline
"builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), \
empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, \
func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-pass-pipeline
"builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), \
eliminate-empty-tensors, func.func(llvm-request-c-wrappers), \
convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, \
convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, \
convert-func-to-llvm, reconcile-unrealized-casts)" |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llc -O3
-mtriple=${BUDDY_OPT_TRIPLE} -mattr=${BUDDY_OPT_ATTR} -filetype=obj
-o ${CMAKE_CURRENT_BINARY_DIR}/forward_auto_vectorization.o
DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/forward.mlir
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
COMMENT "Building forward_auto_vectorization.o"
VERBATIM)

add_custom_command(
OUTPUT subgraph0_auto_vectorization.o
COMMAND
cat ${CMAKE_CURRENT_SOURCE_DIR}/subgraph0.mlir |
sed -e {s/@subgraph0/@subgraph0_auto_vectorization/} |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-pass-pipeline
"builtin.module(func.func(tosa-to-linalg-named, tosa-to-arith, tosa-to-linalg, tosa-to-tensor))" |
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
-convert-elementwise-to-linalg
-func-bufferize-dynamic-offset
-arith-bufferize
-func-bufferize
-tensor-bufferize
-linalg-bufferize
-finalizing-bufferize
-convert-linalg-to-loops
-lower-affine
-convert-scf-to-cf
-llvm-request-c-wrappers
-convert-math-to-llvm
-convert-math-to-libm
-convert-arith-to-llvm
-convert-func-to-llvm
-expand-strided-metadata
-finalize-memref-to-llvm
-reconcile-unrealized-casts |
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
${LLVM_MLIR_BINARY_DIR}/llc -O3
-mtriple=${BUDDY_OPT_TRIPLE} -mattr=${BUDDY_OPT_ATTR} -filetype=obj
-o ${CMAKE_CURRENT_BINARY_DIR}/subgraph0_auto_vectorization.o
DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/subgraph0.mlir
${BUDDY_MLIR_BINARY_DIR}/buddy-opt
COMMENT "Building subgraph0_auto_vectorization.o"
VERBATIM)

add_library(RMSNORM_SCALAR STATIC subgraph0_scalar.o forward_scalar.o)
set_target_properties(RMSNORM_SCALAR PROPERTIES LINKER_LANGUAGE CXX)

add_library(RMSNORM_AUTO_VECTORIZATION STATIC subgraph0_auto_vectorization.o forward_auto_vectorization.o)
set_target_properties(RMSNORM_AUTO_VECTORIZATION PROPERTIES LINKER_LANGUAGE CXX)

add_executable(dl-layer-rmsnorm-benchmark
GoogleBenchmarkMain.cpp
)

set_target_properties(dl-layer-rmsnorm-benchmark PROPERTIES
LINK_FLAGS "-static"
)

set(BenchmarkTool GoogleBenchmark)

if(CROSS_COMPILE_RVV)
set(BUDDY_LIB_DIR ${BUDDY_MLIR_CROSS_LIB_DIR})
else()
set(BUDDY_LIB_DIR ${BUDDY_MLIR_LIB_DIR})
endif()

target_link_libraries(dl-layer-rmsnorm-benchmark
${BenchmarkTool}
RMSNORM_AUTO_VECTORIZATION
RMSNORM_SCALAR
${BUDDY_LIB_DIR}/libStaticMLIRCRunnerUtils.a
)
147 changes: 147 additions & 0 deletions benchmarks/DeepLearning/Layers/RMSNorm/GoogleBenchmarkMain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
//===- GoogleBenchmarkMain.cpp---------------------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//
//
// This file implements the benchmark for RMSNORM layer.
//
//===----------------------------------------------------------------------===//

#include <benchmark/benchmark.h>
#include <buddy/Core/Container.h>
#include <iostream>
#include <random>

// Define target layout.
#define INPUT_DIM 256
#define BATCH_SIZE 1
#define OUTPUT_DIM 256
constexpr size_t ParamsSize = 256;

// Helper functions and variables.
namespace {
const std::string PASS = "\033[32mPASS\033[0m";
const std::string FAIL = "\033[31mFAIL\033[0m";

bool areArraysEqual(float array1[], float array2[], int size) {
for (int i = 0; i < size; ++i) {
if (array1[i] != array2[i]) {
return false;
}
}
return true;
}
} // namespace

namespace {
// Declare the RMSNORM layer C interface.
extern "C" {
void _mlir_ciface_forward_scalar(MemRef<float, 2> *output,
MemRef<float, 1> *input1,
MemRef<float, 2> *input2);
void _mlir_ciface_forward_auto_vectorization(MemRef<float, 2> *output,
MemRef<float, 1> *input1,
MemRef<float, 2> *input2);
}

} // namespace

template <typename Func> void DL_LAYER_RMSNORM(benchmark::State &state, Func func) {

// Define the sizes of the input and output tensors.
intptr_t sizesInput[2] = {BATCH_SIZE, INPUT_DIM};
intptr_t sizesOutput[2] = {BATCH_SIZE, OUTPUT_DIM};
intptr_t sizesParams[1] = {ParamsSize};

MemRef<float, 2> input1(sizesInput, 2);
MemRef<float, 2> output(sizesOutput, 0);
MemRef<float, 1> paramsContainer(sizesParams, 3);

for (auto _ : state) {
func(&output, &paramsContainer, &input1);
}
}

BENCHMARK_CAPTURE(DL_LAYER_RMSNORM, Scalar, _mlir_ciface_forward_scalar)
->Unit(benchmark::kMillisecond);
BENCHMARK_CAPTURE(DL_LAYER_RMSNORM, Auto_Vectorization,
_mlir_ciface_forward_auto_vectorization)
->Unit(benchmark::kMillisecond);

/// Correctness Verification
/// The verification does not affect the performance.
/// - Set the scalar case as the criteria.
/// - Input elements are random numbers.
/// - Output elements are initialized to zero.
/// - Compare the output of various optimizations with the scalar version to
/// verify correctness.
void verification() {
// Set the random number generator.
std::random_device rd;
std::mt19937 generator(rd());
std::uniform_real_distribution<float> distribution(0.0, 1.0);

// Set the layout sizes of input and output memref container.
intptr_t sizesInput[2] = {BATCH_SIZE, INPUT_DIM};
intptr_t sizesOutput[2] = {BATCH_SIZE, OUTPUT_DIM};
intptr_t sizesParams[1] = {ParamsSize};

// Generate input memref containers with random numbers.
const int inputSize = BATCH_SIZE * INPUT_DIM;
float inputRand1[inputSize];
float inputRand2[ParamsSize];

for (int i = 0; i < inputSize; ++i) {
inputRand1[i] = distribution(generator);
}
for (int i = 0; i < ParamsSize; ++i) {
inputRand2[i] = distribution(generator);
}

MemRef<float, 2> inputMemRef(inputRand1, sizesInput);
MemRef<float, 1> paramsContainer(inputRand2, sizesParams);

// Generate output memref containers with zero.
MemRef<float, 2> outputScalar(sizesOutput);
MemRef<float, 2> outputAutoVectorization(sizesOutput);

// Perform all the addf implementations.
_mlir_ciface_forward_scalar(&outputScalar, &paramsContainer, &inputMemRef);
_mlir_ciface_forward_auto_vectorization(&outputAutoVectorization,
&paramsContainer, &inputMemRef);
// Get the result array.
auto resultScalar = outputScalar.getData();
auto resultAutoVectorization = outputAutoVectorization.getData();

// Print the verification result.
std::cout << "-----------------------------------------------------------"
<< std::endl;
std::cout << "Correctness Verification: "
<< (areArraysEqual(resultScalar, resultAutoVectorization,
sizesOutput[0] * sizesOutput[1])
? PASS
: FAIL)
<< std::endl;
std::cout << "-----------------------------------------------------------"
<< std::endl;
}

int main(int argc, char **argv) {
// Run benchmark.
::benchmark::Initialize(&argc, argv);
::benchmark::RunSpecifiedBenchmarks();
// Run correctness verification.
verification();
return 0;
}
62 changes: 62 additions & 0 deletions benchmarks/DeepLearning/Layers/RMSNorm/buddy_rmsnorm_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import torch
from torch._inductor.decomposition import decompositions as inductor_decomp

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.graph import GraphDriver
from buddy.compiler.graph.transform import simply_fuse
from buddy.compiler.ops import tosa

# Define the RMSNorm layer.
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps=1e-8):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))

def forward(self, x):
# Compute the root mean square of x
rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
# Normalize and apply the scale parameter
x = x / rms * self.weight
return x

# Initialize the RMSNorm model and set to evaluation mode.
input_dim = 256
model = RMSNorm(input_dim)
model = model.eval()

# Initialize Dynamo Compiler with specific configurations as an importer.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
)

# Define input data.
data = torch.randn([1, input_dim], dtype=torch.float32)

# Import the model into MLIR module and parameters.
with torch.no_grad():
graphs = dynamo_compiler.importer(model, data)
assert len(graphs) == 1
graph = graphs[0]
params = dynamo_compiler.imported_params[graph]

# Apply graph transformations (e.g., fusion of operations).
pattern_list = [simply_fuse]
graph.fuse_ops(pattern_list)

# Lower the graph to the top-level MLIR IR.
driver = GraphDriver(graph)
driver.subgraphs[0].lower_to_top_level_ir()

# Define the output path for MLIR files.
path_prefix = os.path.dirname(os.path.abspath(__file__))

# Save the generated subgraph MLIR module to a file.
with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file:
print(driver.subgraphs[0]._imported_module, file=module_file)

# Save the forward MLIR module to a file.
with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file:
print(driver.construct_main_graph(True), file=module_file)
Loading

0 comments on commit 46e69c1

Please sign in to comment.