Skip to content

Commit

Permalink
Integrate LLVM at e2402615a5a76d46a433dfcc1de10b38a1263c9d (#3982)
Browse files Browse the repository at this point in the history
Update LLVM to

llvm/llvm-project@e240261
Update StableHlo to

openxla/stablehlo@8cd9444

Updates API calls from:
1. `applyPatternsAndFoldGreedily` -> `applyPatternsGreedily`
2. `applyOpPatternsAndFold` -> `applyOpPatternsGreedily`

This commit also inlines the `BufferizeTypeConverter` in Torch-MLIR
which has been removed from the LLVM project here:

llvm/llvm-project@2ff2e87.

This commit also updates the `AdjustCallingConventions` pass in order to
align with the changes made for `TypeConverter` upstream. Some of the
tests from the `adjust-calling-conventions.mlir` are disabled for the
time being since they are not supported even after making changes in the
pass. We will enable them once the `AdjustCallingConventions` pass is
fully functional in a seperate PR. The fix will be tracked by
#3983.


TOSA Updates Summary:

Update Torch to TOSA legalizations with TOSA 1.0 ops' forms from LLVM
hash 64edde66. Changes include:

TOSA Pad op's new shape requirement
TOSA Convolution ops' new acc_type
TOSA Tile with multiples as a !tosa.shape input

---------

Signed-off-by: Vivek Khandelwal <[email protected]>
Co-authored-by: Justin Ngo <[email protected]>
  • Loading branch information
vivekkhandelwal1 and justin-ngo-arm authored Jan 28, 2025
1 parent af8514c commit 1225073
Show file tree
Hide file tree
Showing 29 changed files with 430 additions and 178 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
Submodule llvm-project updated 18624 files
2 changes: 1 addition & 1 deletion externals/stablehlo
Submodule stablehlo updated 143 files
11 changes: 11 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,17 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
TypeAttr &accType);

// Get accumulator type for TOSA convolution ops
LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
RankedTensorType inputTy,
RankedTensorType weightTy,
RankedTensorType outputTy, TypeAttr &accType);

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape);
} // namespace tosa
} // namespace mlir

Expand Down
6 changes: 3 additions & 3 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Type floatDtype = mlir::Float64Type::get(context);
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype);
Value zero =
Expand All @@ -569,7 +569,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (isa<AtenLogicalNotOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Type floatDtype = mlir::Float64Type::get(context);
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
Value zero =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
Expand Down Expand Up @@ -1028,7 +1028,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Type powType = dtype;
if (payloadArgs[0].getType().isInteger() ||
payloadArgs[1].getType().isInteger())
powType = mlir::FloatType::getF64(op->getContext());
powType = mlir::Float64Type::get(op->getContext());
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType);
auto powOp = b.create<math::PowFOp>(loc, lhs, rhs);
Expand Down
65 changes: 44 additions & 21 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
Expand Down Expand Up @@ -2252,6 +2253,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op,
"non-const dilation list unsupported");

TypeAttr accType;
if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy,
accType)))
return rewriter.notifyMatchFailure(
op, "failed to get accumulator type for convolution ops");

// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
// Perform the necessary transformations.
std::optional<Value> nchwToNhwcTransposeConst =
Expand Down Expand Up @@ -2365,12 +2372,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
// full convolution
convOpResult =
rewriter
.create<tosa::Conv2DOp>(op->getLoc(),
getTypeConverter()->convertType(convOpTy),
transposedInput, transformedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
.create<tosa::Conv2DOp>(
op->getLoc(), getTypeConverter()->convertType(convOpTy),
transposedInput, transformedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation), accType)
.getResult();
} else if (weightShape[1] == 1) {
// depthwise convolution
Expand All @@ -2381,7 +2388,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
transposedInput, transformedWeight, bias,
rewriter.getDenseI64ArrayAttr(padding),
rewriter.getDenseI64ArrayAttr(stride),
rewriter.getDenseI64ArrayAttr(dilation))
rewriter.getDenseI64ArrayAttr(dilation), accType)
.getResult();
} else {
llvm_unreachable("Unhandled convolution type");
Expand Down Expand Up @@ -3909,9 +3916,11 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
}
}

auto result = rewriter.create<tosa::TileOp>(
op->getLoc(), resultType, reshapedInput,
rewriter.getDenseI64ArrayAttr(tileOpShape));
auto tileOpMultiples =
tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShape);

