Skip to content

Commit

Permalink
Consolidate passes that operate on imported HLO
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 728706709
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Feb 28, 2025
1 parent f0dddfc commit 7fdeab8
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 299 deletions.
75 changes: 75 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,43 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplific
}

// -----
@@ -1908,6 +1917,19 @@

// -----

+// CHECK-LABEL: @side_effecting_custom_call
+func.func @side_effecting_custom_call(%arg0: tensor<0xf32>) -> (tensor<0xf32>, tensor<0xf32>) {
+ // CHECK: %[[CST:.*]] = stablehlo.constant dense<> : tensor<0xf32>
+ // CHECK-NEXT: %[[CC:.*]] = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = true} : (tensor<0xf32>) -> tensor<0xf32>
+ %0 = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = true} : (tensor<0xf32>) -> tensor<0xf32>
+ // CHECK-NOT: stablehlo.custom_call{{.*}}has_side_effect = false
+ %1 = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = false} : (tensor<0xf32>) -> tensor<0xf32>
+ // CHECK: return %[[CC]], %[[CST]]
+ return %0, %1 : tensor<0xf32>, tensor<0xf32>
+}
+
+// -----
+
/////////
// Generic Shape Ops

diff --ruN a/stablehlo/stablehlo/transforms/optimization/Passes.h b/stablehlo/stablehlo/transforms/optimization/Passes.h
--- stablehlo/stablehlo/transforms/optimization/Passes.h
+++ stablehlo/stablehlo/transforms/optimization/Passes.h
@@ -50,6 +50,13 @@
MLIRContext *context,
bool foldFloat = false,
PatternBenefit benefit = 1);
+
+/// Some workloads in XLA import StableHLO from HLO. Since there are a few
+/// differences in HLO (no implicit captures, lots of tuples, etc.), this
+/// set of patterns brings the imported HLO back to a more canonical form
+/// without applying a full set of graph simplifications.
+void populateStablehloHloImportCanonicalizationPatterns(
+ MLIRContext *context, RewritePatternSet *patterns);
} // namespace stablehlo
} // namespace mlir

diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
Expand All @@ -68,4 +105,42 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
return rewriter.notifyMatchFailure(op, "operand is not empty tensor");

if (resultTy.hasStaticShape()) {
@@ -1399,6 +1403,12 @@
return rewriter.notifyMatchFailure(op, "not stablehlo");
if (isa<ConstantOp>(op))
return rewriter.notifyMatchFailure(op, "op is empty constant");
+
+ // Skip ops that have memory effects, similar to XLA's zero extent
+ // simplification, replacing these doesn't save any computation.
+ auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (effectInterface && !effectInterface.hasNoEffect())
+ return rewriter.notifyMatchFailure(op, "op has memory effect");

// If the result is a zero-extent tensor, replace the whole op with an empty
// constant.
@@ -1528,6 +1538,12 @@
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
}

+void populateStablehloHloImportCanonicalizationPatterns(
+ MLIRContext *context, RewritePatternSet *patterns) {
+ patterns->add<TupleIsRepacking, TupleIsUnpacked, WhileOpImplicitCapture>(
+ context);
+}
+
std::unique_ptr<Pass> createStablehloAggressiveSimplificationPass(
GreedyRewriteConfig config) {
return std::make_unique<StablehloAggressiveSimplificationPass>(config);
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
@@ -411,7 +411,7 @@
// GetTupleElementOp

// Pattern: get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
-def : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx),
+def TupleIsUnpacked : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx),
(GetOperandN $tuple, $idx)>;

////////

3 changes: 1 addition & 2 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,7 @@ cc_library(
"stablehlo_ext/transforms/sdy_refine_shapes.cpp",
"stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_entry_function_tuples.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_from_hlo_import.cpp",
"stablehlo_ext/transforms/stablehlo_legalize_quant_composite.cpp",
"stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp",
"stablehlo_ext/transforms/stablehlo_refine_shapes.cpp",
Expand Down
28 changes: 20 additions & 8 deletions xla/mlir_hlo/stablehlo_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,30 @@ def StablehloPrepareForHloExportPass : Pass<"stablehlo-ext-prepare-for-hlo-expor
}];
}

