diff --git a/docs/BuildOnLinuxOSX.md b/docs/BuildOnLinuxOSX.md index 21fbc37e0c..b87d0205e0 100644 --- a/docs/BuildOnLinuxOSX.md +++ b/docs/BuildOnLinuxOSX.md @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd .. +cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) diff --git a/docs/BuildOnWindows.md b/docs/BuildOnWindows.md index 13e2a002ec..144e0415cf 100644 --- a/docs/BuildOnWindows.md +++ b/docs/BuildOnWindows.md @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd .. +cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp index 114c19d618..6853f9d070 100644 --- a/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ZLowToLLVM/ZLowToLLVMCommon.cpp @@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) { auto int16Ty = IntegerType::get(context, 16); auto int32Ty = IntegerType::get(context, 32); auto int64Ty = IntegerType::get(context, 64); - auto float32Ty = FloatType::getF32(context); + auto float32Ty = Float32Type::get(context); // Declare API type as an enum value, its string name and an LLVM Type // specifying its signature. @@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) { Type llvmI64Ty = IntegerType::get(context, 64); Type llvmI1Ty = IntegerType::get(context, 1); Type llvmI8Ty = IntegerType::get(context, 8); - Type llvmF32Ty = FloatType::getF32(context); + Type llvmF32Ty = Float32Type::get(context); Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3); Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20); Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty); @@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module, scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float"); create.llvm.store(recScale, recScalePtr); } else { - Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.); + Value zero = create.llvm.constant(Float32Type::get(context), (double)0.); create.llvm.store(zero, recScalePtr); } @@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module, offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float"); create.llvm.store(offset, offsetPtr); } else { - Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.); + Value zero = create.llvm.constant(Float32Type::get(context), (double)0.); create.llvm.store(zero, offsetPtr); } diff --git a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp index e976b42b7f..5a4c494f14 100644 --- a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp @@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern { // or // (memref<3x4x5xf64>, index, f64, f64, f64) Type llvmVoidTy = LLVM::LLVMVoidType::get(context); - Type llvmOptionsTy = FloatType::getF32(context); + Type llvmOptionsTy = Float32Type::get(context); Type llvmOutputTy = getPointerType(context, llvmOptionsTy); if (inType.isF64()) { - llvmOptionsTy = FloatType::getF64(context); + llvmOptionsTy = Float64Type::get(context); llvmOutputTy = getPointerType(context, llvmOptionsTy); } Type llvmI64Ty = IntegerType::get(context, 64); diff --git a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp index 2a0ee747c7..a50acf402f 100644 --- a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp @@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern { Type outType = op->getResultTypes().front(); Type llvmInType, llvmOutType; if (inType.isF16()) - llvmInType = FloatType::getF16(context); + llvmInType = Float16Type::get(context); else if (inType.isF32()) - llvmInType = FloatType::getF32(context); + llvmInType = Float32Type::get(context); else if (inType.isF64()) - llvmInType = FloatType::getF64(context); + llvmInType = Float64Type::get(context); else if (inType.isBF16()) - llvmInType = FloatType::getBF16(context); + llvmInType = Float64Type::get(context); if (outType.isInteger(1)) llvmOutType = IntegerType::get(context, 1); else if (outType.isF32()) - llvmOutType = FloatType::getF32(context); + llvmOutType = Float32Type::get(context); else if (outType.isF64()) - llvmOutType = FloatType::getF64(context); + llvmOutType = Float64Type::get(context); // Insert and/or get reference to elementary math function declaration. assert( @@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern { return SymbolRefAttr::get(context, mathFuncName); // Create function declaration. - // auto llvmF32Ty = FloatType::get(context); auto llvmFnType = LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef({llvmInType})); diff --git a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp index 62d7c25de3..a52e57afe7 100644 --- a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp @@ -62,7 +62,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { // Get memRefDescriptor, the new memref descriptor. MemRefDescriptor memRefDescriptor = - MemRefDescriptor::undef(rewriter, loc, targetStructType); + MemRefDescriptor::poison(rewriter, loc, targetStructType); auto targetElementPtrType = memRefDescriptor.getElementPtrType(); // Set the new memref to the same buffer as the source memref. @@ -78,7 +78,7 @@ class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(targetType, strides, offset))) + if (failed(targetType.getStridesAndOffset(strides, offset))) return failure(); // Unhandled dynamic offset. diff --git a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp index 565e63a7d7..00e252fdb6 100644 --- a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp +++ b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp @@ -281,7 +281,7 @@ struct ONNXCategoryMapperOpLowering SmallVector strides; int64_t alignmentOffset; // not used, just to make the function call // completed. - if (getStridesAndOffset(memRefType, strides, alignmentOffset) + if (memRefType.getStridesAndOffset(strides, alignmentOffset) .failed()) llvm_unreachable("Failed to get strides"); Value stringMemRef = diff --git a/src/Conversion/ONNXToKrnl/Math/LRN.cpp b/src/Conversion/ONNXToKrnl/Math/LRN.cpp index 1b08661a2d..12a596d08c 100644 --- a/src/Conversion/ONNXToKrnl/Math/LRN.cpp +++ b/src/Conversion/ONNXToKrnl/Math/LRN.cpp @@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern { float alphaLit = adaptor.getAlpha().convertToFloat(); float betaLit = adaptor.getBeta().convertToFloat(); int sizeLit = adaptor.getSize(); - auto f32Type = FloatType::getF32(rewriter.getContext()); + auto f32Type = Float32Type::get(rewriter.getContext()); Value biasValue = create.math.constant(f32Type, biasLit); Value alphaDivSizeValue = create.math.constant(f32Type, alphaLit / static_cast(sizeLit)); diff --git a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp index adf494c88e..655bf3c89e 100644 --- a/src/Conversion/ONNXToTOSA/DialectBuilder.cpp +++ b/src/Conversion/ONNXToTOSA/DialectBuilder.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -147,14 +148,16 @@ Value TosaBuilder::transpose(Value &value, llvm::ArrayRef perm) { Value TosaBuilder::slice(Value &inputConst, llvm::ArrayRef size, llvm::ArrayRef start) { - DenseI64ArrayAttr sizeAttr = rewriter().getDenseI64ArrayAttr(size); - DenseI64ArrayAttr startAttr = rewriter().getDenseI64ArrayAttr(start); + auto startVal = + mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(start)); + auto sizeVal = + mlir::tosa::getTosaConstShape(rewriter(), loc(), llvm::to_vector(size)); Value newSliceInput = tosa::CreateOpAndInfer(rewriter(), loc(), RankedTensorType::get( llvm::SmallVector(size.size(), ShapedType::kDynamic), mlir::cast(inputConst.getType()).getElementType()), - inputConst, startAttr, sizeAttr); + inputConst, startVal, sizeVal); return newSliceInput; } @@ -164,8 +167,9 @@ Value TosaBuilder::reshape(Value &value, llvm::ArrayRef shape) { Type newValueType = RankedTensorType::get( llvm::SmallVector(shape.size(), ShapedType::kDynamic), valueType.getElementType()); - return tosa::CreateOpAndInfer( - rewriter(), loc(), newValueType, value, shapeAttr); + return tosa::CreateOpAndInfer(rewriter(), loc(), + newValueType, value, + mlir::tosa::getTosaConstShape(rewriter(), loc(), shapeAttr)); } Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) { @@ -178,8 +182,12 @@ Value TosaBuilder::mul(Value &lhs, Value &rhs, int32_t shift) { Type newValueType = RankedTensorType::get( llvm::SmallVector(lhsType.getRank(), ShapedType::kDynamic), lhsType.getElementType()); + + auto int8Type = rewriter().getI8Type(); + auto shiftValue = TosaBuilder::createConst( + ArrayRef{static_cast(shift)}, {1}, int8Type); return tosa::CreateOpAndInfer( - rewriter(), loc(), newValueType, lhs, rhs, shift); + rewriter(), loc(), newValueType, lhs, rhs, shiftValue); } Value TosaBuilder::intdiv(Value &lhs, Value &rhs) { @@ -236,8 +244,8 @@ template Value TosaBuilder::binaryOp(Value &lhs, Value &rhs); // Return null if none is found. ElementsAttr IndexExprBuilderForTosa::getConst(Value value) { auto definingOp = value.getDefiningOp(); - // If we have a cast between index/integer, skip it, i.e. get the defining op - // that is the input to the cast. + // If we have a cast between index/integer, skip it, i.e. get the defining + // op that is the input to the cast. if (auto castOp = dyn_cast_or_null(definingOp)) { Value input = castOp.getIn(); definingOp = input.getDefiningOp(); diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 2e105d2dc5..ab8b9a43a0 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -121,11 +121,21 @@ class ONNXReluOpLoweringToTOSA : public OpConversionPattern { // Quantized types are not supported right now (in type conversion). // Once they are, the input should be rescaled for quantized types. (TBD) // Maps to `tosa.clamp` which has both int and fp limits. - rewriter.replaceOpWithNewOp(op, op.getType(), input, - rewriter.getI64IntegerAttr(0), - rewriter.getI64IntegerAttr(std::numeric_limits::max()), - rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); + auto inputElementType = + llvm::cast(op.getType()).getElementType(); + if (llvm::isa(inputElementType)) { + auto minClamp = rewriter.getI64IntegerAttr(0); + auto maxClamp = + rewriter.getI64IntegerAttr(std::numeric_limits::max()); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, minClamp, maxClamp); + } else { + auto minClamp = rewriter.getF32FloatAttr(0.0f); + auto maxClamp = + rewriter.getF32FloatAttr(std::numeric_limits::max()); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, minClamp, maxClamp); + } return success(); } }; diff --git a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp index 4f1028002c..dedd51a451 100644 --- a/src/Conversion/ONNXToTOSA/Math/Gemm.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Gemm.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" @@ -67,13 +68,14 @@ class ONNXGemmOpLoweringToTOSA : public OpConversionPattern { llvm::SmallVector dynamicTensorShape = { ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic}; + A = tosa::CreateOpAndInfer(rewriter, op->getLoc(), RankedTensorType::get(dynamicTensorShape, AType.getElementType()), A, - rewriter.getDenseI64ArrayAttr(newShapeA)) + mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeA)) .getResult(); B = tosa::CreateOpAndInfer(rewriter, op->getLoc(), RankedTensorType::get(dynamicTensorShape, BType.getElementType()), B, - rewriter.getDenseI64ArrayAttr(newShapeB)) + mlir::tosa::getTosaConstShape(rewriter, op.getLoc(), newShapeB)) .getResult(); // If transA or transB are present, create Transpose operators. diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp index 321a2b35e2..1ec1dec493 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp @@ -14,6 +14,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -60,8 +61,8 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter, } tosaPads.insert(tosaPads.end(), lastVals.begin(), lastVals.end()); TosaBuilder tosaBuilder(rewriter, loc); - return tosaBuilder.getConst( - tosaPads, {static_cast(tosaPads.size())}); + + return mlir::tosa::getTosaConstShape(rewriter, loc, tosaPads); } } // namespace tosa diff --git a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp index bcd5c7c128..6b00198e17 100644 --- a/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp +++ b/src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp @@ -45,6 +45,7 @@ T getValueFromTosaConst(mlir::Value &val) { template TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type result_ty, Args &&... args) { + auto op = rewriter.create(loc, result_ty, args...); mlir::InferShapedTypeOpInterface shapeInterface = @@ -64,6 +65,7 @@ TosaOp CreateOpAndInfer(mlir::PatternRewriter &rewriter, mlir::Location loc, // the new result shaped type. This is because rescale can include a cast to // different bit-width types and does not have a TypeAttr to define the // target type. + assert(returnedShapes.size() >= 1 && "Expected at least one returned shape"); auto predictedShape = returnedShapes[0]; if (predictedShape.hasRank()) updateType(nullptr, op, predictedShape.getDims(), diff --git a/src/Dialect/ONNX/ElementsAttr/BType.cpp b/src/Dialect/ONNX/ElementsAttr/BType.cpp index 8073d2a4e2..a6aa4b17f5 100644 --- a/src/Dialect/ONNX/ElementsAttr/BType.cpp +++ b/src/Dialect/ONNX/ElementsAttr/BType.cpp @@ -55,10 +55,10 @@ Type mlirTypeOfBType(BType btype, MLIRContext *ctx) { case BType::FLOAT : return b.getF32Type(); case BType::FLOAT16 : return b.getF16Type(); case BType::BFLOAT16 : return b.getBF16Type(); - case BType::FLOAT8E4M3FN : return b.getFloat8E4M3FNType(); - case BType::FLOAT8E4M3FNUZ : return b.getFloat8E4M3FNUZType(); - case BType::FLOAT8E5M2 : return b.getFloat8E5M2Type(); - case BType::FLOAT8E5M2FNUZ : return b.getFloat8E5M2FNUZType(); + case BType::FLOAT8E4M3FN : return b.getType(); + case BType::FLOAT8E4M3FNUZ : return b.getType(); + case BType::FLOAT8E5M2 : return b.getType(); + case BType::FLOAT8E5M2FNUZ : return b.getType(); default: llvm_unreachable("unsupported data type"); } // clang-format on @@ -104,4 +104,4 @@ BType wideBTypeOfBType(BType d) { [](auto btype) { return toBType::widetype>; }); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp index 47a74a0093..56fd3c5ca8 100644 --- a/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp +++ b/src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp @@ -96,7 +96,7 @@ LogicalResult ONNXOneHotEncoderOp::inferShapes( return success(); ONNXOneHotEncoderOpShapeHelper shapeHelper(getOperation(), {}); - return shapeHelper.computeShapeAndUpdateType(FloatType::getF32(getContext())); + return shapeHelper.computeShapeAndUpdateType(Float32Type::get(getContext())); return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp index a38ddfcb11..13308602cd 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp @@ -452,7 +452,7 @@ LogicalResult ONNXScalerOp::inferShapes( ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {}); RankedTensorType xType = mlir::dyn_cast(getX().getType()); return shapeHelper.computeShapeAndUpdateType( - FloatType::getF32(getContext()), xType.getEncoding()); + Float32Type::get(getContext()), xType.getEncoding()); } //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp b/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp index 926f37764f..e5cdb01cde 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp @@ -47,16 +47,16 @@ std::vector ONNXRandomNormalOp::resultTypeInference() { Type elementType; if (auto attr = getDtypeAttr()) { if (getDtype() == 0) { - elementType = FloatType::getF16(getContext()); + elementType = Float16Type::get(getContext()); } else if (getDtype() == 1) { - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); } else if (getDtype() == 2) { - elementType = FloatType::getF64(getContext()); + elementType = Float64Type::get(getContext()); } else { llvm_unreachable("dtype not supported for RandomNormal"); } } else { - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); } return {UnrankedTensorType::get(elementType)}; } @@ -68,11 +68,11 @@ std::vector ONNXRandomNormalOp::resultTypeInference() { LogicalResult ONNXRandomNormalOp::inferShapes( std::function doShapeInference) { auto elementTypeID = getDtype(); - Type elementType = FloatType::getF32(getContext()); + Type elementType = Float32Type::get(getContext()); if (elementTypeID == 0) - elementType = FloatType::getF16(getContext()); + elementType = Float16Type::get(getContext()); else if (elementTypeID == 2) - elementType = FloatType::getF64(getContext()); + elementType = Float64Type::get(getContext()); ONNXRandomNormalOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateType(elementType); diff --git a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp index 9df2bbe18b..321d2b55a1 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp @@ -42,13 +42,11 @@ LogicalResult ONNXRandomNormalLikeOp::verify() { if (elementTypeID < 0 || elementTypeID > 2) { return emitOpError("dtype not 0, 1 or 2."); } - if (elementTypeID == 0 && outputType != FloatType::getF16(getContext())) + if (elementTypeID == 0 && outputType != Float16Type::get(getContext())) return emitOpError("output tensor does match 0 dtype."); - else if (elementTypeID == 1 && - outputType != FloatType::getF32(getContext())) + else if (elementTypeID == 1 && outputType != Float32Type::get(getContext())) return emitOpError("output tensor does match 1 dtype."); - else if (elementTypeID == 2 && - outputType != FloatType::getF64(getContext())) + else if (elementTypeID == 2 && outputType != Float64Type::get(getContext())) return emitOpError("output tensor does match 2 dtype."); } else if (inputType != outputType) { return emitOpError("output and input element types do not match."); @@ -75,11 +73,11 @@ LogicalResult ONNXRandomNormalLikeOp::inferShapes( } else { int64_t elementTypeID = elementTypeIDDType.value(); if (elementTypeID == 0) - elementType = FloatType::getF16(getContext()); + elementType = Float16Type::get(getContext()); else if (elementTypeID == 1) - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); else if (elementTypeID == 2) - elementType = FloatType::getF64(getContext()); + elementType = Float64Type::get(getContext()); else return emitError("dtype attribute is invalid (use: 0, 1 or 2)"); } diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 7f260f2e99..9eb0b27d21 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -608,13 +608,13 @@ Type convertONNXTypeToMLIRType( Builder &builder, onnx::TensorProto_DataType onnxType) { switch (onnxType) { case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN: - return builder.getFloat8E4M3FNType(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ: - return builder.getFloat8E4M3FNUZType(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2: - return builder.getFloat8E5M2Type(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ: - return builder.getFloat8E5M2FNUZType(); + return builder.getType(); case onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16: return builder.getBF16Type(); case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: diff --git a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp index 7f27d19ebb..ae1ea165fd 100644 --- a/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp +++ b/src/Dialect/ONNX/ONNXOps/Quantize/DynamicQuantizeLinear.cpp @@ -61,7 +61,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes( IntegerType ui8Type = IntegerType::get(getContext(), 8, IntegerType::Unsigned); - FloatType f32Type = FloatType::getF32(getContext()); + FloatType f32Type = Float32Type::get(getContext()); ONNXDynamicQuantizeLinearOpShapeHelper shapeHelper(getOperation(), {}); return shapeHelper.computeShapeAndUpdateTypes( diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp index 70ee132682..bfa487d74a 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp @@ -54,10 +54,10 @@ std::vector ONNXConstantOp::resultTypeInference() { } else if (auto attr = getSparseValueAttr()) { type = mlir::cast(attr).getShapedType(); } else if (auto attr = getValueFloatAttr()) { - type = RankedTensorType::get({}, FloatType::getF32(getContext())); + type = RankedTensorType::get({}, Float32Type::get(getContext())); } else if (auto attr = getValueFloatsAttr()) { int64_t size = attr.size(); - type = RankedTensorType::get({size}, FloatType::getF32(getContext())); + type = RankedTensorType::get({size}, Float32Type::get(getContext())); } else if (auto attr = getValueIntAttr()) { type = RankedTensorType::get({}, IntegerType::get(getContext(), 64)); } else if (auto attr = getValueIntsAttr()) { diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp index 6058adfcdb..773152fc52 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp @@ -99,7 +99,7 @@ std::vector ONNXConstantOfShapeOp::resultTypeInference() { if (auto attr = getValueAttr()) { elementType = mlir::cast(attr).getElementType(); } else { - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); } return {UnrankedTensorType::get(elementType)}; } @@ -125,7 +125,7 @@ LogicalResult ONNXConstantOfShapeOp::inferShapes( } else { // If 'value' attribute is not specified, it defaults to a tensor of // value 0 and datatype float32. - elementType = FloatType::getF32(getContext()); + elementType = Float32Type::get(getContext()); llvm::SmallVector dims(1, 1); auto tensorType = RankedTensorType::get(dims, elementType); diff --git a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir index 2b56c8db2b..1db01a656a 100644 --- a/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir +++ b/test/mlir/accelerators/nnpa/conversion/lower-all-to-llvm-constant-shape.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-krnl-to-llvm %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --convert-krnl-to-llvm -cse %s -split-input-file | FileCheck %s // ----- @@ -15,11 +15,31 @@ func.func @test_zlow_softmax_constant_shape() -> () { %work_area = memref.alloc() {alignment = 4096 : i64} : memref<8192xi8> "zlow.softmax"(%input, %work_area, %shape, %res) {act_func = "ACT_NONE"} : (memref<1x1x1x1x32x64xf16>, memref<8192xi8>, memref<3xi64>, memref<1x1x1x1x32x64xf16>) -> () return +} +// CHECK: llvm.mlir.global internal constant @[[SHAPE_CONST_GLOBAL:.*]](dense<[1, 5, 10]> : tensor<3xi64>) {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<3 x i64> +// CHECK-LABEL: llvm.func @test_zlow_softmax_constant_shape +// CHECK-DAG: [[SHAPE_MEMREF_0:%.+]] = llvm.mlir.addressof @[[SHAPE_CONST_GLOBAL]] : !llvm.ptr +// CHECK-DAG: [[SHAPE_MEMREF_1:%.+]] = llvm.bitcast [[SHAPE_MEMREF_0]] : !llvm.ptr to !llvm.ptr +// CHECK-DAG: [[SHAPE_MEMREF_2:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: [[SHAPE_MEMREF_3:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_1]], [[SHAPE_MEMREF_2]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: [[SHAPE_MEMREF_4:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_1]], [[SHAPE_MEMREF_3]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: [[SHAPE_MEMREF_5:%.+]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: [[SHAPE_MEMREF_6:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_5]], [[SHAPE_MEMREF_4]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: [[SHAPE_MEMREF_7:%.+]] = llvm.mlir.constant(3 : index) : i64 +// CHECK-NEXT: [[SHAPE_MEMREF_8:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_7]], [[SHAPE_MEMREF_6]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: [[SHAPE_MEMREF_9:%.+]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: [[SHAPE_MEMREF_10:%.+]] = llvm.insertvalue [[SHAPE_MEMREF_9]], [[SHAPE_MEMREF_8]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK-LABEL: llvm.func @test_zlow_softmax_constant_shape() {{.*}} { - // CHECK: %[[DIM0:.*]] = llvm.mlir.constant(1 : i64) : i64 - // CHECK: %[[DIM1:.*]] = llvm.mlir.constant(5 : i64) : i64 - // CHECK: %[[DIM2:.*]] = llvm.mlir.constant(10 : i64) : i64 - // CHECK: llvm.call @zdnn_init_pre_transformed_desc({{.*}}, {{.*}}, {{.*}}, %[[DIM0]], %[[DIM1]], %[[DIM2]]) vararg(!llvm.func) : (i64, i64, !llvm.ptr, i64, i64, i64) -> () +// ... -} +// CHECK: %[[SHAPE:.*]] = llvm.extractvalue [[SHAPE_MEMREF_10]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[DIM0_0:.*]] = llvm.getelementptr %[[SHAPE]][0] : (!llvm.ptr) -> !llvm.ptr, i64 +// CHECK-NEXT: %[[DIM0_1:.*]] = llvm.load %[[DIM0_0]] : !llvm.ptr -> i64 +// CHECK-NEXT: %[[DIM1_0:.*]] = llvm.getelementptr %[[SHAPE]][1] : (!llvm.ptr) -> !llvm.ptr, i64 +// CHECK-NEXT: %[[DIM1_1:.*]] = llvm.load %[[DIM1_0]] : !llvm.ptr -> i64 +// CHECK-NEXT: %[[DIM2_0:.*]] = llvm.getelementptr %[[SHAPE]][2] : (!llvm.ptr) -> !llvm.ptr, i64 +// CHECK-NEXT: %[[DIM2_1:.*]] = llvm.load %[[DIM2_0]] : !llvm.ptr -> i64 + +// ... + +// CHECK: llvm.call @zdnn_init_pre_transformed_desc({{.*}}, {{.*}}, {{.*}}, %[[DIM0_1]], %[[DIM1_1]], %[[DIM2_1]]) vararg(!llvm.func) : (i64, i64, !llvm.ptr, i64, i64, i64) -> () diff --git a/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir b/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir index 9711b01c79..4d871b5260 100644 --- a/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir +++ b/test/mlir/conversion/krnl_to_llvm/krnl_category_mapper.mlir @@ -179,7 +179,7 @@ func.func private @test_category_mapper_int64_to_string(%arg0: memref<2x2xi64>) // CHECK-LABEL: @test_category_mapper_int64_to_string(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: [[LEN:%.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: [[MALLOC:%.+]] = llvm.call @malloc({{.*}}) : (i64) -> !llvm.ptr - // CHECK: [[UNDEF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> + // CHECK: [[UNDEF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> // CHECK: [[EV_1:%.+]] = llvm.insertvalue {{.*}}, [[UNDEF]][0] // CHECK: [[EV_2:%.+]] = llvm.insertvalue {{.*}}, [[EV_1]][1] // CHECK: [[C0:%.+]] = llvm.mlir.constant(0 : index) : i64 @@ -222,7 +222,7 @@ func.func private @test_krnl_global_with_129_elements() -> memref<129x!krnl.stri // CHECK: llvm.func @test_krnl_global_with_129_elements() -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> attributes {llvm.emit_c_interface, sym_visibility = "private"} { // CHECK: [[VAR_0_1_:%.+]] = llvm.mlir.addressof @cats_strings : !llvm.ptr // CHECK-DAG: [[VAR_1_1_:%.+]] = llvm.bitcast [[VAR_0_1_]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: [[VAR_2_1_:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-DAG: [[VAR_2_1_:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[VAR_3_1_:%.+]] = llvm.insertvalue [[VAR_1_1_]], [[VAR_2_1_]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[VAR_4_1_:%.+]] = llvm.insertvalue [[VAR_1_1_]], [[VAR_3_1_]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[VAR_5_1_:%.+]] = llvm.mlir.constant(0 : index) : i64 diff --git a/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir b/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir index c736b12cf4..65a9bf0f73 100644 --- a/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir +++ b/test/mlir/conversion/krnl_to_llvm/krnl_global_with_alignment_lowering.mlir @@ -11,7 +11,7 @@ func.func @test_krnl_global_constant_alignment() -> memref<3xf32> { // CHECK-LABEL: llvm.func @test_krnl_global_constant_alignment() -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> attributes {llvm.emit_c_interface} { // CHECK: [[VAR_0_:%.+]] = llvm.mlir.addressof @constant : !llvm.ptr // CHECK-DAG: [[VAR_1_:%.+]] = llvm.bitcast [[VAR_0_]] : !llvm.ptr to !llvm.ptr -// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[VAR_3_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_2_]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[VAR_4_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_3_]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[VAR_5_:%.+]] = llvm.mlir.constant(0 : index) : i64 @@ -37,7 +37,7 @@ func.func @test_krnl_global_constant_no_alignment() -> memref<2xi64> { // CHECK-LABEL: llvm.func @test_krnl_global_constant_no_alignment() -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> attributes {llvm.emit_c_interface} { // CHECK: [[VAR_0_:%.+]] = llvm.mlir.addressof @constant : !llvm.ptr // CHECK-DAG: [[VAR_1_:%.+]] = llvm.bitcast [[VAR_0_]] : !llvm.ptr to !llvm.ptr -// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-DAG: [[VAR_2_:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: [[VAR_3_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_2_]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[VAR_4_:%.+]] = llvm.insertvalue [[VAR_1_]], [[VAR_3_]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[VAR_5_:%.+]] = llvm.mlir.constant(0 : index) : i64 diff --git a/test/mlir/conversion/krnl_to_llvm/reshape.mlir b/test/mlir/conversion/krnl_to_llvm/reshape.mlir index 97d5374ec5..80edf5c5a6 100644 --- a/test/mlir/conversion/krnl_to_llvm/reshape.mlir +++ b/test/mlir/conversion/krnl_to_llvm/reshape.mlir @@ -7,7 +7,7 @@ func.func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tens "func.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: llvm.func @test_reshape -// CHECK: [[OLD_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: [[OLD_MEMREF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[INSERT_1_:%.+]] = llvm.insertvalue {{.*}}, [[OLD_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[INSERT_2_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_1_]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: [[INSERT_3_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_2_]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -17,7 +17,7 @@ func.func @test_reshape(%arg0 : tensor, %arg1 : tensor<4xi64>) -> tens // CHECK-DAG:[[INSERT_7_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_6_]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // COM: Check that there is no copy but only a new MemRef with a new view, i.e. new sizes and strides. -// CHECK-DAG: [[NEW_MEMREF:%.+]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> +// CHECK-DAG: [[NEW_MEMREF:%.+]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK: [[INSERT_8_:%.+]] = llvm.insertvalue {{.*}}, [[NEW_MEMREF]][0] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK-DAG: [[INSERT_9_:%.+]] = llvm.insertvalue {{.*}}, [[INSERT_8_]][1] : !llvm.struct<(ptr, ptr, i64, array<4 x i64>, array<4 x i64>)> // CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : index) : i64 diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir index 6b2e395407..0baf28f0ab 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize.mlir @@ -290,14 +290,11 @@ func.func private @test_reducesum1(%arg0: tensor<3x2x2xf32>, %arg1: tensor [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){ // CHECK-DAG: [[VAR_2_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) // CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_4_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[VAR_4_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index +// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index // CHECK-DAG: [[VAR_6_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi eq, [[VAR_6_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_6_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_10_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<3x2x2xf32> // CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_5_1_]], [[VAR_8_]], [[VAR_11_]]{{.}} : memref<3x1x2xf32> // CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 @@ -348,14 +345,11 @@ func.func @test_reducesum2(%arg0: tensor<3x2x2xf32>, %arg1: tensor) -> te // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){ // CHECK-DAG: [[VAR_3_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) // CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_5_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[VAR_5_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index +// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index // CHECK-DAG: [[VAR_7_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_8_1_:%.+]] = arith.cmpi eq, [[VAR_7_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index +// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_3_1_]]#0, [[VAR_3_1_]]#1, [[VAR_3_1_]]#2] : memref<3x2x2xf32> // CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_6_1_]], [[VAR_9_]], [[VAR_12_]]{{.}} : memref<3x1x2xf32> // CHECK: [[VAR_15_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir index f35603fc9e..59c67b8377 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir @@ -245,14 +245,11 @@ func.func private @test_reducesum1(%arg0: tensor<3x2x2xf32>, %arg1: tensor [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){ // CHECK-DAG: [[VAR_2_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) // CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_4_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[VAR_4_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index +// CHECK-DAG: [[VAR_5_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_2_1_]]#0 : index // CHECK-DAG: [[VAR_6_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi eq, [[VAR_6_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_10_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_10_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index +// CHECK: [[VAR_8_:%.+]] = arith.select [[VAR_6_1_]], [[CST_0_1_]], [[VAR_2_1_]]#1 : index +// CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_2_1_]]#2 : index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_1_]]#0, [[VAR_2_1_]]#1, [[VAR_2_1_]]#2] : memref<3x2x2xf32> // CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_5_1_]], [[VAR_8_]], [[VAR_11_]]{{.}} : memref<3x1x2xf32> // CHECK: [[VAR_14_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 @@ -303,14 +300,11 @@ func.func @test_reducesum2(%arg0: tensor<3x2x2xf32>, %arg1: tensor) -> te // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) with ([[LOOP_1_]]#0 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_1_]]#1 -> [[I_2_:%.+]] = 0 to 2, [[LOOP_1_]]#2 -> [[I_3_:%.+]] = 0 to 2){ // CHECK-DAG: [[VAR_3_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index) // CHECK-DAG: [[LOAD_PARAM_1_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_5_1_:%.+]] = arith.cmpi eq, [[LOAD_PARAM_1_MEM_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[VAR_5_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index +// CHECK-DAG: [[VAR_6_1_:%.+]] = arith.select [[LOAD_PARAM_1_MEM_1_]], [[CST_0_1_]], [[VAR_3_1_]]#0 : index // CHECK-DAG: [[VAR_7_1_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_8_1_:%.+]] = arith.cmpi eq, [[VAR_7_1_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_1_]], [[CST_0_1_]], [[VAR_3_1_]]#1 : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[LOAD_RES_MEM_]], [[VAR_true_]] : i1 -// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_11_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index +// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[LOAD_RES_MEM_]], [[CST_0_1_]], [[VAR_3_1_]]#2 : index // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_3_1_]]#0, [[VAR_3_1_]]#1, [[VAR_3_1_]]#2] : memref<3x2x2xf32> // CHECK: [[LOAD_RES_1_MEM_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_6_1_]], [[VAR_9_]], [[VAR_12_]]{{.}} : memref<3x1x2xf32> // CHECK: [[VAR_15_:%.+]] = arith.addf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 diff --git a/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir index 336fa49a2a..c309eacffe 100644 --- a/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Tensor/Compress_with_canonicalize.mlir @@ -14,15 +14,13 @@ func.func @compress_axis0(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor // CHECK-SAME: ([[INPUT_:%.+]]: memref<3x2xf32>, [[CONDITION_:%.+]]: memref<3xi1>) -> memref { // CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false // CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref // CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 3){ // CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_7_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_c1_]], [[VAR_c0_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref // CHECK: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_8_]] : index // CHECK: krnl.store [[VAR_10_]], [[RES_]][] : memref @@ -34,8 +32,7 @@ func.func @compress_axis0(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 3){ // CHECK: [[VAR_5_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_CONDITION_MEM_1_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_1_]], [[VAR_false_]] : i1 -// CHECK: scf.if [[VAR_7_1_]] { +// CHECK: scf.if [[LOAD_CONDITION_MEM_1_]] { // CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 2){ @@ -64,15 +61,13 @@ func.func @compress_axis0_not_enough(%arg0: tensor<3x2xf32>, %arg1: tensor<2xi1> // CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false // CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref // CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 2){ // CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_]]{{.}} : memref<2xi1> -// CHECK: [[VAR_7_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_c1_]], [[VAR_c0_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref // CHECK: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_8_]] : index // CHECK: krnl.store [[VAR_10_]], [[RES_]][] : memref @@ -86,8 +81,7 @@ func.func @compress_axis0_not_enough(%arg0: tensor<3x2xf32>, %arg1: tensor<2xi1> // CHECK: [[LOAD_CONDITION_MEM_1_:%.+]] = arith.cmpi slt, [[VAR_5_1_]], [[VAR_c2_]] : index // CHECK: scf.if [[LOAD_CONDITION_MEM_1_]] { // CHECK: [[LOAD_CONDITION_MEM_2_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_1_]]{{.}} : memref<2xi1> -// CHECK: [[VAR_8_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_2_]], [[VAR_false_]] : i1 -// CHECK: scf.if [[VAR_8_1_]] { +// CHECK: scf.if [[LOAD_CONDITION_MEM_2_]] { // CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 2){ @@ -116,15 +110,13 @@ func.func @compress_axis1(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor // CHECK-SAME: ([[INPUT_:%.+]]: memref<3x2xf32>, [[CONDITION_:%.+]]: memref<3xi1>) -> memref<3x?xf32> { // CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false // CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref // CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 3){ // CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_7_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_c1_]], [[VAR_c0_]] : index +// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref // CHECK: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_8_]] : index // CHECK: krnl.store [[VAR_10_]], [[RES_]][] : memref @@ -136,8 +128,7 @@ func.func @compress_axis1(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi1>) -> tensor // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 2){ // CHECK: [[VAR_5_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_CONDITION_MEM_1_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_5_1_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_7_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_1_]], [[VAR_false_]] : i1 -// CHECK: scf.if [[VAR_7_1_]] { +// CHECK: scf.if [[LOAD_CONDITION_MEM_1_]] { // CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 3){ @@ -166,15 +157,13 @@ func.func @compress_no_axis_not_elided(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi // CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false // CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref // CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 3){ // CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_6_]]{{.}} : memref<3xi1> -// CHECK: [[VAR_8_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_c1_]], [[VAR_c0_]] : index +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref // CHECK: [[VAR_11_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_9_]] : index // CHECK: krnl.store [[VAR_11_]], [[RES_]][] : memref @@ -191,8 +180,7 @@ func.func @compress_no_axis_not_elided(%arg0: tensor<3x2xf32>, %arg1: tensor<3xi // CHECK: [[VAR_8_1_:%.+]] = arith.cmpi slt, [[LOAD_CONDITION_MEM_1_]], [[VAR_c3_]] : index // CHECK: scf.if [[VAR_8_1_]] { // CHECK: [[LOAD_CONDITION_MEM_2_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[LOAD_CONDITION_MEM_1_]]{{.}} : memref<3xi1> -// CHECK: [[LOAD_RES_MEM_2_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_2_]], [[VAR_false_]] : i1 -// CHECK: scf.if [[LOAD_RES_MEM_2_]] { +// CHECK: scf.if [[LOAD_CONDITION_MEM_2_]] { // CHECK-DAG: [[LOAD_INPUT_MEM_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_6_1_]]#0, [[VAR_6_1_]]#1] : memref<3x2xf32> // CHECK-DAG: [[LOAD_RES_MEM_3_:%.+]] = krnl.load [[RES_]][] : memref // CHECK: krnl.store [[LOAD_INPUT_MEM_]], [[RES_1_]]{{.}}[[LOAD_RES_MEM_3_]]{{.}} : memref @@ -218,15 +206,13 @@ func.func @compress_no_axis_enough_cond(%arg0: tensor<3x2xf32>, %arg1: tensor<6x // CHECK-SAME: ([[INPUT_:%.+]]: memref<3x2xf32>, [[CONDITION_:%.+]]: memref<6xi1>) -> memref { // CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false // CHECK-DAG: [[RES_:%.+]] = memref.alloca() : memref // CHECK: krnl.store [[VAR_c0_]], [[RES_]][] : memref // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ // CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index // CHECK: [[LOAD_CONDITION_MEM_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[VAR_6_]]{{.}} : memref<6xi1> -// CHECK: [[VAR_8_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_]], [[VAR_false_]] : i1 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_8_]], [[VAR_c1_]], [[VAR_c0_]] : index +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[LOAD_CONDITION_MEM_]], [[VAR_c1_]], [[VAR_c0_]] : index // CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]][] : memref // CHECK: [[VAR_11_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[VAR_9_]] : index // CHECK: krnl.store [[VAR_11_]], [[RES_]][] : memref @@ -241,8 +227,7 @@ func.func @compress_no_axis_enough_cond(%arg0: tensor<3x2xf32>, %arg1: tensor<6x // CHECK-DAG: [[VAR_6_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[LOAD_CONDITION_MEM_1_:%.+]] = krnl.load [[RES_2_]][] : memref // CHECK: [[LOAD_CONDITION_MEM_2_:%.+]] = krnl.load [[CONDITION_]]{{.}}[[LOAD_CONDITION_MEM_1_]]{{.}} : memref<6xi1> -// CHECK: [[VAR_9_1_:%.+]] = arith.cmpi ne, [[LOAD_CONDITION_MEM_2_]], [[VAR_false_]] : i1 -// CHECK: scf.if [[VAR_9_1_]] { +// CHECK: scf.if [[LOAD_CONDITION_MEM_2_]] { // CHECK-DAG: [[VAR_11_1_:%.+]] = krnl.load [[INPUT_]]{{.}}[[VAR_6_1_]]#0, [[VAR_6_1_]]#1] : memref<3x2xf32> // CHECK-DAG: [[LOAD_RES_MEM_2_:%.+]] = krnl.load [[RES_]][] : memref // CHECK: krnl.store [[VAR_11_1_]], [[RES_1_]]{{.}}[[LOAD_RES_MEM_2_]]{{.}} : memref diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir index 0437de4625..d754e00f0c 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Conv.mlir @@ -11,7 +11,8 @@ func.func @test_onnx_conv2d_stride_13(%arg0: tensor<5x3x256x256xf32>, %arg1 : te // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x15x15x2xf32> +// CHECK: %[[VAL_6_0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]], %[[VAL_6_0]], %[[VAL_6_0]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x15x15x2xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x15x15x2xf32>, tensor<4xi32>) -> tensor<5x2x15x15xf32> // CHECK: return %[[VAL_8]] : tensor<5x2x15x15xf32> @@ -29,7 +30,8 @@ func.func @test_onnx_conv2d_novalue(%arg0: tensor<5x3x256x256xf32>, %arg1 : tens // CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> -// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x197x199x2xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_9]], %[[VAL_9]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x197x199x2xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x197x199x2xf32>, tensor<4xi32>) -> tensor<5x2x197x199xf32> // CHECK: return %[[VAL_8]] : tensor<5x2x197x199xf32> @@ -47,7 +49,8 @@ func.func @test_onnx_conv2d_no_dilation_pad(%arg0: tensor<5x3x256x256xf32>, %arg // CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<7x3x64x64xf32>, tensor<4xi32>) -> tensor<7x64x64x3xf32> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<7xf32>}> : () -> tensor<7xf32> -// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<7x64x64x3xf32>, tensor<7xf32>) -> tensor<5x15x15x7xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_9]], %[[VAL_9]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x3xf32>, tensor<7x64x64x3xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x15x15x7xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x15x15x7xf32>, tensor<4xi32>) -> tensor<5x7x15x15xf32> // CHECK: return %[[VAL_8]] : tensor<5x7x15x15xf32> @@ -65,7 +68,8 @@ func.func @test_onnx_conv2d_no_dilation_pad_stride(%arg0: tensor<5x3x256x260xf32 // CHECK: %[[VAL_3:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_2]] : (tensor<5x3x256x260xf32>, tensor<4xi32>) -> tensor<5x256x260x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_2]] : (tensor<2x3x60x64xf32>, tensor<4xi32>) -> tensor<2x60x64x3xf32> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> -// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x260x3xf32>, tensor<2x60x64x3xf32>, tensor<2xf32>) -> tensor<5x197x197x2xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_9]], %[[VAL_9]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x260x3xf32>, tensor<2x60x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x197x197x2xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x197x197x2xf32>, tensor<4xi32>) -> tensor<5x2x197x197xf32> // CHECK: return %[[VAL_8]] : tensor<5x2x197x197xf32> @@ -82,22 +86,36 @@ func.func @test_onnx_conv2d_group(%arg0: tensor<5x64x256x256xf32>, %arg1 : tenso // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x64x256x256xf32>, tensor<4xi32>) -> tensor<5x256x256x64xf32> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<12x16x45x45xf32>, tensor<4xi32>) -> tensor<12x45x45x16xf32> -// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_9:.*]] = tosa.conv2d %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> -// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.conv2d %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> -// CHECK: %[[VAL_14:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.slice %[[VAL_4]] {size = array, start = array} : (tensor<5x256x256x64xf32>) -> tensor<5x256x256x16xf32> -// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_5]] {size = array, start = array} : (tensor<12x45x45x16xf32>) -> tensor<3x45x45x16xf32> -// CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_2]] {size = array, start = array} : (tensor<12xf32>) -> tensor<3xf32> -// CHECK: %[[VAL_21:.*]] = tosa.conv2d %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>) -> tensor<5x17x17x3xf32> +// CHECK: %[[STARTS_0:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[SIZES_0:.*]] = tosa.const_shape {value = dense<[5, 256, 256, 16]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_0]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32> +// CHECK: %[[SIZES_1:.*]] = tosa.const_shape {value = dense<[3, 45, 45, 16]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_0]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32> +// CHECK: %[[STARTS_2:.*]] = tosa.const_shape {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[THREE:.*]] = tosa.const_shape {value = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_2]], %[[STARTS_2]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.conv2d %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32> +// CHECK: %[[STARTS_3:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 16]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_3]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32> +// CHECK: %[[STARTS_4:.*]] = tosa.const_shape {value = dense<[3, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_4]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32> +// CHECK: %[[VAL_12:.*]] = tosa.slice %[[VAL_2]], %[[THREE]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.conv2d %[[VAL_10]], %[[VAL_11]], %[[VAL_12]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32> +// CHECK: %[[STARTS_5:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 32]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_14:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_5]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32> +// CHECK: %[[STARTS_6:.*]] = tosa.const_shape {value = dense<[6, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_6]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32> +// CHECK: %[[SIX:.*]] = tosa.const_shape {value = dense<6> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_2]], %[[SIX]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32> +// CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32> +// CHECK: %[[STARTS_7:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 48]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_18:.*]] = tosa.slice %[[VAL_4]], %[[STARTS_7]], %[[SIZES_0]] : (tensor<5x256x256x64xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<5x256x256x16xf32> +// CHECK: %[[STARTS_8:.*]] = tosa.const_shape {value = dense<[9, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_5]], %[[STARTS_8]], %[[SIZES_1]] : (tensor<12x45x45x16xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<3x45x45x16xf32> +// CHECK: %[[NINE:.*]] = tosa.const_shape {value = dense<9> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_2]], %[[NINE]], %[[THREE]] : (tensor<12xf32>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<3xf32> +// CHECK: %[[VAL_21:.*]] = tosa.conv2d %[[VAL_18]], %[[VAL_19]], %[[VAL_20]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x256x256x16xf32>, tensor<3x45x45x16xf32>, tensor<3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x17x17x3xf32> // CHECK: %[[VAL_22:.*]] = tosa.concat %[[VAL_9]], %[[VAL_13]], %[[VAL_17]], %[[VAL_21]] {axis = 3 : i32} : (tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>, tensor<5x17x17x3xf32>) -> tensor<5x17x17x12xf32> // CHECK: %[[VAL_23:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_22]], %[[VAL_23]] : (tensor<5x17x17x12xf32>, tensor<4xi32>) -> tensor<5x12x17x17xf32> @@ -115,8 +133,9 @@ func.func @test_onnx_conv2d_autopad(%arg0: tensor<5x3x125x256xf32>, %arg1 : tens // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_4:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_3]] : (tensor<5x3x125x256xf32>, tensor<4xi32>) -> tensor<5x125x256x3xf32> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_3]] : (tensor<2x3x64x64xf32>, tensor<4xi32>) -> tensor<2x64x64x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x125x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>) -> tensor<5x125x256x2xf32> +// CHECK-DAG: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.conv2d %[[VAL_4]], %[[VAL_5]], %[[VAL_2]], %[[ZERO]], %[[ZERO]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x125x256x3xf32>, tensor<2x64x64x3xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<5x125x256x2xf32> // CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_6]], %[[VAL_7]] : (tensor<5x125x256x2xf32>, tensor<4xi32>) -> tensor<5x2x125x256xf32> // CHECK: return %[[VAL_8]] : tensor<5x2x125x256xf32> -} \ No newline at end of file +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 623ef3fe5f..50957fd565 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -5,7 +5,7 @@ func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { "func.return"(%0) : (tensor<10x10xf32>) -> () // CHECK-LABEL: func @test_relu // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> { -// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<10x10xf32>) -> tensor<10x10xf32> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<10x10xf32>) -> tensor<10x10xf32> // CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32> // CHECK-NEXT: } } @@ -17,7 +17,7 @@ func.func @test_relu_dynamic(%arg0 : tensor) -> tensor<*xf32> { "func.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: func @test_relu_dynamic // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor) -> tensor +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor) -> tensor // CHECK-NEXT: return [[VAR_0_]] : tensor // CHECK-NEXT: } } @@ -60,7 +60,8 @@ func.func @test_add_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) "func.return"(%0) : (tensor<13x21x1xf32>) -> () // CHECK-LABEL: func.func @test_add_broadcast // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { -// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32> // CHECK: [[VAR_1_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> // CHECK: return [[VAR_1_]] : tensor<13x21x1xf32> } @@ -83,7 +84,8 @@ func.func @test_sub_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) "func.return"(%0) : (tensor<13x21x1xf32>) -> () // CHECK-LABEL: func.func @test_sub_broadcast // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { -// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> +// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32> // CHECK: [[VAR_1_:%.+]] = tosa.sub [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> // CHECK: return [[VAR_1_]] : tensor<13x21x1xf32> } @@ -106,7 +108,8 @@ func.func @test_div_broadcast(%arg0: tensor<13x21x1xi32>, %arg1: tensor<1xi32>) "func.return"(%0) : (tensor<13x21x1xi32>) -> () // CHECK-LABEL: func @test_div_broadcast // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi32>, [[PARAM_1_:%.+]]: tensor<1xi32>) -> tensor<13x21x1xi32> { -// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1xi32> +// CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xi32>, !tosa.shape<3>) -> tensor<1x1x1xi32> // CHECK-NEXT: [[VAR_1_:%.+]] = tosa.int_div [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi32>, tensor<1x1x1xi32>) -> tensor<13x21x1xi32> } @@ -118,7 +121,8 @@ func.func @test_div_decomposed(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1 // CHECK-LABEL: func @test_div_decomposed // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> { // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32> -// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32> +// CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32> } // ----- @@ -129,6 +133,8 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens // CHECK-LABEL: func @test_div_decomposed_broadcast // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> { // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<1xf32>) -> tensor<1xf32> -// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1xf32> -// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> +// CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32> } diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir index 5ccbd32a28..37ef4caee4 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_linear.mlir @@ -21,10 +21,11 @@ func.func @gemm_to_fc_broadcast(%arg0: tensor<2x5xf32>, %arg1: tensor<4x5xf32>, // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x5xf32>, [[PARAM_1_:%.+]]: tensor<4x5xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>) -> tensor<2x4xf32> { // CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> // CHECK-DAG: [[VAR_1_:%.+]] = tosa.fully_connected [[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]] : (tensor<2x5xf32>, tensor<4x5xf32>, tensor<4xf32>) -> tensor<2x4xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> -// CHECK: [[VAR_3_:%.+]] = tosa.add [[VAR_1_]], [[VAR_2_]] : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> -// CHECK: return [[VAR_3_]] : tensor<2x4xf32> -// CHECK: } +// CHECK-NEXT: [[VAR_2_:%.+]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-NEXT: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]], [[VAR_2_]] : (tensor<1xf32>, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK-NEXT: [[VAR_4_:%.+]] = tosa.add [[VAR_1_]], [[VAR_3_]] : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: return [[VAR_4_]] : tensor<2x4xf32> +// CHECK-NEXT: } } // ----- @@ -41,4 +42,4 @@ func.func @gemm_to_fc_opt(%arg0: tensor<1x5xf32>, %arg1: tensor<4x5xf32>) -> ten // CHECK: %[[VAL_4:.*]] = tosa.fully_connected %[[VAL_0]], %[[VAL_1]], %[[VAL_3]] : (tensor<1x5xf32>, tensor<4x5xf32>, tensor<4xf32>) -> tensor<1x4xf32> // CHECK: return %[[VAL_4]] : tensor<1x4xf32> // CHECK: } -} \ No newline at end of file +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir index 3654d493ea..81b08bd2b2 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Gemm_to_matmul.mlir @@ -5,13 +5,17 @@ func.func @test_gemm_to_matmul(%arg0: tensor<3x5xf32>, %arg1: tensor<5x4xf32>, % return %0 : tensor<3x4xf32> // CHECK-LABEL: func.func @test_gemm_to_matmul // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x5xf32>, [[PARAM_1_:%.+]]: tensor<5x4xf32>, [[PARAM_2_:%.+]]: tensor<3x4xf32>) -> tensor<3x4xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<3x5xf32>) -> tensor<1x3x5xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<5x4xf32>) -> tensor<1x5x4xf32> +// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 3, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x5xf32>, !tosa.shape<3>) -> tensor<1x3x5xf32> +// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 5, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<5x4xf32>, !tosa.shape<3>) -> tensor<1x5x4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x5xf32>, tensor<1x5x4xf32>) -> tensor<1x3x4xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x5xf32>, tensor<1x5x4xf32>) -> tensor<1x3x4xf32> +// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x4xf32>, !tosa.shape<3>) -> tensor<1x3x4xf32> // CHECK: [[VAR_4_:%.+]] = tosa.add [[VAR_2_]], [[VAR_3_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.reshape [[VAR_4_]] {new_shape = array} : (tensor<1x3x4xf32>) -> tensor<3x4xf32> +// CHECK-DAG: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_5_:%.+]] = tosa.reshape [[VAR_4_]], [[SHAPE_3]] : (tensor<1x3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> // CHECK: return [[VAR_5_]] : tensor<3x4xf32> // CHECK: } } @@ -23,14 +27,19 @@ func.func @test_alpha(%arg0: tensor<3x6xf32>, %arg1: tensor<6x4xf32>, %arg2: ten return %0 : tensor<3x4xf32> // CHECK-LABEL: func.func @test_alpha // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x6xf32>, [[PARAM_1_:%.+]]: tensor<6x4xf32>, [[PARAM_2_:%.+]]: tensor<3x4xf32>) -> tensor<3x4xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<3x6xf32>) -> tensor<1x3x6xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x6x4xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.618000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x4xf32>) -> tensor<1x3x4xf32> +// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32> +// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 6, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<6x4xf32>, !tosa.shape<3>) -> tensor<1x6x4xf32> +// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.618000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_0_]], [[ZERO]] : (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32> +// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x4xf32>, !tosa.shape<3>) -> tensor<1x3x4xf32> // CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_4_]], [[VAR_5_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32> -// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]] {new_shape = array} : (tensor<1x3x4xf32>) -> tensor<3x4xf32> +// CHECK-DAG: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]], [[SHAPE_3]] : (tensor<1x3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> // CHECK: return [[VAR_7_]] : tensor<3x4xf32> // CHECK: } } @@ -42,15 +51,20 @@ func.func @test_beta(%arg0: tensor<3x6xf32>, %arg1: tensor<6x6xf32>, %arg2: tens return %0 : tensor<3x6xf32> // CHECK-LABEL: func.func @test_beta // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x6xf32>, [[PARAM_1_:%.+]]: tensor<6x6xf32>, [[PARAM_2_:%.+]]: tensor<3x6xf32>) -> tensor<3x6xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<3x6xf32>) -> tensor<1x3x6xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<6x6xf32>) -> tensor<1x6x6xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x6xf32>) -> tensor<1x3x6xf32> +// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32> +// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 6, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<6x6xf32>, !tosa.shape<3>) -> tensor<1x6x6xf32> +// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_3_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x6xf32>) -> tensor<1x3x6xf32> +// CHECK: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_2_]], [[VAR_3_]], [[ZERO]] : (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32> +// CHECK: [[VAR_5_:%.+]] = tosa.matmul [[VAR_0_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x6xf32>) -> tensor<1x3x6xf32> // CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_5_]], [[VAR_4_]] : (tensor<1x3x6xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32> -// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]] {new_shape = array} : (tensor<1x3x6xf32>) -> tensor<3x6xf32> +// CHECK-DAG: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 6]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]], [[SHAPE_3]] : (tensor<1x3x6xf32>, !tosa.shape<2>) -> tensor<3x6xf32> // CHECK: return [[VAR_7_]] : tensor<3x6xf32> // CHECK: } } @@ -62,14 +76,18 @@ func.func @test_transa(%arg0: tensor<6x3xf32>, %arg1: tensor<6x4xf32>, %arg2: te return %0 : tensor<3x4xf32> // CHECK-LABEL: func.func @test_transa // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<6x3xf32>, [[PARAM_1_:%.+]]: tensor<6x4xf32>, [[PARAM_2_:%.+]]: tensor<3x4xf32>) -> tensor<3x4xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<6x3xf32>) -> tensor<1x6x3xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x6x4xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 6, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<6x3xf32>, !tosa.shape<3>) -> tensor<1x6x3xf32> +// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 6, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<6x4xf32>, !tosa.shape<3>) -> tensor<1x6x4xf32> +// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: [[VAR_3_:%.+]] = tosa.transpose [[VAR_0_]], [[VAR_2_]] : (tensor<1x6x3xf32>, tensor<3xi32>) -> tensor<1x3x6xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.matmul [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32> +// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_5_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x4xf32>, !tosa.shape<3>) -> tensor<1x3x4xf32> // CHECK: [[VAR_6_:%.+]] = tosa.add [[VAR_4_]], [[VAR_5_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32> -// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]] {new_shape = array} : (tensor<1x3x4xf32>) -> tensor<3x4xf32> +// CHECK-DAG: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[VAR_6_]], [[SHAPE_3]] : (tensor<1x3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> // CHECK: return [[VAR_7_]] : tensor<3x4xf32> // CHECK: } } @@ -81,17 +99,22 @@ func.func @test_transb(%arg0: tensor<3x6xf32>, %arg1: tensor<4x6xf32>, %arg2: te return %0 : tensor<3x4xf32> // CHECK-LABEL: func.func @test_transb // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x6xf32>, [[PARAM_1_:%.+]]: tensor<4x6xf32>, [[PARAM_2_:%.+]]: tensor<3x4xf32>) -> tensor<3x4xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<3x6xf32>) -> tensor<1x3x6xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<4x6xf32>) -> tensor<1x4x6xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 3, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<3x6xf32>, !tosa.shape<3>) -> tensor<1x3x6xf32> +// CHECK: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 4, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<4x6xf32>, !tosa.shape<3>) -> tensor<1x4x6xf32> +// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_2_]] : (tensor<1x4x6xf32>, tensor<3xi32>) -> tensor<1x6x4xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<1.184000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x3x6xf32>) -> tensor<1x3x6xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = tosa.matmul [[VAR_5_]], [[VAR_3_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32> -// CHECK-DAG: [[VAR_7_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_2_]] : (tensor<1x4x6xf32>, tensor<3xi32>) -> tensor<1x6x4xf32> +// CHECK: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<1.184000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_0_]], [[ZERO]] : (tensor<1x1x1xf32>, tensor<1x3x6xf32>, tensor<1xi8>) -> tensor<1x3x6xf32> +// CHECK: [[VAR_6_:%.+]] = tosa.matmul [[VAR_5_]], [[VAR_3_]] : (tensor<1x3x6xf32>, tensor<1x6x4xf32>) -> tensor<1x3x4xf32> +// CHECK: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 3, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_7_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<3x4xf32>, !tosa.shape<3>) -> tensor<1x3x4xf32> // CHECK: [[VAR_8_:%.+]] = tosa.add [[VAR_6_]], [[VAR_7_]] : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x3x4xf32> -// CHECK: [[VAR_9_:%.+]] = tosa.reshape [[VAR_8_]] {new_shape = array} : (tensor<1x3x4xf32>) -> tensor<3x4xf32> +// CHECK: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[3, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_9_:%.+]] = tosa.reshape [[VAR_8_]], [[SHAPE_3]] : (tensor<1x3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> // CHECK: return [[VAR_9_]] : tensor<3x4xf32> // CHECK: } } @@ -105,12 +128,15 @@ func.func @test_no_c(%arg0: tensor<1x5xf32>, %arg1: tensor<5x5xf32>) -> tensor<1 // CHECK-LABEL: func.func @test_no_c // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x5xf32>, [[PARAM_1_:%.+]]: tensor<5x5xf32>) -> tensor<1x5xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<1x5xf32>) -> tensor<1x1x5xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<5x5xf32>) -> tensor<1x5x5xf32> -// CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<1x5xf32>, !tosa.shape<3>) -> tensor<1x1x5xf32> +// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 5, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<5x5xf32>, !tosa.shape<3>) -> tensor<1x5x5xf32> +// CHECK: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<1x5x5xf32>, tensor<3xi32>) -> tensor<1x5x5xf32> // CHECK: [[VAR_5_:%.+]] = tosa.matmul [[VAR_1_]], [[VAR_4_]] : (tensor<1x1x5xf32>, tensor<1x5x5xf32>) -> tensor<1x1x5xf32> -// CHECK: [[VAR_6_:%.+]] = tosa.reshape [[VAR_5_]] {new_shape = array} : (tensor<1x1x5xf32>) -> tensor<1x5xf32> +// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_6_:%.+]] = tosa.reshape [[VAR_5_]], [[SHAPE_2]] : (tensor<1x1x5xf32>, !tosa.shape<2>) -> tensor<1x5xf32> // CHECK: return [[VAR_6_]] : tensor<1x5xf32> // CHECK: } } @@ -124,12 +150,16 @@ func.func @test_no_c_no_trans(%arg0: tensor<1x5xf32>, %arg1: tensor<5x6xf32>) -> // CHECK-LABEL: func.func @test_no_c_no_trans // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x5xf32>, [[PARAM_1_:%.+]]: tensor<5x6xf32>) -> tensor<1x6xf32> { // CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<1x5xf32>) -> tensor<1x1x5xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<5x6xf32>) -> tensor<1x5x6xf32> +// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<1x5xf32>, !tosa.shape<3>) -> tensor<1x1x5xf32> +// CHECK-DAG: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 5, 6]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_2_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<5x6xf32>, !tosa.shape<3>) -> tensor<1x5x6xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<1.349000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x1x5xf32>) -> tensor<1x1x5xf32> +// CHECK-DAG: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_4_:%.+]] = tosa.mul [[VAR_3_]], [[VAR_1_]], [[ZERO]] : (tensor<1x1x1xf32>, tensor<1x1x5xf32>, tensor<1xi8>) -> tensor<1x1x5xf32> // CHECK: [[VAR_5_:%.+]] = tosa.matmul [[VAR_4_]], [[VAR_2_]] : (tensor<1x1x5xf32>, tensor<1x5x6xf32>) -> tensor<1x1x6xf32> -// CHECK: [[VAR_6_:%.+]] = tosa.reshape [[VAR_5_]] {new_shape = array} : (tensor<1x1x6xf32>) -> tensor<1x6xf32> +// CHECK-DAG: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 6]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_6_:%.+]] = tosa.reshape [[VAR_5_]], [[SHAPE_2]] : (tensor<1x1x6xf32>, !tosa.shape<2>) -> tensor<1x6xf32> // CHECK: return [[VAR_6_]] : tensor<1x6xf32> // CHECK: } } @@ -141,24 +171,30 @@ func.func @test_mixed(%arg0: tensor<11x5xf32>, %arg1: tensor<3x11xf32>, %arg2: t return %0 : tensor<5x3xf32> // CHECK-LABEL: func.func @test_mixed // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<11x5xf32>, [[PARAM_1_:%.+]]: tensor<3x11xf32>, [[PARAM_2_:%.+]]: tensor<5x3xf32>) -> tensor<5x3xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<11x5xf32>) -> tensor<1x11x5xf32> -// CHECK-DAG: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array} : (tensor<3x11xf32>) -> tensor<1x3x11xf32> -// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: [[SHAPE_0:%.+]] = tosa.const_shape {value = dense<[1, 11, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE_0]] : (tensor<11x5xf32>, !tosa.shape<3>) -> tensor<1x11x5xf32> +// CHECK: [[SHAPE_1:%.+]] = tosa.const_shape {value = dense<[1, 3, 11]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE_1]] : (tensor<3x11xf32>, !tosa.shape<3>) -> tensor<1x3x11xf32> +// CHECK: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_3_:%.+]] = tosa.transpose [[VAR_0_]], [[VAR_2_]] : (tensor<1x11x5xf32>, tensor<3xi32>) -> tensor<1x5x11xf32> -// CHECK-DAG: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: [[VAR_3_:%.+]] = tosa.transpose [[VAR_0_]], [[VAR_2_]] : (tensor<1x11x5xf32>, tensor<3xi32>) -> tensor<1x5x11xf32> +// CHECK: [[VAR_4_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_5_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_4_]] : (tensor<1x3x11xf32>, tensor<3xi32>) -> tensor<1x11x3xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = "tosa.const"() <{value = dense<1.402000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[VAR_5_:%.+]] = tosa.transpose [[VAR_1_]], [[VAR_4_]] : (tensor<1x3x11xf32>, tensor<3xi32>) -> tensor<1x11x3xf32> +// CHECK: [[VAR_6_:%.+]] = "tosa.const"() <{value = dense<1.402000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: [[ZERO_0:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = tosa.mul [[VAR_6_]], [[VAR_3_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x5x11xf32>) -> tensor<1x5x11xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.998000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: [[VAR_9_:%.+]] = tosa.reshape [[PARAM_2_]] {new_shape = array} : (tensor<5x3xf32>) -> tensor<1x5x3xf32> +// CHECK: [[VAR_7_:%.+]] = tosa.mul [[VAR_6_]], [[VAR_3_]], [[ZERO_0]] : (tensor<1x1x1xf32>, tensor<1x5x11xf32>, tensor<1xi8>) -> tensor<1x5x11xf32> +// CHECK: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.998000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK: [[SHAPE_2:%.+]] = tosa.const_shape {value = dense<[1, 5, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: [[VAR_9_:%.+]] = tosa.reshape [[PARAM_2_]], [[SHAPE_2]] : (tensor<5x3xf32>, !tosa.shape<3>) -> tensor<1x5x3xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = tosa.mul [[VAR_8_]], [[VAR_9_]] {shift = 0 : i8} : (tensor<1x1x1xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = tosa.matmul [[VAR_7_]], [[VAR_5_]] : (tensor<1x5x11xf32>, tensor<1x11x3xf32>) -> tensor<1x5x3xf32> +// CHECK: [[ZERO_1:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: [[VAR_10_:%.+]] = tosa.mul [[VAR_8_]], [[VAR_9_]], [[ZERO_1]] : (tensor<1x1x1xf32>, tensor<1x5x3xf32>, tensor<1xi8>) -> tensor<1x5x3xf32> +// CHECK: [[VAR_11_:%.+]] = tosa.matmul [[VAR_7_]], [[VAR_5_]] : (tensor<1x5x11xf32>, tensor<1x11x3xf32>) -> tensor<1x5x3xf32> // CHECK: [[VAR_12_:%.+]] = tosa.add [[VAR_11_]], [[VAR_10_]] : (tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32> -// CHECK: [[VAR_13_:%.+]] = tosa.reshape [[VAR_12_]] {new_shape = array} : (tensor<1x5x3xf32>) -> tensor<5x3xf32> +// CHECK: [[SHAPE_3:%.+]] = tosa.const_shape {value = dense<[5, 3]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: [[VAR_13_:%.+]] = tosa.reshape [[VAR_12_]], [[SHAPE_3]] : (tensor<1x5x3xf32>, !tosa.shape<2>) -> tensor<5x3xf32> // CHECK: return [[VAR_13_]] : tensor<5x3xf32> // CHECK: } -} \ No newline at end of file +} diff --git a/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir index 937638926e..6f67c65e36 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/ReduceMean.mlir @@ -9,8 +9,10 @@ return %1 : tensor<2x5x1x1xf32> // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32> +// CHECK: %[[SHAPE:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]], %[[SHAPE]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]], %[[ZERO]] : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<2x5x1x1xf32> // CHECK: return %[[VAL_5]] : tensor<2x5x1x1xf32> } @@ -27,8 +29,10 @@ return %0 : tensor<1x1x1x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<1x1x9x11xf32>) -> tensor<1x1x1x11xf32> // CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<1x1x1x11xf32>) -> tensor<1x1x1x1xf32> // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.00101010106> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[SHAPE:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]], %[[SHAPE]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]], %[[ZERO]] : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<1x1x1x1xf32> // CHECK: return %[[VAL_7]] : tensor<1x1x1x1xf32> } @@ -42,10 +46,13 @@ return %1 : tensor<2x5xf32> // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5xf32> { // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x5x1x1xf32>) -> tensor<2x5xf32> +// CHECK: %[[SHAPE_0:.*]] = tosa.const_shape {value = dense<[2, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]], %[[SHAPE_0]] : (tensor<2x5x1x1xf32>, !tosa.shape<2>) -> tensor<2x5xf32> // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x5xf32>, tensor<1x1xf32>) -> tensor<2x5xf32> +// CHECK: %[[SHAPE_1:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]], %[[SHAPE_1]] : (tensor, !tosa.shape<2>) -> tensor<1x1xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]], %[[ZERO]] : (tensor<2x5xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x5xf32> // CHECK: return %[[VAL_6]] : tensor<2x5xf32> } @@ -60,8 +67,10 @@ return %1 : tensor<2x5x1x1xf32> // CHECK: %[[VAL_1:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 2 : i32} : (tensor<2x5x9x11xf32>) -> tensor<2x5x1x11xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 3 : i32} : (tensor<2x5x1x11xf32>) -> tensor<2x5x1x1xf32> // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0101010101> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x1x1xf32> +// CHECK: %[[SHAPE:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]], %[[SHAPE]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]], %[[ZERO]] : (tensor<2x5x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<2x5x1x1xf32> // CHECK: return %[[VAL_5]] : tensor<2x5x1x1xf32> } @@ -75,7 +84,9 @@ return %0 : tensor<2x5x9x11xf32> // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> { // CHECK: %[[VAL_1:.*]] = tosa.identity %[[VAL_0]] : (tensor<2x5x9x11xf32>) -> tensor<2x5x9x11xf32> // CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x5x9x11xf32>, tensor<1x1x1x1xf32>) -> tensor<2x5x9x11xf32> +// CHECK: %[[SHAPE:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]], %[[SHAPE]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], %[[ZERO]] : (tensor<2x5x9x11xf32>, tensor<1x1x1x1xf32>, tensor<1xi8>) -> tensor<2x5x9x11xf32> // CHECK: return %[[VAL_4]] : tensor<2x5x9x11xf32> } diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir index 2ecb8f5795..616442e935 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Softmax.mlir @@ -7,7 +7,8 @@ func.func @test_softmax_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK: %[[VAL_1:.*]] = tosa.exp %[[VAL_0]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 2 : i32} : (tensor<13x21x3xf32>) -> tensor<13x21x1xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], %[[ZERO]] : (tensor<13x21x3xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> } // ----- @@ -19,7 +20,8 @@ func.func @test_softmax_v13_axis_one(%arg0: tensor<13x21x3xf32>) -> tensor<13x21 // CHECK: %[[VAL_1:.*]] = tosa.exp %[[VAL_0]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32> // CHECK: %[[VAL_3:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<13x1x3xf32>) -> tensor<13x1x3xf32> -// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]], %[[ZERO]] : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> } // ----- @@ -32,7 +34,8 @@ func.func @test_softmax_before_v13(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3 // CHECK: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 1 : i32} : (tensor<13x21x3xf32>) -> tensor<13x1x3xf32> // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 2 : i32} : (tensor<13x1x3xf32>) -> tensor<13x1x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<13x1x1xf32>) -> tensor<13x1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<13x1x1xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]], %[[ZERO]] : (tensor<13x21x3xf32>, tensor<13x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> } // ----- @@ -46,5 +49,6 @@ func.func @test_softmax_before_v13_axis_zero(%arg0: tensor<13x21x3xf32>) -> tens // CHECK: %[[VAL_3:.*]] = tosa.reduce_sum %[[VAL_2]] {axis = 1 : i32} : (tensor<1x21x3xf32>) -> tensor<1x1x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> // CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1x1xf32>) -> tensor<1x1x1xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> -} \ No newline at end of file +// CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], %[[ZERO]] : (tensor<13x21x3xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> +} diff --git a/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir b/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir index 8d9d426494..e15e8653b5 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/AveragePool.mlir @@ -141,10 +141,10 @@ func.func @test_averagepool_pad_with_count_include_pad(%arg0 : tensor<5x5x32x32x } // CHECK-LABEL: func.func @test_averagepool_pad_with_count_include_pad // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x32x32xf32>) -> tensor<5x5x32x32xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xi64>}> : () -> tensor<8xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<5x5x32x32xf32>, tensor<8xi64>, tensor) -> tensor<5x5x34x34xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<5x5x32x32xf32>, !tosa.shape<8>, tensor) -> tensor<5x5x34x34xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<5x5x34x34xf32>, tensor<4xi32>) -> tensor<5x34x34x5xf32> // CHECK-DAG: [[VAR_5_:%.+]] = tosa.avg_pool2d [[VAR_4_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x34x34x5xf32>) -> tensor<5x32x32x5xf32> @@ -162,10 +162,10 @@ func.func @test_averagepool_pad_nonunif_with_count_include_pad(%arg0 : tensor<5x } // CHECK-LABEL: func.func @test_averagepool_pad_nonunif_with_count_include_pad // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x32x32xf32>) -> tensor<5x5x30x34xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 0, 0, 0, 0, 2, 1, 3]> : tensor<8xi64>}> : () -> tensor<8xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 2, 1, 3]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<5x5x32x32xf32>, tensor<8xi64>, tensor) -> tensor<5x5x34x36xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<5x5x32x32xf32>, !tosa.shape<8>, tensor) -> tensor<5x5x34x36xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<5x5x34x36xf32>, tensor<4xi32>) -> tensor<5x34x36x5xf32> // CHECK-DAG: [[VAR_5_:%.+]] = tosa.avg_pool2d [[VAR_4_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x34x36x5xf32>) -> tensor<5x30x34x5xf32> @@ -183,10 +183,10 @@ func.func @test_averagepool_strides_nonunifpad_ceil_with_count_include_pad(%arg0 } // CHECK-LABEL: func.func @test_averagepool_strides_nonunifpad_ceil_with_count_include_pad // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<5x5x30x32xf32>) -> tensor<5x5x16x17xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 0, 0, 0, 1, 0, 2, 0]> : tensor<8xi64>}> : () -> tensor<8xi64> +// CHECK-DAG: [[VAR_0_:%.+]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 1, 0, 2, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<5x5x30x32xf32>, tensor<8xi64>, tensor) -> tensor<5x5x31x34xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.pad [[PARAM_0_]], [[VAR_0_]], [[VAR_1_]] : (tensor<5x5x30x32xf32>, !tosa.shape<8>, tensor) -> tensor<5x5x31x34xf32> // CHECK-DAG: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> // CHECK: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<5x5x31x34xf32>, tensor<4xi32>) -> tensor<5x31x34x5xf32> // CHECK-DAG: [[VAR_5_:%.+]] = tosa.avg_pool2d [[VAR_4_]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<5x31x34x5xf32>) -> tensor<5x16x17x5xf32> diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir index 73a653be96..b2f105c031 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir @@ -6,7 +6,8 @@ func.func @test_reshape(%arg0 : tensor<128x1024xf32>) -> tensor<1x128x16x64xf32> "func.return"(%1) : (tensor<1x128x16x64xf32>) -> () // CHECK-LABEL: func @test_reshape // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<128x1024xf32>) -> tensor<1x128x16x64xf32> { -// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<128x1024xf32>) -> tensor<1x128x16x64xf32> +// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<[1, 128, 16, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE]] : (tensor<128x1024xf32>, !tosa.shape<4>) -> tensor<1x128x16x64xf32> // CHECK-NEXT: return [[VAR_1_]] : tensor<1x128x16x64xf32> } @@ -16,6 +17,7 @@ func.func @test_reshape_allowzero(%arg0 : tensor<12x128x1024xf32>) -> tensor<12x "func.return"(%1) : (tensor<12x128x16x64xf32>) -> () // CHECK-LABEL: func @test_reshape // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x128x1024xf32>) -> tensor<12x128x16x64xf32> { -// CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<12x128x1024xf32>) -> tensor<12x128x16x64xf32> +// CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<[12, 128, 16, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]], [[SHAPE]] : (tensor<12x128x1024xf32>, !tosa.shape<4>) -> tensor<12x128x16x64xf32> // CHECK-NEXT: return [[VAR_1_]] : tensor<12x128x16x64xf32> } diff --git a/third_party/stablehlo b/third_party/stablehlo index 459e481b77..459897561d 160000 --- a/third_party/stablehlo +++ b/third_party/stablehlo @@ -1 +1 @@ -Subproject commit 459e481b77f8537aae3f8e8e8ad9550721afe202 +Subproject commit 459897561d365ef97caba46984847f9184d472ec diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index 804dff5fda..c7366d7b5f 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd .. +cd llvm-project && git checkout 0e779ad4998ef65907502101c5b82ede05ddfa4e && cd ..