Skip to content

Commit

Permalink
mostly tosa updates
Browse files Browse the repository at this point in the history
  • Loading branch information
brnorris03 committed Feb 21, 2025
1 parent f4a3e00 commit 7a108f7
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

// Get memRefDescriptor, the new memref descriptor.
MemRefDescriptor memRefDescriptor =
MemRefDescriptor::undef(rewriter, loc, targetStructType);
MemRefDescriptor::poison(rewriter, loc, targetStructType);
auto targetElementPtrType = memRefDescriptor.getElementPtrType();

// Set the new memref to the same buffer as the source memref.
Expand All @@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern {

int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(targetType, strides, offset)))
if (failed(targetType.getStridesAndOffset(strides, offset)))
return failure();

// Unhandled dynamic offset.
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering
SmallVector<int64_t, 4> strides;
int64_t alignmentOffset; // not used, just to make the function call
// completed.
if (getStridesAndOffset(memRefType, strides, alignmentOffset)
if (memRefType.getStridesAndOffset(strides, alignmentOffset)
.failed())
llvm_unreachable("Failed to get strides");
Value stringMemRef =
Expand Down
24 changes: 16 additions & 8 deletions src/Conversion/ONNXToTOSA/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"

#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
Expand Down Expand Up @@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef<int32_t> perm) {

Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef<int64_t> size,
llvm::ArrayRef<int64_t> start) {
DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size);
DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start);
auto startVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start));
auto sizeVal =
mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size));
Value newSliceInput =
tosa::CreateOpAndInfer<mlir::tosa::SliceOp>(rewriter(), loc(),
RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(size.size(), ShapedType::kDynamic),
mlir::cast<ShapedType>(inputConst.getType()).getElementType()),
inputConst, startAttr, sizeAttr);
inputConst, startVal, sizeVal);
return newSliceInput;
}

Expand All @@ -164,8 +167,9 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef<int64_t> shape) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(shape.size(), ShapedType::kDynamic),
valueType.getElementType());
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(
rewriter(), loc(), newValueType, value, shapeAttr);
return tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter(), loc(),
newValueType, value,
mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr));
}

Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Expand All @@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) {
Type newValueType = RankedTensorType::get(
llvm::SmallVector<int64_t, 4>(lhsType.getRank(), ShapedType::kDynamic),
lhsType.getElementType());

auto int8Type = rewriter().getI8Type();
auto shiftValue =
TosaBuilder::createConst(ArrayRef<int32_t>{shift}, {1}, int8Type);
return tosa::CreateOpAndInfer<mlir::tosa::MulOp>(
rewriter(), loc(), newValueType, lhs, rhs, shift);
rewriter(), loc(), newValueType, lhs, rhs, shiftValue);
}

Value TosaBuilder::intdiv(Value &lhs, Value &rhs) {
Expand Down Expand Up @@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp<mlir::tosa::SubOp>(Value &lhs, Value &rhs);
// Return null if none is found.
ElementsAttr IndexExprBuilderForTosa::getConst(Value value) {
auto definingOp = value.getDefiningOp();
// If we have a cast between index/integer, skip it, i.e. get the defining op
// that is the input to the cast.
// If we have a cast between index/integer, skip it, i.e. get the defining
// op that is the input to the cast.
if (auto castOp = dyn_cast_or_null<arith::IndexCastOp>(definingOp)) {
Value input = castOp.getIn();
definingOp = input.getDefiningOp();
Expand Down
20 changes: 15 additions & 5 deletions src/Conversion/ONNXToTOSA/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern<ONNXReluOp> {
// Quantized types are not supported right now (in type conversion).
// Once they are, the input should be rescaled for quantized types. (TBD)
// Maps to `tosa.clamp` which has both int and fp limits.
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(op, op.getType(), input,
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max()),
rewriter.getF32FloatAttr(0.0f),
rewriter.getF32FloatAttr(std::numeric_limits<float>::max()));
auto inputElementType =
llvm::cast<TensorType>(op.getType()).getElementType();
if (llvm::isa<IntegerType>(inputElementType)) {
auto minClamp = rewriter.getI64IntegerAttr(0);
auto maxClamp =
rewriter.getI64IntegerAttr(std::numeric_limits<int32_t>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
} else {
auto minClamp = rewriter.getF32FloatAttr(0.0f);
auto maxClamp =
rewriter.getF32FloatAttr(std::numeric_limits<float>::max());
rewriter.replaceOpWithNewOp<mlir::tosa::ClampOp>(
op, op.getType(), input, minClamp, maxClamp);
}
return success();
}
};
Expand Down
10 changes: 6 additions & 4 deletions src/Conversion/ONNXToTOSA/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp"
#include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp"
Expand Down Expand Up @@ -67,13 +68,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern<ONNXGemmOp> {

llvm::SmallVector<int64_t> dynamicTensorShape = {
ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic};
A = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),

tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A,
rewriter.getDenseI64ArrayAttr(newShapeA))
.getResult();
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA))
.getResult();
B = tosa::CreateOpAndInfer<mlir::tosa::ReshapeOp>(rewriter, op->getLoc(),
RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B,
rewriter.getDenseI64ArrayAttr(newShapeB))
mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB))
.getResult();

// If transA or transB are present, create Transpose operators.
Expand Down

0 comments on commit 7a108f7

Please sign in to comment.