def StablehloFlattenTuplePass : Pass<"stablehlo-ext-flatten-tuple", "func::FuncOp"> {
let summary = "Flatten tuples in operands and results of operators that "
"support both tuple and variadic type.";
}
def StablehloCanonicalizeFromHloImportPass : Pass<"stablehlo-ext-canonicalize-from-hlo-import", "mlir::func::FuncOp"> {
let summary = "Simplify StableHLO imported from HLO";

let dependentDialects = ["stablehlo::StablehloDialect"];

let description = [{
This pass simplifies StableHLO imported from HLO. This pass is a subset of
the graph simplification passes and is intended to bring the imported HLO
back to a more canonical form without applying a full set of graph
simplifications.

Namely, this pass:
* Simplifies tuples, undoing `tuple(get_tuple_element)` and
`get_tuple_element(tuple)`.
* Converts WhileOp explicit captured constants to implicit captures.
* Flattens tuples in operands and results of operators that support both
tuple and variadic type.
* Flattens tuples in entry function of the module.
}];

def StablehloFlattenEntryFunctionTuplesPass : Pass<"stablehlo-ext-expand-flatten-entry-function-tuples", "ModuleOp"> {
let summary = "Flatten HLO tuple for the entry function of the module.";
let options = [
Option<"entryFunctionNameOption", "entry-function", "std::string",
/*default=*/"", "the name of entry function of the module">,
/*default=*/[{"main"}], "the name of entry function of the module">,
];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
}

def StablehloLegalizeQuantCompositePass : Pass<"stablehlo-ext-legalize-quant-composite", "ModuleOp"> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,128 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This file implements logic for flattening tuples in HLO ops.
// This file implements logic for some optimizations to reduce size on export.

#include <cassert>
#include <memory>
#include <iterator>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/transforms/optimization/Passes.h"
#include "stablehlo_ext/transforms/passes.h" // NOLINT: Used in passes.h.inc

#define DEBUG_TYPE "stablehlo-ext-canonicalize-from-hlo-import"

namespace mlir {
namespace stablehlo_ext {

#define GEN_PASS_DEF_STABLEHLOFLATTENTUPLEPASS
#define GEN_PASS_DEF_STABLEHLOCANONICALIZEFROMHLOIMPORTPASS
#include "stablehlo_ext/transforms/passes.h.inc"

namespace {

/////////////
// Flatten Tuples in entry computation

// Expands the mhlo.tuple used in return op. Also updates function
// signature accordingly.
void expandTupledTensorInReturnOp(func::FuncOp func) {
FunctionType oldFuncType = func.getFunctionType();
// Update input signatures.
// We will flatten the tuples for the function inputs as well.
// So if an input is tuple, will be flattened and packed as following:
// func_1(%arg0: tuple<input1, input2>) =>
//
// func_1(%arg0: <input1>, %arg1: <input2>) {
// %0 = mhlo.tuple(%arg0, %arg1)
// }
SmallVector<Type, 4> expandedInputTypes;
SmallVector<BlockArgument, 20> funcArguments(func.getArguments().begin(),
func.getArguments().end());
for (auto argument : funcArguments) {
auto type = argument.getType();
auto tupleType = mlir::dyn_cast_or_null<TupleType>(type);
if (!tupleType) {
expandedInputTypes.push_back(type);
} else {
// We need to
// 1) expand the tuple
// 2) insert a new tuple
// 3) rewire the new tuple
int originalArgumentIndex = argument.getArgNumber();
int argumentIndex = originalArgumentIndex;
SmallVector<Value, 4> flattenedOperands;
// insert the flattened tuples after the original tuple.
Location loc = func.getBody().getLoc();
for (auto flattenedType : tupleType.getTypes()) {
expandedInputTypes.push_back(flattenedType);
func.insertArgument(++argumentIndex, flattenedType, {}, loc);
flattenedOperands.push_back(func.getArgument(argumentIndex));
}

// Construct a new tuple and rewire it.
OpBuilder builder(func.getBody());
builder.setInsertionPointToStart(&func.getBody().front());
auto newTuple =
builder.create<stablehlo::TupleOp>(loc, tupleType, flattenedOperands);
func.getArgument(originalArgumentIndex).replaceAllUsesWith(newTuple);

// Now the original argument has been rewired, we should be able to
// safely erase it.
func.eraseArgument(originalArgumentIndex);
}
}

// Update output signatures.
auto returnOp = cast<mlir::func::ReturnOp>(func.getBody().back().back());
OpBuilder builder(returnOp);

// Expand all tuples in old return operands.
SmallVector<Value, 4> expandedReturnOperands;
SmallVector<Type, 4> expandedResultTypes;
for (auto value : returnOp.getOperands()) {
if (auto tupleTy = mlir::dyn_cast<TupleType>(value.getType())) {
llvm::copy(tupleTy.getTypes(), std::back_inserter(expandedResultTypes));
for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) {
expandedReturnOperands.push_back(
builder.createOrFold<stablehlo::GetTupleElementOp>(
value.getLoc(), ty, value, index));
}
} else {
expandedReturnOperands.push_back(value);
expandedResultTypes.push_back(value.getType());
}
}

if (returnOp.getOperands() == expandedReturnOperands) return;

builder.create<mlir::func::ReturnOp>(returnOp.getLoc(),
expandedReturnOperands);
returnOp.erase();
auto newFuncType = FunctionType::get(oldFuncType.getContext(),
expandedInputTypes, expandedResultTypes);
func.setType(newFuncType);
}

/////////////
// Flatten Tuples in Custom Calls

// Calculates the flatten types of a value.
void flattenTupleType(Value value, llvm::SmallVectorImpl<Type> &types) {
if (!mlir::isa<TupleType>(value.getType())) {
Expand Down Expand Up @@ -132,27 +219,46 @@ struct FlattenCustomCallOp : public OpRewritePattern<stablehlo::CustomCallOp> {
}
};

class StablehloFlattenTuplePass
: public impl::StablehloFlattenTuplePassBase<StablehloFlattenTuplePass> {
public:
// Simplify a model after HLO import.
struct StablehloCanonicalizeFromHloImportPass
: public impl::StablehloCanonicalizeFromHloImportPassBase<
StablehloCanonicalizeFromHloImportPass> {
using StablehloCanonicalizeFromHloImportPassBase::
StablehloCanonicalizeFromHloImportPassBase;

void runOnOperation() override {
// If entry function, flatten the input tuples
func::FuncOp func = getOperation();
if (func.getName() == entryFunctionNameOption.getValue()) {
// Recursively expand tuples until all of them are gone.
while (
llvm::any_of(llvm::concat<const Type>(func.getArgumentTypes(),
func.getResultTypes()),
[](Type type) { return mlir::isa<TupleType>(type); })) {
expandTupledTensorInReturnOp(func);
}
}

// Flatten tuples in function body
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<FlattenCustomCallOp>(context);
stablehlo::populateStablehloHloImportCanonicalizationPatterns(context,
&patterns);

// Apply patterns without folding
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
config.fold = false;
config.cseConstants = false;
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
config))) {
if (failed(applyPatternsGreedily(func, std::move(patterns), config)))
signalPassFailure();
}
}
};

} // namespace
} // end namespace

} // namespace stablehlo_ext
} // namespace mlir
Loading

0 comments on commit 7fdeab8

Please sign in to comment.