Skip to content

Commit

Permalink
update float types
Browse files Browse the repository at this point in the history
  • Loading branch information
brnorris03 committed Feb 21, 2025
1 parent 4a78a79 commit f4a3e00
Show file tree
Hide file tree
Showing 13 changed files with 42 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
auto int16Ty = IntegerType::get(context, 16);
auto int32Ty = IntegerType::get(context, 32);
auto int64Ty = IntegerType::get(context, 64);
auto float32Ty = FloatType::getF32(context);
auto float32Ty = Float32Type::get(context);

// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
Expand Down Expand Up @@ -570,7 +570,7 @@ Type getZTensorStructTy(MLIRContext *context) {
Type llvmI64Ty = IntegerType::get(context, 64);
Type llvmI1Ty = IntegerType::get(context, 1);
Type llvmI8Ty = IntegerType::get(context, 8);
Type llvmF32Ty = FloatType::getF32(context);
Type llvmF32Ty = Float32Type::get(context);
Type llvmArray3I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 3);
Type llvmArray20I8Ty = LLVM::LLVMArrayType::get(llvmI8Ty, 20);
Type llvmI8PtrTy = krnl::getPointerType(context, llvmI8Ty);
Expand Down Expand Up @@ -662,7 +662,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
scaleTy.isF32() && "Wrong type for zTensor's rec_scale. Must be float");
create.llvm.store(recScale, recScalePtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, recScalePtr);
}

Expand All @@ -675,7 +675,7 @@ void fillInZTensor(PatternRewriter &rewriter, Location loc, ModuleOp module,
offsetTy.isF32() && "Wrong type for zTensor's offset. Must be float");
create.llvm.store(offset, offsetPtr);
} else {
Value zero = create.llvm.constant(FloatType::getF32(context), (double)0.);
Value zero = create.llvm.constant(Float32Type::get(context), (double)0.);
create.llvm.store(zero, offsetPtr);
}

