From c8360ca7f794f559dd05e85769a6d66b49f30ef9 Mon Sep 17 00:00:00 2001 From: Junyi Mei Date: Tue, 25 Feb 2025 19:52:38 +0800 Subject: [PATCH] [midend] Optimization on matmul-transpose-b vectorization (#465) Current vectorization pass of matmul-transpose-b reduce the vector in each iteration and accumulate it to the result element. This commit modify it into elementwise addition and do the reduction after the inner loop with reassoc enabled. Signed-off-by: Junyi Mei --- .../MatMulTransposeBVec.cpp | 316 +++++++++--------- .../matmul-transpose-b-vectorization.mlir | 101 ++++++ 2 files changed, 266 insertions(+), 151 deletions(-) create mode 100644 tests/Conversion/matmul-transpose-b-vectorization.mlir diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp index 345de4c1de..2f81414a33 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp @@ -19,12 +19,16 @@ //===----------------------------------------------------------------------===// #include +#include #include #include +#include +#include #include #include #include #include +#include #include #include "Utils/Utils.h" @@ -37,132 +41,138 @@ using namespace vector; //===----------------------------------------------------------------------===// namespace { -class MatMulTransposeBVecPattern : public ConversionPattern{ +class MatMulTransposeBVecPattern : public ConversionPattern { public: - explicit MatMulTransposeBVecPattern(MLIRContext *context,int64_t vecSizeparam) - : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(),1,context){ - vecSize = vecSizeparam; - } - - LogicalResult - matchAndRewrite(Operation *op,ArrayRef /*operands*/, - ConversionPatternRewriter &rewriter) const override{ - auto loc = op->getLoc(); - auto ctx = op->getContext(); - // Get input A, B, C. - Value A = op->getOperand(0); - Value B = op->getOperand(1); - Value C = op->getOperand(2); - - // Get shape of input and output. - ShapedType ATy = A.getType().cast(); - Type eleTy = ATy.getElementType(); - - // the element type for mask vector. - IntegerType i1 = IntegerType::get(ctx, 1); - - VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); - VectorType vectorMaskTy = VectorType::get({vecSize}, i1); - - const Value c0 = - rewriter.create(loc, rewriter.getIndexAttr(0)); - const Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); - const Value step = rewriter.create(loc, vecSize); - - const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); - Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); - - const Value aRow = rewriter.create(loc, A, c0); - const Value bRow = rewriter.create(loc, B, c0); - const Value bCol = rewriter.create(loc, B, c1); - - AffineExpr d0; - bindDims(ctx, d0); - AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); - SmallVector lowerBounds(2, c0); - SmallVector uperBounds{aRow, bRow}; - SmallVector steps(2, 1); - // clang-format off - affine::buildAffineLoopNest( - rewriter, loc, lowerBounds, uperBounds, steps, - [&](OpBuilder &builder, Location loc, ValueRange ivs) { - // Create loop based on vector size. - builder.create( - loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{bCol}, vecTailMap, 1, std::nullopt, - [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, - ValueRange itrArgs) { - AffineExpr a,b,c; - bindDims(ctx, a,b,c); - AffineMap AVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); - // Check tail. - AffineExpr m, n, k; - bindDims(ctx, m, n, k); - AffineMap BVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); - - // Calculate the tail. - Value bColCur = builder.create(loc, iv, step); - Value tailLen = builder.create(loc, bCol, bColCur); - Value tailFlag = rewriter.create( - loc, arith::CmpIPredicate::sge, tailLen, step); - // If the current column does not reach the tail. - builder.create(loc, tailFlag, - [&](OpBuilder &builder, Location loc) { - Value aVec = builder.create( - loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - Value bVec = builder.create( - loc, vectorTy, B, BVectorMap, ValueRange{ivs[1], ivs[1], iv}); - Value resvec = builder.create(loc,aVec,bVec); - Value res1 = builder.create( - loc,mlir::vector::CombiningKind::ADD,resvec); - Value res2 = builder.create( - loc, C, ValueRange{ivs[0], ivs[1]}); - Value sum = builder.create(loc, res1, res2); - builder.create(loc, sum, - C, ValueRange{ivs[0], ivs[1]}); - builder.create(loc); - }, - // The else branch - [&](OpBuilder &builder, Location loc) { - // TODO: remove this value and operation? - // Value aVec = builder.create( - // loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - builder.create( - loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - // Create mask according to the tail. - Value maskVec = builder.create( - loc, vectorMaskTy, tailLen); - Value ColIdxTail = builder.create(loc, iv, step); - - Value aVecTail = builder.create( - loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, - maskVec, passthruVec); - - Value bVecTail = builder.create( - loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, - maskVec, passthruVec); - - Value resvec = builder.create(loc,aVecTail,bVecTail); - Value res1 = builder.create( - loc,mlir::vector::CombiningKind::ADD,resvec); - Value res2 = builder.create( - loc, C, ValueRange{ivs[0], ivs[1]}); - Value sum = builder.create(loc, res1, res2); - builder.create(loc, sum, C, ValueRange{ivs[0], ivs[1]}); - builder.create(loc); - }); - builder.create(loc); - }); + explicit MatMulTransposeBVecPattern(MLIRContext *context, + int64_t vecSizeparam) + : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(), 1, + context) { + vecSize = vecSizeparam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Get shape of input and output. + ShapedType ATy = A.getType().cast(); + Type eleTy = ATy.getElementType(); + + // the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + + VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); + VectorType vectorMaskTy = VectorType::get({vecSize}, i1); + + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value step = rewriter.create(loc, vecSize); + + const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); + Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); + + const Value aRow = rewriter.create(loc, A, c0); + const Value bRow = rewriter.create(loc, B, c0); + const Value bCol = rewriter.create(loc, B, c1); + + AffineExpr d0; + bindDims(ctx, d0); + AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); + SmallVector lowerBounds(2, c0); + SmallVector uperBounds{aRow, bRow}; + SmallVector steps(2, 1); + + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create loop based on vector size. + auto innerLoop = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{bCol}, vecTailMap, 1, ValueRange{passthruVec}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + Value acc = itrArgs[0]; + + AffineExpr a, b, c; + bindDims(ctx, a, b, c); + AffineMap AVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); + // Check tail. + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + AffineMap BVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); + + // Calculate the tail. + Value bColCur = builder.create(loc, iv, step); + Value tailLen = + builder.create(loc, bCol, bColCur); + Value tailFlag = rewriter.create( + loc, arith::CmpIPredicate::sge, tailLen, step); + // If the current column does not reach the tail. + auto ifOp = builder.create( + loc, tailFlag, + [&](OpBuilder &builder, Location loc) { + Value aVec = builder.create( + loc, vectorTy, A, AVectorMap, + ValueRange{ivs[0], ivs[1], iv}); + Value bVec = builder.create( + loc, vectorTy, B, BVectorMap, + ValueRange{ivs[1], ivs[1], iv}); + Value resvec = + builder.create(loc, aVec, bVec); + Value newAcc = + builder.create(loc, acc, resvec); + builder.create(loc, newAcc); + }, + // The else branch + [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value maskVec = builder.create( + loc, vectorMaskTy, tailLen); + Value ColIdxTail = + builder.create(loc, iv, step); + + Value aVecTail = builder.create( + loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, + maskVec, passthruVec); + + Value bVecTail = builder.create( + loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, + maskVec, passthruVec); + + Value resvec = builder.create( + loc, aVecTail, bVecTail); + Value newAcc = + builder.create(loc, acc, resvec); + builder.create(loc, newAcc); + }); + builder.create(loc, ifOp.getResult(0)); + }); + + Value load = builder.create( + loc, C, ValueRange{ivs[0], ivs[1]}); + Value reduction = builder.create( + loc, CombiningKind::ADD, innerLoop->getResult(0), load, + arith::FastMathFlags::reassoc); + builder.create(loc, reduction, C, + ValueRange{ivs[0], ivs[1]}); }); - // clang-format on - rewriter.eraseOp(op); - return success(); - } + + rewriter.eraseOp(op); + return success(); + } + private: - int64_t vecSize; + int64_t vecSize; }; } // end anonymous namespace @@ -170,41 +180,45 @@ class MatMulTransposeBVecPattern : public ConversionPattern{ // MatMulVectorizationPass //===----------------------------------------------------------------------===// -namespace{ - class MatMulTransposeBVecPass - :public PassWrapper>{ +namespace { +class MatMulTransposeBVecPass + : public PassWrapper> { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) - StringRef getArgument() const final{ return "matmul-transpose-b-vectorization"; } - StringRef getDescription() const final { return "vectorize linalg MatmulTransposeBOp"; } - MatMulTransposeBVecPass() = default; - MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} - void runOnOperation() override; - void getDependentDialects(DialectRegistry ®istry) const override{ - registry.insert(); - } - Option vecSize{*this,"vec-size", - llvm::cl::desc("The size of vectorization"), - llvm::cl::init(32)}; - + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) + StringRef getArgument() const final { + return "matmul-transpose-b-vectorization"; + } + StringRef getDescription() const final { + return "vectorize linalg MatmulTransposeBOp"; + } + MatMulTransposeBVecPass() = default; + MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option vecSize{*this, "vec-size", + llvm::cl::desc("The size of vectorization"), + llvm::cl::init(32)}; }; -} +} // namespace -void MatMulTransposeBVecPass::runOnOperation(){ - MLIRContext *context = &getContext(); - ModuleOp module = getOperation(); +void MatMulTransposeBVecPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); - ConversionTarget target(*context); - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); - RewritePatternSet patterns(context); - patterns.add(context,vecSize); + RewritePatternSet patterns(context); + patterns.add(context, vecSize); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/tests/Conversion/matmul-transpose-b-vectorization.mlir b/tests/Conversion/matmul-transpose-b-vectorization.mlir new file mode 100644 index 0000000000..391c7ce84d --- /dev/null +++ b/tests/Conversion/matmul-transpose-b-vectorization.mlir @@ -0,0 +1,101 @@ +// RUN: buddy-opt %s \ +// RUN: -matmul-transpose-b-vectorization="vec-size=64" \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module{ + func.func private @printMemrefF32(memref<*xf32>) + func.func private @printMemrefF64(memref<*xf64>) + + func.func @matmul_f32(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + + func.func @matmul_f64(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + + func.func @main(){ + // Set up dims. + %cM = arith.constant 4 : index + %cN = arith.constant 4 : index + %cK = arith.constant 4 : index + + //-------------------------------------------------------------------------- + // Test f32 as element type. + //-------------------------------------------------------------------------- + + // Set Init Value. + %cf1_32 = arith.constant 1.0 : f32 + + %A_f32 = memref.alloc(%cM, %cK) : memref + %B_f32 = memref.alloc(%cK, %cN) : memref + %C_f32 = memref.alloc(%cM, %cN) : memref + + linalg.fill ins(%cf1_32 : f32) outs(%A_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%B_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%C_f32 : memref) + + call @matmul_f32(%A_f32, %B_f32, %C_f32) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C_f32 = memref.cast %C_f32 : memref to memref<*xf32> + call @printMemrefF32(%print_C_f32) : (memref<*xf32>) -> () + + memref.dealloc %C_f32 : memref + memref.dealloc %B_f32 : memref + memref.dealloc %A_f32 : memref + + //-------------------------------------------------------------------------- + // Test f64 as element type. + //-------------------------------------------------------------------------- + + // Set Init Value. + %cf1_64 = arith.constant 1.0 : f64 + + %A_f64 = memref.alloc(%cM, %cK) : memref + %B_f64 = memref.alloc(%cK, %cN) : memref + %C_f64 = memref.alloc(%cM, %cN) : memref + + linalg.fill ins(%cf1_64 : f64) outs(%A_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%B_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%C_f64 : memref) + + call @matmul_f64(%A_f64, %B_f64, %C_f64) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C_f64 = memref.cast %C_f64 : memref to memref<*xf64> + call @printMemrefF64(%print_C_f64) : (memref<*xf64>) -> () + + memref.dealloc %C_f64 : memref + memref.dealloc %B_f64 : memref + memref.dealloc %A_f64 : memref + + return + } +}