auto result = rewriter.create<tosa::TileOp>(op->getLoc(), resultType,
reshapedInput, tileOpMultiples);

rewriter.replaceOp(op, {result.getResult()});
}
Expand Down Expand Up @@ -4104,9 +4113,11 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
rewriter.getIntegerType(32));

auto tileOpMultiples =
tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape);

auto expandedIndices = rewriter.create<tosa::TileOp>(
op->getLoc(), tileType, reshapedIndices.getResult(),
rewriter.getDenseI64ArrayAttr(tileShape));
op->getLoc(), tileType, reshapedIndices.getResult(), tileOpMultiples);

// convert torch style index and dim into tf style indices
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
Expand Down Expand Up @@ -4445,17 +4456,23 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
if (needsTiling) {
auto idxType =
dyn_cast<RankedTensorType>(indicesTfConcatTensors[i].getType());

// indicesTfConcatTensors has a trailing [1] dim for the final concat.
auto maxRankMaxDimShapeTf(maxRankMaxDimShape);
maxRankMaxDimShapeTf.push_back(1);

auto tileOpShapeTf(tileOpShape);
tileOpShapeTf.push_back(1);

auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf,
idxType.getElementType());
auto reshapedIdxTensor = indicesTfConcatTensors[i];

auto tileOpMultiples =
tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShapeTf);

indicesTfConcatTensors[i] = rewriter.create<tosa::TileOp>(
op->getLoc(), tileOutputTy, reshapedIdxTensor,
rewriter.getDenseI64ArrayAttr(tileOpShapeTf));
op->getLoc(), tileOutputTy, reshapedIdxTensor, tileOpMultiples);
}

// Every index tensor now has the same rank and shape
Expand Down Expand Up @@ -6023,12 +6040,14 @@ class ConvertAtenFillOp : public OpConversionPattern<AtenOpT> {
op->getLoc(), fillValueMatchedInputRankType, fillValue,
rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape));

auto tileOpMultiples =
tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape());

fillValueTargetTensor = rewriter.create<tosa::TileOp>(
op->getLoc(),
RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()),
fillValueElemTy),
fillValueMatchedInputRankTensor.getResult(),
makeShapeTorchCompatible(outType.getShape()));
fillValueMatchedInputRankTensor.getResult(), tileOpMultiples);
} else {
if (failed(torchScalarToTosaTensor(
rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy,
Expand Down Expand Up @@ -6179,7 +6198,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
}

DenseElementsAttr paddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({rank, 2}, rewriter.getI64Type()),
RankedTensorType::get({2 * rank}, rewriter.getI64Type()),
translatePadsList);

Value padsList1 = rewriter.create<mlir::tosa::ConstOp>(
Expand Down Expand Up @@ -7836,9 +7855,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
resultType.getElementType()),
self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));

auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(),
resultShapeIndex0Replaced);

auto selfTiled = rewriter.create<tosa::TileOp>(
op->getLoc(), resultType, selfReshaped.getResult(),
rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));
op->getLoc(), resultType, selfReshaped.getResult(), selfTileOpMultiples);

// Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]}
auto vec2Reshaped = rewriter.create<tosa::ReshapeOp>(
Expand All @@ -7847,9 +7868,11 @@ LogicalResult ConvertAtenOp<AtenOuterOp>::matchAndRewrite(
resultType.getElementType()),
vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced));

auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(),
resultShapeIndex1Replaced);

auto vec2Tiled = rewriter.create<tosa::TileOp>(
op->getLoc(), resultType, vec2Reshaped.getResult(),
rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced));
op->getLoc(), resultType, vec2Reshaped.getResult(), vec2TileOpMultiples);

auto result =
tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(),
Expand Down
6 changes: 4 additions & 2 deletions lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

Expand Down Expand Up @@ -566,11 +567,12 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,

// [0] -> [0,0,0]
SmallVector<int64_t, 1> tileShape({W}); // {3}
auto tileOpMultiples =
tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape);
auto tosaFillValuesTileOp = tosa::CreateOpAndInfer<tosa::TileOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()),
tosaFillValuesOneReshapeOp.getResult(),
rewriter.getDenseI64ArrayAttr(tileShape));
tosaFillValuesOneReshapeOp.getResult(), tileOpMultiples);

