Skip to content

Commit

Permalink
Consolidate passes that operate on imported HLO
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728706709
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Feb 28, 2025
1 parent a83a2d3 commit 9d3c847
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 336 deletions.
37 changes: 0 additions & 37 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,6 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo_bounded_dynamism.mlir b/sta
%c = stablehlo.constant dense<1> : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
return %c : tensor<1x?xf32, #stablehlo.bounds<?, 5>>
}
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
@@ -924,6 +924,15 @@
// CHECK: %[[RES:.+]] = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f32>) -> tensor<7x2xf32>
// CHECK: return %[[RES]]
return %0 : tensor<7x2xf32>
+}
+
+// Can't do anything with the dynamic shape, but shouldn't crash.
+// CHECK-LABEL: @dynamic_pad
+func.func @dynamic_pad(%arg0: tensor<?x2x3xi1>, %arg1: tensor<i1>) -> tensor<?x2x1xi1> {
+ %0 = stablehlo.pad %arg0, %arg1, low = [0, 0, -1], high = [0, 0, -1], interior = [0, 0, 0] : (tensor<?x2x3xi1>, tensor<i1>) -> tensor<?x2x1xi1>
+ // CHECK-NEXT: %[[RES:.+]] = stablehlo.pad %arg0, %arg1, low = [0, 0, -1], high = [0, 0, -1], interior = [0, 0, 0] : (tensor<?x2x3xi1>, tensor<i1>) -> tensor<?x2x1xi1>
+ // CHECK-NEXT: return %[[RES]]
+ return %0 : tensor<?x2x1xi1>
}

// -----
diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_legalize_qdq_to_quantized_op.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_legalize_qdq_to_quantized_op.mlir
--- stablehlo/stablehlo/tests/transforms/stablehlo_legalize_qdq_to_quantized_op.mlir
+++ stablehlo/stablehlo/tests/transforms/stablehlo_legalize_qdq_to_quantized_op.mlir
Expand Down Expand Up @@ -95,22 +76,4 @@ diff --ruN a/stablehlo/stablehlo/transforms/StablehloLegalizeQDQToQuantizedOp.cp
llvm::SmallVector<Value> quantizedComputeOpOperands;
for (const Value& operand : computeOp->getOperands()) {
auto* definingOp = operand.getDefiningOp();
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
@@ -934,8 +934,12 @@
auto padVal = op.getPaddingValue();

auto resultTy = cast<RankedTensorType>(op.getType());
-
- if (cast<ShapedType>(operand.getType()).getNumElements() != 0)
+ auto operandTy = cast<RankedTensorType>(operand.getType());
+
+ if (!operandTy.hasStaticShape())
+ return rewriter.notifyMatchFailure(op, "operand shape is dynamic");
+
+ if (operandTy.getNumElements() != 0)
return rewriter.notifyMatchFailure(op, "operand is not empty tensor");

if (resultTy.hasStaticShape()) {

3 changes: 1 addition & 2 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,7 @@ cc_library(
"stablehlo_ext/transforms/sdy_refine_shapes.cpp",
"stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_entry_function_tuples.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_from_hlo_import.cpp",
"stablehlo_ext/transforms/stablehlo_legalize_quant_composite.cpp",
"stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp",
"stablehlo_ext/transforms/stablehlo_refine_shapes.cpp",
Expand Down
28 changes: 20 additions & 8 deletions xla/mlir_hlo/stablehlo_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,30 @@ def StablehloPrepareForHloExportPass : Pass<"stablehlo-ext-prepare-for-hlo-expor
}];
}

