diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 54db540c3..a8a5ac47e 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -10,6 +10,8 @@ // ops. //===----------------------------------------------------------------------===// +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -1449,6 +1451,147 @@ struct ShiftRightLogicalSimplify final } }; +struct WhileDeadResults final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + bool isLoopResultDead(OpResult result) const { + // Not dead if the result is in use. + if (!result.use_empty()) + return false; + + // Or if the corresponding argument is being used in computing the + // condition. + auto whileOp = cast(result.getOwner()); + Value condArgument = + whileOp.getCond().getArgument(result.getResultNumber()); + SetVector forwardSlice; + getForwardSlice(condArgument, &forwardSlice); + if (!llvm::all_of(forwardSlice, mlir::isPure)) + return false; + if (forwardSlice.contains(whileOp.getCond().front().getTerminator())) + return false; + + // Or in computing another result. We first do a fast-path check of having + // the argument not influencing the terminator operation, before going into + // finer-grain analysis. + // + // TODO: it is possible that this argument does influence another terminator + // operand, but that operand in turn corresponds to a dead value, but + // handling that would require more complex logic of detecting dead cycles + // of value chains. + forwardSlice.clear(); + assert(llvm::hasSingleElement(whileOp.getBody())); + Value bodyArgument = + whileOp.getBody().getArgument(result.getResultNumber()); + getForwardSlice(bodyArgument, &forwardSlice); + if (!llvm::all_of(forwardSlice, mlir::isPure)) + return false; + + Operation *bodyTerminator = whileOp.getBody().front().getTerminator(); + if (!forwardSlice.contains(bodyTerminator)) + return true; + + for (OpOperand &terminatorOperand : bodyTerminator->getOpOperands()) { + if (terminatorOperand.getOperandNumber() == result.getResultNumber()) + continue; + + SetVector backwardSlice; + BackwardSliceOptions options; + options.omitBlockArguments = true; + getBackwardSlice(terminatorOperand.get(), &backwardSlice, options); + for (Operation *op : backwardSlice) { + if (llvm::is_contained(op->getOperands(), bodyArgument)) + return false; + } + } + return true; + } + + void replaceTerminator(PatternRewriter &rewriter, Region ®ion, + ArrayRef deadResults) const { + Operation *terminator = region.front().getTerminator(); + SmallVector terminatorOperands; + for (auto &&[i, operand] : llvm::enumerate(terminator->getOperands())) { + if (!llvm::is_contained(deadResults, i)) + terminatorOperands.push_back(operand); + } + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp( + terminator, TypeRange(), terminatorOperands, terminator->getAttrs()); + } + + LogicalResult matchAndRewrite(mlir::stablehlo::WhileOp op, + PatternRewriter &rewriter) const override { + SmallVector deadResults; + for (OpResult result : op.getResults()) { + if (!isLoopResultDead(result)) + continue; + + deadResults.push_back(result.getResultNumber()); + } + if (deadResults.empty()) + return failure(); + + SetVector condSlice, bodySlice; + for (int64_t i : deadResults) { + getForwardSlice(op.getCond().getArgument(i), &condSlice); + getForwardSlice(op.getBody().getArgument(i), &bodySlice); + } + condSlice.remove(op.getCond().front().getTerminator()); + bodySlice.remove(op.getBody().front().getTerminator()); + replaceTerminator(rewriter, op.getCond(), deadResults); + replaceTerminator(rewriter, op.getBody(), deadResults); + + condSlice = mlir::topologicalSort(condSlice); + bodySlice = mlir::topologicalSort(bodySlice); + for (Operation *erasable : llvm::reverse(condSlice)) + rewriter.eraseOp(erasable); + for (Operation *erasable : llvm::reverse(bodySlice)) + rewriter.eraseOp(erasable); + + SmallVector operands; + SmallVector resultTypes; + SmallVector condBlockArgLocs, bodyBlockArgsLocs; + for (auto &&[i, operand, resultType] : + llvm::enumerate(op->getOperands(), op.getResultTypes())) { + if (llvm::is_contained(deadResults, i)) + continue; + + operands.push_back(operand); + resultTypes.push_back(resultType); + condBlockArgLocs.push_back(op.getCond().getArgument(i).getLoc()); + bodyBlockArgsLocs.push_back(op.getBody().getArgument(i).getLoc()); + } + + auto updated = rewriter.create( + op->getLoc(), resultTypes, operands, op->getAttrs()); + SmallVector resultReplacements; + for (int64_t old = 0, upd = 0, end = op->getNumResults(); old < end; + ++old) { + if (llvm::is_contained(deadResults, old)) { + resultReplacements.push_back(nullptr); + continue; + } + resultReplacements.push_back(updated->getResult(upd)); + ++upd; + } + + for (int64_t i : llvm::reverse(deadResults)) + op.getCond().eraseArgument(i); + rewriter.inlineRegionBefore(op.getCond(), updated.getCond(), + updated.getCond().begin()); + + for (int64_t i : llvm::reverse(deadResults)) + op.getBody().eraseArgument(i); + rewriter.inlineRegionBefore(op.getBody(), updated.getBody(), + updated.getBody().begin()); + + rewriter.replaceOp(op, resultReplacements); + return success(); + } +}; + struct NegativePadToSlice final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -7394,19 +7537,38 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase { } patterns.add((no_nan || all_finite), context); - patterns.add( - context); + // clang-format off + patterns.add< + BroadcastInDimOpCanon, + ChainedDynamicBroadcastInDimCanonicalization, + CompareOpCanon, + ConjComplexNegate, + ConvertOpCanon, + DivideSqrtToMultiplyRsqrt, + DynamicBroadcastInDimAllDimsNonExpanding, + DynamicBroadcastInDimOpNotActuallyDynamic, + DynamicGatherOpIsNotDynamic, + DynamicReshapeOpCanon, + EmptyReduceOpCanon, + GatherOpCanon, + GetDimensionSizeOpCanon, + GetTupleElementOpCanon, + IfInline, + IfToSelect, + ImagOpCanon, + MergeConsecutiveReshapes, + NoopReduceOpCanon, + RealOpCanon, + ReorderElementwiseAndShapeOp, + ReshapeOpCanon, + SelectOpUsedWithinIf, + TransposeBroadcastInDimToBroadcastInDim, + TransposeIsReshape, + WhileDeadResults, + WhileSimplify, + ZeroExtentTensorCanon + >(context); + // clang-format on patterns.add(max_constant_expansion, context, PatternBenefit(65000)); patterns.add(max_constant_expansion, context, diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index c5ac4c660..ccee26f52 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -809,6 +809,10 @@ def ApplyShiftRightLogicalSimplifyPatterns : EnzymeHLOPatternOp< "shift_right_logical_simplify"> { let patterns = ["ShiftRightLogicalSimplify"]; } +def WhileDeadResultPatterns : EnzymeHLOPatternOp< + "while_deadresult"> { + let patterns = ["WhileDeadResults"]; +} def ApplyRemSimplifyPatterns : EnzymeHLOPatternOp< "rem_simplify"> { let patterns = ["RemSimplify"]; diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 2d60fef4a..542ead293 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -283,6 +283,7 @@ def hlo_opts(): if_inline<1>; if_to_select<1>; while_simplify<1>; +while_deadresult<1>; dot_reshape_pad<1>; pad_dot_general<1>(1); diff --git a/test/lit_tests/whiledeadarg.mlir b/test/lit_tests/whiledeadarg.mlir new file mode 100644 index 000000000..84be8854c --- /dev/null +++ b/test/lit_tests/whiledeadarg.mlir @@ -0,0 +1,55 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s + +// CHECK-LABEL: @while_deadarg +func.func @while_deadarg(%arg0: tensor<2x6x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>, %arg5: tensor<2xui64>) -> (tensor<2x3xf32>, tensor<2xui64>, tensor<2x6x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>) { + %c = stablehlo.constant dense<5> : tensor + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %c_0 = stablehlo.constant dense<2> : tensor + %c_1 = stablehlo.constant dense<1> : tensor + %c_2 = stablehlo.constant dense<0> : tensor + %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x6x3xf32>) -> tensor<3x6x2xf32> + %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> + %2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> + %3 = stablehlo.slice %0 [0:3, 0:1, 0:2] : (tensor<3x6x2xf32>) -> tensor<3x1x2xf32> + %4 = stablehlo.transpose %3, dims = [2, 1, 0] : (tensor<3x1x2xf32>) -> tensor<2x1x3xf32> + %5 = stablehlo.reshape %4 : (tensor<2x1x3xf32>) -> tensor<2x3xf32> + %6 = stablehlo.broadcast_in_dim %arg4, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> + %7 = stablehlo.dot_general %arg1, %5, contracting_dims = [0] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32> + %8 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> + %9 = stablehlo.add %7, %8 : tensor<3x2xf32> + %10 = stablehlo.add %6, %9 : tensor<3x2xf32> + %11 = stablehlo.tanh %10 : tensor<3x2xf32> + %12 = stablehlo.reshape %11 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + %13 = stablehlo.pad %12, %cst, low = [0, 0, 0], high = [0, 0, 5], interior = [0, 0, 0] : (tensor<3x2x1xf32>, tensor) -> tensor<3x2x6xf32> + + // CHECK: %{{.+}}:8 = stablehlo.while + %14:9 = stablehlo.while(%iterArg = %c_2, %iterArg_3 = %13, %iterArg_4 = %1, %iterArg_5 = %2, %iterArg_6 = %arg3, %iterArg_7 = %arg4, %iterArg_8 = %arg5, %iterArg_9 = %11, %iterArg_10 = %0) : tensor, tensor<3x2x6xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x2xf32>, tensor<3x6x2xf32> + cond { + %19 = stablehlo.compare LT, %iterArg, %c : (tensor, tensor) -> tensor + stablehlo.return %19 : tensor + } do { + %19 = stablehlo.add %c_0, %iterArg : tensor + %20 = stablehlo.subtract %19, %c_1 : tensor + %21 = stablehlo.dynamic_slice %iterArg_10, %c_2, %20, %c_2, sizes = [3, 1, 2] : (tensor<3x6x2xf32>, tensor, tensor, tensor) -> tensor<3x1x2xf32> + %22 = stablehlo.transpose %21, dims = [2, 1, 0] : (tensor<3x1x2xf32>) -> tensor<2x1x3xf32> + %23 = stablehlo.reshape %22 : (tensor<2x1x3xf32>) -> tensor<2x3xf32> + %24 = stablehlo.dot_general %iterArg_5, %iterArg_9, contracting_dims = [1] x [0] : (tensor<3x3xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + %25 = stablehlo.broadcast_in_dim %iterArg_7, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> + %26 = stablehlo.add %24, %25 : tensor<3x2xf32> + %27 = stablehlo.dot_general %iterArg_4, %23, contracting_dims = [1] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32> + %28 = stablehlo.broadcast_in_dim %iterArg_6, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32> + %29 = stablehlo.add %27, %28 : tensor<3x2xf32> + %30 = stablehlo.add %26, %29 : tensor<3x2xf32> + %31 = stablehlo.tanh %30 : tensor<3x2xf32> + %32 = stablehlo.reshape %31 : (tensor<3x2xf32>) -> tensor<3x2x1xf32> + // CHECK-NOT: dynamic_update_slice + %33 = stablehlo.dynamic_update_slice %iterArg_3, %32, %c_2, %c_2, %20 : (tensor<3x2x6xf32>, tensor<3x2x1xf32>, tensor, tensor, tensor) -> tensor<3x2x6xf32> + %34 = stablehlo.add %iterArg, %c_1 : tensor + stablehlo.return %34, %33, %iterArg_4, %iterArg_5, %iterArg_6, %iterArg_7, %iterArg_8, %31, %iterArg_10 : tensor, tensor<3x2x6xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x2xf32>, tensor<3x6x2xf32> + } + %15 = stablehlo.transpose %14#7, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32> + %16 = stablehlo.transpose %14#8, dims = [2, 1, 0] : (tensor<3x6x2xf32>) -> tensor<2x6x3xf32> + %17 = stablehlo.transpose %14#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> + %18 = stablehlo.transpose %14#3, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32> + return %15, %14#6, %16, %17, %18, %14#4, %14#5 : tensor<2x3xf32>, tensor<2xui64>, tensor<2x6x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32> +}