From 0a991a98137d490f75ceb765ea5823917beff377 Mon Sep 17 00:00:00 2001 From: qzylalala <304228244@qq.com> Date: Mon, 6 May 2024 12:49:58 +0800 Subject: [PATCH 1/2] [Gemmini] fuse matmul and transpose --- .../GemminiDialect/matmul-transpose-fuse.mlir | 22 +++++++++++++++ .../LowerLinalgToGemmini.cpp | 27 ++++++++++++++++--- 2 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 examples/GemminiDialect/matmul-transpose-fuse.mlir diff --git a/examples/GemminiDialect/matmul-transpose-fuse.mlir b/examples/GemminiDialect/matmul-transpose-fuse.mlir new file mode 100644 index 0000000000..51f2d350af --- /dev/null +++ b/examples/GemminiDialect/matmul-transpose-fuse.mlir @@ -0,0 +1,22 @@ +// RUN: buddy-opt %s \ +// RUN: --convert-linalg-to-gemmini | \ +// RUN: FileCheck %s + +func.func @matmul_transpose(%lhs: memref<3x4xi8>, %rhs: memref<4x3xi8>, + %output: memref<3x3xi8>) { + // Matrix-matrix multiplication + %matmul = memref.alloc() : memref<3x3xi8> + // CHECK: gemmini.tile_matmul %arg0 %arg1 %arg2 %alloc_0 + linalg.matmul + ins(%lhs, %rhs: memref<3x4xi8>, memref<4x3xi8>) + outs(%output: memref<3x3xi8>) + + // transpose + linalg.transpose + ins(%matmul: memref<3x3xi8>) + outs(%output: memref<3x3xi8>) + permutation = [1, 0] + + memref.dealloc %matmul : memref<3x3xi8> + return +} diff --git a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp index bfee320cc4..5ce24a5e90 100644 --- a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp +++ b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp @@ -63,10 +63,29 @@ class MatmulLowering : public OpRewritePattern { Value fillOpInputValue = rewriter.create(loc, fillOpInsType, fillOpInputAttr); rewriter.create(loc, fillOpInputValue, bias); - rewriter.replaceOpWithNewOp( - matMulOp, input0, input1, output0, bias, /*aScaleFactor = */ scale1, - /*bScaleFactor = */ scale1, /*dScaleFactor = */ scale1, /*act = */ 0, - /*accScale = */ scale1, /*bertScale = */ scale0); + + // If this matmul operation is followed by a transpose operation, do fusion. + // We assume that the result of this matmul op only has one user. + if (matMulOp->hasOneUse()) { + // llvm::outs() << "Step in. \n"; + Operation* userOp = *matMulOp->user_begin(); + if (auto transposeOp = dyn_cast(userOp)) { + // (A * B)T = BT * AT + rewriter.replaceOpWithNewOp( + matMulOp, input1, input0, output0, bias, /*aScaleFactor = */ scale1, + /*bScaleFactor = */ scale1, /*dScaleFactor = */ scale1, /*act = */0, + /*accScale = */ scale1, /*bertScale = */ scale0, + /*aTranspose = */ true, /*bTranspose = */ true); + rewriter.eraseOp(transposeOp); + } + } else { + // llvm::outs() << "Not step in. \n"; + rewriter.replaceOpWithNewOp( + matMulOp, input0, input1, output0, bias, /*aScaleFactor = */ scale1, + /*bScaleFactor = */ scale1, /*dScaleFactor = */ scale1, /*act = */ 0, + /*accScale = */ scale1, /*bertScale = */ scale0); + } + rewriter.create(loc, bias); return success(); } From de0cc06fea717c557a47888ca7c9bedc611eb319 Mon Sep 17 00:00:00 2001 From: qzylalala <304228244@qq.com> Date: Tue, 7 May 2024 14:58:43 +0800 Subject: [PATCH 2/2] [Gemmini] add test for fusion --- examples/GemminiDialect/makefile | 12 +++++++ .../GemminiDialect/matmul-transpose-fuse.mlir | 34 +++++++++++++------ .../LowerLinalgToGemmini.cpp | 33 ++++++++++++------ 3 files changed, 59 insertions(+), 20 deletions(-) diff --git a/examples/GemminiDialect/makefile b/examples/GemminiDialect/makefile index cba84b780a..0beb7f3827 100644 --- a/examples/GemminiDialect/makefile +++ b/examples/GemminiDialect/makefile @@ -49,6 +49,18 @@ matmul-os-run: @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out @spike --extension=gemmini pk a.out +matmul-transpose-fuse-run: + @${BUDDY_OPT} ./matmul-transpose-fuse.mlir \ + -convert-linalg-to-gemmini \ + -convert-linalg-to-loops \ + -lower-gemmini | \ + ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ + ${BUDDY_LLC} -filetype=obj -mtriple=riscv64 \ + -mattr=+buddyext,+D -float-abi=hard \ + -o log.o + @riscv64-unknown-linux-gnu-gcc log.o -O2 -static -o a.out + @spike --extension=gemmini pk a.out + compute-accumulated-run: @${BUDDY_OPT} ./compute-accumulated.mlir -lower-gemmini | \ ${BUDDY_TRANSLATE} --buddy-to-llvmir | \ diff --git a/examples/GemminiDialect/matmul-transpose-fuse.mlir b/examples/GemminiDialect/matmul-transpose-fuse.mlir index 51f2d350af..138a9c8efd 100644 --- a/examples/GemminiDialect/matmul-transpose-fuse.mlir +++ b/examples/GemminiDialect/matmul-transpose-fuse.mlir @@ -2,21 +2,35 @@ // RUN: --convert-linalg-to-gemmini | \ // RUN: FileCheck %s -func.func @matmul_transpose(%lhs: memref<3x4xi8>, %rhs: memref<4x3xi8>, - %output: memref<3x3xi8>) { +memref.global "private" @gv1 : memref<3x4xi8> = dense<[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]]> +memref.global "private" @gv2 : memref<4x3xi8> = dense<[[1, 1, 1], + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]> + +func.func @main() -> i8 { + %arrayA = memref.get_global @gv1 : memref<3x4xi8> + %arrayB = memref.get_global @gv2 : memref<4x3xi8> + %arrayC = memref.alloc() : memref<3x3xi8> + %cst0 = arith.constant 0 : i8 + gemmini.print %arrayC : memref<3x3xi8> // Matrix-matrix multiplication - %matmul = memref.alloc() : memref<3x3xi8> - // CHECK: gemmini.tile_matmul %arg0 %arg1 %arg2 %alloc_0 + // CHECK: gemmini.tile_matmul %1 %0 %alloc %alloc_0 {aTranspose = true, bTranspose = true} : + // CHECK-SAME: memref<4x3xi8> memref<3x4xi8> memref<3x3xi8> memref<3x4xi32> linalg.matmul - ins(%lhs, %rhs: memref<3x4xi8>, memref<4x3xi8>) - outs(%output: memref<3x3xi8>) + ins(%arrayA, %arrayB: memref<3x4xi8>, memref<4x3xi8>) + outs(%arrayC: memref<3x3xi8>) // transpose linalg.transpose - ins(%matmul: memref<3x3xi8>) - outs(%output: memref<3x3xi8>) + ins(%arrayC: memref<3x3xi8>) + outs(%arrayC: memref<3x3xi8>) permutation = [1, 0] - memref.dealloc %matmul : memref<3x3xi8> - return + gemmini.print %arrayC : memref<3x3xi8> + memref.dealloc %arrayC : memref<3x3xi8> + + return %cst0 : i8 } diff --git a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp index 5ce24a5e90..609b18490c 100644 --- a/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp +++ b/midend/lib/Conversion/LowerLinalgToGemmini/LowerLinalgToGemmini.cpp @@ -64,22 +64,35 @@ class MatmulLowering : public OpRewritePattern { rewriter.create(loc, fillOpInsType, fillOpInputAttr); rewriter.create(loc, fillOpInputValue, bias); + // llvm::outs() << " has " + // << std::distance(output0.getUses().begin(), + // output0.getUses().end()) + // << " uses:\n"; + // for (Operation *userOp : output0.getUsers()) { + // llvm::outs() << " - " << userOp->getName() << "\n"; + // } + // If this matmul operation is followed by a transpose operation, do fusion. - // We assume that the result of this matmul op only has one user. - if (matMulOp->hasOneUse()) { - // llvm::outs() << "Step in. \n"; - Operation* userOp = *matMulOp->user_begin(); + // We should make sure that the result of this matmul op only has one user. + Operation* fuseOp = *output0.user_begin(); + int output0Use = 0; + for (auto userOp : output0.getUsers()) { if (auto transposeOp = dyn_cast(userOp)) { - // (A * B)T = BT * AT - rewriter.replaceOpWithNewOp( + fuseOp = transposeOp; + output0Use ++; + } + } + + if (output0Use) { + // llvm::outs() << "Fuse linalg.matmul and linalg.transpose. \n"; + rewriter.replaceOpWithNewOp( matMulOp, input1, input0, output0, bias, /*aScaleFactor = */ scale1, /*bScaleFactor = */ scale1, /*dScaleFactor = */ scale1, /*act = */0, /*accScale = */ scale1, /*bertScale = */ scale0, - /*aTranspose = */ true, /*bTranspose = */ true); - rewriter.eraseOp(transposeOp); - } + /*repeatingBias = */ false, /*aTranspose = */ true, + /*bTranspose = */ true); + rewriter.eraseOp(fuseOp); } else { - // llvm::outs() << "Not step in. \n"; rewriter.replaceOpWithNewOp( matMulOp, input0, input1, output0, bias, /*aScaleFactor = */ scale1, /*bScaleFactor = */ scale1, /*dScaleFactor = */ scale1, /*act = */ 0,