diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 9846f9ae618..29f977c95a9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -121,6 +121,7 @@ def TypeOf : Operation { class ComplexInst : Inst; class ArithInst : Inst; +class LlvmInst : Inst; class MathInst : Inst; def AddF : ArithInst<"AddFOp">; @@ -133,6 +134,9 @@ def RemF : ArithInst<"RemFOp">; def CheckedMulF : ArithInst<"MulFOp">; def CheckedDivF : ArithInst<"DivFOp">; +def LlvmCheckedMulF : LlvmInst<"FMulOp">; +def LlvmExpF : LlvmInst<"ExpOp">; + def CosF : MathInst<"CosOp">; def SinF : MathInst<"SinOp">; def ExpF : MathInst<"ExpOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index b9e9ade7421..11b191d7fdd 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -15,6 +15,7 @@ #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td index e77e88aea47..949eeb22e09 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMDerivatives.td @@ -26,3 +26,9 @@ def : ReadOnlyIdentityOp<"LLVM", "PtrToIntOp", [0]>; def : ReadOnlyIdentityOp<"LLVM", "IntToPtrOp", [0]>; def : AllocationOp<"LLVM", "AllocaOp">; + +def : MLIRDerivative<"LLVM", "ExpOp", (Op $x), + [ + (LlvmCheckedMulF (DiffeRet), (LlvmExpF $x)) + ] + >; diff --git a/enzyme/test/MLIR/ForwardMode/llvm.mlir b/enzyme/test/MLIR/ForwardMode/llvm.mlir index df7e572ce4c..f6cd8c3a9c2 100644 --- a/enzyme/test/MLIR/ForwardMode/llvm.mlir +++ b/enzyme/test/MLIR/ForwardMode/llvm.mlir @@ -13,6 +13,16 @@ module { %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f64, f64) -> (f64) return %r : f64 } + + func.func @exp(%x: f32) -> f32 { + %0 = llvm.intr.exp(%x) : (f32) -> f32 + return %0 : f32 + } + + func.func @dexp(%x: f32, %dx: f32) -> f32 { + %r = enzyme.fwddiff @exp(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f32, f32) -> f32 + return %r : f32 + } } // CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { @@ -29,3 +39,10 @@ module { // CHECK-NEXT: %[[i7:.+]] = llvm.load %[[i1]] : !llvm.ptr -> f64 // CHECK-NEXT: return %[[i6]] : f64 // CHECK-NEXT: } + +// CHECK: func.func private @fwddiffeexp(%[[arg0:.+]]: f32, %[[arg1:.+]]: f32) -> f32 { +// CHECK-NEXT: %[[der:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32 +// CHECK-NEXT: %[[res:.+]] = llvm.fmul %[[arg1]], %[[der]] : f32 +// CHECK-NEXT: %[[exp:.+]] = llvm.intr.exp(%[[arg0]]) : (f32) -> f32 +// CHECK-NEXT: return %[[res]] : f32 +// CHECK-NEXT: }