Skip to content

Commit

Permalink
Enzyme ops removal for stablehlo.if (#224)
Browse files Browse the repository at this point in the history
* Enzyme ops removal for stablehlo.if

* Add test for ops removal in stabehlo.if

* fix previous if test
  • Loading branch information
Pangoraw authored Jan 8, 2025
1 parent 85612ea commit aa23e9f
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 18 deletions.
196 changes: 196 additions & 0 deletions src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
#include "Enzyme/MLIR/Passes/RemovalUtils.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -1822,6 +1823,199 @@ class AutoDiffSort
}
};

static void removalBlockExplore(Block *block, IRMapping &mapping,
OpBuilder &builder,
llvm::SmallDenseSet<Value> &gradients,
llvm::DenseMap<Value, CacheInfo> &caches) {
for (auto it = block->begin(), e = block->end(); it != e;) {
Operation *op = &*it;

if (auto setOp = dyn_cast<enzyme::SetOp>(op)) {
auto grad = setOp.getGradient();
auto value = setOp.getValue();
mapping.map(grad, value);
gradients.insert(grad);
}

if (auto getOp = dyn_cast<enzyme::GetOp>(op)) {
auto grad = getOp.getGradient();
Value value = mapping.lookupOrNull(getOp.getGradient());
if (!value) {
value = builder.create<enzyme::GetOp>(
getOp->getLoc(), getOp.getResult().getType(), grad);
mapping.map(grad, value);
}
getOp.getResult().replaceAllUsesWith(value);
}

if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
CacheInfo info(pushOp.getCache());

Value pushedValue = info.pushedValue();

// Then we can push the value before the if, if it is defined before the
// if
if (pushedValue.getParentBlock() != block) {
builder.create<enzyme::PushOp>(pushOp->getLoc(), pushOp.getCache(),
pushedValue);

++it; // Increment iterator to allow in place deletion
pushOp->erase();

// Move the pop before the other if
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(info.popOp->getParentOp());

auto newPop = builder.create<enzyme::PopOp>(
info.popOp->getLoc(), pushedValue.getType(), info.popOp.getCache());
info.popOp.getResult().replaceAllUsesWith(newPop);
info.popOp->erase();

continue;
}

if (caches.contains(pushedValue)) {
info = info.merge(caches.lookup(pushedValue));
}
caches[pushedValue] = info;
}

++it;
}
}