Expand Down
4 changes: 2 additions & 2 deletions src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ class KrnlRandomNormalOpLowering : public ConversionPattern {
// or
// (memref<3x4x5xf64>, index, f64, f64, f64)
Type llvmVoidTy = LLVM::LLVMVoidType::get(context);
Type llvmOptionsTy = FloatType::getF32(context);
Type llvmOptionsTy = Float32Type::get(context);
Type llvmOutputTy = getPointerType(context, llvmOptionsTy);
if (inType.isF64()) {
llvmOptionsTy = FloatType::getF64(context);
llvmOptionsTy = Float64Type::get(context);
llvmOutputTy = getPointerType(context, llvmOptionsTy);
}
Type llvmI64Ty = IntegerType::get(context, 64);
Expand Down
13 changes: 6 additions & 7 deletions src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,19 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
Type outType = op->getResultTypes().front();
Type llvmInType, llvmOutType;
if (inType.isF16())
llvmInType = FloatType::getF16(context);
llvmInType = Float16Type::get(context);
else if (inType.isF32())
llvmInType = FloatType::getF32(context);
llvmInType = Float32Type::get(context);
else if (inType.isF64())
llvmInType = FloatType::getF64(context);
llvmInType = Float64Type::get(context);
else if (inType.isBF16())
llvmInType = FloatType::getBF16(context);
llvmInType = Float64Type::get(context);
if (outType.isInteger(1))
llvmOutType = IntegerType::get(context, 1);
else if (outType.isF32())
llvmOutType = FloatType::getF32(context);
llvmOutType = Float32Type::get(context);
else if (outType.isF64())
llvmOutType = FloatType::getF64(context);
llvmOutType = Float64Type::get(context);

// Insert and/or get reference to elementary math function declaration.
assert(
Expand Down Expand Up @@ -214,7 +214,6 @@ class KrnlUnaryMathOpLowering : public ConversionPattern {
return SymbolRefAttr::get(context, mathFuncName);

// Create function declaration.
// auto llvmF32Ty = FloatType::get(context);
auto llvmFnType =
LLVM::LLVMFunctionType::get(llvmOutType, ArrayRef<Type>({llvmInType}));

Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/ONNXToKrnl/Math/LRN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct ONNXLRNOpLowering : public OpConversionPattern<ONNXLRNOp> {
float alphaLit = adaptor.getAlpha().convertToFloat();
float betaLit = adaptor.getBeta().convertToFloat();
int sizeLit = adaptor.getSize();
auto f32Type = FloatType::getF32(rewriter.getContext());
auto f32Type = Float32Type::get(rewriter.getContext());
Value biasValue = create.math.constant(f32Type, biasLit);
Value alphaDivSizeValue =
create.math.constant(f32Type, alphaLit / static_cast<float>(sizeLit));
Expand Down
10 changes: 5 additions & 5 deletions src/Dialect/ONNX/ElementsAttr/BType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ Type mlirTypeOfBType(BType btype, MLIRContext *ctx) {
case BType::FLOAT : return b.getF32Type();
case BType::FLOAT16 : return b.getF16Type();
case BType::BFLOAT16 : return b.getBF16Type();
case BType::FLOAT8E4M3FN : return b.getFloat8E4M3FNType();
case BType::FLOAT8E4M3FNUZ : return b.getFloat8E4M3FNUZType();
case BType::FLOAT8E5M2 : return b.getFloat8E5M2Type();
case BType::FLOAT8E5M2FNUZ : return b.getFloat8E5M2FNUZType();
case BType::FLOAT8E4M3FN : return b.getType<Float8E4M3FNType>();
case BType::FLOAT8E4M3FNUZ : return b.getType<Float8E4M3FNUZType>();
case BType::FLOAT8E5M2 : return b.getType<Float8E5M2Type>();
case BType::FLOAT8E5M2FNUZ : return b.getType<Float8E5M2FNUZType>();
default: llvm_unreachable("unsupported data type");
}
// clang-format on
Expand Down Expand Up @@ -104,4 +104,4 @@ BType wideBTypeOfBType(BType d) {
[](auto btype) { return toBType<typename BTypeTrait<btype>::widetype>; });
}

} // namespace onnx_mlir
} // namespace onnx_mlir
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXOps/ML/OneHotEncoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ LogicalResult ONNXOneHotEncoderOp::inferShapes(
return success();

ONNXOneHotEncoderOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(FloatType::getF32(getContext()));
return shapeHelper.computeShapeAndUpdateType(Float32Type::get(getContext()));
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion src/Dialect/ONNX/ONNXOps/Math/ElementwiseUnary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ LogicalResult ONNXScalerOp::inferShapes(
ONNXUnaryOpShapeHelper shapeHelper(getOperation(), {});
RankedTensorType xType = mlir::dyn_cast<RankedTensorType>(getX().getType());
return shapeHelper.computeShapeAndUpdateType(
FloatType::getF32(getContext()), xType.getEncoding());
Float32Type::get(getContext()), xType.getEncoding());
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 7 additions & 7 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
Type elementType;
if (auto attr = getDtypeAttr()) {
if (getDtype() == 0) {
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
} else if (getDtype() == 1) {
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
} else if (getDtype() == 2) {
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());
} else {
llvm_unreachable("dtype not supported for RandomNormal");
}
} else {
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
}
return {UnrankedTensorType::get(elementType)};
}
Expand All @@ -68,11 +68,11 @@ std::vector<Type> ONNXRandomNormalOp::resultTypeInference() {
LogicalResult ONNXRandomNormalOp::inferShapes(
std::function<void(Region &)> doShapeInference) {
auto elementTypeID = getDtype();
Type elementType = FloatType::getF32(getContext());
Type elementType = Float32Type::get(getContext());
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());

ONNXRandomNormalOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateType(elementType);
Expand Down
14 changes: 6 additions & 8 deletions src/Dialect/ONNX/ONNXOps/Math/RandomNormalLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ LogicalResult ONNXRandomNormalLikeOp::verify() {
if (elementTypeID < 0 || elementTypeID > 2) {
return emitOpError("dtype not 0, 1 or 2.");
}
if (elementTypeID == 0 && outputType != FloatType::getF16(getContext()))
if (elementTypeID == 0 && outputType != Float16Type::get(getContext()))
return emitOpError("output tensor does match 0 dtype.");
else if (elementTypeID == 1 &&
outputType != FloatType::getF32(getContext()))
else if (elementTypeID == 1 && outputType != Float32Type::get(getContext()))
return emitOpError("output tensor does match 1 dtype.");
else if (elementTypeID == 2 &&
outputType != FloatType::getF64(getContext()))
else if (elementTypeID == 2 && outputType != Float64Type::get(getContext()))
return emitOpError("output tensor does match 2 dtype.");
} else if (inputType != outputType) {
return emitOpError("output and input element types do not match.");
Expand All @@ -75,11 +73,11 @@ LogicalResult ONNXRandomNormalLikeOp::inferShapes(
} else {
int64_t elementTypeID = elementTypeIDDType.value();
if (elementTypeID == 0)
elementType = FloatType::getF16(getContext());
elementType = Float16Type::get(getContext());
else if (elementTypeID == 1)
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
else if (elementTypeID == 2)
elementType = FloatType::getF64(getContext());
elementType = Float64Type::get(getContext());
else
return emitError("dtype attribute is invalid (use: 0, 1 or 2)");
}
Expand Down
8 changes: 4 additions & 4 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,13 @@ Type convertONNXTypeToMLIRType(
Builder &builder, onnx::TensorProto_DataType onnxType) {
switch (onnxType) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FN:
return builder.getFloat8E4M3FNType();
return builder.getType<Float8E4M3FNType>();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E4M3FNUZ:
return builder.getFloat8E4M3FNUZType();
return builder.getType<Float8E4M3FNUZType>();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2:
return builder.getFloat8E5M2Type();
return builder.getType<Float8E5M2Type>();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
return builder.getFloat8E5M2FNUZType();
return builder.getType<Float8E5M2FNUZType>();
case onnx::TensorProto_DataType::TensorProto_DataType_BFLOAT16:
return builder.getBF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes(

IntegerType ui8Type =
IntegerType::get(getContext(), 8, IntegerType::Unsigned);
FloatType f32Type = FloatType::getF32(getContext());
FloatType f32Type = Float32Type::get(getContext());

ONNXDynamicQuantizeLinearOpShapeHelper shapeHelper(getOperation(), {});
return shapeHelper.computeShapeAndUpdateTypes(
Expand Down
4 changes: 2 additions & 2 deletions src/Dialect/ONNX/ONNXOps/Tensor/Constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ std::vector<Type> ONNXConstantOp::resultTypeInference() {
} else if (auto attr = getSparseValueAttr()) {
type = mlir::cast<ElementsAttr>(attr).getShapedType();
} else if (auto attr = getValueFloatAttr()) {
type = RankedTensorType::get({}, FloatType::getF32(getContext()));
type = RankedTensorType::get({}, Float32Type::get(getContext()));
} else if (auto attr = getValueFloatsAttr()) {
int64_t size = attr.size();
type = RankedTensorType::get({size}, FloatType::getF32(getContext()));
type = RankedTensorType::get({size}, Float32Type::get(getContext()));
} else if (auto attr = getValueIntAttr()) {
type = RankedTensorType::get({}, IntegerType::get(getContext(), 64));
} else if (auto attr = getValueIntsAttr()) {
Expand Down
4 changes: 2 additions & 2 deletions src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ std::vector<Type> ONNXConstantOfShapeOp::resultTypeInference() {
if (auto attr = getValueAttr()) {
elementType = mlir::cast<ElementsAttr>(attr).getElementType();
} else {
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());
}
return {UnrankedTensorType::get(elementType)};
}
Expand All @@ -125,7 +125,7 @@ LogicalResult ONNXConstantOfShapeOp::inferShapes(
} else {
// If 'value' attribute is not specified, it defaults to a tensor of
// value 0 and datatype float32.
elementType = FloatType::getF32(getContext());
elementType = Float32Type::get(getContext());

llvm::SmallVector<int64_t, 2> dims(1, 1);
auto tensorType = RankedTensorType::get(dims, elementType);
Expand Down

0 comments on commit f4a3e00

Please sign in to comment.