From 7aecaa59e6e55c5d72ccdf5b5f854267a74526ff Mon Sep 17 00:00:00 2001 From: Javier Setoain Date: Tue, 18 Jun 2024 16:42:53 -0600 Subject: [PATCH] [aievec] Make transpose lowering more progressive --- .../Transforms/VectorToAIEVecConversions.cpp | 64 ++++++++++++++- .../Transforms/VectorToVectorConversions.cpp | 77 ++++++++++++++++++- .../AIEVec/precanonicalization-aieml.mlir | 50 ++++++++++-- 3 files changed, 184 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp index 438a0350ca..f3332ec563 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp @@ -2979,6 +2979,67 @@ struct LowerVectorTransposeOpToAIEVecShuffleOpPattern } }; +// Convert a `vector.flat_transpose` op to an `aievec.shuffle` op for AIEml. +struct LowerVectorFlatTransposeOpToAIEVecShuffleOpPattern + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(vector::FlatTransposeOp transpOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto rows = transpOp.getRows(); + auto cols = transpOp.getColumns(); + auto resVecTy = cast(transpOp.getResult().getType()); + auto elemTyBitWidth = resVecTy.getElementTypeBitWidth(); + auto vBitWidth = elemTyBitWidth * rows * cols; + + if (vBitWidth != 512) + return failure(); + + if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32) + return failure(); + + auto shuffleMode = aievec::ShuffleMode::T32_4X4; + if (elemTyBitWidth == 8) { + switch (rows) { + case 4: + shuffleMode = aievec::ShuffleMode::T8_4X16; + break; + case 8: + shuffleMode = aievec::ShuffleMode::T8_8X8; + break; + case 16: + shuffleMode = aievec::ShuffleMode::T8_16X4; + break; + default: + return failure(); + } + } else if (elemTyBitWidth == 16) { + switch (rows) { + case 2: + shuffleMode = aievec::ShuffleMode::T16_2X16; + break; + case 4: + shuffleMode = aievec::ShuffleMode::T16_4X8; + break; + case 8: + shuffleMode = aievec::ShuffleMode::T16_8X4; + break; + case 16: + shuffleMode = aievec::ShuffleMode::T16_16X2; + break; + default: + return failure(); + } + } else if (cols != 4) + return failure(); + + rewriter.replaceOpWithNewOp( + transpOp, resVecTy, adaptor.getMatrix(), nullptr, shuffleMode); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pattern collection //===----------------------------------------------------------------------===// @@ -3061,7 +3122,8 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns, ConvertMulAddToAIEVecFMAElemOpPattern, ConvertVectorFMAOpToAIEVecFMAElemOpPattern, LowerVectorExtractStridedSliceOpAIE2Pattern, - LowerVectorTransposeOpToAIEVecShuffleOpPattern + LowerVectorTransposeOpToAIEVecShuffleOpPattern, + LowerVectorFlatTransposeOpToAIEVecShuffleOpPattern >(patterns.getContext()); patterns.add(patterns.getContext(), backend == TargetBackend::CPP); diff --git a/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp b/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp index e12ebe189a..7de31b9fba 100644 --- a/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp +++ b/lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp @@ -627,6 +627,54 @@ struct ExtractTransposeFromContractionOp } }; +// This pattern flattens a `vector.transpose` operation for shapes that can be +// handled by basic AIE shuffle ops. +struct FlattenVectorTransposeOpPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto resTy = transpOp.getResultVectorType(); + auto resShape = resTy.getShape(); + auto elemTyBitWidth = resTy.getElementTypeBitWidth(); + auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(), + elemTyBitWidth, std::multiplies<>()); + if (vBitWidth != 512) + return failure(); + + if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32) + return failure(); + + // Verify leading dimensions are all 1. + for (int64_t i = 0; i < static_cast(resShape.size() - 2); ++i) + if (resShape[i] != 1) + return failure(); + + // Only permutation of the 2 innermost dimensions are supported. + ArrayRef perm = transpOp.getPermutation(); + for (int64_t i = 0; i < static_cast(perm.size() - 2); ++i) + if (perm[i] != i) + return failure(); + if (perm.back() != static_cast(perm.size() - 2)) + return failure(); + + auto flatVecTy = + VectorType::get({512 / elemTyBitWidth}, resTy.getElementType()); + auto loc = transpOp.getLoc(); + auto flatInput = rewriter.create(loc, flatVecTy, + adaptor.getVector()); + auto flatTranspOp = rewriter.create( + loc, flatVecTy, flatInput, static_cast(resShape.back()), + static_cast(resShape[resShape.size() - 2])); + rewriter.replaceOpWithNewOp(transpOp, resTy, + flatTranspOp); + + return success(); + } +}; + //============================================================================// //============ AIE2 canonicalization conversion patterns ===============// //============================================================================// @@ -690,6 +738,32 @@ static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target, [](vector::ContractionOp op) { return !isGemmBTransposedContractionOp(op); }); + target.addDynamicallyLegalOp([](vector::TransposeOp op) { + auto resTy = op.getResultVectorType(); + auto resShape = resTy.getShape(); + auto elemTyBitWidth = resTy.getElementTypeBitWidth(); + auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(), + elemTyBitWidth, std::multiplies<>()); + if (vBitWidth != 512) + return true; + + if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32) + return true; + + // Verify leading dimensions are all 1. + for (int64_t i = 0; i < static_cast(resShape.size() - 2); ++i) + if (resShape[i] != 1) + return true; + + // Only permutation of the 2 innermost dimensions are supported. + ArrayRef perm = op.getPermutation(); + for (int64_t i = 0; i < static_cast(perm.size() - 2); ++i) + if (perm[i] != i) + return true; + if (perm.back() != static_cast(perm.size() - 2)) + return true; + return false; + }); } static void @@ -699,7 +773,8 @@ populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns, 256); patterns .add(patterns.getContext()); + FlattenMultDimTransferWritePattern, FlattenVectorTransposeOpPattern>( + patterns.getContext()); } //============================================================================// diff --git a/test/dialect/AIEVec/precanonicalization-aieml.mlir b/test/dialect/AIEVec/precanonicalization-aieml.mlir index fbc79fd337..60a990bb5e 100644 --- a/test/dialect/AIEVec/precanonicalization-aieml.mlir +++ b/test/dialect/AIEVec/precanonicalization-aieml.mlir @@ -104,8 +104,13 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>, %B : vector<1x1x4x8xbf16>, %C : vector<1x1x4x4xf32>) -> vector<1x1x4x4xf32> { - // CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] : - // CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16> + // CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]] + // CHECK-SAME: : vector<1x1x4x8xbf16> to vector<32xbf16> + // CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]] + // CHECK-SAME: {columns = 8 : i32, rows = 4 : i32} + // CHECK-SAME: : vector<32xbf16> -> vector<32xbf16> + // CHECK: %[[TRB:.*]] = vector.shape_cast %[[FVBT]] + // CHECK-SAME: : vector<32xbf16> to vector<1x1x8x4xbf16> // CHECK: %[[RES:.*]] = vector.contract { // CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", @@ -144,8 +149,13 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>, -> vector<1x1x4x4xf32> { // CHECK: %[[LHS:.*]] = arith.extf %[[VA]] : // CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x4x8xf32> - // CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] : - // CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16> + // CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]] + // CHECK-SAME: : vector<1x1x4x8xbf16> to vector<32xbf16> + // CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]] + // CHECK-SAME: {columns = 8 : i32, rows = 4 : i32} + // CHECK-SAME: : vector<32xbf16> -> vector<32xbf16> + // CHECK: %[[TRB:.*]] = vector.shape_cast %[[FVBT]] + // CHECK-SAME: : vector<32xbf16> to vector<1x1x8x4xbf16> // CHECK: %[[RHS:.*]] = arith.extf %[[TRB]] : // CHECK-SAME: vector<1x1x8x4xbf16> to vector<1x1x8x4xf32> // CHECK: %[[RES:.*]] = vector.contract { @@ -165,4 +175,34 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>, kind = #vector.kind} %lhs, %rhs, %C : vector<1x1x4x8xf32>, vector<1x1x4x8xf32> into vector<1x1x4x4xf32> return %res : vector<1x1x4x4xf32> -} \ No newline at end of file +} + +// +// ----- +// + +// CHECK-LABEL: func.func @vector_transpose( +// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<4x8xbf16>, +// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<8x4xbf16>) +func.func @vector_transpose(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>) + -> (vector<4x8xbf16>, vector<8x4xbf16>) { + // CHECK: %[[FVA:.*]] = vector.shape_cast %[[VA]] + // CHECK-SAME: : vector<4x8xbf16> to vector<32xbf16> + // CHECK: %[[FVAT:.*]] = vector.flat_transpose %[[FVA]] + // CHECK-SAME: {columns = 8 : i32, rows = 4 : i32} + // CHECK-SAME: : vector<32xbf16> -> vector<32xbf16> + // CHECK: %[[VAT:.*]] = vector.shape_cast %[[FVAT]] + // CHECK-SAME: : vector<32xbf16> to vector<8x4xbf16> + %tA = vector.transpose %A, [1, 0] : vector<4x8xbf16> to vector<8x4xbf16> + // CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]] + // CHECK-SAME: : vector<8x4xbf16> to vector<32xbf16> + // CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]] + // CHECK-SAME: {columns = 4 : i32, rows = 8 : i32} + // CHECK-SAME: : vector<32xbf16> -> vector<32xbf16> + // CHECK: %[[VBT:.*]] = vector.shape_cast %[[FVBT]] + // CHECK-SAME: : vector<32xbf16> to vector<4x8xbf16> + %tB = vector.transpose %B, [1, 0] : vector<8x4xbf16> to vector<4x8xbf16> + %AtB = arith.addf %A, %tB : vector<4x8xbf16> + %tAB = arith.addf %tA, %B : vector<8x4xbf16> + return %AtB, %tAB : vector<4x8xbf16>, vector<8x4xbf16> +}