Skip to content

Commit

Permalink
[midend] Fix batch matmul vectorization pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghb97 committed Aug 16, 2024
1 parent 4b87f52 commit 7e796f6
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 1 deletion.
2 changes: 2 additions & 0 deletions examples/BuddyMatmul/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
log.*
s
82 changes: 82 additions & 0 deletions examples/BuddyMatmul/linalg-batchmatmul-f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// RUN: buddy-opt %s \
// RUN: -batchmatmul-optimize \
// RUN: -convert-linalg-to-affine-loops \
// RUN: -lower-affine \
// RUN: -convert-vector-to-scf \
// RUN: -convert-scf-to-cf \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-math-to-llvm \
// RUN: -convert-math-to-libm \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-func-to-llvm \
// RUN: -expand-strided-metadata \
// RUN: -finalize-memref-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

func.func private @printMemrefF32(memref<*xf32>)

func.func @batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
linalg.batch_matmul
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%arg2 : memref<?x?x?xf32>)
return
}

func.func @alloc_f32(%arg0: index, %arg1: index, %arg2: index, %arg4: f32) -> memref<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = memref.alloc(%arg0, %arg1, %arg2) : memref<?x?x?xf32>
scf.for %idx0 = %c0 to %arg0 step %c1 {
scf.for %idx1 = %c0 to %arg1 step %c1 {
scf.for %idx2 = %c0 to %arg2 step %c1 {
memref.store %arg4, %0[%idx0, %idx1, %idx2] : memref<?x?x?xf32>
}
}
}
return %0 : memref<?x?x?xf32>
}

func.func @main(){
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c576 = arith.constant 576 : index
%c1024 = arith.constant 1024 : index
%c1000 = arith.constant 1000 : index
%f0 = arith.constant 0.0 : f32
%f2 = arith.constant 2.0 : f32
%f3 = arith.constant 3.0 : f32

%m0 = call @alloc_f32(%c1, %c1, %c576, %f2) : (index, index, index, f32) -> memref<?x?x?xf32>
%m1 = call @alloc_f32(%c1, %c576, %c1024, %f3) : (index, index, index, f32) -> memref<?x?x?xf32>
%m2 = call @alloc_f32(%c1, %c1, %c1024, %f0) : (index, index, index, f32) -> memref<?x?x?xf32>

call @batch_matmul(%m0, %m1, %m2) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%printed_m2 = memref.cast %m2 : memref<?x?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1024] strides = [1024, 1024, 1] data =
// CHECK-NEXT: [
// CHECK: [
// CHECK: [3456{{(, 3456)*}}]
call @printMemrefF32(%printed_m2) : (memref<*xf32>) -> ()

%m3 = call @alloc_f32(%c1, %c1, %c1024, %f2) : (index, index, index, f32) -> memref<?x?x?xf32>
%m4 = call @alloc_f32(%c1, %c1024, %c1000, %f3) : (index, index, index, f32) -> memref<?x?x?xf32>
%m5 = call @alloc_f32(%c1, %c1, %c1000, %f0) : (index, index, index, f32) -> memref<?x?x?xf32>

call @batch_matmul(%m3, %m4, %m5) : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>) -> ()

%printed_m5 = memref.cast %m5 : memref<?x?x?xf32> to memref<*xf32>

// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 1, 1000] strides = [1000, 1000, 1] data =
// CHECK-NEXT: [
// CHECK: [
// CHECK: [6144{{(, 6144)*}}]
call @printMemrefF32(%printed_m5) : (memref<*xf32>) -> ()

return
}
35 changes: 35 additions & 0 deletions examples/BuddyMatmul/makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
BUDDY_OPT := ../../build-review/bin/buddy-opt
MLIR_OPT := ../../llvm/build/bin/mlir-opt
MLIR_TRANSLATE := ../../llvm/build/bin/mlir-translate
MLIR_CPU_RUNNER := ../../llvm/build/bin/mlir-cpu-runner
LLC := ../../llvm/build/bin/llc
OPT_FLAG := -O0

ifeq ($(shell uname),Linux)
MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.so
MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.so
MTRIPLE := x86_64-unknown-linux-gnu
else ifeq ($(shell uname),Darwin)
MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.dylib
MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.dylib
MTRIPLE := x86_64-apple-darwin
endif

linalg-batchmatmul-f32-run:
@${BUDDY_OPT} ./linalg-batchmatmul-f32.mlir \
-batchmatmul-optimize \
-convert-linalg-to-affine-loops \
-lower-affine \
-convert-vector-to-scf \
-convert-scf-to-cf \
-convert-vector-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-arith-to-llvm \
-convert-func-to-llvm \
-expand-strided-metadata \
-finalize-memref-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
IntegerSet::get(
1, 1, {d0 * -affineVectorSize + s0 - affineVectorSize},
{false}),
ValueRange{loopVarBatchIdx, bCol}, true);
ValueRange{loopVarColOfB, bCol}, true);

// Branch handling full vector operations.
OpBuilder trueBranchBuilder = branchingOp.getThenBodyBuilder();
Expand Down

0 comments on commit 7e796f6

Please sign in to comment.