Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aievec] Make transpose lowering more progressive #1574

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2979,6 +2979,67 @@ struct LowerVectorTransposeOpToAIEVecShuffleOpPattern
}
};

// Convert a `vector.flat_transpose` op to an `aievec.shuffle` op for AIEml.
struct LowerVectorFlatTransposeOpToAIEVecShuffleOpPattern
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably remove the old pattern LowerVectorTransposeOpToAIEVecShuffleOpPattern in the future, to reduce the maintenance cost.

: OpConversionPattern<vector::FlatTransposeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::FlatTransposeOp transpOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rows = transpOp.getRows();
auto cols = transpOp.getColumns();
auto resVecTy = cast<VectorType>(transpOp.getResult().getType());
auto elemTyBitWidth = resVecTy.getElementTypeBitWidth();
auto vBitWidth = elemTyBitWidth * rows * cols;

if (vBitWidth != 512)
return failure();

if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
return failure();

auto shuffleMode = aievec::ShuffleMode::T32_4X4;
if (elemTyBitWidth == 8) {
switch (rows) {
case 4:
shuffleMode = aievec::ShuffleMode::T8_4X16;
break;
case 8:
shuffleMode = aievec::ShuffleMode::T8_8X8;
break;
case 16:
shuffleMode = aievec::ShuffleMode::T8_16X4;
break;
default:
return failure();
}
} else if (elemTyBitWidth == 16) {
switch (rows) {
case 2:
shuffleMode = aievec::ShuffleMode::T16_2X16;
break;
case 4:
shuffleMode = aievec::ShuffleMode::T16_4X8;
break;
case 8:
shuffleMode = aievec::ShuffleMode::T16_8X4;
break;
case 16:
shuffleMode = aievec::ShuffleMode::T16_16X2;
break;
default:
return failure();
}
} else if (cols != 4)
return failure();

rewriter.replaceOpWithNewOp<aievec::ShuffleOp>(
transpOp, resVecTy, adaptor.getMatrix(), nullptr, shuffleMode);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

operand name is matrix, interesting.


return success();
}
};

//===----------------------------------------------------------------------===//
// Pattern collection
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3061,7 +3122,8 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
ConvertMulAddToAIEVecFMAElemOpPattern,
ConvertVectorFMAOpToAIEVecFMAElemOpPattern,
LowerVectorExtractStridedSliceOpAIE2Pattern,
LowerVectorTransposeOpToAIEVecShuffleOpPattern
LowerVectorTransposeOpToAIEVecShuffleOpPattern,
LowerVectorFlatTransposeOpToAIEVecShuffleOpPattern
>(patterns.getContext());
patterns.add<LowerVectorContractionOpToAIEVecMatMulPattern
>(patterns.getContext(), backend == TargetBackend::CPP);
Expand Down
77 changes: 76 additions & 1 deletion lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,54 @@ struct ExtractTransposeFromContractionOp
}
};

// This pattern flattens a `vector.transpose` operation for shapes that can be
// handled by basic AIE shuffle ops.
struct FlattenVectorTransposeOpPattern
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a separate lit test for this pattern? Thanks!

: public OpConversionPattern<vector::TransposeOp> {
using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resTy = transpOp.getResultVectorType();
auto resShape = resTy.getShape();
auto elemTyBitWidth = resTy.getElementTypeBitWidth();
auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
elemTyBitWidth, std::multiplies<>());
if (vBitWidth != 512)
return failure();

if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
return failure();

// Verify leading dimensions are all 1.
for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
if (resShape[i] != 1)
return failure();

// Only permutation of the 2 innermost dimensions are supported.
ArrayRef<int64_t> perm = transpOp.getPermutation();
for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
if (perm[i] != i)
return failure();
if (perm.back() != static_cast<int64_t>(perm.size() - 2))
return failure();

auto flatVecTy =
VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
auto loc = transpOp.getLoc();
auto flatInput = rewriter.create<vector::ShapeCastOp>(loc, flatVecTy,
adaptor.getVector());
auto flatTranspOp = rewriter.create<vector::FlatTransposeOp>(
loc, flatVecTy, flatInput, static_cast<int32_t>(resShape.back()),
static_cast<int32_t>(resShape[resShape.size() - 2]));
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy,
flatTranspOp);

return success();
}
};

