Skip to content

Commit

Permalink
Fix shaped broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 27, 2024
1 parent b2d055f commit f2094b3
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions src/enzyme_ad/jax/Passes/ArithRaising.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,24 +96,19 @@ struct ArithRaisingPass : public ArithRaisingPassBase<ArithRaisingPass> {
op->walk([=](enzyme::BroadcastOp broadcastOp) {
OpBuilder builder(broadcastOp);
Value newBroadcastOp;
if (use_stablehlo) {
SmallVector<int64_t> broadcastDims;
auto shape =
broadcastOp.getInput().getType().cast<TensorType>().getShape();
broadcastDims.reserve(shape.size());
for (auto en : llvm::enumerate(shape)) {
// original dimensions end up one further because the batch dimension
// is prepended:
broadcastDims.push_back(en.index() + 1);
}
newBroadcastOp = builder.create<stablehlo::BroadcastInDimOp>(
broadcastOp.getLoc(), broadcastOp.getType(), broadcastOp.getInput(),
builder.getDenseI64ArrayAttr(broadcastDims));
} else {
newBroadcastOp = builder.create<mhlo::BroadcastOp>(
broadcastOp.getLoc(), broadcastOp.getInput(),
builder.getI64TensorAttr({broadcastOp.getWidth()}));
assert(use_stablehlo);
SmallVector<int64_t> broadcastDims;
auto shape =
broadcastOp.getInput().getType().cast<TensorType>().getShape();
broadcastDims.reserve(shape.size());
for (auto en : llvm::enumerate(shape)) {
// original dimensions end up one further because the batch dimension
// is prepended:
broadcastDims.push_back(en.index() + 1);
}
newBroadcastOp = builder.create<stablehlo::BroadcastInDimOp>(
broadcastOp.getLoc(), broadcastOp.getType(), broadcastOp.getInput(),
builder.getDenseI64ArrayAttr(broadcastDims));
broadcastOp.replaceAllUsesWith(newBroadcastOp);
broadcastOp.erase();
});
Expand Down

0 comments on commit f2094b3

Please sign in to comment.