Skip to content

Commit

Permalink
enzyme.broadcast conversion (#197)
Browse files Browse the repository at this point in the history
* add enzyme.broadcast to `stablehlo.broadcast_in_dim`/`mhlo.broadcast` conversion in `arith-raise`.

* test

* formatting
  • Loading branch information
jumerckx authored Dec 27, 2024
1 parent f550542 commit b2d055f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
26 changes: 26 additions & 0 deletions src/enzyme_ad/jax/Passes/ArithRaising.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
// ops.
//===----------------------------------------------------------------------===//

#include "Enzyme/MLIR/Dialect/Dialect.h"
#include "Enzyme/MLIR/Dialect/Ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -91,6 +93,30 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
constOp.erase();
}
});
op->walk([=](enzyme::BroadcastOp broadcastOp) {
OpBuilder builder(broadcastOp);
Value newBroadcastOp;
if (use_stablehlo) {
SmallVector<int64_t> broadcastDims;
auto shape =
broadcastOp.getInput().getType().cast<TensorType>().getShape();
broadcastDims.reserve(shape.size());
for (auto en : llvm::enumerate(shape)) {
// original dimensions end up one further because the batch dimension
// is prepended:
broadcastDims.push_back(en.index() + 1);
}
newBroadcastOp = builder.create<stablehlo::BroadcastInDimOp>(
broadcastOp.getLoc(), broadcastOp.getType(), broadcastOp.getInput(),
builder.getDenseI64ArrayAttr(broadcastDims));
} else {
newBroadcastOp = builder.create<mhlo::BroadcastOp>(
broadcastOp.getLoc(), broadcastOp.getInput(),
builder.getI64TensorAttr({broadcastOp.getWidth()}));
}
broadcastOp.replaceAllUsesWith(newBroadcastOp);
broadcastOp.erase();
});
}
};

Expand Down
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def ArithRaisingPass : Pass<"arith-raise"> {
"arith::ArithDialect",
"mhlo::MhloDialect",
"stablehlo::StablehloDialect",
"chlo::ChloDialect"
"chlo::ChloDialect",
"enzyme::EnzymeDialect",
];
let constructor = "mlir::enzyme::createArithRaisingPass()";
let options = [
Expand Down
16 changes: 16 additions & 0 deletions test/lit_tests/broadcastdiff.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: enzymexlamlir-opt --arith-raise %s | FileCheck %s

module {
func.func @main(%arg0: tensor<f64>, %arg1: tensor<2xf64>) -> tensor<2xf64> {
%0 = "enzyme.broadcast"(%arg0) <{width = 2 : i64}> : (tensor<f64>) -> tensor<2xf64>
%1 = arith.addf %0, %arg1 : tensor<2xf64>
return %1 : tensor<2xf64>
}
}

// CHECK: func.func @main(%arg0: tensor<f64>, %arg1: tensor<2xf64>) -> tensor<2xf64> {
// CHECK-NEXT: %[[i0:.+]] = stablehlo.broadcast_in_dim %arg0, dims = [] : (tensor<f64>) -> tensor<2xf64>
// CHECK-NEXT: %[[i1:.+]] = stablehlo.add %[[i0:.+]], %arg1 : tensor<2xf64>
// CHECK-NEXT: return %[[i1:.+]] : tensor<2xf64>
// CHECK-NEXT: }
// CHECK-NEXT: }

0 comments on commit b2d055f

Please sign in to comment.