Skip to content

Commit

Permalink
Reland: pattern to remove dead results of stablehlo.while
Browse files Browse the repository at this point in the history
Closes #266.
  • Loading branch information
ftynse authored and wsmoses committed Jan 27, 2025
1 parent fd5517f commit 633870e
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 13 deletions.
188 changes: 175 additions & 13 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
// ops.
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -1449,6 +1451,147 @@ struct ShiftRightLogicalSimplify final
}
};

struct WhileDeadResults final : OpRewritePattern<mlir::stablehlo::WhileOp> {
using OpRewritePattern::OpRewritePattern;

bool isLoopResultDead(OpResult result) const {
// Not dead if the result is in use.
if (!result.use_empty())
return false;

// Or if the corresponding argument is being used in computing the
// condition.
auto whileOp = cast<mlir::stablehlo::WhileOp>(result.getOwner());
Value condArgument =
whileOp.getCond().getArgument(result.getResultNumber());
SetVector<Operation *> forwardSlice;
getForwardSlice(condArgument, &forwardSlice);
if (!llvm::all_of(forwardSlice, mlir::isPure))
return false;
if (forwardSlice.contains(whileOp.getCond().front().getTerminator()))
return false;

// Or in computing another result. We first do a fast-path check of having
// the argument not influencing the terminator operation, before going into
// finer-grain analysis.
//
// TODO: it is possible that this argument does influence another terminator
// operand, but that operand in turn corresponds to a dead value, but
// handling that would require more complex logic of detecting dead cycles
// of value chains.
forwardSlice.clear();
assert(llvm::hasSingleElement(whileOp.getBody()));
Value bodyArgument =
whileOp.getBody().getArgument(result.getResultNumber());
getForwardSlice(bodyArgument, &forwardSlice);
if (!llvm::all_of(forwardSlice, mlir::isPure))
return false;

Operation *bodyTerminator = whileOp.getBody().front().getTerminator();
if (!forwardSlice.contains(bodyTerminator))
return true;

for (OpOperand &terminatorOperand : bodyTerminator->getOpOperands()) {
if (terminatorOperand.getOperandNumber() == result.getResultNumber())
continue;

SetVector<Operation *> backwardSlice;
BackwardSliceOptions options;
options.omitBlockArguments = true;
getBackwardSlice(terminatorOperand.get(), &backwardSlice, options);
for (Operation *op : backwardSlice) {
if (llvm::is_contained(op->getOperands(), bodyArgument))
return false;
}
}
return true;
}

void replaceTerminator(PatternRewriter &rewriter, Region &region,
ArrayRef<int64_t> deadResults) const {
Operation *terminator = region.front().getTerminator();
SmallVector<Value> terminatorOperands;
for (auto &&[i, operand] : llvm::enumerate(terminator->getOperands())) {
if (!llvm::is_contained(deadResults, i))
terminatorOperands.push_back(operand);
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<mlir::stablehlo::ReturnOp>(
terminator, TypeRange(), terminatorOperands, terminator->getAttrs());
}

LogicalResult matchAndRewrite(mlir::stablehlo::WhileOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> deadResults;
for (OpResult result : op.getResults()) {
if (!isLoopResultDead(result))
continue;

deadResults.push_back(result.getResultNumber());
}
if (deadResults.empty())
return failure();

SetVector<Operation *> condSlice, bodySlice;
for (int64_t i : deadResults) {
getForwardSlice(op.getCond().getArgument(i), &condSlice);
getForwardSlice(op.getBody().getArgument(i), &bodySlice);
}
condSlice.remove(op.getCond().front().getTerminator());
bodySlice.remove(op.getBody().front().getTerminator());
replaceTerminator(rewriter, op.getCond(), deadResults);
replaceTerminator(rewriter, op.getBody(), deadResults);

condSlice = mlir::topologicalSort(condSlice);
bodySlice = mlir::topologicalSort(bodySlice);
for (Operation *erasable : llvm::reverse(condSlice))
rewriter.eraseOp(erasable);
for (Operation *erasable : llvm::reverse(bodySlice))
rewriter.eraseOp(erasable);

SmallVector<Value> operands;
SmallVector<Type> resultTypes;
SmallVector<Location> condBlockArgLocs, bodyBlockArgsLocs;
for (auto &&[i, operand, resultType] :
llvm::enumerate(op->getOperands(), op.getResultTypes())) {
if (llvm::is_contained(deadResults, i))
continue;

operands.push_back(operand);
resultTypes.push_back(resultType);
condBlockArgLocs.push_back(op.getCond().getArgument(i).getLoc());
bodyBlockArgsLocs.push_back(op.getBody().getArgument(i).getLoc());
}

auto updated = rewriter.create<mlir::stablehlo::WhileOp>(
op->getLoc(), resultTypes, operands, op->getAttrs());
SmallVector<Value> resultReplacements;
for (int64_t old = 0, upd = 0, end = op->getNumResults(); old < end;
++old) {
if (llvm::is_contained(deadResults, old)) {
resultReplacements.push_back(nullptr);
continue;
}
resultReplacements.push_back(updated->getResult(upd));
++upd;
}

for (int64_t i : llvm::reverse(deadResults))
op.getCond().eraseArgument(i);
rewriter.inlineRegionBefore(op.getCond(), updated.getCond(),
updated.getCond().begin());

for (int64_t i : llvm::reverse(deadResults))
op.getBody().eraseArgument(i);
rewriter.inlineRegionBefore(op.getBody(), updated.getBody(),
updated.getBody().begin());

rewriter.replaceOp(op, resultReplacements);
return success();
}
};

struct NegativePadToSlice final : OpRewritePattern<mlir::stablehlo::PadOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -7394,19 +7537,38 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
}
patterns.add<NoNanAddSubSimplify>((no_nan || all_finite), context);