// [0,0,0] -> [[0,0,0]]
SmallVector<int64_t, 2> newTosaFillValuesShape({N, W}); // {1,3}
Expand Down
58 changes: 58 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,5 +454,63 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
return success();
}

// Get accumulator type for TOSA convolution ops
LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
RankedTensorType inputTy,
RankedTensorType weightTy,
RankedTensorType outputTy, TypeAttr &accType) {
auto inputElemTy = inputTy.getElementType();
auto weightElemTy = weightTy.getElementType();
auto outputElemTy = outputTy.getElementType();

auto quantTy = dyn_cast<quant::QuantizedType>(inputElemTy);
if (quantTy)
inputElemTy = quantTy.getStorageType();

// Get TOSA conv ops acc type based on input, weight, and output types
// according to the spec:
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d
// https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d
// https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d
//
// For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the
// output type but does not offer any guarantee on the numerical precision
// since such cases will fail TOSA validation.
if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) ||
(inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) ||
(inputElemTy.isBF16() && weightElemTy.isBF16() &&
outputElemTy.isBF16())) {
accType = mlir::TypeAttr::get(rewriter.getF32Type());
} else if (inputElemTy.isInteger(8) &&
(weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) &&
outputElemTy.isInteger(32)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(32));
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
outputElemTy.isInteger(48)) {
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
} else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() &&
outputElemTy.isF16()) ||
(inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() &&
outputElemTy.isF16())) {
accType = mlir::TypeAttr::get(rewriter.getF16Type());
} else {
accType = mlir::TypeAttr::get(outputElemTy);
}

return success();
}

// Temporary function to get TOSA const shape
// TODO: Remove this function when getTosaConstShape is available in
// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
llvm::ArrayRef<int64_t> shape) {
auto attr = rewriter.getIndexTensorAttr(shape);
auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size());
mlir::Operation *mlir_op =
rewriter.create<tosa::ConstShapeOp>(loc, type, attr);
return mlir_op->getResult(0);
}

} // namespace tosa
} // namespace mlir
50 changes: 49 additions & 1 deletion lib/Dialect/TMTensor/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern<TMTensorOp> {
};

namespace {

static Value materializeToTensor(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(isa<BaseMemRefType>(inputs[0].getType()));
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
}

/// Converts TMTensor operations that work on tensor-type operands or results to
/// work on buffers.
struct TMTensorBufferizePass
Expand All @@ -133,7 +141,47 @@ struct TMTensorBufferizePass
void runOnOperation() override {
MLIRContext &context = getContext();
ConversionTarget target(context);
bufferization::BufferizeTypeConverter typeConverter;
// Since the `BufferizeTypeConverter` has been removed here
// https://github.com/llvm/llvm-project/commit/2ff2e871f5e632ea493efaf4f2192f8b18a54ab1,
// hence we have inlined the converter here.
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
// Convert RankedTensorType to MemRefType.
typeConverter.addConversion([](RankedTensorType type) -> Type {
return MemRefType::get(type.getShape(), type.getElementType());
});
// Convert UnrankedTensorType to UnrankedMemRefType.
typeConverter.addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
typeConverter.addArgumentMaterialization(materializeToTensor);
typeConverter.addSourceMaterialization(materializeToTensor);
typeConverter.addTargetMaterialization([](OpBuilder &builder,
BaseMemRefType type,
ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1 && "expected exactly one input");
if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) {
// MemRef to MemRef cast.
assert(inputType != type && "expected different types");
// Ranked to unranked casts must be explicit.
auto rankedDestType = dyn_cast<MemRefType>(type);
if (!rankedDestType)
return nullptr;
bufferization::BufferizationOptions options;
options.bufferAlignment = 0;
FailureOr<Value> replacement = castOrReallocMemRefValue(
builder, inputs[0], rankedDestType, options);
if (failed(replacement))
return nullptr;
return *replacement;
}
if (isa<TensorType>(inputs[0].getType())) {
// Tensor to MemRef cast.
return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]);
}
llvm_unreachable("only tensor/memref input types supported");
});

// Mark all Standard operations legal.
target.addLegalDialect<arith::ArithDialect, func::FuncDialect,
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,7 @@ struct TMTensorToLoopsPass : public TMTensorToLoopsBase<TMTensorToLoopsPass> {

RewritePatternSet patterns(context);
patterns.insert<ScalarLoopOpInterfaceLowerToLoopsPattern>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
Expand Down
Loading

0 comments on commit 1225073

Please sign in to comment.