Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLVM update 0e779ad #3086

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
auto int16Ty = IntegerType::get(context, 16);
auto int32Ty = IntegerType::get(context, 32);
auto int64Ty = IntegerType::get(context, 64);
auto float32Ty = FloatType::getF32(context);
auto float32Ty = Float32Type::get(context);

// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
Expand Down Expand Up @@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) {
Type llvmI64Ty = IntegerType::get(context, 64);
Type llvmI1Ty = IntegerType::get(context, 1);
Type llvmI8Ty = IntegerType::get(context, 8);
Type llvmF32Ty = FloatType::getF32(context);
Type llvmF32Ty = Float32Type::get(context);
Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3);
Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20);
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
Expand Down Expand Up @@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float");
create.llvm.store(recScale, recScalePtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, recScalePtr);
}

Expand All @@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float");
create.llvm.store(offset, offsetPtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, offsetPtr);
}

Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern {
// or
// (memref<3x4x5xf64>, index, f64, f64, f64)
Type llvmVoidTy = LLVM::LLVMVoidType::get(context);
Type llvmOptionsTy = FloatType::getF32(context);
Type llvmOptionsTy = Float32Type::get(context);
Type llvmOutputTy = getPointerType(context, llvmOptionsTy);
if (inType.isF64()) {
llvmOptionsTy = FloatType::getF64(context);
llvmOptionsTy = Float64Type::get(context);
llvmOutputTy = getPointerType(context, llvmOptionsTy);
}
Type llvmI64Ty = IntegerType::get(context, 64);
Expand Down
13 changes: 6 additions & 7 deletions src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
Type outType = op->getResultTypes().front();
Type llvmInType, llvmOutType;
if (inType.isF16())
llvmInType = FloatType::getF16(context);
llvmInType = Float16Type::get(context);
else if (inType.isF32())
llvmInType = FloatType::getF32(context);
llvmInType = Float32Type::get(context);
else if (inType.isF64())
llvmInType = FloatType::getF64(context);
llvmInType = Float64Type::get(context);
else if (inType.isBF16())
llvmInType = FloatType::getBF16(context);
llvmInType = Float64Type::get(context);
if (outType.isInteger(1))
llvmOutType = IntegerType::get(context, 1);
else if (outType.isF32())
llvmOutType = FloatType::getF32(context);
llvmOutType = Float32Type::get(context);
else if (outType.isF64())
llvmOutType = FloatType::getF64(context);
llvmOutType = Float64Type::get(context);

// Insert and/or get reference to elementary math function declaration.
assert(
Expand Down Expand Up @@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
return SymbolRefAttr::get(context, mathFuncName);

// Create function declaration.
// auto llvmF32Ty = FloatType::get(context);
auto llvmFnType =
LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef<Type>({llvmInType}));

Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

// Get memRefDescriptor, the new memref descriptor.
MemRefDescriptor memRefDescriptor =
MemRefDescriptor::undef(rewriter, loc, targetStructType);
MemRefDescriptor::poison(rewriter, loc, targetStructType);
auto targetElementPtrType = memRefDescriptor.getElementPtrType();

// Set the new memref to the same buffer as the source memref.
Expand All @@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(targetType, strides, offset)))
if (failed(targetType.getStridesAndOffset(strides, offset)))
return failure();

