Skip to content

Commit

Permalink
PR #23181: Annotate loops: start, step, induction variable.
Browse files Browse the repository at this point in the history
Imported from GitHub PR #23181

Currently, we annotate while loops with a known trip count accordingly. This PR adds the start, step and induction variable so it's easy to see when a while loop is actually a for loop.

For the larger context see
main...jreiffers:xla:memcpy and the companion document
https://docs.google.com/document/d/1E2_Jt_Dw4VbPXPVktNWhtsDEtIpN-kurCMGGyvMV4JA/edit?tab=t.0.
Copybara import of the project:

--
a5f8f4e by Johannes Reifferscheid <[email protected]>:

Annotate loops: start, step, induction variable.

Currently, we annotate while loops with a known trip count accordingly.
This PR adds the start, step and induction variable so it's easy to see
when a while loop is actually a for loop.

For the larger context see
main...jreiffers:xla:memcpy and
the companion document
https://docs.google.com/document/d/1E2_Jt_Dw4VbPXPVktNWhtsDEtIpN-kurCMGGyvMV4JA/edit?tab=t.0.

Merging this change closes #23181

FUTURE_COPYBARA_INTEGRATE_REVIEW=#23181 from jreiffers:while-annotator a5f8f4e
PiperOrigin-RevId: 731860761
  • Loading branch information
jreiffers authored and Google-ML-Automation committed Feb 28, 2025
1 parent 0c04616 commit 5a2abfc
Show file tree
Hide file tree
Showing 13 changed files with 484 additions and 313 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/cpu_benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
defaults:
run:
shell: bash
timeout-minutes: 360
timeout-minutes: 540
steps:
- name: Print machine specs
run: |
Expand Down Expand Up @@ -108,6 +108,7 @@ jobs:
cd tmp_hlo
wget https://storage.googleapis.com/xla-benchmarking-temp/gemma2_2b_keras_jax.hlo
cd ..
./bazel-bin/xla/tools/run_hlo_module --input_format=hlo --platform=CPU tmp_hlo/gemma2_2b_keras_jax.hlo
- name: Compute the cost of gemma2_2b_keras_jax.hlo
run: |
Expand Down
75 changes: 75 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,43 @@ diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_aggressive_simplific
}

// -----
@@ -1908,6 +1917,19 @@

// -----

+// CHECK-LABEL: @side_effecting_custom_call
+func.func @side_effecting_custom_call(%arg0: tensor<0xf32>) -> (tensor<0xf32>, tensor<0xf32>) {
+ // CHECK: %[[CST:.*]] = stablehlo.constant dense<> : tensor<0xf32>
+ // CHECK-NEXT: %[[CC:.*]] = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = true} : (tensor<0xf32>) -> tensor<0xf32>
+ %0 = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = true} : (tensor<0xf32>) -> tensor<0xf32>
+ // CHECK-NOT: stablehlo.custom_call{{.*}}has_side_effect = false
+ %1 = stablehlo.custom_call @foo(%arg0) {api_version = 0 : i32, has_side_effect = false} : (tensor<0xf32>) -> tensor<0xf32>
+ // CHECK: return %[[CC]], %[[CST]]
+ return %0, %1 : tensor<0xf32>, tensor<0xf32>
+}
+
+// -----
+
/////////
// Generic Shape Ops

diff --ruN a/stablehlo/stablehlo/transforms/optimization/Passes.h b/stablehlo/stablehlo/transforms/optimization/Passes.h
--- stablehlo/stablehlo/transforms/optimization/Passes.h
+++ stablehlo/stablehlo/transforms/optimization/Passes.h
@@ -50,6 +50,13 @@
MLIRContext *context,
bool foldFloat = false,
PatternBenefit benefit = 1);
+
+/// Some workloads in XLA import StableHLO from HLO. Since there are a few
+/// differences in HLO (no implicit captures, lots of tuples, etc.), this
+/// set of patterns brings the imported HLO back to a more canonical form
+/// without applying a full set of graph simplifications.
+void populateStablehloHloImportCanonicalizationPatterns(
+ MLIRContext *context, RewritePatternSet *patterns);
} // namespace stablehlo
} // namespace mlir

diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp
Expand All @@ -68,4 +105,42 @@ diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimp
return rewriter.notifyMatchFailure(op, "operand is not empty tensor");

if (resultTy.hasStaticShape()) {
@@ -1399,6 +1403,12 @@
return rewriter.notifyMatchFailure(op, "not stablehlo");
if (isa<ConstantOp>(op))
return rewriter.notifyMatchFailure(op, "op is empty constant");
+
+ // Skip ops that have memory effects, similar to XLA's zero extent
+ // simplification, replacing these doesn't save any computation.
+ auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (effectInterface && !effectInterface.hasNoEffect())
+ return rewriter.notifyMatchFailure(op, "op has memory effect");

// If the result is a zero-extent tensor, replace the whole op with an empty
// constant.
@@ -1528,6 +1538,12 @@
DynamicReshapeOpIsStatic, DynamicIotaIsStatic>(context);
}

