diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 247dc5fe0a5..44b4329c136 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -25,6 +25,14 @@ using namespace mlir; using namespace mlir::enzyme; namespace { + +static mlir::Type batchType(mlir::Type type, int64_t width) { + if (width > 1 || ShapedType::isDynamic(width)) { + return RankedTensorType::get({width}, type); + } + return type; +} + class FloatTypeInterface : public AutoDiffTypeInterface::ExternalModel { @@ -44,12 +52,8 @@ class FloatTypeInterface return a; } - Type getShadowType(Type self, unsigned width) const { - if (width > 1) { - return RankedTensorType::get({width}, self); - } else { - return self; - } + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } @@ -108,9 +112,8 @@ class TensorTypeInterface return added; } - Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } @@ -141,9 +144,8 @@ class IntegerTypeInterface return a; } - Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } @@ -175,9 +177,8 @@ class ComplexTypeInterface return builder.create(loc, a)->getResult(0); } - Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + Type getShadowType(Type self, int64_t width) const { + return batchType(self, width); } bool isMutable(Type self) const { return false; } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 9623aef4fd0..1bb32c817d3 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -36,7 +36,7 @@ namespace { #include "Implementations/SCFDerivatives.inc" // TODO: support non constant number of iteration by using unknown dimensions -static std::optional getNumberOfIterations(scf::ForOp forOp) { +static std::optional getConstantNumberOfIterations(scf::ForOp forOp) { auto lb = forOp.getLowerBound(); auto ub = forOp.getUpperBound(); auto step = forOp.getStep(); @@ -55,6 +55,14 @@ static std::optional getNumberOfIterations(scf::ForOp forOp) { return (ubI - lbI) / stepI; } +static Value getNumberOfIterations(OpBuilder &builder, scf::ForOp forOp) { + Value lb = forOp.getLowerBound(), ub = forOp.getUpperBound(), + step = forOp.getStep(); + Value diff = builder.create(forOp->getLoc(), ub, lb); + Value nSteps = builder.create(forOp->getLoc(), diff, step); + return nSteps; +} + struct ForOpEnzymeOpsRemover : public EnzymeOpsRemoverOpInterface::ExternalModel { @@ -132,7 +140,7 @@ struct ForOpEnzymeOpsRemover } } - auto numIters = getNumberOfIterations(forOp); + auto numIters = getConstantNumberOfIterations(forOp); Value inductionVariable; // [0, N[ counter if (matchPattern(forOp.getLowerBound(), m_Zero()) && @@ -186,16 +194,25 @@ struct ForOpEnzymeOpsRemover } auto newType = - info.batchType(numIters.value_or(mlir::ShapedType::kDynamic)); - ValueRange operands = - numIters.has_value() - ? ValueRange{} - : ValueRange{builder - .create( - forOp->getLoc(), builder.getIndexAttr(10)) - .getResult()}; - auto initValue = builder.create(info.initOp->getLoc(), - newType, operands); + info.cachedType() + .cast() + .getShadowType(numIters.value_or(mlir::ShapedType::kDynamic)) + .cast(); + + SmallVector dynamicDims; + + for (auto it : llvm::enumerate(newType.getShape())) { + if (ShapedType::isDynamic(it.value())) { + if (it.index() == 0) + dynamicDims.push_back(getNumberOfIterations(builder, forOp)); + else + return failure(); // TODO: find dynamic dims within the body. + } + } + + Value initValue = builder.create(info.initOp->getLoc(), + newType, dynamicDims); + // cast(newType).createNullValue( // builder, info.initOp->getLoc()); @@ -241,9 +258,11 @@ struct ForOpEnzymeOpsRemover builder.setInsertionPoint(otherForOp); SmallVector operands(otherForOp.getInitArgs().begin(), otherForOp.getInitArgs().end()); - operands.push_back(builder.create( - otherForOp->getLoc(), - builder.getIndexAttr(numIters.value_or(1) - 1))); + operands.push_back(numIters.has_value() + ? builder.create( + otherForOp->getLoc(), + builder.getIndexAttr(numIters.value() - 1)) + : getNumberOfIterations(builder, forOp)); Block *otherBody = otherForOp.getBody(); Value otherInductionVariable = @@ -285,7 +304,9 @@ struct ForOpEnzymeOpsRemover Value cache = info.initOp.getResult(); - auto newType = info.batchType(numIters.value()); + auto newType = + info.cachedType().cast().getShadowType( + numIters.value()); enzyme::InitOp newInit = ({ OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(info.initOp); diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td index 2e9f1697af4..ced1e68700b 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td @@ -57,7 +57,7 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { }], /*retTy=*/"::mlir::Type", /*methodName=*/"getShadowType", - /*args=*/(ins "unsigned":$width) + /*args=*/(ins "int64_t":$width) >, InterfaceMethod< /*desc=*/[{ diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index a624744f134..1d4b2a04622 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -9,23 +9,6 @@ #include "RemovalUtils.h" #include "Interfaces/AutoDiffOpInterface.h" -mlir::Type mlir::enzyme::CacheInfo::batchType() { - return mlir::enzyme::CacheInfo::batchType(mlir::ShapedType::kDynamic); -} - -mlir::Type mlir::enzyme::CacheInfo::batchType(int64_t dim) { - auto T = pushedValue().getType(); - - if (auto TT = dyn_cast(T)) { - SmallVector shape; - shape.push_back(dim); - shape.append(TT.getShape().begin(), TT.getShape().end()); - return TT.clone(shape); - } - - return mlir::RankedTensorType::get({dim}, T); -} - mlir::LogicalResult mlir::enzyme::removeOpsWithinBlock(mlir::Block *block) { bool valid = true; diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h index 83b9cc45f9f..d360de05af5 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.h @@ -35,9 +35,9 @@ struct CacheInfo { } Value pushedValue() { return pushOp.getValue(); } - - Type batchType(int64_t dim); - Type batchType(); // unknown size + Type cachedType() { + return initOp.getResult().getType().cast().getType(); + } }; LogicalResult removeOpsWithinBlock(Block *block);