// Unhandled dynamic offset.
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering
SmallVector<int64_t, 4> strides;
int64_t alignmentOffset; // not used, just to make the function call
// completed.
if (getStridesAndOffset(memRefType, strides, alignmentOffset)
if (memRefType.getStridesAndOffset(strides, alignmentOffset)
.failed())
llvm_unreachable("Failed to get strides");
Value stringMemRef =
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/Math/LRN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern<ONNXLRNOp> {
float alphaLit = adaptor.getAlpha().convertToFloat();
float betaLit = adaptor.getBeta().convertToFloat();
int sizeLit = adaptor.getSize();
auto f32Type = FloatType::getF32(rewriter.getContext());
auto f32Type = Float32Type::get(rewriter.getContext());
Value biasValue = create.math.constant(f32Type, biasLit);
Value alphaDivSizeValue =
create.math.constant(f32Type, alphaLit / static_cast<float>(sizeLit));
Expand Down
24 changes: 16 additions & 8 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"

#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef<int32_t> perm) {

Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef<int64_t> size,
llvm::ArrayRef<int64_t> start) {
DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size);
DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start);
auto startVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start));
auto sizeVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size));
Value newSliceInput =
tosa::CreateOpAndInfer<mlir::tosa::SliceOp>(rewriter(), loc(),
RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(size.size(), ShapedType::kDynamic),
mlir::cast<ShapedType>(inputConst.getType()).getElementType()),
inputConst, startAttr, sizeAttr);
inputConst, startVal, sizeVal);
return newSliceInput;
}

Expand All @@ -164,8 +167,9 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef<int64_t> shape) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(shape.size(), ShapedType::kDynamic),
valueType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(
rewriter(), loc(), newValueType, value, shapeAttr);
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter(), loc(),
newValueType, value,
mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr));
}

Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Expand All @@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());

auto int8Type = rewriter().getI8Type();
auto shiftValue = TosaBuilder::createConst(
ArrayRef<int8_t>{static_cast<int8_t>(shift)}, {1}, int8Type);
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
rewriter(), loc(), newValueType, lhs, rhs, shiftValue);
}

Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
Expand Down Expand Up @@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(Value &lhs, Value &rhs);
// Return null if none is found.
ElementsAttr IndexExprBuilderForTosa::getConst(Value value) {
auto definingOp = value.getDefiningOp();
// If we have a cast between index/integer, skip it, i.e. get the defining op
// that is the input to the cast.
// If we have a cast between index/integer, skip it, i.e. get the defining
// op that is the input to the cast.
if (auto castOp = dyn_cast_or_null<arith::IndexCastOp>(definingOp)) {
Value input = castOp.getIn();
definingOp = input.getDefiningOp();
Expand Down
20 changes: 15 additions & 5 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
// Quantized types are not supported right now (in type conversion).
// Once they are, the input should be rescaled for quantized types. (TBD)
// Maps to `tosa.clamp` which has both int and fp limits.
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(op, op.getType(), input,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
auto inputElementType =
llvm::cast<TensorType>(op.getType()).getElementType();
if (llvm::isa<IntegerType>(inputElementType)) {
auto minClamp = rewriter.getI64IntegerAttr(0);
auto maxClamp =
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
} else {
auto minClamp = rewriter.getF32FloatAttr(0.0f);
auto maxClamp =
rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
}
return success();
}
};
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToTOSA/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
Expand Down Expand Up @@ -67,13 +68,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {

llvm::SmallVector<int64_t> dynamicTensorShape = {
ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic};

A = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A,
rewriter.getDenseI64ArrayAttr(newShapeA))
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA))
.getResult();
B = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B,
rewriter.getDenseI64ArrayAttr(newShapeB))
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB))
.getResult();

// If transA or transB are present, create Transpose operators.
Expand Down
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -60,8 +61,8 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
}
tosaPads.insert(tosaPads.end(), lastVals.begin(), lastVals.end());
TosaBuilder tosaBuilder(rewriter, loc);
return tosaBuilder.getConst(
tosaPads, {static_cast<int64_t>(tosaPads.size())});