struct IfOpEnzymeOpsRemover
: public EnzymeOpsRemoverOpInterface::ExternalModel<IfOpEnzymeOpsRemover,
stablehlo::IfOp> {
LogicalResult removeEnzymeOps(Operation *op) const {
// Gradients:
//
// For each set in a branch, we instead set after the if by using the
// return value.
//
// if %pred {
// enzyme.set %grad, %2
// } else {
// }
//
// %0 = enzyme.get %grad
// %1 = if %pred {
// return %2
// } else {
// return %0
// }
// enzyme.set %grad, %1
//
// For each get in a branch, we get before and use that instead of the
// get.

// Caches:
//
// For each push, push after the if instead add a dummy value in the other
// branch.
//
// For each pop in the reverse if, pop before the if instead of inside a
// branch.

auto ifOp = cast<IfOp>(op);

Block *trueBlock = &ifOp.getTrueBranch().front(),
*falseBlock = &ifOp.getFalseBranch().front();

if (enzyme::removeOpsWithinBlock(trueBlock).failed() ||
enzyme::removeOpsWithinBlock(falseBlock).failed()) {
return failure();
}

// Gradients whose value is set in either branches.
llvm::SmallDenseSet<Value> gradients;

// We assume pushes are exclusive.
llvm::DenseMap<Value, CacheInfo> pushedCaches;

// Grad to value
IRMapping trueMapping, falseMapping;
OpBuilder builder(ifOp);

removalBlockExplore(trueBlock, trueMapping, builder, gradients,
pushedCaches);
removalBlockExplore(falseBlock, falseMapping, builder, gradients,
pushedCaches);

Operation *trueTerm = trueBlock->getTerminator();
Operation *falseTerm = falseBlock->getTerminator();

for (auto grad : gradients) {
auto trueValue = trueMapping.lookupOrNull(grad);
if (!trueValue) {
trueValue = builder.create<enzyme::GetOp>(
grad.getLoc(),
grad.getType().cast<enzyme::GradientType>().getBasetype(), grad);
}
trueTerm->insertOperands(trueTerm->getNumOperands(),
ValueRange(trueValue));

auto falseValue = falseMapping.lookupOrNull(grad);
if (!falseValue) {
falseValue = builder.create<enzyme::GetOp>(
grad.getLoc(),
grad.getType().cast<enzyme::GradientType>().getBasetype(), grad);
}
falseTerm->insertOperands(falseTerm->getNumOperands(),
ValueRange(falseValue));
}

for (auto &[pushedValue, info] : pushedCaches) {
Value dummy =
pushedValue.getType().cast<AutoDiffTypeInterface>().createNullValue(
builder, pushedValue.getLoc());

Value trueValue =
pushedValue.getParentBlock() == trueBlock ? pushedValue : dummy;
Value falseValue =
pushedValue.getParentBlock() == falseBlock ? pushedValue : dummy;

trueTerm->insertOperands(trueTerm->getNumOperands(),
ValueRange(trueValue));
falseTerm->insertOperands(falseTerm->getNumOperands(),
ValueRange(falseValue));
}

auto newIf = builder.create<stablehlo::IfOp>(
ifOp->getLoc(), trueTerm->getOperandTypes(), ifOp.getPred());
newIf.getTrueBranch().takeBody(ifOp.getTrueBranch());
newIf.getFalseBranch().takeBody(ifOp.getFalseBranch());

size_t idx = ifOp->getNumResults();
for (auto grad : gradients) {
builder.create<enzyme::SetOp>(grad.getLoc(), grad, newIf->getResult(idx));
idx++;
}

for (auto &[pushedValue, info] : pushedCaches) {
builder.create<enzyme::PushOp>(info.pushOp->getLoc(),
info.initOp.getResult(),
newIf->getResult(idx));
info.pushOp->erase();

OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(info.popOp->getParentOp());

auto newPop = builder.create<enzyme::PopOp>(
info.popOp->getLoc(), info.popOp.getResult().getType(),
info.popOp.getCache());
info.popOp.getResult().replaceAllUsesWith(newPop);
info.popOp->erase();

idx++;
}

ifOp->erase();

return success();
}
};

} // namespace

