Skip to content

Commit

Permalink
dialects: (llvm) Add a bunch of float methods
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonLydike committed Jan 23, 2025
1 parent e49877b commit 774aee4
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 41 deletions.
22 changes: 21 additions & 1 deletion tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: XDSL_ROUNDTRIP

%arg0, %arg1 = "test.op"() : () -> (i32, i32)
%arg0, %arg1, %f1 = "test.op"() : () -> (i32, i32, f32)

%add_both = llvm.add %arg0, %arg1 {"overflowFlags" = #llvm.overflow<nsw, nuw>} : i32
// CHECK: %add_both = llvm.add %arg0, %arg1 {overflowFlags = #llvm.overflow<nsw,nuw>} : i32
Expand Down Expand Up @@ -121,3 +121,23 @@

%icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32
// CHECK: %icmp_uge = llvm.icmp "uge" %arg0, %arg1 : i32

// float arith:

%fmul = llvm.fmul %f1, %f1 : f32
// CHECK: %fmul = llvm.fmul %f1, %f1 : f32

%fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: %fmul_fast = llvm.fmul %f1, %f1 {fastmathFlags = #llvm.fastmath<fast>} : f32

%fdiv = llvm.fdiv %f1, %f1 : f32
// CHECK: %fdiv = llvm.fdiv %f1, %f1 : f32

%fadd = llvm.fadd %f1, %f1 : f32
// CHECK: %fadd = llvm.fadd %f1, %f1 : f32

%fsub = llvm.fsub %f1, %f1 : f32
// CHECK: %fsub = llvm.fsub %f1, %f1 : f32

%frem = llvm.frem %f1, %f1 : f32
// CHECK: %frem = llvm.frem %f1, %f1 : f32
8 changes: 8 additions & 0 deletions tests/filecheck/dialects/llvm/example.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,12 @@ builtin.module {

// CHECK: %val = "test.op"() : () -> i32
// CHECK-NEXT: %fval = llvm.bitcast %val : i32 to f32

%fval2 = llvm.sitofp %val : i32 to f32

// CHECK-NEXT: %fval2 = llvm.sitofp %val : i32 to f32

%fval3 = llvm.fpext %fval : f32 to f64

// CHECK-NEXT: %fval3 = llvm.fpext %fval : f32 to f64
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s --mlir-print-op-generic | xdsl-opt | filecheck %s

builtin.module {
%arg0, %arg1 = "test.op"() : () -> (i32, i32)
%arg0, %arg1, %f1 = "test.op"() : () -> (i32, i32, f32)

%add = llvm.add %arg0, %arg1 : i32
// CHECK: %{{.*}} = llvm.add %{{.*}}, %{{.*}} : i32
Expand Down Expand Up @@ -44,4 +44,24 @@ builtin.module {

%ashr = llvm.ashr %arg0, %arg1 : i32
// CHECK: %{{.*}} = llvm.ashr %{{.*}}, %{{.*}} : i32

// float arith:

%fmul = llvm.fmul %f1, %f1 : f32
// CHECK: %{{\d+}} = llvm.fmul %2, %2 : f32

%fmul_fast = llvm.fmul %f1, %f1 {test = true, fastmathFlags = #llvm.fastmath<fast>} : f32
// CHECK: %{{\d+}} = llvm.fmul %2, %2 {test = true, fastmathFlags = #llvm.fastmath<fast>} : f32

%fdiv = llvm.fdiv %f1, %f1 : f32
// CHECK: %{{\d+}} = llvm.fdiv %2, %2 : f32

%fadd = llvm.fadd %f1, %f1 : f32
// CHECK: %{{\d+}} = llvm.fadd %2, %2 : f32

%fsub = llvm.fsub %f1, %f1 : f32
// CHECK: %{{\d+}} = llvm.fsub %2, %2 : f32

%frem = llvm.frem %f1, %f1 : f32
// CHECK: %{{\d+}} = llvm.frem %2, %2 : f32
}
163 changes: 124 additions & 39 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
ContainerType,
DenseArrayBase,
DenseI64ArrayConstr,
Float16Type,
Float32Type,
Float64Type,
IndexType,
IntAttr,
IntegerAttr,
Expand All @@ -25,7 +28,7 @@
i32,
i64,
)
from xdsl.dialects.utils import FastMathAttrBase
from xdsl.dialects.utils import FastMathAttrBase, FastMathFlag
from xdsl.ir import (
Attribute,
BitEnumAttribute,
Expand All @@ -38,9 +41,11 @@
TypeAttribute,
)
from xdsl.irdl import (
AnyOf,
BaseAttr,
IRDLOperation,
ParameterDef,
ParsePropInAttrDict,
VarConstraint,
base,
irdl_attr_definition,
Expand All @@ -57,7 +62,7 @@
)
from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer
from xdsl.traits import IsTerminator, NoMemoryEffect, SymbolOpInterface
from xdsl.traits import IsTerminator, NoMemoryEffect, Pure, SymbolOpInterface
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr
Expand Down Expand Up @@ -1710,10 +1715,7 @@ class ZeroOp(IRDLOperation):
res = result_def(LLVMTypeConstr)


@irdl_op_definition
class BitcastOp(IRDLOperation):
name = "llvm.bitcast"

class GenericCastOp(IRDLOperation, ABC):
arg = operand_def(Attribute)
"""
LLVM-compatible non-aggregate type
Expand All @@ -1735,56 +1737,139 @@ def __init__(self, val: Operation | SSAValue, res_type: Attribute):
)


floatingPointLike = AnyOf([Float16Type, Float32Type, Float64Type])


class AbstractFloatArithOp(IRDLOperation, ABC):
T: ClassVar = VarConstraint("T", floatingPointLike)

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)

fastmathFlags = prop_def(FastMathAttr, default_value=FastMathAttr(None))

traits = traits_def(Pure())

assembly_format = "$lhs `,` $rhs attr-dict `:` type($lhs)"

irdl_options = [ParsePropInAttrDict()]

def __init__(
self,
lhs: SSAValue | Operation,
rhs: SSAValue | Operation,
fast_math: FastMathAttr | FastMathFlag | None = None,
attrs: dict[str, Attribute] | None = None,
):
if isinstance(fast_math, FastMathFlag | str | None):
fast_math = FastMathAttr(fast_math)

super().__init__(
operands=[lhs, rhs],
result_types=[SSAValue.get(lhs).type],
properties={"fastmathFlags": fast_math},
attributes=attrs,
)


@irdl_op_definition
class FAddOp(AbstractFloatArithOp):
name = "llvm.fadd"


@irdl_op_definition
class FMulOp(AbstractFloatArithOp):
name = "llvm.fmul"


@irdl_op_definition
class FDivOp(AbstractFloatArithOp):
name = "llvm.fdiv"


@irdl_op_definition
class FSubOp(AbstractFloatArithOp):
name = "llvm.fsub"


@irdl_op_definition
class FRemOp(AbstractFloatArithOp):
name = "llvm.frem"


@irdl_op_definition
class BitcastOp(GenericCastOp):
name = "llvm.bitcast"


@irdl_op_definition
class SIToFPOp(GenericCastOp):
name = "llvm.sitofp"


@irdl_op_definition
class FPExtOp(GenericCastOp):
name = "llvm.fpext"


LLVM = Dialect(
"llvm",
[
AShrOp,
AddOp,
AddressOfOp,
AllocaOp,
AndOp,
BitcastOp,
SubOp,
CallIntrinsicOp,
CallOp,
ConstantOp,
ExtractValueOp,
FAddOp,
FDivOp,
FMulOp,
FPExtOp,
FRemOp,
FSubOp,
FuncOp,
GEPOp,
GlobalOp,
ICmpOp,
InlineAsmOp,
InsertValueOp,
IntToPtrOp,
LShrOp,
LoadOp,
MulOp,
UDivOp,
NullOp,
OrOp,
ReturnOp,
SDivOp,
URemOp,
SExtOp,
SIToFPOp,
SRemOp,
AndOp,
OrOp,
XOrOp,
ShlOp,
LShrOp,
AShrOp,
StoreOp,
SubOp,
TruncOp,
ZExtOp,
SExtOp,
ICmpOp,
ExtractValueOp,
InsertValueOp,
InlineAsmOp,
UDivOp,
URemOp,
UndefOp,
AllocaOp,
GEPOp,
IntToPtrOp,
NullOp,
LoadOp,
StoreOp,
GlobalOp,
AddressOfOp,
FuncOp,
CallOp,
ReturnOp,
ConstantOp,
CallIntrinsicOp,
XOrOp,
ZExtOp,
ZeroOp,
],
[
LLVMStructType,
LLVMPointerType,
CallingConventionAttr,
FastMathAttr,
LLVMArrayType,
LLVMVoidType,
LLVMFunctionType,
LLVMPointerType,
LLVMStructType,
LLVMVoidType,
LinkageAttr,
CallingConventionAttr,
TailCallKindAttr,
FastMathAttr,
OverflowAttr,
TailCallKindAttr,
],
)

0 comments on commit 774aee4

Please sign in to comment.