return mlir::tosa::getTosaConstShape(rewriter, loc, tosaPads);
}

} // namespace tosa
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ T getValueFromTosaConst(mlir::Value &val) {
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc,
mlir::Type result_ty, Args &&... args) {

auto op = rewriter.create<TosaOp>(loc, result_ty, args...);

mlir::InferShapedTypeOpInterface shapeInterface =
Expand All @@ -64,6 +65,7 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc,
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
assert(returnedShapes.size() >= 1 && "Expected at least one returned shape");
auto predictedShape = returnedShapes[0];
if (predictedShape.hasRank())
updateType(nullptr, op, predictedShape.getDims(),
Expand Down
10 changes: 5 additions & 5 deletions src/Dialect/ONNX/ElementsAttr/BType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ Type mlirTypeOfBType(BType btype, MLIRContext *ctx) {
case BType::FLOAT : return b.getF32Type();
case BType::FLOAT16 : return b.getF16Type();
case BType::BFLOAT16 : return b.getBF16Type();
case BType::FLOAT8E4M3FN : return b.getFloat8E4M3FNType();
case BType::FLOAT8E4M3FNUZ : return b.getFloat8E4M3FNUZType();
case BType::FLOAT8E5M2 : return b.getFloat8E5M2Type();
case BType::FLOAT8E5M2FNUZ : return b.getFloat8E5M2FNUZType();
case BType::FLOAT8E4M3FN : return b.getType<Float8E4M3FNType>();
case BType::FLOAT8E4M3FNUZ : return b.getType<Float8E4M3FNUZType>();
case BType::FLOAT8E5M2 : return b.getType<Float8E5M2Type>();
case BType::FLOAT8E5M2FNUZ : return b.getType<Float8E5M2FNUZType>();
default: llvm_unreachable("unsupported data type");
}
// clang-format on
Expand Down Expand Up @@ -104,4 +104,4 @@ BType wideBTypeOfBType(BType d) {
[](auto btype) { return toBType<typename BTypeTrait<btype>::widetype>; });
}

} // namespace onnx_mlir
} // namespace onnx_mlir
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ LogicalResult ONNXOneHotEncoderOp::inferShapes(
return success();

ONNXOneHotEncoderOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(FloatType::getF32(getContext()));
return shapeHelper.computeShapeAndUpdateType(Float32Type::get(getContext()));
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ LogicalResult ONNXScalerOp::inferShapes(
ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {});
RankedTensorType xType = mlir::dyn_cast<RankedTensorType>(getX().getType());
return shapeHelper.computeShapeAndUpdateType(
FloatType::getF32(getContext()), xType.getEncoding());
Float32Type::get(getContext()), xType.getEncoding());
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 7 additions & 7 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
Type elementType;
if (auto attr = getDtypeAttr()) {
if (getDtype() == 0) {
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
} else if (getDtype() == 1) {
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
} else if (getDtype() == 2) {
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());
} else {
llvm_unreachable("dtype not supported for RandomNormal");
}
} else {
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
}
return {UnrankedTensorType::get(elementType)};
}
Expand All @@ -68,11 +68,11 @@ std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
LogicalResult ONNXRandomNormalOp::inferShapes(
std::function<void(Region &)> doShapeInference) {
auto elementTypeID = getDtype();
Type elementType = FloatType::getF32(getContext());
Type elementType = Float32Type::get(getContext());
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());

ONNXRandomNormalOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(elementType);
Expand Down
14 changes: 6 additions & 8 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ LogicalResult ONNXRandomNormalLikeOp::verify() {
if (elementTypeID < 0 || elementTypeID > 2) {
return emitOpError("dtype not 0, 1 or 2.");
}
if (elementTypeID == 0 && outputType != FloatType::getF16(getContext()))
if (elementTypeID == 0 && outputType != Float16Type::get(getContext()))
return emitOpError("output tensor does match 0 dtype.");
else if (elementTypeID == 1 &&
outputType != FloatType::getF32(getContext()))
else if (elementTypeID == 1 && outputType != Float32Type::get(getContext()))
return emitOpError("output tensor does match 1 dtype.");
else if (elementTypeID == 2 &&
outputType != FloatType::getF64(getContext()))
else if (elementTypeID == 2 && outputType != Float64Type::get(getContext()))
return emitOpError("output tensor does match 2 dtype.");
} else if (inputType != outputType) {
return emitOpError("output and input element types do not match.");
Expand All @@ -75,11 +73,11 @@ LogicalResult ONNXRandomNormalLikeOp::inferShapes(
} else {
int64_t elementTypeID = elementTypeIDDType.value();
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
else if (elementTypeID == 1)
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());
else
return emitError("dtype attribute is invalid (use: 0, 1 or 2)");
}
Expand Down
Loading