Skip to content

Commit

Permalink
PR #22800: Change the default value of print_operand_shape_ to false …
Browse files Browse the repository at this point in the history
…and print_large_constants_ to true.

Imported from GitHub PR #22800

Operand shape in long hlo text adds redundant information, which shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options, and it is required for parsability and reproducibility. Turning this on by default. This is still controlled by debug option and the default value of that flag disables the large constants, and that behavior is not changed. Just the default print options change here.

Copybara import of the project:

--
e30dea2 by Shraiysh Vaishay <[email protected]>:

Change the default value of print_operand_shape_ to false and print_large_constants_ to true.

Operand shape in long hlo text adds redundant information, which
shouldn't be required. Changing the default value to off.

The large constants were also printed earlier by default print options,
and it is required for parsability and reproducibility. Turning this on by default.
This is still controlled by debug option and the default value of that
flag disables the large constants, and that behavior is not changed. Just the
default print options change here.

--
7008af0 by Shraiysh Vaishay <[email protected]>:

Handle tests

--
b22d5f9 by Shraiysh Vaishay <[email protected]>:

Fix more tests

--
d51579c by Shraiysh Vaishay <[email protected]>:

Fix more tests

Merging this change closes #22800

FUTURE_COPYBARA_INTEGRATE_REVIEW=#22800 from shraiysh:change_default_print_op_shape d51579c
PiperOrigin-RevId: 730428975
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Mar 1, 2025
1 parent 92b05a5 commit 77fd8e5
Show file tree
Hide file tree
Showing 57 changed files with 1,463 additions and 1,419 deletions.
6 changes: 5 additions & 1 deletion xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,11 @@ TEST_F(AlgorithmTest, Algorithm_BF16_BF16_F32_on_BF16_input_for_multiply) {
CHECK: %[[reduce:.*]] = f32[256]{0} reduce(
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
TF_ASSERT_OK_AND_ASSIGN(
auto ok,
RunFileCheck(
module->ToString(HloPrintOptions().set_print_operand_shape(true)),
pattern));
ASSERT_TRUE(ok);
EXPECT_TRUE(RunAndCompareNoHloPasses(
std::move(module), ErrorSpec{/*aabs=*/1e-7, /*arel=*/1e-7}));
Expand Down
40 changes: 20 additions & 20 deletions xla/codegen/emitters/computation_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,24 @@ TEST_F(ComputationPartitionerTest, PartitionDiamonds) {

constexpr auto kExpected = R"(PartitionedComputation fused_computation:
SUBGRAPH fused_computation_add3 {
%slice3.1 = f32[2]{0} slice(f32[3]{0} %add2), slice={[0:2]}
%slice3.2 = f32[2]{0} slice(f32[3]{0} %add2), slice={[1:3]}
ROOT %add3 = f32[2]{0} add(f32[2]{0} %slice3.1, f32[2]{0} %slice3.2)
%slice3.1 = f32[2]{0} slice(%add2), slice={[0:2]}
%slice3.2 = f32[2]{0} slice(%add2), slice={[1:3]}
ROOT %add3 = f32[2]{0} add(%slice3.1, %slice3.2)
}
SUBGRAPH fused_computation_add2 {
%slice2.1 = f32[3]{0} slice(f32[4]{0} %add1), slice={[0:3]}
%slice2.2 = f32[3]{0} slice(f32[4]{0} %add1), slice={[1:4]}
ROOT %add2 = f32[3]{0} add(f32[3]{0} %slice2.1, f32[3]{0} %slice2.2)
%slice2.1 = f32[3]{0} slice(%add1), slice={[0:3]}
%slice2.2 = f32[3]{0} slice(%add1), slice={[1:4]}
ROOT %add2 = f32[3]{0} add(%slice2.1, %slice2.2)
}
SUBGRAPH fused_computation_add1 {
%slice1.1 = f32[4]{0} slice(f32[5]{0} %add0), slice={[0:4]}
%slice1.2 = f32[4]{0} slice(f32[5]{0} %add0), slice={[1:5]}
ROOT %add1 = f32[4]{0} add(f32[4]{0} %slice1.1, f32[4]{0} %slice1.2)
%slice1.1 = f32[4]{0} slice(%add0), slice={[0:4]}
%slice1.2 = f32[4]{0} slice(%add0), slice={[1:5]}
ROOT %add1 = f32[4]{0} add(%slice1.1, %slice1.2)
}
SUBGRAPH fused_computation_add0 {
%slice0.1 = f32[5]{0} slice(f32[6]{0} %param), slice={[0:5]}
%slice0.2 = f32[5]{0} slice(f32[6]{0} %param), slice={[1:6]}
ROOT %add0 = f32[5]{0} add(f32[5]{0} %slice0.1, f32[5]{0} %slice0.2)
%slice0.1 = f32[5]{0} slice(%param), slice={[0:5]}
%slice0.2 = f32[5]{0} slice(%param), slice={[1:6]}
ROOT %add0 = f32[5]{0} add(%slice0.1, %slice0.2)
}
SUBGRAPH fused_computation_param {
ROOT %param = f32[6]{0} parameter(0)
Expand Down Expand Up @@ -146,15 +146,15 @@ TEST_F(ComputationPartitionerTest, DiamondConcatenate) {

constexpr auto kExpected = R"(PartitionedComputation fused_computation:
SUBGRAPH fused_computation_concat {
%neg = f32[6]{0} negate(f32[6]{0} %log)
%neg = f32[6]{0} negate(%log)
%param2 = f32[6]{0} parameter(1)
%add = f32[6]{0} add(f32[6]{0} %log, f32[6]{0} %param2)
%exp = f32[6]{0} exponential(f32[6]{0} %add)
ROOT %concat = f32[12]{0} concatenate(f32[6]{0} %neg, f32[6]{0} %exp), dimensions={0}
%add = f32[6]{0} add(%log, %param2)
%exp = f32[6]{0} exponential(%add)
ROOT %concat = f32[12]{0} concatenate(%neg, %exp), dimensions={0}
}
SUBGRAPH fused_computation_log {
%param1 = f32[6]{0} parameter(0)
ROOT %log = f32[6]{0} log(f32[6]{0} %param1)
ROOT %log = f32[6]{0} log(%param1)
})";
EXPECT_EQ(computation.ToString(6), kExpected);
}
Expand All @@ -178,9 +178,9 @@ TEST_F(ComputationPartitionerTest, TupleRoot) {
SUBGRAPH fused_computation_root {
%p0 = f32[6]{0} parameter(0)
%p1 = f32[6]{0} parameter(1)
%add = f32[6]{0} add(f32[6]{0} %p0, f32[6]{0} %p1)
%sub = f32[6]{0} subtract(f32[6]{0} %p0, f32[6]{0} %p1)
ROOT %root = (f32[6]{0}, f32[6]{0}) tuple(f32[6]{0} %add, f32[6]{0} %sub)
%add = f32[6]{0} add(%p0, %p1)
%sub = f32[6]{0} subtract(%p0, %p1)
ROOT %root = (f32[6]{0}, f32[6]{0}) tuple(%add, %sub)
})";
EXPECT_EQ(computation.ToString(6), kExpected);
}
Expand Down
2 changes: 1 addition & 1 deletion xla/codegen/testlib/kernel_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_from_instruction(self):
"HloModule sine_module,",
"{",
"%input = s32[4]{0} parameter(0)",
"ROOT %sine = s32[4]{0} sine(s32[4]{0} %input)",
"ROOT %sine = s32[4]{0} sine(%input)",
"}",
]
self.assertContainsInOrder(
Expand Down
4 changes: 2 additions & 2 deletions xla/hlo/ir/hlo_print_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class HloPrintOptions {
: print_operand_index_annotation_interval_(5),
print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
indent_amount_(0),
print_large_constants_(false),
print_large_constants_(true),
print_only_essential_constants_(false),
print_original_value_(true),
print_metadata_(true),
Expand All @@ -62,7 +62,7 @@ class HloPrintOptions {
compact_operands_(false),
include_layout_in_shapes_(true),
print_result_shape_(true),
print_operand_shape_(true),
print_operand_shape_(false),
print_operand_names_(true),
print_program_shape_(true),
print_percent_(true),
Expand Down
30 changes: 23 additions & 7 deletions xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_print_options.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/parser/hlo_lexer.h"
#include "xla/hlo/testlib/pattern_matcher_gmock.h"
Expand Down Expand Up @@ -2809,9 +2810,12 @@ class HloParameterizedParserTest
: public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
// Expects "ToString(ParseHloModule(std::string)) == string", that is, parses
// the string, asserts that it succeeded, stringifies the parsed module, and
// checks that it equals the original string.
// Expects "ToString(ParseHloModule(ToString(ParseHloModule(std::string)))) ==
// string", that is, parses the string, asserts that it succeeded, stringifies
// the parsed module, parses this string to ensure that the default ToString()
// version is parsable, then stringifies the newly parsed module with
// appropriate options for original tests, and checks that it equals the
// original string.
void ExpectEqual() {
VLOG(3) << "Running HloParameterizedParserTest with short_form = "
<< short_form << ", proto_round_trip = " << proto_round_trip;
Expand All @@ -2827,9 +2831,20 @@ class HloParameterizedParserTest
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
module = std::move(verified_module);
verified_module = std::make_unique<VerifiedHloModule>(
GetParam().test_name, config,
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(
module->ToString(HloPrintOptions().set_print_operand_shape(true))));
module = std::move(verified_module);
} else {
TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnUnverifiedModule(original, config));
TF_ASSERT_OK_AND_ASSIGN(
module,
ParseAndReturnUnverifiedModule(module->ToString(), module->config()));
}
if (proto_round_trip) {
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
Expand All @@ -2838,9 +2853,8 @@ class HloParameterizedParserTest
if (short_form) {
EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
} else {
EXPECT_EQ(
original,
module->ToString(HloPrintOptions().set_print_large_constants(true)));
EXPECT_EQ(original, module->ToString(
HloPrintOptions().set_print_operand_shape(true)));
}
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instr : computation->instructions()) {
Expand Down Expand Up @@ -3276,7 +3290,9 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] {
)";
auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status());
EXPECT_EQ(result.value()->ToString(HloPrintOptions()), original);
EXPECT_EQ(result.value()->ToString(
HloPrintOptions().set_print_large_constants(false)),
original);
}

TEST_F(HloParserTest, NegativeNan) {
Expand Down
40 changes: 20 additions & 20 deletions xla/hlo/tools/tests/generate_hlo_test_checks_test_output.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ ENTRY no_op {
// CHECK-LABEL: ENTRY %test_case
// CHECK-NEXT: %[[foo:[^ ]+]] = u8[] parameter(0)
// CHECK-NEXT: %[[bar:[^ ]+]] = u8[] parameter(1)
// CHECK-NEXT: %[[foobar:[^ ]+]] = u8[] add(u8[] %[[foo]], u8[] %[[bar]])
// CHECK-NEXT: %[[foobar:[^ ]+]] = u8[] add(%[[foo]], %[[bar]])
// CHECK-NEXT: %[[baz:[^ ]+]] = u8[] parameter(2)
// CHECK-NEXT: %[[qux:[^ ]+]] = u8[] parameter(3)
// CHECK-NEXT: %[[bazqux:[^ ]+]] = u8[] add(u8[] %[[baz]], u8[] %[[qux]])
// CHECK-NEXT: ROOT %[[foobarbazqux:[^ ]+]] = u8[] add(u8[] %[[foobar]], u8[] %[[bazqux]])
// CHECK-NEXT: %[[bazqux:[^ ]+]] = u8[] add(%[[baz]], %[[qux]])
// CHECK-NEXT: ROOT %[[foobarbazqux:[^ ]+]] = u8[] add(%[[foobar]], %[[bazqux]])

HloModule TestBasicFunctionality

Expand All @@ -51,18 +51,18 @@ ENTRY test_case {

// CHECK-LABEL: ENTRY %test_case
// CHECK-NEXT: %[[constant:[^ ]+]] = f32[] constant(1)
// CHECK-NEXT: %[[broadcast:[^ ]+]] = f32[8,7]{1,0} broadcast(f32[] %[[constant]]), dimensions={}
// CHECK-NEXT: %[[broadcast:[^ ]+]] = f32[8,7]{1,0} broadcast(%[[constant]]), dimensions={}
// CHECK-NEXT: %[[parameter_anon:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} parameter(0)
// CHECK-NEXT: %[[reshape0:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[parameter_anon]])
// CHECK-NEXT: %[[reshape:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} reshape(f32[8,7]{1,0} %[[reshape0]])
// CHECK-NEXT: %[[negate_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} negate(f32[1,8,1,7]{3,2,1,0} %[[reshape]])
// CHECK-NEXT: %[[reshape_1:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[negate_1]])
// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[8,7]{1,0} exponential(f32[8,7]{1,0} %[[reshape_1]])
// CHECK-NEXT: %[[add_1:[^ ]+]] = f32[8,7]{1,0} add(f32[8,7]{1,0} %[[broadcast]], f32[8,7]{1,0} %[[exponential]])
// CHECK-NEXT: %[[divide:[^ ]+]] = f32[8,7]{1,0} divide(f32[8,7]{1,0} %[[broadcast]], f32[8,7]{1,0} %[[add_1]])
// CHECK-NEXT: %[[reshape0:[^ ]+]] = f32[8,7]{1,0} reshape(%[[parameter_anon]])
// CHECK-NEXT: %[[reshape:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} reshape(%[[reshape0]])
// CHECK-NEXT: %[[negate_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} negate(%[[reshape]])
// CHECK-NEXT: %[[reshape_1:[^ ]+]] = f32[8,7]{1,0} reshape(%[[negate_1]])
// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[8,7]{1,0} exponential(%[[reshape_1]])
// CHECK-NEXT: %[[add_1:[^ ]+]] = f32[8,7]{1,0} add(%[[broadcast]], %[[exponential]])
// CHECK-NEXT: %[[divide:[^ ]+]] = f32[8,7]{1,0} divide(%[[broadcast]], %[[add_1]])
// CHECK-NEXT: %[[parameter_anon_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} parameter(1)
// CHECK-NEXT: %[[reshape1:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[parameter_anon_1]])
// CHECK-NEXT: ROOT %[[add:[^ ]+]] = f32[8,7]{1,0} add(f32[8,7]{1,0} %[[divide]], f32[8,7]{1,0} %[[reshape1]])
// CHECK-NEXT: %[[reshape1:[^ ]+]] = f32[8,7]{1,0} reshape(%[[parameter_anon_1]])
// CHECK-NEXT: ROOT %[[add:[^ ]+]] = f32[8,7]{1,0} add(%[[divide]], %[[reshape1]])

HloModule TestWithRelevantOptimizationPasses

Expand Down Expand Up @@ -93,21 +93,21 @@ ENTRY no_op {
// CHECK: %[[$foo_bar_baz_0:[^ ]+]]
// CHECK-NEXT: %[[foo_bar_baz_0_1:[^ ]+]] = f32[] parameter(0)
// CHECK-NEXT: %[[foo_bar_baz_0_2:[^ ]+]] = f32[] parameter(1)
// CHECK-NEXT: %[[foo_bar_baz_0_3:[^ ]+]] = f32[] multiply(f32[] %[[foo_bar_baz_0_1]], f32[] %[[foo_bar_baz_0_2]])
// CHECK-NEXT: ROOT %[[foobarbaz:[^ ]+]] = f32[] add(f32[] %[[foo_bar_baz_0_1]], f32[] %[[foo_bar_baz_0_3]])
// CHECK-NEXT: %[[foo_bar_baz_0_3:[^ ]+]] = f32[] multiply(%[[foo_bar_baz_0_1]], %[[foo_bar_baz_0_2]])
// CHECK-NEXT: ROOT %[[foobarbaz:[^ ]+]] = f32[] add(%[[foo_bar_baz_0_1]], %[[foo_bar_baz_0_3]])

// CHECK-LABEL: %foo.bar.baz_0
// CHECK-NEXT: %[[foo_bar_baz_0_4:[^ ]+]] = f32[] parameter(0)
// CHECK-NEXT: %[[foo_bar_baz_0_5:[^ ]+]] = f32[] parameter(1)
// CHECK-NEXT: %[[foo_bar_baz_0_6:[^ ]+]] = f32[] add(f32[] %[[foo_bar_baz_0_4]], f32[] %[[foo_bar_baz_0_5]])
// CHECK-NEXT: ROOT %[[foobarbaz_1:[^ ]+]] = f32[] multiply(f32[] %[[foo_bar_baz_0_4]], f32[] %[[foo_bar_baz_0_6]])
// CHECK-NEXT: %[[foo_bar_baz_0_6:[^ ]+]] = f32[] add(%[[foo_bar_baz_0_4]], %[[foo_bar_baz_0_5]])
// CHECK-NEXT: ROOT %[[foobarbaz_1:[^ ]+]] = f32[] multiply(%[[foo_bar_baz_0_4]], %[[foo_bar_baz_0_6]])

// CHECK-LABEL: ENTRY %foobarbaz
// CHECK-NEXT: %[[constant_0:[^ ]+]] = f32[4]{0} constant({8.7, 6.5, 4.3, 2.1})
// CHECK-NEXT: %[[constant_0_1:[^ ]+]] = f32[4]{0} constant({1.2, 3.4, 5.6, 7.8})
// CHECK-NEXT: %[[call_foo_bar_baz_0:[^ ]+]] = f32[] call(f32[4]{0} %[[constant_0]], f32[4]{0} %[[constant_0_1]]), to_apply=%foo.bar.baz_0
// CHECK-NEXT: %[[call_foo_bar_baz_0_1:[^ ]+]] = f32[] call(f32[4]{0} %[[constant_0_1]], f32[4]{0} %[[constant_0]]), to_apply=%[[$foo_bar_baz_0]]
// CHECK-NEXT: ROOT %[[sum:[^ ]+]] = f32[] add(f32[] %[[call_foo_bar_baz_0_1]], f32[] %[[call_foo_bar_baz_0_1]])
// CHECK-NEXT: %[[call_foo_bar_baz_0:[^ ]+]] = f32[] call(%[[constant_0]], %[[constant_0_1]]), to_apply=%foo.bar.baz_0
// CHECK-NEXT: %[[call_foo_bar_baz_0_1:[^ ]+]] = f32[] call(%[[constant_0_1]], %[[constant_0]]), to_apply=%[[$foo_bar_baz_0]]
// CHECK-NEXT: ROOT %[[sum:[^ ]+]] = f32[] add(%[[call_foo_bar_baz_0_1]], %[[call_foo_bar_baz_0_1]])

// Test the tool's ability to disambiguate symbols with extremely similar names.
HloModule TestSymbolNameDisambiguation
Expand Down
Loading

0 comments on commit 77fd8e5

Please sign in to comment.