Skip to content

Commit

Permalink
Decompose stablehlo transpose into multiple ttir.transpose ops (#1273)
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT authored Nov 15, 2024
1 parent 41323e9 commit 4e8207f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 25 deletions.
61 changes: 36 additions & 25 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,38 +148,49 @@ class StableHLOToTTIRTransposeOpConversionPattern
matchAndRewrite(mlir::stablehlo::TransposeOp srcOp,
mlir::stablehlo::TransposeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());
auto input = Value(adaptor.getOperand());
auto transposes = getPermutationTransposes(adaptor.getPermutation().vec());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::TransposeOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor),
adaptor.getPermutation()[0], adaptor.getPermutation()[1],
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
for (auto transposeDims : transposes) {
auto dim0 = std::get<0>(transposeDims);
auto dim1 = std::get<1>(transposeDims);

auto inputType = mlir::cast<RankedTensorType>(input.getType());
auto outputShape = inputType.getShape().vec();
std::swap(outputShape[dim0], outputShape[dim1]);

auto outputType = RankedTensorType::get(
outputShape, inputType.getElementType(), inputType.getEncoding());

auto outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputShape, outputType.getElementType());

input = rewriter.create<mlir::tt::ttir::TransposeOp>(
srcOp.getLoc(), outputType, input, outputTensor,
rewriter.getSI32IntegerAttr(dim0), rewriter.getSI32IntegerAttr(dim1),
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
OperandConstraint::AnyDeviceTile))));
}
rewriter.replaceOp(srcOp, input);
return success();
}

LogicalResult
checkBasicLegality(mlir::stablehlo::TransposeOp &srcOp,
mlir::stablehlo::TransposeOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter) const {

if (adaptor.getPermutation().size() != 2) {
return rewriter.notifyMatchFailure(
srcOp, "TTIR supports only two dimensional transposeOp.");
private:
std::vector<std::tuple<int64_t, int64_t>>
getPermutationTransposes(std::vector<int64_t> permutation) const {
std::vector<std::tuple<int64_t, int64_t>> transposes;
for (uint32_t i = 0; i < permutation.size(); i++) {
while (i != permutation[i]) {
transposes.push_back(
std::make_tuple(permutation[i], permutation[permutation[i]]));
std::swap(permutation[i], permutation[permutation[i]]);
}
}

return success();
return transposes;
}
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module {
func.func @main(%arg0: tensor<1x32x64x128xf32>) -> tensor<1x128x32x64xf32> {
// CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.transpose"[[C:.*]]
%0 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x32x64x128xf32>) -> tensor<1x128x32x64xf32>
return %0 : tensor<1x128x32x64xf32>
}
}

0 comments on commit 4e8207f

Please sign in to comment.