void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
Expand All @@ -1832,6 +2026,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(

// SortOp::attachInterface<AutoDiffSort>(*context);

IfOp::attachInterface<IfOpEnzymeOpsRemover>(*context);

WhileOp::attachInterface<ADDataFlowWhileOp>(*context);
SortOp::attachInterface<ADDataFlowSortOp>(*context);
ScatterOp::attachInterface<ADDataFlowScatterOp>(*context);
Expand Down
26 changes: 8 additions & 18 deletions test/lit_tests/diffrules/stablehlo/if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,13 @@ module {

// REVERSE: func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<i1>, %arg2: tensor<10xf32>) -> tensor<10xf32> {
// REVERSE-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<10xf32>
// REVERSE-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<10xf32>>
// REVERSE-NEXT: "enzyme.set"(%0, %cst) : (!enzyme.Gradient<tensor<10xf32>>, tensor<10xf32>) -> ()
// REVERSE-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<10xf32>>
// REVERSE-NEXT: "enzyme.set"(%1, %cst) : (!enzyme.Gradient<tensor<10xf32>>, tensor<10xf32>) -> ()
// REVERSE-NEXT: %2 = arith.addf %arg2, %cst : tensor<10xf32>
// REVERSE-NEXT: "stablehlo.if"(%arg1) ({
// REVERSE-NEXT: %4 = "enzyme.get"(%1) : (!enzyme.Gradient<tensor<10xf32>>) -> tensor<10xf32>
// REVERSE-NEXT: %5 = arith.addf %4, %2 : tensor<10xf32>
// REVERSE-NEXT: "enzyme.set"(%1, %5) : (!enzyme.Gradient<tensor<10xf32>>, tensor<10xf32>) -> ()
// REVERSE-NEXT: "enzyme.set"(%1, %cst) : (!enzyme.Gradient<tensor<10xf32>>, tensor<10xf32>) -> ()
// REVERSE-NEXT: %6 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<10xf32>>) -> tensor<10xf32>
// REVERSE-NEXT: %7 = arith.addf %6, %5 : tensor<10xf32>
// REVERSE-NEXT: "enzyme.set"(%0, %7) : (!enzyme.Gradient<tensor<10xf32>>, tensor<10xf32>) -> ()
// REVERSE-NEXT: stablehlo.return
// REVERSE-NEXT: %0 = arith.addf %arg2, %cst : tensor<10xf32>
// REVERSE-NEXT: %1:2 = "stablehlo.if"(%arg1) ({
// REVERSE-NEXT: %2 = arith.addf %0, %cst : tensor<10xf32>
// REVERSE-NEXT: %3 = arith.addf %2, %cst : tensor<10xf32>
// REVERSE-NEXT: stablehlo.return %3, %cst : tensor<10xf32>, tensor<10xf32>
// REVERSE-NEXT: }, {
// REVERSE-NEXT: stablehlo.return
// REVERSE-NEXT: }) : (tensor<i1>) -> ()
// REVERSE-NEXT: %3 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<10xf32>>) -> tensor<10xf32>
// REVERSE-NEXT: return %3 : tensor<10xf32>
// REVERSE-NEXT: stablehlo.return %cst, %cst : tensor<10xf32>, tensor<10xf32>
// REVERSE-NEXT: }) : (tensor<i1>) -> (tensor<10xf32>, tensor<10xf32>)
// REVERSE-NEXT: return %1#0 : tensor<10xf32>
// REVERSE-NEXT: }
30 changes: 30 additions & 0 deletions test/lit_tests/diffrules/stablehlo/if_remove.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --enzyme-hlo-opt | FileCheck %s --check-prefix=REVERSE

module {
func.func @main(%arg0: tensor<10xf32>, %pred: tensor<i1>) -> tensor<10xf32> {
%cst = stablehlo.constant dense<1.0> : tensor<10xf32>

%0 = "stablehlo.if"(%pred) ({
%1 = stablehlo.multiply %arg0, %cst : tensor<10xf32>
%2 = stablehlo.multiply %1, %1 : tensor<10xf32>
"stablehlo.return"(%2) : (tensor<10xf32>) -> ()
}, {
"stablehlo.return"(%cst) : (tensor<10xf32>) -> ()
}) : (tensor<i1>) -> tensor<10xf32>

return %0 : tensor<10xf32>
}
}

// REVERSE: func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<i1>, %arg2: tensor<10xf32>) -> tensor<10xf32> {
// REVERSE-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf32>
// REVERSE-NEXT: %0 = stablehlo.select %arg1, %arg0, %cst : tensor<i1>, tensor<10xf32>
// REVERSE-NEXT: %1 = "stablehlo.if"(%arg1) ({
// REVERSE-NEXT: %2 = stablehlo.multiply %arg2, %0 : tensor<10xf32>
// REVERSE-NEXT: %3 = stablehlo.add %2, %2 : tensor<10xf32>
// REVERSE-NEXT: stablehlo.return %3 : tensor<10xf32>
// REVERSE-NEXT: }, {
// REVERSE-NEXT: stablehlo.return %cst : tensor<10xf32>
// REVERSE-NEXT: }) : (tensor<i1>) -> tensor<10xf32>
// REVERSE-NEXT: return %1 : tensor<10xf32>
// REVERSE-NEXT: }

0 comments on commit aa23e9f

Please sign in to comment.