Skip to content

Commit

Permalink
[examples] Add omp pass pipeline for batch matmul.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Dec 19, 2024
1 parent 1953db1 commit 9629eea
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
25 changes: 15 additions & 10 deletions examples/BuddyMatmul/linalg-batchmatmul-f32.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,24 @@
// RUN: | FileCheck %s

func.func private @printMemrefF32(memref<*xf32>)
func.func private @rtclock() -> f64

func.func @batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
%t_start = call @rtclock() : () -> f64

linalg.batch_matmul
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%arg2 : memref<?x?x?xf32>)

%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64

%printed_output = memref.cast %arg2 : memref<?x?x?xf32> to memref<*xf32>
call @printMemrefF32(%printed_output) : (memref<*xf32>) -> ()

// Print timings.
vector.print %time : f64

return
}

Expand Down Expand Up @@ -54,29 +67,21 @@ func.func @main(){
%m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref<?x?x?xf32>
%m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref<?x?x?xf32>

call @batch_matmul(%m0, %m1, %m2) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%printed_m2 = memref.cast %m2 : memref<?x?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data =
// CHECK-NEXT: [
// CHECK: [
// CHECK: [3456{{(, 3456)*}}]
call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> ()
call @batch_matmul(%m0, %m1, %m2) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref<?x?x?xf32>
%m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref<?x?x?xf32>
%m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref<?x?x?xf32>

call @batch_matmul(%m3, %m4, %m5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%printed_m5 = memref.cast %m5 : memref<?x?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data =
// CHECK-NEXT: [
// CHECK: [
// CHECK: [6144{{(, 6144)*}}]
call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> ()
call @batch_matmul(%m3, %m4, %m5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

return
}
47 changes: 47 additions & 0 deletions examples/BuddyMatmul/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ OPT_FLAG := -O0
ifeq ($(shell uname),Linux)
MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.so
MLIR_C_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_c_runner_utils.so
LIB_OMP := ${LLVM_BUILD_DIR}/lib/libomp.so
MTRIPLE := x86_64-unknown-linux-gnu
else ifeq ($(shell uname),Darwin)
MLIR_RUNNER_UTILS := ${LLVM_BUILD_DIR}/lib/libmlir_runner_utils.dylib
Expand All @@ -36,6 +37,52 @@ linalg-batchmatmul-f32-run:
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

linalg-batchmatmul-f32-omp-lower:
@${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \
-batchmatmul-optimize \
-convert-linalg-to-affine-loops \
-affine-parallelize \
-lower-affine \
-convert-scf-to-openmp \
-convert-vector-to-scf \
-expand-strided-metadata \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-openmp-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts \
-o log.mlir

linalg-batchmatmul-f32-omp-run:
@${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \
-batchmatmul-optimize \
-convert-linalg-to-affine-loops \
-affine-parallelize \
-lower-affine \
-convert-scf-to-openmp \
-convert-vector-to-scf \
-expand-strided-metadata \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-openmp-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} \
-shared-libs=${LIB_OMP}

linalg-matmul-transpose-b-f32-run:
@${BUDDY_OPT} ./linalg-transposematmulb-f32.mlir\
-matmul-transpose-b-vectorization \
Expand Down

0 comments on commit 9629eea

Please sign in to comment.