Skip to content

Commit

Permalink
[aievec] Make transpose lowering more progressive
Browse files Browse the repository at this point in the history
  • Loading branch information
jsetoain committed Jun 20, 2024
1 parent 875648d commit 993340c
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 7 deletions.
64 changes: 63 additions & 1 deletion lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2880,6 +2880,67 @@ struct LowerVectorTransposeOpToAIEVecShuffleOpPattern
}
};

// Convert a `vector.flat_transpose` op to an `aievec.shuffle` op for AIEml.
struct LowerVectorFlatTransposeOpToAIEVecShuffleOpPattern
: OpConversionPattern<vector::FlatTransposeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(vector::FlatTransposeOp transpOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rows = transpOp.getRows();
auto cols = transpOp.getColumns();
auto resVecTy = cast<VectorType>(transpOp.getResult().getType());
auto elemTyBitWidth = resVecTy.getElementTypeBitWidth();
auto vBitWidth = elemTyBitWidth * rows * cols;

if (vBitWidth != 512)
return failure();

if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
return failure();

auto shuffleMode = aievec::ShuffleMode::T32_4X4;
if (elemTyBitWidth == 8) {
switch (rows) {
case 4:
shuffleMode = aievec::ShuffleMode::T8_4X16;
break;
case 8:
shuffleMode = aievec::ShuffleMode::T8_8X8;
break;
case 16:
shuffleMode = aievec::ShuffleMode::T8_16X4;
break;
default:
return failure();
}
} else if (elemTyBitWidth == 16) {
switch (rows) {
case 2:
shuffleMode = aievec::ShuffleMode::T16_2X16;
break;
case 4:
shuffleMode = aievec::ShuffleMode::T16_4X8;
break;
case 8:
shuffleMode = aievec::ShuffleMode::T16_8X4;
break;
case 16:
shuffleMode = aievec::ShuffleMode::T16_16X2;
break;
default:
return failure();
}
} else if (cols != 4)
return failure();

rewriter.replaceOpWithNewOp<aievec::ShuffleOp>(
transpOp, resVecTy, adaptor.getMatrix(), nullptr, shuffleMode);

return success();
}
};

//===----------------------------------------------------------------------===//
// Pattern collection
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2961,7 +3022,8 @@ static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
ConvertBroadcastToAIEBroadcast,
ConvertMulAddToAIEVecFMAElemOpPattern,
LowerVectorExtractStridedSliceOpAIEMLPattern,
LowerVectorTransposeOpToAIEVecShuffleOpPattern
LowerVectorTransposeOpToAIEVecShuffleOpPattern,
LowerVectorFlatTransposeOpToAIEVecShuffleOpPattern
>(patterns.getContext());
patterns.add<LowerVectorContractionOpToAIEVecMatMulPattern
>(patterns.getContext(), backend == TargetBackend::CPP);
Expand Down
77 changes: 76 additions & 1 deletion lib/Dialect/AIEVec/Transforms/VectorToVectorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,54 @@ struct ExtractTransposeFromContractionOp
}
};

// This pattern flattens a `vector.transpose` operation for shapes that can be
// handled by basic AIE shuffle ops.
struct FlattenVectorTransposeOpPattern
: public OpConversionPattern<vector::TransposeOp> {
using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resTy = transpOp.getResultVectorType();
auto resShape = resTy.getShape();
auto elemTyBitWidth = resTy.getElementTypeBitWidth();
auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
elemTyBitWidth, std::multiplies<>());
if (vBitWidth != 512)
return failure();

if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
return failure();

// Verify leading dimensions are all 1.
for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
if (resShape[i] != 1)
return failure();

// Only permutation of the 2 innermost dimensions are supported.
ArrayRef<int64_t> perm = transpOp.getPermutation();
for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
if (perm[i] != i)
return failure();
if (perm.back() != static_cast<int64_t>(perm.size() - 2))
return failure();

auto flatVecTy =
VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
auto loc = transpOp.getLoc();
auto flatInput = rewriter.create<vector::ShapeCastOp>(loc, flatVecTy,
adaptor.getVector());
auto flatTranspOp = rewriter.create<vector::FlatTransposeOp>(
loc, flatVecTy, flatInput, static_cast<int32_t>(resShape.back()),
static_cast<int32_t>(resShape[resShape.size() - 2]));
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy,
flatTranspOp);

return success();
}
};

//============================================================================//
//============ AIEML canonicalization conversion patterns ===============//
//============================================================================//
Expand Down Expand Up @@ -625,6 +673,32 @@ static void configureAIEMLCanonicalizeLegalizations(ConversionTarget &target,
[](vector::ContractionOp op) {
return !isGemmBTransposedContractionOp(op);
});
target.addDynamicallyLegalOp<vector::TransposeOp>([](vector::TransposeOp op) {
auto resTy = op.getResultVectorType();
auto resShape = resTy.getShape();
auto elemTyBitWidth = resTy.getElementTypeBitWidth();
auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
elemTyBitWidth, std::multiplies<>());
if (vBitWidth != 512)
return true;

if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
return true;

// Verify leading dimensions are all 1.
for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
if (resShape[i] != 1)
return true;

// Only permutation of the 2 innermost dimensions are supported.
ArrayRef<int64_t> perm = op.getPermutation();
for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
if (perm[i] != i)
return true;
if (perm.back() != static_cast<int64_t>(perm.size() - 2))
return true;
return false;
});
}