//============================================================================//
//============ AIE2 canonicalization conversion patterns ===============//
//============================================================================//
Expand Down Expand Up @@ -690,6 +738,32 @@ static void configureAIE2CanonicalizeLegalizations(ConversionTarget &target,
[](vector::ContractionOp op) {
return !isGemmBTransposedContractionOp(op);
});
target.addDynamicallyLegalOp<vector::TransposeOp>([](vector::TransposeOp op) {
auto resTy = op.getResultVectorType();
auto resShape = resTy.getShape();
auto elemTyBitWidth = resTy.getElementTypeBitWidth();
auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
elemTyBitWidth, std::multiplies<>());
if (vBitWidth != 512)
return true;

if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
return true;

// Verify leading dimensions are all 1.
for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
if (resShape[i] != 1)
return true;

// Only permutation of the 2 innermost dimensions are supported.
ArrayRef<int64_t> perm = op.getPermutation();
for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
if (perm[i] != i)
return true;
if (perm.back() != static_cast<int64_t>(perm.size() - 2))
return true;
return false;
});
}

static void
Expand All @@ -699,7 +773,8 @@ populateAIE2CanonicalizeConversionPatterns(RewritePatternSet &patterns,
256);
patterns
.add<ExtractTransposeFromContractionOp, FlattenMultDimTransferReadPattern,
FlattenMultDimTransferWritePattern>(patterns.getContext());
FlattenMultDimTransferWritePattern, FlattenVectorTransposeOpPattern>(
patterns.getContext());
}

//============================================================================//
Expand Down
50 changes: 45 additions & 5 deletions test/dialect/AIEVec/precanonicalization-aieml.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,13 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
%B : vector<1x1x4x8xbf16>,
%C : vector<1x1x4x4xf32>)
-> vector<1x1x4x4xf32> {
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]]
// CHECK-SAME: : vector<1x1x4x8xbf16> to vector<32xbf16>
// CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]]
// CHECK-SAME: {columns = 8 : i32, rows = 4 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[TRB:.*]] = vector.shape_cast %[[FVBT]]
// CHECK-SAME: : vector<32xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RES:.*]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction",
Expand Down Expand Up @@ -144,8 +149,13 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
-> vector<1x1x4x4xf32> {
// CHECK: %[[LHS:.*]] = arith.extf %[[VA]] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]]
// CHECK-SAME: : vector<1x1x4x8xbf16> to vector<32xbf16>
// CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]]
// CHECK-SAME: {columns = 8 : i32, rows = 4 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[TRB:.*]] = vector.shape_cast %[[FVBT]]
// CHECK-SAME: : vector<32xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RHS:.*]] = arith.extf %[[TRB]] :
// CHECK-SAME: vector<1x1x8x4xbf16> to vector<1x1x8x4xf32>
// CHECK: %[[RES:.*]] = vector.contract {
Expand All @@ -165,4 +175,34 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
kind = #vector.kind<add>} %lhs, %rhs, %C :
vector<1x1x4x8xf32>, vector<1x1x4x8xf32> into vector<1x1x4x4xf32>
return %res : vector<1x1x4x4xf32>
}
}

//
// -----
//

// CHECK-LABEL: func.func @vector_transpose(
// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<4x8xbf16>,
// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<8x4xbf16>)
func.func @vector_transpose(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>)
-> (vector<4x8xbf16>, vector<8x4xbf16>) {
// CHECK: %[[FVA:.*]] = vector.shape_cast %[[VA]]
// CHECK-SAME: : vector<4x8xbf16> to vector<32xbf16>
// CHECK: %[[FVAT:.*]] = vector.flat_transpose %[[FVA]]
// CHECK-SAME: {columns = 8 : i32, rows = 4 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[VAT:.*]] = vector.shape_cast %[[FVAT]]
// CHECK-SAME: : vector<32xbf16> to vector<8x4xbf16>
%tA = vector.transpose %A, [1, 0] : vector<4x8xbf16> to vector<8x4xbf16>
// CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]]
// CHECK-SAME: : vector<8x4xbf16> to vector<32xbf16>
// CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]]
// CHECK-SAME: {columns = 4 : i32, rows = 8 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[VBT:.*]] = vector.shape_cast %[[FVBT]]
// CHECK-SAME: : vector<32xbf16> to vector<4x8xbf16>
%tB = vector.transpose %B, [1, 0] : vector<8x4xbf16> to vector<4x8xbf16>
%AtB = arith.addf %A, %tB : vector<4x8xbf16>
%tAB = arith.addf %tA, %B : vector<8x4xbf16>
return %AtB, %tAB : vector<4x8xbf16>, vector<8x4xbf16>
}
Loading