diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index 35b814c6038..c38b990ceb6 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -27,17 +27,18 @@ using namespace mlir::enzyme; namespace { static mlir::Type batchType(mlir::Type type, int64_t width) { - if (width > 1 || ShapedType::isDynamic(width)) { - if (auto TT = dyn_cast(type)) { - SmallVector shape; - shape.reserve(TT.getShape().size() + 1); - shape.push_back(width); - shape.append(TT.getShape().begin(), TT.getShape().end()); - return TT.clone(shape); - } - return RankedTensorType::get({width}, type); + if (width == 1) + return type; + + if (auto TT = dyn_cast(type)) { + SmallVector shape; + shape.reserve(TT.getShape().size() + 1); + shape.push_back(width); + shape.append(TT.getShape().begin(), TT.getShape().end()); + return TT.clone(shape); } - return type; + + return RankedTensorType::get({width}, type); } class FloatTypeInterface