def StablehloFlattenTuplePass : Pass<"stablehlo-ext-flatten-tuple", "func::FuncOp"> {
let summary = "Flatten tuples in operands and results of operators that "
"support both tuple and variadic type.";
}
def StablehloCanonicalizeFromHloImportPass : Pass<"stablehlo-ext-canonicalize-from-hlo-import", "mlir::func::FuncOp"> {
let summary = "Simplify StableHLO imported from HLO";

let dependentDialects = ["stablehlo::StablehloDialect"];

let description = [{
This pass simplifies StableHLO imported from HLO. This pass is a subset of
the graph simplification passes and is intended to bring the imported HLO
back to a more canonical form without applying a full set of graph
simplifications.

Namely, this pass:
* Simplifies tuples, undoing `tuple(get_tuple_element)` and
`get_tuple_element(tuple)`.
* Converts WhileOp explicit captured constants to implicit captures.
* Flattens tuples in operands and results of operators that support both
tuple and variadic type.
* Flattens tuples in entry function of the module.
}];

def StablehloFlattenEntryFunctionTuplesPass : Pass<"stablehlo-ext-expand-flatten-entry-function-tuples", "ModuleOp"> {
let summary = "Flatten HLO tuple for the entry function of the module.";
let options = [
Option<"entryFunctionNameOption", "entry-function", "std::string",
/*default=*/"", "the name of entry function of the module">,
/*default=*/[{"main"}], "the name of entry function of the module">,
];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
}

def StablehloLegalizeQuantCompositePass : Pass<"stablehlo-ext-legalize-quant-composite", "ModuleOp"> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,128 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This file implements logic for flattening tuples in HLO ops.
// This file implements logic for some optimizations to reduce size on export.

#include <cassert>
#include <memory>
#include <iterator>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/optimization/Passes.h"
#include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc

#define DEBUG_TYPE "stablehlo-ext-canonicalize-from-hlo-import"

namespace mlir {
namespace stablehlo_ext {

#define GEN_PASS_DEF_STABLEHLOFLATTENTUPLEPASS
#define GEN_PASS_DEF_STABLEHLOCANONICALIZEFROMHLOIMPORTPASS
#include "stablehlo_ext/transforms/passes.h.inc"

namespace {

/////////////
// Flatten Tuples in entry computation

// Expands the mhlo.tuple used in return op. Also updates function
// signature accordingly.
void expandTupledTensorInReturnOp(func::FuncOp func) {
FunctionType oldFuncType = func.getFunctionType();
// Update input signatures.
// We will flatten the tuples for the function inputs as well.
// So if an input is tuple, will be flattened and packed as following:
// func_1(%arg0: tuple<input1, input2>) =>
//
// func_1(%arg0: <input1>, %arg1: <input2>) {
// %0 = mhlo.tuple(%arg0, %arg1)
// }
SmallVector<Type, 4> expandedInputTypes;
SmallVector<BlockArgument, 20> funcArguments(func.getArguments().begin(),
func.getArguments().end());
for (auto argument : funcArguments) {
auto type = argument.getType();
auto tupleType = mlir::dyn_cast_or_null<TupleType>(type);
if (!tupleType) {
expandedInputTypes.push_back(type);
} else {
// We need to
// 1) expand the tuple
// 2) insert a new tuple
// 3) rewire the new tuple
int originalArgumentIndex = argument.getArgNumber();
int argumentIndex = originalArgumentIndex;
SmallVector<Value, 4> flattenedOperands;
// insert the flattened tuples after the original tuple.
Location loc = func.getBody().getLoc();
for (auto flattenedType : tupleType.getTypes()) {
expandedInputTypes.push_back(flattenedType);
func.insertArgument(++argumentIndex, flattenedType, {}, loc);
flattenedOperands.push_back(func.getArgument(argumentIndex));
}

// Construct a new tuple and rewire it.
OpBuilder builder(func.getBody());
builder.setInsertionPointToStart(&func.getBody().front());
auto newTuple =
builder.create<stablehlo::TupleOp>(loc, tupleType, flattenedOperands);
func.getArgument(originalArgumentIndex).replaceAllUsesWith(newTuple);

// Now the original argument has been rewired, we should be able to
// safely erase it.
func.eraseArgument(originalArgumentIndex);
}
}

// Update output signatures.
auto returnOp = cast<mlir::func::ReturnOp>(func.getBody().back().back());
OpBuilder builder(returnOp);

// Expand all tuples in old return operands.
SmallVector<Value, 4> expandedReturnOperands;
SmallVector<Type, 4> expandedResultTypes;
for (auto value : returnOp.getOperands()) {
if (auto tupleTy = mlir::dyn_cast<TupleType>(value.getType())) {
llvm::copy(tupleTy.getTypes(), std::back_inserter(expandedResultTypes));
for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) {
expandedReturnOperands.push_back(
builder.createOrFold<stablehlo::GetTupleElementOp>(
value.getLoc(), ty, value, index));
}
} else {
expandedReturnOperands.push_back(value);
expandedResultTypes.push_back(value.getType());
}
}

if (returnOp.getOperands() == expandedReturnOperands) return;

builder.create<mlir::func::ReturnOp>(returnOp.getLoc(),
expandedReturnOperands);
returnOp.erase();
auto newFuncType = FunctionType::get(oldFuncType.getContext(),
expandedInputTypes, expandedResultTypes);
func.setType(newFuncType);
}

/////////////
// Flatten Tuples in Custom Calls

// Calculates the flatten types of a value.
void flattenTupleType(Value value, llvm::SmallVectorImpl<Type> &types) {
if (!mlir::isa<TupleType>(value.getType())) {
Expand Down Expand Up @@ -132,27 +219,46 @@ struct FlattenCustomCallOp : public OpRewritePattern<stablehlo::CustomCallOp> {
}
};

class StablehloFlattenTuplePass
: public impl::StablehloFlattenTuplePassBase<StablehloFlattenTuplePass> {
public:
// Simplify a model after HLO import.
struct StablehloCanonicalizeFromHloImportPass
: public impl::StablehloCanonicalizeFromHloImportPassBase<
StablehloCanonicalizeFromHloImportPass> {
using StablehloCanonicalizeFromHloImportPassBase::
StablehloCanonicalizeFromHloImportPassBase;

void runOnOperation() override {
// If entry function, flatten the input tuples
func::FuncOp func = getOperation();
if (func.getName() == entryFunctionNameOption.getValue()) {
// Recursively expand tuples until all of them are gone.
while (
llvm::any_of(llvm::concat<const Type>(func.getArgumentTypes(),
func.getResultTypes()),
[](Type type) { return mlir::isa<TupleType>(type); })) {
expandTupledTensorInReturnOp(func);
}
}

// Flatten tuples in function body
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<FlattenCustomCallOp>(context);
stablehlo::populateStablehloHloImportCanonicalizationPatterns(context,
&patterns);

// Apply patterns without folding
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.fold = false;
config.cseConstants = false;
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
config))) {
if (failed(applyPatternsGreedily(func, std::move(patterns), config)))
signalPassFailure();
}
}
};

} // namespace
} // end namespace

} // namespace stablehlo_ext
} // namespace mlir
Loading

0 comments on commit 9d3c847

Please sign in to comment.