static void
Expand All @@ -634,7 +708,8 @@ populateAIEMLCanonicalizeConversionPatterns(RewritePatternSet &patterns,
256);
patterns
.add<ExtractTransposeFromContractionOp, FlattenMultDimTransferReadPattern,
FlattenMultDimTransferWritePattern>(patterns.getContext());
FlattenMultDimTransferWritePattern, FlattenVectorTransposeOpPattern>(
patterns.getContext());
}

//============================================================================//
Expand Down
50 changes: 45 additions & 5 deletions test/dialect/AIEVec/precanonicalization-aieml.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,13 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
%B : vector<1x1x4x8xbf16>,
%C : vector<1x1x4x4xf32>)
-> vector<1x1x4x4xf32> {
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]]
// CHECK-SAME: : vector<1x1x4x8xbf16> to vector<32xbf16>
// CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]]
// CHECK-SAME: {columns = 8 : i32, rows = 4 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[TRB:.*]] = vector.shape_cast %[[FVBT]]
// CHECK-SAME: : vector<32xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RES:.*]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[IDXMAPA]], #[[IDXMAPB]], #[[IDXMAPC]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction",
Expand Down Expand Up @@ -144,8 +149,13 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
-> vector<1x1x4x4xf32> {
// CHECK: %[[LHS:.*]] = arith.extf %[[VA]] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x4x8xf32>
// CHECK: %[[TRB:.*]] = vector.transpose %[[VB]], [0, 1, 3, 2] :
// CHECK-SAME: vector<1x1x4x8xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]]
// CHECK-SAME: : vector<1x1x4x8xbf16> to vector<32xbf16>
// CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]]
// CHECK-SAME: {columns = 8 : i32, rows = 4 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[TRB:.*]] = vector.shape_cast %[[FVBT]]
// CHECK-SAME: : vector<32xbf16> to vector<1x1x8x4xbf16>
// CHECK: %[[RHS:.*]] = arith.extf %[[TRB]] :
// CHECK-SAME: vector<1x1x8x4xbf16> to vector<1x1x8x4xf32>
// CHECK: %[[RES:.*]] = vector.contract {
Expand All @@ -165,4 +175,34 @@ func.func @vector_contract_permuted_b(%A : vector<1x1x4x8xbf16>,
kind = #vector.kind<add>} %lhs, %rhs, %C :
vector<1x1x4x8xf32>, vector<1x1x4x8xf32> into vector<1x1x4x4xf32>
return %res : vector<1x1x4x4xf32>
}
}

//
// -----
//

// CHECK-LABEL: func.func @vector_transpose(
// CHECK-SAME: %[[VA:[a-zA-Z0-9]+]]: vector<4x8xbf16>,
// CHECK-SAME: %[[VB:[a-zA-Z0-9]+]]: vector<8x4xbf16>)
func.func @vector_transpose(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>)
-> (vector<4x8xbf16>, vector<8x4xbf16>) {
// CHECK: %[[FVA:.*]] = vector.shape_cast %[[VA]]
// CHECK-SAME: : vector<4x8xbf16> to vector<32xbf16>
// CHECK: %[[FVAT:.*]] = vector.flat_transpose %[[FVA]]
// CHECK-SAME: {columns = 8 : i32, rows = 4 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[VAT:.*]] = vector.shape_cast %[[FVAT]]
// CHECK-SAME: : vector<32xbf16> to vector<8x4xbf16>
%tA = vector.transpose %A, [1, 0] : vector<4x8xbf16> to vector<8x4xbf16>
// CHECK: %[[FVB:.*]] = vector.shape_cast %[[VB]]
// CHECK-SAME: : vector<8x4xbf16> to vector<32xbf16>
// CHECK: %[[FVBT:.*]] = vector.flat_transpose %[[FVB]]
// CHECK-SAME: {columns = 4 : i32, rows = 8 : i32}
// CHECK-SAME: : vector<32xbf16> -> vector<32xbf16>
// CHECK: %[[VBT:.*]] = vector.shape_cast %[[FVBT]]
// CHECK-SAME: : vector<32xbf16> to vector<4x8xbf16>
%tB = vector.transpose %B, [1, 0] : vector<8x4xbf16> to vector<4x8xbf16>
%AtB = arith.addf %A, %tB : vector<4x8xbf16>
%tAB = arith.addf %tA, %B : vector<8x4xbf16>
return %AtB, %tAB : vector<4x8xbf16>, vector<8x4xbf16>
}

0 comments on commit 993340c

Please sign in to comment.