diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index be139fb3d8b..72672a95940 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> { } +def BroadcastOp : Enzyme_Op<"broadcast"> { + let description = [{ + Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front. + For scalar operands, ranked tensor is created. + + NOTE: Only works for scalar and *ranked* tensor operands for now. + }]; + + let arguments = (ins AnyType:$input, DenseI64ArrayAttr:$shape); + let results = (outs AnyRankedTensor:$output); + + let builders = [ + OpBuilder<(ins "Value":$input, "ArrayRef":$shape)> + ]; +} + #endif // ENZYME_OPS diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 3e318542730..7e48db2d583 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -191,3 +192,17 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, + ArrayRef shape) { + auto shapeAttr = builder.getDenseI64ArrayAttr(shape); + auto resultTy = input.getType(); + for (auto s : llvm::reverse(shape)) { + resultTy = resultTy.cast().getShadowType(s); + } + build(builder, result, resultTy, input, shapeAttr); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp index 9b27503d79d..8d3650969d0 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp @@ -17,6 +17,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface( arith::ConstantOp::attachInterface(*context); }); } + +void mlir::enzyme::registerTensorDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index d2d6ddfe19b..7c72b97d093 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -45,8 +45,11 @@ class FloatTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + if (width > 1) { + return RankedTensorType::get({width}, self); + } else { + return self; + } } bool isMutable(Type self) const { return false; } @@ -106,7 +109,14 @@ class TensorTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); + if (width != 1) { + auto tenType = self.cast(); + auto shape = tenType.getShape(); + SmallVector newShape; + newShape.push_back(width); + newShape.append(shape.begin(), shape.end()); + return RankedTensorType::get(newShape, tenType.getElementType()); + } return self; } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 355808cdbcc..f727dca2f87 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -74,7 +74,8 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, newVals.push_back(gutils->invertPointerM(op, builder)); } else { Type retTy = - arg.getType().cast().getShadowType(); + arg.getType().cast().getShadowType( + gutils->width); auto toret = retTy.cast().createNullValue( builder, op.getLoc()); newVals.push_back(toret); @@ -146,7 +147,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( if (auto iface = dyn_cast(operand.get().getType())) { if (!iface.isMutable()) { - Type retTy = iface.getShadowType(); + Type retTy = iface.getShadowType(gutils->width); auto toret = retTy.cast().createNullValue( builder, operand.get().getLoc()); newOperands.push_back(toret); @@ -346,7 +347,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( << result.getType() << "\n"; return failure(); } - newOpResultTypes.push_back(typeIface.getShadowType()); + newOpResultTypes.push_back(typeIface.getShadowType(gutils->width)); } SmallVector newOperands; @@ -432,4 +433,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); enzyme::registerFuncDialectAutoDiffInterface(registry); + enzyme::registerTensorDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index d6f28ccfc73..650f6c6326b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); +void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 69cfad436cf..5ec908f1268 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -245,9 +245,11 @@ FunctionOpInterface CloneFunctionWithReturns( mlir::Value val = blk.getArgument(i); mlir::Value dval; if (i == ArgActivity.size() - 1) - dval = blk.addArgument(val.getType(), val.getLoc()); + dval = blk.addArgument(getShadowType(val.getType(), width), + val.getLoc()); else - dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(), + dval = blk.insertArgument(blk.args_begin() + i + 1, + getShadowType(val.getType(), width), val.getLoc()); ptrInputs.map(oval, dval); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 1ec4212dc5a..32cb5b79614 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -108,7 +108,8 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, return invertedPointers.lookupOrNull(v); if (isConstantValue(v)) { - if (auto iface = v.getType().dyn_cast()) { + if (auto iface = + getShadowType(v.getType()).dyn_cast()) { OpBuilder::InsertionGuard guard(Builder2); if (auto op = v.getDefiningOp()) Builder2.setInsertionPoint(getNewFromOriginal(op)); diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 0445fc43064..99db4d80034 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms MLIRFuncDialect MLIRFuncTransforms MLIRGPUDialect + MLIRTensorDialect MLIRIR MLIRLLVMDialect MLIRMathDialect diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 58c43be236d..fb6df3e2208 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "Dialect/Dialect.h" @@ -80,6 +81,10 @@ namespace affine { class AffineDialect; } // end namespace affine +namespace tensor { +class TensorDialect; +} // end namespace tensor + namespace LLVM { class LLVMDialect; } // end namespace LLVM diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 6458e63b273..c5b4df76917 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -16,7 +16,8 @@ def DifferentiatePass : Pass<"enzyme"> { let dependentDialects = [ "arith::ArithDialect", "complex::ComplexDialect", - "cf::ControlFlowDialect" + "cf::ControlFlowDialect", + "tensor::TensorDialect", ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp index 0e6bdf7b101..99e7243129b 100644 --- a/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp +++ b/enzyme/Enzyme/MLIR/enzymemlir-opt.cpp @@ -67,6 +67,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir new file mode 100644 index 00000000000..f20989aa424 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %y : f64) -> f64 { + %c = arith.cmpf ult, %x, %y : f64 + cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64) + + ^blk2(%r : f64): + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffesquare(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 +// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) +// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_for.mlir b/enzyme/test/MLIR/ForwardMode/batched_for.mlir new file mode 100644 index 00000000000..95557cb0b6f --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_for.mlir @@ -0,0 +1,33 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (f64) { + %n = arith.addf %arg2, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> +// CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, tensor<2xf64>) { +// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, tensor<2xf64> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[i0]]#1 : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_if.mlir b/enzyme/test/MLIR/ForwardMode/batched_if.mlir new file mode 100644 index 00000000000..33c9e1b9fe8 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_if.mlir @@ -0,0 +1,43 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %c : i1) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, i1) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: i1) -> tensor<2xf64> { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, tensor<2xf64>, f64) { +// CHECK-NEXT: %[[t4:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t5:.+]] = arith.mulf %[[arg1]], %[[t4]] : tensor<2xf64> +// CHECK-NEXT: %[[t6:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t7:.+]] = arith.mulf %[[arg1]], %[[t6]] : tensor<2xf64> +// CHECK-NEXT: %[[t8:.+]] = arith.addf %[[t5]], %[[t7]] : tensor<2xf64> +// CHECK-NEXT: %[[t9:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[t9]], %[[t8]], %[[cst2]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[e4:.+]] = arith.addf %[[arg1]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[e5:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[e5]], %[[e4]], %[[cst10]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = "enzyme.broadcast"(%[[r0]]#2) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#1, %[[r1]] : tensor<2xf64> +// CHECK-NEXT: %[[r3:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir new file mode 100644 index 00000000000..f06f86d2a04 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : f64 -> tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir new file mode 100644 index 00000000000..11b75f634a6 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = arith.mulf %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>) + return %r : tensor<2x10xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 900c5c813cd..dccbc7b7923 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -275,8 +275,19 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, os << ord; } if (!vecValue && !startsWith(ord, "local")) { - if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) + if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << "if (gutils->width != 1) {\n" + << " " << argName << "_" << (idx - 1) + << " = builder.create(\n" + << " op.getLoc(),\n" + << " " << argName << "_" << (idx - 1) << ",\n" + << " llvm::SmallVector({gutils->width}));\n" + << "}"; + } + } if (lookup && intrinsic != MLIRDerivatives) os << ", " << builder << ")";