diff --git a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp index 345de4c1de..2f81414a33 100644 --- a/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp +++ b/midend/lib/Conversion/MatMulOptimization/MatMulTransposeBVec.cpp @@ -19,12 +19,16 @@ //===----------------------------------------------------------------------===// #include +#include #include #include +#include +#include #include #include #include #include +#include #include #include "Utils/Utils.h" @@ -37,132 +41,138 @@ using namespace vector; //===----------------------------------------------------------------------===// namespace { -class MatMulTransposeBVecPattern : public ConversionPattern{ +class MatMulTransposeBVecPattern : public ConversionPattern { public: - explicit MatMulTransposeBVecPattern(MLIRContext *context,int64_t vecSizeparam) - : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(),1,context){ - vecSize = vecSizeparam; - } - - LogicalResult - matchAndRewrite(Operation *op,ArrayRef /*operands*/, - ConversionPatternRewriter &rewriter) const override{ - auto loc = op->getLoc(); - auto ctx = op->getContext(); - // Get input A, B, C. - Value A = op->getOperand(0); - Value B = op->getOperand(1); - Value C = op->getOperand(2); - - // Get shape of input and output. - ShapedType ATy = A.getType().cast(); - Type eleTy = ATy.getElementType(); - - // the element type for mask vector. - IntegerType i1 = IntegerType::get(ctx, 1); - - VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); - VectorType vectorMaskTy = VectorType::get({vecSize}, i1); - - const Value c0 = - rewriter.create(loc, rewriter.getIndexAttr(0)); - const Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); - const Value step = rewriter.create(loc, vecSize); - - const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); - Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); - - const Value aRow = rewriter.create(loc, A, c0); - const Value bRow = rewriter.create(loc, B, c0); - const Value bCol = rewriter.create(loc, B, c1); - - AffineExpr d0; - bindDims(ctx, d0); - AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); - SmallVector lowerBounds(2, c0); - SmallVector uperBounds{aRow, bRow}; - SmallVector steps(2, 1); - // clang-format off - affine::buildAffineLoopNest( - rewriter, loc, lowerBounds, uperBounds, steps, - [&](OpBuilder &builder, Location loc, ValueRange ivs) { - // Create loop based on vector size. - builder.create( - loc, ValueRange{c0}, builder.getDimIdentityMap(), - ValueRange{bCol}, vecTailMap, 1, std::nullopt, - [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, - ValueRange itrArgs) { - AffineExpr a,b,c; - bindDims(ctx, a,b,c); - AffineMap AVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); - // Check tail. - AffineExpr m, n, k; - bindDims(ctx, m, n, k); - AffineMap BVectorMap = AffineMap::get( - /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); - - // Calculate the tail. - Value bColCur = builder.create(loc, iv, step); - Value tailLen = builder.create(loc, bCol, bColCur); - Value tailFlag = rewriter.create( - loc, arith::CmpIPredicate::sge, tailLen, step); - // If the current column does not reach the tail. - builder.create(loc, tailFlag, - [&](OpBuilder &builder, Location loc) { - Value aVec = builder.create( - loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - Value bVec = builder.create( - loc, vectorTy, B, BVectorMap, ValueRange{ivs[1], ivs[1], iv}); - Value resvec = builder.create(loc,aVec,bVec); - Value res1 = builder.create( - loc,mlir::vector::CombiningKind::ADD,resvec); - Value res2 = builder.create( - loc, C, ValueRange{ivs[0], ivs[1]}); - Value sum = builder.create(loc, res1, res2); - builder.create(loc, sum, - C, ValueRange{ivs[0], ivs[1]}); - builder.create(loc); - }, - // The else branch - [&](OpBuilder &builder, Location loc) { - // TODO: remove this value and operation? - // Value aVec = builder.create( - // loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - builder.create( - loc, vectorTy, A, AVectorMap, ValueRange{ivs[0], ivs[1], iv}); - // Create mask according to the tail. - Value maskVec = builder.create( - loc, vectorMaskTy, tailLen); - Value ColIdxTail = builder.create(loc, iv, step); - - Value aVecTail = builder.create( - loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, - maskVec, passthruVec); - - Value bVecTail = builder.create( - loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, - maskVec, passthruVec); - - Value resvec = builder.create(loc,aVecTail,bVecTail); - Value res1 = builder.create( - loc,mlir::vector::CombiningKind::ADD,resvec); - Value res2 = builder.create( - loc, C, ValueRange{ivs[0], ivs[1]}); - Value sum = builder.create(loc, res1, res2); - builder.create(loc, sum, C, ValueRange{ivs[0], ivs[1]}); - builder.create(loc); - }); - builder.create(loc); - }); + explicit MatMulTransposeBVecPattern(MLIRContext *context, + int64_t vecSizeparam) + : ConversionPattern(linalg::MatmulTransposeBOp::getOperationName(), 1, + context) { + vecSize = vecSizeparam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto ctx = op->getContext(); + // Get input A, B, C. + Value A = op->getOperand(0); + Value B = op->getOperand(1); + Value C = op->getOperand(2); + + // Get shape of input and output. + ShapedType ATy = A.getType().cast(); + Type eleTy = ATy.getElementType(); + + // the element type for mask vector. + IntegerType i1 = IntegerType::get(ctx, 1); + + VectorType vectorTy = mlir::VectorType::get({vecSize}, eleTy); + VectorType vectorMaskTy = VectorType::get({vecSize}, i1); + + const Value c0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value c1 = + rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value step = rewriter.create(loc, vecSize); + + const Value c0Ele = buddy::insertZeroConstantOp(ctx, rewriter, loc, eleTy); + Value passthruVec = rewriter.create(loc, vectorTy, c0Ele); + + const Value aRow = rewriter.create(loc, A, c0); + const Value bRow = rewriter.create(loc, B, c0); + const Value bCol = rewriter.create(loc, B, c1); + + AffineExpr d0; + bindDims(ctx, d0); + AffineMap vecTailMap = AffineMap::get(1, 0, {d0.ceilDiv(vecSize)}, ctx); + SmallVector lowerBounds(2, c0); + SmallVector uperBounds{aRow, bRow}; + SmallVector steps(2, 1); + + affine::buildAffineLoopNest( + rewriter, loc, lowerBounds, uperBounds, steps, + [&](OpBuilder &builder, Location loc, ValueRange ivs) { + // Create loop based on vector size. + auto innerLoop = builder.create( + loc, ValueRange{c0}, builder.getDimIdentityMap(), + ValueRange{bCol}, vecTailMap, 1, ValueRange{passthruVec}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, + ValueRange itrArgs) { + Value acc = itrArgs[0]; + + AffineExpr a, b, c; + bindDims(ctx, a, b, c); + AffineMap AVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {a, c * vecSize}, ctx); + // Check tail. + AffineExpr m, n, k; + bindDims(ctx, m, n, k); + AffineMap BVectorMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, {m, k * vecSize}, ctx); + + // Calculate the tail. + Value bColCur = builder.create(loc, iv, step); + Value tailLen = + builder.create(loc, bCol, bColCur); + Value tailFlag = rewriter.create( + loc, arith::CmpIPredicate::sge, tailLen, step); + // If the current column does not reach the tail. + auto ifOp = builder.create( + loc, tailFlag, + [&](OpBuilder &builder, Location loc) { + Value aVec = builder.create( + loc, vectorTy, A, AVectorMap, + ValueRange{ivs[0], ivs[1], iv}); + Value bVec = builder.create( + loc, vectorTy, B, BVectorMap, + ValueRange{ivs[1], ivs[1], iv}); + Value resvec = + builder.create(loc, aVec, bVec); + Value newAcc = + builder.create(loc, acc, resvec); + builder.create(loc, newAcc); + }, + // The else branch + [&](OpBuilder &builder, Location loc) { + // Create mask according to the tail. + Value maskVec = builder.create( + loc, vectorMaskTy, tailLen); + Value ColIdxTail = + builder.create(loc, iv, step); + + Value aVecTail = builder.create( + loc, vectorTy, A, ValueRange{ivs[0], ColIdxTail}, + maskVec, passthruVec); + + Value bVecTail = builder.create( + loc, vectorTy, B, ValueRange{ivs[1], ColIdxTail}, + maskVec, passthruVec); + + Value resvec = builder.create( + loc, aVecTail, bVecTail); + Value newAcc = + builder.create(loc, acc, resvec); + builder.create(loc, newAcc); + }); + builder.create(loc, ifOp.getResult(0)); + }); + + Value load = builder.create( + loc, C, ValueRange{ivs[0], ivs[1]}); + Value reduction = builder.create( + loc, CombiningKind::ADD, innerLoop->getResult(0), load, + arith::FastMathFlags::reassoc); + builder.create(loc, reduction, C, + ValueRange{ivs[0], ivs[1]}); }); - // clang-format on - rewriter.eraseOp(op); - return success(); - } + + rewriter.eraseOp(op); + return success(); + } + private: - int64_t vecSize; + int64_t vecSize; }; } // end anonymous namespace @@ -170,41 +180,45 @@ class MatMulTransposeBVecPattern : public ConversionPattern{ // MatMulVectorizationPass //===----------------------------------------------------------------------===// -namespace{ - class MatMulTransposeBVecPass - :public PassWrapper>{ +namespace { +class MatMulTransposeBVecPass + : public PassWrapper> { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) - StringRef getArgument() const final{ return "matmul-transpose-b-vectorization"; } - StringRef getDescription() const final { return "vectorize linalg MatmulTransposeBOp"; } - MatMulTransposeBVecPass() = default; - MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} - void runOnOperation() override; - void getDependentDialects(DialectRegistry ®istry) const override{ - registry.insert(); - } - Option vecSize{*this,"vec-size", - llvm::cl::desc("The size of vectorization"), - llvm::cl::init(32)}; - + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MatMulTransposeBVecPass) + StringRef getArgument() const final { + return "matmul-transpose-b-vectorization"; + } + StringRef getDescription() const final { + return "vectorize linalg MatmulTransposeBOp"; + } + MatMulTransposeBVecPass() = default; + MatMulTransposeBVecPass(const MatMulTransposeBVecPass &) {} + void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + Option vecSize{*this, "vec-size", + llvm::cl::desc("The size of vectorization"), + llvm::cl::init(32)}; }; -} +} // namespace -void MatMulTransposeBVecPass::runOnOperation(){ - MLIRContext *context = &getContext(); - ModuleOp module = getOperation(); +void MatMulTransposeBVecPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); - ConversionTarget target(*context); - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); - RewritePatternSet patterns(context); - patterns.add(context,vecSize); + RewritePatternSet patterns(context); + patterns.add(context, vecSize); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/tests/Conversion/matmul-transpose-b-vectorization.mlir b/tests/Conversion/matmul-transpose-b-vectorization.mlir new file mode 100644 index 0000000000..391c7ce84d --- /dev/null +++ b/tests/Conversion/matmul-transpose-b-vectorization.mlir @@ -0,0 +1,101 @@ +// RUN: buddy-opt %s \ +// RUN: -matmul-transpose-b-vectorization="vec-size=64" \ +// RUN: -convert-linalg-to-loops -lower-affine -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm -finalize-memref-to-llvm -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +module{ + func.func private @printMemrefF32(memref<*xf32>) + func.func private @printMemrefF64(memref<*xf64>) + + func.func @matmul_f32(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + + func.func @matmul_f64(%a : memref, %b : memref, %c : memref) { + linalg.matmul_transpose_b + ins(%a, %b: memref, memref) + outs(%c:memref) + return + } + + func.func @main(){ + // Set up dims. + %cM = arith.constant 4 : index + %cN = arith.constant 4 : index + %cK = arith.constant 4 : index + + //-------------------------------------------------------------------------- + // Test f32 as element type. + //-------------------------------------------------------------------------- + + // Set Init Value. + %cf1_32 = arith.constant 1.0 : f32 + + %A_f32 = memref.alloc(%cM, %cK) : memref + %B_f32 = memref.alloc(%cK, %cN) : memref + %C_f32 = memref.alloc(%cM, %cN) : memref + + linalg.fill ins(%cf1_32 : f32) outs(%A_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%B_f32 : memref) + linalg.fill ins(%cf1_32 : f32) outs(%C_f32 : memref) + + call @matmul_f32(%A_f32, %B_f32, %C_f32) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C_f32 = memref.cast %C_f32 : memref to memref<*xf32> + call @printMemrefF32(%print_C_f32) : (memref<*xf32>) -> () + + memref.dealloc %C_f32 : memref + memref.dealloc %B_f32 : memref + memref.dealloc %A_f32 : memref + + //-------------------------------------------------------------------------- + // Test f64 as element type. + //-------------------------------------------------------------------------- + + // Set Init Value. + %cf1_64 = arith.constant 1.0 : f64 + + %A_f64 = memref.alloc(%cM, %cK) : memref + %B_f64 = memref.alloc(%cK, %cN) : memref + %C_f64 = memref.alloc(%cM, %cN) : memref + + linalg.fill ins(%cf1_64 : f64) outs(%A_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%B_f64 : memref) + linalg.fill ins(%cf1_64 : f64) outs(%C_f64 : memref) + + call @matmul_f64(%A_f64, %B_f64, %C_f64) : (memref, memref, memref) -> () + + // Print output. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5], + // CHECK-NEXT: [5, 5, 5, 5] + // CHECK-SAME: ] + %print_C_f64 = memref.cast %C_f64 : memref to memref<*xf64> + call @printMemrefF64(%print_C_f64) : (memref<*xf64>) -> () + + memref.dealloc %C_f64 : memref + memref.dealloc %B_f64 : memref + memref.dealloc %A_f64 : memref + + return + } +}