patterns.add<CompareOpCanon, BroadcastInDimOpCanon,
TransposeBroadcastInDimToBroadcastInDim, ConvertOpCanon,
DynamicBroadcastInDimOpNotActuallyDynamic,
ChainedDynamicBroadcastInDimCanonicalization,
DynamicBroadcastInDimAllDimsNonExpanding, NoopReduceOpCanon,
EmptyReduceOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
ConjComplexNegate, GetDimensionSizeOpCanon, GatherOpCanon,
ReshapeOpCanon, MergeConsecutiveReshapes, TransposeIsReshape,
SelectOpUsedWithinIf, IfInline, IfToSelect, WhileSimplify,
ZeroExtentTensorCanon, ReorderElementwiseAndShapeOp,
DynamicGatherOpIsNotDynamic, DivideSqrtToMultiplyRsqrt>(
context);
// clang-format off
patterns.add<
BroadcastInDimOpCanon,
ChainedDynamicBroadcastInDimCanonicalization,
CompareOpCanon,
ConjComplexNegate,
ConvertOpCanon,
DivideSqrtToMultiplyRsqrt,
DynamicBroadcastInDimAllDimsNonExpanding,
DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicGatherOpIsNotDynamic,
DynamicReshapeOpCanon,
EmptyReduceOpCanon,
GatherOpCanon,
GetDimensionSizeOpCanon,
GetTupleElementOpCanon,
IfInline,
IfToSelect,
ImagOpCanon,
MergeConsecutiveReshapes,
NoopReduceOpCanon,
RealOpCanon,
ReorderElementwiseAndShapeOp,
ReshapeOpCanon,
SelectOpUsedWithinIf,
TransposeBroadcastInDimToBroadcastInDim,
TransposeIsReshape,
WhileDeadResults,
WhileSimplify,
ZeroExtentTensorCanon
>(context);
// clang-format on
patterns.add<SelectOpCanon>(max_constant_expansion, context,
PatternBenefit(65000));
patterns.add<ConcatenateOpCanon>(max_constant_expansion, context,
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,10 @@ def ApplyShiftRightLogicalSimplifyPatterns : EnzymeHLOPatternOp<
"shift_right_logical_simplify"> {
let patterns = ["ShiftRightLogicalSimplify"];
}
def WhileDeadResultPatterns : EnzymeHLOPatternOp<
"while_deadresult"> {
let patterns = ["WhileDeadResults"];
}
def ApplyRemSimplifyPatterns : EnzymeHLOPatternOp<
"rem_simplify"> {
let patterns = ["RemSimplify"];
Expand Down
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def hlo_opts():
if_inline<1>;
if_to_select<1>;
while_simplify<1>;
while_deadresult<1>;
dot_reshape_pad<1>;
pad_dot_general<1>(1);
Expand Down
55 changes: 55 additions & 0 deletions test/lit_tests/whiledeadarg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

// CHECK-LABEL: @while_deadarg
func.func @while_deadarg(%arg0: tensor<2x6x3xf32>, %arg1: tensor<3x3xf32>, %arg2: tensor<3x3xf32>, %arg3: tensor<3xf32>, %arg4: tensor<3xf32>, %arg5: tensor<2xui64>) -> (tensor<2x3xf32>, tensor<2xui64>, tensor<2x6x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>) {
%c = stablehlo.constant dense<5> : tensor<i64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%c_0 = stablehlo.constant dense<2> : tensor<i64>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<0> : tensor<i64>
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x6x3xf32>) -> tensor<3x6x2xf32>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
%2 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
%3 = stablehlo.slice %0 [0:3, 0:1, 0:2] : (tensor<3x6x2xf32>) -> tensor<3x1x2xf32>
%4 = stablehlo.transpose %3, dims = [2, 1, 0] : (tensor<3x1x2xf32>) -> tensor<2x1x3xf32>
%5 = stablehlo.reshape %4 : (tensor<2x1x3xf32>) -> tensor<2x3xf32>
%6 = stablehlo.broadcast_in_dim %arg4, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
%7 = stablehlo.dot_general %arg1, %5, contracting_dims = [0] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32>
%8 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
%9 = stablehlo.add %7, %8 : tensor<3x2xf32>
%10 = stablehlo.add %6, %9 : tensor<3x2xf32>
%11 = stablehlo.tanh %10 : tensor<3x2xf32>
%12 = stablehlo.reshape %11 : (tensor<3x2xf32>) -> tensor<3x2x1xf32>
%13 = stablehlo.pad %12, %cst, low = [0, 0, 0], high = [0, 0, 5], interior = [0, 0, 0] : (tensor<3x2x1xf32>, tensor<f32>) -> tensor<3x2x6xf32>

// CHECK: %{{.+}}:8 = stablehlo.while
%14:9 = stablehlo.while(%iterArg = %c_2, %iterArg_3 = %13, %iterArg_4 = %1, %iterArg_5 = %2, %iterArg_6 = %arg3, %iterArg_7 = %arg4, %iterArg_8 = %arg5, %iterArg_9 = %11, %iterArg_10 = %0) : tensor<i64>, tensor<3x2x6xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x2xf32>, tensor<3x6x2xf32>
cond {
%19 = stablehlo.compare LT, %iterArg, %c : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %19 : tensor<i1>
} do {
%19 = stablehlo.add %c_0, %iterArg : tensor<i64>
%20 = stablehlo.subtract %19, %c_1 : tensor<i64>
%21 = stablehlo.dynamic_slice %iterArg_10, %c_2, %20, %c_2, sizes = [3, 1, 2] : (tensor<3x6x2xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<3x1x2xf32>
%22 = stablehlo.transpose %21, dims = [2, 1, 0] : (tensor<3x1x2xf32>) -> tensor<2x1x3xf32>
%23 = stablehlo.reshape %22 : (tensor<2x1x3xf32>) -> tensor<2x3xf32>
%24 = stablehlo.dot_general %iterArg_5, %iterArg_9, contracting_dims = [1] x [0] : (tensor<3x3xf32>, tensor<3x2xf32>) -> tensor<3x2xf32>
%25 = stablehlo.broadcast_in_dim %iterArg_7, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
%26 = stablehlo.add %24, %25 : tensor<3x2xf32>
%27 = stablehlo.dot_general %iterArg_4, %23, contracting_dims = [1] x [1] : (tensor<3x3xf32>, tensor<2x3xf32>) -> tensor<3x2xf32>
%28 = stablehlo.broadcast_in_dim %iterArg_6, dims = [0] : (tensor<3xf32>) -> tensor<3x2xf32>
%29 = stablehlo.add %27, %28 : tensor<3x2xf32>
%30 = stablehlo.add %26, %29 : tensor<3x2xf32>
%31 = stablehlo.tanh %30 : tensor<3x2xf32>
%32 = stablehlo.reshape %31 : (tensor<3x2xf32>) -> tensor<3x2x1xf32>
// CHECK-NOT: dynamic_update_slice
%33 = stablehlo.dynamic_update_slice %iterArg_3, %32, %c_2, %c_2, %20 : (tensor<3x2x6xf32>, tensor<3x2x1xf32>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<3x2x6xf32>
%34 = stablehlo.add %iterArg, %c_1 : tensor<i64>
stablehlo.return %34, %33, %iterArg_4, %iterArg_5, %iterArg_6, %iterArg_7, %iterArg_8, %31, %iterArg_10 : tensor<i64>, tensor<3x2x6xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x2xf32>, tensor<3x6x2xf32>
}
%15 = stablehlo.transpose %14#7, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32>
%16 = stablehlo.transpose %14#8, dims = [2, 1, 0] : (tensor<3x6x2xf32>) -> tensor<2x6x3xf32>
%17 = stablehlo.transpose %14#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
%18 = stablehlo.transpose %14#3, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
return %15, %14#6, %16, %17, %18, %14#4, %14#5 : tensor<2x3xf32>, tensor<2xui64>, tensor<2x6x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>
}

0 comments on commit 633870e

Please sign in to comment.