+void populateStablehloHloImportCanonicalizationPatterns(
+ MLIRContext *context, RewritePatternSet *patterns) {
+ patterns->add<TupleIsRepacking, TupleIsUnpacked, WhileOpImplicitCapture>(
+ context);
+}
+
std::unique_ptr<Pass> createStablehloAggressiveSimplificationPass(
GreedyRewriteConfig config) {
return std::make_unique<StablehloAggressiveSimplificationPass>(config);
diff --ruN a/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
--- stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
+++ stablehlo/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td
@@ -411,7 +411,7 @@
// GetTupleElementOp

// Pattern: get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
-def : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx),
+def TupleIsUnpacked : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx),
(GetOperandN $tuple, $idx)>;

////////

29 changes: 27 additions & 2 deletions xla/hlo/transforms/while_loop_trip_count_annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,34 @@ absl::StatusOr<bool> WhileLoopTripCountAnnotator::Run(
if (instr->opcode() != HloOpcode::kWhile) {
continue;
}
if (auto trip_count = ComputeWhileLoopTripCount(instr)) {

if (auto induction_variable_index = GetLoopInductionVarTupleIdx(instr)) {
// The following analyses all need the induction variable index.
WhileLoopBackendConfig config;
config.mutable_known_trip_count()->set_n(*trip_count);

config.mutable_known_induction_variable()->set_tuple_index(
*induction_variable_index);
if (auto range = MatchTrivialLoopRange(instr);
range.has_value() && range->IsBounded() && range->IsStepKnown() &&
// We store the values in signed integers, so we need to verify
// they fit.
range->max()->GetSignedValue() >= 0 &&
range->min().GetSignedValue() >= 0 &&
range->step()->GetSignedValue() > 0) {
int64_t max = range->max()->GetUnsignedValue();
int64_t min = range->min().GetUnsignedValue();
int64_t step = range->step()->GetSignedValue();
int64_t trip_count = (max - min) / step + 1;

config.mutable_known_trip_count()->set_n(trip_count);
config.mutable_known_init_step()->set_init(min);
config.mutable_known_init_step()->set_step(step);
} else if (auto trip_count = ComputeWhileLoopTripCount(instr)) {
// If this is not a trivial loop, it might still be possible to brute
// force the trip count.
config.mutable_known_trip_count()->set_n(*trip_count);
}

TF_RETURN_IF_ERROR(instr->set_backend_config(config));
changed = true;
}
Expand Down
75 changes: 64 additions & 11 deletions xla/hlo/transforms/while_loop_trip_count_annotator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) {
m->entry_computation()
->root_instruction()
->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(10, config.known_trip_count().n());
EXPECT_TRUE(config.has_known_induction_variable());
EXPECT_TRUE(config.has_known_init_step());
EXPECT_EQ(config.known_trip_count().n(), 10);
EXPECT_EQ(config.known_induction_variable().tuple_index(), 0);
EXPECT_EQ(config.known_init_step().init(), 0);
EXPECT_EQ(config.known_init_step().step(), 1);
}

TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) {
Expand Down Expand Up @@ -95,25 +100,25 @@ TEST_F(TripCountAnnotatorTest, KnownLargeTripCount) {
m->entry_computation()
->root_instruction()
->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(1000000, config.known_trip_count().n());
EXPECT_EQ(config.known_trip_count().n(), 1000000);
}

TEST_F(TripCountAnnotatorTest, NonzeroStart) {
TEST_F(TripCountAnnotatorTest, NonzeroStartStep) {
const char* kModuleStr = R"(
HloModule test
Body {
param = (s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=0
one = s32[] constant(1)
i_plus_one = s32[] add(i, one)
ROOT tuple = (s32[]) tuple(i_plus_one)
two = s32[] constant(2)
i_plus_two = s32[] add(i, two)
ROOT tuple = (s32[]) tuple(i_plus_two)
}
Cond {
param = (s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=0
trip_count = s32[] constant(1000000)
ROOT done = pred[] compare(i, trip_count), direction=LT
max_i = s32[] constant(1000000)
ROOT done = pred[] compare(i, max_i), direction=LT
}
ENTRY test {
Expand All @@ -131,7 +136,10 @@ TEST_F(TripCountAnnotatorTest, NonzeroStart) {
m->entry_computation()
->root_instruction()
->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(999990, config.known_trip_count().n());
EXPECT_EQ(config.known_trip_count().n(), 499995);
EXPECT_TRUE(config.has_known_init_step());
EXPECT_EQ(config.known_init_step().init(), 10);
EXPECT_EQ(config.known_init_step().step(), 2);
}

TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) {
Expand Down Expand Up @@ -167,7 +175,7 @@ TEST_F(TripCountAnnotatorTest, LessThanOrEqualTo) {
m->entry_computation()
->root_instruction()
->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(999991, config.known_trip_count().n());
EXPECT_EQ(config.known_trip_count().n(), 999991);
}

TEST_F(TripCountAnnotatorTest, Int64Overflow) {
Expand Down Expand Up @@ -200,7 +208,52 @@ TEST_F(TripCountAnnotatorTest, Int64Overflow) {
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
WhileLoopTripCountAnnotator pass;
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
EXPECT_FALSE(changed);
EXPECT_TRUE(changed);

TF_ASSERT_OK_AND_ASSIGN(auto config,
m->entry_computation()
->root_instruction()
->backend_config<WhileLoopBackendConfig>());
EXPECT_FALSE(config.has_known_trip_count());
EXPECT_FALSE(config.has_known_init_step());
EXPECT_TRUE(config.has_known_induction_variable());
EXPECT_EQ(config.known_induction_variable().tuple_index(), 0);
}

TEST_F(TripCountAnnotatorTest, NonZeroTupleIndex) {
const char* kModuleStr = R"(
HloModule test
Body {
param = (s32[], s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=1
one = s32[] constant(1)
i_plus_one = s32[] add(i, one)
ROOT tuple = (s32[], s32[]) tuple(one, i_plus_one)
}
Cond {
param = (s32[], s32[]) parameter(0)
i = s32[] get-tuple-element(param), index=1
trip_count = s32[] constant(10)
ROOT done = pred[] compare(i, trip_count), direction=LT
}
ENTRY test {
i_start = s32[] constant(0)
initial_tuple = (s32[], s32[]) tuple(i_start, i_start)
ROOT while = (s32[], s32[]) while(initial_tuple), condition=Cond, body=Body
})";

TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
WhileLoopTripCountAnnotator pass;
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, m.get()));
ASSERT_TRUE(changed);

TF_ASSERT_OK_AND_ASSIGN(auto config,
m->entry_computation()
->root_instruction()
->backend_config<WhileLoopBackendConfig>());
EXPECT_EQ(config.known_induction_variable().tuple_index(), 1);
}

} // namespace
Expand Down
3 changes: 1 addition & 2 deletions xla/mlir_hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1087,8 +1087,7 @@ cc_library(
"stablehlo_ext/transforms/sdy_refine_shapes.cpp",
"stablehlo_ext/transforms/stablehlo_add_quant_dequant_conv.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_entry_function_tuples.cpp",
"stablehlo_ext/transforms/stablehlo_flatten_tuple.cpp",
"stablehlo_ext/transforms/stablehlo_canonicalize_from_hlo_import.cpp",
"stablehlo_ext/transforms/stablehlo_legalize_quant_composite.cpp",
"stablehlo_ext/transforms/stablehlo_prepare_for_hlo_export.cpp",
"stablehlo_ext/transforms/stablehlo_refine_shapes.cpp",
Expand Down
28 changes: 20 additions & 8 deletions xla/mlir_hlo/stablehlo_ext/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,30 @@ def StablehloPrepareForHloExportPass : Pass<"stablehlo-ext-prepare-for-hlo-expor
}];
}

def StablehloFlattenTuplePass : Pass<"stablehlo-ext-flatten-tuple", "func::FuncOp"> {
let summary = "Flatten tuples in operands and results of operators that "
"support both tuple and variadic type.";
}
def StablehloCanonicalizeFromHloImportPass : Pass<"stablehlo-ext-canonicalize-from-hlo-import", "mlir::func::FuncOp"> {
let summary = "Simplify StableHLO imported from HLO";

let dependentDialects = ["stablehlo::StablehloDialect"];

let description = [{
This pass simplifies StableHLO imported from HLO. This pass is a subset of
the graph simplification passes and is intended to bring the imported HLO
back to a more canonical form without applying a full set of graph
simplifications.

Namely, this pass:
* Simplifies tuples, undoing `tuple(get_tuple_element)` and
`get_tuple_element(tuple)`.
* Converts WhileOp explicit captured constants to implicit captures.
* Flattens tuples in operands and results of operators that support both
tuple and variadic type.
* Flattens tuples in entry function of the module.
}];

def StablehloFlattenEntryFunctionTuplesPass : Pass<"stablehlo-ext-expand-flatten-entry-function-tuples", "ModuleOp"> {
let summary = "Flatten HLO tuple for the entry function of the module.";
let options = [
Option<"entryFunctionNameOption", "entry-function", "std::string",
/*default=*/"", "the name of entry function of the module">,
/*default=*/[{"main"}], "the name of entry function of the module">,
];
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
}

def StablehloLegalizeQuantCompositePass : Pass<"stablehlo-ext-legalize-quant-composite", "ModuleOp"> {
Expand Down
Loading

0 comments on commit 5a2abfc

Please sign in to comment.