Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #22800: Change the default value of print_operand_shape_ to false and print_large_constants_ to true. #23028

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion xla/backends/gpu/codegen/triton/dot_algorithms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,11 @@ TEST_F(AlgorithmTest, Algorithm_BF16_BF16_F32_on_BF16_input_for_multiply) {
CHECK: %[[reduce:.*]] = f32[256]{0} reduce(
)";
TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern));
TF_ASSERT_OK_AND_ASSIGN(
auto ok,
RunFileCheck(
module->ToString(HloPrintOptions().set_print_operand_shape(true)),
pattern));
ASSERT_TRUE(ok);
EXPECT_TRUE(RunAndCompareNoHloPasses(
std::move(module), ErrorSpec{/*aabs=*/1e-7, /*arel=*/1e-7}));
Expand Down
40 changes: 20 additions & 20 deletions xla/codegen/emitters/computation_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,24 @@ TEST_F(ComputationPartitionerTest, PartitionDiamonds) {

constexpr auto kExpected = R"(PartitionedComputation fused_computation:
SUBGRAPH fused_computation_add3 {
%slice3.1 = f32[2]{0} slice(f32[3]{0} %add2), slice={[0:2]}
%slice3.2 = f32[2]{0} slice(f32[3]{0} %add2), slice={[1:3]}
ROOT %add3 = f32[2]{0} add(f32[2]{0} %slice3.1, f32[2]{0} %slice3.2)
%slice3.1 = f32[2]{0} slice(%add2), slice={[0:2]}
%slice3.2 = f32[2]{0} slice(%add2), slice={[1:3]}
ROOT %add3 = f32[2]{0} add(%slice3.1, %slice3.2)
}
SUBGRAPH fused_computation_add2 {
%slice2.1 = f32[3]{0} slice(f32[4]{0} %add1), slice={[0:3]}
%slice2.2 = f32[3]{0} slice(f32[4]{0} %add1), slice={[1:4]}
ROOT %add2 = f32[3]{0} add(f32[3]{0} %slice2.1, f32[3]{0} %slice2.2)
%slice2.1 = f32[3]{0} slice(%add1), slice={[0:3]}
%slice2.2 = f32[3]{0} slice(%add1), slice={[1:4]}
ROOT %add2 = f32[3]{0} add(%slice2.1, %slice2.2)
}
SUBGRAPH fused_computation_add1 {
%slice1.1 = f32[4]{0} slice(f32[5]{0} %add0), slice={[0:4]}
%slice1.2 = f32[4]{0} slice(f32[5]{0} %add0), slice={[1:5]}
ROOT %add1 = f32[4]{0} add(f32[4]{0} %slice1.1, f32[4]{0} %slice1.2)
%slice1.1 = f32[4]{0} slice(%add0), slice={[0:4]}
%slice1.2 = f32[4]{0} slice(%add0), slice={[1:5]}
ROOT %add1 = f32[4]{0} add(%slice1.1, %slice1.2)
}
SUBGRAPH fused_computation_add0 {
%slice0.1 = f32[5]{0} slice(f32[6]{0} %param), slice={[0:5]}
%slice0.2 = f32[5]{0} slice(f32[6]{0} %param), slice={[1:6]}
ROOT %add0 = f32[5]{0} add(f32[5]{0} %slice0.1, f32[5]{0} %slice0.2)
%slice0.1 = f32[5]{0} slice(%param), slice={[0:5]}
%slice0.2 = f32[5]{0} slice(%param), slice={[1:6]}
ROOT %add0 = f32[5]{0} add(%slice0.1, %slice0.2)
}
SUBGRAPH fused_computation_param {
ROOT %param = f32[6]{0} parameter(0)
Expand Down Expand Up @@ -146,15 +146,15 @@ TEST_F(ComputationPartitionerTest, DiamondConcatenate) {

constexpr auto kExpected = R"(PartitionedComputation fused_computation:
SUBGRAPH fused_computation_concat {
%neg = f32[6]{0} negate(f32[6]{0} %log)
%neg = f32[6]{0} negate(%log)
%param2 = f32[6]{0} parameter(1)
%add = f32[6]{0} add(f32[6]{0} %log, f32[6]{0} %param2)
%exp = f32[6]{0} exponential(f32[6]{0} %add)
ROOT %concat = f32[12]{0} concatenate(f32[6]{0} %neg, f32[6]{0} %exp), dimensions={0}
%add = f32[6]{0} add(%log, %param2)
%exp = f32[6]{0} exponential(%add)
ROOT %concat = f32[12]{0} concatenate(%neg, %exp), dimensions={0}
}
SUBGRAPH fused_computation_log {
%param1 = f32[6]{0} parameter(0)
ROOT %log = f32[6]{0} log(f32[6]{0} %param1)
ROOT %log = f32[6]{0} log(%param1)
})";
EXPECT_EQ(computation.ToString(6), kExpected);
}
Expand All @@ -178,9 +178,9 @@ TEST_F(ComputationPartitionerTest, TupleRoot) {
SUBGRAPH fused_computation_root {
%p0 = f32[6]{0} parameter(0)
%p1 = f32[6]{0} parameter(1)
%add = f32[6]{0} add(f32[6]{0} %p0, f32[6]{0} %p1)
%sub = f32[6]{0} subtract(f32[6]{0} %p0, f32[6]{0} %p1)
ROOT %root = (f32[6]{0}, f32[6]{0}) tuple(f32[6]{0} %add, f32[6]{0} %sub)
%add = f32[6]{0} add(%p0, %p1)
%sub = f32[6]{0} subtract(%p0, %p1)
ROOT %root = (f32[6]{0}, f32[6]{0}) tuple(%add, %sub)
})";
EXPECT_EQ(computation.ToString(6), kExpected);
}
Expand Down
2 changes: 1 addition & 1 deletion xla/codegen/testlib/kernel_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_from_instruction(self):
"HloModule sine_module,",
"{",
"%input = s32[4]{0} parameter(0)",
"ROOT %sine = s32[4]{0} sine(s32[4]{0} %input)",
"ROOT %sine = s32[4]{0} sine(%input)",
"}",
]
self.assertContainsInOrder(
Expand Down
4 changes: 2 additions & 2 deletions xla/hlo/ir/hlo_print_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class HloPrintOptions {
: print_operand_index_annotation_interval_(5),
print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
indent_amount_(0),
print_large_constants_(false),
print_large_constants_(true),
print_only_essential_constants_(false),
print_original_value_(true),
print_metadata_(true),
Expand All @@ -62,7 +62,7 @@ class HloPrintOptions {
compact_operands_(false),
include_layout_in_shapes_(true),
print_result_shape_(true),
print_operand_shape_(true),
print_operand_shape_(false),
print_operand_names_(true),
print_program_shape_(true),
print_percent_(true),
Expand Down
30 changes: 23 additions & 7 deletions xla/hlo/parser/hlo_parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_print_options.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/parser/hlo_lexer.h"
#include "xla/hlo/testlib/pattern_matcher_gmock.h"
Expand Down Expand Up @@ -2809,9 +2810,12 @@ class HloParameterizedParserTest
: public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
// Expects "ToString(ParseHloModule(std::string)) == string", that is, parses
// the string, asserts that it succeeded, stringifies the parsed module, and
// checks that it equals the original string.
// Expects "ToString(ParseHloModule(ToString(ParseHloModule(std::string)))) ==
// string", that is, parses the string, asserts that it succeeded, stringifies
// the parsed module, parses this string to ensure that the default ToString()
// version is parsable, then stringifies the newly parsed module with
// appropriate options for original tests, and checks that it equals the
// original string.
void ExpectEqual() {
VLOG(3) << "Running HloParameterizedParserTest with short_form = "
<< short_form << ", proto_round_trip = " << proto_round_trip;
Expand All @@ -2827,9 +2831,20 @@ class HloParameterizedParserTest
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
module = std::move(verified_module);
verified_module = std::make_unique<VerifiedHloModule>(
GetParam().test_name, config,
/*verifier_layout_sensitive=*/false,
/*allow_mixed_precision_in_hlo_verifier=*/true,
ShapeUtil::ByteSizeOfElements);
TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(
module->ToString(HloPrintOptions().set_print_operand_shape(true))));
module = std::move(verified_module);
} else {
TF_ASSERT_OK_AND_ASSIGN(module,
ParseAndReturnUnverifiedModule(original, config));
TF_ASSERT_OK_AND_ASSIGN(
module,
ParseAndReturnUnverifiedModule(module->ToString(), module->config()));
}
if (proto_round_trip) {
TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
Expand All @@ -2838,9 +2853,8 @@ class HloParameterizedParserTest
if (short_form) {
EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
} else {
EXPECT_EQ(
original,
module->ToString(HloPrintOptions().set_print_large_constants(true)));
EXPECT_EQ(original, module->ToString(
HloPrintOptions().set_print_operand_shape(true)));
}
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instr : computation->instructions()) {
Expand Down Expand Up @@ -3276,7 +3290,9 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] {
)";
auto result = ParseAndReturnVerifiedModule(original);
TF_EXPECT_OK(result.status());
EXPECT_EQ(result.value()->ToString(HloPrintOptions()), original);
EXPECT_EQ(result.value()->ToString(
HloPrintOptions().set_print_large_constants(false)),
original);
}

TEST_F(HloParserTest, NegativeNan) {
Expand Down
40 changes: 20 additions & 20 deletions xla/hlo/tools/tests/generate_hlo_test_checks_test_output.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ ENTRY no_op {
// CHECK-LABEL: ENTRY %test_case
// CHECK-NEXT: %[[foo:[^ ]+]] = u8[] parameter(0)
// CHECK-NEXT: %[[bar:[^ ]+]] = u8[] parameter(1)
// CHECK-NEXT: %[[foobar:[^ ]+]] = u8[] add(u8[] %[[foo]], u8[] %[[bar]])
// CHECK-NEXT: %[[foobar:[^ ]+]] = u8[] add(%[[foo]], %[[bar]])
// CHECK-NEXT: %[[baz:[^ ]+]] = u8[] parameter(2)
// CHECK-NEXT: %[[qux:[^ ]+]] = u8[] parameter(3)
// CHECK-NEXT: %[[bazqux:[^ ]+]] = u8[] add(u8[] %[[baz]], u8[] %[[qux]])
// CHECK-NEXT: ROOT %[[foobarbazqux:[^ ]+]] = u8[] add(u8[] %[[foobar]], u8[] %[[bazqux]])
// CHECK-NEXT: %[[bazqux:[^ ]+]] = u8[] add(%[[baz]], %[[qux]])
// CHECK-NEXT: ROOT %[[foobarbazqux:[^ ]+]] = u8[] add(%[[foobar]], %[[bazqux]])

HloModule TestBasicFunctionality

Expand All @@ -51,18 +51,18 @@ ENTRY test_case {

// CHECK-LABEL: ENTRY %test_case
// CHECK-NEXT: %[[constant:[^ ]+]] = f32[] constant(1)
// CHECK-NEXT: %[[broadcast:[^ ]+]] = f32[8,7]{1,0} broadcast(f32[] %[[constant]]), dimensions={}
// CHECK-NEXT: %[[broadcast:[^ ]+]] = f32[8,7]{1,0} broadcast(%[[constant]]), dimensions={}
// CHECK-NEXT: %[[parameter_anon:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} parameter(0)
// CHECK-NEXT: %[[reshape0:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[parameter_anon]])
// CHECK-NEXT: %[[reshape:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} reshape(f32[8,7]{1,0} %[[reshape0]])
// CHECK-NEXT: %[[negate_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} negate(f32[1,8,1,7]{3,2,1,0} %[[reshape]])
// CHECK-NEXT: %[[reshape_1:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[negate_1]])
// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[8,7]{1,0} exponential(f32[8,7]{1,0} %[[reshape_1]])
// CHECK-NEXT: %[[add_1:[^ ]+]] = f32[8,7]{1,0} add(f32[8,7]{1,0} %[[broadcast]], f32[8,7]{1,0} %[[exponential]])
// CHECK-NEXT: %[[divide:[^ ]+]] = f32[8,7]{1,0} divide(f32[8,7]{1,0} %[[broadcast]], f32[8,7]{1,0} %[[add_1]])
// CHECK-NEXT: %[[reshape0:[^ ]+]] = f32[8,7]{1,0} reshape(%[[parameter_anon]])
// CHECK-NEXT: %[[reshape:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} reshape(%[[reshape0]])
// CHECK-NEXT: %[[negate_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} negate(%[[reshape]])
// CHECK-NEXT: %[[reshape_1:[^ ]+]] = f32[8,7]{1,0} reshape(%[[negate_1]])
// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[8,7]{1,0} exponential(%[[reshape_1]])
// CHECK-NEXT: %[[add_1:[^ ]+]] = f32[8,7]{1,0} add(%[[broadcast]], %[[exponential]])
// CHECK-NEXT: %[[divide:[^ ]+]] = f32[8,7]{1,0} divide(%[[broadcast]], %[[add_1]])
// CHECK-NEXT: %[[parameter_anon_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} parameter(1)
// CHECK-NEXT: %[[reshape1:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[parameter_anon_1]])
// CHECK-NEXT: ROOT %[[add:[^ ]+]] = f32[8,7]{1,0} add(f32[8,7]{1,0} %[[divide]], f32[8,7]{1,0} %[[reshape1]])
// CHECK-NEXT: %[[reshape1:[^ ]+]] = f32[8,7]{1,0} reshape(%[[parameter_anon_1]])
// CHECK-NEXT: ROOT %[[add:[^ ]+]] = f32[8,7]{1,0} add(%[[divide]], %[[reshape1]])

HloModule TestWithRelevantOptimizationPasses

Expand Down Expand Up @@ -93,21 +93,21 @@ ENTRY no_op {
// CHECK: %[[$foo_bar_baz_0:[^ ]+]]
// CHECK-NEXT: %[[foo_bar_baz_0_1:[^ ]+]] = f32[] parameter(0)
// CHECK-NEXT: %[[foo_bar_baz_0_2:[^ ]+]] = f32[] parameter(1)
// CHECK-NEXT: %[[foo_bar_baz_0_3:[^ ]+]] = f32[] multiply(f32[] %[[foo_bar_baz_0_1]], f32[] %[[foo_bar_baz_0_2]])
// CHECK-NEXT: ROOT %[[foobarbaz:[^ ]+]] = f32[] add(f32[] %[[foo_bar_baz_0_1]], f32[] %[[foo_bar_baz_0_3]])
// CHECK-NEXT: %[[foo_bar_baz_0_3:[^ ]+]] = f32[] multiply(%[[foo_bar_baz_0_1]], %[[foo_bar_baz_0_2]])
// CHECK-NEXT: ROOT %[[foobarbaz:[^ ]+]] = f32[] add(%[[foo_bar_baz_0_1]], %[[foo_bar_baz_0_3]])

// CHECK-LABEL: %foo.bar.baz_0
// CHECK-NEXT: %[[foo_bar_baz_0_4:[^ ]+]] = f32[] parameter(0)
// CHECK-NEXT: %[[foo_bar_baz_0_5:[^ ]+]] = f32[] parameter(1)
// CHECK-NEXT: %[[foo_bar_baz_0_6:[^ ]+]] = f32[] add(f32[] %[[foo_bar_baz_0_4]], f32[] %[[foo_bar_baz_0_5]])
// CHECK-NEXT: ROOT %[[foobarbaz_1:[^ ]+]] = f32[] multiply(f32[] %[[foo_bar_baz_0_4]], f32[] %[[foo_bar_baz_0_6]])
// CHECK-NEXT: %[[foo_bar_baz_0_6:[^ ]+]] = f32[] add(%[[foo_bar_baz_0_4]], %[[foo_bar_baz_0_5]])
// CHECK-NEXT: ROOT %[[foobarbaz_1:[^ ]+]] = f32[] multiply(%[[foo_bar_baz_0_4]], %[[foo_bar_baz_0_6]])

// CHECK-LABEL: ENTRY %foobarbaz
// CHECK-NEXT: %[[constant_0:[^ ]+]] = f32[4]{0} constant({8.7, 6.5, 4.3, 2.1})
// CHECK-NEXT: %[[constant_0_1:[^ ]+]] = f32[4]{0} constant({1.2, 3.4, 5.6, 7.8})
// CHECK-NEXT: %[[call_foo_bar_baz_0:[^ ]+]] = f32[] call(f32[4]{0} %[[constant_0]], f32[4]{0} %[[constant_0_1]]), to_apply=%foo.bar.baz_0
// CHECK-NEXT: %[[call_foo_bar_baz_0_1:[^ ]+]] = f32[] call(f32[4]{0} %[[constant_0_1]], f32[4]{0} %[[constant_0]]), to_apply=%[[$foo_bar_baz_0]]
// CHECK-NEXT: ROOT %[[sum:[^ ]+]] = f32[] add(f32[] %[[call_foo_bar_baz_0_1]], f32[] %[[call_foo_bar_baz_0_1]])
// CHECK-NEXT: %[[call_foo_bar_baz_0:[^ ]+]] = f32[] call(%[[constant_0]], %[[constant_0_1]]), to_apply=%foo.bar.baz_0
// CHECK-NEXT: %[[call_foo_bar_baz_0_1:[^ ]+]] = f32[] call(%[[constant_0_1]], %[[constant_0]]), to_apply=%[[$foo_bar_baz_0]]
// CHECK-NEXT: ROOT %[[sum:[^ ]+]] = f32[] add(%[[call_foo_bar_baz_0_1]], %[[call_foo_bar_baz_0_1]])

// Test the tool's ability to disambiguate symbols with extremely similar names.
HloModule TestSymbolNameDisambiguation
Expand Down
Loading
Loading