diff --git a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc index 7fbb821159ff1b..cc7500c1e1deca 100644 --- a/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc +++ b/xla/backends/gpu/codegen/triton/dot_algorithms_test.cc @@ -283,7 +283,11 @@ TEST_F(AlgorithmTest, Algorithm_BF16_BF16_F32_on_BF16_input_for_multiply) { CHECK: %[[reduce:.*]] = f32[256]{0} reduce( )"; TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); - TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + TF_ASSERT_OK_AND_ASSIGN( + auto ok, + RunFileCheck( + module->ToString(HloPrintOptions().set_print_operand_shape(true)), + pattern)); ASSERT_TRUE(ok); EXPECT_TRUE(RunAndCompareNoHloPasses( std::move(module), ErrorSpec{/*aabs=*/1e-7, /*arel=*/1e-7})); diff --git a/xla/codegen/emitters/computation_partitioner_test.cc b/xla/codegen/emitters/computation_partitioner_test.cc index eaf2c68584c115..5fa85a47c018c0 100644 --- a/xla/codegen/emitters/computation_partitioner_test.cc +++ b/xla/codegen/emitters/computation_partitioner_test.cc @@ -82,24 +82,24 @@ TEST_F(ComputationPartitionerTest, PartitionDiamonds) { constexpr auto kExpected = R"(PartitionedComputation fused_computation: SUBGRAPH fused_computation_add3 { - %slice3.1 = f32[2]{0} slice(f32[3]{0} %add2), slice={[0:2]} - %slice3.2 = f32[2]{0} slice(f32[3]{0} %add2), slice={[1:3]} - ROOT %add3 = f32[2]{0} add(f32[2]{0} %slice3.1, f32[2]{0} %slice3.2) + %slice3.1 = f32[2]{0} slice(%add2), slice={[0:2]} + %slice3.2 = f32[2]{0} slice(%add2), slice={[1:3]} + ROOT %add3 = f32[2]{0} add(%slice3.1, %slice3.2) } SUBGRAPH fused_computation_add2 { - %slice2.1 = f32[3]{0} slice(f32[4]{0} %add1), slice={[0:3]} - %slice2.2 = f32[3]{0} slice(f32[4]{0} %add1), slice={[1:4]} - ROOT %add2 = f32[3]{0} add(f32[3]{0} %slice2.1, f32[3]{0} %slice2.2) + %slice2.1 = f32[3]{0} slice(%add1), slice={[0:3]} + %slice2.2 = f32[3]{0} slice(%add1), slice={[1:4]} + ROOT %add2 = f32[3]{0} add(%slice2.1, %slice2.2) } SUBGRAPH fused_computation_add1 { - %slice1.1 = f32[4]{0} slice(f32[5]{0} %add0), slice={[0:4]} - %slice1.2 = f32[4]{0} slice(f32[5]{0} %add0), slice={[1:5]} - ROOT %add1 = f32[4]{0} add(f32[4]{0} %slice1.1, f32[4]{0} %slice1.2) + %slice1.1 = f32[4]{0} slice(%add0), slice={[0:4]} + %slice1.2 = f32[4]{0} slice(%add0), slice={[1:5]} + ROOT %add1 = f32[4]{0} add(%slice1.1, %slice1.2) } SUBGRAPH fused_computation_add0 { - %slice0.1 = f32[5]{0} slice(f32[6]{0} %param), slice={[0:5]} - %slice0.2 = f32[5]{0} slice(f32[6]{0} %param), slice={[1:6]} - ROOT %add0 = f32[5]{0} add(f32[5]{0} %slice0.1, f32[5]{0} %slice0.2) + %slice0.1 = f32[5]{0} slice(%param), slice={[0:5]} + %slice0.2 = f32[5]{0} slice(%param), slice={[1:6]} + ROOT %add0 = f32[5]{0} add(%slice0.1, %slice0.2) } SUBGRAPH fused_computation_param { ROOT %param = f32[6]{0} parameter(0) @@ -146,15 +146,15 @@ TEST_F(ComputationPartitionerTest, DiamondConcatenate) { constexpr auto kExpected = R"(PartitionedComputation fused_computation: SUBGRAPH fused_computation_concat { - %neg = f32[6]{0} negate(f32[6]{0} %log) + %neg = f32[6]{0} negate(%log) %param2 = f32[6]{0} parameter(1) - %add = f32[6]{0} add(f32[6]{0} %log, f32[6]{0} %param2) - %exp = f32[6]{0} exponential(f32[6]{0} %add) - ROOT %concat = f32[12]{0} concatenate(f32[6]{0} %neg, f32[6]{0} %exp), dimensions={0} + %add = f32[6]{0} add(%log, %param2) + %exp = f32[6]{0} exponential(%add) + ROOT %concat = f32[12]{0} concatenate(%neg, %exp), dimensions={0} } SUBGRAPH fused_computation_log { %param1 = f32[6]{0} parameter(0) - ROOT %log = f32[6]{0} log(f32[6]{0} %param1) + ROOT %log = f32[6]{0} log(%param1) })"; EXPECT_EQ(computation.ToString(6), kExpected); } @@ -178,9 +178,9 @@ TEST_F(ComputationPartitionerTest, TupleRoot) { SUBGRAPH fused_computation_root { %p0 = f32[6]{0} parameter(0) %p1 = f32[6]{0} parameter(1) - %add = f32[6]{0} add(f32[6]{0} %p0, f32[6]{0} %p1) - %sub = f32[6]{0} subtract(f32[6]{0} %p0, f32[6]{0} %p1) - ROOT %root = (f32[6]{0}, f32[6]{0}) tuple(f32[6]{0} %add, f32[6]{0} %sub) + %add = f32[6]{0} add(%p0, %p1) + %sub = f32[6]{0} subtract(%p0, %p1) + ROOT %root = (f32[6]{0}, f32[6]{0}) tuple(%add, %sub) })"; EXPECT_EQ(computation.ToString(6), kExpected); } diff --git a/xla/codegen/testlib/kernel_runner_test.py b/xla/codegen/testlib/kernel_runner_test.py index 7af15ea213e927..f463840fe5fbba 100644 --- a/xla/codegen/testlib/kernel_runner_test.py +++ b/xla/codegen/testlib/kernel_runner_test.py @@ -38,7 +38,7 @@ def test_from_instruction(self): "HloModule sine_module,", "{", "%input = s32[4]{0} parameter(0)", - "ROOT %sine = s32[4]{0} sine(s32[4]{0} %input)", + "ROOT %sine = s32[4]{0} sine(%input)", "}", ] self.assertContainsInOrder( diff --git a/xla/hlo/ir/hlo_print_options.h b/xla/hlo/ir/hlo_print_options.h index 4fc3d7864b08d9..557ce7f276405f 100644 --- a/xla/hlo/ir/hlo_print_options.h +++ b/xla/hlo/ir/hlo_print_options.h @@ -51,7 +51,7 @@ class HloPrintOptions { : print_operand_index_annotation_interval_(5), print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), indent_amount_(0), - print_large_constants_(false), + print_large_constants_(true), print_only_essential_constants_(false), print_original_value_(true), print_metadata_(true), @@ -62,7 +62,7 @@ class HloPrintOptions { compact_operands_(false), include_layout_in_shapes_(true), print_result_shape_(true), - print_operand_shape_(true), + print_operand_shape_(false), print_operand_names_(true), print_program_shape_(true), print_percent_(true), diff --git a/xla/hlo/parser/hlo_parser_test.cc b/xla/hlo/parser/hlo_parser_test.cc index dbbb25c8d74d6d..75fd6c0501c0ff 100644 --- a/xla/hlo/parser/hlo_parser_test.cc +++ b/xla/hlo/parser/hlo_parser_test.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_print_options.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/parser/hlo_lexer.h" #include "xla/hlo/testlib/pattern_matcher_gmock.h" @@ -2809,9 +2810,12 @@ class HloParameterizedParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: - // Expects "ToString(ParseHloModule(std::string)) == string", that is, parses - // the string, asserts that it succeeded, stringifies the parsed module, and - // checks that it equals the original string. + // Expects "ToString(ParseHloModule(ToString(ParseHloModule(std::string)))) == + // string", that is, parses the string, asserts that it succeeded, stringifies + // the parsed module, parses this string to ensure that the default ToString() + // version is parsable, then stringifies the newly parsed module with + // appropriate options for original tests, and checks that it equals the + // original string. void ExpectEqual() { VLOG(3) << "Running HloParameterizedParserTest with short_form = " << short_form << ", proto_round_trip = " << proto_round_trip; @@ -2827,9 +2831,20 @@ class HloParameterizedParserTest ShapeUtil::ByteSizeOfElements); TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original)); module = std::move(verified_module); + verified_module = std::make_unique( + GetParam().test_name, config, + /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true, + ShapeUtil::ByteSizeOfElements); + TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule( + module->ToString(HloPrintOptions().set_print_operand_shape(true)))); + module = std::move(verified_module); } else { TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original, config)); + TF_ASSERT_OK_AND_ASSIGN( + module, + ParseAndReturnUnverifiedModule(module->ToString(), module->config())); } if (proto_round_trip) { TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( @@ -2838,9 +2853,8 @@ class HloParameterizedParserTest if (short_form) { EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable())); } else { - EXPECT_EQ( - original, - module->ToString(HloPrintOptions().set_print_large_constants(true))); + EXPECT_EQ(original, module->ToString( + HloPrintOptions().set_print_operand_shape(true))); } for (HloComputation* computation : module->computations()) { for (HloInstruction* instr : computation->instructions()) { @@ -3276,7 +3290,9 @@ ENTRY %ShortConstant.v4 () -> f32[67,89] { )"; auto result = ParseAndReturnVerifiedModule(original); TF_EXPECT_OK(result.status()); - EXPECT_EQ(result.value()->ToString(HloPrintOptions()), original); + EXPECT_EQ(result.value()->ToString( + HloPrintOptions().set_print_large_constants(false)), + original); } TEST_F(HloParserTest, NegativeNan) { diff --git a/xla/hlo/tools/tests/generate_hlo_test_checks_test_output.hlo b/xla/hlo/tools/tests/generate_hlo_test_checks_test_output.hlo index d805fc6f13855f..21ab49a1cfabfa 100644 --- a/xla/hlo/tools/tests/generate_hlo_test_checks_test_output.hlo +++ b/xla/hlo/tools/tests/generate_hlo_test_checks_test_output.hlo @@ -27,11 +27,11 @@ ENTRY no_op { // CHECK-LABEL: ENTRY %test_case // CHECK-NEXT: %[[foo:[^ ]+]] = u8[] parameter(0) // CHECK-NEXT: %[[bar:[^ ]+]] = u8[] parameter(1) -// CHECK-NEXT: %[[foobar:[^ ]+]] = u8[] add(u8[] %[[foo]], u8[] %[[bar]]) +// CHECK-NEXT: %[[foobar:[^ ]+]] = u8[] add(%[[foo]], %[[bar]]) // CHECK-NEXT: %[[baz:[^ ]+]] = u8[] parameter(2) // CHECK-NEXT: %[[qux:[^ ]+]] = u8[] parameter(3) -// CHECK-NEXT: %[[bazqux:[^ ]+]] = u8[] add(u8[] %[[baz]], u8[] %[[qux]]) -// CHECK-NEXT: ROOT %[[foobarbazqux:[^ ]+]] = u8[] add(u8[] %[[foobar]], u8[] %[[bazqux]]) +// CHECK-NEXT: %[[bazqux:[^ ]+]] = u8[] add(%[[baz]], %[[qux]]) +// CHECK-NEXT: ROOT %[[foobarbazqux:[^ ]+]] = u8[] add(%[[foobar]], %[[bazqux]]) HloModule TestBasicFunctionality @@ -51,18 +51,18 @@ ENTRY test_case { // CHECK-LABEL: ENTRY %test_case // CHECK-NEXT: %[[constant:[^ ]+]] = f32[] constant(1) -// CHECK-NEXT: %[[broadcast:[^ ]+]] = f32[8,7]{1,0} broadcast(f32[] %[[constant]]), dimensions={} +// CHECK-NEXT: %[[broadcast:[^ ]+]] = f32[8,7]{1,0} broadcast(%[[constant]]), dimensions={} // CHECK-NEXT: %[[parameter_anon:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} parameter(0) -// CHECK-NEXT: %[[reshape0:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[parameter_anon]]) -// CHECK-NEXT: %[[reshape:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} reshape(f32[8,7]{1,0} %[[reshape0]]) -// CHECK-NEXT: %[[negate_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} negate(f32[1,8,1,7]{3,2,1,0} %[[reshape]]) -// CHECK-NEXT: %[[reshape_1:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[negate_1]]) -// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[8,7]{1,0} exponential(f32[8,7]{1,0} %[[reshape_1]]) -// CHECK-NEXT: %[[add_1:[^ ]+]] = f32[8,7]{1,0} add(f32[8,7]{1,0} %[[broadcast]], f32[8,7]{1,0} %[[exponential]]) -// CHECK-NEXT: %[[divide:[^ ]+]] = f32[8,7]{1,0} divide(f32[8,7]{1,0} %[[broadcast]], f32[8,7]{1,0} %[[add_1]]) +// CHECK-NEXT: %[[reshape0:[^ ]+]] = f32[8,7]{1,0} reshape(%[[parameter_anon]]) +// CHECK-NEXT: %[[reshape:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} reshape(%[[reshape0]]) +// CHECK-NEXT: %[[negate_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} negate(%[[reshape]]) +// CHECK-NEXT: %[[reshape_1:[^ ]+]] = f32[8,7]{1,0} reshape(%[[negate_1]]) +// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[8,7]{1,0} exponential(%[[reshape_1]]) +// CHECK-NEXT: %[[add_1:[^ ]+]] = f32[8,7]{1,0} add(%[[broadcast]], %[[exponential]]) +// CHECK-NEXT: %[[divide:[^ ]+]] = f32[8,7]{1,0} divide(%[[broadcast]], %[[add_1]]) // CHECK-NEXT: %[[parameter_anon_1:[^ ]+]] = f32[1,8,1,7]{3,2,1,0} parameter(1) -// CHECK-NEXT: %[[reshape1:[^ ]+]] = f32[8,7]{1,0} reshape(f32[1,8,1,7]{3,2,1,0} %[[parameter_anon_1]]) -// CHECK-NEXT: ROOT %[[add:[^ ]+]] = f32[8,7]{1,0} add(f32[8,7]{1,0} %[[divide]], f32[8,7]{1,0} %[[reshape1]]) +// CHECK-NEXT: %[[reshape1:[^ ]+]] = f32[8,7]{1,0} reshape(%[[parameter_anon_1]]) +// CHECK-NEXT: ROOT %[[add:[^ ]+]] = f32[8,7]{1,0} add(%[[divide]], %[[reshape1]]) HloModule TestWithRelevantOptimizationPasses @@ -93,21 +93,21 @@ ENTRY no_op { // CHECK: %[[$foo_bar_baz_0:[^ ]+]] // CHECK-NEXT: %[[foo_bar_baz_0_1:[^ ]+]] = f32[] parameter(0) // CHECK-NEXT: %[[foo_bar_baz_0_2:[^ ]+]] = f32[] parameter(1) -// CHECK-NEXT: %[[foo_bar_baz_0_3:[^ ]+]] = f32[] multiply(f32[] %[[foo_bar_baz_0_1]], f32[] %[[foo_bar_baz_0_2]]) -// CHECK-NEXT: ROOT %[[foobarbaz:[^ ]+]] = f32[] add(f32[] %[[foo_bar_baz_0_1]], f32[] %[[foo_bar_baz_0_3]]) +// CHECK-NEXT: %[[foo_bar_baz_0_3:[^ ]+]] = f32[] multiply(%[[foo_bar_baz_0_1]], %[[foo_bar_baz_0_2]]) +// CHECK-NEXT: ROOT %[[foobarbaz:[^ ]+]] = f32[] add(%[[foo_bar_baz_0_1]], %[[foo_bar_baz_0_3]]) // CHECK-LABEL: %foo.bar.baz_0 // CHECK-NEXT: %[[foo_bar_baz_0_4:[^ ]+]] = f32[] parameter(0) // CHECK-NEXT: %[[foo_bar_baz_0_5:[^ ]+]] = f32[] parameter(1) -// CHECK-NEXT: %[[foo_bar_baz_0_6:[^ ]+]] = f32[] add(f32[] %[[foo_bar_baz_0_4]], f32[] %[[foo_bar_baz_0_5]]) -// CHECK-NEXT: ROOT %[[foobarbaz_1:[^ ]+]] = f32[] multiply(f32[] %[[foo_bar_baz_0_4]], f32[] %[[foo_bar_baz_0_6]]) +// CHECK-NEXT: %[[foo_bar_baz_0_6:[^ ]+]] = f32[] add(%[[foo_bar_baz_0_4]], %[[foo_bar_baz_0_5]]) +// CHECK-NEXT: ROOT %[[foobarbaz_1:[^ ]+]] = f32[] multiply(%[[foo_bar_baz_0_4]], %[[foo_bar_baz_0_6]]) // CHECK-LABEL: ENTRY %foobarbaz // CHECK-NEXT: %[[constant_0:[^ ]+]] = f32[4]{0} constant({8.7, 6.5, 4.3, 2.1}) // CHECK-NEXT: %[[constant_0_1:[^ ]+]] = f32[4]{0} constant({1.2, 3.4, 5.6, 7.8}) -// CHECK-NEXT: %[[call_foo_bar_baz_0:[^ ]+]] = f32[] call(f32[4]{0} %[[constant_0]], f32[4]{0} %[[constant_0_1]]), to_apply=%foo.bar.baz_0 -// CHECK-NEXT: %[[call_foo_bar_baz_0_1:[^ ]+]] = f32[] call(f32[4]{0} %[[constant_0_1]], f32[4]{0} %[[constant_0]]), to_apply=%[[$foo_bar_baz_0]] -// CHECK-NEXT: ROOT %[[sum:[^ ]+]] = f32[] add(f32[] %[[call_foo_bar_baz_0_1]], f32[] %[[call_foo_bar_baz_0_1]]) +// CHECK-NEXT: %[[call_foo_bar_baz_0:[^ ]+]] = f32[] call(%[[constant_0]], %[[constant_0_1]]), to_apply=%foo.bar.baz_0 +// CHECK-NEXT: %[[call_foo_bar_baz_0_1:[^ ]+]] = f32[] call(%[[constant_0_1]], %[[constant_0]]), to_apply=%[[$foo_bar_baz_0]] +// CHECK-NEXT: ROOT %[[sum:[^ ]+]] = f32[] add(%[[call_foo_bar_baz_0_1]], %[[call_foo_bar_baz_0_1]]) // Test the tool's ability to disambiguate symbols with extremely similar names. HloModule TestSymbolNameDisambiguation diff --git a/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc b/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc index 033bd4d5d84cfb..374941f94d2969 100644 --- a/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc +++ b/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc @@ -49,25 +49,25 @@ ENTRY main { // CHECK: HloModule bitcast_to_smaller // CHECK: %xla.bitcast_convert_s32_10__2_s8_10_4_.17 (a.1: s32[10]) -> s8[10,4] { // CHECK: %[[VAL_0:.*]] = s32[10]{0} parameter(0) -// CHECK: %[[VAL_1:.*]] = s32[10,1]{1,0} reshape(s32[10]{0} %[[VAL_0]]) -// CHECK: %[[VAL_2:.*]] = s32[10,1]{1,0} broadcast(s32[10,1]{1,0} %[[VAL_1]]), dimensions={0,1} -// CHECK: %[[VAL_3:.*]] = s32[10]{0} reshape(s32[10,1]{1,0} %[[VAL_2]]) -// CHECK: %[[VAL_4:.*]] = s32[10,4]{1,0} broadcast(s32[10]{0} %[[VAL_3]]), dimensions={0} -// CHECK: %[[VAL_5:.*]] = u32[10,4]{1,0} bitcast-convert(s32[10,4]{1,0} %[[VAL_4]]) +// CHECK: %[[VAL_1:.*]] = s32[10,1]{1,0} reshape(%[[VAL_0]]) +// CHECK: %[[VAL_2:.*]] = s32[10,1]{1,0} broadcast(%[[VAL_1]]), dimensions={0,1} +// CHECK: %[[VAL_3:.*]] = s32[10]{0} reshape(%[[VAL_2]]) +// CHECK: %[[VAL_4:.*]] = s32[10,4]{1,0} broadcast(%[[VAL_3]]), dimensions={0} +// CHECK: %[[VAL_5:.*]] = u32[10,4]{1,0} bitcast-convert(%[[VAL_4]]) // CHECK: %[[VAL_6:.*]] = u32[] constant(8) -// CHECK: %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_6]]), dimensions={} +// CHECK: %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(%[[VAL_6]]), dimensions={} // CHECK: %[[VAL_8:.*]] = u32[10,4]{1,0} iota(), iota_dimension=1 -// CHECK: %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %[[VAL_7]], u32[10,4]{1,0} %[[VAL_8]]) -// CHECK: %[[VAL_10:.*]] = u32[10,4]{1,0} shift-right-logical(u32[10,4]{1,0} %[[VAL_5]], u32[10,4]{1,0} %[[VAL_9]]) +// CHECK: %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(%[[VAL_7]], %[[VAL_8]]) +// CHECK: %[[VAL_10:.*]] = u32[10,4]{1,0} shift-right-logical(%[[VAL_5]], %[[VAL_9]]) // CHECK: %[[VAL_11:.*]] = u32[] constant(255) -// CHECK: %[[VAL_12:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_11]]), dimensions={} -// CHECK: %[[VAL_13:.*]] = u32[10,4]{1,0} and(u32[10,4]{1,0} %[[VAL_10]], u32[10,4]{1,0} %[[VAL_12]]) -// CHECK: %[[VAL_14:.*]] = u8[10,4]{1,0} convert(u32[10,4]{1,0} %[[VAL_13]]) -// CHECK: ROOT %[[VAL_15:.*]] = s8[10,4]{1,0} bitcast-convert(u8[10,4]{1,0} %[[VAL_14]]) +// CHECK: %[[VAL_12:.*]] = u32[10,4]{1,0} broadcast(%[[VAL_11]]), dimensions={} +// CHECK: %[[VAL_13:.*]] = u32[10,4]{1,0} and(%[[VAL_10]], %[[VAL_12]]) +// CHECK: %[[VAL_14:.*]] = u8[10,4]{1,0} convert(%[[VAL_13]]) +// CHECK: ROOT %[[VAL_15:.*]] = s8[10,4]{1,0} bitcast-convert(%[[VAL_14]]) // CHECK: } // CHECK: ENTRY %main (p: s32[10]) -> s8[10,4] { // CHECK: %[[VAL_16:.*]] = s32[10]{0} parameter(0) -// CHECK: ROOT %[[VAL_17:.*]] = s8[10,4]{1,0} call(s32[10]{0} %[[VAL_16]]), to_apply=%[[VAL_18:.*]] +// CHECK: ROOT %[[VAL_17:.*]] = s8[10,4]{1,0} call(%[[VAL_16]]), to_apply=%[[VAL_18:.*]] // CHECK: } )")); } @@ -92,25 +92,25 @@ ENTRY main { // CHECK: HloModule bitcast_to_smaller, entry_computation_layout={(s64[10]{0})->s32[10,2]{1,0}} // CHECK: %xla.bitcast_convert_s64_10__2_s32_10_2_.17 (a.1: s64[10]) -> s32[10,2] { // CHECK: %[[VAL_0:.*]] = s64[10]{0} parameter(0) -// CHECK: %[[VAL_1:.*]] = s64[10,1]{1,0} reshape(s64[10]{0} %[[VAL_0]]) -// CHECK: %[[VAL_2:.*]] = s64[10,1]{1,0} broadcast(s64[10,1]{1,0} %[[VAL_1]]), dimensions={0,1} -// CHECK: %[[VAL_3:.*]] = s64[10]{0} reshape(s64[10,1]{1,0} %[[VAL_2]]) -// CHECK: %[[VAL_4:.*]] = s64[10,2]{1,0} broadcast(s64[10]{0} %[[VAL_3]]), dimensions={0} -// CHECK: %[[VAL_5:.*]] = u64[10,2]{1,0} bitcast-convert(s64[10,2]{1,0} %[[VAL_4]]) +// CHECK: %[[VAL_1:.*]] = s64[10,1]{1,0} reshape(%[[VAL_0]]) +// CHECK: %[[VAL_2:.*]] = s64[10,1]{1,0} broadcast(%[[VAL_1]]), dimensions={0,1} +// CHECK: %[[VAL_3:.*]] = s64[10]{0} reshape(%[[VAL_2]]) +// CHECK: %[[VAL_4:.*]] = s64[10,2]{1,0} broadcast(%[[VAL_3]]), dimensions={0} +// CHECK: %[[VAL_5:.*]] = u64[10,2]{1,0} bitcast-convert(%[[VAL_4]]) // CHECK: %[[VAL_6:.*]] = u64[] constant(32) -// CHECK: %[[VAL_7:.*]] = u64[10,2]{1,0} broadcast(u64[] %[[VAL_6]]), dimensions={} +// CHECK: %[[VAL_7:.*]] = u64[10,2]{1,0} broadcast(%[[VAL_6]]), dimensions={} // CHECK: %[[VAL_8:.*]] = u64[10,2]{1,0} iota(), iota_dimension=1 -// CHECK: %[[VAL_9:.*]] = u64[10,2]{1,0} multiply(u64[10,2]{1,0} %[[VAL_7]], u64[10,2]{1,0} %[[VAL_8]]) -// CHECK: %[[VAL_10:.*]] = u64[10,2]{1,0} shift-right-logical(u64[10,2]{1,0} %[[VAL_5]], u64[10,2]{1,0} %[[VAL_9]]) +// CHECK: %[[VAL_9:.*]] = u64[10,2]{1,0} multiply(%[[VAL_7]], %[[VAL_8]]) +// CHECK: %[[VAL_10:.*]] = u64[10,2]{1,0} shift-right-logical(%[[VAL_5]], %[[VAL_9]]) // CHECK: %[[VAL_11:.*]] = u64[] constant(4294967295) -// CHECK: %[[VAL_12:.*]] = u64[10,2]{1,0} broadcast(u64[] %[[VAL_11]]), dimensions={} -// CHECK: %[[VAL_13:.*]] = u64[10,2]{1,0} and(u64[10,2]{1,0} %[[VAL_10]], u64[10,2]{1,0} %[[VAL_12]]) -// CHECK: %[[VAL_14:.*]] = u32[10,2]{1,0} convert(u64[10,2]{1,0} %[[VAL_13]]) -// CHECK: ROOT %[[VAL_15:.*]] = s32[10,2]{1,0} bitcast-convert(u32[10,2]{1,0} %[[VAL_14]]) +// CHECK: %[[VAL_12:.*]] = u64[10,2]{1,0} broadcast(%[[VAL_11]]), dimensions={} +// CHECK: %[[VAL_13:.*]] = u64[10,2]{1,0} and(%[[VAL_10]], %[[VAL_12]]) +// CHECK: %[[VAL_14:.*]] = u32[10,2]{1,0} convert(%[[VAL_13]]) +// CHECK: ROOT %[[VAL_15:.*]] = s32[10,2]{1,0} bitcast-convert(%[[VAL_14]]) // CHECK: } // CHECK: ENTRY %main (p: s64[10]) -> s32[10,2] { // CHECK: %[[VAL_16:.*]] = s64[10]{0} parameter(0) -// CHECK: ROOT %[[VAL_17:.*]] = s32[10,2]{1,0} call(s64[10]{0} %[[VAL_16]]), to_apply=%[[VAL_18:.*]] +// CHECK: ROOT %[[VAL_17:.*]] = s32[10,2]{1,0} call(%[[VAL_16]]), to_apply=%[[VAL_18:.*]] // CHECK: } )")); } @@ -138,24 +138,24 @@ ENTRY main { // CHECK: %or_U32.10 (lhs.11: u32[], rhs.12: u32[]) -> u32[] { // CHECK: %[[VAL_0:.*]] = u32[] parameter(0) // CHECK: %[[VAL_1:.*]] = u32[] parameter(1) -// CHECK: ROOT %[[VAL_2:.*]] = u32[] or(u32[] %[[VAL_0]], u32[] %[[VAL_1]]) +// CHECK: ROOT %[[VAL_2:.*]] = u32[] or(%[[VAL_0]], %[[VAL_1]]) // CHECK: } // CHECK: %xla.bitcast_convert_s8_10_4__2_s32_10_.16 (a.1: s8[10,4]) -> s32[10] { // CHECK: %[[VAL_3:.*]] = s8[10,4]{1,0} parameter(0) -// CHECK: %[[VAL_4:.*]] = u8[10,4]{1,0} bitcast-convert(s8[10,4]{1,0} %[[VAL_3]]) -// CHECK: %[[VAL_5:.*]] = u32[10,4]{1,0} convert(u8[10,4]{1,0} %[[VAL_4]]) +// CHECK: %[[VAL_4:.*]] = u8[10,4]{1,0} bitcast-convert(%[[VAL_3]]) +// CHECK: %[[VAL_5:.*]] = u32[10,4]{1,0} convert(%[[VAL_4]]) // CHECK: %[[VAL_6:.*]] = u32[] constant(8) -// CHECK: %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_6]]), dimensions={} +// CHECK: %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(%[[VAL_6]]), dimensions={} // CHECK: %[[VAL_8:.*]] = u32[10,4]{1,0} iota(), iota_dimension=1 -// CHECK: %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %[[VAL_7]], u32[10,4]{1,0} %[[VAL_8]]) -// CHECK: %[[VAL_10:.*]] = u32[10,4]{1,0} shift-left(u32[10,4]{1,0} %[[VAL_5]], u32[10,4]{1,0} %[[VAL_9]]) +// CHECK: %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(%[[VAL_7]], %[[VAL_8]]) +// CHECK: %[[VAL_10:.*]] = u32[10,4]{1,0} shift-left(%[[VAL_5]], %[[VAL_9]]) // CHECK: %[[VAL_11:.*]] = u32[] constant(0) -// CHECK: %[[VAL_12:.*]] = u32[10]{0} reduce(u32[10,4]{1,0} %[[VAL_10]], u32[] %[[VAL_11]]), dimensions={1}, to_apply=%[[VAL_13:.*]] -// CHECK: ROOT %[[VAL_14:.*]] = s32[10]{0} bitcast-convert(u32[10]{0} %[[VAL_12]]) +// CHECK: %[[VAL_12:.*]] = u32[10]{0} reduce(%[[VAL_10]], %[[VAL_11]]), dimensions={1}, to_apply=%[[VAL_13:.*]] +// CHECK: ROOT %[[VAL_14:.*]] = s32[10]{0} bitcast-convert(%[[VAL_12]]) // CHECK: } // CHECK: ENTRY %main (p: s8[10,4]) -> s32[10] { // CHECK: %[[VAL_15:.*]] = s8[10,4]{1,0} parameter(0) -// CHECK: ROOT %[[VAL_16:.*]] = s32[10]{0} call(s8[10,4]{1,0} %[[VAL_15]]), to_apply=%[[VAL_17:.*]] +// CHECK: ROOT %[[VAL_16:.*]] = s32[10]{0} call(%[[VAL_15]]), to_apply=%[[VAL_17:.*]] // CHECK: } )")); } diff --git a/xla/hlo/transforms/tests/algebraic_simplifier.hlo b/xla/hlo/transforms/tests/algebraic_simplifier.hlo index 72fbad3cd3261b..3f0e18d22eabe4 100644 --- a/xla/hlo/transforms/tests/algebraic_simplifier.hlo +++ b/xla/hlo/transforms/tests/algebraic_simplifier.hlo @@ -6,9 +6,9 @@ // CHECK-LABEL: ENTRY %test // CHECK-NEXT: %[[p0:[^ ]+]] = s32[8]{0} parameter(0) // CHECK-NEXT: %[[p1:[^ ]+]] = s32[8]{0} parameter(1) -// CHECK-NEXT: %[[add:[^ ]+]] = s32[8]{0} add(s32[8]{0} %[[p0]], s32[8]{0} %[[p1]]) +// CHECK-NEXT: %[[add:[^ ]+]] = s32[8]{0} add(%[[p0]], %[[p1]]) // CHECK-NEXT: %[[p2:[^ ]+]] = s32[8]{0} parameter(2) -// CHECK-NEXT: ROOT %[[multiply:[^ ]+]] = s32[8]{0} multiply(s32[8]{0} %[[add]], s32[8]{0} %[[p2]]) +// CHECK-NEXT: ROOT %[[multiply:[^ ]+]] = s32[8]{0} multiply(%[[add]], %[[p2]]) HloModule m ENTRY test { diff --git a/xla/hlo/transforms/tests/cholesky_expander.hlo b/xla/hlo/transforms/tests/cholesky_expander.hlo index 53b19ac36b15b8..dfdd63e56e41ed 100644 --- a/xla/hlo/transforms/tests/cholesky_expander.hlo +++ b/xla/hlo/transforms/tests/cholesky_expander.hlo @@ -5,80 +5,80 @@ // CHECK: %[[$unblocked_body_15:[^ ]+]] // CHECK-NEXT: %[[parameter_16:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_17:[^ ]+]] = s32[] get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_16]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_17:[^ ]+]] = s32[] get-tuple-element(%[[parameter_16]]), index=0 // CHECK-NEXT: %[[constant_21:[^ ]+]] = s32[] constant(1) -// CHECK-NEXT: %[[add_22:[^ ]+]] = s32[] add(s32[] %[[get_tuple_element_17]], s32[] %[[constant_21]]) -// CHECK-NEXT: %[[get_tuple_element_18:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_16]]), index=1 +// CHECK-NEXT: %[[add_22:[^ ]+]] = s32[] add(%[[get_tuple_element_17]], %[[constant_21]]) +// CHECK-NEXT: %[[get_tuple_element_18:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element(%[[parameter_16]]), index=1 // CHECK-NEXT: %[[iota_24:[^ ]+]] = s32[32,4,4]{2,1,0} iota(), iota_dimension=1 // CHECK-NEXT: %[[iota_23:[^ ]+]] = s32[32,4,4]{2,1,0} iota(), iota_dimension=2 -// CHECK-NEXT: %[[compare_25:[^ ]+]] = pred[32,4,4]{2,1,0} compare(s32[32,4,4]{2,1,0} %[[iota_24]], s32[32,4,4]{2,1,0} %[[iota_23]]), direction=GE -// CHECK-NEXT: %[[broadcast_26:[^ ]+]] = s32[32,4,4]{2,1,0} broadcast(s32[] %[[get_tuple_element_17]]), dimensions={} -// CHECK-NEXT: %[[compare_27:[^ ]+]] = pred[32,4,4]{2,1,0} compare(s32[32,4,4]{2,1,0} %[[iota_23]], s32[32,4,4]{2,1,0} %[[broadcast_26]]), direction=EQ -// CHECK-NEXT: %[[and_28:[^ ]+]] = pred[32,4,4]{2,1,0} and(pred[32,4,4]{2,1,0} %[[compare_25]], pred[32,4,4]{2,1,0} %[[compare_27]]) -// CHECK-NEXT: %[[get_tuple_element_19:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_16]]), index=2 -// CHECK-NEXT: %[[dot_31:[^ ]+]] = f16[32,4,4]{2,1,0} dot(f16[32,4,4]{2,1,0} %[[get_tuple_element_19]], f16[32,4,4]{2,1,0} %[[get_tuple_element_19]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} -// CHECK-NEXT: %[[transpose_32:[^ ]+]] = f16[32,4,4]{2,1,0} transpose(f16[32,4,4]{2,1,0} %[[dot_31]]), dimensions={0,1,2} -// CHECK-NEXT: %[[subtract_33:[^ ]+]] = f16[32,4,4]{2,1,0} subtract(f16[32,4,4]{2,1,0} %[[get_tuple_element_18]], f16[32,4,4]{2,1,0} %[[transpose_32]]) +// CHECK-NEXT: %[[compare_25:[^ ]+]] = pred[32,4,4]{2,1,0} compare(%[[iota_24]], %[[iota_23]]), direction=GE +// CHECK-NEXT: %[[broadcast_26:[^ ]+]] = s32[32,4,4]{2,1,0} broadcast(%[[get_tuple_element_17]]), dimensions={} +// CHECK-NEXT: %[[compare_27:[^ ]+]] = pred[32,4,4]{2,1,0} compare(%[[iota_23]], %[[broadcast_26]]), direction=EQ +// CHECK-NEXT: %[[and_28:[^ ]+]] = pred[32,4,4]{2,1,0} and(%[[compare_25]], %[[compare_27]]) +// CHECK-NEXT: %[[get_tuple_element_19:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element(%[[parameter_16]]), index=2 +// CHECK-NEXT: %[[dot_31:[^ ]+]] = f16[32,4,4]{2,1,0} dot(%[[get_tuple_element_19]], %[[get_tuple_element_19]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} +// CHECK-NEXT: %[[transpose_32:[^ ]+]] = f16[32,4,4]{2,1,0} transpose(%[[dot_31]]), dimensions={0,1,2} +// CHECK-NEXT: %[[subtract_33:[^ ]+]] = f16[32,4,4]{2,1,0} subtract(%[[get_tuple_element_18]], %[[transpose_32]]) // CHECK-NEXT: %[[constant_34:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_slice_35:[^ ]+]] = f16[32,1,1]{2,1,0} dynamic-slice(f16[32,4,4]{2,1,0} %[[subtract_33]], s32[] %[[constant_34]], s32[] %[[get_tuple_element_17]], s32[] %[[get_tuple_element_17]]), dynamic_slice_sizes={32,1,1} -// CHECK-NEXT: %[[sqrt_36:[^ ]+]] = f16[32,1,1]{2,1,0} sqrt(f16[32,1,1]{2,1,0} %[[dynamic_slice_35]]) -// CHECK-NEXT: %[[reshape_39:[^ ]+]] = f16[32]{0} reshape(f16[32,1,1]{2,1,0} %[[sqrt_36]]) -// CHECK-NEXT: %[[broadcast_40:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(f16[32]{0} %[[reshape_39]]), dimensions={0} -// CHECK-NEXT: %[[divide_41:[^ ]+]] = f16[32,4,4]{2,1,0} divide(f16[32,4,4]{2,1,0} %[[subtract_33]], f16[32,4,4]{2,1,0} %[[broadcast_40]]) +// CHECK-NEXT: %[[dynamic_slice_35:[^ ]+]] = f16[32,1,1]{2,1,0} dynamic-slice(%[[subtract_33]], %[[constant_34]], %[[get_tuple_element_17]], %[[get_tuple_element_17]]), dynamic_slice_sizes={32,1,1} +// CHECK-NEXT: %[[sqrt_36:[^ ]+]] = f16[32,1,1]{2,1,0} sqrt(%[[dynamic_slice_35]]) +// CHECK-NEXT: %[[reshape_39:[^ ]+]] = f16[32]{0} reshape(%[[sqrt_36]]) +// CHECK-NEXT: %[[broadcast_40:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(%[[reshape_39]]), dimensions={0} +// CHECK-NEXT: %[[divide_41:[^ ]+]] = f16[32,4,4]{2,1,0} divide(%[[subtract_33]], %[[broadcast_40]]) // CHECK-NEXT: %[[constant_29:[^ ]+]] = f16[] constant(0) -// CHECK-NEXT: %[[broadcast_30:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(f16[] %[[constant_29]]), dimensions={} -// CHECK-NEXT: %[[select_42:[^ ]+]] = f16[32,4,4]{2,1,0} select(pred[32,4,4]{2,1,0} %[[and_28]], f16[32,4,4]{2,1,0} %[[divide_41]], f16[32,4,4]{2,1,0} %[[broadcast_30]]) -// CHECK-NEXT: %[[add_43:[^ ]+]] = f16[32,4,4]{2,1,0} add(f16[32,4,4]{2,1,0} %[[select_42]], f16[32,4,4]{2,1,0} %[[get_tuple_element_19]]) -// CHECK-NEXT: %[[get_tuple_element_20:[^ ]+]] = pred[32,1,1]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_16]]), index=3 -// CHECK-NEXT: %[[compare_37:[^ ]+]] = pred[32,1,1]{2,1,0} compare(f16[32,1,1]{2,1,0} %[[sqrt_36]], f16[32,1,1]{2,1,0} %[[sqrt_36]]), direction=NE -// CHECK-NEXT: %[[or_38:[^ ]+]] = pred[32,1,1]{2,1,0} or(pred[32,1,1]{2,1,0} %[[get_tuple_element_20]], pred[32,1,1]{2,1,0} %[[compare_37]]) -// CHECK-NEXT: ROOT %[[tuple_44:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) tuple(s32[] %[[add_22]], f16[32,4,4]{2,1,0} %[[get_tuple_element_18]], f16[32,4,4]{2,1,0} %[[add_43]], pred[32,1,1]{2,1,0} %[[or_38]]) +// CHECK-NEXT: %[[broadcast_30:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(%[[constant_29]]), dimensions={} +// CHECK-NEXT: %[[select_42:[^ ]+]] = f16[32,4,4]{2,1,0} select(%[[and_28]], %[[divide_41]], %[[broadcast_30]]) +// CHECK-NEXT: %[[add_43:[^ ]+]] = f16[32,4,4]{2,1,0} add(%[[select_42]], %[[get_tuple_element_19]]) +// CHECK-NEXT: %[[get_tuple_element_20:[^ ]+]] = pred[32,1,1]{2,1,0} get-tuple-element(%[[parameter_16]]), index=3 +// CHECK-NEXT: %[[compare_37:[^ ]+]] = pred[32,1,1]{2,1,0} compare(%[[sqrt_36]], %[[sqrt_36]]), direction=NE +// CHECK-NEXT: %[[or_38:[^ ]+]] = pred[32,1,1]{2,1,0} or(%[[get_tuple_element_20]], %[[compare_37]]) +// CHECK-NEXT: ROOT %[[tuple_44:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) tuple(%[[add_22]], %[[get_tuple_element_18]], %[[add_43]], %[[or_38]]) // CHECK: %[[$unblocked_condition_45:[^ ]+]] // CHECK-NEXT: %[[parameter_46:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_48:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_46]]), index=1 -// CHECK-NEXT: %[[get_tuple_element_49:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_46]]), index=2 -// CHECK-NEXT: %[[get_tuple_element_50:[^ ]+]] = pred[32,1,1]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_46]]), index=3 -// CHECK-NEXT: %[[get_tuple_element_47:[^ ]+]] = s32[] get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[parameter_46]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_48:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element(%[[parameter_46]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_49:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element(%[[parameter_46]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_50:[^ ]+]] = pred[32,1,1]{2,1,0} get-tuple-element(%[[parameter_46]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_47:[^ ]+]] = s32[] get-tuple-element(%[[parameter_46]]), index=0 // CHECK-NEXT: %[[constant_51:[^ ]+]] = s32[] constant(4) -// CHECK-NEXT: ROOT %[[compare_52:[^ ]+]] = pred[] compare(s32[] %[[get_tuple_element_47]], s32[] %[[constant_51]]), direction=LT +// CHECK-NEXT: ROOT %[[compare_52:[^ ]+]] = pred[] compare(%[[get_tuple_element_47]], %[[constant_51]]), direction=LT // CHECK: %[[$xla_cholesky_f16_32_4_4__upper_70:[^ ]+]] // CHECK-NEXT: %[[constant_13:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[a_1:[^ ]+]] = f16[32,4,4]{2,1,0} parameter(0) -// CHECK-NEXT: %[[transpose_2:[^ ]+]] = f16[32,4,4]{1,2,0} transpose(f16[32,4,4]{2,1,0} %[[a_1]]), dimensions={0,2,1} -// CHECK-NEXT: %[[slice_7:[^ ]+]] = f16[32,4,4]{2,1,0} slice(f16[32,4,4]{1,2,0} %[[transpose_2]]), slice={[0:32], [0:4], [0:4]} -// CHECK-NEXT: %[[slice_8:[^ ]+]] = f16[32,4,4]{2,1,0} slice(f16[32,4,4]{2,1,0} %[[slice_7]]), slice={[0:32], [0:4], [0:4]} +// CHECK-NEXT: %[[transpose_2:[^ ]+]] = f16[32,4,4]{1,2,0} transpose(%[[a_1]]), dimensions={0,2,1} +// CHECK-NEXT: %[[slice_7:[^ ]+]] = f16[32,4,4]{2,1,0} slice(%[[transpose_2]]), slice={[0:32], [0:4], [0:4]} +// CHECK-NEXT: %[[slice_8:[^ ]+]] = f16[32,4,4]{2,1,0} slice(%[[slice_7]]), slice={[0:32], [0:4], [0:4]} // CHECK-NEXT: %[[constant_9:[^ ]+]] = f16[] constant(0) -// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(f16[] %[[constant_9]]), dimensions={} +// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(%[[constant_9]]), dimensions={} // CHECK-NEXT: %[[constant_11:[^ ]+]] = pred[] constant(false) -// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = pred[32,1,1]{2,1,0} broadcast(pred[] %[[constant_11]]), dimensions={} -// CHECK-NEXT: %[[tuple_14:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) tuple(s32[] %[[constant_13]], f16[32,4,4]{2,1,0} %[[slice_8]], f16[32,4,4]{2,1,0} %[[broadcast_10]], pred[32,1,1]{2,1,0} %[[broadcast_12]]) -// CHECK-NEXT: %[[while_53:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) while((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[tuple_14]]), condition=%[[$unblocked_condition_45]], body=%[[$unblocked_body_15]] -// CHECK-NEXT: %[[get_tuple_element_54:[^ ]+]] = s32[] get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[while_53]]), index=0 -// CHECK-NEXT: %[[get_tuple_element_55:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[while_53]]), index=1 +// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = pred[32,1,1]{2,1,0} broadcast(%[[constant_11]]), dimensions={} +// CHECK-NEXT: %[[tuple_14:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) tuple(%[[constant_13]], %[[slice_8]], %[[broadcast_10]], %[[broadcast_12]]) +// CHECK-NEXT: %[[while_53:[^ ]+]] = (s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) while(%[[tuple_14]]), condition=%[[$unblocked_condition_45]], body=%[[$unblocked_body_15]] +// CHECK-NEXT: %[[get_tuple_element_54:[^ ]+]] = s32[] get-tuple-element(%[[while_53]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_55:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element(%[[while_53]]), index=1 // CHECK-NEXT: %[[constant_5:[^ ]+]] = pred[] constant(false) -// CHECK-NEXT: %[[broadcast_6:[^ ]+]] = pred[32,1,1]{2,1,0} broadcast(pred[] %[[constant_5]]), dimensions={} -// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = pred[32,1,1]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[while_53]]), index=3 -// CHECK-NEXT: %[[or_58:[^ ]+]] = pred[32,1,1]{2,1,0} or(pred[32,1,1]{2,1,0} %[[broadcast_6]], pred[32,1,1]{2,1,0} %[[get_tuple_element_57]]) -// CHECK-NEXT: %[[broadcast_63:[^ ]+]] = pred[32,1,1]{2,1,0} broadcast(pred[32,1,1]{2,1,0} %[[or_58]]), dimensions={0,1,2} -// CHECK-NEXT: %[[reshape_64:[^ ]+]] = pred[32]{0} reshape(pred[32,1,1]{2,1,0} %[[broadcast_63]]) -// CHECK-NEXT: %[[broadcast_65:[^ ]+]] = pred[32,4,4]{2,1,0} broadcast(pred[32]{0} %[[reshape_64]]), dimensions={0} +// CHECK-NEXT: %[[broadcast_6:[^ ]+]] = pred[32,1,1]{2,1,0} broadcast(%[[constant_5]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = pred[32,1,1]{2,1,0} get-tuple-element(%[[while_53]]), index=3 +// CHECK-NEXT: %[[or_58:[^ ]+]] = pred[32,1,1]{2,1,0} or(%[[broadcast_6]], %[[get_tuple_element_57]]) +// CHECK-NEXT: %[[broadcast_63:[^ ]+]] = pred[32,1,1]{2,1,0} broadcast(%[[or_58]]), dimensions={0,1,2} +// CHECK-NEXT: %[[reshape_64:[^ ]+]] = pred[32]{0} reshape(%[[broadcast_63]]) +// CHECK-NEXT: %[[broadcast_65:[^ ]+]] = pred[32,4,4]{2,1,0} broadcast(%[[reshape_64]]), dimensions={0} // CHECK-NEXT: %[[constant_66:[^ ]+]] = f16[] constant(nan) -// CHECK-NEXT: %[[broadcast_67:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(f16[] %[[constant_66]]), dimensions={} +// CHECK-NEXT: %[[broadcast_67:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(%[[constant_66]]), dimensions={} // CHECK-NEXT: %[[constant_3:[^ ]+]] = f16[] constant(0) -// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(f16[] %[[constant_3]]), dimensions={} -// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element((s32[], f16[32,4,4]{2,1,0}, f16[32,4,4]{2,1,0}, pred[32,1,1]{2,1,0}) %[[while_53]]), index=2 +// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = f16[32,4,4]{2,1,0} broadcast(%[[constant_3]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = f16[32,4,4]{2,1,0} get-tuple-element(%[[while_53]]), index=2 // CHECK-NEXT: %[[constant_59:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_60:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_61:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_update_slice_62:[^ ]+]] = f16[32,4,4]{2,1,0} dynamic-update-slice(f16[32,4,4]{2,1,0} %[[broadcast_4]], f16[32,4,4]{2,1,0} %[[get_tuple_element_56]], s32[] %[[constant_59]], s32[] %[[constant_60]], s32[] %[[constant_61]]) -// CHECK-NEXT: %[[select_68:[^ ]+]] = f16[32,4,4]{2,1,0} select(pred[32,4,4]{2,1,0} %[[broadcast_65]], f16[32,4,4]{2,1,0} %[[broadcast_67]], f16[32,4,4]{2,1,0} %[[dynamic_update_slice_62]]) -// CHECK-NEXT: ROOT %[[transpose_69:[^ ]+]] = f16[32,4,4]{1,2,0} transpose(f16[32,4,4]{2,1,0} %[[select_68]]), dimensions={0,2,1} +// CHECK-NEXT: %[[dynamic_update_slice_62:[^ ]+]] = f16[32,4,4]{2,1,0} dynamic-update-slice(%[[broadcast_4]], %[[get_tuple_element_56]], %[[constant_59]], %[[constant_60]], %[[constant_61]]) +// CHECK-NEXT: %[[select_68:[^ ]+]] = f16[32,4,4]{2,1,0} select(%[[broadcast_65]], %[[broadcast_67]], %[[dynamic_update_slice_62]]) +// CHECK-NEXT: ROOT %[[transpose_69:[^ ]+]] = f16[32,4,4]{1,2,0} transpose(%[[select_68]]), dimensions={0,2,1} // CHECK-LABEL: ENTRY %test // CHECK-NEXT: %[[input:[^ ]+]] = f16[32,4,4]{2,1,0} parameter(0) -// CHECK-NEXT: ROOT %[[call:[^ ]+]] = f16[32,4,4]{2,1,0} call(f16[32,4,4]{2,1,0} %[[input]]), to_apply=%[[$xla_cholesky_f16_32_4_4__upper_70]] +// CHECK-NEXT: ROOT %[[call:[^ ]+]] = f16[32,4,4]{2,1,0} call(%[[input]]), to_apply=%[[$xla_cholesky_f16_32_4_4__upper_70]] HloModule CholeskyExpanderTest @@ -93,88 +93,88 @@ ENTRY test { // CHECK: %[[$unblocked_body_15:[^ ]+]] // CHECK-NEXT: %[[parameter_16:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_17:[^ ]+]] = s32[] get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_16]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_17:[^ ]+]] = s32[] get-tuple-element(%[[parameter_16]]), index=0 // CHECK-NEXT: %[[constant_21:[^ ]+]] = s32[] constant(1) -// CHECK-NEXT: %[[add_22:[^ ]+]] = s32[] add(s32[] %[[get_tuple_element_17]], s32[] %[[constant_21]]) -// CHECK-NEXT: %[[get_tuple_element_18:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_16]]), index=1 +// CHECK-NEXT: %[[add_22:[^ ]+]] = s32[] add(%[[get_tuple_element_17]], %[[constant_21]]) +// CHECK-NEXT: %[[get_tuple_element_18:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element(%[[parameter_16]]), index=1 // CHECK-NEXT: %[[iota_24:[^ ]+]] = s32[4,8,8]{2,1,0} iota(), iota_dimension=1 // CHECK-NEXT: %[[iota_23:[^ ]+]] = s32[4,8,8]{2,1,0} iota(), iota_dimension=2 -// CHECK-NEXT: %[[compare_25:[^ ]+]] = pred[4,8,8]{2,1,0} compare(s32[4,8,8]{2,1,0} %[[iota_24]], s32[4,8,8]{2,1,0} %[[iota_23]]), direction=GE -// CHECK-NEXT: %[[broadcast_26:[^ ]+]] = s32[4,8,8]{2,1,0} broadcast(s32[] %[[get_tuple_element_17]]), dimensions={} -// CHECK-NEXT: %[[compare_27:[^ ]+]] = pred[4,8,8]{2,1,0} compare(s32[4,8,8]{2,1,0} %[[iota_23]], s32[4,8,8]{2,1,0} %[[broadcast_26]]), direction=EQ -// CHECK-NEXT: %[[and_28:[^ ]+]] = pred[4,8,8]{2,1,0} and(pred[4,8,8]{2,1,0} %[[compare_25]], pred[4,8,8]{2,1,0} %[[compare_27]]) -// CHECK-NEXT: %[[get_tuple_element_19:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_16]]), index=2 -// CHECK-NEXT: %[[real_31:[^ ]+]] = f32[4,8,8]{2,1,0} real(c64[4,8,8]{2,1,0} %[[get_tuple_element_19]]) -// CHECK-NEXT: %[[imag_32:[^ ]+]] = f32[4,8,8]{2,1,0} imag(c64[4,8,8]{2,1,0} %[[get_tuple_element_19]]) -// CHECK-NEXT: %[[negate_33:[^ ]+]] = f32[4,8,8]{2,1,0} negate(f32[4,8,8]{2,1,0} %[[imag_32]]) -// CHECK-NEXT: %[[complex_34:[^ ]+]] = c64[4,8,8]{2,1,0} complex(f32[4,8,8]{2,1,0} %[[real_31]], f32[4,8,8]{2,1,0} %[[negate_33]]) -// CHECK-NEXT: %[[dot_35:[^ ]+]] = c64[4,8,8]{2,1,0} dot(c64[4,8,8]{2,1,0} %[[get_tuple_element_19]], c64[4,8,8]{2,1,0} %[[complex_34]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} -// CHECK-NEXT: %[[transpose_36:[^ ]+]] = c64[4,8,8]{2,1,0} transpose(c64[4,8,8]{2,1,0} %[[dot_35]]), dimensions={0,1,2} -// CHECK-NEXT: %[[subtract_37:[^ ]+]] = c64[4,8,8]{2,1,0} subtract(c64[4,8,8]{2,1,0} %[[get_tuple_element_18]], c64[4,8,8]{2,1,0} %[[transpose_36]]) +// CHECK-NEXT: %[[compare_25:[^ ]+]] = pred[4,8,8]{2,1,0} compare(%[[iota_24]], %[[iota_23]]), direction=GE +// CHECK-NEXT: %[[broadcast_26:[^ ]+]] = s32[4,8,8]{2,1,0} broadcast(%[[get_tuple_element_17]]), dimensions={} +// CHECK-NEXT: %[[compare_27:[^ ]+]] = pred[4,8,8]{2,1,0} compare(%[[iota_23]], %[[broadcast_26]]), direction=EQ +// CHECK-NEXT: %[[and_28:[^ ]+]] = pred[4,8,8]{2,1,0} and(%[[compare_25]], %[[compare_27]]) +// CHECK-NEXT: %[[get_tuple_element_19:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element(%[[parameter_16]]), index=2 +// CHECK-NEXT: %[[real_31:[^ ]+]] = f32[4,8,8]{2,1,0} real(%[[get_tuple_element_19]]) +// CHECK-NEXT: %[[imag_32:[^ ]+]] = f32[4,8,8]{2,1,0} imag(%[[get_tuple_element_19]]) +// CHECK-NEXT: %[[negate_33:[^ ]+]] = f32[4,8,8]{2,1,0} negate(%[[imag_32]]) +// CHECK-NEXT: %[[complex_34:[^ ]+]] = c64[4,8,8]{2,1,0} complex(%[[real_31]], %[[negate_33]]) +// CHECK-NEXT: %[[dot_35:[^ ]+]] = c64[4,8,8]{2,1,0} dot(%[[get_tuple_element_19]], %[[complex_34]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} +// CHECK-NEXT: %[[transpose_36:[^ ]+]] = c64[4,8,8]{2,1,0} transpose(%[[dot_35]]), dimensions={0,1,2} +// CHECK-NEXT: %[[subtract_37:[^ ]+]] = c64[4,8,8]{2,1,0} subtract(%[[get_tuple_element_18]], %[[transpose_36]]) // CHECK-NEXT: %[[constant_38:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_slice_39:[^ ]+]] = c64[4,1,1]{2,1,0} dynamic-slice(c64[4,8,8]{2,1,0} %[[subtract_37]], s32[] %[[constant_38]], s32[] %[[get_tuple_element_17]], s32[] %[[get_tuple_element_17]]), dynamic_slice_sizes={4,1,1} -// CHECK-NEXT: %[[real_40:[^ ]+]] = f32[4,1,1]{2,1,0} real(c64[4,1,1]{2,1,0} %[[dynamic_slice_39]]) -// CHECK-NEXT: %[[sqrt_41:[^ ]+]] = f32[4,1,1]{2,1,0} sqrt(f32[4,1,1]{2,1,0} %[[real_40]]) +// CHECK-NEXT: %[[dynamic_slice_39:[^ ]+]] = c64[4,1,1]{2,1,0} dynamic-slice(%[[subtract_37]], %[[constant_38]], %[[get_tuple_element_17]], %[[get_tuple_element_17]]), dynamic_slice_sizes={4,1,1} +// CHECK-NEXT: %[[real_40:[^ ]+]] = f32[4,1,1]{2,1,0} real(%[[dynamic_slice_39]]) +// CHECK-NEXT: %[[sqrt_41:[^ ]+]] = f32[4,1,1]{2,1,0} sqrt(%[[real_40]]) // CHECK-NEXT: %[[constant_42:[^ ]+]] = f32[] constant(0) -// CHECK-NEXT: %[[broadcast_43:[^ ]+]] = f32[4,1,1]{2,1,0} broadcast(f32[] %[[constant_42]]), dimensions={} -// CHECK-NEXT: %[[complex_44:[^ ]+]] = c64[4,1,1]{2,1,0} complex(f32[4,1,1]{2,1,0} %[[sqrt_41]], f32[4,1,1]{2,1,0} %[[broadcast_43]]) -// CHECK-NEXT: %[[reshape_47:[^ ]+]] = c64[4]{0} reshape(c64[4,1,1]{2,1,0} %[[complex_44]]) -// CHECK-NEXT: %[[broadcast_48:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(c64[4]{0} %[[reshape_47]]), dimensions={0} -// CHECK-NEXT: %[[divide_49:[^ ]+]] = c64[4,8,8]{2,1,0} divide(c64[4,8,8]{2,1,0} %[[subtract_37]], c64[4,8,8]{2,1,0} %[[broadcast_48]]) +// CHECK-NEXT: %[[broadcast_43:[^ ]+]] = f32[4,1,1]{2,1,0} broadcast(%[[constant_42]]), dimensions={} +// CHECK-NEXT: %[[complex_44:[^ ]+]] = c64[4,1,1]{2,1,0} complex(%[[sqrt_41]], %[[broadcast_43]]) +// CHECK-NEXT: %[[reshape_47:[^ ]+]] = c64[4]{0} reshape(%[[complex_44]]) +// CHECK-NEXT: %[[broadcast_48:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(%[[reshape_47]]), dimensions={0} +// CHECK-NEXT: %[[divide_49:[^ ]+]] = c64[4,8,8]{2,1,0} divide(%[[subtract_37]], %[[broadcast_48]]) // CHECK-NEXT: %[[constant_29:[^ ]+]] = c64[] constant((0, 0)) -// CHECK-NEXT: %[[broadcast_30:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(c64[] %[[constant_29]]), dimensions={} -// CHECK-NEXT: %[[select_50:[^ ]+]] = c64[4,8,8]{2,1,0} select(pred[4,8,8]{2,1,0} %[[and_28]], c64[4,8,8]{2,1,0} %[[divide_49]], c64[4,8,8]{2,1,0} %[[broadcast_30]]) -// CHECK-NEXT: %[[add_51:[^ ]+]] = c64[4,8,8]{2,1,0} add(c64[4,8,8]{2,1,0} %[[select_50]], c64[4,8,8]{2,1,0} %[[get_tuple_element_19]]) -// CHECK-NEXT: %[[get_tuple_element_20:[^ ]+]] = pred[4,1,1]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_16]]), index=3 -// CHECK-NEXT: %[[compare_45:[^ ]+]] = pred[4,1,1]{2,1,0} compare(f32[4,1,1]{2,1,0} %[[sqrt_41]], f32[4,1,1]{2,1,0} %[[sqrt_41]]), direction=NE -// CHECK-NEXT: %[[or_46:[^ ]+]] = pred[4,1,1]{2,1,0} or(pred[4,1,1]{2,1,0} %[[get_tuple_element_20]], pred[4,1,1]{2,1,0} %[[compare_45]]) -// CHECK-NEXT: ROOT %[[tuple_52:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) tuple(s32[] %[[add_22]], c64[4,8,8]{2,1,0} %[[get_tuple_element_18]], c64[4,8,8]{2,1,0} %[[add_51]], pred[4,1,1]{2,1,0} %[[or_46]]) +// CHECK-NEXT: %[[broadcast_30:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(%[[constant_29]]), dimensions={} +// CHECK-NEXT: %[[select_50:[^ ]+]] = c64[4,8,8]{2,1,0} select(%[[and_28]], %[[divide_49]], %[[broadcast_30]]) +// CHECK-NEXT: %[[add_51:[^ ]+]] = c64[4,8,8]{2,1,0} add(%[[select_50]], %[[get_tuple_element_19]]) +// CHECK-NEXT: %[[get_tuple_element_20:[^ ]+]] = pred[4,1,1]{2,1,0} get-tuple-element(%[[parameter_16]]), index=3 +// CHECK-NEXT: %[[compare_45:[^ ]+]] = pred[4,1,1]{2,1,0} compare(%[[sqrt_41]], %[[sqrt_41]]), direction=NE +// CHECK-NEXT: %[[or_46:[^ ]+]] = pred[4,1,1]{2,1,0} or(%[[get_tuple_element_20]], %[[compare_45]]) +// CHECK-NEXT: ROOT %[[tuple_52:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) tuple(%[[add_22]], %[[get_tuple_element_18]], %[[add_51]], %[[or_46]]) // CHECK: %[[$unblocked_condition_53:[^ ]+]] // CHECK-NEXT: %[[parameter_54:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_54]]), index=1 -// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_54]]), index=2 -// CHECK-NEXT: %[[get_tuple_element_58:[^ ]+]] = pred[4,1,1]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_54]]), index=3 -// CHECK-NEXT: %[[get_tuple_element_55:[^ ]+]] = s32[] get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[parameter_54]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element(%[[parameter_54]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element(%[[parameter_54]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_58:[^ ]+]] = pred[4,1,1]{2,1,0} get-tuple-element(%[[parameter_54]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_55:[^ ]+]] = s32[] get-tuple-element(%[[parameter_54]]), index=0 // CHECK-NEXT: %[[constant_59:[^ ]+]] = s32[] constant(8) -// CHECK-NEXT: ROOT %[[compare_60:[^ ]+]] = pred[] compare(s32[] %[[get_tuple_element_55]], s32[] %[[constant_59]]), direction=LT +// CHECK-NEXT: ROOT %[[compare_60:[^ ]+]] = pred[] compare(%[[get_tuple_element_55]], %[[constant_59]]), direction=LT // CHECK: %[[$xla_cholesky_c64_4_8_8__upper_78:[^ ]+]] // CHECK-NEXT: %[[constant_13:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[a_1:[^ ]+]] = c64[4,8,8]{2,1,0} parameter(0) -// CHECK-NEXT: %[[transpose_2:[^ ]+]] = c64[4,8,8]{1,2,0} transpose(c64[4,8,8]{2,1,0} %[[a_1]]), dimensions={0,2,1} -// CHECK-NEXT: %[[slice_7:[^ ]+]] = c64[4,8,8]{2,1,0} slice(c64[4,8,8]{1,2,0} %[[transpose_2]]), slice={[0:4], [0:8], [0:8]} -// CHECK-NEXT: %[[slice_8:[^ ]+]] = c64[4,8,8]{2,1,0} slice(c64[4,8,8]{2,1,0} %[[slice_7]]), slice={[0:4], [0:8], [0:8]} +// CHECK-NEXT: %[[transpose_2:[^ ]+]] = c64[4,8,8]{1,2,0} transpose(%[[a_1]]), dimensions={0,2,1} +// CHECK-NEXT: %[[slice_7:[^ ]+]] = c64[4,8,8]{2,1,0} slice(%[[transpose_2]]), slice={[0:4], [0:8], [0:8]} +// CHECK-NEXT: %[[slice_8:[^ ]+]] = c64[4,8,8]{2,1,0} slice(%[[slice_7]]), slice={[0:4], [0:8], [0:8]} // CHECK-NEXT: %[[constant_9:[^ ]+]] = c64[] constant((0, 0)) -// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(c64[] %[[constant_9]]), dimensions={} +// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(%[[constant_9]]), dimensions={} // CHECK-NEXT: %[[constant_11:[^ ]+]] = pred[] constant(false) -// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = pred[4,1,1]{2,1,0} broadcast(pred[] %[[constant_11]]), dimensions={} -// CHECK-NEXT: %[[tuple_14:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) tuple(s32[] %[[constant_13]], c64[4,8,8]{2,1,0} %[[slice_8]], c64[4,8,8]{2,1,0} %[[broadcast_10]], pred[4,1,1]{2,1,0} %[[broadcast_12]]) -// CHECK-NEXT: %[[while_61:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) while((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[tuple_14]]), condition=%[[$unblocked_condition_53]], body=%[[$unblocked_body_15]] -// CHECK-NEXT: %[[get_tuple_element_62:[^ ]+]] = s32[] get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[while_61]]), index=0 -// CHECK-NEXT: %[[get_tuple_element_63:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[while_61]]), index=1 +// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = pred[4,1,1]{2,1,0} broadcast(%[[constant_11]]), dimensions={} +// CHECK-NEXT: %[[tuple_14:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) tuple(%[[constant_13]], %[[slice_8]], %[[broadcast_10]], %[[broadcast_12]]) +// CHECK-NEXT: %[[while_61:[^ ]+]] = (s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) while(%[[tuple_14]]), condition=%[[$unblocked_condition_53]], body=%[[$unblocked_body_15]] +// CHECK-NEXT: %[[get_tuple_element_62:[^ ]+]] = s32[] get-tuple-element(%[[while_61]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_63:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element(%[[while_61]]), index=1 // CHECK-NEXT: %[[constant_5:[^ ]+]] = pred[] constant(false) -// CHECK-NEXT: %[[broadcast_6:[^ ]+]] = pred[4,1,1]{2,1,0} broadcast(pred[] %[[constant_5]]), dimensions={} -// CHECK-NEXT: %[[get_tuple_element_65:[^ ]+]] = pred[4,1,1]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[while_61]]), index=3 -// CHECK-NEXT: %[[or_66:[^ ]+]] = pred[4,1,1]{2,1,0} or(pred[4,1,1]{2,1,0} %[[broadcast_6]], pred[4,1,1]{2,1,0} %[[get_tuple_element_65]]) -// CHECK-NEXT: %[[broadcast_71:[^ ]+]] = pred[4,1,1]{2,1,0} broadcast(pred[4,1,1]{2,1,0} %[[or_66]]), dimensions={0,1,2} -// CHECK-NEXT: %[[reshape_72:[^ ]+]] = pred[4]{0} reshape(pred[4,1,1]{2,1,0} %[[broadcast_71]]) -// CHECK-NEXT: %[[broadcast_73:[^ ]+]] = pred[4,8,8]{2,1,0} broadcast(pred[4]{0} %[[reshape_72]]), dimensions={0} +// CHECK-NEXT: %[[broadcast_6:[^ ]+]] = pred[4,1,1]{2,1,0} broadcast(%[[constant_5]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_65:[^ ]+]] = pred[4,1,1]{2,1,0} get-tuple-element(%[[while_61]]), index=3 +// CHECK-NEXT: %[[or_66:[^ ]+]] = pred[4,1,1]{2,1,0} or(%[[broadcast_6]], %[[get_tuple_element_65]]) +// CHECK-NEXT: %[[broadcast_71:[^ ]+]] = pred[4,1,1]{2,1,0} broadcast(%[[or_66]]), dimensions={0,1,2} +// CHECK-NEXT: %[[reshape_72:[^ ]+]] = pred[4]{0} reshape(%[[broadcast_71]]) +// CHECK-NEXT: %[[broadcast_73:[^ ]+]] = pred[4,8,8]{2,1,0} broadcast(%[[reshape_72]]), dimensions={0} // CHECK-NEXT: %[[constant_74:[^ ]+]] = c64[] constant((nan, 0)) -// CHECK-NEXT: %[[broadcast_75:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(c64[] %[[constant_74]]), dimensions={} +// CHECK-NEXT: %[[broadcast_75:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(%[[constant_74]]), dimensions={} // CHECK-NEXT: %[[constant_3:[^ ]+]] = c64[] constant((0, 0)) -// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(c64[] %[[constant_3]]), dimensions={} -// CHECK-NEXT: %[[get_tuple_element_64:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element((s32[], c64[4,8,8]{2,1,0}, c64[4,8,8]{2,1,0}, pred[4,1,1]{2,1,0}) %[[while_61]]), index=2 +// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = c64[4,8,8]{2,1,0} broadcast(%[[constant_3]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_64:[^ ]+]] = c64[4,8,8]{2,1,0} get-tuple-element(%[[while_61]]), index=2 // CHECK-NEXT: %[[constant_67:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_68:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_69:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_update_slice_70:[^ ]+]] = c64[4,8,8]{2,1,0} dynamic-update-slice(c64[4,8,8]{2,1,0} %[[broadcast_4]], c64[4,8,8]{2,1,0} %[[get_tuple_element_64]], s32[] %[[constant_67]], s32[] %[[constant_68]], s32[] %[[constant_69]]) -// CHECK-NEXT: %[[select_76:[^ ]+]] = c64[4,8,8]{2,1,0} select(pred[4,8,8]{2,1,0} %[[broadcast_73]], c64[4,8,8]{2,1,0} %[[broadcast_75]], c64[4,8,8]{2,1,0} %[[dynamic_update_slice_70]]) -// CHECK-NEXT: ROOT %[[transpose_77:[^ ]+]] = c64[4,8,8]{1,2,0} transpose(c64[4,8,8]{2,1,0} %[[select_76]]), dimensions={0,2,1} +// CHECK-NEXT: %[[dynamic_update_slice_70:[^ ]+]] = c64[4,8,8]{2,1,0} dynamic-update-slice(%[[broadcast_4]], %[[get_tuple_element_64]], %[[constant_67]], %[[constant_68]], %[[constant_69]]) +// CHECK-NEXT: %[[select_76:[^ ]+]] = c64[4,8,8]{2,1,0} select(%[[broadcast_73]], %[[broadcast_75]], %[[dynamic_update_slice_70]]) +// CHECK-NEXT: ROOT %[[transpose_77:[^ ]+]] = c64[4,8,8]{1,2,0} transpose(%[[select_76]]), dimensions={0,2,1} // CHECK-LABEL: ENTRY %test // CHECK-NEXT: %[[input:[^ ]+]] = c64[4,8,8]{2,1,0} parameter(0) -// CHECK-NEXT: ROOT %[[call:[^ ]+]] = c64[4,8,8]{2,1,0} call(c64[4,8,8]{2,1,0} %[[input]]), to_apply=%[[$xla_cholesky_c64_4_8_8__upper_78]] +// CHECK-NEXT: ROOT %[[call:[^ ]+]] = c64[4,8,8]{2,1,0} call(%[[input]]), to_apply=%[[$xla_cholesky_c64_4_8_8__upper_78]] HloModule CholeskyExpanderTest @@ -189,149 +189,149 @@ ENTRY test { // CHECK: %[[$unblocked_body_15:[^ ]+]] // CHECK-NEXT: %[[parameter_16:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_17:[^ ]+]] = s32[] get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_16]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_17:[^ ]+]] = s32[] get-tuple-element(%[[parameter_16]]), index=0 // CHECK-NEXT: %[[constant_21:[^ ]+]] = s32[] constant(1) -// CHECK-NEXT: %[[add_22:[^ ]+]] = s32[] add(s32[] %[[get_tuple_element_17]], s32[] %[[constant_21]]) -// CHECK-NEXT: %[[get_tuple_element_18:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_16]]), index=1 +// CHECK-NEXT: %[[add_22:[^ ]+]] = s32[] add(%[[get_tuple_element_17]], %[[constant_21]]) +// CHECK-NEXT: %[[get_tuple_element_18:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_16]]), index=1 // CHECK-NEXT: %[[iota_24:[^ ]+]] = s32[1,128,128]{2,1,0} iota(), iota_dimension=1 // CHECK-NEXT: %[[iota_23:[^ ]+]] = s32[1,128,128]{2,1,0} iota(), iota_dimension=2 -// CHECK-NEXT: %[[compare_25:[^ ]+]] = pred[1,128,128]{2,1,0} compare(s32[1,128,128]{2,1,0} %[[iota_24]], s32[1,128,128]{2,1,0} %[[iota_23]]), direction=GE -// CHECK-NEXT: %[[broadcast_26:[^ ]+]] = s32[1,128,128]{2,1,0} broadcast(s32[] %[[get_tuple_element_17]]), dimensions={} -// CHECK-NEXT: %[[compare_27:[^ ]+]] = pred[1,128,128]{2,1,0} compare(s32[1,128,128]{2,1,0} %[[iota_23]], s32[1,128,128]{2,1,0} %[[broadcast_26]]), direction=EQ -// CHECK-NEXT: %[[and_28:[^ ]+]] = pred[1,128,128]{2,1,0} and(pred[1,128,128]{2,1,0} %[[compare_25]], pred[1,128,128]{2,1,0} %[[compare_27]]) -// CHECK-NEXT: %[[get_tuple_element_19:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_16]]), index=2 -// CHECK-NEXT: %[[dot_31:[^ ]+]] = f32[1,128,128]{2,1,0} dot(f32[1,128,128]{2,1,0} %[[get_tuple_element_19]], f32[1,128,128]{2,1,0} %[[get_tuple_element_19]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} -// CHECK-NEXT: %[[transpose_32:[^ ]+]] = f32[1,128,128]{2,1,0} transpose(f32[1,128,128]{2,1,0} %[[dot_31]]), dimensions={0,1,2} -// CHECK-NEXT: %[[subtract_33:[^ ]+]] = f32[1,128,128]{2,1,0} subtract(f32[1,128,128]{2,1,0} %[[get_tuple_element_18]], f32[1,128,128]{2,1,0} %[[transpose_32]]) +// CHECK-NEXT: %[[compare_25:[^ ]+]] = pred[1,128,128]{2,1,0} compare(%[[iota_24]], %[[iota_23]]), direction=GE +// CHECK-NEXT: %[[broadcast_26:[^ ]+]] = s32[1,128,128]{2,1,0} broadcast(%[[get_tuple_element_17]]), dimensions={} +// CHECK-NEXT: %[[compare_27:[^ ]+]] = pred[1,128,128]{2,1,0} compare(%[[iota_23]], %[[broadcast_26]]), direction=EQ +// CHECK-NEXT: %[[and_28:[^ ]+]] = pred[1,128,128]{2,1,0} and(%[[compare_25]], %[[compare_27]]) +// CHECK-NEXT: %[[get_tuple_element_19:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_16]]), index=2 +// CHECK-NEXT: %[[dot_31:[^ ]+]] = f32[1,128,128]{2,1,0} dot(%[[get_tuple_element_19]], %[[get_tuple_element_19]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} +// CHECK-NEXT: %[[transpose_32:[^ ]+]] = f32[1,128,128]{2,1,0} transpose(%[[dot_31]]), dimensions={0,1,2} +// CHECK-NEXT: %[[subtract_33:[^ ]+]] = f32[1,128,128]{2,1,0} subtract(%[[get_tuple_element_18]], %[[transpose_32]]) // CHECK-NEXT: %[[constant_34:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_slice_35:[^ ]+]] = f32[1,1,1]{2,1,0} dynamic-slice(f32[1,128,128]{2,1,0} %[[subtract_33]], s32[] %[[constant_34]], s32[] %[[get_tuple_element_17]], s32[] %[[get_tuple_element_17]]), dynamic_slice_sizes={1,1,1} -// CHECK-NEXT: %[[sqrt_36:[^ ]+]] = f32[1,1,1]{2,1,0} sqrt(f32[1,1,1]{2,1,0} %[[dynamic_slice_35]]) -// CHECK-NEXT: %[[reshape_39:[^ ]+]] = f32[1]{0} reshape(f32[1,1,1]{2,1,0} %[[sqrt_36]]) -// CHECK-NEXT: %[[broadcast_40:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(f32[1]{0} %[[reshape_39]]), dimensions={0} -// CHECK-NEXT: %[[divide_41:[^ ]+]] = f32[1,128,128]{2,1,0} divide(f32[1,128,128]{2,1,0} %[[subtract_33]], f32[1,128,128]{2,1,0} %[[broadcast_40]]) +// CHECK-NEXT: %[[dynamic_slice_35:[^ ]+]] = f32[1,1,1]{2,1,0} dynamic-slice(%[[subtract_33]], %[[constant_34]], %[[get_tuple_element_17]], %[[get_tuple_element_17]]), dynamic_slice_sizes={1,1,1} +// CHECK-NEXT: %[[sqrt_36:[^ ]+]] = f32[1,1,1]{2,1,0} sqrt(%[[dynamic_slice_35]]) +// CHECK-NEXT: %[[reshape_39:[^ ]+]] = f32[1]{0} reshape(%[[sqrt_36]]) +// CHECK-NEXT: %[[broadcast_40:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(%[[reshape_39]]), dimensions={0} +// CHECK-NEXT: %[[divide_41:[^ ]+]] = f32[1,128,128]{2,1,0} divide(%[[subtract_33]], %[[broadcast_40]]) // CHECK-NEXT: %[[constant_29:[^ ]+]] = f32[] constant(0) -// CHECK-NEXT: %[[broadcast_30:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(f32[] %[[constant_29]]), dimensions={} -// CHECK-NEXT: %[[select_42:[^ ]+]] = f32[1,128,128]{2,1,0} select(pred[1,128,128]{2,1,0} %[[and_28]], f32[1,128,128]{2,1,0} %[[divide_41]], f32[1,128,128]{2,1,0} %[[broadcast_30]]) -// CHECK-NEXT: %[[add_43:[^ ]+]] = f32[1,128,128]{2,1,0} add(f32[1,128,128]{2,1,0} %[[select_42]], f32[1,128,128]{2,1,0} %[[get_tuple_element_19]]) -// CHECK-NEXT: %[[get_tuple_element_20:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_16]]), index=3 -// CHECK-NEXT: %[[compare_37:[^ ]+]] = pred[1,1,1]{2,1,0} compare(f32[1,1,1]{2,1,0} %[[sqrt_36]], f32[1,1,1]{2,1,0} %[[sqrt_36]]), direction=NE -// CHECK-NEXT: %[[or_38:[^ ]+]] = pred[1,1,1]{2,1,0} or(pred[1,1,1]{2,1,0} %[[get_tuple_element_20]], pred[1,1,1]{2,1,0} %[[compare_37]]) -// CHECK-NEXT: ROOT %[[tuple_44:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(s32[] %[[add_22]], f32[1,128,128]{2,1,0} %[[get_tuple_element_18]], f32[1,128,128]{2,1,0} %[[add_43]], pred[1,1,1]{2,1,0} %[[or_38]]) +// CHECK-NEXT: %[[broadcast_30:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(%[[constant_29]]), dimensions={} +// CHECK-NEXT: %[[select_42:[^ ]+]] = f32[1,128,128]{2,1,0} select(%[[and_28]], %[[divide_41]], %[[broadcast_30]]) +// CHECK-NEXT: %[[add_43:[^ ]+]] = f32[1,128,128]{2,1,0} add(%[[select_42]], %[[get_tuple_element_19]]) +// CHECK-NEXT: %[[get_tuple_element_20:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element(%[[parameter_16]]), index=3 +// CHECK-NEXT: %[[compare_37:[^ ]+]] = pred[1,1,1]{2,1,0} compare(%[[sqrt_36]], %[[sqrt_36]]), direction=NE +// CHECK-NEXT: %[[or_38:[^ ]+]] = pred[1,1,1]{2,1,0} or(%[[get_tuple_element_20]], %[[compare_37]]) +// CHECK-NEXT: ROOT %[[tuple_44:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(%[[add_22]], %[[get_tuple_element_18]], %[[add_43]], %[[or_38]]) // CHECK: %[[$unblocked_condition_45:[^ ]+]] // CHECK-NEXT: %[[parameter_46:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_48:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_46]]), index=1 -// CHECK-NEXT: %[[get_tuple_element_49:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_46]]), index=2 -// CHECK-NEXT: %[[get_tuple_element_50:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_46]]), index=3 -// CHECK-NEXT: %[[get_tuple_element_47:[^ ]+]] = s32[] get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_46]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_48:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_46]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_49:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_46]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_50:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element(%[[parameter_46]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_47:[^ ]+]] = s32[] get-tuple-element(%[[parameter_46]]), index=0 // CHECK-NEXT: %[[constant_51:[^ ]+]] = s32[] constant(128) -// CHECK-NEXT: ROOT %[[compare_52:[^ ]+]] = pred[] compare(s32[] %[[get_tuple_element_47]], s32[] %[[constant_51]]), direction=LT +// CHECK-NEXT: ROOT %[[compare_52:[^ ]+]] = pred[] compare(%[[get_tuple_element_47]], %[[constant_51]]), direction=LT // CHECK: %[[$unblocked_body_82:[^ ]+]] // CHECK-NEXT: %[[parameter_83:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_84:[^ ]+]] = s32[] get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_83]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_84:[^ ]+]] = s32[] get-tuple-element(%[[parameter_83]]), index=0 // CHECK-NEXT: %[[constant_88:[^ ]+]] = s32[] constant(1) -// CHECK-NEXT: %[[add_89:[^ ]+]] = s32[] add(s32[] %[[get_tuple_element_84]], s32[] %[[constant_88]]) -// CHECK-NEXT: %[[get_tuple_element_85:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_83]]), index=1 +// CHECK-NEXT: %[[add_89:[^ ]+]] = s32[] add(%[[get_tuple_element_84]], %[[constant_88]]) +// CHECK-NEXT: %[[get_tuple_element_85:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_83]]), index=1 // CHECK-NEXT: %[[iota_91:[^ ]+]] = s32[1,128,128]{2,1,0} iota(), iota_dimension=1 // CHECK-NEXT: %[[iota_90:[^ ]+]] = s32[1,128,128]{2,1,0} iota(), iota_dimension=2 -// CHECK-NEXT: %[[compare_92:[^ ]+]] = pred[1,128,128]{2,1,0} compare(s32[1,128,128]{2,1,0} %[[iota_91]], s32[1,128,128]{2,1,0} %[[iota_90]]), direction=GE -// CHECK-NEXT: %[[broadcast_93:[^ ]+]] = s32[1,128,128]{2,1,0} broadcast(s32[] %[[get_tuple_element_84]]), dimensions={} -// CHECK-NEXT: %[[compare_94:[^ ]+]] = pred[1,128,128]{2,1,0} compare(s32[1,128,128]{2,1,0} %[[iota_90]], s32[1,128,128]{2,1,0} %[[broadcast_93]]), direction=EQ -// CHECK-NEXT: %[[and_95:[^ ]+]] = pred[1,128,128]{2,1,0} and(pred[1,128,128]{2,1,0} %[[compare_92]], pred[1,128,128]{2,1,0} %[[compare_94]]) -// CHECK-NEXT: %[[get_tuple_element_86:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_83]]), index=2 -// CHECK-NEXT: %[[dot_98:[^ ]+]] = f32[1,128,128]{2,1,0} dot(f32[1,128,128]{2,1,0} %[[get_tuple_element_86]], f32[1,128,128]{2,1,0} %[[get_tuple_element_86]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} -// CHECK-NEXT: %[[transpose_99:[^ ]+]] = f32[1,128,128]{2,1,0} transpose(f32[1,128,128]{2,1,0} %[[dot_98]]), dimensions={0,1,2} -// CHECK-NEXT: %[[subtract_100:[^ ]+]] = f32[1,128,128]{2,1,0} subtract(f32[1,128,128]{2,1,0} %[[get_tuple_element_85]], f32[1,128,128]{2,1,0} %[[transpose_99]]) +// CHECK-NEXT: %[[compare_92:[^ ]+]] = pred[1,128,128]{2,1,0} compare(%[[iota_91]], %[[iota_90]]), direction=GE +// CHECK-NEXT: %[[broadcast_93:[^ ]+]] = s32[1,128,128]{2,1,0} broadcast(%[[get_tuple_element_84]]), dimensions={} +// CHECK-NEXT: %[[compare_94:[^ ]+]] = pred[1,128,128]{2,1,0} compare(%[[iota_90]], %[[broadcast_93]]), direction=EQ +// CHECK-NEXT: %[[and_95:[^ ]+]] = pred[1,128,128]{2,1,0} and(%[[compare_92]], %[[compare_94]]) +// CHECK-NEXT: %[[get_tuple_element_86:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_83]]), index=2 +// CHECK-NEXT: %[[dot_98:[^ ]+]] = f32[1,128,128]{2,1,0} dot(%[[get_tuple_element_86]], %[[get_tuple_element_86]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} +// CHECK-NEXT: %[[transpose_99:[^ ]+]] = f32[1,128,128]{2,1,0} transpose(%[[dot_98]]), dimensions={0,1,2} +// CHECK-NEXT: %[[subtract_100:[^ ]+]] = f32[1,128,128]{2,1,0} subtract(%[[get_tuple_element_85]], %[[transpose_99]]) // CHECK-NEXT: %[[constant_101:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_slice_102:[^ ]+]] = f32[1,1,1]{2,1,0} dynamic-slice(f32[1,128,128]{2,1,0} %[[subtract_100]], s32[] %[[constant_101]], s32[] %[[get_tuple_element_84]], s32[] %[[get_tuple_element_84]]), dynamic_slice_sizes={1,1,1} -// CHECK-NEXT: %[[sqrt_103:[^ ]+]] = f32[1,1,1]{2,1,0} sqrt(f32[1,1,1]{2,1,0} %[[dynamic_slice_102]]) -// CHECK-NEXT: %[[reshape_106:[^ ]+]] = f32[1]{0} reshape(f32[1,1,1]{2,1,0} %[[sqrt_103]]) -// CHECK-NEXT: %[[broadcast_107:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(f32[1]{0} %[[reshape_106]]), dimensions={0} -// CHECK-NEXT: %[[divide_108:[^ ]+]] = f32[1,128,128]{2,1,0} divide(f32[1,128,128]{2,1,0} %[[subtract_100]], f32[1,128,128]{2,1,0} %[[broadcast_107]]) +// CHECK-NEXT: %[[dynamic_slice_102:[^ ]+]] = f32[1,1,1]{2,1,0} dynamic-slice(%[[subtract_100]], %[[constant_101]], %[[get_tuple_element_84]], %[[get_tuple_element_84]]), dynamic_slice_sizes={1,1,1} +// CHECK-NEXT: %[[sqrt_103:[^ ]+]] = f32[1,1,1]{2,1,0} sqrt(%[[dynamic_slice_102]]) +// CHECK-NEXT: %[[reshape_106:[^ ]+]] = f32[1]{0} reshape(%[[sqrt_103]]) +// CHECK-NEXT: %[[broadcast_107:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(%[[reshape_106]]), dimensions={0} +// CHECK-NEXT: %[[divide_108:[^ ]+]] = f32[1,128,128]{2,1,0} divide(%[[subtract_100]], %[[broadcast_107]]) // CHECK-NEXT: %[[constant_96:[^ ]+]] = f32[] constant(0) -// CHECK-NEXT: %[[broadcast_97:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(f32[] %[[constant_96]]), dimensions={} -// CHECK-NEXT: %[[select_109:[^ ]+]] = f32[1,128,128]{2,1,0} select(pred[1,128,128]{2,1,0} %[[and_95]], f32[1,128,128]{2,1,0} %[[divide_108]], f32[1,128,128]{2,1,0} %[[broadcast_97]]) -// CHECK-NEXT: %[[add_110:[^ ]+]] = f32[1,128,128]{2,1,0} add(f32[1,128,128]{2,1,0} %[[select_109]], f32[1,128,128]{2,1,0} %[[get_tuple_element_86]]) -// CHECK-NEXT: %[[get_tuple_element_87:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_83]]), index=3 -// CHECK-NEXT: %[[compare_104:[^ ]+]] = pred[1,1,1]{2,1,0} compare(f32[1,1,1]{2,1,0} %[[sqrt_103]], f32[1,1,1]{2,1,0} %[[sqrt_103]]), direction=NE -// CHECK-NEXT: %[[or_105:[^ ]+]] = pred[1,1,1]{2,1,0} or(pred[1,1,1]{2,1,0} %[[get_tuple_element_87]], pred[1,1,1]{2,1,0} %[[compare_104]]) -// CHECK-NEXT: ROOT %[[tuple_111:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(s32[] %[[add_89]], f32[1,128,128]{2,1,0} %[[get_tuple_element_85]], f32[1,128,128]{2,1,0} %[[add_110]], pred[1,1,1]{2,1,0} %[[or_105]]) +// CHECK-NEXT: %[[broadcast_97:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(%[[constant_96]]), dimensions={} +// CHECK-NEXT: %[[select_109:[^ ]+]] = f32[1,128,128]{2,1,0} select(%[[and_95]], %[[divide_108]], %[[broadcast_97]]) +// CHECK-NEXT: %[[add_110:[^ ]+]] = f32[1,128,128]{2,1,0} add(%[[select_109]], %[[get_tuple_element_86]]) +// CHECK-NEXT: %[[get_tuple_element_87:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element(%[[parameter_83]]), index=3 +// CHECK-NEXT: %[[compare_104:[^ ]+]] = pred[1,1,1]{2,1,0} compare(%[[sqrt_103]], %[[sqrt_103]]), direction=NE +// CHECK-NEXT: %[[or_105:[^ ]+]] = pred[1,1,1]{2,1,0} or(%[[get_tuple_element_87]], %[[compare_104]]) +// CHECK-NEXT: ROOT %[[tuple_111:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(%[[add_89]], %[[get_tuple_element_85]], %[[add_110]], %[[or_105]]) // CHECK: %[[$unblocked_condition_112:[^ ]+]] // CHECK-NEXT: %[[parameter_113:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) parameter(0) -// CHECK-NEXT: %[[get_tuple_element_115:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_113]]), index=1 -// CHECK-NEXT: %[[get_tuple_element_116:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_113]]), index=2 -// CHECK-NEXT: %[[get_tuple_element_117:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_113]]), index=3 -// CHECK-NEXT: %[[get_tuple_element_114:[^ ]+]] = s32[] get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[parameter_113]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_115:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_113]]), index=1 +// CHECK-NEXT: %[[get_tuple_element_116:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[parameter_113]]), index=2 +// CHECK-NEXT: %[[get_tuple_element_117:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element(%[[parameter_113]]), index=3 +// CHECK-NEXT: %[[get_tuple_element_114:[^ ]+]] = s32[] get-tuple-element(%[[parameter_113]]), index=0 // CHECK-NEXT: %[[constant_118:[^ ]+]] = s32[] constant(128) -// CHECK-NEXT: ROOT %[[compare_119:[^ ]+]] = pred[] compare(s32[] %[[get_tuple_element_114]], s32[] %[[constant_118]]), direction=LT +// CHECK-NEXT: ROOT %[[compare_119:[^ ]+]] = pred[] compare(%[[get_tuple_element_114]], %[[constant_118]]), direction=LT // CHECK: %[[$xla_cholesky_f32_1_256_256__upper_137:[^ ]+]] // CHECK-NEXT: %[[constant_13:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[a_1:[^ ]+]] = f32[1,256,256]{2,1,0} parameter(0) -// CHECK-NEXT: %[[transpose_2:[^ ]+]] = f32[1,256,256]{1,2,0} transpose(f32[1,256,256]{2,1,0} %[[a_1]]), dimensions={0,2,1} -// CHECK-NEXT: %[[slice_7:[^ ]+]] = f32[1,256,128]{2,1,0} slice(f32[1,256,256]{1,2,0} %[[transpose_2]]), slice={[0:1], [0:256], [0:128]} -// CHECK-NEXT: %[[slice_8:[^ ]+]] = f32[1,128,128]{2,1,0} slice(f32[1,256,128]{2,1,0} %[[slice_7]]), slice={[0:1], [0:128], [0:128]} +// CHECK-NEXT: %[[transpose_2:[^ ]+]] = f32[1,256,256]{1,2,0} transpose(%[[a_1]]), dimensions={0,2,1} +// CHECK-NEXT: %[[slice_7:[^ ]+]] = f32[1,256,128]{2,1,0} slice(%[[transpose_2]]), slice={[0:1], [0:256], [0:128]} +// CHECK-NEXT: %[[slice_8:[^ ]+]] = f32[1,128,128]{2,1,0} slice(%[[slice_7]]), slice={[0:1], [0:128], [0:128]} // CHECK-NEXT: %[[constant_9:[^ ]+]] = f32[] constant(0) -// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(f32[] %[[constant_9]]), dimensions={} +// CHECK-NEXT: %[[broadcast_10:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(%[[constant_9]]), dimensions={} // CHECK-NEXT: %[[constant_11:[^ ]+]] = pred[] constant(false) -// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(pred[] %[[constant_11]]), dimensions={} -// CHECK-NEXT: %[[tuple_14:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(s32[] %[[constant_13]], f32[1,128,128]{2,1,0} %[[slice_8]], f32[1,128,128]{2,1,0} %[[broadcast_10]], pred[1,1,1]{2,1,0} %[[broadcast_12]]) -// CHECK-NEXT: %[[while_53:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) while((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[tuple_14]]), condition=%[[$unblocked_condition_45]], body=%[[$unblocked_body_15]] -// CHECK-NEXT: %[[get_tuple_element_54:[^ ]+]] = s32[] get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_53]]), index=0 -// CHECK-NEXT: %[[get_tuple_element_55:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_53]]), index=1 +// CHECK-NEXT: %[[broadcast_12:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(%[[constant_11]]), dimensions={} +// CHECK-NEXT: %[[tuple_14:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(%[[constant_13]], %[[slice_8]], %[[broadcast_10]], %[[broadcast_12]]) +// CHECK-NEXT: %[[while_53:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) while(%[[tuple_14]]), condition=%[[$unblocked_condition_45]], body=%[[$unblocked_body_15]] +// CHECK-NEXT: %[[get_tuple_element_54:[^ ]+]] = s32[] get-tuple-element(%[[while_53]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_55:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[while_53]]), index=1 // CHECK-NEXT: %[[constant_80:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[slice_69:[^ ]+]] = f32[1,128,128]{2,1,0} slice(f32[1,256,256]{1,2,0} %[[transpose_2]]), slice={[0:1], [128:256], [128:256]} +// CHECK-NEXT: %[[slice_69:[^ ]+]] = f32[1,128,128]{2,1,0} slice(%[[transpose_2]]), slice={[0:1], [128:256], [128:256]} // CHECK-NEXT: %[[constant_3:[^ ]+]] = f32[] constant(0) -// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = f32[1,256,256]{2,1,0} broadcast(f32[] %[[constant_3]]), dimensions={} -// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_53]]), index=2 +// CHECK-NEXT: %[[broadcast_4:[^ ]+]] = f32[1,256,256]{2,1,0} broadcast(%[[constant_3]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_56:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[while_53]]), index=2 // CHECK-NEXT: %[[constant_59:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_60:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_61:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_update_slice_62:[^ ]+]] = f32[1,256,256]{2,1,0} dynamic-update-slice(f32[1,256,256]{2,1,0} %[[broadcast_4]], f32[1,128,128]{2,1,0} %[[get_tuple_element_56]], s32[] %[[constant_59]], s32[] %[[constant_60]], s32[] %[[constant_61]]) -// CHECK-NEXT: %[[slice_63:[^ ]+]] = f32[1,128,128]{2,1,0} slice(f32[1,256,128]{2,1,0} %[[slice_7]]), slice={[0:1], [128:256], [0:128]} -// CHECK-NEXT: %[[triangular_solve_64:[^ ]+]] = f32[1,128,128]{2,1,0} triangular-solve(f32[1,128,128]{2,1,0} %[[get_tuple_element_56]], f32[1,128,128]{2,1,0} %[[slice_63]]), lower=true, transpose_a=ADJOINT +// CHECK-NEXT: %[[dynamic_update_slice_62:[^ ]+]] = f32[1,256,256]{2,1,0} dynamic-update-slice(%[[broadcast_4]], %[[get_tuple_element_56]], %[[constant_59]], %[[constant_60]], %[[constant_61]]) +// CHECK-NEXT: %[[slice_63:[^ ]+]] = f32[1,128,128]{2,1,0} slice(%[[slice_7]]), slice={[0:1], [128:256], [0:128]} +// CHECK-NEXT: %[[triangular_solve_64:[^ ]+]] = f32[1,128,128]{2,1,0} triangular-solve(%[[get_tuple_element_56]], %[[slice_63]]), lower=true, transpose_a=ADJOINT // CHECK-NEXT: %[[constant_65:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_66:[^ ]+]] = s32[] constant(128) // CHECK-NEXT: %[[constant_67:[^ ]+]] = s32[] constant(0) -// CHECK-NEXT: %[[dynamic_update_slice_68:[^ ]+]] = f32[1,256,256]{2,1,0} dynamic-update-slice(f32[1,256,256]{2,1,0} %[[dynamic_update_slice_62]], f32[1,128,128]{2,1,0} %[[triangular_solve_64]], s32[] %[[constant_65]], s32[] %[[constant_66]], s32[] %[[constant_67]]) -// CHECK-NEXT: %[[slice_70:[^ ]+]] = f32[1,128,128]{2,1,0} slice(f32[1,256,256]{2,1,0} %[[dynamic_update_slice_68]]), slice={[0:1], [128:256], [0:128]} -// CHECK-NEXT: %[[slice_71:[^ ]+]] = f32[1,128,128]{2,1,0} slice(f32[1,256,256]{2,1,0} %[[dynamic_update_slice_68]]), slice={[0:1], [128:256], [0:128]} -// CHECK-NEXT: %[[dot_72:[^ ]+]] = f32[1,128,128]{2,1,0} dot(f32[1,128,128]{2,1,0} %[[slice_70]], f32[1,128,128]{2,1,0} %[[slice_71]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} -// CHECK-NEXT: %[[transpose_73:[^ ]+]] = f32[1,128,128]{2,1,0} transpose(f32[1,128,128]{2,1,0} %[[dot_72]]), dimensions={0,1,2} -// CHECK-NEXT: %[[subtract_74:[^ ]+]] = f32[1,128,128]{2,1,0} subtract(f32[1,128,128]{2,1,0} %[[slice_69]], f32[1,128,128]{2,1,0} %[[transpose_73]]) -// CHECK-NEXT: %[[slice_75:[^ ]+]] = f32[1,128,128]{2,1,0} slice(f32[1,128,128]{2,1,0} %[[subtract_74]]), slice={[0:1], [0:128], [0:128]} +// CHECK-NEXT: %[[dynamic_update_slice_68:[^ ]+]] = f32[1,256,256]{2,1,0} dynamic-update-slice(%[[dynamic_update_slice_62]], %[[triangular_solve_64]], %[[constant_65]], %[[constant_66]], %[[constant_67]]) +// CHECK-NEXT: %[[slice_70:[^ ]+]] = f32[1,128,128]{2,1,0} slice(%[[dynamic_update_slice_68]]), slice={[0:1], [128:256], [0:128]} +// CHECK-NEXT: %[[slice_71:[^ ]+]] = f32[1,128,128]{2,1,0} slice(%[[dynamic_update_slice_68]]), slice={[0:1], [128:256], [0:128]} +// CHECK-NEXT: %[[dot_72:[^ ]+]] = f32[1,128,128]{2,1,0} dot(%[[slice_70]], %[[slice_71]]), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}, operand_precision={highest,highest}, frontend_attributes={grad_x="false",grad_y="false"} +// CHECK-NEXT: %[[transpose_73:[^ ]+]] = f32[1,128,128]{2,1,0} transpose(%[[dot_72]]), dimensions={0,1,2} +// CHECK-NEXT: %[[subtract_74:[^ ]+]] = f32[1,128,128]{2,1,0} subtract(%[[slice_69]], %[[transpose_73]]) +// CHECK-NEXT: %[[slice_75:[^ ]+]] = f32[1,128,128]{2,1,0} slice(%[[subtract_74]]), slice={[0:1], [0:128], [0:128]} // CHECK-NEXT: %[[constant_76:[^ ]+]] = f32[] constant(0) -// CHECK-NEXT: %[[broadcast_77:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(f32[] %[[constant_76]]), dimensions={} +// CHECK-NEXT: %[[broadcast_77:[^ ]+]] = f32[1,128,128]{2,1,0} broadcast(%[[constant_76]]), dimensions={} // CHECK-NEXT: %[[constant_78:[^ ]+]] = pred[] constant(false) -// CHECK-NEXT: %[[broadcast_79:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(pred[] %[[constant_78]]), dimensions={} -// CHECK-NEXT: %[[tuple_81:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(s32[] %[[constant_80]], f32[1,128,128]{2,1,0} %[[slice_75]], f32[1,128,128]{2,1,0} %[[broadcast_77]], pred[1,1,1]{2,1,0} %[[broadcast_79]]) -// CHECK-NEXT: %[[while_120:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) while((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[tuple_81]]), condition=%[[$unblocked_condition_112]], body=%[[$unblocked_body_82]] -// CHECK-NEXT: %[[get_tuple_element_121:[^ ]+]] = s32[] get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_120]]), index=0 -// CHECK-NEXT: %[[get_tuple_element_122:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_120]]), index=1 +// CHECK-NEXT: %[[broadcast_79:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(%[[constant_78]]), dimensions={} +// CHECK-NEXT: %[[tuple_81:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) tuple(%[[constant_80]], %[[slice_75]], %[[broadcast_77]], %[[broadcast_79]]) +// CHECK-NEXT: %[[while_120:[^ ]+]] = (s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) while(%[[tuple_81]]), condition=%[[$unblocked_condition_112]], body=%[[$unblocked_body_82]] +// CHECK-NEXT: %[[get_tuple_element_121:[^ ]+]] = s32[] get-tuple-element(%[[while_120]]), index=0 +// CHECK-NEXT: %[[get_tuple_element_122:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[while_120]]), index=1 // CHECK-NEXT: %[[constant_5:[^ ]+]] = pred[] constant(false) -// CHECK-NEXT: %[[broadcast_6:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(pred[] %[[constant_5]]), dimensions={} -// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_53]]), index=3 -// CHECK-NEXT: %[[or_58:[^ ]+]] = pred[1,1,1]{2,1,0} or(pred[1,1,1]{2,1,0} %[[broadcast_6]], pred[1,1,1]{2,1,0} %[[get_tuple_element_57]]) -// CHECK-NEXT: %[[get_tuple_element_124:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_120]]), index=3 -// CHECK-NEXT: %[[or_125:[^ ]+]] = pred[1,1,1]{2,1,0} or(pred[1,1,1]{2,1,0} %[[or_58]], pred[1,1,1]{2,1,0} %[[get_tuple_element_124]]) -// CHECK-NEXT: %[[broadcast_130:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(pred[1,1,1]{2,1,0} %[[or_125]]), dimensions={0,1,2} -// CHECK-NEXT: %[[reshape_131:[^ ]+]] = pred[1]{0} reshape(pred[1,1,1]{2,1,0} %[[broadcast_130]]) -// CHECK-NEXT: %[[broadcast_132:[^ ]+]] = pred[1,256,256]{2,1,0} broadcast(pred[1]{0} %[[reshape_131]]), dimensions={0} +// CHECK-NEXT: %[[broadcast_6:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(%[[constant_5]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_57:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element(%[[while_53]]), index=3 +// CHECK-NEXT: %[[or_58:[^ ]+]] = pred[1,1,1]{2,1,0} or(%[[broadcast_6]], %[[get_tuple_element_57]]) +// CHECK-NEXT: %[[get_tuple_element_124:[^ ]+]] = pred[1,1,1]{2,1,0} get-tuple-element(%[[while_120]]), index=3 +// CHECK-NEXT: %[[or_125:[^ ]+]] = pred[1,1,1]{2,1,0} or(%[[or_58]], %[[get_tuple_element_124]]) +// CHECK-NEXT: %[[broadcast_130:[^ ]+]] = pred[1,1,1]{2,1,0} broadcast(%[[or_125]]), dimensions={0,1,2} +// CHECK-NEXT: %[[reshape_131:[^ ]+]] = pred[1]{0} reshape(%[[broadcast_130]]) +// CHECK-NEXT: %[[broadcast_132:[^ ]+]] = pred[1,256,256]{2,1,0} broadcast(%[[reshape_131]]), dimensions={0} // CHECK-NEXT: %[[constant_133:[^ ]+]] = f32[] constant(nan) -// CHECK-NEXT: %[[broadcast_134:[^ ]+]] = f32[1,256,256]{2,1,0} broadcast(f32[] %[[constant_133]]), dimensions={} -// CHECK-NEXT: %[[get_tuple_element_123:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element((s32[], f32[1,128,128]{2,1,0}, f32[1,128,128]{2,1,0}, pred[1,1,1]{2,1,0}) %[[while_120]]), index=2 +// CHECK-NEXT: %[[broadcast_134:[^ ]+]] = f32[1,256,256]{2,1,0} broadcast(%[[constant_133]]), dimensions={} +// CHECK-NEXT: %[[get_tuple_element_123:[^ ]+]] = f32[1,128,128]{2,1,0} get-tuple-element(%[[while_120]]), index=2 // CHECK-NEXT: %[[constant_126:[^ ]+]] = s32[] constant(0) // CHECK-NEXT: %[[constant_127:[^ ]+]] = s32[] constant(128) // CHECK-NEXT: %[[constant_128:[^ ]+]] = s32[] constant(128) -// CHECK-NEXT: %[[dynamic_update_slice_129:[^ ]+]] = f32[1,256,256]{2,1,0} dynamic-update-slice(f32[1,256,256]{2,1,0} %[[dynamic_update_slice_68]], f32[1,128,128]{2,1,0} %[[get_tuple_element_123]], s32[] %[[constant_126]], s32[] %[[constant_127]], s32[] %[[constant_128]]) -// CHECK-NEXT: %[[select_135:[^ ]+]] = f32[1,256,256]{2,1,0} select(pred[1,256,256]{2,1,0} %[[broadcast_132]], f32[1,256,256]{2,1,0} %[[broadcast_134]], f32[1,256,256]{2,1,0} %[[dynamic_update_slice_129]]) -// CHECK-NEXT: ROOT %[[transpose_136:[^ ]+]] = f32[1,256,256]{1,2,0} transpose(f32[1,256,256]{2,1,0} %[[select_135]]), dimensions={0,2,1} +// CHECK-NEXT: %[[dynamic_update_slice_129:[^ ]+]] = f32[1,256,256]{2,1,0} dynamic-update-slice(%[[dynamic_update_slice_68]], %[[get_tuple_element_123]], %[[constant_126]], %[[constant_127]], %[[constant_128]]) +// CHECK-NEXT: %[[select_135:[^ ]+]] = f32[1,256,256]{2,1,0} select(%[[broadcast_132]], %[[broadcast_134]], %[[dynamic_update_slice_129]]) +// CHECK-NEXT: ROOT %[[transpose_136:[^ ]+]] = f32[1,256,256]{1,2,0} transpose(%[[select_135]]), dimensions={0,2,1} // CHECK-LABEL: ENTRY %test // CHECK-NEXT: %[[input:[^ ]+]] = f32[1,256,256]{2,1,0} parameter(0) -// CHECK-NEXT: ROOT %[[call:[^ ]+]] = f32[1,256,256]{2,1,0} call(f32[1,256,256]{2,1,0} %[[input]]), to_apply=%[[$xla_cholesky_f32_1_256_256__upper_137]] +// CHECK-NEXT: ROOT %[[call:[^ ]+]] = f32[1,256,256]{2,1,0} call(%[[input]]), to_apply=%[[$xla_cholesky_f32_1_256_256__upper_137]] HloModule CholeskyExpanderTest diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir index cdaef75f7753eb..35696ceea86bb1 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir @@ -12,10 +12,10 @@ func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0) // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + // CHECK-NEXT: %add.3 = f32[4] add(%Arg_0.1, %Arg_1.2) %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2) + // CHECK-NEXT: ROOT %add.4 = f32[4] add(%add.3, %Arg_1.2) %1 = "mhlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %1 : tensor<4xf32> } diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir index d525ac4fdeb4a1..40df18a01ed2d8 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir @@ -3,7 +3,9 @@ // CHECK-LABEL: HloModule dot_algorithm_f8_f8_f32 module @dot_algorithm_f8_f8_f32 { func.func @main(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK: f32[2,2,2] dot(f32[2,2,2] {{.*}}, f32[2,2,2] {{.*}}), {{.*}}, algorithm=dot_any_f8_any_f8_f32 + // CHECK: %[[ARG0:.+]] = f32[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[2,2,2] parameter(1) + // CHECK: f32[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_any_f8_any_f8_f32 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -25,7 +27,9 @@ module @dot_algorithm_f8_f8_f32 { // CHECK-LABEL: HloModule dot_algorithm_f8_f8_f32_fast_accum module @dot_algorithm_f8_f8_f32_fast_accum { func.func @main(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK: f32[2,2,2] dot(f32[2,2,2] {{.*}}, f32[2,2,2] {{.*}}), {{.*}}, algorithm=dot_any_f8_any_f8_f32_fast_accum + // CHECK: %[[ARG0:.+]] = f32[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[2,2,2] parameter(1) + // CHECK: f32[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_any_f8_any_f8_f32_fast_accum %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -47,7 +51,9 @@ module @dot_algorithm_f8_f8_f32_fast_accum { // CHECK-LABEL: HloModule dot_algorithm_f16_f16_f16 module @dot_algorithm_f16_f16_f16 { func.func @main(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK: f32[2,2,2] dot(f32[2,2,2] {{.*}}, f32[2,2,2] {{.*}}), {{.*}}, algorithm=dot_f16_f16_f16 + // CHECK: %[[ARG0:.+]] = f32[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[2,2,2] parameter(1) + // CHECK: f32[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_f16_f16_f16 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -69,7 +75,9 @@ module @dot_algorithm_f16_f16_f16 { // CHECK-LABEL: HloModule dot_algorithm_f16_f16_f32 module @dot_algorithm_f16_f16_f32 { func.func @main(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK: f32[2,2,2] dot(f32[2,2,2] {{.*}}, f32[2,2,2] {{.*}}), {{.*}}, algorithm=dot_f16_f16_f32 + // CHECK: %[[ARG0:.+]] = f32[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[2,2,2] parameter(1) + // CHECK: f32[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_f16_f16_f32 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -91,7 +99,9 @@ module @dot_algorithm_f16_f16_f32 { // CHECK-LABEL: HloModule dot_algorithm_bf16_bf16_bf16 module @dot_algorithm_bf16_bf16_bf16 { func.func @main(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { - // CHECK: bf16[2,2,2] dot(bf16[2,2,2] {{.*}}, bf16[2,2,2] {{.*}}), {{.*}}, algorithm=dot_bf16_bf16_bf16 + // CHECK: %[[ARG0:.+]] = bf16[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = bf16[2,2,2] parameter(1) + // CHECK: bf16[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_bf16_bf16_bf16 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -113,7 +123,9 @@ module @dot_algorithm_bf16_bf16_bf16 { // CHECK-LABEL: HloModule dot_algorithm_bf16_bf16_f32 module @dot_algorithm_bf16_bf16_f32 { func.func @main(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { - // CHECK: bf16[2,2,2] dot(bf16[2,2,2] {{.*}}, bf16[2,2,2] {{.*}}), {{.*}}, algorithm=dot_bf16_bf16_f32 + // CHECK: %[[ARG0:.+]] = bf16[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = bf16[2,2,2] parameter(1) + // CHECK: bf16[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_bf16_bf16_f32 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -135,7 +147,9 @@ module @dot_algorithm_bf16_bf16_f32 { // CHECK-LABEL: HloModule dot_algorithm_bf16_bf16_f32_x3 module @dot_algorithm_bf16_bf16_f32_x3 { func.func @main(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { - // CHECK: bf16[2,2,2] dot(bf16[2,2,2] {{.*}}, bf16[2,2,2] {{.*}}), {{.*}}, algorithm=dot_bf16_bf16_f32_x3 + // CHECK: %[[ARG0:.+]] = bf16[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = bf16[2,2,2] parameter(1) + // CHECK: bf16[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_bf16_bf16_f32_x3 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -157,7 +171,9 @@ module @dot_algorithm_bf16_bf16_f32_x3 { // CHECK-LABEL: HloModule dot_algorithm_bf16_bf16_f32_x6 module @dot_algorithm_bf16_bf16_f32_x6 { func.func @main(%arg0: tensor<2x2x2xbf16>, %arg1: tensor<2x2x2xbf16>) -> tensor<2x2x2xbf16> { - // CHECK: bf16[2,2,2] dot(bf16[2,2,2] {{.*}}, bf16[2,2,2] {{.*}}), {{.*}}, algorithm=dot_bf16_bf16_f32_x6 + // CHECK: %[[ARG0:.+]] = bf16[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = bf16[2,2,2] parameter(1) + // CHECK: bf16[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_bf16_bf16_f32_x6 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -179,7 +195,9 @@ module @dot_algorithm_bf16_bf16_f32_x6 { // CHECK-LABEL: HloModule dot_algorithm_tf32_tf32_f32 module @dot_algorithm_tf32_tf32_f32 { func.func @main(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK: f32[2,2,2] dot(f32[2,2,2] {{.*}}, f32[2,2,2] {{.*}}), {{.*}}, algorithm=dot_tf32_tf32_f32 + // CHECK: %[[ARG0:.+]] = f32[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[2,2,2] parameter(1) + // CHECK: f32[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_tf32_tf32_f32 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -201,7 +219,9 @@ module @dot_algorithm_tf32_tf32_f32 { // CHECK-LABEL: HloModule dot_algorithm_tf32_tf32_f32_x3 module @dot_algorithm_tf32_tf32_f32_x3 { func.func @main(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK: f32[2,2,2] dot(f32[2,2,2] {{.*}}, f32[2,2,2] {{.*}}), {{.*}}, algorithm=dot_tf32_tf32_f32_x3 + // CHECK: %[[ARG0:.+]] = f32[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[2,2,2] parameter(1) + // CHECK: f32[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_tf32_tf32_f32_x3 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -223,7 +243,9 @@ module @dot_algorithm_tf32_tf32_f32_x3 { // CHECK-LABEL: HloModule dot_algorithm_f32_f32_f32 module @dot_algorithm_f32_f32_f32 { func.func @main(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK: f32[2,2,2] dot(f32[2,2,2] {{.*}}, f32[2,2,2] {{.*}}), {{.*}}, algorithm=dot_f32_f32_f32 + // CHECK: %[[ARG0:.+]] = f32[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[2,2,2] parameter(1) + // CHECK: f32[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_f32_f32_f32 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], @@ -245,7 +267,9 @@ module @dot_algorithm_f32_f32_f32 { // CHECK-LABEL: HloModule dot_algorithm_f64_f64_f64 module @dot_algorithm_f64_f64_f64 { func.func @main(%arg0: tensor<2x2x2xf64>, %arg1: tensor<2x2x2xf64>) -> tensor<2x2x2xf64> { - // CHECK: f64[2,2,2] dot(f64[2,2,2] {{.*}}, f64[2,2,2] {{.*}}), {{.*}}, algorithm=dot_f64_f64_f64 + // CHECK: %[[ARG0:.+]] = f64[2,2,2] parameter(0) + // CHECK: %[[ARG1:.+]] = f64[2,2,2] parameter(1) + // CHECK: f64[2,2,2] dot(%[[ARG0]], %[[ARG1]]), {{.*}}, algorithm=dot_f64_f64_f64 %0 = "mhlo.dot_general"(%arg0, %arg1) <{ dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo], diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/call.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/call.mlir index 791dd91a8ebbfd..c688cb554a1e87 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/call.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/call.mlir @@ -4,7 +4,7 @@ module @call_with_backend_config { func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { // CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:Arg_0.[0-9]+]]: s32[8,2]) -> s32[8,2] { // CHECK-NEXT: %[[ARG0]] = s32[8,2] parameter(0) - // CHECK-NEXT: s32[8,2] call(s32[8,2] %[[ARG0]]), to_apply=%g.{{[0-9.]+}}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"device_type":"DEVICE_TYPE_HOST","used_scoped_memory_configs":[]} + // CHECK-NEXT: s32[8,2] call(%[[ARG0]]), to_apply=%g.{{[0-9.]+}}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"device_type":"DEVICE_TYPE_HOST","used_scoped_memory_configs":[]} %0 = call @g.2(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> @@ -22,7 +22,7 @@ module @call_with_sharding { func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { // CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:Arg_0.[0-9]+]]: s32[8,2]) -> s32[8,2] { // CHECK-NEXT: %[[ARG0]] = s32[8,2] parameter(0) - // CHECK-NEXT: s32[8,2] call(s32[8,2] %[[ARG0]]), to_apply=%g.{{[0-9.]+}}, sharding={devices=[2,2]<=[4]} + // CHECK-NEXT: s32[8,2] call(%[[ARG0]]), to_apply=%g.{{[0-9.]+}}, sharding={devices=[2,2]<=[4]} %0 = call @g.2(%arg0) {mhlo.sharding = "{devices=[2,2]<=[4]}"} : (tensor<8x2xi32>) -> tensor<8x2xi32> %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> @@ -40,11 +40,11 @@ module @call_with_sharding_multiple_results { func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { // CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:Arg_0.[0-9]+]]: s32[8,2]) -> s32[8,2] { // CHECK-NEXT: %[[ARG0]] = s32[8,2] parameter(0) - // CHECK-NEXT: %[[CALL:.*]] = (s32[8,2], s32[8,2]) call(s32[8,2] %[[ARG0]]), to_apply=%g.2.2, + // CHECK-NEXT: %[[CALL:.*]] = (s32[8,2], s32[8,2]) call(%[[ARG0]]), to_apply=%g.2.2, // CHECK-SAME{LITERAL}: sharding={{maximal device=0}, {replicated}}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"device_type":"DEVICE_TYPE_HOST","used_scoped_memory_configs":[]} - // CHECK-NEXT: %[[IGNORE:.*]] = s32[8,2] get-tuple-element((s32[8,2], s32[8,2]) %[[CALL]]), index=1, sharding={replicated} - // CHECK-NEXT: %[[GET_ELEMENT:.*]] = s32[8,2] get-tuple-element((s32[8,2], s32[8,2]) %[[CALL]]), index=0, sharding={maximal device=0} - // CHECK-NEXT: ROOT %custom-call.{{[0-9]+}} = s32[8,2] custom-call(s32[8,2] %[[GET_ELEMENT]]), custom_call_target="MoveToHost" + // CHECK-NEXT: %[[IGNORE:.*]] = s32[8,2] get-tuple-element(%[[CALL]]), index=1, sharding={replicated} + // CHECK-NEXT: %[[GET_ELEMENT:.*]] = s32[8,2] get-tuple-element(%[[CALL]]), index=0, sharding={maximal device=0} + // CHECK-NEXT: ROOT %custom-call.{{[0-9]+}} = s32[8,2] custom-call(%[[GET_ELEMENT]]), custom_call_target="MoveToHost" %0:2 = call @g.2(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, mhlo.sharding = "{{maximal device=0}, {replicated}}"} : (tensor<8x2xi32>) -> (tensor<8x2xi32>, tensor<8x2xi32>) %1 = mhlo.custom_call @MoveToHost(%0#0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir index 8be1cf2a8914e1..9d0632ea422a7b 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir @@ -20,17 +20,17 @@ func.func @main() -> tensor { // CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { // CHECK: %[[ARG:.*]] = f32[] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(f32[] %[[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[] negate(%[[ARG]]) // CHECK: } // CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { // CHECK: %[[ARG:.*]] = f32[] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(f32[] %[[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[] copy(%[[ARG]]) // CHECK: } // CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> f32[] { // CHECK: %[[ARG:.*]] = f32[] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(f32[] %[[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[] floor(%[[ARG]]) // CHECK: } // CHECK-LABEL: ENTRY @@ -40,7 +40,7 @@ func.func @main() -> tensor { // CHECK-DAG: %[[OPERAND_1:.*]] = f32[] constant(56) // CHECK-DAG: %[[OPERAND_2:.*]] = f32[] constant(12) // CHECK-DAG: %[[OPERAND_3:.*]] = f32[] constant(13) -// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} +// CHECK: ROOT %[[RESULT:.*]] = f32[] conditional(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} // ----- @@ -64,20 +64,20 @@ func.func @main() -> (tensor, tensor) { // CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { // CHECK: %[[ARG:.*]] = f32[] parameter(0) -// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]]) -// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]]) +// CHECK: %[[NEGATE:.*]] = f32[] negate(%[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(%[[NEGATE]], %[[NEGATE]]) // CHECK: } // CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { // CHECK: %[[ARG:.*]] = f32[] parameter(0) -// CHECK: %[[COPY:.*]] = f32[] copy(f32[] %[[ARG]]) -// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY]], f32[] %[[COPY]]) +// CHECK: %[[COPY:.*]] = f32[] copy(%[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(%[[COPY]], %[[COPY]]) // CHECK: } // CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { // CHECK: %[[ARG:.*]] = f32[] parameter(0) -// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[ARG]]) -// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]]) +// CHECK: %[[FLOOR:.*]] = f32[] floor(%[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(%[[FLOOR]], %[[FLOOR]]) // CHECK: } // CHECK-LABEL: ENTRY @@ -87,10 +87,10 @@ func.func @main() -> (tensor, tensor) { // CHECK-DAG: %[[OPERAND_1:.*]] = f32[] constant(56) // CHECK-DAG: %[[OPERAND_2:.*]] = f32[] constant(12) // CHECK-DAG: %[[OPERAND_3:.*]] = f32[] constant(13) -// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} -// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=0 -// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[TUPLE]]), index=1 -// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]]) +// CHECK: %[[TUPLE:.*]] = (f32[], f32[]) conditional(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} +// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element(%[[TUPLE]]), index=0 +// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element(%[[TUPLE]]), index=1 +// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(%[[RES_1]], %[[RES_2]]) // ----- // Test export mhlo::CaseOp with diffrent number of block-arguments (even 0). @@ -117,24 +117,24 @@ func.func @main() -> (tensor, tensor) { // CHECK: %[[NEGATE_BRANCH:.*]] ({{.*}}: f32[]) -> (f32[], f32[]) { // CHECK: %[[ARG:.*]] = f32[] parameter(0) -// CHECK: %[[NEGATE:.*]] = f32[] negate(f32[] %[[ARG]]) -// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[NEGATE]], f32[] %[[NEGATE]]) +// CHECK: %[[NEGATE:.*]] = f32[] negate(%[[ARG]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(%[[NEGATE]], %[[NEGATE]]) // CHECK: } // CHECK: %[[COPY_BRANCH:.*]] ({{.*}}: (f32[], f32[])) -> (f32[], f32[]) { // CHECK: %[[ARG:.*]] = (f32[], f32[]) parameter(0) -// CHECK-DAG: %[[GTE1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG]]), index=0 -// CHECK-DAG: %[[COPY1:.*]] = f32[] copy(f32[] %[[GTE1]]) -// CHECK-DAG: %[[GTE2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[ARG]]), index=1 -// CHECK-DAG: %[[COPY2:.*]] = f32[] copy(f32[] %[[GTE2]]) -// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[COPY1]], f32[] %[[COPY2]]) +// CHECK-DAG: %[[GTE1:.*]] = f32[] get-tuple-element(%[[ARG]]), index=0 +// CHECK-DAG: %[[COPY1:.*]] = f32[] copy(%[[GTE1]]) +// CHECK-DAG: %[[GTE2:.*]] = f32[] get-tuple-element(%[[ARG]]), index=1 +// CHECK-DAG: %[[COPY2:.*]] = f32[] copy(%[[GTE2]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(%[[COPY1]], %[[COPY2]]) // CHECK: } // CHECK: %[[FLOOR_BRANCH:.*]] ({{.*}}: ()) -> (f32[], f32[]) { // CHECK: %[[ARG:.*]] = () parameter(0) // CHECK: %[[CST:.*]] = f32[] constant -// CHECK: %[[FLOOR:.*]] = f32[] floor(f32[] %[[CST]]) -// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(f32[] %[[FLOOR]], f32[] %[[FLOOR]]) +// CHECK: %[[FLOOR:.*]] = f32[] floor(%[[CST]]) +// CHECK: ROOT %[[TUPLE:.*]] = (f32[], f32[]) tuple(%[[FLOOR]], %[[FLOOR]]) // CHECK: } // CHECK-LABEL: ENTRY @@ -144,11 +144,11 @@ func.func @main() -> (tensor, tensor) { // CHECK-DAG: %[[OPERAND_1:.*]] = f32[] constant(56) // CHECK-DAG: %[[OPERAND_2:.*]] = f32[] constant(12) // CHECK-DAG: %[[OPERAND_3:.*]] = f32[] constant(13) -// CHECK-DAG: %[[TUPLE1:.*]] = (f32[], f32[]) tuple(f32[] %[[OPERAND_2]], f32[] %[[OPERAND_3]]) +// CHECK-DAG: %[[TUPLE1:.*]] = (f32[], f32[]) tuple(%[[OPERAND_2]], %[[OPERAND_3]]) // CHECK-DAG: %[[TUPLE2:.*]] = () tuple() -// CHECK: %[[COND:.*]] = (f32[], f32[]) conditional(s32[] %[[INDEX]], f32[] %[[OPERAND_1]], (f32[], f32[]) %[[TUPLE1]], () %[[TUPLE2]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} +// CHECK: %[[COND:.*]] = (f32[], f32[]) conditional(%[[INDEX]], %[[OPERAND_1]], %[[TUPLE1]], %[[TUPLE2]]), branch_computations={%[[NEGATE_BRANCH]], %[[COPY_BRANCH]], %[[FLOOR_BRANCH]]} -// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[COND]]), index=0 -// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element((f32[], f32[]) %[[COND]]), index=1 -// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(f32[] %[[RES_1]], f32[] %[[RES_2]]) +// CHECK: %[[RES_1:.*]] = f32[] get-tuple-element(%[[COND]]), index=0 +// CHECK: %[[RES_2:.*]] = f32[] get-tuple-element(%[[COND]]), index=1 +// CHECK: ROOT %[[RESULT:.*]] = (f32[], f32[]) tuple(%[[RES_1]], %[[RES_2]]) diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir index 0147971a96c32c..70ebca79e2769a 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir @@ -5,11 +5,11 @@ module @composite { // CHECK: %[[ADD:add.[0-9]+]] ([[ARG0:Arg_0.[0-9]+]]: f32[]) -> f32[] { // CHECK: %[[ARG0]] = f32[] parameter(0) // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) - // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG0]], f32[] %[[CONSTANT]]) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(%[[ARG0]], %[[CONSTANT]]) // CHECK: } // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) - // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(%[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} // CHECK: } func.func @main() -> tensor { %0 = mhlo.constant dense<4.200000e+01> : tensor @@ -41,7 +41,7 @@ module @composite { //CHECK: } //CHECK: ENTRY %main.{{[0-9]+}} () -> () { //CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) - //CHECK: %call.5 = () call(f32[] %[[CONSTANT]]), to_apply=%[[RETURN]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + //CHECK: %call.5 = () call(%[[CONSTANT]]), to_apply=%[[RETURN]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} //CHECK: ROOT %tuple.{{[0-9]+}} = () tuple() //CHECK: } func.func @main() -> () { @@ -70,15 +70,15 @@ module @composite { //CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> (f32[], f32[]) { //CHECK: %[[ARG]] = f32[] parameter(0) //CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) - //CHECK: %[[ADDOP:add.[0-9]+]] = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) - //CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple(f32[] %[[ADDOP]], f32[] %[[ADDOP]]) + //CHECK: %[[ADDOP:add.[0-9]+]] = f32[] add(%[[ARG]], %[[CONSTANT]]) + //CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple(%[[ADDOP]], %[[ADDOP]]) //CHECK: } //CHECK: ENTRY %main.{{[0-9]+}} () -> (f32[], f32[]) { //CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) - //CHECK: %[[CALL:call.[0-9]+]] = (f32[], f32[]) call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} - //CHECK: %[[GTE0:get-tuple-element.[0-9]+]] = f32[] get-tuple-element((f32[], f32[]) %[[CALL]]), index=0 - //CHECK: %[[GTE1:get-tuple-element.[0-9]+]] = f32[] get-tuple-element((f32[], f32[]) %[[CALL]]), index=1 - //CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple(f32[] %[[GTE0]], f32[] %[[GTE1]]) + //CHECK: %[[CALL:call.[0-9]+]] = (f32[], f32[]) call(%[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + //CHECK: %[[GTE0:get-tuple-element.[0-9]+]] = f32[] get-tuple-element(%[[CALL]]), index=0 + //CHECK: %[[GTE1:get-tuple-element.[0-9]+]] = f32[] get-tuple-element(%[[CALL]]), index=1 + //CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple(%[[GTE0]], %[[GTE1]]) //CHECK: } func.func @main() -> (tensor, tensor) { %0 = mhlo.constant dense<4.200000e+01> : tensor @@ -108,11 +108,11 @@ module @composite { // CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> f32[] { // CHECK: %[[ARG]] = f32[] parameter(0) // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) - // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(%[[ARG]], %[[CONSTANT]]) // CHECK: } // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) - // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"} + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(%[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"} // CHECK: } func.func @main() -> tensor { %0 = mhlo.constant dense<4.200000e+01> : tensor @@ -137,11 +137,11 @@ module @composite { // CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> f32[] { // CHECK: %[[ARG]] = f32[] parameter(0) // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) - // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(%[[ARG]], %[[CONSTANT]]) // CHECK: } // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) - // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="0"} + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(%[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="0"} // CHECK: } func.func @main() -> tensor { %0 = mhlo.constant dense<4.200000e+01> : tensor @@ -169,11 +169,11 @@ module @composite { // CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> f32[] { // CHECK: %[[ARG]] = f32[] parameter(0) // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) - // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(%[[ARG]], %[[CONSTANT]]) // CHECK: } // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) - // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"} + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(%[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"} // CHECK: } func.func @main() -> tensor { %0 = mhlo.constant dense<4.200000e+01> : tensor diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir index bb83b10cc3f5fb..229e2527d184a5 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir @@ -10,12 +10,12 @@ func.func @main(%arg0: tensor>) func.return %4 : tensor<1x?xi64, #mhlo.type_extensions> // CHECK: %[[ARG0:.*]] = s64[<=4,1] parameter(0) // CHECK-NEXT: %[[SIZE0x1:.*]] = s32[1] constant({1}) - // CHECK-NEXT: %[[SIZE1:.*]] = s32[] get-dimension-size(s64[<=4,1] %[[ARG0]]), dimensions={0} - // CHECK-NEXT: %[[SIZE1x1:.*]] = s32[1] reshape(s32[] %[[SIZE1]]) - // CHECK-NEXT: %[[SHAPE:.*]] = s32[2] concatenate(s32[1] %[[SIZE0x1]], s32[1] %[[SIZE1x1]]), dimensions={0} - // CHECK-NEXT: %[[SHAPE0x1:.*]] = s32[1] slice(s32[2] %[[SHAPE]]), slice={[0:1]} - // CHECK-NEXT: %[[SHAPE0:.*]] = s32[] reshape(s32[1] %[[SHAPE0x1]]) - // CHECK-NEXT: %[[SHAPE1x1:.*]] = s32[1] slice(s32[2] %[[SHAPE]]), slice={[1:2]} - // CHECK-NEXT: %[[SHAPE1:.*]] = s32[] reshape(s32[1] %[[SHAPE1x1]]) - // CHECK-NEXT: ROOT %dynamic-reshape.10 = s64[1,<=4] dynamic-reshape(s64[<=4,1] %[[ARG0]], s32[] %[[SHAPE0]], s32[] %[[SHAPE1]]) + // CHECK-NEXT: %[[SIZE1:.*]] = s32[] get-dimension-size(%[[ARG0]]), dimensions={0} + // CHECK-NEXT: %[[SIZE1x1:.*]] = s32[1] reshape(%[[SIZE1]]) + // CHECK-NEXT: %[[SHAPE:.*]] = s32[2] concatenate(%[[SIZE0x1]], %[[SIZE1x1]]), dimensions={0} + // CHECK-NEXT: %[[SHAPE0x1:.*]] = s32[1] slice(%[[SHAPE]]), slice={[0:1]} + // CHECK-NEXT: %[[SHAPE0:.*]] = s32[] reshape(%[[SHAPE0x1]]) + // CHECK-NEXT: %[[SHAPE1x1:.*]] = s32[1] slice(%[[SHAPE]]), slice={[1:2]} + // CHECK-NEXT: %[[SHAPE1:.*]] = s32[] reshape(%[[SHAPE1x1]]) + // CHECK-NEXT: ROOT %dynamic-reshape.10 = s64[1,<=4] dynamic-reshape(%[[ARG0]], %[[SHAPE0]], %[[SHAPE1]]) } diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index 180c203de88226..abd39ed3de1319 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -19,7 +19,7 @@ func.func @main(%arg0: tensor<2xi1>) -> tensor<2xi1> { // CHECK: ENTRY // CHECK: %[[ARG:.*]] = pred[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = pred[2] xor(pred[2] %[[ARG]], pred[2] %[[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = pred[2] xor(%[[ARG]], %[[ARG]]) // ----- @@ -32,7 +32,7 @@ func.func @main(%arg0: !mhlo.token, %arg1: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = token[] parameter(0) // CHECK: %[[ARG1:.*]] = token[] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = token[] after-all(token[] %[[ARG0]], token[] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = token[] after-all(%[[ARG0]], %[[ARG1]]) // ----- @@ -69,7 +69,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<5xf32> { // CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[5] reduce-scatter(f32[10] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[5] reduce-scatter(%[[ARG0]]) // CHECK-SAME: channel_id=5 // CHECK-SAME{LITERAL}: replica_groups={{0,2},{1,3}} // CHECK-SAME: dimensions={0} @@ -90,7 +90,7 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { // CHECK: ENTRY // CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: ROOT %[[OUTPUT:.*]] = f32[128,128] all-gather(f32[128,32] %[[INPUT]]) +// CHECK: ROOT %[[OUTPUT:.*]] = f32[128,128] all-gather(%[[INPUT]]) // CHECK-SAME: channel_id=1 // CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} // CHECK-SAME: dimensions={1} @@ -111,7 +111,7 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { // CHECK: ENTRY // CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: ROOT %[[OUTPUT:.*]] = f32[128,128] all-gather(f32[128,32] %[[INPUT]]) +// CHECK: ROOT %[[OUTPUT:.*]] = f32[128,128] all-gather(%[[INPUT]]) // CHECK-SAME: channel_id=1 // CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} // CHECK-SAME: dimensions={1} @@ -123,9 +123,9 @@ func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple // CHECK: %[[ARG0:.*]] = f32[8,2] parameter(0) // CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1) // CHECK-NEXT: %[[TUPLE:.*]] = (f32[8,2], f32[8,4]) tuple - // CHECK-NEXT: %[[TUPLE_ARG0:.*]] = f32[8,2] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=0 - // CHECK-NEXT: %[[TUPLE_ARG1:.*]] = f32[8,4] get-tuple-element((f32[8,2], f32[8,4]) %[[TUPLE]]), index=1 - // CHECK-NEXT: (f32[8,8], f32[8,16]) all-gather(f32[8,2] %[[TUPLE_ARG0]], f32[8,4] %[[TUPLE_ARG1]]), channel_id=1, replica_groups={{.*}}, dimensions={1} + // CHECK-NEXT: %[[TUPLE_ARG0:.*]] = f32[8,2] get-tuple-element(%[[TUPLE]]), index=0 + // CHECK-NEXT: %[[TUPLE_ARG1:.*]] = f32[8,4] get-tuple-element(%[[TUPLE]]), index=1 + // CHECK-NEXT: (f32[8,8], f32[8,16]) all-gather(%[[TUPLE_ARG0]], %[[TUPLE_ARG1]]), channel_id=1, replica_groups={{.*}}, dimensions={1} %0:2 = "mhlo.all_gather"(%arg0, %arg1) { all_gather_dim = 1 : i64, channel_handle = #mhlo.channel_handle, @@ -159,7 +159,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(%[[ARG0]]) // CHECK-SAME: channel_id=5 // CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} // CHECK-SAME: to_apply=%[[COMPUTATION]] @@ -188,7 +188,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(%[[ARG0]]) // CHECK-SAME: channel_id=5 // CHECK-SAME{LITERAL}: replica_groups={{0,2,4},{1,3,5,6}} // CHECK-SAME: to_apply=%[[COMPUTATION]] @@ -217,7 +217,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(%[[ARG0]]) // CHECK-SAME: channel_id=5 // CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} // CHECK-SAME: use_global_device_ids=true @@ -229,9 +229,9 @@ func.func private @main(%arg0: tensor<8xf32>, %arg1: tensor) -> tuple, %arg3: tensor): %2 = mhlo.add %arg2, %arg3 : tensor @@ -265,7 +265,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<5xf32> { // CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[5] reduce-scatter(f32[10] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[5] reduce-scatter(%[[ARG0]]) // CHECK-SAME: channel_id=5 // CHECK-SAME{LITERAL}: replica_groups={{0,2},{1,3}} // CHECK-SAME: dimensions={0} @@ -296,7 +296,7 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<5xf32> { // CHECK: %[[COMPUTATION:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[5] reduce-scatter(f32[10] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[5] reduce-scatter(%[[ARG0]]) // CHECK-SAME: channel_id=5 // CHECK-SAME{LITERAL}: replica_groups={{0,2},{1,3}} // CHECK-SAME: use_global_device_ids=true @@ -318,12 +318,12 @@ func.func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %mean: tenso // CHECK: [[VAL_3:%.*]] = f32[2] parameter(2) // CHECK: [[VAL_4:%.*]] = f32[2] parameter(3) // CHECK: [[VAL_5:%.*]] = f32[2,2,2,2] parameter(4) -// CHECK: [[BNG:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]], f32[2] [[VAL_4]], f32[2,2,2,2] [[VAL_5]]), epsilon=0.001, feature_index=0 -// CHECK: [[GTE0:%.*]] = f32[2,2,2,2] get-tuple-element((f32[2,2,2,2], f32[2], f32[2]) [[BNG]]), index=0 -// CHECK: [[GTE1:%.*]] = f32[2] get-tuple-element((f32[2,2,2,2], f32[2], f32[2]) [[BNG]]), index=1 -// CHECK: [[GTE2:%.*]] = f32[2] get-tuple-element((f32[2,2,2,2], f32[2], f32[2]) [[BNG]]), index=2 +// CHECK: [[BNG:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad([[VAL_1]], [[VAL_2]], [[VAL_3]], [[VAL_4]], [[VAL_5]]), epsilon=0.001, feature_index=0 +// CHECK: [[GTE0:%.*]] = f32[2,2,2,2] get-tuple-element([[BNG]]), index=0 +// CHECK: [[GTE1:%.*]] = f32[2] get-tuple-element([[BNG]]), index=1 +// CHECK: [[GTE2:%.*]] = f32[2] get-tuple-element([[BNG]]), index=2 // CHECK: ROOT -// CHECK-SAME: [[RES:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) tuple(f32[2,2,2,2] [[GTE0]], f32[2] [[GTE1]], f32[2] [[GTE2]]) +// CHECK-SAME: [[RES:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) tuple([[GTE0]], [[GTE1]], [[GTE2]]) // ----- @@ -339,12 +339,12 @@ func.func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: ten // CHECK: [[VAL_1:%.*]] = f32[2,2,2,2] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[2] parameter(1) // CHECK: [[VAL_3:%.*]] = f32[2] parameter(2) -// CHECK: [[BNT:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] [[VAL_1]], f32[2] [[VAL_2]], f32[2] [[VAL_3]]), epsilon=0.001, feature_index=3 -// CHECK: [[GTE0:%.*]] = f32[2,2,2,2] get-tuple-element((f32[2,2,2,2], f32[2], f32[2]) [[BNT]]), index=0 -// CHECK: [[GTE1:%.*]] = f32[2] get-tuple-element((f32[2,2,2,2], f32[2], f32[2]) [[BNT]]), index=1 -// CHECK: [[GTE2:%.*]] = f32[2] get-tuple-element((f32[2,2,2,2], f32[2], f32[2]) [[BNT]]), index=2 +// CHECK: [[BNT:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training([[VAL_1]], [[VAL_2]], [[VAL_3]]), epsilon=0.001, feature_index=3 +// CHECK: [[GTE0:%.*]] = f32[2,2,2,2] get-tuple-element([[BNT]]), index=0 +// CHECK: [[GTE1:%.*]] = f32[2] get-tuple-element([[BNT]]), index=1 +// CHECK: [[GTE2:%.*]] = f32[2] get-tuple-element([[BNT]]), index=2 // CHECK: ROOT -// CHECK-SAME: [[RES:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) tuple(f32[2,2,2,2] [[GTE0]], f32[2] [[GTE1]], f32[2] [[GTE2]]) +// CHECK-SAME: [[RES:%.*]] = (f32[2,2,2,2], f32[2], f32[2]) tuple([[GTE0]], [[GTE1]], [[GTE2]]) // ----- @@ -352,22 +352,22 @@ func.func @main(%input: tensor<2x2x2x2xf32>, %scale: tensor<2xf32>, %offset: ten func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xi32>, %arg3: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { // CHECK: [[VAL_1:%.*]] = f32[4] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[4] parameter(1) - // CHECK: [[ATAN2:%.*]] = f32[4] atan2(f32[4] [[VAL_1]], f32[4] [[VAL_2]]) + // CHECK: [[ATAN2:%.*]] = f32[4] atan2([[VAL_1]], [[VAL_2]]) // CHECK: [[VAL_3:%.*]] = s32[4] parameter(2) // CHECK: [[VAL_4:%.*]] = s32[4] parameter(3) %0 = mhlo.atan2 %arg0, %arg1 : tensor<4xf32> - // CHECK: [[SHL:%.*]] = s32[4] shift-left(s32[4] [[VAL_3]], s32[4] [[VAL_4]]) + // CHECK: [[SHL:%.*]] = s32[4] shift-left([[VAL_3]], [[VAL_4]]) %1 = mhlo.shift_left %arg2, %arg3 : tensor<4xi32> - // CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic(s32[4] [[VAL_3]], s32[4] [[VAL_4]]) + // CHECK: [[SHRA:%.*]] = s32[4] shift-right-arithmetic([[VAL_3]], [[VAL_4]]) %2 = mhlo.shift_right_arithmetic %arg2, %arg3 : tensor<4xi32> - // CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical(s32[4] [[VAL_3]], s32[4] [[VAL_4]]) + // CHECK: [[SHRL:%.*]] = s32[4] shift-right-logical([[VAL_3]], [[VAL_4]]) %3 = mhlo.shift_right_logical %arg2, %arg3 : tensor<4xi32> // CHECK: ROOT - // CHECK-SAME: [[VAL_9:%.*]] = (f32[4], s32[4], s32[4], s32[4]) tuple(f32[4] [[ATAN2]], s32[4] [[SHL]], s32[4] [[SHRA]], s32[4] [[SHRL]]) + // CHECK-SAME: [[VAL_9:%.*]] = (f32[4], s32[4], s32[4], s32[4]) tuple([[ATAN2]], [[SHL]], [[SHRA]], [[SHRL]]) func.return %0, %1, %2, %3 : tensor<4xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32> } @@ -381,14 +381,14 @@ func.func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK: ENTRY // CHECK: %[[ARG:.*]] = s32[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] bitcast-convert(s32[2] %[[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[2] bitcast-convert(%[[ARG]]) // ----- // CHECK: HloModule func.func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { // CHECK: [[ARG:%.*]] = s32[4] parameter(0) - // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] [[ARG]]), dimensions={3} + // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast([[ARG]]), dimensions={3} %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>}> : (tensor<4xi32>) -> tensor<1x2x3x4xi32> func.return %0 : tensor<1x2x3x4xi32> } @@ -405,7 +405,7 @@ func.func @main(%arg0: tensor<1xf32>) -> tensor<1x10xf32> { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = f32[1] parameter(0) -// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast(f32[1] [[ARG]]), dimensions={0} +// CHECK: ROOT %broadcast.2 = f32[1,10] broadcast([[ARG]]), dimensions={0} // ----- @@ -456,19 +456,19 @@ func.func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK: %[[ARG_1]] = s32[4] parameter(0) // CHECK: %[[ARG_2]] = s32[4] parameter(1) // CHECK: ROOT -// CHECK-SAME: s32[4] add(s32[4] %[[ARG_1]], s32[4] %[[ARG_2]]) +// CHECK-SAME: s32[4] add(%[[ARG_1]], %[[ARG_2]]) // CHECK: [[CALLEE_2:%.*]] ([[ARG_3:.*]]: s32[4], [[ARG_4:.*]]: s32[4]) -> s32[4] { // CHECK: %[[ARG_3]] = s32[4] parameter(0) // CHECK: %[[ARG_4]] = s32[4] parameter(1) // CHECK: ROOT -// CHECK-SAME: s32[4] add(s32[4] %[[ARG_3]], s32[4] %[[ARG_4]]) +// CHECK-SAME: s32[4] add(%[[ARG_3]], %[[ARG_4]]) // CHECK: ENTRY [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> s32[4] { // CHECK: %[[ARG]] = s32[4] parameter(0) -// CHECK: [[CALL_OUT:%.*]] = s32[4] call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE_1]] +// CHECK: [[CALL_OUT:%.*]] = s32[4] call(%[[ARG]], %[[ARG]]), to_apply=[[CALLEE_1]] // CHECK: ROOT -// CHECK-SAME: s32[4] call(s32[4] [[CALL_OUT]], s32[4] [[CALL_OUT]]), to_apply=[[CALLEE_2]] +// CHECK-SAME: s32[4] call([[CALL_OUT]], [[CALL_OUT]]), to_apply=[[CALLEE_2]] // ----- @@ -490,11 +490,11 @@ func.func @callee(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, // CHECK: ENTRY // CHECK-SAME: [[MAIN:%.*]] ([[ARG:.*]]: s32[4]) -> (s32[4], s32[4]) { // CHECK: %[[ARG]] = s32[4] parameter(0) -// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(s32[4] %[[ARG]], s32[4] %[[ARG]]), to_apply=[[CALLEE]] -// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=0 -// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element((s32[4], s32[4]) [[CALL_OUT]]), index=1 +// CHECK: [[CALL_OUT:%.*]] = (s32[4], s32[4]) call(%[[ARG]], %[[ARG]]), to_apply=[[CALLEE]] +// CHECK: [[OUT_0:%.*]] = s32[4] get-tuple-element([[CALL_OUT]]), index=0 +// CHECK: [[OUT_1:%.*]] = s32[4] get-tuple-element([[CALL_OUT]]), index=1 // CHECK: ROOT -// CHECK-SAME: (s32[4], s32[4]) tuple(s32[4] [[OUT_0]], s32[4] [[OUT_1]]) +// CHECK-SAME: (s32[4], s32[4]) tuple([[OUT_0]], [[OUT_1]]) // ----- @@ -508,7 +508,7 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { } // CHECK: ENTRY // CHECK: [[ARG:%.*]] = f32[128,32] parameter(0) -// CHECK: ROOT [[RESULT:%.*]] = f32[128,32] collective-broadcast(f32[128,32] [[ARG]]), channel_id=1 +// CHECK: ROOT [[RESULT:%.*]] = f32[128,32] collective-broadcast([[ARG]]), channel_id=1 // CHECK-SAME{LITERAL}: replica_groups={{0,1},{2,3}} // ----- @@ -522,7 +522,7 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { } // CHECK: ENTRY // CHECK: [[ARG:%.*]] = f32[128,32] parameter(0) -// CHECK: ROOT [[RESULT:%.*]] = f32[128,32] collective-permute(f32[128,32] [[ARG]]), channel_id=1, source_target_pairs={{\{\{}}0,1},{1,2},{2,3}} +// CHECK: ROOT [[RESULT:%.*]] = f32[128,32] collective-permute([[ARG]]), channel_id=1, source_target_pairs={{\{\{}}0,1},{1,2},{2,3}} // ----- @@ -540,7 +540,7 @@ func.func @main(%arg0 : tensor<5x2xf32>, // CHECK: %[[ARG0:.*]] = f32[5,2] parameter(0) // CHECK: %[[ARG1:.*]] = f32[5,5] parameter(1) // CHECK: %[[ARG2:.*]] = f32[5,7] parameter(2) -// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(f32[5,2] %[[ARG0]], f32[5,5] %[[ARG1]], f32[5,7] %[[ARG2]]), dimensions={1} +// CHECK: ROOT %[[RESULT:.*]] = f32[5,14] concatenate(%[[ARG0]], %[[ARG1]], %[[ARG2]]), dimensions={1} // ----- @@ -558,7 +558,7 @@ func.func @main() { %cst_1 = arith.constant dense<1> : tensor<1xi32> // CHECK: %[[C:.*]] = s32[] constant(1) - // CHECK: s32[10] broadcast(s32[] %[[C]]) + // CHECK: s32[10] broadcast(%[[C]]) %cst_2 = arith.constant dense<1> : tensor<10xi32> // CHECK: s32[4] constant({1, 2, 3, 4}) @@ -644,7 +644,7 @@ func.func @main(%arg0 : tensor<100x26x26x32xf32>, %arg1 : tensor<3x3x1x32xf32>) // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[100,26,26,32] parameter(0) // CHECK: %[[ARG1:.*]] = f32[3,3,1,32] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(f32[100,26,26,32] %[[ARG0]], f32[3,3,1,32] %[[ARG1]]), +// CHECK: ROOT %[[RESULT:.*]] = f32[100,28,28,1] convolution(%[[ARG0]], %[[ARG1]]), // CHECK-SAME: window={size=3x3 pad=2_2x2_2}, // CHECK-SAME: dim_labels=b01f_01oi->b01f @@ -678,7 +678,7 @@ func.func @main(%arg0 : tensor<100x26x26x32xi8>, %arg1 : tensor<3x3x1x32xi8>) -> // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = s8[100,26,26,32] parameter(0) // CHECK: %[[ARG1:.*]] = s8[3,3,1,32] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = s32[100,28,28,1] convolution(s8[100,26,26,32] %[[ARG0]], s8[3,3,1,32] %[[ARG1]]), +// CHECK: ROOT %[[RESULT:.*]] = s32[100,28,28,1] convolution(%[[ARG0]], %[[ARG1]]), // CHECK-SAME: window={size=3x3 pad=2_2x2_2}, // CHECK-SAME: dim_labels=b01f_01oi->b01f @@ -713,7 +713,7 @@ func.func @main(%arg0 : tensor<100x26x26x32xi8>, %arg1 : tensor<3x3x1x32xi8>) -> // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = s8[100,26,26,32] parameter(0) // CHECK: %[[ARG1:.*]] = s8[3,3,1,32] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = s32[100,28,28,1] convolution(s8[100,26,26,32] %[[ARG0]], s8[3,3,1,32] %[[ARG1]]), +// CHECK: ROOT %[[RESULT:.*]] = s32[100,28,28,1] convolution(%[[ARG0]], %[[ARG1]]), // CHECK-SAME: window={size=3x3 pad=2_2x2_2 rhs_reversal=1x1}, // CHECK-SAME: dim_labels=b01f_01oi->b01f @@ -727,7 +727,7 @@ func.func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK: ENTRY // CHECK: %[[ARG:.*]] = s32[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(s32[2] %[[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(%[[ARG]]) // ----- @@ -754,22 +754,22 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: ENTRY // CHECK: %[[ARG:.*]] = f32[2] parameter(0) -// CHECK: %[[E5M2_VAL:.*]] = f8e5m2[2] convert(f32[2] %[[ARG]]) -// CHECK: %[[F32_VAL:.*]] = f32[2] convert(f8e5m2[2] %[[E5M2_VAL]]) -// CHECK: %[[E4M3FN_VAL:.*]] = f8e4m3fn[2] convert(f32[2] %[[F32_VAL]]) -// CHECK: %[[F32_VAL2:.*]] = f32[2] convert(f8e4m3fn[2] %[[E4M3FN_VAL]]) -// CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]]) -// CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]]) -// CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]]) -// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) -// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) -// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) -// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) -// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) -// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(f32[2] %[[F32_VAL6]]) -// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(f4e2m1fn[2] %[[E2M1FN_VAL]]) -// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(f32[2] %[[F32_VAL7]]) -// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(f8e8m0fnu[2] %[[E8M0FNU_VAL]]) +// CHECK: %[[E5M2_VAL:.*]] = f8e5m2[2] convert(%[[ARG]]) +// CHECK: %[[F32_VAL:.*]] = f32[2] convert(%[[E5M2_VAL]]) +// CHECK: %[[E4M3FN_VAL:.*]] = f8e4m3fn[2] convert(%[[F32_VAL]]) +// CHECK: %[[F32_VAL2:.*]] = f32[2] convert(%[[E4M3FN_VAL]]) +// CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(%[[F32_VAL2]]) +// CHECK: %[[F32_VAL3:.*]] = f32[2] convert(%[[E4M3FNUZ_VAL]]) +// CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(%[[F32_VAL3]]) +// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(%[[E5M2FNUZ_VAL]]) +// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(%[[F32_VAL4]]) +// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(%[[E4M3_VAL]]) +// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(%[[F32_VAL5]]) +// CHECK: %[[F32_VAL6:.*]] = f32[2] convert(%[[E3M4_VAL]]) +// CHECK: %[[E2M1FN_VAL:.*]] = f4e2m1fn[2] convert(%[[F32_VAL6]]) +// CHECK: %[[F32_VAL7:.*]] = f32[2] convert(%[[E2M1FN_VAL]]) +// CHECK: %[[E8M0FNU_VAL:.*]] = f8e8m0fnu[2] convert(%[[F32_VAL7]]) +// CHECK: ROOT %[[F32_VAL8:.*]] = f32[2] convert(%[[E8M0FNU_VAL]]) // ----- @@ -782,7 +782,7 @@ func.func @main(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xui32>) -> tensor<5x5xi // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[5,5] parameter(0) // CHECK: %[[ARG1:.*]] = u32[5,5] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = s8[5,5] stochastic-convert(f32[5,5] %[[ARG0]], u32[5,5] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = s8[5,5] stochastic-convert(%[[ARG0]], %[[ARG1]]) // ----- @@ -794,7 +794,7 @@ func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = s32[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = s32[2] copy([[ARG]]) // ----- @@ -806,11 +806,11 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { } // CHECK: %[[SUM_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] -// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[] add(%[[ARG0]], %[[ARG1]]) // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(f32[10] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[10] all-reduce(%[[ARG0]]) // CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} // CHECK-SAME: to_apply=%[[SUM_COMPUTATION]] @@ -825,7 +825,7 @@ func.func @main(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: ENTRY // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: ROOT -// CHECK-SAME: f32[2,3] custom-call(f32[2,3] [[VAL_1]]) +// CHECK-SAME: f32[2,3] custom-call([[VAL_1]]) // CHECK-SAME: custom_call_target="SetBound" // CHECK-SAME: literal=s32[] 1 @@ -845,7 +845,7 @@ func.func @main(%arg0: tensor<6xf32>, %arg1: tensor<6xf32>, %arg2: tensor<3xi32> // CHECK: [[ARG_4:%.*]] = s32[3] parameter(4) // CHECK: [[ARG_5:%.*]] = s32[3] parameter(5) // CHECK: ROOT -// CHECK-SAME: f32[6] ragged-all-to-all(f32[6] [[ARG_0]], f32[6] [[ARG_1]], s32[3] [[ARG_2]], s32[3] [[ARG_3]], s32[3] [[ARG_4]], /*index=5*/s32[3] [[ARG_5]]) +// CHECK-SAME: f32[6] ragged-all-to-all([[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_3]], [[ARG_4]], /*index=5*/[[ARG_5]]) // CHECK-SAME{LITERAL}: channel_id=1, replica_groups={{0,1,2}} // ----- @@ -876,9 +876,9 @@ func.func public @main(%arg0: tensor<16x256xbf16>, %arg1: tensor, %arg2: te // CHECK-DAG: [[ARG1:%.*]] = s32[] parameter(1) // CHECK-DAG: [[ARG2:%.*]] = s32[16,256] parameter(2) // CHECK-DAG: [[ARG3:%.*]] = bf16[] parameter(3) -// CHECK-DAG: [[VAL0:%.*]] = (bf16[16,256], s32[16,256]) sort(bf16[16,256] [[ARG0]], s32[16,256] [[ARG2]]) -// CHECK-DAG: [[VAL1:%.*]] = s32[16,256] get-tuple-element((bf16[16,256], s32[16,256]) [[VAL0]]) -// CHECK-DAG: [[VAL2:%.*]] = s32[16,4] slice(s32[16,256] [[VAL1]]) +// CHECK-DAG: [[VAL0:%.*]] = (bf16[16,256], s32[16,256]) sort([[ARG0]], [[ARG2]]) +// CHECK-DAG: [[VAL1:%.*]] = s32[16,256] get-tuple-element([[VAL0]]) +// CHECK-DAG: [[VAL2:%.*]] = s32[16,4] slice([[VAL1]]) // ----- @@ -908,14 +908,14 @@ func.func public @main(%arg0: tensor<16x256xbf16>, %arg1: tensor, %arg2: te // CHECK: s32[] parameter(3) // CHECK: [[ARG0:%.*]] = bf16[] parameter(0) // CHECK: [[ARG1:%.*]] = bf16[] parameter(1) -// CHECK: ROOT [[VAL:%.*]] = pred[] compare(bf16[] [[ARG0]], bf16[] [[ARG1]]), direction=GT +// CHECK: ROOT [[VAL:%.*]] = pred[] compare([[ARG0]], [[ARG1]]), direction=GT // CHECK: ENTRY // CHECK-DAG: [[ARG0:%.*]] = bf16[16,256] parameter(0) // CHECK-DAG: [[ARG1:%.*]] = s32[] parameter(1) // CHECK-DAG: [[ARG2:%.*]] = s32[16,256] parameter(2) // CHECK-DAG: [[ARG3:%.*]] = bf16[] parameter(3) -// CHECK-DAG: (bf16[16,128], s32[16,128]) custom-call(bf16[16,256] [[ARG0]], s32[16,256] [[ARG2]], bf16[] [[ARG3]], s32[] [[ARG1]]), +// CHECK-DAG: (bf16[16,128], s32[16,128]) custom-call([[ARG0]], [[ARG2]], [[ARG3]], [[ARG1]]), // CHECK-SAME: custom_call_target="PartialReduce", called_computations={%top_k_gt_comparator.[[COMPARATOR]]} // CHECK-SAME: backend_config={"log2_reduction": 1, "reduction_dim": 1, "to_apply_type": "comparator", "top_k": 4, "recall_target": 0.949218} @@ -1522,7 +1522,7 @@ func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3x // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]) +// CHECK-SAME: f32[1,2,3] custom-call([[VAL_1]], [[VAL_2]]) // CHECK-SAME: custom_call_target="foo" // CHECK-SAME: custom_call_has_side_effect=true // CHECK-SAME: schedule=SCHEDULE_LATEST @@ -1540,7 +1540,7 @@ func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3x // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]) +// CHECK-SAME: f32[1,2,3] custom-call([[VAL_1]], [[VAL_2]]) // CHECK-SAME: custom_call_target="foo" // CHECK-SAME: custom_call_has_side_effect=true // CHECK-SAME: schedule=SCHEDULE_EARLIEST @@ -1557,7 +1557,7 @@ func.func @main(%arg0: tensor<2x3xf32>) -> tuple> { // CHECK: ENTRY // CHECK: [[ARG0:%.*]] = f32[2,3] parameter(0) // CHECK: ROOT -// CHECK-SAME: (f32[2,3]) custom-call(f32[2,3] [[ARG0]]) +// CHECK-SAME: (f32[2,3]) custom-call([[ARG0]]) // CHECK-SAME: custom_call_target="foo" // ----- @@ -1571,7 +1571,7 @@ func.func @main(%arg0: tensor<2x3xf32>) -> tuple, tensor<4x5xf16 // CHECK: ENTRY // CHECK: [[ARG0:%.*]] = f32[2,3] parameter(0) // CHECK: ROOT -// CHECK-SAME: (f32[2,3], f16[4,5]) custom-call(f32[2,3] [[ARG0]]) +// CHECK-SAME: (f32[2,3], f16[4,5]) custom-call([[ARG0]]) // CHECK-SAME: custom_call_target="foo" // ----- @@ -1584,12 +1584,12 @@ func.func @main(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<4x5xf16>) { // CHECK: ENTRY // CHECK: [[ARG0:%.*]] = f32[2,3] parameter(0) -// CHECK: [[OUTS:%.*]] = (f32[2,3], f16[4,5]) custom-call(f32[2,3] [[ARG0]]) +// CHECK: [[OUTS:%.*]] = (f32[2,3], f16[4,5]) custom-call([[ARG0]]) // CHECK-SAME: custom_call_target="foo" -// CHECK-DAG: [[OUT0:%.*]] = f32[2,3] get-tuple-element((f32[2,3], f16[4,5]) [[OUTS]]), index=0 -// CHECK-DAG: [[OUT1:%.*]] = f16[4,5] get-tuple-element((f32[2,3], f16[4,5]) [[OUTS]]), index=1 +// CHECK-DAG: [[OUT0:%.*]] = f32[2,3] get-tuple-element([[OUTS]]), index=0 +// CHECK-DAG: [[OUT1:%.*]] = f16[4,5] get-tuple-element([[OUTS]]), index=1 // CHECK: ROOT -// CHECK-SAME: (f32[2,3], f16[4,5]) tuple(f32[2,3] [[OUT0]], f16[4,5] [[OUT1]]) +// CHECK-SAME: (f32[2,3], f16[4,5]) tuple([[OUT0]], [[OUT1]]) // ----- @@ -1605,7 +1605,7 @@ func.func @main(%arg0: tensor<3xi8>, %arg1: tensor<3xi8>) -> tensor { // CHECK: %[[ARG0]] = s8[3] parameter(0) // CHECK: %[[ARG1]] = s8[3] parameter(1) // CHECK: ROOT -// CHECK-SAME: s64[] dot(s8[3] %[[ARG0]], s8[3] %[[ARG1]]), +// CHECK-SAME: s64[] dot(%[[ARG0]], %[[ARG1]]) // ----- @@ -1620,7 +1620,7 @@ func.func @main(%arg0: tensor<3xi4>, %arg1: tensor<3xi4>) -> tensor { // CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: s4[3], [[ARG_2:.*]]: s4[3]) -> s8[] // CHECK: %[[ARG_1:.*]] = s4[3] parameter(0) // CHECK: %[[ARG_2:.*]] = s4[3] parameter(1) -// CHECK: ROOT %[[DOT:.*]] = s8[] dot(s4[3] %[[ARG_1:.*]], s4[3] %[[ARG_2:.*]]) +// CHECK: ROOT %[[DOT:.*]] = s8[] dot(%[[ARG_1:.*]], %[[ARG_2:.*]]) // ----- @@ -1635,7 +1635,7 @@ func.func @main(%arg0: tensor<3xui4>, %arg1: tensor<3xui4>) -> tensor { // CHECK: [[CALLEE_1:%.*]] ([[ARG_1:.*]]: u4[3], [[ARG_2:.*]]: u4[3]) -> u8[] // CHECK: %[[ARG_1:.*]] = u4[3] parameter(0) // CHECK: %[[ARG_2:.*]] = u4[3] parameter(1) -// CHECK: ROOT %[[DOT:.*]] = u8[] dot(u4[3] %[[ARG_1:.*]], u4[3] %[[ARG_2:.*]]) +// CHECK: ROOT %[[DOT:.*]] = u8[] dot(%[[ARG_1:.*]], %[[ARG_2:.*]]) // ----- @@ -1658,7 +1658,7 @@ func.func @main(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>) -> tensor<2x2x // CHECK: %[[ARG0]] = s8[2,2,2] parameter(0) // CHECK: %[[ARG1]] = s8[2,2,3] parameter(1) // CHECK: ROOT -// CHECK-SAME: s32[2,2,3] dot(s8[2,2,2] %[[ARG0]], s8[2,2,3] %[[ARG1]]), +// CHECK-SAME: s32[2,2,3] dot(%[[ARG0]], %[[ARG1]]), // CHECK-SAME: lhs_batch_dims={0} // CHECK-SAME: lhs_contracting_dims={2} // CHECK-SAME: rhs_batch_dims={0} @@ -1668,7 +1668,10 @@ func.func @main(%arg0: tensor<2x2x2xi8>, %arg1: tensor<2x2x3xi8>) -> tensor<2x2x // CHECK: HloModule func.func @main(%arg0: tensor<10x16xbf16>, %arg1: tensor<32x20xbf16>, %meta: tensor<10x2xui16>) -> tensor<10x20xf32> { - // CHECK: dot(bf16[10,16] %{{.*}}, bf16[32,20] %{{.*}}, u16[10,2] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 + // CHECK: [[ARG0:%.*]] = bf16[10,16] parameter(0) + // CHECK: [[ARG1:%.*]] = bf16[32,20] parameter(1) + // CHECK: [[META:%.*]] = u16[10,2] parameter(2) + // CHECK: dot([[ARG0]], [[ARG1]], [[META]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sparsity=L.1@2:4 %0 = "mhlo.sparse_dot"(%arg0, %arg1, %meta) { lhs_sparsity = #mhlo.sparsity, dot_dimension_numbers = #mhlo.dot< @@ -1684,7 +1687,9 @@ func.func @main(%arg0: tensor<10x16xbf16>, %arg1: tensor<32x20xbf16>, %meta: ten // CHECK: HloModule func.func @main(%arg0: tensor<3x4xi32>, %arg1: tensor<4x5xi32>) -> tensor<3x5xi32> { // Simple einsum is lowered to HLO dot op. - // CHECK: dot(s32[3,4] %{{.*}}, s32[4,5] %{{.*}}), lhs_contracting_dims={1}, rhs_contracting_dims={0} + // CHECK: [[ARG0:%.*]] = s32[3,4] parameter(0) + // CHECK: [[ARG1:%.*]] = s32[4,5] parameter(1) + // CHECK: dot([[ARG0]], [[ARG1]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} %0 = "mhlo.einsum"(%arg0, %arg1) <{einsum_config = "ab,bc->ac"}> : (tensor<3x4xi32>, tensor<4x5xi32>) -> tensor<3x5xi32> func.return %0 : tensor<3x5xi32> } @@ -1699,7 +1704,7 @@ func.func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = f32[3,9] parameter(0) -// CHECK: c64[3,5] fft(f32[3,9] [[ARG]]), fft_type=RFFT, fft_length={9} +// CHECK: c64[3,5] fft([[ARG]]), fft_type=RFFT, fft_length={9} // ----- @@ -1707,7 +1712,7 @@ func.func @main(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex> { func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x300xf32> { // CHECK: [[ARG0:%.*]] = f32[200,100,300] parameter(0) // CHECK: [[ARG1:%.*]] = s32[10,2] parameter(1) - // CHECK: f32[10,300] gather(f32[200,100,300] [[ARG0]], s32[10,2] [[ARG1]]) + // CHECK: f32[10,300] gather([[ARG0]], [[ARG1]]) // CHECK-SAME: offset_dims={1} // CHECK-SAME: collapsed_slice_dims={0,1} // CHECK-SAME: start_index_map={0,1} @@ -1733,7 +1738,7 @@ func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tens func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<100x200x1xi32>) -> tensor<100x200x300xf32> { // CHECK: [[ARG0:%.*]] = f32[200,100,300] parameter(0) // CHECK: [[ARG1:%.*]] = s32[100,200,1] parameter(1) - // CHECK: f32[100,200,300] gather(f32[200,100,300] [[ARG0]], s32[100,200,1] [[ARG1]]) + // CHECK: f32[100,200,300] gather([[ARG0]], [[ARG1]]) // CHECK-SAME: offset_dims={2} // CHECK-SAME: collapsed_slice_dims={} // CHECK-SAME: start_index_map={2} @@ -1768,8 +1773,8 @@ func.func @main(%arg: tensor<4x2xf32>, %size: tensor) -> tensor { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = f32[4,2] parameter(0) // CHECK: [[SIZE:%.*]] = s32[] parameter(1) -// CHECK: [[DYNAMIC:%.*]] = f32[4,<=2] set-dimension-size(f32[4,2] [[ARG]], s32[] [[SIZE]]), dimensions={1} -// CHECK: ROOT %[[RESULT:.*]] = s32[] get-dimension-size(f32[4,<=2] [[DYNAMIC]]), dimensions={1} +// CHECK: [[DYNAMIC:%.*]] = f32[4,<=2] set-dimension-size([[ARG]], [[SIZE]]), dimensions={1} +// CHECK: ROOT %[[RESULT:.*]] = s32[] get-dimension-size([[DYNAMIC]]), dimensions={1} // ----- @@ -1784,7 +1789,7 @@ func.func @main(%arg: tensor>) - // CHECK: ENTRY // CHECK: [[ARG:%.*]] = f32[<=8,4] parameter(0) // CHECK: [[SIZE:%.*]] = s32[] constant(8) -// CHECK: ROOT [[DYNAMIC:%.*]] = f32[8,4] set-dimension-size(f32[<=8,4] [[ARG]], s32[] [[SIZE]]), dimensions={0} +// CHECK: ROOT [[DYNAMIC:%.*]] = f32[8,4] set-dimension-size([[ARG]], [[SIZE]]), dimensions={0} // ----- @@ -1796,7 +1801,7 @@ func.func @main(%arg0: tuple, tensor>) -> tensor { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = (f32[], s32[]) parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element((f32[], s32[]) %[[ARG0]]), index=0 +// CHECK: ROOT %[[RESULT:.*]] = f32[] get-tuple-element(%[[ARG0]]), index=0 // ----- @@ -1811,11 +1816,11 @@ func.func @main(%arg0: !mhlo.token) -> tuple, tensor>, // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: [[INFEED:%.*]] = ((s32[3,3], pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" -// CHECK: [[GTE1:%.*]] = (s32[3,3], pred[]) get-tuple-element(((s32[3,3], pred[]), token[]) [[INFEED]]), index=0 -// CHECK: [[GTE2:%.*]] = s32[3,3] get-tuple-element((s32[3,3], pred[]) [[GTE1]]), index=0 -// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element((s32[3,3], pred[]) [[GTE1]]), index=1 -// CHECK: [[GTE4:%.*]] = token[] get-tuple-element(((s32[3,3], pred[]), token[]) [[INFEED]]), index=1 +// CHECK: [[INFEED:%.*]] = ((s32[3,3], pred[]), token[]) infeed([[ARG]]), infeed_config="foobar" +// CHECK: [[GTE1:%.*]] = (s32[3,3], pred[]) get-tuple-element([[INFEED]]), index=0 +// CHECK: [[GTE2:%.*]] = s32[3,3] get-tuple-element([[GTE1]]), index=0 +// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element([[GTE1]]), index=1 +// CHECK: [[GTE4:%.*]] = token[] get-tuple-element([[INFEED]]), index=1 // ----- @@ -1827,10 +1832,10 @@ func.func @main(%arg0: !mhlo.token) -> tensor<3x3xi32> { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: [[INFEED:%.*]] = ((s32[3,3]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" -// CHECK: [[GTE0:%.*]] = (s32[3,3]) get-tuple-element(((s32[3,3]), token[]) [[INFEED]]), index=0 -// CHECK: ROOT [[GTE1:%.*]] = s32[3,3] get-tuple-element((s32[3,3]) [[GTE0]]), index=0 -// CHECK: [[GTE2:%.*]] = token[] get-tuple-element(((s32[3,3]), token[]) [[INFEED]]), index=1 +// CHECK: [[INFEED:%.*]] = ((s32[3,3]), token[]) infeed([[ARG]]), infeed_config="foobar" +// CHECK: [[GTE0:%.*]] = (s32[3,3]) get-tuple-element([[INFEED]]), index=0 +// CHECK: ROOT [[GTE1:%.*]] = s32[3,3] get-tuple-element([[GTE0]]), index=0 +// CHECK: [[GTE2:%.*]] = token[] get-tuple-element([[INFEED]]), index=1 // ----- @@ -1843,8 +1848,8 @@ func.func @main(%arg0: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: [[INFEED:%.*]] = ((), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" -// CHECK: ROOT [[TOKEN:%.*]] = token[] get-tuple-element(((), token[]) [[INFEED]]), index=1 +// CHECK: [[INFEED:%.*]] = ((), token[]) infeed([[ARG]]), infeed_config="foobar" +// CHECK: ROOT [[TOKEN:%.*]] = token[] get-tuple-element([[INFEED]]), index=1 // ----- @@ -1875,14 +1880,14 @@ func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: [[ARG_0:%.*]] = f32[] parameter(0) // CHECK: [[ARG_1:%.*]] = f32[] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[] add(f32[] [[ARG_0]], f32[] [[ARG_1]]) +// CHECK-SAME: f32[] add([[ARG_0]], [[ARG_1]]) // CHECK: } // CHECK: ENTRY // CHECK: [[ARG_2:%.*]] = f32[4] parameter(0) // CHECK: [[ARG_3:%.*]] = f32[4] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[4] map(f32[4] [[ARG_2]], f32[4] [[ARG_3]]), dimensions={0}, to_apply=[[COMPUTATION]] +// CHECK-SAME: f32[4] map([[ARG_2]], [[ARG_3]]), dimensions={0}, to_apply=[[COMPUTATION]] // ----- @@ -1903,7 +1908,7 @@ func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xi32>) -> tensor<4xf32> { // CHECK: [[ARG_2:%.*]] = f32[4] parameter(0) // CHECK: [[ARG_3:%.*]] = s32[4] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[4] map(f32[4] [[ARG_2]], s32[4] [[ARG_3]]), dimensions={0}, to_apply=[[COMPUTATION]] +// CHECK-SAME: f32[4] map([[ARG_2]], [[ARG_3]]), dimensions={0}, to_apply=[[COMPUTATION]] // ----- @@ -1916,9 +1921,9 @@ func.func @main(%data: tensor<3xi32>, %token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[DATA:%.*]] = s32[3] parameter(0) -// CHECK-DAG: [[DATATUPLE:%.*]] = (s32[3]) tuple(s32[3] [[DATA]]) +// CHECK-DAG: [[DATATUPLE:%.*]] = (s32[3]) tuple([[DATA]]) // CHECK-DAG: [[TOKEN:%.*]] = token[] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = token[] outfeed((s32[3]) [[DATATUPLE]], token[] [[TOKEN]]), outfeed_shape=(s32[3]{0}), outfeed_config="foobar" +// CHECK: ROOT %[[RESULT:.*]] = token[] outfeed([[DATATUPLE]], [[TOKEN]]), outfeed_shape=(s32[3]{0}), outfeed_config="foobar" // ----- @@ -1945,15 +1950,15 @@ func.func @main(%data: tensor<3x2xi32>, %token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[DATA:%.*]] = s32[3,2] parameter(0) -// CHECK: [[SHARD:%.*]] = s32[3,2] custom-call(s32[3,2] [[DATA]]) +// CHECK: [[SHARD:%.*]] = s32[3,2] custom-call([[DATA]]) // CHECK-SAME: custom_call_target="Sharding" // CHECK-SAME: sharding={devices=[1,2]0,1} -// CHECK: [[FULL:%.*]] = s32[6,2] custom-call(s32[3,2] [[SHARD]]) +// CHECK: [[FULL:%.*]] = s32[6,2] custom-call([[SHARD]]) // CHECK-SAME: custom_call_target="SPMDShardToFullShape" // CHECK-SAME: sharding={devices=[1,2]0,1} -// CHECK-DAG: [[DATATUPLE:%.*]] = (s32[6,2]) tuple(s32[6,2] [[FULL]]) +// CHECK-DAG: [[DATATUPLE:%.*]] = (s32[6,2]) tuple([[FULL]]) // CHECK-DAG: [[TOKEN:%.*]] = token[] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = token[] outfeed((s32[6,2]) [[DATATUPLE]], token[] [[TOKEN]]), outfeed_shape=(s32[6,2]{1,0}), outfeed_config="foobar", +// CHECK: ROOT %[[RESULT:.*]] = token[] outfeed([[DATATUPLE]], [[TOKEN]]), outfeed_shape=(s32[6,2]{1,0}), outfeed_config="foobar", // CHECK-SAME: sharding={ // CHECK-SAME: {devices=[2,1]0,1}, {maximal device=0} // CHECK-SAME: } @@ -1969,9 +1974,9 @@ func.func @main(%data1: tensor<3xi32>, %data2: tensor<3xi32>, %token: !mhlo.toke // CHECK: ENTRY // CHECK: [[DATA1:%.*]] = s32[3] parameter(0) // CHECK: [[DATA2:%.*]] = s32[3] parameter(1) -// CHECK-DAG: [[TUPLE:%.*]] = (s32[3], s32[3]) tuple(s32[3] [[DATA1]], s32[3] [[DATA2]]) +// CHECK-DAG: [[TUPLE:%.*]] = (s32[3], s32[3]) tuple([[DATA1]], [[DATA2]]) // CHECK-DAG: [[TOKEN:%.*]] = token[] parameter(2) -// CHECK: ROOT %[[RESULT:.*]] = token[] outfeed((s32[3], s32[3]) [[TUPLE]], token[] [[TOKEN]]), outfeed_shape=(s32[3]{0}, s32[3]{0}), outfeed_config="foobar" +// CHECK: ROOT %[[RESULT:.*]] = token[] outfeed([[TUPLE]], [[TOKEN]]), outfeed_shape=(s32[3]{0}, s32[3]{0}), outfeed_config="foobar" // ----- @@ -1984,7 +1989,7 @@ func.func @main(%token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK-DAG: [[EMPTY_TUPLE:%.*]] = () tuple() // CHECK-DAG: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: ROOT [[RESULT:%.*]] = token[] outfeed(() [[EMPTY_TUPLE]], token[] [[TOKEN]]), outfeed_shape=(), outfeed_config="foobar" +// CHECK: ROOT [[RESULT:%.*]] = token[] outfeed([[EMPTY_TUPLE]], [[TOKEN]]), outfeed_shape=(), outfeed_config="foobar" // ----- @@ -1998,7 +2003,7 @@ func.func @main(%arg: tensor<4x6xf32>, %pad: tensor) -> tensor<13x19xf32> { // CHECK: [[ARG:%.*]] = f32[4,6] parameter(0) // CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1 +// CHECK-SAME: f32[13,19] pad([[ARG]], [[PADDING_VAL]]), padding=2_4_1x3_5_1 // ----- @@ -2017,8 +2022,8 @@ func.func @main(%token: !mhlo.token) -> tuple, !mhlo.token> { // CHECK: ENTRY // CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer=true -// CHECK: (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer=true +// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv([[TOKEN]]), channel_id=5, is_host_transfer=true +// CHECK: (s32[3,4], token[]) recv-done([[RECV]]), channel_id=5, is_host_transfer=true // ----- @@ -2037,8 +2042,8 @@ func.func @main(%token: !mhlo.token) -> tuple, !mhlo.token> { // CHECK: ENTRY // CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5 -// CHECK: (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5 +// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv([[TOKEN]]), channel_id=5 +// CHECK: (s32[3,4], token[]) recv-done([[RECV]]), channel_id=5 // ----- @@ -2057,10 +2062,10 @@ func.func @main(%token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK-NEXT: [[ARG:%.*]] = token[] parameter(0) -// CHECK-NEXT: [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[ARG]]), channel_id=5 -// CHECK-NEXT: [[RECV_DONE:%.*]] = ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5 -// CHECK-NEXT: [[DATA:%.*]] = () get-tuple-element(((), token[]) [[RECV_DONE]]), index=0 -// CHECK-NEXT: ROOT [[TOKEN:%.*]] = token[] get-tuple-element(((), token[]) [[RECV_DONE]]), index=1 +// CHECK-NEXT: [[RECV:%.*]] = ((), u32[], token[]) recv([[ARG]]), channel_id=5 +// CHECK-NEXT: [[RECV_DONE:%.*]] = ((), token[]) recv-done([[RECV]]), channel_id=5 +// CHECK-NEXT: [[DATA:%.*]] = () get-tuple-element([[RECV_DONE]]), index=0 +// CHECK-NEXT: ROOT [[TOKEN:%.*]] = token[] get-tuple-element([[RECV_DONE]]), index=1 // ----- @@ -2077,16 +2082,16 @@ func.func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tens // CHECK: %[[REGION:region_[0-9]+]] // CHECK-SAME: ([[ARG_FA:.*]]: f32[], [[ARG_IA:.*]]: s32[], [[ARG_FB:.*]]: f32[], [[ARG_IB:.*]]: s32[]) -> (f32[], s32[]) -// CHECK: %[[FMAX:.*]] = f32[] maximum(f32[] %[[ARG_FA]], f32[] %[[ARG_FB]]) -// CHECK: %[[IMAX:.*]] = s32[] maximum(s32[] %[[ARG_IA]], s32[] %[[ARG_IB]]) -// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(f32[] %[[FMAX]], s32[] %[[IMAX]]) +// CHECK: %[[FMAX:.*]] = f32[] maximum(%[[ARG_FA]], %[[ARG_FB]]) +// CHECK: %[[IMAX:.*]] = s32[] maximum(%[[ARG_IA]], %[[ARG_IB]]) +// CHECK: ROOT %[[RESULT_REGION:.*]] = (f32[], s32[]) tuple(%[[FMAX]], %[[IMAX]]) // CHECK: ENTRY // CHECK-SAME: ([[ARG0:.*]]: f32[1,10], [[ARG1:.*]]: s32[1,10], [[ARG2:.*]]: f32[], [[ARG3:.*]]: s32[]) -> (f32[1], s32[1]) -// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(f32[1,10] %[[ARG0]], s32[1,10] %[[ARG1]], f32[] %[[ARG2]], s32[] %[[ARG3]]), dimensions={1}, to_apply=%[[REGION]] -// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=0 -// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element((f32[1], s32[1]) %[[RESULT]]), index=1 -// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(f32[1] %[[RESULT0]], s32[1] %[[RESULT1]]) +// CHECK: %[[RESULT:.*]] = (f32[1], s32[1]) reduce(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]), dimensions={1}, to_apply=%[[REGION]] +// CHECK: %[[RESULT0:.*]] = f32[1] get-tuple-element(%[[RESULT]]), index=0 +// CHECK: %[[RESULT1:.*]] = s32[1] get-tuple-element(%[[RESULT]]), index=1 +// CHECK: ROOT %[[RESULT:.*]] = (f32[1], s32[1]) tuple(%[[RESULT0]], %[[RESULT1]]) // ----- @@ -2108,12 +2113,12 @@ func.func @main(%arg0: tensor<2x17x31x7xi32>) -> tensor<2x5x8x7xi32> { } // CHECK: %[[MAX_COMPUTATION:.*]] ([[ARG0:.*]]: s32[], [[ARG1:.*]]: s32[]) -> s32[] -// CHECK: ROOT %[[RESULT:.*]] = s32[] maximum(s32[] %[[ARG0]], s32[] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = s32[] maximum(%[[ARG0]], %[[ARG1]]) // CHECK: ENTRY // CHECK-DAG: %[[ARG0:.*]] = s32[2,17,31,7] parameter(0) // CHECK-DAG: %[[INIT:.*]] = s32[] constant(-2147483648) -// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(s32[2,17,31,7] %[[ARG0]], s32[] %constant.2), +// CHECK: ROOT %[[RESULT:.*]] = s32[2,5,8,7] reduce-window(%[[ARG0]], %constant.2), // CHECK-SAME: window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, // CHECK-SAME: to_apply=%[[MAX_COMPUTATION]] @@ -2127,7 +2132,7 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<1x2xf32> { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(f32[2] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[1,2] reshape(%[[ARG0]]) // ----- @@ -2141,7 +2146,7 @@ func.func @main(%arg0 : tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10,11,12,13] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(f32[10,11,12,13] %[[ARG0]]), dimensions={1,2} +// CHECK: ROOT %[[RESULT:.*]] = f32[10,11,12,13] reverse(%[[ARG0]]), dimensions={1,2} // ----- @@ -2155,7 +2160,7 @@ func.func @main(%mu: tensor, %sigma: tensor) -> tensor<2x3x5xf32> { // CHECK: ENTRY // CHECK: %[[MU:.*]] = f32[] parameter(0) // CHECK: %[[SIGMA:.*]] = f32[] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[MU]], f32[] %[[SIGMA]]), distribution=rng_normal +// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(%[[MU]], %[[SIGMA]]), distribution=rng_normal // ----- @@ -2171,7 +2176,7 @@ func.func @main() -> tensor<2x3x5xf32> { // CHECK: ENTRY // CHECK-DAG: %[[A:.*]] = f32[] constant(0) // CHECK-DAG: %[[B:.*]] = f32[] constant(1) -// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(f32[] %[[A]], f32[] %[[B]]), distribution=rng_uniform +// CHECK: ROOT %[[RESULT:.*]] = f32[2,3,5] rng(%[[A]], %[[B]]), distribution=rng_uniform // ----- @@ -2200,7 +2205,7 @@ func.func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor // CHECK: [[VAL_2:%.*]] = s32[10,2] parameter(1) // CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2) // CHECK: ROOT -// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[10,2] [[VAL_2]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]] +// CHECK-SAME: f32[200,100,300] scatter([[VAL_1]], [[VAL_2]], [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]] // ----- @@ -2230,7 +2235,7 @@ func.func @main(%input_tensor: tensor<200x100x300xf32>, %scatter_indices: tensor // CHECK: [[VAL_2:%.*]] = s32[100,200,1] parameter(1) // CHECK: [[VAL_3:%.*]] = f32[100,200,300] parameter(2) // CHECK: ROOT -// CHECK-SAME: f32[200,100,300] scatter(f32[200,100,300] [[VAL_1]], s32[100,200,1] [[VAL_2]], f32[100,200,300] [[VAL_3]]), update_window_dims={2}, inserted_window_dims={}, scatter_dims_to_operand_dims={2}, input_batching_dims={0,1}, scatter_indices_batching_dims={1,0}, index_vector_dim=2, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]] +// CHECK-SAME: f32[200,100,300] scatter([[VAL_1]], [[VAL_2]], [[VAL_3]]), update_window_dims={2}, inserted_window_dims={}, scatter_dims_to_operand_dims={2}, input_batching_dims={0,1}, scatter_indices_batching_dims={1,0}, index_vector_dim=2, indices_are_sorted=true, unique_indices=true, to_apply=[[COMPUTATION]] // ----- @@ -2250,7 +2255,7 @@ func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi64>, %arg2: // CHECK: [[VAL_1:%.*]] = f32[200,100,300] parameter(0) // CHECK: [[VAL_2:%.*]] = s64[10,2] parameter(1) // CHECK: [[VAL_3:%.*]] = f32[10,300] parameter(2) -// CHECK: (f32[200,100,300], f32[200,100,300]) scatter(f32[200,100,300] [[VAL_1]], f32[200,100,300] [[VAL_1]], s64[10,2] [[VAL_2]], f32[10,300] [[VAL_3]], f32[10,300] [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=[[COMPUTATION]] +// CHECK: (f32[200,100,300], f32[200,100,300]) scatter([[VAL_1]], [[VAL_1]], [[VAL_2]], [[VAL_3]], [[VAL_3]]), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=[[COMPUTATION]] // ----- @@ -2258,11 +2263,11 @@ func.func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi64>, %arg2: // CHECK: HloModule func.func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK: %[[ARG0:.*]] = pred[] parameter(0) - // CHECK: %[[COND:.*]] = pred[2,3] broadcast(pred[] %[[ARG0]]), dimensions={} + // CHECK: %[[COND:.*]] = pred[2,3] broadcast(%[[ARG0]]), dimensions={} // CHECK: %[[ARG1:.*]] = s32[2,3] parameter(1) // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2) - // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]]) + // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(%[[COND]], %[[ARG1]], %[[ARG2]]) %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> func.return %0 : tensor<2x3xi32> } @@ -2288,10 +2293,10 @@ func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) } // CHECK: %[[SELECT_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] { -// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GE, type=TOTALORDER +// CHECK: ROOT %[[RESULT:.*]] = pred[] compare(%[[ARG0]], %[[ARG1]]), direction=GE, type=TOTALORDER // CHECK: %[[SCATTER_COMPUTATION:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> f32[] { -// CHECK: ROOT %[[RESULT:.*]] = f32[] add(f32[] %[[ARG0]], f32[] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[] add(%[[ARG0]], %[[ARG1]]) // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[10,24,24,64] parameter(0) @@ -2299,7 +2304,7 @@ func.func @main(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>) // CHECK: %[[INIT:.*]] = f32[] constant(0) // CHECK: ROOT %[[RESULT:.*]] = f32[10,24,24,64] -// CHECK-SAME: select-and-scatter(f32[10,24,24,64] %[[ARG0]], f32[10,12,12,64] %[[ARG1]], f32[] %[[INIT]]), +// CHECK-SAME: select-and-scatter(%[[ARG0]], %[[ARG1]], %[[INIT]]), // CHECK-SAME: window={size=1x2x2x1 stride=1x2x2x1}, // CHECK-SAME: select=%[[SELECT_COMPUTATION]], scatter=%[[SCATTER_COMPUTATION]] @@ -2320,9 +2325,9 @@ func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) // CHECK: [[TOKEN:%.*]] = token[] parameter(1) -// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer=true +// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send([[ARG]], [[TOKEN]]), channel_id=5, is_host_transfer=true // CHECK: ROOT -// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5, is_host_transfer=true +// CHECK-SAME: token[] send-done([[SEND]]), channel_id=5, is_host_transfer=true // ----- @@ -2341,9 +2346,9 @@ func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) // CHECK: [[TOKEN:%.*]] = token[] parameter(1) -// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5 +// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send([[ARG]], [[TOKEN]]), channel_id=5 // CHECK: ROOT -// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5 +// CHECK-SAME: token[] send-done([[SEND]]), channel_id=5 // ----- @@ -2362,9 +2367,9 @@ func.func @main(%token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK-DAG: [[ARG:%.*]] = () tuple() // CHECK-DAG: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send(() [[ARG]], token[] [[TOKEN]]), channel_id=5 +// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send([[ARG]], [[TOKEN]]), channel_id=5 // CHECK: ROOT -// CHECK-SAME: token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5 +// CHECK-SAME: token[] send-done([[SEND]]), channel_id=5 // ----- @@ -2378,7 +2383,7 @@ func.func @main(%arg: tensor<4x4xf32>, %size: tensor) -> tensor<4x4xf32> { // CHECK: [[ARG:%.*]] = f32[4,4] parameter(0) // CHECK: [[SIZE:%.*]] = s32[] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[4,<=4] set-dimension-size(f32[4,4] [[ARG]], s32[] [[SIZE]]), dimensions={1} +// CHECK-SAME: f32[4,<=4] set-dimension-size([[ARG]], [[SIZE]]), dimensions={1} // ----- @@ -2391,7 +2396,7 @@ func.func @main(%arg: tensor<3x4xi32>) -> tensor<1x2xi32> { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) // CHECK: ROOT -// CHECK-SAME: s32[1,2] slice(s32[3,4] [[ARG]]), slice={[1:2:1], [0:4:2]} +// CHECK-SAME: s32[1,2] slice([[ARG]]), slice={[1:2:1], [0:4:2]} // ----- @@ -2406,7 +2411,7 @@ func.func @main(%arg: tensor<3x4xi32>, %start1: tensor, %start2: tensor, %start1: tensor, %start2: tensor) -> tensor<2x1x4x3xi32> { // CHECK: [[ARG:%.*]] = s32[1,2,3,4] parameter(0) - // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] [[ARG]]), dimensions={1,0,3,2} + // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose([[ARG]]), dimensions={1,0,3,2} %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}> : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> func.return %0 : tensor<2x1x4x3xi32> } @@ -2430,7 +2435,7 @@ func.func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf3 // CHECK: [[ARG_A:%.*]] = f32[4,4] parameter(0) // CHECK: [[ARG_B:%.*]] = f32[4,3] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[4,3] triangular-solve(f32[4,4] [[ARG_A]], f32[4,3] [[ARG_B]]), left_side=true, lower=true, unit_diagonal=true, transpose_a=NO_TRANSPOSE +// CHECK-SAME: f32[4,3] triangular-solve([[ARG_A]], [[ARG_B]]), left_side=true, lower=true, unit_diagonal=true, transpose_a=NO_TRANSPOSE // ----- @@ -2443,24 +2448,24 @@ func.func @main(%arg0: tensor, %arg1 : tensor) -> tuple, t // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[] parameter(0) // CHECK: %[[ARG1:.*]] = s32[] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(f32[] %[[ARG0]], s32[] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = (f32[], s32[]) tuple(%[[ARG0]], %[[ARG1]]) // ----- // CHECK: HloModule func.func @main(%arg_f32: tensor<4xf32>, %arg_i32: tensor<4xi32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32>) { // CHECK: [[ARG_F32:%.*]] = f32[4] parameter(0) - // CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one(f32[4] [[ARG_F32]]) + // CHECK: [[EXPM1:%.*]] = f32[4] exponential-minus-one([[ARG_F32]]) %expm1 = "mhlo.exponential_minus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one(f32[4] [[ARG_F32]]) + // CHECK: [[LOG1P:%.*]] = f32[4] log-plus-one([[ARG_F32]]) %log1p = "mhlo.log_plus_one"(%arg_f32) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: [[ARG_I32:%.*]] = s32[4] parameter(1) - // CHECK: [[NOT:%.*]] = s32[4] not(s32[4] [[ARG_I32]]) + // CHECK: [[NOT:%.*]] = s32[4] not([[ARG_I32]]) %not = "mhlo.not"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> - // CHECK: [[POPCNT:%.*]] = s32[4] popcnt(s32[4] [[ARG_I32]]) + // CHECK: [[POPCNT:%.*]] = s32[4] popcnt([[ARG_I32]]) %popcnt = "mhlo.popcnt"(%arg_i32) : (tensor<4xi32>) -> tensor<4xi32> func.return %expm1, %log1p, %not, %popcnt : tensor<4xf32>, tensor<4xf32>, tensor<4xi32>, tensor<4xi32> @@ -2473,7 +2478,7 @@ func.func @main(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: [[VAL_1:%.*]] = pred[4] parameter(0) // CHECK: [[VAL_2:%.*]] = pred[4] parameter(1) %0 = mhlo.xor %arg0, %arg1 : tensor<4xi1> - // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor(pred[4] [[VAL_1]], pred[4] [[VAL_2]]) + // CHECK: ROOT [[VAL_3:%.*]] = pred[4] xor([[VAL_1]], [[VAL_2]]) func.return %0 : tensor<4xi1> } @@ -2490,11 +2495,11 @@ func.func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { } // CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] { -// CHECK: ROOT %compare.{{[0-9+]}} = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT +// CHECK: ROOT %compare.{{[0-9+]}} = pred[] compare(%[[ARG0]], %[[ARG1]]), direction=GT -// CHECK: [[SORT:%.+]] = (f32[16,16], s32[16,16]) sort(f32[16,16] %Arg_0.1, s32[16,16] %Arg_1.2), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] -// CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0 -// CHECK: [[GET1:%.+]] = s32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=1 +// CHECK: [[SORT:%.+]] = (f32[16,16], s32[16,16]) sort(%Arg_0.1, %Arg_1.2), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] +// CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element([[SORT]]), index=0 +// CHECK: [[GET1:%.+]] = s32[16,16] get-tuple-element([[SORT]]), index=1 // ----- @@ -2509,9 +2514,9 @@ func.func @main(%input0: tensor<16x16xf32>) { } // CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[]) -> pred[] { -// CHECK: ROOT %[[CMP:.*]] = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT +// CHECK: ROOT %[[CMP:.*]] = pred[] compare(%[[ARG0]], %[[ARG1]]), direction=GT -// CHECK: %[[RESULT:.*]] = f32[16,16] sort(f32[16,16] %Arg_0.1), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] +// CHECK: %[[RESULT:.*]] = f32[16,16] sort(%Arg_0.1), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] // ----- @@ -2533,7 +2538,7 @@ func.func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] custom-call(f32[16,16] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] custom-call(%[[ARG0]]) // CHECK-SAME: custom_call_target="Sharding" // CHECK-SAME: sharding={devices=[1,2]0,1} @@ -2568,10 +2573,10 @@ func.func @main(%arg0: tensor<2xcomplex>, %arg1: tensor<2xcomplex>) -> // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = c64[2] parameter(0) -// CHECK: %[[ABS0:.*]] = f32[2] abs(c64[2] %[[ARG0]]) +// CHECK: %[[ABS0:.*]] = f32[2] abs(%[[ARG0]]) // CHECK: %[[ARG1:.*]] = c128[2] parameter(1) -// CHECK: %[[ABS1:.*]] = f64[2] abs(c128[2] %[[ARG1]]) -// CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(f32[2] %[[ABS0]], f64[2] %[[ABS1]]) +// CHECK: %[[ABS1:.*]] = f64[2] abs(%[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = (f32[2], f64[2]) tuple(%[[ABS0]], %[[ABS1]]) // ----- @@ -2583,7 +2588,7 @@ func.func @main(%arg0: tensor<4xui8>) -> tensor<4xui8> { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = u8[4] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = u8[4] not(u8[4] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = u8[4] not(%[[ARG0]]) // ----- @@ -2596,7 +2601,7 @@ func.func @main(%arg0: tensor<4xi32>) -> tensor<*xi32> { // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = s32[4] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = s32[4] not(s32[4] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = s32[4] not(%[[ARG0]]) // ----- @@ -2629,11 +2634,11 @@ func.func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> tuple, %token: !mhlo.token) -> !mhlo.token { // CHECK: HloModule func.func @main(%arg: tensor<3xui64>) -> tuple, tensor<2x2xui32>> { // CHECK: %[[ARG0:.*]] = u64[3] parameter(0) -// CHECK: [[RNG:%.*]] = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %[[ARG0]]), algorithm=rng_philox -// CHECK: [[GTE0:%.*]] = u64[3] get-tuple-element((u64[3], u32[2,2]) [[RNG]]), index=0 -// CHECK: [[GTE1:%.*]] = u32[2,2] get-tuple-element((u64[3], u32[2,2]) [[RNG]]), index=1 +// CHECK: [[RNG:%.*]] = (u64[3], u32[2,2]) rng-bit-generator(%[[ARG0]]), algorithm=rng_philox +// CHECK: [[GTE0:%.*]] = u64[3] get-tuple-element([[RNG]]), index=0 +// CHECK: [[GTE1:%.*]] = u32[2,2] get-tuple-element([[RNG]]), index=1 // CHECK: ROOT -// CHECK-SAME: [[RES:%.*]] = (u64[3], u32[2,2]) tuple(u64[3] [[GTE0]], u32[2,2] [[GTE1]]) +// CHECK-SAME: [[RES:%.*]] = (u64[3], u32[2,2]) tuple([[GTE0]], [[GTE1]]) %0:2 = "mhlo.rng_bit_generator"(%arg) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>) %1 = "mhlo.tuple"(%0#0, %0#1) : (tensor<3xui64>, tensor<2x2xui32>) -> tuple, tensor<2x2xui32>> func.return %1 : tuple, tensor<2x2xui32>> @@ -2684,7 +2689,7 @@ func.func @main(%arg: tensor<3xui64>) -> tuple, tensor<2x2xui32>> // CHECK: HloModule func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] cbrt(f32[3,4] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] cbrt(%[[ARG0]]) %0 = "mhlo.cbrt"(%arg) : (tensor<3x4xf32>) -> tensor<3x4xf32> func.return %0 : tensor<3x4xf32> } @@ -2694,7 +2699,7 @@ func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: HloModule func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] reduce-precision(f32[3,4] %[[ARG0]]), exponent_bits=8, mantissa_bits=10 +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] reduce-precision(%[[ARG0]]), exponent_bits=8, mantissa_bits=10 %0 = "mhlo.reduce_precision"(%arg) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32> func.return %0 : tensor<3x4xf32> } @@ -2704,7 +2709,7 @@ func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: HloModule func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> { // CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = f32[3,4,1] bitcast(f32[3,4] %[[ARG0]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4,1] bitcast(%[[ARG0]]) %0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32> func.return %0 : tensor<3x4x1xf32> } @@ -2715,11 +2720,11 @@ func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> { func.func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>) { // CHECK: %[[ARG0:.*]] = f32[4,4] parameter(0) // CHECK: %[[ARG1:.*]] = f32[3,4] parameter(1) -// CHECK: %[[ARGS:.*]] = (f32[4,4], f32[3,4]) tuple(f32[4,4] %[[ARG0]], f32[3,4] %[[ARG1]]), sharding={{\{}}{replicated}, {devices=[1,2]<=[2]}} -// CHECK: %[[OPT:.*]] = (f32[4,4], f32[3,4]) opt-barrier((f32[4,4], f32[3,4]) %[[ARGS]]), sharding={{\{}}{replicated}, {devices=[1,2]<=[2]}} -// CHECK: %[[GTE0:.*]] = f32[4,4] get-tuple-element((f32[4,4], f32[3,4]) %[[OPT]]), index=0, sharding={replicated} -// CHECK: %[[GTE1:.*]] = f32[3,4] get-tuple-element((f32[4,4], f32[3,4]) %[[OPT]]), index=1, sharding={devices=[1,2]<=[2]} -// CHECK: ROOT %[[ROOT:.*]] = (f32[4,4], f32[3,4]) tuple(f32[4,4] %[[GTE0]], f32[3,4] %[[GTE1]]) +// CHECK: %[[ARGS:.*]] = (f32[4,4], f32[3,4]) tuple(%[[ARG0]], %[[ARG1]]), sharding={{\{}}{replicated}, {devices=[1,2]<=[2]}} +// CHECK: %[[OPT:.*]] = (f32[4,4], f32[3,4]) opt-barrier(%[[ARGS]]), sharding={{\{}}{replicated}, {devices=[1,2]<=[2]}} +// CHECK: %[[GTE0:.*]] = f32[4,4] get-tuple-element(%[[OPT]]), index=0, sharding={replicated} +// CHECK: %[[GTE1:.*]] = f32[3,4] get-tuple-element(%[[OPT]]), index=1, sharding={devices=[1,2]<=[2]} +// CHECK: ROOT %[[ROOT:.*]] = (f32[4,4], f32[3,4]) tuple(%[[GTE0]], %[[GTE1]]) %0, %1 = "mhlo.optimization_barrier"(%arg0, %arg1) {mhlo.sharding = "{{replicated}, {devices=[1,2]<=[2]}}"} : (tensor<4x4xf32>, tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>) func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xf32> } @@ -2749,7 +2754,7 @@ func.func private @main(%arg0: tensor) -> tensor { func.func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: %[[ARG0:.*]] = f32[4,4] parameter(0) // CHECK: %[[ARG1:.*]] = f32[3,4] parameter(1) -// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] triangular-solve(f32[4,4] %[[ARG0]], f32[3,4] %[[ARG1]]), lower=true, transpose_a=NO_TRANSPOSE +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] triangular-solve(%[[ARG0]], %[[ARG1]]), lower=true, transpose_a=NO_TRANSPOSE %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = false, lower = true, transpose_a = #mhlo, unit_diagonal = false} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> func.return %0: tensor<3x4xf32> } @@ -2760,18 +2765,18 @@ func.func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf3 // CHECK: %[[APPLYFN:.*]] ({{.*}}) -> (f32[], s32[]) { // CHECK: %[[A0:.*]] = f32[] parameter(0) // CHECK: %[[B0:.*]] = f32[] parameter(2) -// CHECK: %[[ADDF32:.*]] = f32[] add(f32[] %[[A0]], f32[] %[[B0]]) +// CHECK: %[[ADDF32:.*]] = f32[] add(%[[A0]], %[[B0]]) // CHECK: %[[A1:.*]] = s32[] parameter(1) // CHECK: %[[B1:.*]] = s32[] parameter(3) -// CHECK: %[[ADDS32:.*]] = s32[] add(s32[] %[[A1]], s32[] %[[B1]]) -// CHECK: ROOT %{{.*}} = (f32[], s32[]) tuple(f32[] %[[ADDF32]], s32[] %[[ADDS32]]) +// CHECK: %[[ADDS32:.*]] = s32[] add(%[[A1]], %[[B1]]) +// CHECK: ROOT %{{.*}} = (f32[], s32[]) tuple(%[[ADDF32]], %[[ADDS32]]) // CHECK: ENTRY // CHECK: %[[ARG0:.*]] = f32[4,2] parameter(0) // CHECK: %[[ARG1:.*]] = s32[4,2] parameter(1) // CHECK: %[[ARG2:.*]] = f32[] parameter(2) // CHECK: %[[ARG3:.*]] = s32[] parameter(3) -// CHECK: (f32[2,2], s32[2,2]) reduce-window(f32[4,2] %[[ARG0]], s32[4,2] %[[ARG1]], f32[] %[[ARG2]], s32[] %[[ARG3]]) +// CHECK: (f32[2,2], s32[2,2]) reduce-window(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) // CHECK-SAME: window={size=5x1 stride=3x1 pad=2_2x0_0} // CHECK-SAME: to_apply=%[[APPLYFN]] func.func @main(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor, %init1: tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) { @@ -2793,7 +2798,7 @@ func.func @main(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor) -> tensor<2xf32> { // CHECK: %[[ARG0:.*]] = f32[2] parameter(0) %0 = "mhlo.round_nearest_even"(%arg0) {} : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: round-nearest-even(f32[2] %[[ARG0]]) + // CHECK: round-nearest-even(%[[ARG0]]) func.return %0 : tensor<2xf32> } @@ -2803,7 +2808,7 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[ARG0:.*]] = f32[2] parameter(0) %0 = "mhlo.tan"(%arg0) {} : (tensor<2xf32>) -> tensor<2xf32> - // CHECK: tan(f32[2] %[[ARG0]]) + // CHECK: tan(%[[ARG0]]) func.return %0 : tensor<2xf32> } @@ -2813,7 +2818,7 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { func.func @main(%arg0: tensor<4x4xf32>) -> (tensor<4x2xf32>, tensor<4x2xi32>) { // CHECK: %[[ARG0:.*]] = f32[4,4] parameter(0) %0:2 = "mhlo.topk"(%arg0) {k = 2, largest = true} : (tensor<4x4xf32>) -> (tensor<4x2xf32>, tensor<4x2xi32>) - // CHECK: (f32[4,2], s32[4,2]) topk(f32[4,4] %[[ARG0]]), k=2, largest=true + // CHECK: (f32[4,2], s32[4,2]) topk(%[[ARG0]]), k=2, largest=true func.return %0#0, %0#1 : tensor<4x2xf32>, tensor<4x2xi32> } @@ -2855,7 +2860,7 @@ func.func @main(%arg0: tuple, tensor<2x3xf32>>, %arg1: tensor<5x func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0) // CHECK: %[[TOK:.*]] = token[] after-all() -// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] add-dependency(f32[3,4] %[[ARG0]], token[] %[[TOK]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] add-dependency(%[[ARG0]], %[[TOK]]) %token = "mhlo.after_all"() : () -> !mhlo.token %0 = "mhlo.add_dependency"(%arg, %token) : (tensor<3x4xf32>, !mhlo.token) -> tensor<3x4xf32> func.return %0 : tensor<3x4xf32> @@ -2876,7 +2881,7 @@ func.func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> attributes {execution_ // CHECK: HloModule func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[ARG0:.*]] = s32[2,2] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = s32[2,2] all-to-all(s32[2,2] %[[ARG0]]), channel_id=1, replica_groups={{.}}{1,2},{0,3}}, dimensions={1} +// CHECK: ROOT %[[RESULT:.*]] = s32[2,2] all-to-all(%[[ARG0]]), channel_id=1, replica_groups={{.}}{1,2},{0,3}}, dimensions={1} %0 = "mhlo.all_to_all"(%arg0) { concat_dimension = 1 : i64, replica_groups = dense<[[1, 2], [0, 3]]> : tensor<2x2xi64>, @@ -2891,7 +2896,7 @@ func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { func.func private @main(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32>) -> tuple, tensor<128x4xf32>> { // CHECK: %[[ARG0:.*]] = f32[128,4] parameter(0) // CHECK: %[[ARG1:.*]] = f32[128,4] parameter(1) -// CHECK: (f32[128,4], f32[128,4]) all-to-all(f32[128,4] %[[ARG0]], f32[128,4] %[[ARG1]]), channel_id=1, replica_groups={{.}}{0,1}} +// CHECK: (f32[128,4], f32[128,4]) all-to-all(%[[ARG0]], %[[ARG1]]), channel_id=1, replica_groups={{.}}{0,1}} %0:2 = "mhlo.all_to_all"(%arg0, %arg1) { replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, channel_handle = #mhlo.channel_handle @@ -2922,7 +2927,7 @@ func.func @main(%arg0: tensor<2x3xf32>, %arg1: tensor<5x5xf32>) -> tensor<1x2x3x // CHECK: [[VAL_1:%.*]] = f32[2,3] parameter(0) // CHECK: [[VAL_2:%.*]] = f32[5,5] parameter(1) // CHECK: ROOT -// CHECK-SAME: f32[1,2,3] custom-call(f32[2,3] [[VAL_1]], f32[5,5] [[VAL_2]]) +// CHECK-SAME: f32[1,2,3] custom-call([[VAL_1]], [[VAL_2]]) // CHECK-SAME: custom_call_target="foo" // CHECK-SAME: custom_call_has_side_effect=true // CHECK-SAME: api_version=API_VERSION_TYPED_FFI @@ -2948,8 +2953,8 @@ func.func @main(%operand: tensor) -> tensor { // CHECK: HloModule {{.*}}, entry_computation_layout={(f32[?,784]{1,0})->f32[?,784]{1,0}} // CHECK-EMPTY: // CHECK-NEXT: ENTRY {{.*}} ([[ARG0:.*]]: f32[?,784]) -> f32[?,784] { -// CHECK-NEXT: [[ARG0]] = f32[?,784] parameter(0) -// CHECK-NEXT: ROOT {{.*}} = f32[?,784] abs(f32[?,784] %Arg_0.1), {{.*}} +// CHECK-NEXT: %[[ARG0]] = f32[?,784] parameter(0) +// CHECK-NEXT: ROOT {{.*}} = f32[?,784] abs(%[[ARG0]]), {{.*}} // CHECK-NEXT: } // ----- diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir index add453c9a276df..70efb9c6b6d28a 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir @@ -20,12 +20,12 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { // CHECK: ENTRY // CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(f32[128,32] %[[INPUT]]) +// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(%[[INPUT]]) // CHECK-SAME: channel_id=1 // CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} // CHECK-SAME: dimensions={1} // CHECK-SAME: use_global_device_ids=true -// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(f32[128,128] %[[OUTPUT]] +// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(%[[OUTPUT]] // ----- @@ -55,11 +55,11 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { // CHECK: ENTRY // CHECK: %[[INPUT:.*]] = f32[10] parameter(0) -// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(f32[10] %[[INPUT]]) +// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(%[[INPUT]]) // CHECK-SAME: channel_id=5 // CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} // CHECK-SAME: use_global_device_ids=true -// CHECK: ROOT {{.*}} f32[10] all-reduce-done(f32[10] %[[OUTPUT]] +// CHECK: ROOT {{.*}} f32[10] all-reduce-done(%[[OUTPUT]] // ----- @@ -125,10 +125,10 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // CHECK: ENTRY // CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(f32[128,32] %[[INPUT]]) +// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(%[[INPUT]]) // CHECK-SAME: channel_id=1 // CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}} -// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(f32[128,32] %[[OUTPUT]] +// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(%[[OUTPUT]] // ----- @@ -146,9 +146,9 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // CHECK: ENTRY // CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %[[INPUT]]) +// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(%[[INPUT]]) // CHECK-SAME: cross_program_prefetch_index=0 -// CHECK: ROOT {{.*}} f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %[[OUTPUT]] +// CHECK: ROOT {{.*}} f32[128,32] copy-done(%[[OUTPUT]] // ----- @@ -173,8 +173,8 @@ func.func @main(%token: !mhlo.token) -> (!mhlo.token) { // CHECK: ENTRY // CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5 -// CHECK: ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5 +// CHECK: [[RECV:%.*]] = ((), u32[], token[]) recv([[TOKEN]]), channel_id=5 +// CHECK: ((), token[]) recv-done([[RECV]]), channel_id=5 // ----- @@ -198,17 +198,17 @@ func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) { // CHECK: ENTRY // CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer +// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv([[TOKEN]]), channel_id=5, is_host_transfer // CHECK-SAME: sharding={ // CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0} // CHECK-SAME: } -// CHECK: [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer +// CHECK: [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done([[RECV]]), channel_id=5, is_host_transfer // CHECK-SAME: sharding={ // CHECK-SAME: {maximal device=0}, {maximal device=0} // CHECK-SAME: } -// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0} -// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0} -// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]]) +// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element([[RECV_DONE]]), index=0, sharding={maximal device=0} +// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element([[RECV_DONE]]), index=1, sharding={maximal device=0} +// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple([[TUPLE0]], [[TUPLE1]]) // ----- @@ -233,9 +233,9 @@ func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) // CHECK: [[TOKEN:%.*]] = token[] parameter(1) -// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer +// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send([[ARG]], [[TOKEN]]), channel_id=5, is_host_transfer // CHECK: ROOT -// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5 +// CHECK-SAME: token[] send-done([[SEND]]), channel_id=5 // ----- @@ -258,9 +258,9 @@ func.func @main(%token: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send(() [[UNIT:%.*]], token[] [[TOKEN]]), channel_id=5 +// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send([[UNIT:%.*]], [[TOKEN]]), channel_id=5 // CHECK: ROOT -// CHECK-SAME: token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5 +// CHECK-SAME: token[] send-done([[SEND]]), channel_id=5 // ----- @@ -275,12 +275,12 @@ func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> // CHECK: ENTRY func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) - // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]) + // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(%[[ARG0]]) // CHECK-SAME: calls=[[CALLED_COMPUTATION]] %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> - // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) + // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(%[[START]]) %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> - // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) + // CHECK: ROOT %{{.*}} = (f32[20]) async-done(%[[UPDATE]]) %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> return %2 : tensor<20xf32> } @@ -300,10 +300,10 @@ func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> // CHECK: ENTRY func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) - // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]], - // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) + // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(%[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]], + // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(%[[START]]) // CHECK: ROOT - // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) + // CHECK-SAME: (f32[20]) async-done(%[[UPDATE]]) %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> @@ -321,12 +321,12 @@ func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { // CHECK: ENTRY func.func @main() -> tensor<1x2xf32> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "", outputs = "_retval0"}} { // CHECK: %[[AFTER_ALL:.*]] = token[] after-all() - // CHECK-NEXT: %[[RECV:.*]] = (f32[1,2], u32[], token[]) recv(token[] %[[AFTER_ALL]]), channel_id=2, is_host_transfer=true, + // CHECK-NEXT: %[[RECV:.*]] = (f32[1,2], u32[], token[]) recv(%[[AFTER_ALL]]), channel_id=2, is_host_transfer=true, // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} - // CHECK-NEXT: %[[RECV_DONE:.*]] = (f32[1,2], token[]) recv-done((f32[1,2], u32[], token[]) %[[RECV]]), channel_id=2, is_host_transfer=true, + // CHECK-NEXT: %[[RECV_DONE:.*]] = (f32[1,2], token[]) recv-done(%[[RECV]]), channel_id=2, is_host_transfer=true, // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} - // CHECK-NEXT: ROOT %[[GET_TUPLE_0:.*]] = f32[1,2] get-tuple-element((f32[1,2], token[]) %[[RECV_DONE]]), index=0, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} - // CHECK-NEXT: %[[GET_TUPLE_1:.*]] = token[] get-tuple-element((f32[1,2], token[]) %[[RECV_DONE]]), index=1, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} + // CHECK-NEXT: ROOT %[[GET_TUPLE_0:.*]] = f32[1,2] get-tuple-element(%[[RECV_DONE]]), index=0, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} + // CHECK-NEXT: %[[GET_TUPLE_1:.*]] = token[] get-tuple-element(%[[RECV_DONE]]), index=1, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_1_retvals_htod_0"} %0 = mhlo.create_token : !mhlo.token %1:2 = "mhlo.recv"(%0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "host_compute_channel_1_retvals_htod_0"}, mhlo.sharding = "\08\04"} : (!mhlo.token) -> (tensor<1x2xf32>, !mhlo.token) return %1#0 : tensor<1x2xf32> @@ -346,15 +346,15 @@ func.func @main() -> tensor<1x2xf32> attributes {allow_soft_placement = false, t func.func @main(%arg0: tensor<1x2xi64>) -> tensor<1x2xi64> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "_arg0", outputs = "_retval0"}} { // CHECK: %[[ARG0:.*]] = s64[1,2] parameter(0) // CHECK-NEXT: %[[AFTER_ALL:.*]] = token[] after-all() - // CHECK-NEXT: %[[SEND:.*]] = (s64[1,2], u32[], token[]) send(s64[1,2] %[[ARG0]], token[] %[[AFTER_ALL]]), channel_id=3, is_host_transfer=true, + // CHECK-NEXT: %[[SEND:.*]] = (s64[1,2], u32[], token[]) send(%[[ARG0]], %[[AFTER_ALL]]), channel_id=3, is_host_transfer=true, // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_args_dtoh_0"} - // CHECK-NEXT: %[[SEND_DONE:.*]] = token[] send-done((s64[1,2], u32[], token[]) %[[SEND]]), channel_id=3, is_host_transfer=true, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_args_dtoh_0"} - // CHECK-NEXT: %[[RECV:.*]] = (s64[1,2], u32[], token[]) recv(token[] %[[SEND_DONE]]), channel_id=4, is_host_transfer=true, + // CHECK-NEXT: %[[SEND_DONE:.*]] = token[] send-done(%[[SEND]]), channel_id=3, is_host_transfer=true, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_args_dtoh_0"} + // CHECK-NEXT: %[[RECV:.*]] = (s64[1,2], u32[], token[]) recv(%[[SEND_DONE]]), channel_id=4, is_host_transfer=true, // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} - // CHECK-NEXT: %[[RECV_DONE:.*]] = (s64[1,2], token[]) recv-done((s64[1,2], u32[], token[]) %[[RECV]]), channel_id=4, is_host_transfer=true, + // CHECK-NEXT: %[[RECV_DONE:.*]] = (s64[1,2], token[]) recv-done(%[[RECV]]), channel_id=4, is_host_transfer=true, // CHECK-SAME{LITERAL}: sharding={{manual}, {manual}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} - // CHECK-NEXT: ROOT %[[GET_TUPLE_0:.*]] = s64[1,2] get-tuple-element((s64[1,2], token[]) %[[RECV_DONE]]), index=0, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} - // CHECK-NEXT: %[[GET_TUPLE_1:.*]] = token[] get-tuple-element((s64[1,2], token[]) %[[RECV_DONE]]), index=1, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} + // CHECK-NEXT: ROOT %[[GET_TUPLE_0:.*]] = s64[1,2] get-tuple-element(%[[RECV_DONE]]), index=0, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} + // CHECK-NEXT: %[[GET_TUPLE_1:.*]] = token[] get-tuple-element(%[[RECV_DONE]]), index=1, sharding={manual}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="host_compute_channel_0_retvals_htod_0"} %0 = mhlo.create_token : !mhlo.token %1 = "mhlo.send"(%arg0, %0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "host_compute_channel_0_args_dtoh_0"}, mhlo.sharding = "\08\04"} : (tensor<1x2xi64>, !mhlo.token) -> !mhlo.token %2:2 = "mhlo.recv"(%1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "host_compute_channel_0_retvals_htod_0"}, mhlo.sharding = "\08\04"} : (!mhlo.token) -> (tensor<1x2xi64>, !mhlo.token) diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir index 811bc4a44f8284..5f44d0c3c4f292 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export_bounded_dynamism.mlir @@ -30,6 +30,7 @@ func.func @main(%arg0: tensor<1x?x512xf32, #mhlo.type_extensions>) -> tensor<1x?xf32, #mhlo.type_extensions> { %0 = mhlo.reshape %arg0 : (tensor>) -> tensor<1x?xf32, #mhlo.type_extensions> - // CHECK: f32[1,<=5] reshape(f32[<=5] + // CHECK: %[[ARG0:.+]] = f32[<=5] parameter(0) + // CHECK: f32[1,<=5] reshape(%[[ARG0]]) return %0 : tensor<1x?xf32, #mhlo.type_extensions> } diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir index ced6fb5e257c53..a082394aefecda 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir @@ -47,7 +47,7 @@ module @entry attributes { // CHECK: %Arg_3.4 = s32[] parameter(3) // CHECK: %Arg_0.1 = f32[2,3,4]{2,1,0} parameter(0) // CHECK: %Arg_1.2 = f32[2,3,4]{2,1,0} parameter(1) -// CHECK: ROOT %add.5 = f32[2,3,4]{2,1,0} add(f32[2,3,4]{2,1,0} %Arg_0.1, f32[2,3,4]{2,1,0} %Arg_1.2) +// CHECK: ROOT %add.5 = f32[2,3,4]{2,1,0} add(%Arg_0.1, %Arg_1.2) // CHECK: } // ----- diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir index 5b11ed4a4c323f..f5cc357961dbc0 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir @@ -12,4 +12,4 @@ func.func @main(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32> {mhlo.is_same // CHECK: %[[ARG0:.*]] = f32[16,16] parameter(0) // CHECK-NOT: parameter_replication={true} // CHECK: %[[ARG1:.*]] = f32[16,16] parameter(1), parameter_replication={true} -// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(f32[16,16] %[[ARG0]], f32[16,16] %[[ARG1]]) +// CHECK: ROOT %[[RESULT:.*]] = f32[16,16] add(%[[ARG0]], %[[ARG1]]) diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/function.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/function.mlir index 29ac479e024ef3..06c9c8dc82895d 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/function.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/function.mlir @@ -8,9 +8,9 @@ module @non_entry_function_shardings { // CHECK: %called_computation.{{[0-9]+}} (Arg_0.{{[0-9]+}}: s32[8,2]) -> s32[8,2] { // CHECK-NEXT: %[[ARG:.*]] = s32[8,2] parameter(0), sharding={devices=[2,2]<=[4]} - // CHECK-NEXT: %[[MULT:.*]] = s32[8,2] multiply(s32[8,2] %[[ARG]], s32[8,2] %[[ARG]]) - // CHECK-NEXT: %[[TUPLE:.*]] = (s32[8,2]) tuple(s32[8,2] %[[MULT]]) - // CHECK-NEXT: ROOT %get-tuple-element.{{[0-9]+}} = s32[8,2] get-tuple-element((s32[8,2]) %[[TUPLE]]), index=0, sharding={devices=[2,2]<=[4]} + // CHECK-NEXT: %[[MULT:.*]] = s32[8,2] multiply(%[[ARG]], %[[ARG]]) + // CHECK-NEXT: %[[TUPLE:.*]] = (s32[8,2]) tuple(%[[MULT]]) + // CHECK-NEXT: ROOT %get-tuple-element.{{[0-9]+}} = s32[8,2] get-tuple-element(%[[TUPLE]]), index=0, sharding={devices=[2,2]<=[4]} func.func private @called_computation(%arg0: tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) -> (tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) { %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> return %0 : tensor<8x2xi32> diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir index 62e5ab5664f28b..904ad43e490db3 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir @@ -7,13 +7,13 @@ // CHECK: ENTRY // CHECK: %[[PARAM0:.*]] = f32[] parameter(0) // CHECK: %[[PARAM1:.*]] = f32[] parameter(1) -// CHECK: %[[FUSION0:.*]] = f32[] fusion(f32[] %[[PARAM0]], f32[] %[[PARAM1]]), kind=kLoop, calls=%[[REGION0]] -// CHECK: %[[FUSION1:.*]] = (f32[], f32[]) fusion(f32[] %[[PARAM0]], f32[] %[[PARAM1]]), kind=kLoop, calls=%[[REGION1]] -// CHECK: f32[] get-tuple-element((f32[], f32[]) %[[FUSION1]]), index=0 -// CHECK: f32[] get-tuple-element((f32[], f32[]) %[[FUSION1]]), index=1 -// CHECK: %[[FUSION2:.*]] = (f32[], f32[]) fusion(f32[] %[[PARAM0]]), kind=kLoop, calls=%[[REGION2]] -// CHECK: f32[] get-tuple-element((f32[], f32[]) %[[FUSION2]]), index=0 -// CHECK: f32[] get-tuple-element((f32[], f32[]) %[[FUSION2]]), index=1 +// CHECK: %[[FUSION0:.*]] = f32[] fusion(%[[PARAM0]], %[[PARAM1]]), kind=kLoop, calls=%[[REGION0]] +// CHECK: %[[FUSION1:.*]] = (f32[], f32[]) fusion(%[[PARAM0]], %[[PARAM1]]), kind=kLoop, calls=%[[REGION1]] +// CHECK: f32[] get-tuple-element(%[[FUSION1]]), index=0 +// CHECK: f32[] get-tuple-element(%[[FUSION1]]), index=1 +// CHECK: %[[FUSION2:.*]] = (f32[], f32[]) fusion(%[[PARAM0]]), kind=kLoop, calls=%[[REGION2]] +// CHECK: f32[] get-tuple-element(%[[FUSION2]]), index=0 +// CHECK: f32[] get-tuple-element(%[[FUSION2]]), index=1 // CHECK: } func.func @main(%arg0: tensor, %arg1: tensor) { %result = "mhlo.fusion"(%arg0, %arg1) ({ diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir index 4c414422124736..836c257c7a0ce3 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir @@ -13,10 +13,10 @@ func.func @main(%arg0: tensor) -> tuple> { // CHECK: %[[VAL0:.+]] = f32[] constant(10) %cst = arith.constant dense<1.000000e+01> : tensor - // CHECK: %[[VAL1:.+]] = pred[] compare(f32[] %[[A0]], f32[] %[[VAL0]]), direction=LT + // CHECK: %[[VAL1:.+]] = pred[] compare(%[[A0]], %[[VAL0]]), direction=LT %0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: %[[VAL2:.+]] = f32[] conditional(pred[] %[[VAL1]], f32[] %[[A0]], f32[] %[[A0]]), true_computation=[[R0]], false_computation=[[R1]] + // CHECK: %[[VAL2:.+]] = f32[] conditional(%[[VAL1]], %[[A0]], %[[A0]]), true_computation=[[R0]], false_computation=[[R1]] %2 = "mhlo.if"(%0) ({ %6 = "mhlo.log"(%arg0) : (tensor) -> tensor "mhlo.return"(%6) : (tensor) -> () @@ -25,7 +25,7 @@ func.func @main(%arg0: tensor) -> tuple> { "mhlo.return"(%6) : (tensor) -> () }) : (tensor) -> tensor - // CHECK: ROOT %[[VAL3:.+]] = (f32[]) tuple(f32[] %[[VAL2]]) + // CHECK: ROOT %[[VAL3:.+]] = (f32[]) tuple(%[[VAL2]]) %3 = "mhlo.tuple"(%2) : (tensor) -> tuple> func.return %3 : tuple> } @@ -65,9 +65,9 @@ func.func @main(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: ENTRY // CHECK-DAG: %[[A0:.+]] = f32[] parameter(0) // CHECK-DAG: %[[A1:.+]] = f32[] parameter(1) -// CHECK-DAG: %[[TUPLE1:.+]] = (f32[], f32[]) tuple(f32[] %[[A0]], f32[] %[[A1]]) -// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[]) tuple(f32[] %[[A0]], f32[] %[[A1]]) -// CHECK-DAG: %[[COND:.+]] = (f32[], f32[]) conditional(pred[] %[[PRED:.+]], (f32[], f32[]) %[[TUPLE1]], (f32[], f32[]) %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] +// CHECK-DAG: %[[TUPLE1:.+]] = (f32[], f32[]) tuple(%[[A0]], %[[A1]]) +// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[]) tuple(%[[A0]], %[[A1]]) +// CHECK-DAG: %[[COND:.+]] = (f32[], f32[]) conditional(%[[PRED:.+]], %[[TUPLE1]], %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] // ----- // Test export mhlo::IfOp with multiple args, but different numbers of args for @@ -105,9 +105,9 @@ func.func @main(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-DAG: %[[A0:.+]] = f32[] parameter(0) // CHECK-DAG: %[[CST:.+]] = f32[] constant(10) // CHECK-DAG: %[[A1:.+]] = f32[] parameter(1) -// CHECK-DAG: %[[TUPLE1:.+]] = (f32[], f32[], f32[]) tuple(f32[] %[[CST]], f32[] %[[A1]], f32[] %[[A0]]) -// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[]) tuple(f32[] %[[A0]], f32[] %[[A1]]) -// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(pred[] %[[PRED:.+]], (f32[], f32[], f32[]) %[[TUPLE1]], (f32[], f32[]) %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] +// CHECK-DAG: %[[TUPLE1:.+]] = (f32[], f32[], f32[]) tuple(%[[CST]], %[[A1]], %[[A0]]) +// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[]) tuple(%[[A0]], %[[A1]]) +// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(%[[PRED:.+]], %[[TUPLE1]], %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] // ----- // Test export mhlo::IfOp with false branch having no implict captures. @@ -145,9 +145,9 @@ func.func @main(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-DAG: %[[A0:.+]] = f32[] parameter(0) // CHECK-DAG: %[[CST:.+]] = f32[] constant(10) // CHECK-DAG: %[[A1:.+]] = f32[] parameter(1) -// CHECK-DAG: %[[TUPLE1:.+]] = (f32[], f32[], f32[]) tuple(f32[] %[[CST]], f32[] %[[A1]], f32[] %[[A0]]) +// CHECK-DAG: %[[TUPLE1:.+]] = (f32[], f32[], f32[]) tuple(%[[CST]], %[[A1]], %[[A0]]) // CHECK-DAG: %[[TUPLE2:.+]] = () tuple() -// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(pred[] %[[PRED:.+]], (f32[], f32[], f32[]) %[[TUPLE1]], () %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] +// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(%[[PRED:.+]], %[[TUPLE1]], %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] // ----- // Test export mhlo::IfOp with true branch having no implict captures. @@ -186,8 +186,8 @@ func.func @main(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-DAG: %[[CST:.+]] = f32[] constant(10) // CHECK-DAG: %[[A1:.+]] = f32[] parameter(1) // CHECK-DAG: %[[TUPLE1:.+]] = () tuple() -// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[], f32[]) tuple(f32[] %[[CST]], f32[] %[[A1]], f32[] %[[A0]]) -// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(pred[] %[[PRED:.+]], () %[[TUPLE1]], (f32[], f32[], f32[]) %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] +// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[], f32[]) tuple(%[[CST]], %[[A1]], %[[A0]]) +// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(%[[PRED:.+]], %[[TUPLE1]], %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] // ----- // Test export mhlo::IfOp with both branches having no implict captures. @@ -223,7 +223,7 @@ func.func @main(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: ENTRY // CHECK: %[[TUPLE1:.+]] = () tuple() // CHECK: %[[TUPLE2:.+]] = () tuple() -// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(pred[] %[[PRED:.+]], () %[[TUPLE1]], () %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] +// CHECK: %[[COND:.+]] = (f32[], f32[]) conditional(%[[PRED:.+]], %[[TUPLE1]], %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] // ----- // Test export nested mhlo::IfOp. @@ -278,7 +278,7 @@ func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> te // CHECK-NEXT: %[[A2_EMPTY_TUPLE]] = () parameter(0) // CHECK-DAG: %[[CST2:.+]] = f32[] constant(10) // CHECK-DAG: %[[TUPLE2:.+]] = () tuple() -// CHECK: %[[COND2:.+]] = f32[] conditional(pred[] %{{.+}}, f32[] %[[CST2]], () %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] +// CHECK: %[[COND2:.+]] = f32[] conditional(%{{.+}}, %[[CST2]], %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] // CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple // CHECK-NEXT: } @@ -293,5 +293,5 @@ func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> te // CHECK-DAG: %[[A1:.+]] = f32[] parameter(1) // CHECK-DAG: %[[A2:.+]] = f32[] parameter(2) // CHECK-DAG: %[[TUPLE1:.+]] = () tuple() -// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[]) tuple(f32[] %[[A1]], f32[] %[[A2]]) -// CHECK: (f32[], f32[]) conditional(pred[] %[[A0]], () %[[TUPLE1]], (f32[], f32[]) %[[TUPLE2]]), true_computation=[[R2]], false_computation=[[R3]] +// CHECK-DAG: %[[TUPLE2:.+]] = (f32[], f32[]) tuple(%[[A1]], %[[A2]]) +// CHECK: (f32[], f32[]) conditional(%[[A0]], %[[TUPLE1]], %[[TUPLE2]]), true_computation=[[R2]], false_computation=[[R3]] diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir index 615fe126d6819c..15a6c64ba133d7 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir @@ -6,9 +6,9 @@ func.func @main() -> tensor<6xi4> { // CHECK-NEXT: %[[CONSTANT:.*]] = s4[6] constant({1, -2, -3, 4, -8, 7}) %0 = mhlo.constant dense<[1, -2, -3, 4, -8, 7]> : tensor<6xi4> - // CHECK-NEXT: %[[CONVERT1:.*]] = s8[6] convert(s4[6] %[[CONSTANT]]) + // CHECK-NEXT: %[[CONVERT1:.*]] = s8[6] convert(%[[CONSTANT]]) %1 = "mhlo.convert"(%0) : (tensor<6xi4>) -> tensor<6xi8> - // CHECK-NEXT: ROOT %[[CONVERT2:.*]] = s4[6] convert(s8[6] %[[CONVERT1]]) + // CHECK-NEXT: ROOT %[[CONVERT2:.*]] = s4[6] convert(%[[CONVERT1]]) %2 = "mhlo.convert"(%1) : (tensor<6xi8>) -> tensor<6xi4> func.return %2 : tensor<6xi4> } @@ -19,9 +19,9 @@ func.func @main() -> tensor<6xi4> { func.func @main() -> tensor<4xui4> { // CHECK-NEXT: %[[CONSTANT:.*]] = u4[4] constant({1, 2, 3, 15}) %0 = mhlo.constant dense<[1, 2, 3, 15]> : tensor<4xui4> - // CHECK-NEXT: %[[CONVERT1:.*]] = u8[4] convert(u4[4] %[[CONSTANT]]) + // CHECK-NEXT: %[[CONVERT1:.*]] = u8[4] convert(%[[CONSTANT]]) %1 = "mhlo.convert"(%0) : (tensor<4xui4>) -> tensor<4xui8> - // CHECK-NEXT: ROOT %[[CONVERT2:.*]] = u4[4] convert(u8[4] %[[CONVERT1]]) + // CHECK-NEXT: ROOT %[[CONVERT2:.*]] = u4[4] convert(%[[CONVERT1]]) %2 = "mhlo.convert"(%1) : (tensor<4xui8>) -> tensor<4xui4> func.return %2 : tensor<4xui4> } diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir index 6d908d5d6a0eaa..26cba7ca71c571 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir @@ -46,11 +46,11 @@ func.func @main(%arg0: !mhlo.token) -> tuple, tensor>, // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}, pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" -// CHECK: [[GTE1:%.*]] = (s32[3,3]{0,1}, pred[]) get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=0 -// CHECK: [[GTE2:%.*]] = s32[3,3]{0,1} get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=0 -// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=1 -// CHECK: [[GTE4:%.*]] = token[] get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=1 +// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}, pred[]), token[]) infeed([[ARG]]), infeed_config="foobar" +// CHECK: [[GTE1:%.*]] = (s32[3,3]{0,1}, pred[]) get-tuple-element([[INFEED]]), index=0 +// CHECK: [[GTE2:%.*]] = s32[3,3]{0,1} get-tuple-element([[GTE1]]), index=0 +// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element([[GTE1]]), index=1 +// CHECK: [[GTE4:%.*]] = token[] get-tuple-element([[INFEED]]), index=1 // ----- @@ -64,11 +64,11 @@ func.func @main(%arg0: !mhlo.token) -> tuple, !mhlo.token> { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" -// CHECK: [[GTE0:%.*]] = (s32[3,3]{0,1}) get-tuple-element(((s32[3,3]{0,1}), token[]) [[INFEED]]), index=0 -// CHECK: [[GTE1:%.*]] = s32[3,3]{0,1} get-tuple-element((s32[3,3]{0,1}) [[GTE0]]), index=0 -// CHECK: [[GTE2:%.*]] = token[] get-tuple-element(((s32[3,3]{0,1}), token[]) [[INFEED]]), index=1 -// CHECK: ROOT [[RES:%.*]] = (s32[3,3]{1,0}, token[]) tuple(s32[3,3]{0,1} [[GTE1]], token[] [[GTE2]] +// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}), token[]) infeed([[ARG]]), infeed_config="foobar" +// CHECK: [[GTE0:%.*]] = (s32[3,3]{0,1}) get-tuple-element([[INFEED]]), index=0 +// CHECK: [[GTE1:%.*]] = s32[3,3]{0,1} get-tuple-element([[GTE0]]), index=0 +// CHECK: [[GTE2:%.*]] = token[] get-tuple-element([[INFEED]]), index=1 +// CHECK: ROOT [[RES:%.*]] = (s32[3,3]{1,0}, token[]) tuple([[GTE1]], [[GTE2]] // ----- @@ -81,5 +81,5 @@ func.func @main(%arg0: !mhlo.token) -> !mhlo.token { // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: [[INFEED:%.*]] = ((), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" -// CHECK: ROOT [[GTE1:%.*]] = ((), token[]) get-tuple-element(((), token[]) [[INFEED]]), index=1 +// CHECK: [[INFEED:%.*]] = ((), token[]) infeed([[ARG]]), infeed_config="foobar" +// CHECK: ROOT [[GTE1:%.*]] = ((), token[]) get-tuple-element([[INFEED]]), index=1 diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir index 878285fd21d4d7..dd1b7221d1701f 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir @@ -8,7 +8,7 @@ // TUPLE-LABEL: ENTRY %main.{{.*}} (arg_tuple.1: (s32[4])) -> (s32[4], s32[1,2,3,4]) func.func @main(%arg0: tensor<4xi32>) -> (tensor<4xi32>, tensor<1x2x3x4xi32>) { // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) - // CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} + // CHECK-NEXT: %broadcast.2 = s32[1,2,3,4] broadcast(%Arg_0.1), dimensions={3} %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>}> : (tensor<4xi32>) -> tensor<1x2x3x4xi32> func.return %arg0, %0 : tensor<4xi32>, tensor<1x2x3x4xi32> } diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/ragged_dot.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/ragged_dot.mlir index 51602b1e34edb6..9020d0ca8af029 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/ragged_dot.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/ragged_dot.mlir @@ -2,7 +2,10 @@ module @ragged_dot_non_contracting { func.func @main(%lhs : tensor<19x11x5xf32>, %rhs : tensor<3x5x7xf32>, %group_sizes : tensor<19x3xi64>) -> tensor<19x11x7xf32> { - // CHECK: f32[19,11,7] ragged-dot(f32[19,11,5] {{.*}}, f32[3,5,7] {{.*}}, s64[19,3] {{.*}}), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_ragged_dims={1}, rhs_group_dims={0} + // CHECK: %[[ARG0:.+]] = f32[19,11,5] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[3,5,7] parameter(1) + // CHECK: %[[ARG2:.+]] = s64[19,3] parameter(2) + // CHECK: f32[19,11,7] ragged-dot(%[[ARG0]], %[[ARG1]], %[[ARG2]]), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_ragged_dims={1}, rhs_group_dims={0} %0 = "mhlo.ragged_dot"(%lhs, %rhs, %group_sizes) { ragged_dot_dimension_numbers = #mhlo.ragged_dot< dot_dimension_numbers = < @@ -24,7 +27,10 @@ module @ragged_dot_non_contracting { module @ragged_dot_contracting { func.func @main(%lhs : tensor<11x19x5xf32>, %rhs : tensor<19x5x7xf32>, %group_sizes : tensor<19x3xi64>) -> tensor<3x11x7xf32> { - // CHECK: f32[3,11,7] ragged-dot(f32[11,19,5] {{.*}}, f32[19,5,7] {{.*}}, s64[19,3] {{.*}}), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, lhs_ragged_dims={2} + // CHECK: %[[ARG0:.+]] = f32[11,19,5] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[19,5,7] parameter(1) + // CHECK: %[[ARG2:.+]] = s64[19,3] parameter(2) + // CHECK: f32[3,11,7] ragged-dot(%[[ARG0]], %[[ARG1]], %[[ARG2]]), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, lhs_ragged_dims={2} %0 = "mhlo.ragged_dot"(%lhs, %rhs, %group_sizes) { ragged_dot_dimension_numbers = #mhlo.ragged_dot< dot_dimension_numbers = < @@ -46,7 +52,10 @@ module @ragged_dot_contracting { module @ragged_dot_batch { func.func @main(%lhs : tensor<19x17x11x5xf32>, %rhs : tensor<19x17x5x7xf32>, %group_sizes : tensor<19x3xi64>) -> tensor<19x17x11x7xf32> { - // CHECK: f32[19,17,11,7] ragged-dot(f32[19,17,11,5] {{.*}}, f32[19,17,5,7] {{.*}}, s64[19,3] {{.*}}), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, lhs_ragged_dims={1} + // CHECK: %[[ARG0:.+]] = f32[19,17,11,5] parameter(0) + // CHECK: %[[ARG1:.+]] = f32[19,17,5,7] parameter(1) + // CHECK: %[[ARG2:.+]] = s64[19,3] parameter(2) + // CHECK: f32[19,17,11,7] ragged-dot(%[[ARG0]], %[[ARG1]], %[[ARG2]]), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, lhs_ragged_dims={1} %0 = "mhlo.ragged_dot"(%lhs, %rhs, %group_sizes) { ragged_dot_dimension_numbers = #mhlo.ragged_dot< dot_dimension_numbers = < diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir index b7255055f4b372..edb82b4cefe64c 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir @@ -16,9 +16,9 @@ func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4x // CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[5,8,128]) -> f32[5,8,128] func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) -> (tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) { // CHECK-NEXT: %Arg_0.1 = f32[5,8,128] parameter(0), sharding={devices=[1,2,1]0,1} - // CHECK-NEXT: %custom-call.2 = f32[5,8,128] custom-call(f32[5,8,128] %Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2,1]0,1} - // CHECK-NEXT: %tuple.3 = (f32[5,8,128]) tuple(f32[5,8,128] %custom-call.2) - // CHECK-NEXT: ROOT %get-tuple-element.4 = f32[5,8,128] get-tuple-element((f32[5,8,128]) %tuple.3), index=0 + // CHECK-NEXT: %custom-call.2 = f32[5,8,128] custom-call(%Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: %tuple.3 = (f32[5,8,128]) tuple(%custom-call.2) + // CHECK-NEXT: ROOT %get-tuple-element.4 = f32[5,8,128] get-tuple-element(%tuple.3), index=0 // CHECK-SAME: sharding={devices=[1,2,1]0,1} %0 = "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01" @@ -31,10 +31,10 @@ func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\ // CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[4,4]) -> (f32[4,4], f32[4,4]) func.func @main(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\03\02\01\02\22\04\00\01\02\03B\01\00"}, tensor<4x4xf32>) { // CHECK-NEXT: %Arg_0.1 = f32[4,4] parameter(0) - // CHECK-NEXT: [[RESHAPE_0:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} - // CHECK-NEXT: [[RESHAPE_1:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1) + // CHECK-NEXT: [[RESHAPE_0:%.*]] = f32[4,4] reshape(%Arg_0.1), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + // CHECK-NEXT: [[RESHAPE_1:%.*]] = f32[4,4] reshape(%Arg_0.1) // CHECK-NOT: sharding - // CHECK-NEXT: ROOT {{%.*}} = (f32[4,4], f32[4,4]) tuple(f32[4,4] [[RESHAPE_0]], f32[4,4] [[RESHAPE_1]]) + // CHECK-NEXT: ROOT {{%.*}} = (f32[4,4], f32[4,4]) tuple([[RESHAPE_0]], [[RESHAPE_1]]) // CHECK-SAME: sharding={{\{}}{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {replicated}} return %arg0, %arg0 : tensor<4x4xf32>, tensor<4x4xf32> } @@ -44,8 +44,8 @@ func.func @main(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {mhlo.sharding = "\0 // CHECK-LABEL: ENTRY %main.{{.*}} () -> f32[4] func.func @main() -> (tensor<4xf32>) { // CHECK-NEXT: %constant.1 = f32[] constant(3.1415925) - // CHECK-NEXT: %broadcast.2 = f32[4] broadcast(f32[] %constant.1), dimensions={}, sharding={devices=[2]0,1} - // CHECK-NEXT: ROOT %add.3 = f32[4] add(f32[4] %broadcast.2, f32[4] %broadcast.2) + // CHECK-NEXT: %broadcast.2 = f32[4] broadcast(%constant.1), dimensions={}, sharding={devices=[2]0,1} + // CHECK-NEXT: ROOT %add.3 = f32[4] add(%broadcast.2, %broadcast.2) %0 = mhlo.constant {mhlo.sharding = "{devices=[2]0,1}"} dense<3.1415926> : tensor<4xf32> %1 = mhlo.add %0, %0 : tensor<4xf32> return %1 : tensor<4xf32> @@ -56,8 +56,8 @@ func.func @main() -> (tensor<4xf32>) { // CHECK-LABEL: ENTRY %main.{{.*}} () -> f32[12,24,36] func.func @main() -> (tensor<12x24x36xf32>) { // CHECK-NEXT: %constant.1 = f32[] constant(3.1415925) - // CHECK-NEXT: %broadcast.2 = f32[12,24,36] broadcast(f32[] %constant.1), dimensions={}, sharding={devices=[1,2,1]0,1} - // CHECK-NEXT: ROOT %add.3 = f32[12,24,36] add(f32[12,24,36] %broadcast.2, f32[12,24,36] %broadcast.2) + // CHECK-NEXT: %broadcast.2 = f32[12,24,36] broadcast(%constant.1), dimensions={}, sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: ROOT %add.3 = f32[12,24,36] add(%broadcast.2, %broadcast.2) %0 = mhlo.constant {mhlo.sharding = "{devices=[1,2,1]0,1}"} dense<3.1415926> : tensor<12x24x36xf32> %1 = mhlo.add %0, %0 : tensor<12x24x36xf32> return %1 : tensor<12x24x36xf32> @@ -68,13 +68,13 @@ func.func @main() -> (tensor<12x24x36xf32>) { // CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64> {mhlo.sharding = "{devices=[2,16]<=[32] last_tile_dim_replicate}"}, tensor<512x4xui32> {mhlo.sharding = "{devices=[4,8]<=[32]}"}) { // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) - // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {devices=[8,4]<=[32]}} - // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} - // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) - // CHECK-NEXT: %reshape.6 = u64[2] reshape(u64[2] %add.5) - // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={devices=[8,4]<=[32]} - // CHECK-NEXT: %reshape.7 = u32[512,4] reshape(u32[512,4] %get-tuple-element.4) - // CHECK-NEXT: ROOT %tuple.8 = (u64[2], u32[512,4]) tuple(u64[2] %reshape.6, u32[512,4] %reshape.7), sharding={{\{}}{devices=[2,16]<=[32] last_tile_dim_replicate}, {devices=[4,8]<=[32]}} + // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(%Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {devices=[8,4]<=[32]}} + // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element(%rng-bit-generator.2), index=0, sharding={replicated} + // CHECK-NEXT: %add.5 = u64[2] add(%get-tuple-element.3, %get-tuple-element.3) + // CHECK-NEXT: %reshape.6 = u64[2] reshape(%add.5) + // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element(%rng-bit-generator.2), index=1, sharding={devices=[8,4]<=[32]} + // CHECK-NEXT: %reshape.7 = u32[512,4] reshape(%get-tuple-element.4) + // CHECK-NEXT: ROOT %tuple.8 = (u64[2], u32[512,4]) tuple(%reshape.6, %reshape.7), sharding={{\{}}{devices=[2,16]<=[32] last_tile_dim_replicate}, {devices=[4,8]<=[32]}} %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{{replicated}, {devices=[8,4]<=[32]}}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) %0 = mhlo.add %output_state, %output_state : tensor<2xui64> return %0, %output : tensor<2xui64>, tensor<512x4xui32> @@ -85,11 +85,11 @@ func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64> {mhlo.sharding = "{dev // CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) { // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) - // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {replicated}} - // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} - // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) - // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={replicated} - // CHECK-NEXT: ROOT %tuple.6 = (u64[2], u32[512,4]) tuple(u64[2] %add.5, u32[512,4] %get-tuple-element.4) + // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(%Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {replicated}} + // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element(%rng-bit-generator.2), index=0, sharding={replicated} + // CHECK-NEXT: %add.5 = u64[2] add(%get-tuple-element.3, %get-tuple-element.3) + // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element(%rng-bit-generator.2), index=1, sharding={replicated} + // CHECK-NEXT: ROOT %tuple.6 = (u64[2], u32[512,4]) tuple(%add.5, %get-tuple-element.4) %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{replicated}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) %0 = mhlo.add %output_state, %output_state : tensor<2xui64> return %0, %output : tensor<2xui64>, tensor<512x4xui32> @@ -101,17 +101,17 @@ func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) { // CHECK: %[[BODY:region_0.[0-9]+]] ([[ARG:Arg_.[0-9]+]]: s32[]) -> s32[] { // CHECK-NEXT: %[[ARG]] = s32[] parameter(0), sharding={replicated} -// CHECK-NEXT: %[[ADD:add.[0-9]+]] = s32[] add(s32[] %[[ARG]], s32[] %[[ARG]]) -// CHECK-NEXT: %[[TUPLE:tuple.[0-9]+]] = (s32[]) tuple(s32[] %[[ADD]]) -// CHECK-NEXT: ROOT %get-tuple-element.{{[0-9]+}} = s32[] get-tuple-element((s32[]) %[[TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[ADD:add.[0-9]+]] = s32[] add(%[[ARG]], %[[ARG]]) +// CHECK-NEXT: %[[TUPLE:tuple.[0-9]+]] = (s32[]) tuple(%[[ADD]]) +// CHECK-NEXT: ROOT %get-tuple-element.{{[0-9]+}} = s32[] get-tuple-element(%[[TUPLE]]), index=0, sharding={replicated} // CHECK: %[[COND:region_1.[0-9]+]] ([[ARG:Arg_.[0-9]+]]: s32[]) -> pred[] { // CHECK-NEXT: %[[ARG]] = s32[] parameter(0), sharding={replicated} -// CHECK-NEXT: ROOT %compare.{{[0-9]+}} = pred[] compare(s32[] %[[ARG]], s32[] %[[ARG]]), direction=LT +// CHECK-NEXT: ROOT %compare.{{[0-9]+}} = pred[] compare(%[[ARG]], %[[ARG]]), direction=LT // CHECK: ENTRY %main.{{[0-9]+}} ([[ARG:Arg_0.[0-9]+]]: s32[]) -> s32[] { // CHECK-NEXT: %[[ARG]] = s32[] parameter(0) -// CHECK-NEXT: ROOT %while.10 = s32[] while(s32[] %[[ARG]]), condition=%[[COND]], body=%[[BODY]], sharding={replicated} +// CHECK-NEXT: ROOT %while.10 = s32[] while(%[[ARG]]), condition=%[[COND]], body=%[[BODY]], sharding={replicated} func.func @main(%arg0: tensor) -> tensor { %0 = mhlo.while(%iterArg = %arg0) : tensor attributes {mhlo.sharding = "{replicated}"} @@ -132,33 +132,33 @@ func.func @main(%arg0: tensor) -> tensor { // CHECK: %[[BODY:region_0.[0-9]+]] ([[ARG_TUPLE:arg_tuple.[0-9]+]]: (s32[], f32[4], f32[4])) -> (s32[], f32[4], f32[4]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE0:get-tuple-element.[0-9]+]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} -// CHECK-NEXT: %[[GTE1:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %[[GTE2:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={devices=[4]<=[4]} -// CHECK-NEXT: %[[ADD:add.[0-9]+]] = f32[4] add(f32[4] %[[GTE1]], f32[4] %[[GTE2]]) -// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (s32[], f32[4], f32[4]) tuple(s32[] %[[GTE0]], f32[4] %[[ADD]], f32[4] %[[GTE2]]) +// CHECK-NEXT: %[[GTE0:get-tuple-element.[0-9]+]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE1:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE2:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=2, sharding={devices=[4]<=[4]} +// CHECK-NEXT: %[[ADD:add.[0-9]+]] = f32[4] add(%[[GTE1]], %[[GTE2]]) +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (s32[], f32[4], f32[4]) tuple(%[[GTE0]], %[[ADD]], %[[GTE2]]) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} // CHECK: %[[COND:region_1.[0-9]+]] ([[ARG_TUPLE:arg_tuple.[0-9]+]]: (s32[], f32[4], f32[4])) -> pred[] { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE15:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %[[GTE16:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={devices=[4]<=[4]} -// CHECK-NEXT: %[[GTE14:get-tuple-element.[0-9]+]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} -// CHECK-NEXT: ROOT %compare.{{[0-9]+}} = pred[] compare(s32[] %[[GTE14]], s32[] %[[GTE14]]), direction=LT +// CHECK-NEXT: %[[GTE15:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE16:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=2, sharding={devices=[4]<=[4]} +// CHECK-NEXT: %[[GTE14:get-tuple-element.[0-9]+]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: ROOT %compare.{{[0-9]+}} = pred[] compare(%[[GTE14]], %[[GTE14]]), direction=LT // CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:Arg_0.[0-9]+]]: s32[], [[ARG1:Arg_1.[0-9]+]]: f32[4], [[ARG2:Arg_2.[0-9]+]]: f32[4]) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) // CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1) // CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) -// CHECK-NEXT: %[[TUPLE:tuple.[0-9]+]] = (s32[], f32[4], f32[4]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]) +// CHECK-NEXT: %[[TUPLE:tuple.[0-9]+]] = (s32[], f32[4], f32[4]) tuple(%[[ARG0]], %[[ARG1]], %[[ARG2]]) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[WHILE:while.[0-9]+]] = (s32[], f32[4], f32[4]) while((s32[], f32[4], f32[4]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] +// CHECK-NEXT: %[[WHILE:while.[0-9]+]] = (s32[], f32[4], f32[4]) while(%[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE19:get-tuple-element.[0-9]+]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=0, sharding={replicated} -// CHECK-NEXT: %[[GTE20:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %[[GTE21:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=2, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE20]], f32[4] %[[GTE21]]) +// CHECK-NEXT: %[[GTE19:get-tuple-element.[0-9]+]] = s32[] get-tuple-element(%[[WHILE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE20:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element(%[[WHILE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE21:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element(%[[WHILE]]), index=2, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[4], f32[4]) tuple(%[[GTE20]], %[[GTE21]]) func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { %0:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1, %iterArg_1 = %arg2) : tensor, tensor<4xf32>, tensor<4xf32> @@ -180,33 +180,33 @@ func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) // CHECK: %[[BODY:region_0.[0-9]+]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], f32[4])) -> (s32[], f32[4], f32[4]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) // CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %[[GTE7:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={manual} -// CHECK-NEXT: %[[GTE8:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={manual} -// CHECK-NEXT: %[[GTE9:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={manual} -// CHECK-NEXT: %[[ADD:add.*]] = f32[4] add(f32[4] %[[GTE8]], f32[4] %[[GTE9]]) -// CHECK-NEXT: ROOT %tuple.{{.*}} = (s32[], f32[4], f32[4]) tuple(s32[] %[[GTE7]], f32[4] %[[ADD]], f32[4] %[[GTE9]]) +// CHECK-NEXT: %[[GTE7:get-tuple-element.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={manual} +// CHECK-NEXT: %[[GTE8:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1, sharding={manual} +// CHECK-NEXT: %[[GTE9:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=2, sharding={manual} +// CHECK-NEXT: %[[ADD:add.*]] = f32[4] add(%[[GTE8]], %[[GTE9]]) +// CHECK-NEXT: ROOT %tuple.{{.*}} = (s32[], f32[4], f32[4]) tuple(%[[GTE7]], %[[ADD]], %[[GTE9]]) // CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} // CHECK: %[[COND:region_1.[0-9]+]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], f32[4])) -> pred[] { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) // CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={manual} -// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={manual} -// CHECK-NEXT: %[[GTE14:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={manual} -// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(s32[] %[[GTE14]], s32[] %[[GTE14]]), direction=LT +// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1, sharding={manual} +// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=2, sharding={manual} +// CHECK-NEXT: %[[GTE14:get-tuple-element.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={manual} +// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(%[[GTE14]], %[[GTE14]]), direction=LT // CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4]) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) // CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1) // CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) -// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], f32[4]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]) +// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], f32[4]) tuple(%[[ARG0]], %[[ARG1]], %[[ARG2]]) // CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %[[WHILE:while.*]] = (s32[], f32[4], f32[4]) while((s32[], f32[4], f32[4]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] +// CHECK-NEXT: %[[WHILE:while.*]] = (s32[], f32[4], f32[4]) while(%[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] // CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=0, sharding={manual} -// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=1, sharding={manual} -// CHECK-NEXT: %[[GTE21:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=2, sharding={manual} -// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE20]], f32[4] %[[GTE21]]) +// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = s32[] get-tuple-element(%[[WHILE]]), index=0, sharding={manual} +// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element(%[[WHILE]]), index=1, sharding={manual} +// CHECK-NEXT: %[[GTE21:get-tuple-element.*]] = f32[4] get-tuple-element(%[[WHILE]]), index=2, sharding={manual} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(%[[GTE20]], %[[GTE21]]) func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { %0:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1, %iterArg_1 = %arg2) : tensor, tensor<4xf32>, tensor<4xf32> @@ -227,29 +227,29 @@ func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) // CHECK: %[[BRANCH0:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} -// CHECK-NEXT: %[[GTE10:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %[[GTE11:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1 -// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE10]], f32[4] %[[GTE11]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE10:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE11:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1 +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(%[[GTE10]], %[[GTE11]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} // CHECK: %[[BRANCH1:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} -// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE15]], f32[4] %[[GTE16]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(%[[GTE15]], %[[GTE16]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} // CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4], [[ARG3:Arg_3.*]]: f32[4], [[ARG4:Arg_4.*]]: f32[4]) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) // CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} // CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) -// CHECK-NEXT: %[[TUPLE6:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG1]], f32[4] %[[ARG2]]), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %[[TUPLE6:tuple.*]] = (f32[4], f32[4]) tuple(%[[ARG1]], %[[ARG2]]), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} // CHECK-NEXT: %[[ARG3]] = f32[4] parameter(3), sharding={replicated} // CHECK-NEXT: %[[ARG4]] = f32[4] parameter(4), sharding={devices=[4]<=[4]} -// CHECK-NEXT: %[[TUPLE7:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG3]], f32[4] %[[ARG4]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[COND:conditional.*]] = (f32[4], f32[4]) conditional(s32[] %[[ARG0]], (f32[4], f32[4]) %[[TUPLE6]], (f32[4], f32[4]) %[[TUPLE7]]), branch_computations={%[[BRANCH0]], %[[BRANCH1]]}, +// CHECK-NEXT: %[[TUPLE7:tuple.*]] = (f32[4], f32[4]) tuple(%[[ARG3]], %[[ARG4]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[COND:conditional.*]] = (f32[4], f32[4]) conditional(%[[ARG0]], %[[TUPLE6]], %[[TUPLE7]]), branch_computations={%[[BRANCH0]], %[[BRANCH1]]}, // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[COND]]), index=0, sharding={replicated} -// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[COND]]), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE19]], f32[4] %[[GTE20]]) +// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = f32[4] get-tuple-element(%[[COND]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element(%[[COND]]), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(%[[GTE19]], %[[GTE20]]) func.func @main(%arg0: tensor, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, @@ -279,7 +279,7 @@ func.func @main(%arg0: tensor, // CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) // CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} // CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) -// CHECK-NEXT: ROOT %conditional.{{.*}} = f32[4] conditional(s32[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]), branch_computations={%[[BRANCH0]], %[[BRANCH1]]} +// CHECK-NEXT: ROOT %conditional.{{.*}} = f32[4] conditional(%[[ARG0]], %[[ARG1]], %[[ARG2]]), branch_computations={%[[BRANCH0]], %[[BRANCH1]]} func.func @main(%arg0: tensor, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, %arg2: tensor<4xf32>) -> tensor<4xf32> { @@ -297,29 +297,29 @@ func.func @main(%arg0: tensor, // CHECK: %[[BRANCH0:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} -// CHECK-NEXT: %[[GTE10:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %[[GTE11:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1 -// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE10]], f32[4] %[[GTE11]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE10:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE11:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1 +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(%[[GTE10]], %[[GTE11]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} // CHECK: %[[BRANCH1:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} -// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE15]], f32[4] %[[GTE16]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(%[[GTE15]], %[[GTE16]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} // CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: pred[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4], [[ARG3:Arg_3.*]]: f32[4], [[ARG4:Arg_4.*]]: f32[4]) -> (f32[4], f32[4]) { // CHECK-NEXT: %[[ARG0]] = pred[] parameter(0) // CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} // CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) -// CHECK-NEXT: %[[TUPLE6:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG1]], f32[4] %[[ARG2]]), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %[[TUPLE6:tuple.*]] = (f32[4], f32[4]) tuple(%[[ARG1]], %[[ARG2]]), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} // CHECK-NEXT: %[[ARG3]] = f32[4] parameter(3), sharding={replicated} // CHECK-NEXT: %[[ARG4]] = f32[4] parameter(4), sharding={devices=[4]<=[4]} -// CHECK-NEXT: %[[TUPLE7:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG3]], f32[4] %[[ARG4]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(pred[] %[[ARG0]], (f32[4], f32[4]) %[[TUPLE6]], (f32[4], f32[4]) %[[TUPLE7]]), true_computation=%[[BRANCH0]], false_computation=%[[BRANCH1]], +// CHECK-NEXT: %[[TUPLE7:tuple.*]] = (f32[4], f32[4]) tuple(%[[ARG3]], %[[ARG4]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(%[[ARG0]], %[[TUPLE6]], %[[TUPLE7]]), true_computation=%[[BRANCH0]], false_computation=%[[BRANCH1]], // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} -// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE19]], f32[4] %[[GTE20]]) +// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = f32[4] get-tuple-element(%conditional.18), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element(%conditional.18), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(%[[GTE19]], %[[GTE20]]) func.func @main(%arg0: tensor, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, @@ -348,7 +348,7 @@ func.func @main(%arg0: tensor, // CHECK-NEXT: %[[ARG0]] = pred[] parameter(0) // CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} // CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) -// CHECK-NEXT: ROOT %conditional.{{.*}} = f32[4] conditional(pred[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]), true_computation=%[[TRUE]], false_computation=%[[FALSE]] +// CHECK-NEXT: ROOT %conditional.{{.*}} = f32[4] conditional(%[[ARG0]], %[[ARG1]], %[[ARG2]]), true_computation=%[[TRUE]], false_computation=%[[FALSE]] func.func @main(%arg0: tensor, %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir index e99a57fe7f9225..b444dbb27bdb13 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir @@ -6,10 +6,10 @@ module { %0 = "mhlo.while"(%arg0) ({ // CHECK: [[R0:%.+]] ([[A0:.+]]: s64[]) -> s64[] { // CHECK: %[[A0]] = s64[] parameter(0) - // CHECK: ROOT %add.{{.*}} = s64[] add(s64[] %[[A0]], s64[] %[[A0]]) + // CHECK: ROOT %add.{{.*}} = s64[] add(%[[A0]], %[[A0]]) // CHECK: [[R1:%.+]] ([[A0:.+]]: s64[]) -> pred[] { // CHECK: %[[A0]] = s64[] parameter(0) - // CHECK: ROOT %compare.{{.*}} = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT + // CHECK: ROOT %compare.{{.*}} = pred[] compare(%[[A0]], %[[A0]]), direction=LT ^bb0(%arg1: tensor): %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () @@ -21,7 +21,7 @@ module { // CHECK: ENTRY %main.{{.*}} ([[A0:.+]]: s64[]) -> s64[] { // CHECK: %[[A0]] = s64[] parameter(0) - // CHECK: ROOT %while.{{.*}} = s64[] while(s64[] %[[A0]]), condition=[[R1]], body=[[R0]] + // CHECK: ROOT %while.{{.*}} = s64[] while(%[[A0]]), condition=[[R1]], body=[[R0]] func.return %0 : tensor } } @@ -32,22 +32,22 @@ module { // CHECK: [[BODY:%.+]] ([[TUPLE:.+]]: (s32[], s32[], f32[], f32[])) -> (s32[], s32[], f32[], f32[]) { // CHECK-NEXT: %[[TUPLE]] = (s32[], s32[], f32[], f32[]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=1 -// CHECK-NEXT: %[[GTE_2:.*]] = f32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=2 -// CHECK-NEXT: %[[GTE_3:.*]] = f32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=3 -// CHECK-NEXT: %[[ADD:.*]] = f32[] add(f32[] %[[GTE_2]], f32[] %[[GTE_3]]) -// CHECK-NEXT: ROOT %[[TUPLE_RES:.*]] = (s32[], s32[], f32[], f32[]) tuple(s32[] %[[GTE_0]], s32[] %[[GTE_1]], f32[] %[[GTE_2]], f32[] %[[ADD]]) +// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=1 +// CHECK-NEXT: %[[GTE_2:.*]] = f32[] get-tuple-element(%[[TUPLE]]), index=2 +// CHECK-NEXT: %[[GTE_3:.*]] = f32[] get-tuple-element(%[[TUPLE]]), index=3 +// CHECK-NEXT: %[[ADD:.*]] = f32[] add(%[[GTE_2]], %[[GTE_3]]) +// CHECK-NEXT: ROOT %[[TUPLE_RES:.*]] = (s32[], s32[], f32[], f32[]) tuple(%[[GTE_0]], %[[GTE_1]], %[[GTE_2]], %[[ADD]]) // CHECK: } // CHECK: [[COND:%.+]] ([[TUPLE:.+]]: (s32[], s32[], f32[], f32[])) -> pred[] { // CHECK-NEXT: %[[TUPLE]] = (s32[], s32[], f32[], f32[]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = f32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=2 -// CHECK-NEXT: %[[GTE_1:.*]] = f32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=3 +// CHECK-NEXT: %[[GTE_0:.*]] = f32[] get-tuple-element(%[[TUPLE]]), index=2 +// CHECK-NEXT: %[[GTE_1:.*]] = f32[] get-tuple-element(%[[TUPLE]]), index=3 // CHECK-NEXT: %[[CST_0:.*]] = s32[] constant(0) -// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=0 -// CHECK-NEXT: %[[GTE_3:.*]] = s32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[TUPLE]]), index=1 -// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(s32[] %[[GTE_2]], s32[] %[[GTE_3]]), direction=LT +// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=0 +// CHECK-NEXT: %[[GTE_3:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=1 +// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(%[[GTE_2]], %[[GTE_3]]), direction=LT // CHECK: } @@ -56,12 +56,12 @@ module { // CHECK-NEXT: %[[CST_1:.*]] = s32[] constant(100) // CHECK-NEXT: %[[CST_2:.*]] = f32[] constant(1) // CHECK-NEXT: %[[ARG_0:.*]] = f32[] parameter(0) -// CHECK-NEXT: %[[TUPLE:.*]] = (s32[], s32[], f32[], f32[]) tuple(s32[] %[[CST_0]], s32[] %[[CST_1]], f32[] %[[CST_2]], f32[] %[[ARG_0]]) -// CHECK-NEXT: %[[WHILE:.*]] = (s32[], s32[], f32[], f32[]) while((s32[], s32[], f32[], f32[]) %[[TUPLE]]), condition=[[COND]], body=[[BODY]] -// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[WHILE]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[WHILE]]), index=1 -// CHECK-NEXT: %[[GTE_2:.*]] = f32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[WHILE]]), index=2 -// CHECK-NEXT: ROOT %[[GTE_3:.*]] = f32[] get-tuple-element((s32[], s32[], f32[], f32[]) %[[WHILE]]), index=3 +// CHECK-NEXT: %[[TUPLE:.*]] = (s32[], s32[], f32[], f32[]) tuple(%[[CST_0]], %[[CST_1]], %[[CST_2]], %[[ARG_0]]) +// CHECK-NEXT: %[[WHILE:.*]] = (s32[], s32[], f32[], f32[]) while(%[[TUPLE]]), condition=[[COND]], body=[[BODY]] +// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element(%[[WHILE]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element(%[[WHILE]]), index=1 +// CHECK-NEXT: %[[GTE_2:.*]] = f32[] get-tuple-element(%[[WHILE]]), index=2 +// CHECK-NEXT: ROOT %[[GTE_3:.*]] = f32[] get-tuple-element(%[[WHILE]]), index=3 func.func @main(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor @@ -86,41 +86,41 @@ func.func @main(%arg0: tensor) -> tensor { // CHECK: [[BODY:%.+]] ([[TUPLE_0:.+]]: (s32[1], s32[2], f32[1], f32[3])) -> (s32[1], s32[2], f32[1], f32[3]) { // CHECK-NEXT: %[[TUPLE_0:.*]] = (s32[1], s32[2], f32[1], f32[3]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = s32[1] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[2] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=1 -// CHECK-NEXT: %[[GTE_2:.*]] = f32[1] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=2 -// CHECK-NEXT: %[[GTE_3:.*]] = f32[3] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=3 -// CHECK-NEXT: %[[BDCAST_0:.*]] = f32[1] broadcast(f32[1] %[[GTE_2]]), dimensions={0} -// CHECK-NEXT: %[[RESHAPE_0:.*]] = f32[] reshape(f32[1] %[[BDCAST_0]]) -// CHECK-NEXT: %[[BDCAST_1:.*]] = f32[3] broadcast(f32[] %[[RESHAPE_0]]), dimensions={} -// CHECK-NEXT: %[[ADD:.*]] = f32[3] add(f32[3] %[[GTE_3]], f32[3] %[[BDCAST_1]]) -// CHECK-NEXT: ROOT %[[TUPLE_0:.*]] = (s32[1], s32[2], f32[1], f32[3]) tuple(s32[1] %[[GTE_0]], s32[2] %[[GTE_1]], f32[1] %[[GTE_2]], f32[3] %[[ADD]]) +// CHECK-NEXT: %[[GTE_0:.*]] = s32[1] get-tuple-element(%[[TUPLE_0]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[2] get-tuple-element(%[[TUPLE_0]]), index=1 +// CHECK-NEXT: %[[GTE_2:.*]] = f32[1] get-tuple-element(%[[TUPLE_0]]), index=2 +// CHECK-NEXT: %[[GTE_3:.*]] = f32[3] get-tuple-element(%[[TUPLE_0]]), index=3 +// CHECK-NEXT: %[[BDCAST_0:.*]] = f32[1] broadcast(%[[GTE_2]]), dimensions={0} +// CHECK-NEXT: %[[RESHAPE_0:.*]] = f32[] reshape(%[[BDCAST_0]]) +// CHECK-NEXT: %[[BDCAST_1:.*]] = f32[3] broadcast(%[[RESHAPE_0]]), dimensions={} +// CHECK-NEXT: %[[ADD:.*]] = f32[3] add(%[[GTE_3]], %[[BDCAST_1]]) +// CHECK-NEXT: ROOT %[[TUPLE_0:.*]] = (s32[1], s32[2], f32[1], f32[3]) tuple(%[[GTE_0]], %[[GTE_1]], %[[GTE_2]], %[[ADD]]) // CHECK: } // CHECK: [[COND:%.+]] ([[TUPLE_0:.+]]: (s32[1], s32[2], f32[1], f32[3])) -> pred[] { // CHECK-NEXT: %[[TUPLE_0:.*]] = (s32[1], s32[2], f32[1], f32[3]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = f32[1] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=2 -// CHECK-NEXT: %[[GTE_1:.*]] = f32[3] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=3 -// CHECK-NEXT: %[[GTE_2:.*]] = s32[1] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=0 +// CHECK-NEXT: %[[GTE_0:.*]] = f32[1] get-tuple-element(%[[TUPLE_0]]), index=2 +// CHECK-NEXT: %[[GTE_1:.*]] = f32[3] get-tuple-element(%[[TUPLE_0]]), index=3 +// CHECK-NEXT: %[[GTE_2:.*]] = s32[1] get-tuple-element(%[[TUPLE_0]]), index=0 // CHECK-NEXT: %[[CST_0:.*]] = s32[] constant(0) -// CHECK-NEXT: %[[RED_0:.*]] = s32[] reduce(s32[1] %[[GTE_2]], s32[] %[[CST_0]]), dimensions={0}, to_apply= -// CHECK-NEXT: %[[GTE_3:.*]] = s32[2] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=1 -// CHECK-NEXT: %[[RED_1:.*]] = s32[] reduce(s32[2] %[[GTE_3]], s32[] %[[CST_0]]), dimensions={0}, to_apply= -// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(s32[] %[[RED_0]], s32[] %[[RED_1]]), direction=LT +// CHECK-NEXT: %[[RED_0:.*]] = s32[] reduce(%[[GTE_2]], %[[CST_0]]), dimensions={0}, to_apply= +// CHECK-NEXT: %[[GTE_3:.*]] = s32[2] get-tuple-element(%[[TUPLE_0]]), index=1 +// CHECK-NEXT: %[[RED_1:.*]] = s32[] reduce(%[[GTE_3]], %[[CST_0]]), dimensions={0}, to_apply= +// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(%[[RED_0]], %[[RED_1]]), direction=LT // CHECK: } // CHECK: ENTRY // CHECK-NEXT: %[[CST_0:.*]] = s32[1] constant({0}) // CHECK-NEXT: %[[CST_1:.*]] = s32[] constant(100) -// CHECK-NEXT: %[[BDCAST_0:.*]] = s32[2] broadcast(s32[] %[[CST_1]]), dimensions={} +// CHECK-NEXT: %[[BDCAST_0:.*]] = s32[2] broadcast(%[[CST_1]]), dimensions={} // CHECK-NEXT: %[[CST_2:.*]] = f32[1] constant({1}) // CHECK-NEXT: %[[ARG_0:.*]] = f32[3] parameter(0) -// CHECK-NEXT: %[[TUPLE:.*]] = (s32[1], s32[2], f32[1], f32[3]) tuple(s32[1] %[[CST_0]], s32[2] %[[BDCAST_0]], f32[1] %[[CST_2]], f32[3] %[[ARG_0]]) -// CHECK-NEXT: %[[WHILE:.*]] = (s32[1], s32[2], f32[1], f32[3]) while((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE]]), condition=[[COND]], body=[[BODY]] -// CHECK-NEXT: %[[GTE_0:.*]] = s32[1] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[WHILE]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[2] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[WHILE]]), index=1 -// CHECK-NEXT: %[[GTE_2:.*]] = f32[1] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[WHILE]]), index=2 -// CHECK-NEXT: ROOT %[[GTE_3:.*]] = f32[3] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[WHILE]]), index=3 +// CHECK-NEXT: %[[TUPLE:.*]] = (s32[1], s32[2], f32[1], f32[3]) tuple(%[[CST_0]], %[[BDCAST_0]], %[[CST_2]], %[[ARG_0]]) +// CHECK-NEXT: %[[WHILE:.*]] = (s32[1], s32[2], f32[1], f32[3]) while(%[[TUPLE]]), condition=[[COND]], body=[[BODY]] +// CHECK-NEXT: %[[GTE_0:.*]] = s32[1] get-tuple-element(%[[WHILE]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[2] get-tuple-element(%[[WHILE]]), index=1 +// CHECK-NEXT: %[[GTE_2:.*]] = f32[1] get-tuple-element(%[[WHILE]]), index=2 +// CHECK-NEXT: ROOT %[[GTE_3:.*]] = f32[3] get-tuple-element(%[[WHILE]]), index=3 func.func @main(%arg0: tensor<3xf32>) -> tensor<3xf32> { %0 = mhlo.constant dense<0> : tensor<1xi32> @@ -156,36 +156,36 @@ func.func @main(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: [[BODY:%.+]] ([[TUPLE:.+]]: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) { // CHECK-NEXT: %[[TUPLE:.*]] = (s32[], s32[], s32[]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[TUPLE]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[TUPLE]]), index=1 -// CHECK-NEXT: %[[ADD:.*]] = s32[] add(s32[] %[[GTE_0]], s32[] %[[GTE_1]]) -// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[TUPLE]]), index=2 -// CHECK-NEXT: ROOT %[[TUPLE_RES:.*]] = (s32[], s32[], s32[]) tuple(s32[] %[[ADD]], s32[] %[[GTE_1]], s32[] %[[GTE_2]]) +// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=1 +// CHECK-NEXT: %[[ADD:.*]] = s32[] add(%[[GTE_0]], %[[GTE_1]]) +// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=2 +// CHECK-NEXT: ROOT %[[TUPLE_RES:.*]] = (s32[], s32[], s32[]) tuple(%[[ADD]], %[[GTE_1]], %[[GTE_2]]) // CHECK: } // CHECK: [[COND:%.+]] ([[TUPLE:.+]]: (s32[], s32[], s32[])) -> pred[] { // CHECK-NEXT: %[[TUPLE:.*]] = (s32[], s32[], s32[]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[TUPLE]]), index=1 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[TUPLE]]), index=0 -// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[TUPLE]]), index=2 -// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(s32[] %[[GTE_1]], s32[] %[[GTE_2]]), direction=LT +// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=1 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=0 +// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element(%[[TUPLE]]), index=2 +// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(%[[GTE_1]], %[[GTE_2]]), direction=LT // CHECK: } // CHECK: ENTRY // CHECK-NEXT: %[[ARG_0:.*]] = (s32[], (s32[], (s32[]))) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], (s32[], (s32[]))) %[[ARG_0]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = (s32[], (s32[])) get-tuple-element((s32[], (s32[], (s32[]))) %[[ARG_0]]), index=1 -// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element((s32[], (s32[])) %[[GTE_1]]), index=0 -// CHECK-NEXT: %[[GTE_3:.*]] = (s32[]) get-tuple-element((s32[], (s32[])) %[[GTE_1]]), index=1 -// CHECK-NEXT: %[[GTE_4:.*]] = s32[] get-tuple-element((s32[]) %[[GTE_3]]), index=0 -// CHECK-NEXT: %[[TUPLE_0:.*]] = (s32[], s32[], s32[]) tuple(s32[] %[[GTE_0]], s32[] %[[GTE_2]], s32[] %[[GTE_4]]) -// CHECK-NEXT: %[[WHILE:.*]] = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %[[TUPLE_0]]), condition=[[COND]], body=[[BODY]] -// CHECK-NEXT: %[[GTE_5:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[WHILE]]), index=0 -// CHECK-NEXT: %[[GTE_6:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[WHILE]]), index=1 -// CHECK-NEXT: %[[GTE_7:.*]] = s32[] get-tuple-element((s32[], s32[], s32[]) %[[WHILE]]), index=2 -// CHECK-NEXT: %[[TUPLE_1:.*]] = (s32[]) tuple(s32[] %[[GTE_7]]) -// CHECK-NEXT: %[[TUPLE_2:.*]] = (s32[], (s32[])) tuple(s32[] %[[GTE_6]], (s32[]) %[[TUPLE_1]]) -// CHECK-NEXT: ROOT %[[TUPLE_3:.*]] = (s32[], (s32[], (s32[]))) tuple(s32[] %[[GTE_5]], (s32[], (s32[])) %[[TUPLE_2]]) +// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element(%[[ARG_0]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = (s32[], (s32[])) get-tuple-element(%[[ARG_0]]), index=1 +// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element(%[[GTE_1]]), index=0 +// CHECK-NEXT: %[[GTE_3:.*]] = (s32[]) get-tuple-element(%[[GTE_1]]), index=1 +// CHECK-NEXT: %[[GTE_4:.*]] = s32[] get-tuple-element(%[[GTE_3]]), index=0 +// CHECK-NEXT: %[[TUPLE_0:.*]] = (s32[], s32[], s32[]) tuple(%[[GTE_0]], %[[GTE_2]], %[[GTE_4]]) +// CHECK-NEXT: %[[WHILE:.*]] = (s32[], s32[], s32[]) while(%[[TUPLE_0]]), condition=[[COND]], body=[[BODY]] +// CHECK-NEXT: %[[GTE_5:.*]] = s32[] get-tuple-element(%[[WHILE]]), index=0 +// CHECK-NEXT: %[[GTE_6:.*]] = s32[] get-tuple-element(%[[WHILE]]), index=1 +// CHECK-NEXT: %[[GTE_7:.*]] = s32[] get-tuple-element(%[[WHILE]]), index=2 +// CHECK-NEXT: %[[TUPLE_1:.*]] = (s32[]) tuple(%[[GTE_7]]) +// CHECK-NEXT: %[[TUPLE_2:.*]] = (s32[], (s32[])) tuple(%[[GTE_6]], %[[TUPLE_1]]) +// CHECK-NEXT: ROOT %[[TUPLE_3:.*]] = (s32[], (s32[], (s32[]))) tuple(%[[GTE_5]], %[[TUPLE_2]]) func.func @main(%arg0: tuple, tuple, tuple>>>) -> tuple, tuple, tuple>>> { %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tuple, tuple>>>) -> tensor @@ -216,29 +216,29 @@ func.func @main(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK-NEXT: pred[] constant(false) // CHECK-NEXT: %[[ARG_0:.*]] = f32[3,3] parameter(0) // CHECK-NEXT: %[[CST_1:.*]] = f32[] constant(2) -// CHECK-NEXT: %[[BDCAST:.*]] = f32[3,3] broadcast(f32[] %[[CST_1]]), dimensions={} -// CHECK-NEXT: ROOT %[[ADD:.*]] = f32[3,3] add(f32[3,3] %[[ARG_0]], f32[3,3] %[[BDCAST]]) +// CHECK-NEXT: %[[BDCAST:.*]] = f32[3,3] broadcast(%[[CST_1]]), dimensions={} +// CHECK-NEXT: ROOT %[[ADD:.*]] = f32[3,3] add(%[[ARG_0]], %[[BDCAST]]) // CHECK: } // CHECK: [[REDUCER:%.+]] ([[ARG_0:.+]]: f32[], [[ARG_1:.+]]: f32[]) -> f32[] { // CHECK-NEXT: constant(false) // CHECK-NEXT: %[[ARG_0:.*]] = f32[] parameter(0) // CHECK-NEXT: %[[ARG_1:.*]] = f32[] parameter(1) -// CHECK-NEXT: ROOT %[[ADD:.*]] = f32[] add(f32[] %[[ARG_0]], f32[] %[[ARG_1]]) +// CHECK-NEXT: ROOT %[[ADD:.*]] = f32[] add(%[[ARG_0]], %[[ARG_1]]) // CHECK: } // CHECK: [[COND:%.+]] ([[ARG_0:.+]]: f32[3,3]) -> pred[] { // CHECK-NEXT: pred[] constant(false) // CHECK-NEXT: %[[ARG_0:.*]] = f32[3,3] parameter(0) // CHECK-NEXT: %[[CST_0:.*]] = f32[] constant(0) -// CHECK-NEXT: %[[REDUCE:.*]] = f32[] reduce(f32[3,3] %[[ARG_0]], f32[] %[[CST_0]]), dimensions={0,1}, to_apply=[[REDUCER]] +// CHECK-NEXT: %[[REDUCE:.*]] = f32[] reduce(%[[ARG_0]], %[[CST_0]]), dimensions={0,1}, to_apply=[[REDUCER]] // CHECK-NEXT: %[[CST_1:.*]] = f32[] constant(100) -// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(f32[] %[[REDUCE]], f32[] %[[CST_1]]), direction=LT +// CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(%[[REDUCE]], %[[CST_1]]), direction=LT // CHECK: ENTRY // CHECK-NEXT: %[[CST_0:.*]] = pred[] constant(false) // CHECK-NEXT: %[[ARG_0:.*]] = f32[3,3] parameter(0) -// CHECK-NEXT: ROOT %[[WHILE:.*]] = f32[3,3] while(f32[3,3] %[[ARG_0]]), condition=[[COND]], body=[[BODY]] +// CHECK-NEXT: ROOT %[[WHILE:.*]] = f32[3,3] while(%[[ARG_0]]), condition=[[COND]], body=[[BODY]] func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { %0 = mhlo.constant dense : tensor @@ -272,29 +272,29 @@ func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { // CHECK: [[BODY:%.+]] ([[ARG_TUPLE:.+]]: (s32[], s32[])) -> (s32[], s32[]) { // CHECK-NEXT: %[[ARG_TUPLE:.*]] = (s32[], s32[]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[ARG_TUPLE]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[ARG_TUPLE]]), index=1 -// CHECK-NEXT: %[[TUPLE_0:.*]] = (s32[], s32[]) tuple(s32[] %[[GTE_0]], s32[] %[[GTE_1]]) -// CHECK-NEXT: %[[CC:.*]] = (s32[], s32[]) custom-call(s32[] %[[GTE_0]], (s32[], s32[]) %[[TUPLE_0]]) -// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[CC]]), index=0 -// CHECK-NEXT: %[[GTE_3:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[CC]]), index=1 -// CHECK-NEXT: ROOT %[[TUPLE_1:.*]] = (s32[], s32[]) tuple(s32[] %[[GTE_2]], s32[] %[[GTE_3]]) +// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=1 +// CHECK-NEXT: %[[TUPLE_0:.*]] = (s32[], s32[]) tuple(%[[GTE_0]], %[[GTE_1]]) +// CHECK-NEXT: %[[CC:.*]] = (s32[], s32[]) custom-call(%[[GTE_0]], %[[TUPLE_0]]) +// CHECK-NEXT: %[[GTE_2:.*]] = s32[] get-tuple-element(%[[CC]]), index=0 +// CHECK-NEXT: %[[GTE_3:.*]] = s32[] get-tuple-element(%[[CC]]), index=1 +// CHECK-NEXT: ROOT %[[TUPLE_1:.*]] = (s32[], s32[]) tuple(%[[GTE_2]], %[[GTE_3]]) // CHECK: } // CHECK: [[COND:%.+]] ([[ARG_TUPLE:.+]]: (s32[], s32[])) -> pred[] { // CHECK-NEXT: %[[ARG_TUPLE:.*]] = (s32[], s32[]) parameter(0) -// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[ARG_TUPLE]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[ARG_TUPLE]]), index=1 -// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(s32[] %[[GTE_0]], s32[] %[[GTE_1]]), direction=LT +// CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=1 +// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(%[[GTE_0]], %[[GTE_1]]), direction=LT // CHECK: } // CHECK: ENTRY // CHECK-NEXT: %[[CST_0:.*]] = s32[] constant(0) // CHECK-NEXT: %[[ARG_0:.*]] = s32[] parameter(0) -// CHECK-NEXT: %[[TUPLE:.*]] = (s32[], s32[]) tuple(s32[] %[[CST_0]], s32[] %[[ARG_0]]) -// CHECK-NEXT: %[[WHILE:.*]] = (s32[], s32[]) while((s32[], s32[]) %[[TUPLE]]), condition=[[COND]], body=[[BODY]] -// CHECK-NEXT: ROOT %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[WHILE]]), index=0 -// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[WHILE]]), index=1 +// CHECK-NEXT: %[[TUPLE:.*]] = (s32[], s32[]) tuple(%[[CST_0]], %[[ARG_0]]) +// CHECK-NEXT: %[[WHILE:.*]] = (s32[], s32[]) while(%[[TUPLE]]), condition=[[COND]], body=[[BODY]] +// CHECK-NEXT: ROOT %[[GTE_0:.*]] = s32[] get-tuple-element(%[[WHILE]]), index=0 +// CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element(%[[WHILE]]), index=1 func.func @main(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor diff --git a/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir b/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir index b1ebb843cf8321..3d2dbb74c1fde8 100644 --- a/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir +++ b/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir @@ -8,18 +8,18 @@ // CHECK: %[[BODY:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[], s32[], f32[4])) -> (s32[], f32[4], s32[], s32[], f32[4]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK-DAG: %[[GTE12:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[ARG_TUPLE]]), index=3 -// CHECK-DAG: %[[GTE13:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[ARG_TUPLE]]), index=4, sharding={devices=[4]<=[4]} -// CHECK-DAG: %[[ADD14:add.*]] = s32[] add(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE12]]) -// CHECK-DAG: %[[ADD15:add.*]] = f32[4] add(f32[4] %get-tuple-element.{{.*}}, f32[4] %[[GTE13]]) -// CHECK: ROOT %tuple.{{.*}} = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %[[ADD14]], f32[4] %[[ADD15]], s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE12]], f32[4] %[[GTE13]]) +// CHECK-DAG: %[[GTE12:get-tuple-element.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=3 +// CHECK-DAG: %[[GTE13:get-tuple-element.*]] = f32[4] get-tuple-element(%[[ARG_TUPLE]]), index=4, sharding={devices=[4]<=[4]} +// CHECK-DAG: %[[ADD14:add.*]] = s32[] add(%get-tuple-element.{{.*}}, %[[GTE12]]) +// CHECK-DAG: %[[ADD15:add.*]] = f32[4] add(%get-tuple-element.{{.*}}, %[[GTE13]]) +// CHECK: ROOT %tuple.{{.*}} = (s32[], f32[4], s32[], s32[], f32[4]) tuple(%[[ADD14]], %[[ADD15]], %get-tuple-element.{{.*}}, %[[GTE12]], %[[GTE13]]) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} // CHECK: %[[COND:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[], s32[], f32[4])) -> pred[] { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK: %[[GTE21:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[ARG_TUPLE]]), index=2 -// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE21]]), direction=LT +// CHECK: %[[GTE21:get-tuple-element.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=2 +// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(%get-tuple-element.{{.*}}, %[[GTE21]]), direction=LT // CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4]) -> f32[4] { // CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) @@ -27,12 +27,12 @@ // CHECK-NEXT: %[[CONSTANT4:constant.*]] = s32[] constant(0) // CHECK-NEXT: %[[CONSTANT5:constant.*]] = s32[] constant(1) // CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) -// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], s32[] %[[CONSTANT4]], s32[] %[[CONSTANT5]], f32[4] %[[ARG2]]) +// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], s32[], s32[], f32[4]) tuple(%[[ARG0]], %[[ARG1]], %[[CONSTANT4]], %[[CONSTANT5]], %[[ARG2]]) // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[WHILE:while.25]] = (s32[], f32[4], s32[], s32[], f32[4]) while((s32[], f32[4], s32[], s32[], f32[4]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] +// CHECK-NEXT: %[[WHILE:while.25]] = (s32[], f32[4], s32[], s32[], f32[4]) while(%[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] // CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %[[GTE26:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[WHILE]]), index=0, sharding={replicated} -// CHECK-NEXT: ROOT %[[GTE27:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[WHILE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE26:get-tuple-element.*]] = s32[] get-tuple-element(%[[WHILE]]), index=0 +// CHECK-NEXT: ROOT %[[GTE27:get-tuple-element.*]] = f32[4] get-tuple-element(%[[WHILE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> tensor<4xf32> { %0 = mhlo.constant dense<0> : tensor @@ -60,21 +60,21 @@ func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32> { // CHECK: %[[BODY:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[])) -> (s32[], f32[4], s32[]) { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[]) parameter(0) -// CHECK: %[[GTE:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[]) %[[ARG_TUPLE]]), index=2 -// CHECK: %[[ADD:add.*]] = s32[] add(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE]]) -// CHECK: ROOT %tuple.{{.*}} = (s32[], f32[4], s32[]) tuple(s32[] %[[ADD]], f32[4] %get-tuple-element.{{.*}}, s32[] %[[GTE]]) +// CHECK: %[[GTE:get-tuple-element.*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=2 +// CHECK: %[[ADD:add.*]] = s32[] add(%get-tuple-element.{{.*}}, %[[GTE]]) +// CHECK: ROOT %tuple.{{.*}} = (s32[], f32[4], s32[]) tuple(%[[ADD]], %get-tuple-element.{{.*}}, %[[GTE]]) // CHECK: %[[COND:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[])) -> pred[] { // CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[]) parameter(0) -// CHECK: %[[GTE:get-tuple-element..*]] = s32[] get-tuple-element((s32[], f32[4], s32[]) %[[ARG_TUPLE]]), index=2 -// CHECK: ROOT %compare.{{.*}} = pred[] compare(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE]]), direction=LT +// CHECK: %[[GTE:get-tuple-element..*]] = s32[] get-tuple-element(%[[ARG_TUPLE]]), index=2 +// CHECK: ROOT %compare.{{.*}} = pred[] compare(%get-tuple-element.{{.*}}, %[[GTE]]), direction=LT // CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: s32[]) -> f32[4] { // CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) // CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1) // CHECK-NEXT: %[[ARG2]] = s32[] parameter(2) -// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], s32[]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], s32[] %[[ARG2]]) -// CHECK-NEXT: %while.{{.*}} = (s32[], f32[4], s32[]) while((s32[], f32[4], s32[]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] +// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], s32[]) tuple(%[[ARG0]], %[[ARG1]], %[[ARG2]]) +// CHECK-NEXT: %while.{{.*}} = (s32[], f32[4], s32[]) while(%[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<4xf32> { %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor, tensor<4xf32> diff --git a/xla/hlo/translate/tests/print_layouts.mlir b/xla/hlo/translate/tests/print_layouts.mlir index 40528c15e92fc1..b7e4209ddf3572 100644 --- a/xla/hlo/translate/tests/print_layouts.mlir +++ b/xla/hlo/translate/tests/print_layouts.mlir @@ -3,11 +3,11 @@ // CHECK-LABEL: main // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}, pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" -// CHECK: [[GTE1:%.*]] = (s32[3,3]{0,1}, pred[]) get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=0 -// CHECK: [[GTE2:%.*]] = s32[3,3]{0,1} get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=0 -// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element((s32[3,3]{0,1}, pred[]) [[GTE1]]), index=1 -// CHECK: [[GTE4:%.*]] = token[] get-tuple-element(((s32[3,3]{0,1}, pred[]), token[]) [[INFEED]]), index=1 +// CHECK: [[INFEED:%.*]] = ((s32[3,3]{0,1}, pred[]), token[]) infeed([[ARG]]), infeed_config="foobar" +// CHECK: [[GTE1:%.*]] = (s32[3,3]{0,1}, pred[]) get-tuple-element([[INFEED]]), index=0 +// CHECK: [[GTE2:%.*]] = s32[3,3]{0,1} get-tuple-element([[GTE1]]), index=0 +// CHECK: [[GTE3:%.*]] = pred[] get-tuple-element([[GTE1]]), index=1 +// CHECK: [[GTE4:%.*]] = token[] get-tuple-element([[INFEED]]), index=1 func.func @main(%arg0: !stablehlo.token) -> tuple, tensor>, !stablehlo.token> { %0:3 = "stablehlo.infeed"(%arg0) {infeed_config = "foobar", layout=[[0, 1], [0]]} : (!stablehlo.token) -> (tensor<3x3xi32>, tensor, !stablehlo.token) %1 = "stablehlo.tuple"(%0#0, %0#1) : (tensor<3x3xi32>, tensor) -> tuple, tensor> diff --git a/xla/hlo/translate/tests/simple.mlir b/xla/hlo/translate/tests/simple.mlir index 1a3e1f5246bb5a..74f3f50731615c 100644 --- a/xla/hlo/translate/tests/simple.mlir +++ b/xla/hlo/translate/tests/simple.mlir @@ -4,9 +4,9 @@ func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { // CHECK: %Arg_0.1 = f32[4] parameter(0) // CHECK: %Arg_1.2 = f32[4] parameter(1) - // CHECK: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + // CHECK: %add.3 = f32[4] add(%Arg_0.1, %Arg_1.2) %0 = stablehlo.add %arg0, %arg1 : tensor<4xf32> - // CHECK: ROOT %dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + // CHECK: ROOT %dot.4 = f32[] dot(%add.3, %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} %1 = stablehlo.dot %0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor func.return %1 : tensor } @@ -18,9 +18,9 @@ func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0) // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + // CHECK-NEXT: %add.3 = f32[4] add(%Arg_0.1, %Arg_1.2) %0 = "mhlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2) + // CHECK-NEXT: ROOT %add.4 = f32[4] add(%add.3, %Arg_1.2) %1 = "mhlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %1 : tensor<4xf32> } diff --git a/xla/hlo/translate/tests/vhlo_input.mlir b/xla/hlo/translate/tests/vhlo_input.mlir index f7019397c1ba4e..062fe9e0b6bd71 100644 --- a/xla/hlo/translate/tests/vhlo_input.mlir +++ b/xla/hlo/translate/tests/vhlo_input.mlir @@ -10,9 +10,9 @@ func.func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { // CHECK: %Arg_0.1 = f32[4] parameter(0) // CHECK: %Arg_1.2 = f32[4] parameter(1) - // CHECK: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + // CHECK: %add.3 = f32[4] add(%Arg_0.1, %Arg_1.2) %0 = stablehlo.add %arg0, %arg1 : tensor<4xf32> - // CHECK: ROOT %dot.4 = f32[] dot(f32[4] %add.3, f32[4] %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + // CHECK: ROOT %dot.4 = f32[] dot(%add.3, %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} %1 = stablehlo.dot %0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor func.return %1 : tensor } \ No newline at end of file diff --git a/xla/hlo/utils/hlo_matchers_test.cc b/xla/hlo/utils/hlo_matchers_test.cc index 3a9261db4fe110..a4408700aaf199 100644 --- a/xla/hlo/utils/hlo_matchers_test.cc +++ b/xla/hlo/utils/hlo_matchers_test.cc @@ -65,36 +65,30 @@ TEST_F(HloMatchersTest, Test) { op::Add(op::Parameter(), op::Multiply(_, op::Parameter()))); // Negative matches: check the explanation string. - EXPECT_THAT( - Explain(add.get(), op::Parameter()), - Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))")); - EXPECT_THAT( - Explain(add.get(), op::Add(op::Parameter())), - Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply)) " - "has too many operands (got 2, want 1)")); - EXPECT_THAT( - Explain(add.get(), op::Add(op::Parameter(), op::Parameter())), - Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))" - "\noperand 1:\n\t" - "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" - "doesn't match expected:\n\t" - "parameter" - ", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} " - "%param))")); - EXPECT_THAT( - Explain(add.get(), - op::Add(op::Parameter(), op::Multiply(op::Add(), op::Add()))), - Eq("(%add = f32[1]{0} add(f32[1]{0} %param, f32[1]{0} %multiply))" - "\noperand 1:\n\t" - "%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} %param)\n" - "doesn't match expected:\n\t" - "multiply(add, add)" - ", (%multiply = f32[1]{0} multiply(f32[1]{0} %param, f32[1]{0} " - "%param))\n" - "operand 0:\n\t" - "%param = f32[1]{0} parameter(0)\n" - "doesn't match expected:\n\t" - "add, (%param = f32[1]{0} parameter(0))")); + EXPECT_THAT(Explain(add.get(), op::Parameter()), + Eq("(%add = f32[1]{0} add(%param, %multiply))")); + EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter())), + Eq("(%add = f32[1]{0} add(%param, %multiply)) " + "has too many operands (got 2, want 1)")); + EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter(), op::Parameter())), + Eq("(%add = f32[1]{0} add(%param, %multiply))" + "\noperand 1:\n\t" + "%multiply = f32[1]{0} multiply(%param, %param)\n" + "doesn't match expected:\n\t" + "parameter" + ", (%multiply = f32[1]{0} multiply(%param, %param))")); + EXPECT_THAT(Explain(add.get(), op::Add(op::Parameter(), + op::Multiply(op::Add(), op::Add()))), + Eq("(%add = f32[1]{0} add(%param, %multiply))" + "\noperand 1:\n\t" + "%multiply = f32[1]{0} multiply(%param, %param)\n" + "doesn't match expected:\n\t" + "multiply(add, add)" + ", (%multiply = f32[1]{0} multiply(%param, %param))\n" + "operand 0:\n\t" + "%param = f32[1]{0} parameter(0)\n" + "doesn't match expected:\n\t" + "add, (%param = f32[1]{0} parameter(0))")); } TEST_F(HloMatchersTest, CustomCallMatcher) { @@ -121,8 +115,8 @@ TEST_F(HloMatchersTest, CustomCallMatcher) { ::testing::Not(op::CustomCall(::testing::StartsWith("bar")))); EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")), - "(%custom-call = f32[1]{0} custom-call(f32[3]{0} %constant, " - "s32[3]{0} %constant), custom_call_target=\"foo_target\") " + "(%custom-call = f32[1]{0} custom-call(%constant, %constant), " + "custom_call_target=\"foo_target\") " "custom-call with call target that isn't equal to \"bar\""); EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")), R"(custom-call with call target that is equal to "foo_target")"); @@ -227,21 +221,19 @@ ENTRY DotOperationFusion_TransposeFusion { /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)); - EXPECT_THAT( - Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), - /*lhs_contracting_dim=*/0, - /*rhs_contracting_dim=*/0)), - "(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " - "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " - "lhs_contracting_dimensions (got {1} want {0})"); - - EXPECT_THAT( - Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), - /*lhs_contracting_dim=*/1, - /*rhs_contracting_dim=*/1)), - "(%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " - "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " - "rhs_contracting_dimensions (got {0} want {1})"); + EXPECT_THAT(Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/0, + /*rhs_contracting_dim=*/0)), + "(%dot = f32[1,1024]{1,0} dot(%arg0, %arg1), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " + "lhs_contracting_dimensions (got {1} want {0})"); + + EXPECT_THAT(Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1)), + "(%dot = f32[1,1024]{1,0} dot(%arg0, %arg1), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}) has wrong " + "rhs_contracting_dimensions (got {0} want {1})"); } TEST_F(HloMatchersTest, ComparisonMatcher) { @@ -268,12 +260,12 @@ TEST_F(HloMatchersTest, ComparisonMatcher) { op::Add(op::Parameter(0), op::Parameter(1)))); EXPECT_THAT(Explain(eq.get(), op::Add()), - Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, " - "f32[1]{0} %param.1), direction=EQ)")); - EXPECT_THAT(Explain(eq.get(), op::Ne()), - Eq("(%compare = f32[1]{0} compare(f32[1]{0} %param.0, " - "f32[1]{0} %param.1), direction=EQ) " - "has wrong comparison direction (got EQ, want NE)")); + Eq("(%compare = f32[1]{0} compare(%param.0, %param.1), " + "direction=EQ)")); + EXPECT_THAT( + Explain(eq.get(), op::Ne()), + Eq("(%compare = f32[1]{0} compare(%param.0, %param.1), " + "direction=EQ) has wrong comparison direction (got EQ, want NE)")); } TEST_F(HloMatchersTest, AsyncCopyMatcher) { @@ -300,16 +292,12 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) { EXPECT_THAT(Explain(copy_start.get(), op::AsyncCopy(2, 1, op::Parameter(0))), Eq("(%copy-start = (f32[16]{0:S(2)}, f32[16]{0:S(1)}, u32[]) " - "copy-start(f32[16]{0:S(1)} %p0))")); + "copy-start(%p0))")); EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(3, 1, op::Parameter(0))), - "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, " - "f32[16]{0:S(1)}, u32[]) " - "%copy-start)) " + "(%copy-done = f32[16]{0:S(2)} copy-done(%copy-start)) " "copies to memory space 2, expected 3"); EXPECT_THAT(Explain(copy_done.get(), op::AsyncCopy(2, 3, op::Parameter(0))), - "(%copy-done = f32[16]{0:S(2)} copy-done((f32[16]{0:S(2)}, " - "f32[16]{0:S(1)}, u32[]) " - "%copy-start)) " + "(%copy-done = f32[16]{0:S(2)} copy-done(%copy-start)) " "is in the memory space 1, expected 3"); } @@ -354,7 +342,7 @@ TEST_F(HloMatchersTest, ReplicaGroupsMatcher) { EXPECT_THAT(Explain(p0.get(), op::ReplicaGroups({})), "%param = f32[5,7]{1,0} parameter(0) not a collective op"); EXPECT_THAT(Explain(all_to_all.get(), op::ReplicaGroups({{0, 1}, {2, 3}})), - "%all-to-all = f32[5,7]{1,0} all-to-all(f32[5,7]{1,0} %param), " + "%all-to-all = f32[5,7]{1,0} all-to-all(%param), " "replica_groups={{0,2},{1,3}} has incorrect replica_groups " "(expected: {{0,1},{2,3}})"); EXPECT_THAT(all_to_all.get(), op::ReplicaGroups({{0, 2}, {1, 3}})); diff --git a/xla/service/call_inliner_test.cc b/xla/service/call_inliner_test.cc index 602a126344f28f..444acdd246da16 100644 --- a/xla/service/call_inliner_test.cc +++ b/xla/service/call_inliner_test.cc @@ -526,9 +526,9 @@ TEST_F(CallInlinerTest, DontInlineStreamAnnotationCall) { TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); absl::StatusOr filecheck_result = RunFileCheck(module->ToString({}), R"( //CHECK: %lhs.2 = f32[] constant(42) - //CHECK: %call1 = f32[] call(f32[] %lhs.2), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} + //CHECK: %call1 = f32[] call(%lhs.2), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} //CHECK: %rhs.2 = f32[] constant(2) - //CHECK: ROOT %add.1 = f32[] add(f32[] %call1, f32[] %rhs.2) + //CHECK: ROOT %add.1 = f32[] add(%call1, %rhs.2) )"); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(*filecheck_result); diff --git a/xla/service/collective_pipeliner_test.cc b/xla/service/collective_pipeliner_test.cc index 3d8f002be32538..ce9c87309aa4be 100644 --- a/xla/service/collective_pipeliner_test.cc +++ b/xla/service/collective_pipeliner_test.cc @@ -367,18 +367,18 @@ ENTRY entry { // CHECK: HloModule // CHECK: %while_body // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}} - // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) - // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]]) - // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}}) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.*}}%[[cp]], {{.*}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.*}}%[[dus]], {{.*}}%[[dus]]) + // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[mul]], {{.*}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[dus2]], {{.*}}) // CHECK: } // CHECK: ENTRY %entry // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}} - // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) - // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]]) - // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) - // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) - // CHECK: {{.+}} = {{.+}} while({{.+}} %[[tuple]]) + // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.*}}%[[cp]], {{.*}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.*}}%[[ds]], {{.*}}%[[ds]]) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[mul]], {{.*}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.*}}%[[dus]], {{.*}}) + // CHECK: {{.+}} = {{.+}} while({{.*}}%[[tuple]]) // CHECK: } )")); } @@ -443,18 +443,18 @@ ENTRY entry { // CHECK: HloModule // CHECK: %while_body // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}} - // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) - // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]]) - // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}}) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.*}}%[[cp]], {{.*}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.*}}%[[dus]], {{.*}}%[[dus]]) + // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[mul]], {{.*}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[dus2]], {{.*}}) // CHECK: } // CHECK: ENTRY %entry // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}} - // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) - // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]]) - // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) - // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) - // CHECK: {{.+}} = {{.+}} while({{.+}} %[[tuple]]) + // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.*}}%[[cp]], {{.*}}) + // CHECK: %[[mul:.+]] = {{.+}} multiply({{.*}}%[[ds]], {{.*}}%[[ds]]) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[mul]], {{.*}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.*}}%[[dus]], {{.*}}) + // CHECK: {{.+}} = {{.+}} while({{.*}}%[[tuple]]) // CHECK: } )")); } @@ -1847,15 +1847,15 @@ ENTRY entry { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %while_body // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}} - // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}}) - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[cp]], {{.*}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[dus]], {{.*}}) // CHECK: ENTRY %entry - // CHECK: %[[while:.+]] = {{.+}} while({{.+}}) - // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}} - // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}}) - // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) - // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1 + // CHECK: %[[while:.+]] = {{.+}} while({{.*}}) + // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.*}}%[[while]]), index=1 + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.*}}%[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}} + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[cp2]], {{.*}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.*}}%[[dus]], {{.*}}) + // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.*}}%[[tuple]]), index=1 )")); } @@ -1926,15 +1926,15 @@ ENTRY entry { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %while_body // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}} - // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}}) - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[cp]], {{.*}}) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[dus]], {{.*}}) // CHECK: ENTRY %entry // CHECK: %[[while:.+]] = {{.+}} while({{.+}}) - // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}} - // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}}) - // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) - // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1 + // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.*}}%[[while]]), index=1 + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.*}}%[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}} + // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.*}}%[[cp2]], {{.*}}) + // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.*}}%[[dus]], {{.*}}) + // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.*}}%[[tuple]]), index=1 )")); } diff --git a/xla/service/dynamic_dimension_inference_test.cc b/xla/service/dynamic_dimension_inference_test.cc index caec7dc72cdc8f..944a773e09a270 100644 --- a/xla/service/dynamic_dimension_inference_test.cc +++ b/xla/service/dynamic_dimension_inference_test.cc @@ -1349,8 +1349,9 @@ ENTRY computation { /*opaque=*/std::string{}, API_VERSION_STATUS_RETURNING)); })); - absl::StatusOr filecheck_result = RunFileCheck(module_->ToString({}), - R"( + absl::StatusOr filecheck_result = RunFileCheck( + module_->ToString(HloPrintOptions().set_print_operand_shape(true)), + R"( // CHECK: compare = pred[] compare(s32[] %a_size_1, s32[] %b_size_1), direction=EQ // CHECK: compare.5 = pred[] compare(s32[] %a_size_2, s32[] %b_size_2), direction=EQ // CHECK: and.2 = pred[] and(pred[] %compare, pred[] %compare.5) diff --git a/xla/service/gather_expander_test.cc b/xla/service/gather_expander_test.cc index a7f39c326336c1..e6ea76ea4ff2cd 100644 --- a/xla/service/gather_expander_test.cc +++ b/xla/service/gather_expander_test.cc @@ -297,44 +297,44 @@ ENTRY main { const std::string expected = R"( //CHECK: (s32[], s32[5,2], s32[5,1], s32[5,1])) -> (s32[], s32[5,2], s32[5,1], s32[5,1]) { //CHECK: %[[PARAM:.*]] = (s32[], s32[5,2], s32[5,1], s32[5,1]) parameter(0) - //CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index= + //CHECK: %[[I:.*]] = s32[] get-tuple-element(%[[PARAM]]), index= //CHECK: %[[CONSTANT1:.*]] = s32[] constant(1) - //CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]]) - //CHECK: %[[OPERAND:.*]] = s32[5,2] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index=1 - //CHECK: %[[START_INDICES:.*]] = s32[5,1] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index=2 - //CHECK: %[[RESULT:.*]] = s32[5,1] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index=3 + //CHECK: %[[I_PLUS_1:.*]] = s32[] add(%[[I]], %[[CONSTANT1]]) + //CHECK: %[[OPERAND:.*]] = s32[5,2] get-tuple-element(%[[PARAM]]), index=1 + //CHECK: %[[START_INDICES:.*]] = s32[5,1] get-tuple-element(%[[PARAM]]), index=2 + //CHECK: %[[RESULT:.*]] = s32[5,1] get-tuple-element(%[[PARAM]]), index=3 - //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]]) - //CHECK: %[[I_1D_2:.*]] = s32[1] broadcast(s32[] %[[I]]) + //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(%[[I]]) + //CHECK: %[[I_1D_2:.*]] = s32[1] broadcast(%[[I]]) //CHECK: %[[START_INDICES_INDEX_D1_PAD:.*]] = s32[] constant(0) - //CHECK: %[[START_INDICES_INDEX_VECTOR:.*]] = s32[2] pad(s32[1] %[[I_1D_2]], s32[] %[[START_INDICES_INDEX_D1_PAD]]), padding=0_1 - //CHECK: %[[START_INDICES_INDEX_D0_SLICE:.*]] = s32[1] slice(s32[2] %[[START_INDICES_INDEX_VECTOR]]), slice={[0:1]} - //CHECK: %[[START_INDICES_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_D0_SLICE]]) - //CHECK: %[[START_INDICES_INDEX_D1_SLICE:.*]] = s32[1] slice(s32[2] %[[START_INDICES_INDEX_VECTOR]]), slice={[1:2]} - //CHECK: %[[START_INDICES_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_D1_SLICE]]) - //CHECK: %[[INDEX_VECTOR:.*]] = s32[1,1] dynamic-slice(s32[5,1] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX_D0]], s32[] %[[START_INDICES_INDEX_D1]]) - - //CHECK: %[[OFFSET_RAW:.*]] = s32[1] reshape(s32[1,1] %[[INDEX_VECTOR]]) - //CHECK: %[[OFFSET:.*]] = s32[1] slice(s32[1] %[[OFFSET_RAW]]) - //CHECK: %[[OPERAND_INDEX:.*]] = s32[2] concatenate(s32[1] %[[I_1D_1]], s32[1] %[[OFFSET]]) - //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[2] %[[OPERAND_INDEX]]), slice={[0:1]} - //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]]) - //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[2] %[[OPERAND_INDEX]]), slice={[1:2]} - //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]]) - //CHECK: %[[RESULT_SLICE_RAW0:.*]] = s32[1,1] dynamic-slice(s32[5,2] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]]) - - //CHECK: %[[RESULT_SLICE_RAW1:.*]] = s32[1] reshape(s32[1,1] %[[RESULT_SLICE_RAW0]]) - //CHECK: %[[RESULT_SLICE:.*]] = s32[1,1] reshape(s32[1] %[[RESULT_SLICE_RAW1]]) + //CHECK: %[[START_INDICES_INDEX_VECTOR:.*]] = s32[2] pad(%[[I_1D_2]], %[[START_INDICES_INDEX_D1_PAD]]), padding=0_1 + //CHECK: %[[START_INDICES_INDEX_D0_SLICE:.*]] = s32[1] slice(%[[START_INDICES_INDEX_VECTOR]]), slice={[0:1]} + //CHECK: %[[START_INDICES_INDEX_D0:.*]] = s32[] reshape(%[[START_INDICES_INDEX_D0_SLICE]]) + //CHECK: %[[START_INDICES_INDEX_D1_SLICE:.*]] = s32[1] slice(%[[START_INDICES_INDEX_VECTOR]]), slice={[1:2]} + //CHECK: %[[START_INDICES_INDEX_D1:.*]] = s32[] reshape(%[[START_INDICES_INDEX_D1_SLICE]]) + //CHECK: %[[INDEX_VECTOR:.*]] = s32[1,1] dynamic-slice(%[[START_INDICES]], %[[START_INDICES_INDEX_D0]], %[[START_INDICES_INDEX_D1]]) + + //CHECK: %[[OFFSET_RAW:.*]] = s32[1] reshape(%[[INDEX_VECTOR]]) + //CHECK: %[[OFFSET:.*]] = s32[1] slice(%[[OFFSET_RAW]]) + //CHECK: %[[OPERAND_INDEX:.*]] = s32[2] concatenate(%[[I_1D_1]], %[[OFFSET]]) + //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(%[[OPERAND_INDEX]]), slice={[0:1]} + //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(%[[OPERAND_INDEX_D0_RAW]]) + //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(%[[OPERAND_INDEX]]), slice={[1:2]} + //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(%[[OPERAND_INDEX_D1_RAW]]) + //CHECK: %[[RESULT_SLICE_RAW0:.*]] = s32[1,1] dynamic-slice(%[[OPERAND]], %[[OPERAND_INDEX_D0]], %[[OPERAND_INDEX_D1]]) + + //CHECK: %[[RESULT_SLICE_RAW1:.*]] = s32[1] reshape(%[[RESULT_SLICE_RAW0]]) + //CHECK: %[[RESULT_SLICE:.*]] = s32[1,1] reshape(%[[RESULT_SLICE_RAW1]]) //CHECK: %[[RESULT_INDEX_D1_PAD:.*]] = s32[] constant(0) - //CHECK: %[[RESULT_INDEX_VECTOR:.*]] = s32[2] pad(s32[1] %[[I_1D_2]], s32[] %[[RESULT_INDEX_D1_PAD]]), padding=0_1 - //CHECK: %[[RESULT_INDEX_D0_SLICE:.*]] = s32[1] slice(s32[2] %[[RESULT_INDEX_VECTOR]]), slice={[0:1]} - //CHECK: %[[RESULT_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[RESULT_INDEX_D0_SLICE]]) - //CHECK: %[[RESULT_INDEX_D1_SLICE:.*]] = s32[1] slice(s32[2] %[[RESULT_INDEX_VECTOR]]), slice={[1:2]} - //CHECK: %[[RESULT_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[RESULT_INDEX_D1_SLICE]]) - //CHECK: %[[UPDATED_RESULT:.*]] = s32[5,1] dynamic-update-slice(s32[5,1] %[[RESULT]], s32[1,1] %[[RESULT_SLICE]], s32[] %[[RESULT_INDEX_D0]], s32[] %[[RESULT_INDEX_D1]]) - - //CHECK: ROOT %{{.*}} = (s32[], s32[5,2], s32[5,1], s32[5,1]) tuple(s32[] %[[I_PLUS_1]], s32[5,2] %[[OPERAND]], s32[5,1] %[[START_INDICES]], s32[5,1] %[[UPDATED_RESULT]]) + //CHECK: %[[RESULT_INDEX_VECTOR:.*]] = s32[2] pad(%[[I_1D_2]], %[[RESULT_INDEX_D1_PAD]]), padding=0_1 + //CHECK: %[[RESULT_INDEX_D0_SLICE:.*]] = s32[1] slice(%[[RESULT_INDEX_VECTOR]]), slice={[0:1]} + //CHECK: %[[RESULT_INDEX_D0:.*]] = s32[] reshape(%[[RESULT_INDEX_D0_SLICE]]) + //CHECK: %[[RESULT_INDEX_D1_SLICE:.*]] = s32[1] slice(%[[RESULT_INDEX_VECTOR]]), slice={[1:2]} + //CHECK: %[[RESULT_INDEX_D1:.*]] = s32[] reshape(%[[RESULT_INDEX_D1_SLICE]]) + //CHECK: %[[UPDATED_RESULT:.*]] = s32[5,1] dynamic-update-slice(%[[RESULT]], %[[RESULT_SLICE]], %[[RESULT_INDEX_D0]], %[[RESULT_INDEX_D1]]) + + //CHECK: ROOT %{{.*}} = (s32[], s32[5,2], s32[5,1], s32[5,1]) tuple(%[[I_PLUS_1]], %[[OPERAND]], %[[START_INDICES]], %[[UPDATED_RESULT]]) )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -369,38 +369,38 @@ ENTRY main { const std::string expected = R"( //CHECK: (s32[], s32[7,3,4,5], s32[70], s32[70,3])) -> (s32[], s32[7,3,4,5], s32[70], s32[70,3]) { //CHECK: %[[PARAM:.*]] = (s32[], s32[7,3,4,5], s32[70], s32[70,3]) parameter(0) - //CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[7,3,4,5], s32[70], s32[70,3]) %[[PARAM]]), index=0 + //CHECK: %[[I:.*]] = s32[] get-tuple-element(%[[PARAM]]), index=0 //CHECK: %[[CONSTANT1:.*]] = s32[] constant(1) - //CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]]) - //CHECK: %[[OPERAND:.*]] = s32[7,3,4,5] get-tuple-element((s32[], s32[7,3,4,5], s32[70], s32[70,3]) %[[PARAM]]), index=1 - //CHECK: %[[START_INDICES:.*]] = s32[70] get-tuple-element((s32[], s32[7,3,4,5], s32[70], s32[70,3]) %[[PARAM]]), index=2 + //CHECK: %[[I_PLUS_1:.*]] = s32[] add(%[[I]], %[[CONSTANT1]]) + //CHECK: %[[OPERAND:.*]] = s32[7,3,4,5] get-tuple-element(%[[PARAM]]), index=1 + //CHECK: %[[START_INDICES:.*]] = s32[70] get-tuple-element(%[[PARAM]]), index=2 //CHECK: %[[CONSTANT7:.*]] = s32[] constant(7) - //CHECK: %[[BD0_RAW:.*]] = s32[] remainder(s32[] %[[I]], s32[] %[[CONSTANT7]]) - //CHECK: %[[BD0:.*]] = s32[1] broadcast(s32[] %[[BD0_RAW]]) + //CHECK: %[[BD0_RAW:.*]] = s32[] remainder(%[[I]], %[[CONSTANT7]]) + //CHECK: %[[BD0:.*]] = s32[1] broadcast(%[[BD0_RAW]]) //CHECK: %[[CONSTANT0:.*]] = s32[1] constant({0}) - //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]]) - //CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(s32[1] %[[I_1D_1]]) - //CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_RAW]]) - //CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(s32[70] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX]]) + //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(%[[I]]) + //CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(%[[I_1D_1]]) + //CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(%[[START_INDICES_INDEX_RAW]]) + //CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(%[[START_INDICES]], %[[START_INDICES_INDEX]]) - //CHECK: %[[OFFSET:.*]] = s32[1] slice(s32[1] %[[INDEX_VECTOR]]) - //CHECK: %[[BD1:.*]] = s32[] divide(s32[] %[[I]], s32[] %[[CONSTANT7]]) + //CHECK: %[[OFFSET:.*]] = s32[1] slice(%[[INDEX_VECTOR]]) + //CHECK: %[[BD1:.*]] = s32[] divide(%[[I]], %[[CONSTANT7]]) //CHECK: %[[CONSTANT2:.*]] = s32[] constant(2) - //CHECK: %[[BD2_RAW:.*]] = s32[] divide(s32[] %[[BD1]], s32[] %[[CONSTANT2]]) - //CHECK: %[[BD2:.*]] = s32[1] broadcast(s32[] %[[BD2_RAW]]) - //CHECK: %[[OPERAND_INDEX:.*]] = s32[4] concatenate(s32[1] %[[BD0]], s32[1] %[[CONSTANT0]], s32[1] %[[OFFSET]], s32[1] %[[BD2]]) - - //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[0:1]} - //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]]) - //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[1:2]} - //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]]) - //CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[2:3]} - //CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D2_RAW]]) - //CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[3:4]} - //CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D3_RAW]]) - //CHECK: %{{.*}} = s32[1,3,1,1] dynamic-slice(s32[7,3,4,5] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]], s32[] %[[OPERAND_INDEX_D2]], s32[] %[[OPERAND_INDEX_D3]]) + //CHECK: %[[BD2_RAW:.*]] = s32[] divide(%[[BD1]], %[[CONSTANT2]]) + //CHECK: %[[BD2:.*]] = s32[1] broadcast(%[[BD2_RAW]]) + //CHECK: %[[OPERAND_INDEX:.*]] = s32[4] concatenate(%[[BD0]], %[[CONSTANT0]], %[[OFFSET]], %[[BD2]]) + + //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(%[[OPERAND_INDEX]]), slice={[0:1]} + //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(%[[OPERAND_INDEX_D0_RAW]]) + //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(%[[OPERAND_INDEX]]), slice={[1:2]} + //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(%[[OPERAND_INDEX_D1_RAW]]) + //CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(%[[OPERAND_INDEX]]), slice={[2:3]} + //CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(%[[OPERAND_INDEX_D2_RAW]]) + //CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(%[[OPERAND_INDEX]]), slice={[3:4]} + //CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(%[[OPERAND_INDEX_D3_RAW]]) + //CHECK: %{{.*}} = s32[1,3,1,1] dynamic-slice(%[[OPERAND]], %[[OPERAND_INDEX_D0]], %[[OPERAND_INDEX_D1]], %[[OPERAND_INDEX_D2]], %[[OPERAND_INDEX_D3]]) )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index 80aff6869c3edd..2f70103178ad50 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -711,6 +711,9 @@ ENTRY main { const std::string fallback_convert_to_f16 = R"(CHECK: dot(f16{{[^)]*}}, f16{{[^)]*}}))"; + HloPrintOptions print_options = + HloPrintOptions().set_print_operand_shape(true); + { // Triton enabled, no fallback. TF_ASSERT_OK_AND_ASSIGN(auto optimized_module_no_fallback, @@ -725,7 +728,7 @@ ENTRY main { : cublas_convert_to_f16; TF_ASSERT_OK_AND_ASSIGN( bool filecheck_matched, - RunFileCheck(optimized_module_no_fallback->ToString(), + RunFileCheck(optimized_module_no_fallback->ToString(print_options), triton_expected_check)); EXPECT_TRUE(filecheck_matched); } @@ -743,9 +746,10 @@ ENTRY main { ? cublaslt_keep_types : cublas_convert_to_f16; - TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, - RunFileCheck(optimized_module_no_triton->ToString(), - blas_expected_check)); + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matched, + RunFileCheck(optimized_module_no_triton->ToString(print_options), + blas_expected_check)); EXPECT_TRUE(filecheck_matched); } @@ -755,9 +759,10 @@ ENTRY main { optimize_module(/*enable_triton=*/false, /*enable_blas=*/false, /*enable_blas_fallback=*/false)); - TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, - RunFileCheck(optimized_module_nothing->ToString(), - fallback_convert_to_f16)); + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matched, + RunFileCheck(optimized_module_nothing->ToString(print_options), + fallback_convert_to_f16)); EXPECT_TRUE(filecheck_matched); } } diff --git a/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/xla/service/gpu/gpu_p2p_pipeliner_test.cc index ad1679386acaa8..386817e39fb7be 100644 --- a/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -245,22 +245,22 @@ TEST_F(GpuP2PPipelinerTest, SendRecvForwardCycle) { // back edge and one set for the forward edge. Also check that the send/recv // target pairs and validation attributes are correct. CHECK: %[[RECV_BWD_START:.*]] = {{.*}} after-all() - CHECK: %[[RECV_BWD:.*]] = {{.*}} recv(token[] %[[RECV_BWD_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation={{[{][{]}}2,9{{[}][}]}}} - CHECK: %[[RECV_DONE_BWD:.*]] = {{.*}} recv-done({{.*}} %[[RECV_BWD:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %[[RECV_BWD:.*]] = {{.*}} recv(%[[RECV_BWD_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation={{[{][{]}}2,9{{[}][}]}}} + CHECK: %[[RECV_DONE_BWD:.*]] = {{.*}} recv-done(%[[RECV_BWD:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} CHECK: %[[RECV_FWD_START:.*]] = {{.*}} after-all() - CHECK: %[[RECV_FWD:.*]] = {{.*}} recv(token[] %[[RECV_FWD_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,6},{0,7},{1,8{{[}][}]}}} - CHECK: %[[RECV_DONE_FWD:.*]] = {{.*}} recv-done((f32[2,2]{1,0}, u32[], token[]) %[[RECV_FWD:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} - CHECK: %[[SEND_BWD:.*]] = {{.*}} send({{.*}} %[[RECV_BWD_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation={{[{][{]}}2,9{{[}][}]}}} - CHECK: %[[SEND_DONE_BWD:.*]] = {{.*}} send-done({{.*}} %[[SEND_BWD:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %[[SEND_FWD:.*]] = {{.*}} send({{.*}} %[[RECV_FWD_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,6},{0,7},{1,8{{[}][}]}}} - CHECK: %[[SEND_DONE_FWD:.*]] = {{.*}} send-done({{.*}} %[[SEND_FWD:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %[[RECV_FWD:.*]] = {{.*}} recv(%[[RECV_FWD_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,6},{0,7},{1,8{{[}][}]}}} + CHECK: %[[RECV_DONE_FWD:.*]] = {{.*}} recv-done(%[[RECV_FWD:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %[[SEND_BWD:.*]] = {{.*}} send(%[[RECV_BWD_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation={{[{][{]}}2,9{{[}][}]}}} + CHECK: %[[SEND_DONE_BWD:.*]] = {{.*}} send-done(%[[SEND_BWD:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %[[SEND_FWD:.*]] = {{.*}} send(%[[RECV_FWD_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,6},{0,7},{1,8{{[}][}]}}} + CHECK: %[[SEND_DONE_FWD:.*]] = {{.*}} send-done(%[[SEND_FWD:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} // Check that the total iterations of the while loop in the output is 1 // fewer than the max iteration of the input HLO. CHECK: %[[WHILE_COND:.*]] (cond_param: {{.*}} CHECK-NEXT: %[[COND_PARAM:.*]] = {{.*}} parameter(0) - CHECK: %[[CURRENT_ITER:.*]] = {{.*}} get-tuple-element({{.*}} %[[COND_PARAM:.*]]), index=0 + CHECK: %[[CURRENT_ITER:.*]] = {{.*}} get-tuple-element(%[[COND_PARAM:.*]]), index=0 CHECK: %[[TWO:.*]] = {{.*}} constant(2) - CHECK: ROOT %[[COMPARE:.*]] = pred[] compare({{.*}} %[[CURRENT_ITER:.*]], {{.*}} %[[TWO:.*]]), direction=LT + CHECK: ROOT %[[COMPARE:.*]] = pred[] compare(%[[CURRENT_ITER:.*]], %[[TWO:.*]]), direction=LT // Check that after transformation, main function in ENTRY contains the // first iteration of the while loop. @@ -268,22 +268,22 @@ TEST_F(GpuP2PPipelinerTest, SendRecvForwardCycle) { // Set up dummy send and recv. CHECK: %[[RECV_BWD_DUMMY_START:.*]] = {{.*}} after-all() - CHECK: %[[RECV_BWD_DUMMY:.*]] = {{.*}} recv(token[] %[[RECV_BWD_DUMMY_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation="invalid"} - CHECK: %[[RECV_DONE_BWD_DUMMY:.*]] = {{.*}} recv-done({{.*}} %[[RECV_BWD_DUMMY:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %[[RECV_BWD_DUMMY:.*]] = {{.*}} recv(%[[RECV_BWD_DUMMY_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation="invalid"} + CHECK: %[[RECV_DONE_BWD_DUMMY:.*]] = {{.*}} recv-done(%[[RECV_BWD_DUMMY:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} // Execute what was previously iter 0 of the while loop. CHECK: %[[RECV_FWD_FIRST_ITER_START:.*]] = {{.*}} after-all() - CHECK: %[[RECV_FWD_FIRST_ITER:.*]] = {{.*}} recv(token[] %[[RECV_FWD_FIRST_ITER_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,0},{1,0},{1,0{{[}][}]}}} - CHECK: %[[RECV_DONE_FWD_FIRST_ITER:.*]] = {{.*}} recv-done({{.*}} %[[RECV_FWD_FIRST_ITER:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} - CHECK: %[[SEND_BWD_DUMMY:.*]] = {{.*}} send({{.*}} %[[RECV_DUMMY_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation="invalid"} - CHECK: %[[SEND_DONE_BWD_DUMMY:.*]] = {{.*}} send-done({{.*}} %[[SEND_BWD_DUMMY:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK: %[[SEND_FWD_FIRST_ITER:.*]] = {{.*}} send({{.*}} %[[RECV_FWD_FIRST_ITER_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,0},{1,0},{1,0{{[}][}]}}} - CHECK: %[[SEND_DONE_FWD_FIRST_ITER:.*]] = {{.*}} send-done({{.*}} %[[SEND_FWD_FIRST_ITER:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %[[RECV_FWD_FIRST_ITER:.*]] = {{.*}} recv(%[[RECV_FWD_FIRST_ITER_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,0},{1,0},{1,0{{[}][}]}}} + CHECK: %[[RECV_DONE_FWD_FIRST_ITER:.*]] = {{.*}} recv-done(%[[RECV_FWD_FIRST_ITER:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} + CHECK: %[[SEND_BWD_DUMMY:.*]] = {{.*}} send(%[[RECV_DUMMY_START:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{[{][{]}}3,0{{[}][}]}},_xla_send_recv_validation="invalid"} + CHECK: %[[SEND_DONE_BWD_DUMMY:.*]] = {{.*}} send-done(%[[SEND_BWD_DUMMY:.*]]), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} + CHECK: %[[SEND_FWD_FIRST_ITER:.*]] = {{.*}} send(%[[RECV_FWD_FIRST_ITER_START:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{[{][{]}}0,1},{1,2},{2,3{{[}][}]}},_xla_send_recv_validation={{[{][{]}}0,0},{1,0},{1,0{{[}][}]}}} + CHECK: %[[SEND_DONE_FWD_FIRST_ITER:.*]] = {{.*}} send-done(%[[SEND_FWD_FIRST_ITER:.*]]), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} // Set up main loop, starting from iter 1. CHECK: %[[START_LOOP_FROM_ITER_ONE:.*]] = u32[] constant(1) - CHECK: %[[LOOP_INPUT:.*]] = {{.*}} tuple({{.*}} %[[START_LOOP_FROM_ITER_ONE:.*]]) - CHECK: %[[WHILE:.*]] = {{.*}} while({{.*}} %[[LOOP_INPUT:.*]]), {{.*}} + CHECK: %[[LOOP_INPUT:.*]] = {{.*}} tuple(%[[START_LOOP_FROM_ITER_ONE:.*]]) + CHECK: %[[WHILE:.*]] = {{.*}} while(%[[LOOP_INPUT:.*]]), {{.*}} )") .value()); } diff --git a/xla/service/gpu/tests/dot_bf16.hlo b/xla/service/gpu/tests/dot_bf16.hlo index a88d1b17befc91..dd2a75881159ba 100644 --- a/xla/service/gpu/tests/dot_bf16.hlo +++ b/xla/service/gpu/tests/dot_bf16.hlo @@ -3,8 +3,13 @@ // RUN: %if IS_ROCM %{ hlo-opt %s --platform=gpu --stage=hlo --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/mi200.txtpb --split-input-file --xla_gpu_autotune_level=0 --xla_gpu_enable_triton_gemm=false | FileCheck %s --check-prefixes=CHECK-SM80 %} -// CHECK-SM70: custom-call(f32 -// CHECK-SM80: custom-call(bf16 +// CHECK-SM70: %[[convert1:.+]] = f32[1536,6144]{1,0} convert(%{{.+}}) +// CHECK-SM70: %[[convert2:.+]] = f32[32,1536]{1,0} convert(%{{.+}}) +// CHECK-SM70: custom-call(%[[convert1]], %[[convert2]]), custom_call_target="__cublas$gemm" + +// CHECK-SM80: %[[convert:.+]] = bf16[1536,6144]{1,0} convert(%{{.+}}) +// CHECK-SM80: %[[b:.+]] = bf16[32,1536]{1,0} parameter(1) +// CHECK-SM80: custom-call(%[[convert]], %[[b]]), custom_call_target="__cublas$gemm" HloModule module @@ -17,8 +22,13 @@ ENTRY %computation1 { // ----- -// CHECK-SM70: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(f32[1536,6144]{1,0} {{.*}}, f32[32,1536]{1,0} {{.*}}), custom_call_target="__cublas$gemm" -// CHECK-SM80: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(bf16[1536,6144]{1,0} %convert.2.0, bf16[32,1536]{1,0} %b.1), custom_call_target="__cublas$gemm" +// CHECK-SM70: %[[convert1:.+]] = f32[1536,6144]{1,0} convert(%{{.+}}) +// CHECK-SM70: %[[convert2:.+]] = f32[32,1536]{1,0} convert(%{{.+}}) +// CHECK-SM70: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(%[[convert1]], %[[convert2]]), custom_call_target="__cublas$gemm" + +// CHECK-SM80: %[[convert:.+]] = bf16[1536,6144]{1,0} convert(%{{.+}}) +// CHECK-SM80: %[[b:.+]] = bf16[32,1536]{1,0} parameter(1) +// CHECK-SM80: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(%[[convert]], %[[b]]), custom_call_target="__cublas$gemm" HloModule module2 diff --git a/xla/service/gpu/transforms/all_reduce_splitter_test.cc b/xla/service/gpu/transforms/all_reduce_splitter_test.cc index 581237dd5f0479..cae04a15717ab7 100644 --- a/xla/service/gpu/transforms/all_reduce_splitter_test.cc +++ b/xla/service/gpu/transforms/all_reduce_splitter_test.cc @@ -115,15 +115,15 @@ ENTRY main { EXPECT_THAT(AllReduceSplitter().Run(module.get()), IsOkAndHolds(true)); TF_EXPECT_OK(FileCheck(module->ToString(), R"( CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0) - CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]]) + CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(%[[P0]]) CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]} CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0) - CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]]) - CHECK: %[[AR1:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]]) + CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(%[[AR0]], %[[ZERO]]) + CHECK: %[[AR1:.*]] = bf16[4096]{0} all-reduce(%[[LOCAL_REDUCE]]) CHECK-SAME: replica_groups={[[DESIRED_RGS]]} - CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR1]], s32[] %[[_:.*]]) + CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(%[[AR1]], %[[_:.*]]) CHECK-SAME: dynamic_slice_sizes={1024} - CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]]) + CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(%[[DS]]) CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}} )")); } @@ -202,14 +202,14 @@ ENTRY main { TF_EXPECT_OK(FileCheck(module->ToString(), R"( CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0) CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0) - CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]]) - CHECK: %[[AR0:.*]] = bf16[4096]{0} all-reduce(bf16[4096]{0} %[[LOCAL_REDUCE]]) + CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(%[[P0]], %[[ZERO]]) + CHECK: %[[AR0:.*]] = bf16[4096]{0} all-reduce(%[[LOCAL_REDUCE]]) CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]} - CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(bf16[4096]{0} %[[AR0]], s32[] %[[_:.*]]) + CHECK: %[[DS:.*]] = bf16[1024]{0} dynamic-slice(%[[AR0]], %[[_:.*]]) CHECK-SAME: dynamic_slice_sizes={1024} - CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[DS]]) + CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(%[[DS]]) CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}} - CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]]) + CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(%[[P0]]) CHECK-SAME: replica_groups={[[DESIRED_RGS]]} CHECK: ROOT CHECK-NOT: %[[AR1]] @@ -438,13 +438,13 @@ ENTRY main { EXPECT_THAT(pipeline.Run(module.get()), IsOkAndHolds(true)); TF_EXPECT_OK(FileCheck(module->ToString(), R"( CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0) - CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]]) + CHECK: %[[AR0:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(%[[P0]]) CHECK-SAME: replica_groups={[[DESIRED_RGS:.*]]} CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0) - CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[AR0]], bf16[] %[[ZERO]]) - CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]]) + CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(%[[AR0]], %[[ZERO]]) + CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(%[[LOCAL_REDUCE]]) CHECK-SAME: replica_groups={[[DESIRED_RGS]]} - CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]]) + CHECK-NEXT: ROOT %[[AR2:.*]] = bf16[1024]{0} all-reduce(%[[REDUCE_SCATTER]]) CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}} )")); } @@ -490,11 +490,11 @@ ENTRY main { TF_EXPECT_OK(FileCheck(module->ToString(), R"( CHECK-DAG: %[[P0:.*]] = bf16[2,4096,4096]{2,1,0} parameter(0) CHECK-DAG: %[[ZERO:.*]] = bf16[] constant(0) - CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(bf16[2,4096,4096]{2,1,0} %[[P0]], bf16[] %[[ZERO]]) - CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(bf16[4096]{0} %[[LOCAL_REDUCE]]) - CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(bf16[1024]{0} %[[REDUCE_SCATTER]]) + CHECK-DAG: %[[LOCAL_REDUCE:.*]] = bf16[4096]{0} reduce(%[[P0]], %[[ZERO]]) + CHECK: %[[REDUCE_SCATTER:.*]] = bf16[1024]{0} reduce-scatter(%[[LOCAL_REDUCE]]) + CHECK-NEXT: %[[AR1:.*]] = bf16[1024]{0} all-reduce(%[[REDUCE_SCATTER]]) CHECK-SAME: replica_groups={{[{]}}{0,4},{1,5},{2,6},{3,7}{{[}]}} - CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(bf16[2,4096,4096]{2,1,0} %[[P0]]) + CHECK: %[[EXISTING_AR:.*]] = bf16[2,4096,4096]{2,1,0} all-reduce(%[[P0]]) CHECK: ROOT CHECK-NOT: %[[AR1]] CHECK-SAME: %[[EXISTING_AR]] diff --git a/xla/service/gpu/transforms/collective_select_folder_test.cc b/xla/service/gpu/transforms/collective_select_folder_test.cc index 5d7453abba64d3..dbd43a4ec06ba1 100644 --- a/xla/service/gpu/transforms/collective_select_folder_test.cc +++ b/xla/service/gpu/transforms/collective_select_folder_test.cc @@ -430,18 +430,18 @@ TEST_F(CollectiveSelectFolderTest, // CHECK: ENTRY %computation // CHECK: %[[PARAM:.*]] = (f32[8192]{0}, f32[8192]{0}) parameter(0) // CHECK: %[[OPERAND_BWD:.*]] = {{.*}} get-tuple-element - // CHECK-SAME: ({{.*}} %[[PARAM]]), index=0 + // CHECK-SAME: ({{.*}}%[[PARAM]]), index=0 // CHECK: %[[OPERAND_FWD:.*]] = {{.*}} get-tuple-element - // CHECK-SAME: ({{.*}} %[[PARAM]]), index=1 + // CHECK-SAME: ({{.*}}%[[PARAM]]), index=1 // CHECK: %[[CP_BWD:.*]] = {{.*}} collective-permute - // CHECK-SAME: ({{.*}} %[[OPERAND_BWD]]), channel_id=1, + // CHECK-SAME: ({{.*}}%[[OPERAND_BWD]]), channel_id=1, // CHECK-SAME: source_target_pairs={{\{}}{3,0}} // CHECK: %[[CP_FWD:.*]] = {{.*}} collective-permute - // CHECK-SAME: ({{.*}} %[[OPERAND_FWD]]), channel_id=2, + // CHECK-SAME: ({{.*}}%[[OPERAND_FWD]]), channel_id=2, // CHECK-SAME: source_target_pairs={{\{}}{0,1},{1,2},{2,3}} // CHECK: ROOT %{{.*}} = - // CHECK-SAME: select({{.*}} %{{.*}}, {{.*}} %[[CP_BWD]], - // CHECK-SAME: %[[CP_FWD]]) + // CHECK-SAME: select({{.*}}, {{.*}}%[[CP_BWD]], + // CHECK-SAME: {{.*}}%[[CP_FWD]]) // CHECK: } )"; TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result, @@ -476,8 +476,8 @@ TEST_F(CollectiveSelectFolderTest, DtypeConvertedPartitionId) { /*expect_change=*/true)); const absl::string_view kExpected = R"( // CHECK: %[[PARAM:.*]] = {{.*}} parameter(0) - // CHECK: %[[DATA_A:.*]] = {{.*}} get-tuple-element({{.*}} %[[PARAM]]), index=0 - // CHECK: ROOT %[[DATA_A_:.*]] = {{.*}} collective-permute({{.*}} %[[DATA_A]]) + // CHECK: %[[DATA_A:.*]] = {{.*}} get-tuple-element({{.*}}%[[PARAM]]), index=0 + // CHECK: ROOT %[[DATA_A_:.*]] = {{.*}} collective-permute({{.*}}%[[DATA_A]]) )"; TF_ASSERT_OK_AND_ASSIGN(bool filecheck_result, RunFileCheck(module->ToString(), kExpected)); diff --git a/xla/service/gpu/transforms/collective_send_recv_combiner_test.cc b/xla/service/gpu/transforms/collective_send_recv_combiner_test.cc index f164f6a44ac571..54c7c6bca13792 100644 --- a/xla/service/gpu/transforms/collective_send_recv_combiner_test.cc +++ b/xla/service/gpu/transforms/collective_send_recv_combiner_test.cc @@ -57,24 +57,24 @@ TEST_F(CollectiveSendRecvCombinerTest, TransformedWithSourceTargetPairs) { CHECK-SAME: ((f32[], u32[], token[]), (f32[], u32[], token[])) CHECK-NEXT: %[[PARAM0:.*]] = f32[] parameter(0) CHECK: %[[PARAM1:.*]] = token[] parameter(1) - CHECK: %[[SEND1:.*]] = ({{.*}}) send(f32[] %[[PARAM0]], token[] %[[PARAM1]]), channel_id=1, + CHECK: %[[SEND1:.*]] = ({{.*}}) send(%[[PARAM0]], %[[PARAM1]]), channel_id=1, CHECK-SAME: frontend_attributes{{.*}}_xla_send_recv_source_target_pairs{{.*}}0,1{{.*}}1,2{{.*}}2,3{{.*}} CHECK-NEXT: %[[PARAM2:.*]] = {{.*}} parameter(2) - CHECK: %[[RECV1:.*]] = ({{.*}}) recv({{.*}} %[[PARAM2]]), channel_id=1, + CHECK: %[[RECV1:.*]] = ({{.*}}) recv(%[[PARAM2]]), channel_id=1, CHECK-SAME: frontend_attributes{{.*}}_xla_send_recv_source_target_pairs{{.*}}0,1{{.*}}1,2{{.*}}2,3{{.*}} - CHECK-NEXT: ROOT %[[OUT:.*]] = {{.*}} tuple(({{.*}}) %[[SEND1]], ({{.*}}) %[[RECV1]]) + CHECK-NEXT: ROOT %[[OUT:.*]] = {{.*}} tuple(%[[SEND1]], %[[RECV1]]) CHECK: ENTRY %[[MAIN:.*]] () -> f32[] CHECK: %[[DATA:.*]] = {{.*}} constant(5) CHECK: %[[RECV_START:.*]] = {{.*}} after-all() - CHECK: %[[TUPLE_START:.*]] = {{.*}} async-start({{.*}} %[[DATA]], {{.*}} %[[RECV_START]], {{.*}} %[[RECV_START]]), calls=%[[WRAPPED_SEND_RECV]] - CHECK-NEXT: %[[TUPLE_DONE:.*]] = {{.*}} async-done({{.*}}%[[TUPLE_START]]) - CHECK %[[GTE2:.*]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE_DONE]]), index=1 - CHECK %[[GTE3:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE2]]), index=0 - CHECK %[[GTE4:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE2]]), index=2 - CHECK %[[TUPLE1:.*]] = {{.*}} tuple({{.*}} %[[GTE3:.*]], {{.*}} %[[GTE4]]) - CHECK ROOT %[[OUT:.*]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE1]]), index=0 - CHECK %[[GTE:.*]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE_DONE]]), index=0 - CHECK %[[GTE1:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE]]), index=2 + CHECK: %[[TUPLE_START:.*]] = {{.*}} async-start(%[[DATA]], %[[RECV_START]], %[[RECV_START]]), calls=%[[WRAPPED_SEND_RECV]] + CHECK-NEXT: %[[TUPLE_DONE:.*]] = {{.*}} async-done(%[[TUPLE_START]]) + CHECK %[[GTE2:.*]] = {{.*}} get-tuple-element(%[[TUPLE_DONE]]), index=1 + CHECK %[[GTE3:.*]] = {{.*}} get-tuple-element(%[[GTE2]]), index=0 + CHECK %[[GTE4:.*]] = {{.*}} get-tuple-element(%[[GTE2]]), index=2 + CHECK %[[TUPLE1:.*]] = {{.*}} tuple(%[[GTE3]], %[[GTE4]]) + CHECK ROOT %[[OUT:.*]] = {{.*}} get-tuple-element(%[[TUPLE1]]), index=0 + CHECK %[[GTE:.*]] = {{.*}} get-tuple-element(%[[TUPLE_DONE]]), index=0 + CHECK %[[GTE1:.*]] = {{.*}} get-tuple-element(%[[GTE]]), index=2 )")); } @@ -153,42 +153,27 @@ TEST_F(CollectiveSendRecvCombinerTest, TransformedWithControlDependency) { TF_ASSERT_OK_AND_ASSIGN(bool changed, combiner.Run(module.get())); EXPECT_TRUE(changed); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( - CHECK: %[[WRAPPED_SEND_RECV:.*]] - (param0: f32[], param1: token[], param2: token[]) -> - ((f32[], u32[], token[]), (f32[], u32[], token[])) { + CHECK: %[[WRAPPED_SEND_RECV:.*]] (param0: f32[], param1: token[], param2: token[]) -> ((f32[], u32[], token[]), (f32[], u32[], token[])) { CHECK: %[[PARAM0:.*]] = f32[] parameter(0) CHECK: %[[PARAM1:.*]] = token[] parameter(1) - CHECK: %[[SEND1:.*]] = (f32[], u32[], token[]) send(f32[] %[[PARAM0:.*]], - token[] %[[PARAM1:.*]]), channel_id=1 + CHECK: %[[SEND1:.*]] = (f32[], u32[], token[]) send(%[[PARAM0]], %[[PARAM1]]), channel_id=1 CHECK: %[[PARAM2:.*]] = token[] parameter(2) - CHECK: %[[RECV1:.*]] = (f32[], u32[], token[]) - recv(token[] %[[PARAM2:.*]]), channel_id=1 - CHECK: ROOT %[[OUT:.*]] = ((f32[], u32[], token[]), - (f32[], u32[], token[])) tuple((f32[], u32[], token[]) - %[[SEND1:.*]], (f32[], u32[], token[]) %[[RECV1:.*]]) + CHECK: %[[RECV1:.*]] = (f32[], u32[], token[]) recv(%[[PARAM2]]), channel_id=1 + CHECK: ROOT %[[OUT:.*]] = ((f32[], u32[], token[]), (f32[], u32[], token[])) tuple(%[[SEND1]], %[[RECV1]]) CHECK: ENTRY %[[MAIN:.*]] () -> f32[] { CHECK: %[[DATA:.*]] = f32[] constant(5) CHECK: %[[RECV_START:.*]] = token[] after-all() - CHECK: %[[TUPLE_START:.*]] = ((f32[], token[], token[]), - ((f32[], u32[], token[]), (f32[], u32[], token[])), s32[]) - async-start(f32[] %[[DATA:.*]], token[] %[[RECV_START:.*]], - token[] %[[RECV_START:.*]]), calls=%[[WRAPPED_SEND_RECV:.*]] - CHECK: %[[TUPLE_DONE:.*]] = ((f32[], u32[], token[]), - (f32[], u32[], token[])) async-done(((f32[], token[], token[]), - ((f32[], u32[], token[]), (f32[], u32[], token[])), s32[]) %[[TUPLE_START:.*]]) - CHECK %[[GTE2:.*]] = (f32[], u32[], token[]) - get-tuple-element(((f32[], u32[], token[]), - (f32[], u32[], token[])) %[[TUPLE_DONE:.*]]), index=1 - CHECK %[[GTE3:.*]] = f32[] get-tuple-element((f32[], u32[], token[]) %[[GTE2:.*]]), index=0 - CHECK %[[GTE4:.*]] = token[] get-tuple-element((f32[], u32[], token[]) %[[GTE2:.*]]), index=2 - CHECK %[[TUPLE1:.*]] = (f32[], token[]) tuple(f32[] %[[GTE3:.*]], token[] %[[GTE4:.*]]), - control-predecessors={%[[TUPLE_START:.*]]} - CHECK ROOT %[[OUT:.*]] = f32[] get-tuple-element((f32[], token[]) %[[TUPLE1:.*]]), index=0 - CHECK %[[GTE:.*]] = (f32[], u32[], token[]) - get-tuple-element(((f32[], u32[], token[]), (f32[], u32[], token[])) %[[TUPLE_DONE:.*]]), index=0 - CHECK %[[GTE1:.*]] = token[] get-tuple-element((f32[], u32[], token[]) %[[GTE:.*]]), index=2 + CHECK: %[[TUPLE_START:.*]] = ((f32[], token[], token[]), ((f32[], u32[], token[]), (f32[], u32[], token[])), s32[]) async-start(%[[DATA]], %[[RECV_START]], %[[RECV_START]]), calls=%[[WRAPPED_SEND_RECV]] + CHECK: %[[TUPLE_DONE:.*]] = ((f32[], u32[], token[]), (f32[], u32[], token[])) async-done(%[[TUPLE_START]]) + CHECK %[[GTE2:.*]] = (f32[], u32[], token[]) get-tuple-element(%[[TUPLE_DONE]], index=1) + CHECK %[[GTE3:.*]] = f32[] get-tuple-element(%[[GTE2]], index=0) + CHECK %[[GTE4:.*]] = token[] get-tuple-element(%[[GTE2]], index=2) + CHECK %[[TUPLE1:.*]] = (f32[], token[]) tuple(%[[GTE3]], %[[GTE4]]), control-predecessors={%[[TUPLE_START]]} + CHECK ROOT %[[OUT:.*]] = f32[] get-tuple-element(%[[TUPLE1]], index=0) + CHECK %[[GTE:.*]] = (f32[], u32[], token[]) get-tuple-element(%[[TUPLE_DONE]], index=0) + CHECK %[[GTE1:.*]] = token[] get-tuple-element(%[[GTE]], index=2) )")); } @@ -226,15 +211,15 @@ TEST_F(CollectiveSendRecvCombinerTest, TransformedWithMultipleSendRecv) { CHECK-SAME: (f32[], u32[], token[])) CHECK-NEXT: %[[PARAM0:.*]] = {{.*}} parameter(0) CHECK: %[[PARAM1:.*]] = {{.*}} parameter(1) - CHECK: %[[SEND1:.*]] = {{.*}} send({{.*}} %[[PARAM0]], {{.*}} %[[PARAM1]]), channel_id=1 + CHECK: %[[SEND1:.*]] = {{.*}} send(%[[PARAM0]], %[[PARAM1]]), channel_id=1 CHECK: %[[PARAM2:.*]] = f32[] parameter(2) CHECK: %[[PARAM3:.*]] = {{.*}} parameter(3) - CHECK: %[[SEND2:.*]] = {{.*}} send({{.*}} %[[PARAM2]], {{.*}} %[[PARAM3]]), channel_id=2 + CHECK: %[[SEND2:.*]] = {{.*}} send(%[[PARAM2]], %[[PARAM3]]), channel_id=2 CHECK: %[[PARAM4:.*]] = {{.*}} parameter(4) - CHECK: %[[RECV1:.*]] = {{.*}} recv({{.*}} %[[PARAM4]]), channel_id=1 + CHECK: %[[RECV1:.*]] = {{.*}} recv(%[[PARAM4]]), channel_id=1 CHECK: %[[PARAM5:.*]] = {{.*}} parameter(5) - CHECK: %[[RECV2:.*]] = {{.*}} recv({{.*}} %[[PARAM5]]), channel_id=2 - CHECK: ROOT %[[OUT:.*]] = {{.*}} tuple({{.*}} %[[SEND1]], {{.*}} %[[SEND2]], {{.*}} %[[RECV1]], {{.*}} %[[RECV2]]) + CHECK: %[[RECV2:.*]] = {{.*}} recv(%[[PARAM5]]), channel_id=2 + CHECK: ROOT %[[OUT:.*]] = {{.*}} tuple(%[[SEND1]], %[[SEND2]], %[[RECV1]], %[[RECV2]]) CHECK: ENTRY %[[MAIN:.*]] () -> (f32[], f32[]) CHECK: %[[DATA1:.*]] = {{.*}} constant(1) @@ -242,23 +227,23 @@ TEST_F(CollectiveSendRecvCombinerTest, TransformedWithMultipleSendRecv) { CHECK: %[[DATA2:.*]] = {{.*}} constant(2) CHECK: %[[AFTER_ALL2:.*]] = {{.*}} after-all() CHECK: %[[TUPLE_START:.*]] = {{.*}} async-start{{.*}}calls=%[[WRAPPED_SEND_RECV]] - CHECK: %[[TUPLE_DONE:.*]] = {{.*}} async-done({{.*}} %[[TUPLE_START]]) - CHECK %[[GTE4:.*]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE_DONE]]), index=2 - CHECK %[[GTE5:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE4]]), index=0 - CHECK %[[GTE6:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE4]]), index=2 - CHECK %[[[TUPLE1:.*]]]] = {{.*}} tuple({{.*}} %[[GTE5]], {{.*}} %[[GTE6]]), control-predecessors={%[[TUPLE_START]]]} - CHECK %[[DATA_OUT1:.*]]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE1]]), index=0 - - CHECK %[[GTE7:.*]]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE_DONE]]), index=3 - CHECK %[[GTE8:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE7]]), index=0 - CHECK %[[GTE9:.*]]] = {{.*}} get-tuple-element({{.*}} %[[GTE7]]), index=2 - CHECK %[[TUPLE2:.*]] = {{.*}} tuple({{.*}} %[[GTE8]], {{.*}} %[[GTE9]]), control-predecessors={%[[TUPLE_START]]} - CHECK %[[DATA_OUT2:.*]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE2]]), index=0 - CHECK ROOT %[[OUT:.*]] = {{.*}} tuple({{.*}} %[[DATA_OUT1]], {{.*}} %[[DATA_OUT2]]) - CHECK %[[GTE:.*]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE_DONE]]]), index=0 - CHECK %[[GTE1:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE]]]), index=2 - CHECK %[[GTE2:.*]] = {{.*}} get-tuple-element({{.*}} %[[TUPLE_DONE]]), index=1 - CHECK %[[GTE3:.*]] = {{.*}} get-tuple-element({{.*}} %[[GTE2]]), index=2 + CHECK: %[[TUPLE_DONE:.*]] = {{.*}} async-done(%[[TUPLE_START]]) + CHECK %[[GTE4:.*]] = {{.*}} get-tuple-element(%[[TUPLE_DONE]]), index=2 + CHECK %[[GTE5:.*]] = {{.*}} get-tuple-element(%[[GTE4]]), index=0 + CHECK %[[GTE6:.*]] = {{.*}} get-tuple-element(%[[GTE4]]), index=2 + CHECK %[[[TUPLE1:.*]]]] = {{.*}} tuple(%[[GTE5]], %[[GTE6]]), control-predecessors={%[[TUPLE_START]]]} + CHECK %[[DATA_OUT1:.*]]] = {{.*}} get-tuple-element(%[[TUPLE1]]), index=0 + + CHECK %[[GTE7:.*]]] = {{.*}} get-tuple-element(%[[TUPLE_DONE]]), index=3 + CHECK %[[GTE8:.*]] = {{.*}} get-tuple-element(%[[GTE7]]), index=0 + CHECK %[[GTE9:.*]]] = {{.*}} get-tuple-element(%[[GTE7]]), index=2 + CHECK %[[TUPLE2:.*]] = {{.*}} tuple(%[[GTE8]], %[[GTE9]]), control-predecessors={%[[TUPLE_START]]} + CHECK %[[DATA_OUT2:.*]] = {{.*}} get-tuple-element(%[[TUPLE2]]), index=0 + CHECK ROOT %[[OUT:.*]] = {{.*}} tuple(%[[DATA_OUT1]], %[[DATA_OUT2]]) + CHECK %[[GTE:.*]] = {{.*}} get-tuple-element(%[[TUPLE_DONE]]), index=0 + CHECK %[[GTE1:.*]] = {{.*}} get-tuple-element(%[[GTE]]), index=2 + CHECK %[[GTE2:.*]] = {{.*}} get-tuple-element(%[[TUPLE_DONE]]), index=1 + CHECK %[[GTE3:.*]] = {{.*}} get-tuple-element(%[[GTE2]]), index=2 )")); } } // namespace diff --git a/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc b/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc index d9002f2bb7bb70..a05a4d6e317909 100644 --- a/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc +++ b/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc @@ -943,10 +943,10 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body {{.+}} { // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}} - // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) - // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}} - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) + // CHECK: %[[out1:.+]] = {{.+}} tuple({{.*}}%[[cp1]], {{.*}}) + // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.*}}%[[out1]]), index=0 + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.*}}%[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}} + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[cp2]], {{.*}}) // CHECK: } // CHECK: ENTRY %main {{.+}} { // CHECK-NOT: collective-permute @@ -995,15 +995,15 @@ ENTRY main { VLOG(1) << module->ToString(); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}} - // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) - // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]) - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}} - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) + // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.*}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}} + // CHECK: %[[out1:.+]] = {{.+}} tuple({{.*}}%[[cp1]], {{.*}}) + // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.*}}%[[out1]]) + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.*}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}} + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[cp2]], {{.*}}) // CHECK: ENTRY %main {{.+}} { - // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}} - // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}}) - // CHECK: %[[while:.+]] = {{.+}} while({{.+}} %[[out_peeled]]) + // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.*}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}} + // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.*}}%[[cp_peeled]], {{.*}}) + // CHECK: %[[while:.+]] = {{.+}} while({{.*}}%[[out_peeled]]) // CHECK: } )")); } @@ -1049,11 +1049,11 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[cp1:.+]] = f32[] collective-permute(f32[] %param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}} - // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) - // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}} - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) + // CHECK: %[[cp1:.+]] = f32[] collective-permute(%param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}} + // CHECK: %[[out1:.+]] = {{.+}} tuple({{.*}}%[[cp1]], {{.*}}) + // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.*}}%[[out1]]), index=0 + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.*}}%[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}} + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[cp2]], {{.*}}) // CHECK: ENTRY %main // CHECK-NOT: collective-permute // CHECK: } @@ -1102,15 +1102,15 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3}{{[}]}}} - // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) - // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}} - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) + // CHECK: %[[out1:.+]] = {{.+}} tuple({{.*}}%[[cp1]], {{.*}}) + // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.*}}%[[out1]]), index=0 + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.*}}%[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}} + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[cp2]], {{.*}}) // CHECK: } // CHECK: ENTRY %main // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}} - // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}}) - // CHECK: ROOT {{.+}} = {{.+}} while({{.+}} %[[out_peeled]]) + // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.*}}%[[cp_peeled]], {{.*}}) + // CHECK: ROOT {{.+}} = {{.+}} while({{.*}}%[[out_peeled]]) // CHECK: } )")); } @@ -1158,12 +1158,12 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body // CHECK: %[[cp_start1:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}} - // CHECK: %[[cp1:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start1]]) - // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) - // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}} - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start2]]) - // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) + // CHECK: %[[cp1:.+]] = {{.+}} collective-permute-done({{.*}}%[[cp_start1]]) + // CHECK: %[[out1:.+]] = {{.+}} tuple({{.*}}%[[cp1]], {{.*}}) + // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.*}}%[[out1]]), index=0 + // CHECK: %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.*}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}} + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute-done({{.*}}%[[cp_start2]]) + // CHECK: ROOT {{.+}} = {{.+}} tuple({{.*}}%[[cp2]], {{.*}}) // CHECK: } // CHECK: ENTRY %main // CHECK-NOT: collective-permute diff --git a/xla/service/gpu/transforms/explicit_stream_annotation_async_wrapper_test.cc b/xla/service/gpu/transforms/explicit_stream_annotation_async_wrapper_test.cc index 3b2b26e186367a..df59aeb076e709 100644 --- a/xla/service/gpu/transforms/explicit_stream_annotation_async_wrapper_test.cc +++ b/xla/service/gpu/transforms/explicit_stream_annotation_async_wrapper_test.cc @@ -56,8 +56,8 @@ TEST_F(ExplicitStreamAnnotationAsyncWrapperTest, AnnotatedOpIsWrapped) { TF_ASSERT_OK_AND_ASSIGN(bool mutated, wrapper_pass.Run(module.get())); absl::StatusOr filecheck_result = RunFileCheck(module->ToString({}), R"( // CHECK: %lhs.1 = f32[] constant(42) - // CHECK: %call-start = ((f32[]), f32[]) call-start(f32[] %lhs.1), async_execution_thread="explicit", to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} - // CHECK: ROOT %call-done = f32[] call-done(((f32[]), f32[]) %call-start), frontend_attributes={_xla_stream_annotation="1"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false} + // CHECK: %call-start = ((f32[]), f32[]) call-start(%lhs.1), async_execution_thread="explicit", to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"} + // CHECK: ROOT %call-done = f32[] call-done(%call-start), frontend_attributes={_xla_stream_annotation="1"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false} )"); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(*filecheck_result); @@ -97,10 +97,10 @@ TEST_F(ExplicitStreamAnnotationAsyncWrapperTest, OverlappingGemms) { TF_ASSERT_OK_AND_ASSIGN(bool mutated, wrapper_pass.Run(module.get())); absl::StatusOr filecheck_result = RunFileCheck(module->ToString({}), R"( - // CHECK: %call-start = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), async_execution_thread="explicit", to_apply=%gemm1, frontend_attributes={_xla_stream_annotation="1"} - // CHECK: %call-done = f32[2048,2048]{1,0} call-done(((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) %call-start), frontend_attributes={_xla_stream_annotation="1"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false} - // CHECK: %call-start.1 = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(f32[2048,2048]{1,0} %x, f32[2048,2048]{1,0} %y), async_execution_thread="explicit", to_apply=%gemm2, frontend_attributes={_xla_stream_annotation="2"} - // CHECK: ROOT %call-done.1 = f32[2048,2048]{1,0} call-done(((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) %call-start.1), frontend_attributes={_xla_stream_annotation="2"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false} + // CHECK: %call-start = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(%x, %y), async_execution_thread="explicit", to_apply=%gemm1, frontend_attributes={_xla_stream_annotation="1"} + // CHECK: %call-done = f32[2048,2048]{1,0} call-done(%call-start), frontend_attributes={_xla_stream_annotation="1"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false} + // CHECK: %call-start.1 = ((f32[2048,2048]{1,0}, f32[2048,2048]{1,0}), f32[2048,2048]{1,0}) call-start(%x, %y), async_execution_thread="explicit", to_apply=%gemm2, frontend_attributes={_xla_stream_annotation="2"} + // CHECK: ROOT %call-done.1 = f32[2048,2048]{1,0} call-done(%call-start.1), frontend_attributes={_xla_stream_annotation="2"}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"force_earliest_schedule":false} )"); TF_ASSERT_OK(filecheck_result.status()); EXPECT_TRUE(*filecheck_result); diff --git a/xla/service/gpu/transforms/gemm_fusion_test.cc b/xla/service/gpu/transforms/gemm_fusion_test.cc index 492e27f8d83c3e..f414d7cf32b25c 100644 --- a/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -723,16 +723,16 @@ ENTRY e { MatchHloModule(*module, R"( CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) -CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]) +CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(%[[P0]], %[[P1]]) CHECK-DAG: %[[P2:.*]] = f32[2,4]{1,0} parameter(2) CHECK-DAG: %[[P3:.*]] = f32[2,4]{1,0} parameter(3) -CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P2]], f32[2,4]{1,0} %[[P3]]) -CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]]) +CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(%[[P2]], %[[P3]]) +CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(%[[ADD0]], %[[ADD1]]) CHECK: ENTRY CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} -CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P1]]), +CHECK-SAME: fusion(%[[P0]], %[[P1]], %[[P0]], %[[P1]]), CHECK-SAME: kind=kCustom CHECK-SAME: __triton_gemm })"); @@ -756,14 +756,14 @@ ENTRY e { MatchHloModule(*module, R"( CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) -CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]]) +CHECK-DAG: %[[ADD0:.*]] = f32[2,4]{1,0} add(%[[P0]], %[[P0]]) CHECK-DAG: %[[P1:.*]] = f32[2,4]{1,0} parameter(1) -CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(f32[2,4]{1,0} %[[P1]], f32[2,4]{1,0} %[[P1]]) -CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(f32[2,4]{1,0} %[[ADD0]], f32[2,4]{1,0} %[[ADD1]]) +CHECK-DAG: %[[ADD1:.*]] = f32[2,4]{1,0} add(%[[P1]], %[[P1]]) +CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} dot(%[[ADD0]], %[[ADD1]]) CHECK: ENTRY CHECK-DAG: %[[P0:.*]] = f32[2,4]{1,0} parameter(0) CHECK-DAG: ROOT {{.*}} = f32[2,2]{1,0} -CHECK-SAME: fusion(f32[2,4]{1,0} %[[P0]], f32[2,4]{1,0} %[[P0]]) +CHECK-SAME: fusion(%[[P0]], %[[P0]]) CHECK-SAME: kind=kCustom CHECK-SAME: __triton_gemm })"); @@ -789,15 +789,15 @@ ENTRY e { MatchHloModule(*module, R"( CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) -CHECK-DAG: %[[NEGATE:.*]] = f32[4,4]{1,0} negate(f32[4,4]{1,0} %[[P0]]) -CHECK-DAG: %[[SINE:.*]] = f32[4,4]{1,0} sine(f32[4,4]{1,0} %[[NEGATE]]) -CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[NEGATE]], f32[4,4]{1,0} %[[SINE]]) -CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P1]]) +CHECK-DAG: %[[NEGATE:.*]] = f32[4,4]{1,0} negate(%[[P0]]) +CHECK-DAG: %[[SINE:.*]] = f32[4,4]{1,0} sine(%[[NEGATE]]) +CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(%[[NEGATE]], %[[SINE]]) +CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(%[[ADD]], %[[P1]]) CHECK: ENTRY CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} -CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]]) +CHECK-SAME: fusion(%[[P0]], %[[P1]]) CHECK-SAME: kind=kCustom CHECK-SAME: __triton_gemm })"); @@ -823,14 +823,14 @@ ENTRY e { CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) CHECK-DAG: %[[P2:.*]] = f32[4,4]{1,0} parameter(2) -CHECK-DAG: %[[TRANSPOSE:.*]] = f32[4,4]{1,0} transpose(f32[4,4]{1,0} %[[P1]]) -CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[TRANSPOSE]]) -CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(f32[4,4]{1,0} %[[ADD]], f32[4,4]{1,0} %[[P2]]) +CHECK-DAG: %[[TRANSPOSE:.*]] = f32[4,4]{1,0} transpose(%[[P1]]) +CHECK-DAG: %[[ADD:.*]] = f32[4,4]{1,0} add(%[[P0]], %[[TRANSPOSE]]) +CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} dot(%[[ADD]], %[[P2]]) CHECK: ENTRY CHECK-DAG: %[[P0:.*]] = f32[4,4]{1,0} parameter(0) CHECK-DAG: %[[P1:.*]] = f32[4,4]{1,0} parameter(1) CHECK-DAG: ROOT {{.*}} = f32[4,4]{1,0} -CHECK-SAME: fusion(f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P0]], f32[4,4]{1,0} %[[P1]]) +CHECK-SAME: fusion(%[[P0]], %[[P0]], %[[P1]]) CHECK-SAME: kind=kCustom CHECK-SAME: __triton_gemm })"); @@ -1301,7 +1301,7 @@ ENTRY e { ; CHECK-LABEL: ENTRY %e ({{.*}}: f16[2,10], {{.*}}: f16[10,2]) -> f16[10,10] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,10]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f16[10,2]{1,0} parameter(1) -; CHECK: ROOT {{.*}} = f16[10,10]{1,0} dot(f16[2,10]{1,0} [[P0]], f16[10,2]{1,0} [[P1]]) +; CHECK: ROOT {{.*}} = f16[10,10]{1,0} dot([[P0]], [[P1]]) })"); } @@ -1324,7 +1324,7 @@ ENTRY e { ; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,18]{1,0} parameter(0) ; CHECK-NEXT: [[P1:%[^ ]+]] = f16[50,2]{1,0} parameter(1) ; CHECK: ROOT {{.*}} = f16[18,50]{1,0} -; CHECK: fusion(f16[2,18]{1,0} [[P0]], f16[50,2]{1,0} [[P1]]), +; CHECK: fusion([[P0]], [[P1]]), ; CHECK: kind=kCustom ; CHECK: __triton_gemm })"); @@ -1351,7 +1351,7 @@ ENTRY main { ; CHECK-NEXT: [[P1:%[^ ]+]] = f16[32,2]{1,0} parameter(1) ; CHECK-NEXT: [[META:%[^ ]+]] = u16[2,2]{1,0} parameter(2) ; CHECK: ROOT {{.*}} = f32[2,2]{1,0} -; CHECK-SAME: fusion(f16[2,16]{1,0} [[P0]], f16[32,2]{1,0} [[P1]], u16[2,2]{1,0} [[META]]), +; CHECK-SAME: fusion([[P0]], [[P1]], [[META]]), ; CHECK-SAME: kind=kCustom ; CHECK-SAME: __triton_gemm })"); @@ -1423,7 +1423,9 @@ TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { CHECK: gemm_fusion_dot_computation CHECK: %parameter_0 = s4[8,1024]{1,0} parameter(0) CHECK: ENTRY -CHECK-DAG: ROOT {{.*}} = bf16[8,4]{1,0} fusion(s4[8,1024]{1,0} %lhs_concat, bf16[1024,4]{1,0} %rhs) +CHECK-DAG: %[[LHS_CONCAT:.*]] = s4[8,1024]{1,0} concatenate(%{{.+}}, %{{.+}}), dimensions={0} +CHECK-DAG: %[[RHS:.*]] = bf16[1024,4]{1,0} parameter(2) +CHECK-DAG: ROOT {{.*}} = bf16[8,4]{1,0} fusion(%[[LHS_CONCAT]], %[[RHS]]) })"); } @@ -1447,7 +1449,9 @@ TEST_F(SmallDotGemmFusionTest, Int4ConvertPlusNegateIsRewritten) { CHECK: gemm_fusion_dot_computation CHECK: %parameter_0 = s4[8,1024]{1,0} parameter(0) CHECK: ENTRY -CHECK-DAG: ROOT {{.*}} = f32[8,4]{1,0} fusion(s4[8,1024]{1,0} %lhs, f32[1024,4]{1,0} %rhs) +CHECK-DAG: %[[LHS:.+]] = s4[8,1024]{1,0} parameter(0) +CHECK-DAG: %[[RHS:.+]] = f32[1024,4]{1,0} parameter(1) +CHECK-DAG: ROOT {{.*}} = f32[8,4]{1,0} fusion(%[[LHS]], %[[RHS]]) })"); } diff --git a/xla/service/gpu/transforms/windowed_einsum_handler_test.cc b/xla/service/gpu/transforms/windowed_einsum_handler_test.cc index 14107e0b07cfc9..19efed83bcaef2 100644 --- a/xla/service/gpu/transforms/windowed_einsum_handler_test.cc +++ b/xla/service/gpu/transforms/windowed_einsum_handler_test.cc @@ -337,49 +337,49 @@ ENTRY main.9_spmd { CHECK: ENTRY CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} parameter(1) -CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [6144:8192]} -CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]), +CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [6144:8192]} +CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE0]]), CHECK: replica_groups={ CHECK: {0,1,2,3},{4,5,6,7} CHECK: } CHECK: dimensions={1} CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0) -CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]} -CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]} +CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A0:.*]], %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [4096:6144]} -CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]), +CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [4096:6144]} +CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE1]]), CHECK: replica_groups={ CHECK: {0,1,2,3},{4,5,6,7} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]} -CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]} +CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A1:.*]], %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [2048:4096]} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]), +CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [2048:4096]} +CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE2]]), CHECK: replica_groups={ CHECK: {0,1,2,3},{4,5,6,7} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]} -CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]} +CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A2:.*]], %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:2048]} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]), +CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [0:2048]} +CHECK: %[[A2A3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE3]]), CHECK: replica_groups={ CHECK: {0,1,2,3},{4,5,6,7} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]} -CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]} +CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A3:.*]], %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false} CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0) -CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={} -CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["5"],"force_earliest_schedule":false} -CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false} -CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false} +CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(%[[CONSTANT:.*]]), dimensions={} +CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT0:.*]], %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["5"],"force_earliest_schedule":false} +CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT1:.*]], %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false} +CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT2:.*]], %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false} -CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]]) +CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT3:.*]], %[[ADD2:.*]]) )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -412,49 +412,49 @@ ENTRY main.9_spmd { CHECK: ENTRY CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(1) -CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]} +CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]} CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0) -CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [24576:32768]} -CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]), +CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [0:8192], [24576:32768]} +CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE0:.*]], %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT0:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]} -CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [16384:24576]} -CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]), +CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]} +CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [0:8192], [16384:24576]} +CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE1:.*]], %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT1:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]} -CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [8192:16384]} -CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]), +CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]} +CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [0:8192], [8192:16384]} +CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE2:.*]], %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT2:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]} -CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]} -CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]), +CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]} +CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]} +CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE3:.*]], %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={2}, backend_config={"operation_queue_id":"5","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT3:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0) -CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={} -CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]]) -CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]]) -CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]]) +CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(%[[CONSTANT:.*]]), dimensions={} +CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A0:.*]], %[[BROADCAST:.*]]) +CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A1:.*]], %[[ADD0:.*]]) +CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A2:.*]], %[[ADD1:.*]]) -CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]]) +CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A3:.*]], %[[ADD2:.*]]) )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -491,55 +491,55 @@ ENTRY main.9_spmd { const char* kExpected = R"( CHECK: ENTRY CHECK-DAG: %[[P1:.*]] = bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} parameter(1) -CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0} %[[P1:.*]]), dimensions={0,3,1,2,4,5} -CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} %[[TRANSPOSE0:.*]]) -CHECK-DAG: %[[RESHAPE1:.*]] = bf16[4,8192,1,2048]{3,2,1,0} reshape(bf16[1,4,8192,1,2048]{4,3,2,1,0} %[[RESHAPE0:.*]]) -CHECK-DAG: %[[TRANSPOSE1:.*]] = bf16[1,4,2048,8192]{2,0,3,1} transpose(bf16[4,8192,1,2048]{3,2,1,0} %[[RESHAPE1:.*]]), dimensions={2,0,3,1} -CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{2,0,3,1} %[[TRANSPOSE1:.*]]) - -CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [6144:8192]} -CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE0]]), +CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[1,4,1,8192,1,2048]{5,4,1,3,2,0} transpose(%[[P1:.*]]), dimensions={0,3,1,2,4,5} +CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,8192,1,2048]{4,3,2,1,0} reshape(%[[TRANSPOSE0:.*]]) +CHECK-DAG: %[[RESHAPE1:.*]] = bf16[4,8192,1,2048]{3,2,1,0} reshape(%[[RESHAPE0:.*]]) +CHECK-DAG: %[[TRANSPOSE1:.*]] = bf16[1,4,2048,8192]{2,0,3,1} transpose(%[[RESHAPE1:.*]]), dimensions={2,0,3,1} +CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(%[[TRANSPOSE1:.*]]) + +CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [6144:8192]} +CHECK: %[[A2A0:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE0]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} CHECK-DAG: %[[P0:.*]] = bf16[1,8192,32768]{2,1,0} parameter(0) -CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]} -CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A0:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE4:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [6144:8192], [0:32768]} +CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A0:.*]], %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [4096:6144]} -CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE1]]), +CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [4096:6144]} +CHECK: %[[A2A1:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE1]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]} -CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A1:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE5:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [4096:6144], [0:32768]} +CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A1:.*]], %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"8","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [2048:4096]} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE2]]), +CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [2048:4096]} +CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE2]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]} -CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A2:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE6:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [2048:4096], [0:32768]} +CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A2:.*]], %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"7","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [0:2048]} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(bf16[1,4,2048,2048]{3,2,1,0} %[[SLICE3]]), +CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} slice(%[[COPY:.*]]), slice={[0:1], [0:4], [0:2048], [0:2048]} +CHECK: %[[A2A3:.*]] = bf16[1,4,2048,2048]{3,2,1,0} all-to-all(%[[SLICE3]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(bf16[1,8192,32768]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]} -CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(bf16[1,4,2048,2048]{3,2,1,0} %[[A2A3:.*]], bf16[1,2048,32768]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK-DAG: %[[SLICE7:.*]] = bf16[1,2048,32768]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [0:2048], [0:32768]} +CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,32768]{3,2,1,0} dot(%[[A2A3:.*]], %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"6","wait_on_operation_queues":[],"force_earliest_schedule":false} CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0) -CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={} -CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT0:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false} -CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT1:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false} -CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT2:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["8"],"force_earliest_schedule":false} +CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,32768]{3,2,1,0} broadcast(%[[CONSTANT:.*]]), dimensions={} +CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT0:.*]], %[[BROADCAST:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["6"],"force_earliest_schedule":false} +CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT1:.*]], %[[ADD0:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["7"],"force_earliest_schedule":false} +CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT2:.*]], %[[ADD1:.*]]), backend_config={"operation_queue_id":"0","wait_on_operation_queues":["8"],"force_earliest_schedule":false} -CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,1,0} %[[DOT3:.*]], bf16[1,4,2048,32768]{3,2,1,0} %[[ADD2:.*]]) +CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(%[[DOT3:.*]], %[[ADD2:.*]]) )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -578,55 +578,55 @@ ENTRY main.9_spmd { CHECK: ENTRY CHECK-DAG: %[[P1:.*]] = bf16[1,4,2048,32768]{3,2,1,0} parameter(0) -CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]} +CHECK-DAG: %[[SLICE0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [24576:32768]} CHECK-DAG: %[[P0:.*]] = bf16[1,32768,8192]{2,1,0} parameter(1) -CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [24576:32768], [0:8192]} -CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE0:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"12","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT0:.*]]), +CHECK-DAG: %[[SLICE4:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [24576:32768], [0:8192]} +CHECK-DAG: %[[DOT0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE0:.*]], %[[SLICE4:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"12","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT0:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]} -CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [16384:24576], [0:8192]} -CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE1:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"11","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT1:.*]]), +CHECK-DAG: %[[SLICE1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [16384:24576]} +CHECK-DAG: %[[SLICE5:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [16384:24576], [0:8192]} +CHECK-DAG: %[[DOT1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE1:.*]], %[[SLICE5:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"11","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT1:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]} -CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [8192:16384], [0:8192]} -CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE2:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"10","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT2:.*]]), +CHECK-DAG: %[[SLICE2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [8192:16384]} +CHECK-DAG: %[[SLICE6:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [8192:16384], [0:8192]} +CHECK-DAG: %[[DOT2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE2:.*]], %[[SLICE6:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"10","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT2:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} -CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(bf16[1,4,2048,32768]{3,2,1,0} %[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]} -CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(bf16[1,32768,8192]{2,1,0} %[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]} -CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(bf16[1,4,2048,8192]{3,2,1,0} %[[SLICE3:.*]], bf16[1,8192,8192]{2,1,0} %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false} -CHECK: %[[A2A2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(bf16[1,4,2048,8192]{3,2,1,0} %[[DOT3:.*]]), +CHECK-DAG: %[[SLICE3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} slice(%[[P1]]), slice={[0:1], [0:4], [0:2048], [0:8192]} +CHECK-DAG: %[[SLICE7:.*]] = bf16[1,8192,8192]{2,1,0} slice(%[[P0:.*]]), slice={[0:1], [0:8192], [0:8192]} +CHECK-DAG: %[[DOT3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} dot(%[[SLICE3:.*]], %[[SLICE7:.*]]), lhs_batch_dims={0}, lhs_contracting_dims={3}, rhs_batch_dims={0}, rhs_contracting_dims={1}, backend_config={"operation_queue_id":"9","wait_on_operation_queues":[],"force_earliest_schedule":false} +CHECK: %[[A2A3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} all-to-all(%[[DOT3:.*]]), CHECK: replica_groups={ CHECK: {0,1,2,3} CHECK: } CHECK: dimensions={1} CHECK-DAG: %[[CONSTANT:.*]] = bf16[] constant(0) -CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(bf16[] %[[CONSTANT:.*]]), dimensions={} -CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A0:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[BROADCAST:.*]]) -CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A1:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD0:.*]]) -CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A2:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD1:.*]]) -CHECK-DAG: %[[ADD3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,0} %[[A2A3:.*]], bf16[1,4,2048,8192]{3,2,1,0} %[[ADD2:.*]]) - -CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(bf16[1,4,2048,8192]{3,2,1,0} %[[ADD3:.*]]) -CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[4,1,2048,8192]{3,2,0,1} transpose(bf16[1,4,2048,8192]{3,2,1,0} %[[COPY:.*]]), dimensions={1,0,2,3} -CHECK-DAG: %[[COPY1:.*]] = bf16[4,1,2048,8192]{3,2,1,0} copy(bf16[4,1,2048,8192]{3,2,0,1} %[[TRANSPOSE0:.*]]) -CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(bf16[4,1,2048,8192]{3,2,1,0} %[[COPY1:.*]]) - -CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,2048,8192]{4,3,2,1,0} %[[RESHAPE0:.*]]) +CHECK-DAG: %[[BROADCAST:.*]] = bf16[1,4,2048,8192]{3,2,1,0} broadcast(%[[CONSTANT:.*]]), dimensions={} +CHECK-DAG: %[[ADD0:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A0:.*]], %[[BROADCAST:.*]]) +CHECK-DAG: %[[ADD1:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A1:.*]], %[[ADD0:.*]]) +CHECK-DAG: %[[ADD2:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A2:.*]], %[[ADD1:.*]]) +CHECK-DAG: %[[ADD3:.*]] = bf16[1,4,2048,8192]{3,2,1,0} add(%[[A2A3:.*]], %[[ADD2:.*]]) + +CHECK-DAG: %[[COPY:.*]] = bf16[1,4,2048,8192]{3,2,1,0} copy(%[[ADD3:.*]]) +CHECK-DAG: %[[TRANSPOSE0:.*]] = bf16[4,1,2048,8192]{3,2,0,1} transpose(%[[COPY:.*]]), dimensions={1,0,2,3} +CHECK-DAG: %[[COPY1:.*]] = bf16[4,1,2048,8192]{3,2,1,0} copy(%[[TRANSPOSE0:.*]]) +CHECK-DAG: %[[RESHAPE0:.*]] = bf16[1,4,1,2048,8192]{4,3,2,1,0} reshape(%[[COPY1:.*]]) + +CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(%[[RESHAPE0:.*]]) )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/xla/service/hlo_computation_test.cc b/xla/service/hlo_computation_test.cc index d1cf9e9ca3529b..b8cbfb028fe2b1 100644 --- a/xla/service/hlo_computation_test.cc +++ b/xla/service/hlo_computation_test.cc @@ -705,8 +705,8 @@ TEST_F(HloComputationTest, Stringification) { R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) - %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} - ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + %transpose = f32[10,20]{1,0} transpose(%y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(%x, %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} }, execution_thread="MainThread")"; EXPECT_EQ(computation->ToString(options), expected_computation); } @@ -742,8 +742,8 @@ TEST_F(HloComputationTest, StringificationIndent) { R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) - %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} - ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + %transpose = f32[10,20]{1,0} transpose(%y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(%x, %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} }, execution_thread="MainThread")"; EXPECT_EQ(computation->ToString(options), expected_computation); } @@ -778,8 +778,8 @@ TEST_F(HloComputationTest, StringificationCanonical) { R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { %x = f32[5,10]{1,0} parameter(0) %y = f32[20,10]{1,0} parameter(1) - %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} - ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + %transpose = f32[10,20]{1,0} transpose(%y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(%x, %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} }, execution_thread="MainThread")"; EXPECT_EQ(computation->ToString(options), expected_computation1); diff --git a/xla/service/hlo_instruction_test.cc b/xla/service/hlo_instruction_test.cc index 2a197684c37c0a..d464f7c182bf76 100644 --- a/xla/service/hlo_instruction_test.cc +++ b/xla/service/hlo_instruction_test.cc @@ -1670,8 +1670,8 @@ TEST_F(HloInstructionTest, StringifyDot) { auto options = HloPrintOptions().set_print_metadata(false); EXPECT_EQ(dot->ToString(options), - "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " - "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + "%dot = f32[5,20]{1,0} dot(%x, %transpose), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); auto options2 = HloPrintOptions() .set_print_metadata(false) @@ -1707,10 +1707,10 @@ TEST_F(HloInstructionTest, StringifySparseDot) { ShapeUtil::MakeShape(F32, {5, 20}), x, y, dot_dnums, DefaultPrecisionConfig(2), {sparsity_descriptor}, meta_operands)); - EXPECT_EQ(dot->ToString(), - "%dot = f32[5,20]{1,0} dot(f32[5,16]{1,0} %x, f32[32,20]{1,0} %y, " - "u16[5,2]{1,0} %meta), lhs_contracting_dims={1}, " - "rhs_contracting_dims={0}, sparsity=L.1@2:4"); + EXPECT_EQ( + dot->ToString(), + "%dot = f32[5,20]{1,0} dot(%x, %y, %meta), lhs_contracting_dims={1}, " + "rhs_contracting_dims={0}, sparsity=L.1@2:4"); } TEST_F(HloInstructionTest, StringifyConditional) { @@ -1742,8 +1742,7 @@ TEST_F(HloInstructionTest, StringifyConditional) { builder.AddInstruction(HloInstruction::CreateConditional( sout, pred, x, computation, x, computation)); EXPECT_EQ(conditional->ToString(options), - "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, " - "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), " + "%conditional = f32[5,20]{1,0} conditional(%constant, %x, %x), " "true_computation=%TransposeDot, false_computation=%TransposeDot"); } @@ -1773,8 +1772,8 @@ TEST_F(HloInstructionTest, StringifyWhile) { HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); EXPECT_EQ(loop->ToString(options), - "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " - "condition=%TransposeDot, body=%TransposeDot"); + "%while = f32[5,20]{1,0} while(%x), condition=%TransposeDot, " + "body=%TransposeDot"); } TEST_F(HloInstructionTest, GetSetStatisticsViz) { @@ -1825,8 +1824,7 @@ TEST_F(HloInstructionTest, StringifyStatisticsViz) { // Empty statistics viz must not print "statistics={}" add->set_statistics_viz({}); - EXPECT_EQ(add->ToString(), - "%add = f32[5,10]{1,0} add(f32[5,10]{1,0} %x, f32[5,10]{1,0} %y)"); + EXPECT_EQ(add->ToString(), "%add = f32[5,10]{1,0} add(%x, %y)"); auto CreateStatisticsVizWithStatistics = [](int64_t stat_index_to_visualize, @@ -1855,7 +1853,7 @@ TEST_F(HloInstructionTest, StringifyStatisticsViz) { 1, {{"stat-1", 33.0}, {"stat-2", 44.0}})); EXPECT_EQ(add->ToString(), - "%add = f32[5,10]{1,0} add(f32[5,10]{1,0} %x, f32[5,10]{1,0} %y), " + "%add = f32[5,10]{1,0} add(%x, %y), " "statistics={visualizing_index=1,stat-1=33,stat-2=44}"); } @@ -1888,8 +1886,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " - "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), " + "gather(%input_tensor, %start_indices), " "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " "start_index_map={0,1,2,3,4}, " "index_vector_dim=4, slice_sizes={30,29,28,27,26}"); @@ -1924,8 +1921,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " - "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), " + "gather(%input_tensor, %start_indices), " "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, " "start_index_map={0,1,2,3,4}, " "index_vector_dim=2, slice_sizes={30,29,28,27,26}"); @@ -1971,15 +1967,12 @@ TEST_F(HloInstructionTest, StringifyScatter) { /*unique_indices=*/false)); module->AddEntryComputation(builder.Build()); - EXPECT_EQ( - scatter_instruction->ToString(), - "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} " - "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " - "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, " - "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), " - "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, " - "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, " - "to_apply=%Scatter.update"); + EXPECT_EQ(scatter_instruction->ToString(), + "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} " + "scatter(%input_tensor, %scatter_indices, %scatter_updates), " + "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, " + "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, " + "to_apply=%Scatter.update"); } TEST_F(HloInstructionTest, StringifyAsyncOps) { @@ -2017,9 +2010,9 @@ TEST_F(HloInstructionTest, StringifyAsyncOps) { ENTRY %Entry (p0: f32[10]) -> f32[20] { %p0 = f32[10]{0} parameter(0) - %custom-call-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo" - %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-start) - ROOT %custom-call-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-update) + %custom-call-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(%p0), async_execution_thread="parallel_thread", custom_call_target="foo" + %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(%custom-call-start) + ROOT %custom-call-done = f32[20]{0} custom-call-done(%custom-call-update) } )"; @@ -2032,14 +2025,14 @@ ENTRY %Entry (p0: f32[10]) -> f32[20] { %AsyncOp (p0.1: f32[10]) -> f32[20] { %p0.1 = f32[10]{0} parameter(0) - ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %p0.1), custom_call_target="foo" + ROOT %custom-call = f32[20]{0} custom-call(%p0.1), custom_call_target="foo" }, execution_thread="parallel_thread" ENTRY %Entry (p0: f32[10]) -> f32[20] { %p0 = f32[10]{0} parameter(0) - %custom-call-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", calls=%AsyncOp - %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-start) - ROOT %custom-call-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %custom-call-update) + %custom-call-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(%p0), async_execution_thread="parallel_thread", calls=%AsyncOp + %custom-call-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(%custom-call-start) + ROOT %custom-call-done = f32[20]{0} async-done(%custom-call-update) } )"; @@ -2101,14 +2094,14 @@ TEST_F(HloInstructionTest, StringifyAsyncOpsWithReduceScatter) { %add (p0: f32[], p1: f32[]) -> f32[] { %p0 = f32[] parameter(0) %p1 = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %p0, f32[] %p1) + ROOT %add = f32[] add(%p0, %p1) }, execution_thread="parallel_thread" ENTRY %Entry (pentry: f32[20]) -> f32[10] { %pentry = f32[20]{0} parameter(0) - %reduce-scatter-start = ((f32[20]{0}), f32[10]{0}) reduce-scatter-start(f32[20]{0} %pentry), async_execution_thread="parallel_thread", replica_groups={}, dimensions={0}, to_apply=%add - %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) reduce-scatter-update(((f32[20]{0}), f32[10]{0}) %reduce-scatter-start) - ROOT %reduce-scatter-done = f32[10]{0} reduce-scatter-done(((f32[20]{0}), f32[10]{0}) %reduce-scatter-update) + %reduce-scatter-start = ((f32[20]{0}), f32[10]{0}) reduce-scatter-start(%pentry), async_execution_thread="parallel_thread", replica_groups={}, dimensions={0}, to_apply=%add + %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) reduce-scatter-update(%reduce-scatter-start) + ROOT %reduce-scatter-done = f32[10]{0} reduce-scatter-done(%reduce-scatter-update) } )"; @@ -2123,19 +2116,19 @@ ENTRY %Entry (pentry: f32[20]) -> f32[10] { %add (p0: f32[], p1: f32[]) -> f32[] { %p0 = f32[] parameter(0) %p1 = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %p0, f32[] %p1) + ROOT %add = f32[] add(%p0, %p1) }, execution_thread="parallel_thread" %AsyncOp (pasync: f32[20]) -> f32[10] { %pasync = f32[20]{0} parameter(0) - ROOT %reduce-scatter = f32[10]{0} reduce-scatter(f32[20]{0} %pasync), replica_groups={}, dimensions={0}, to_apply=%add + ROOT %reduce-scatter = f32[10]{0} reduce-scatter(%pasync), replica_groups={}, dimensions={0}, to_apply=%add }, execution_thread="parallel_thread" ENTRY %Entry (pentry: f32[20]) -> f32[10] { %pentry = f32[20]{0} parameter(0) - %reduce-scatter-start = ((f32[20]{0}), f32[10]{0}) async-start(f32[20]{0} %pentry), async_execution_thread="parallel_thread", calls=%AsyncOp - %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) async-update(((f32[20]{0}), f32[10]{0}) %reduce-scatter-start) - ROOT %reduce-scatter-done = f32[10]{0} async-done(((f32[20]{0}), f32[10]{0}) %reduce-scatter-update) + %reduce-scatter-start = ((f32[20]{0}), f32[10]{0}) async-start(%pentry), async_execution_thread="parallel_thread", calls=%AsyncOp + %reduce-scatter-update = ((f32[20]{0}), f32[10]{0}) async-update(%reduce-scatter-start) + ROOT %reduce-scatter-done = f32[10]{0} async-done(%reduce-scatter-update) } )"; @@ -3345,20 +3338,20 @@ TEST_F(HloInstructionTest, PrintUnaryWithResultAccuracy) { HloInstruction* exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, x, result_accuracy)); EXPECT_EQ(exp->ToString(), - "%exponential = f32[] exponential(f32[] %x), " + "%exponential = f32[] exponential(%x), " "result_accuracy={tolerance={atol=0,rtol=0.4,ulps=0}}"); EXPECT_TRUE(exp->has_result_accuracy()); HloInstruction* exp_no_result_accuracy = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, x)); EXPECT_EQ(exp_no_result_accuracy->ToString(), - "%exponential = f32[] exponential(f32[] %x)"); + "%exponential = f32[] exponential(%x)"); EXPECT_FALSE(exp_no_result_accuracy->has_result_accuracy()); HloInstruction* exp_default_set = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, x)); // Setting the mode to DEFAULT is the same as not setting it at all. exp_default_set->set_result_accuracy(ResultAccuracy()); EXPECT_EQ(exp_default_set->ToString(), - "%exponential = f32[] exponential(f32[] %x)"); + "%exponential = f32[] exponential(%x)"); EXPECT_FALSE(exp_default_set->has_result_accuracy()); } diff --git a/xla/service/layout_normalization_test.cc b/xla/service/layout_normalization_test.cc index fb40185f38d4a9..7611ce70e78db1 100644 --- a/xla/service/layout_normalization_test.cc +++ b/xla/service/layout_normalization_test.cc @@ -582,14 +582,15 @@ HloModule module ENTRY main { p = f32[5,4]{0,1} parameter(0) - c = f32[5,4]{0,1} constant({...}) + c = f32[5,4]{0,1} constant({{1,2,3,4},{5,6,7,8},{9,10,11,12},{13,14,15,16},{17,18,19,20}}) ROOT o = f32[5,4]{0,1} add(p, c) } )"; CheckLayoutNormalization(hlo, R"( // CHECK: [[p_0:%[^ ]+]] = f32[5,4]{0,1} parameter(0) // CHECK-NEXT: [[bitcast_1:%[^ ]+]] = f32[4,5]{1,0} bitcast([[p_0]]) -// CHECK-NEXT: [[constant_2:%[^ ]+]] = f32[4,5]{1,0} constant({...}) +// CHECK-NEXT: [[constant_2:%[^ ]+]] = f32[4,5]{1,0} constant( +// CHECK-SAME{LITERAL}: { { 1, 5, 9, 13, 17 }, { 2, 6, 10, 14, 18 }, { 3, 7, 11, 15, 19 }, { 4, 8, 12, 16, 20 } }) // CHECK-NEXT: [[add_3:%[^ ]+]] = f32[4,5]{1,0} add([[bitcast_1]], [[constant_2]]) // CHECK-NEXT: ROOT [[bitcast_3_4:%[^ ]+]] = f32[5,4]{0,1} bitcast([[add_3]]) )"); @@ -600,7 +601,7 @@ TEST_F(LayoutNormalizationTest, ConstantAvoidRevisitOfUser) { HloModule module ENTRY main { - c = f32[5,4]{0,1} constant({...}) + c = f32[5,4]{0,1} constant({{1,2,3,4},{5,6,7,8},{9,10,11,12},{13,14,15,16},{17,18,19,20}}) s = f32[5,4]{0,1} sine(c) t = f32[5,4]{0,1} tanh(s) ROOT o = f32[5,4]{0,1} add(s, t) @@ -610,7 +611,8 @@ ENTRY main { // run into a CHECK failure, because the constant was normalized in-place and // therefore would not be revisited. CheckLayoutNormalization(hlo, R"( -// CHECK: [[constant_2:%[^ ]+]] = f32[4,5]{1,0} constant({...}) +// CHECK: [[constant_2:%[^ ]+]] = f32[4,5]{1,0} constant( +// CHECK-SAME{LITERAL}: { { 1, 5, 9, 13, 17 }, { 2, 6, 10, 14, 18 }, { 3, 7, 11, 15, 19 }, { 4, 8, 12, 16, 20 } }) // CHECK-NEXT: [[sine:%[^ ]+]] = f32[4,5]{1,0} sine([[constant_2]]) // CHECK-NEXT: [[bitcast_1:%[^ ]+]] = f32[5,4]{0,1} bitcast([[sine]]) // CHECK-NEXT: [[bitcast_2:%[^ ]+]] = f32[4,5]{1,0} bitcast([[bitcast_1]]) diff --git a/xla/service/pattern_matcher_test.cc b/xla/service/pattern_matcher_test.cc index c8ecfe16029189..fe3d92d29c12b3 100644 --- a/xla/service/pattern_matcher_test.cc +++ b/xla/service/pattern_matcher_test.cc @@ -866,7 +866,7 @@ TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) { "HloInstruction has opcode constant, expected anything else\n" "in c = s32[] constant(0)\n" "in operand 1\n" - "in a = s32[] add(s32[] c, s32[] c)"); + "in a = s32[] add(c, c)"); EXPECT_DESC_AND_EXPLANATION( iota, m::Op().WithFusionKind(HloInstruction::FusionKind::kLoop), "an HloInstruction with fusion kind kLoop", @@ -920,7 +920,7 @@ TEST_F(PatternMatcherTest, HloInstructionDescribeToAndExplain) { "HloInstruction doesn't have opcode iota\n" "in c = s32[] constant(0)\n" "in operand 0\n" - "in a = s32[] add(s32[] c, s32[] c)"); + "in a = s32[] add(c, c)"); EXPECT_DESC_AND_EXPLANATION( constant, m::Op().WithPredicate(HloPredicateFalse), @@ -955,7 +955,7 @@ TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { "does not match RHS:\n" " - HloInstruction not named \"bar\"\n" " in c = s32[] constant(0)\n" - "in a = s32[] add(s32[] b, s32[] c)"); + "in a = s32[] add(b, c)"); EXPECT_DESC_AND_EXPLANATION( SetName("a", @@ -982,7 +982,7 @@ TEST_F(PatternMatcherTest, HloInstructionMatcherAnyOrderDescribeTo) { "does not match LHS:\n" " - HloInstruction doesn't have opcode constant\n" " in p = s32[] parameter(0)\n" - "in a = s32[] add(s32[] p, s32[] c)"); + "in a = s32[] add(p, c)"); } TEST_F(PatternMatcherTest, AnyOfMatcherDescribeToAndExplain) { @@ -1064,8 +1064,8 @@ TEST_F(PatternMatcherTest, OneUseAndOneUser) { const char* kMultipleUserExplanation = "HloInstruction has 2 users, but expected exactly one.\n" "All users:\n" - " - r = f32[1]{0} reshape(f32[] p0)\n" - " - r1 = f32[1]{0} reshape(f32[] p0)\n" + " - r = f32[1]{0} reshape(p0)\n" + " - r1 = f32[1]{0} reshape(p0)\n" "in p0 = f32[] parameter(0)"; EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), kMultipleUserExplanation); @@ -1080,7 +1080,7 @@ TEST_F(PatternMatcherTest, OneUseAndOneUser) { EXPECT_FALSE(Match(param.get(), m::Op().WithOneUse())); EXPECT_EQ(Explanation(param.get(), m::Op().WithOneUse()), "HloInstruction is used 2 times by its user, but is expected to be " - "used just once: add = f32[] add(f32[] p0, f32[] p0)\n" + "used just once: add = f32[] add(p0, p0)\n" "in p0 = f32[] parameter(0)"); } @@ -1116,7 +1116,7 @@ TEST_F(PatternMatcherTest, MatchSingleUserOnlyUnaryOpTwoUsers) { EXPECT_EQ(Explanation(bitcast.get(), m::Bitcast(m::Op()), /*single_user_only=*/true), "Operand 0 of HloInstruction has 2 users. Expected 1.\nin bitcast " - "= f32[1]{0} bitcast(f32[] p)"); + "= f32[1]{0} bitcast(p)"); } TEST_F(PatternMatcherTest, MatchSingleUserOnlyBinaryOpOneUser) { @@ -1153,7 +1153,7 @@ TEST_F(PatternMatcherTest, MatchSingleUserOnlyBinaryOpTwoUsers) { EXPECT_EQ(Explanation(mul.get(), m::Multiply(m::Op(), m::Op()), /*single_user_only=*/true), "Operand 1 of HloInstruction has 2 users. Expected 1.\nin mul = " - "f32[] multiply(f32[] p1, f32[] p0)"); + "f32[] multiply(p1, p0)"); EXPECT_FALSE(MatchSingleUserOnly(add.get(), m::Add(m::Op(), m::Op()))); EXPECT_FALSE( @@ -1161,7 +1161,7 @@ TEST_F(PatternMatcherTest, MatchSingleUserOnlyBinaryOpTwoUsers) { EXPECT_EQ(Explanation(add.get(), m::Add(m::Op(), m::Op()), /*single_user_only=*/true), "Operand 0 of HloInstruction has 2 users. Expected 1.\nin add = " - "f32[] add(f32[] p0, f32[] p0)"); + "f32[] add(p0, p0)"); } TEST_F(PatternMatcherTest, MatchSingleUserOnlyBinaryOpTwoUsersLowerLevel) { @@ -1195,7 +1195,7 @@ TEST_F(PatternMatcherTest, MatchSingleUserOnlyBinaryOpTwoUsersLowerLevel) { EXPECT_EQ(Explanation(add.get(), m::Add(m::Op(), m::Op()), /*single_user_only=*/true), "Operand 0 of HloInstruction has 2 users. Expected 1.\nin add = " - "f32[] add(f32[] p0, f32[] p0)"); + "f32[] add(p0, p0)"); } TEST_F(PatternMatcherTest, Comparison) { @@ -1237,8 +1237,7 @@ TEST_F(PatternMatcherTest, Comparison) { " * which has exactly one user (but possibly is used " "multiple times by that instruction)", "HloInstruction is not comparison NE\n" - "in compare = f32[1]{0} compare(f32[1]{0} param.0, f32[1]{0} param.1), " - "direction=EQ"); + "in compare = f32[1]{0} compare(param.0, param.1), direction=EQ"); } TEST_F(PatternMatcherTest, ConvDnums) { @@ -1403,9 +1402,8 @@ TEST_F(PatternMatcherTest, TestWithContractingDims) { " * with opcode dot AND\n" " * with lhs_contracting_dims {1} and rhs_contracting_dims {0,1}", "rhs_contracting_dimensions {0} don't match expected {0,1}\n" - "in dot1 = f32[2048,33708]{1,0} dot(f32[2048,1024]{1,0} param1, " - "f32[1024,33708]{1,0} param2), lhs_contracting_dims={1}, " - "rhs_contracting_dims={0}"); + "in dot1 = f32[2048,33708]{1,0} dot(param1, param2), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); } TEST_F(PatternMatcherTest, TestWithReplicaGroups) { @@ -1435,7 +1433,7 @@ TEST_F(PatternMatcherTest, TestWithReplicaGroups) { " * with replica_group {{1,0},{3,2}}", "replica_group {{0,1},{2,3}} don't match expected with replica_group " "{{1,0},{3,2}}\n" - "in all-reduce = f32[128,32]{0,1} all-reduce(f32[128,32]{0,1} input), " + "in all-reduce = f32[128,32]{0,1} all-reduce(input), " "replica_groups={{0,1},{2,3}}, to_apply=add"); } @@ -1490,8 +1488,7 @@ TEST_F(PatternMatcherTest, TestWithControlDeps) { "an HloInstruction with control predecessors {mul} and control " "successors {div}", "HloInstruction expected to have control successors {div} but has {}\n" - "in div = f32[4]{0} divide(f32[4]{0} p0, f32[4]{0} p1), " - "control-predecessors={mul}"); + "in div = f32[4]{0} divide(p0, p1), control-predecessors={mul}"); } } // namespace diff --git a/xla/service/scatter_expander_test.cc b/xla/service/scatter_expander_test.cc index 664f0112068fb8..587780a32ff3cf 100644 --- a/xla/service/scatter_expander_test.cc +++ b/xla/service/scatter_expander_test.cc @@ -170,38 +170,38 @@ HloModule TensorFlowScatter const std::string expected = R"( //CHECK: (s32[], s32[5,3,2,2], s32[30], s32[30,2])) -> (s32[], s32[5,3,2,2], s32[30], s32[30,2]) { //CHECK: %[[PARAM:.*]] = (s32[], s32[5,3,2,2], s32[30], s32[30,2]) parameter(0) - //CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=0 + //CHECK: %[[I:.*]] = s32[] get-tuple-element(%[[PARAM]]), index=0 //CHECK: %[[CONSTANT1:.*]] = s32[] constant(1) - //CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]]) - //CHECK: %[[OPERAND:.*]] = s32[5,3,2,2] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=1 + //CHECK: %[[I_PLUS_1:.*]] = s32[] add(%[[I]], %[[CONSTANT1]]) + //CHECK: %[[OPERAND:.*]] = s32[5,3,2,2] get-tuple-element(%[[PARAM]]), index=1 //CHECK: %[[CONSTANT0:.*]] = s32[] constant(0) - //CHECK: %[[OPERAND_INDICES_LOWER_BOUND:.*]] = s32[4] broadcast(s32[] %[[CONSTANT0]]) + //CHECK: %[[OPERAND_INDICES_LOWER_BOUND:.*]] = s32[4] broadcast(%[[CONSTANT0]]) //CHECK: %[[CONSTANT5:.*]] = s32[] constant(5) - //CHECK: %[[REMAINDER:.*]] = s32[] remainder(s32[] %[[I]], s32[] %[[CONSTANT5]]) - //CHECK: %[[BD2:.*]] = s32[1] broadcast(s32[] %[[REMAINDER]]) - //CHECK: %[[START_INDICES:.*]] = s32[30] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=2 - //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]]) - //CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(s32[1] %[[I_1D_1]]) - //CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_RAW]]) - //CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(s32[30] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX]]) - - //CHECK: %[[SCATTER_INDEX:.*]] = s32[1] slice(s32[1] %[[INDEX_VECTOR]]) + //CHECK: %[[REMAINDER:.*]] = s32[] remainder(%[[I]], %[[CONSTANT5]]) + //CHECK: %[[BD2:.*]] = s32[1] broadcast(%[[REMAINDER]]) + //CHECK: %[[START_INDICES:.*]] = s32[30] get-tuple-element(%[[PARAM]]), index=2 + //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(%[[I]]) + //CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(%[[I_1D_1]]) + //CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(%[[START_INDICES_INDEX_RAW]]) + //CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(%[[START_INDICES]], %[[START_INDICES_INDEX]]) + + //CHECK: %[[SCATTER_INDEX:.*]] = s32[1] slice(%[[INDEX_VECTOR]]) //CHECK: %[[CONSTANT0_2:.*]] = s32[1] constant({0}) - //CHECK: %[[BD_0_1:.*]] = s32[] divide(s32[] %[[I]], s32[] %[[CONSTANT5]]) + //CHECK: %[[BD_0_1:.*]] = s32[] divide(%[[I]], %[[CONSTANT5]]) //CHECK: %[[CONSTANT3:.*]] = s32[] constant(3) - //CHECK: %[[BD0_RAW:.*]] = s32[] divide(s32[] %[[BD_0_1]], s32[] %[[CONSTANT3]]) - //CHECK: %[[BD0:.*]] = s32[1] broadcast(s32[] %[[BD0_RAW]]) - //CHECK: %[[OPERAND_INDICES:.*]] = s32[4] concatenate(s32[1] %[[BD2]], s32[1] %[[SCATTER_INDEX]], s32[1] %[[CONSTANT0_2]], s32[1] %[[BD0]]) - //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[0:1]} - //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]]) - //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[1:2]} - //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]]) - //CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[2:3]} - //CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D2_RAW]]) - //CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[3:4]} - //CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D3_RAW]]) - //CHECK: %{{.*}} = s32[1,1,2,1] dynamic-slice(s32[5,3,2,2] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]], s32[] %[[OPERAND_INDEX_D2]], s32[] %[[OPERAND_INDEX_D3]]) + //CHECK: %[[BD0_RAW:.*]] = s32[] divide(%[[BD_0_1]], %[[CONSTANT3]]) + //CHECK: %[[BD0:.*]] = s32[1] broadcast(%[[BD0_RAW]]) + //CHECK: %[[OPERAND_INDICES:.*]] = s32[4] concatenate(%[[BD2]], %[[SCATTER_INDEX]], %[[CONSTANT0_2]], %[[BD0]]) + //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(%[[OPERAND_INDICES]]), slice={[0:1]} + //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(%[[OPERAND_INDEX_D0_RAW]]) + //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(%[[OPERAND_INDICES]]), slice={[1:2]} + //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(%[[OPERAND_INDEX_D1_RAW]]) + //CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(%[[OPERAND_INDICES]]), slice={[2:3]} + //CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(%[[OPERAND_INDEX_D2_RAW]]) + //CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(%[[OPERAND_INDICES]]), slice={[3:4]} + //CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(%[[OPERAND_INDEX_D3_RAW]]) + //CHECK: %{{.*}} = s32[1,1,2,1] dynamic-slice(%[[OPERAND]], %[[OPERAND_INDEX_D0]], %[[OPERAND_INDEX_D1]], %[[OPERAND_INDEX_D2]], %[[OPERAND_INDEX_D3]]) )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, diff --git a/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 1260a70c53821f..4f9f3a22dd160a 100644 --- a/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -688,7 +688,7 @@ ENTRY %main.0 (Arg_0.0: s64[2]) -> s64[2] { const char* const expected = R"( // CHECK: ENTRY %main.3 (Arg_0.1: s64[2]) -> s64[2] { // CHECK-NEXT: ROOT %Arg_0.1 = s64[2] parameter(0) - // CHECK-NEXT{LITERAL}: %custom-call.2 = () custom-call(s64[2] %Arg_0.1), custom_call_target="xla_ffi_python_cpu_callback", + // CHECK-NEXT{LITERAL}: %custom-call.2 = () custom-call(%Arg_0.1), custom_call_target="xla_ffi_python_cpu_callback", // CHECK-SAME{LITERAL}: operand_layout_constraints={s64[2]{0}}, custom_call_has_side_effect=true, api_version=API_VERSION_TYPED_FFI, // CHECK-SAME{LITERAL}: sharding={{maximal device=0}} )"; diff --git a/xla/tools/hlo_opt/tests/cpu_hlo_pass.hlo b/xla/tools/hlo_opt/tests/cpu_hlo_pass.hlo index 14b5afb1e2e65a..f35aed95cf20ea 100644 --- a/xla/tools/hlo_opt/tests/cpu_hlo_pass.hlo +++ b/xla/tools/hlo_opt/tests/cpu_hlo_pass.hlo @@ -6,10 +6,10 @@ // CHECK-LABEL: ENTRY %DotOperationFusion_TransposeFusion // CHECK-NEXT: %[[arg0:[^ ]+]] = f32[256,1]{1,0} parameter(0) -// CHECK-NEXT: %[[transpose:[^ ]+]] = f32[1,256]{1,0} transpose(f32[256,1]{1,0} %[[arg0]]), dimensions={1,0} +// CHECK-NEXT: %[[transpose:[^ ]+]] = f32[1,256]{1,0} transpose(%[[arg0]]), dimensions={1,0} // CHECK-NEXT: %[[arg1:[^ ]+]] = f32[256,1024]{1,0} parameter(1) -// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[256,1024]{1,0} exponential(f32[256,1024]{1,0} %[[arg1]]) -// CHECK-NEXT: ROOT %[[dot:[^ ]+]] = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %[[transpose]], f32[256,1024]{1,0} %[[exponential]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} +// CHECK-NEXT: %[[exponential:[^ ]+]] = f32[256,1024]{1,0} exponential(%[[arg1]]) +// CHECK-NEXT: ROOT %[[dot:[^ ]+]] = f32[1,1024]{1,0} dot(%[[transpose]], %[[exponential]]), lhs_contracting_dims={1}, rhs_contracting_dims={0} HloModule DotOperationFusion_TransposeFusion diff --git a/xla/tools/hlo_opt/tests/gpu_hlo_pass.hlo b/xla/tools/hlo_opt/tests/gpu_hlo_pass.hlo index de976957809534..8df2cdf621eb15 100644 --- a/xla/tools/hlo_opt/tests/gpu_hlo_pass.hlo +++ b/xla/tools/hlo_opt/tests/gpu_hlo_pass.hlo @@ -4,35 +4,35 @@ HloModule Algorithm3xBF16 // CHECK-LABEL: HloModule Algorithm3xBF16, entry_computation_layout={(f32[128,128]{1,0}, f32[128,128]{1,0})->f32[128,128]{1,0}} // CHECK: ENTRY %e (p0: f32[128,128], p1: f32[128,128]) -> f32[128,128] { // CHECK-NEXT: %[[constant2:.*]] = f32[] constant(inf) -// CHECK-NEXT: %[[broadcast2:.*]] = f32[128,128]{1,0} broadcast(f32[] %[[constant2]]), dimensions={} +// CHECK-NEXT: %[[broadcast2:.*]] = f32[128,128]{1,0} broadcast(%[[constant2]]), dimensions={} // CHECK-NEXT: %[[p0:.*]] = f32[128,128]{1,0} parameter(0) -// CHECK-NEXT: %[[bitcastconvert:.*]] = u32[128,128]{1,0} bitcast-convert(f32[128,128]{1,0} %[[p0]]) +// CHECK-NEXT: %[[bitcastconvert:.*]] = u32[128,128]{1,0} bitcast-convert(%[[p0]]) // CHECK-NEXT: %[[constant:.*]] = u32[] constant(4294901760) -// CHECK-NEXT: %[[broadcast:.*]] = u32[128,128]{1,0} broadcast(u32[] %[[constant]]), dimensions={} -// CHECK-NEXT: %[[and:.*]] = u32[128,128]{1,0} and(u32[128,128]{1,0} %[[bitcastconvert]], u32[128,128]{1,0} %[[broadcast]]) -// CHECK-NEXT: %[[bitcastconvert1:.*]] = f32[128,128]{1,0} bitcast-convert(u32[128,128]{1,0} %[[and]]) -// CHECK-NEXT: %[[subtract:.*]] = f32[128,128]{1,0} subtract(f32[128,128]{1,0} %[[p0]], f32[128,128]{1,0} %[[bitcastconvert1]]) -// CHECK-NEXT: %[[convert1:.*]] = bf16[128,128]{1,0} convert(f32[128,128]{1,0} %[[subtract]]) +// CHECK-NEXT: %[[broadcast:.*]] = u32[128,128]{1,0} broadcast(%[[constant]]), dimensions={} +// CHECK-NEXT: %[[and:.*]] = u32[128,128]{1,0} and(%[[bitcastconvert]], %[[broadcast]]) +// CHECK-NEXT: %[[bitcastconvert1:.*]] = f32[128,128]{1,0} bitcast-convert(%[[and]]) +// CHECK-NEXT: %[[subtract:.*]] = f32[128,128]{1,0} subtract(%[[p0]], %[[bitcastconvert1]]) +// CHECK-NEXT: %[[convert1:.*]] = bf16[128,128]{1,0} convert(%[[subtract]]) // CHECK-NEXT: %[[p1:.*]] = f32[128,128]{1,0} parameter(1) -// CHECK-NEXT: %[[bitcastconvert2:.*]] = u32[128,128]{1,0} bitcast-convert(f32[128,128]{1,0} %[[p1]]) +// CHECK-NEXT: %[[bitcastconvert2:.*]] = u32[128,128]{1,0} bitcast-convert(%[[p1]]) // CHECK-NEXT: %[[constant1:.*]] = u32[] constant(4294901760) -// CHECK-NEXT: %[[broadcast1:.*]] = u32[128,128]{1,0} broadcast(u32[] %[[constant1]]), dimensions={} -// CHECK-NEXT: %[[and1:.*]] = u32[128,128]{1,0} and(u32[128,128]{1,0} %[[bitcastconvert2]], u32[128,128]{1,0} %[[broadcast1]]) -// CHECK-NEXT: %[[bitcastconvert3:.*]] = f32[128,128]{1,0} bitcast-convert(u32[128,128]{1,0} %[[and1]]) -// CHECK-NEXT: %[[convert2:.*]] = bf16[128,128]{1,0} convert(f32[128,128]{1,0} %[[bitcastconvert3]]) -// CHECK-NEXT: %[[dot1:.*]] = f32[128,128]{1,0} dot(bf16[128,128]{1,0} %[[convert1]], bf16[128,128]{1,0} %[[convert2]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32 -// CHECK-NEXT: %[[convert:.*]] = bf16[128,128]{1,0} convert(f32[128,128]{1,0} %[[bitcastconvert1]]) -// CHECK-NEXT: %[[subtract1:.*]] = f32[128,128]{1,0} subtract(f32[128,128]{1,0} %[[p1]], f32[128,128]{1,0} %[[bitcastconvert3]]) -// CHECK-NEXT: %[[convert3:.*]] = bf16[128,128]{1,0} convert(f32[128,128]{1,0} %[[subtract1]]) -// CHECK-NEXT: %[[dot2:.*]] = f32[128,128]{1,0} dot(bf16[128,128]{1,0} %[[convert]], bf16[128,128]{1,0} %[[convert3]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32 -// CHECK-NEXT: %[[add:.*]] = f32[128,128]{1,0} add(f32[128,128]{1,0} %[[dot1]], f32[128,128]{1,0} %[[dot2]]) -// CHECK-NEXT: %[[abs:.*]] = f32[128,128]{1,0} abs(f32[128,128]{1,0} %[[add]]) -// CHECK-NEXT: %[[compare:.*]] = pred[128,128]{1,0} compare(f32[128,128]{1,0} %[[broadcast2]], f32[128,128]{1,0} %[[abs]]), direction=GE +// CHECK-NEXT: %[[broadcast1:.*]] = u32[128,128]{1,0} broadcast(%[[constant1]]), dimensions={} +// CHECK-NEXT: %[[and1:.*]] = u32[128,128]{1,0} and(%[[bitcastconvert2]], %[[broadcast1]]) +// CHECK-NEXT: %[[bitcastconvert3:.*]] = f32[128,128]{1,0} bitcast-convert(%[[and1]]) +// CHECK-NEXT: %[[convert2:.*]] = bf16[128,128]{1,0} convert(%[[bitcastconvert3]]) +// CHECK-NEXT: %[[dot1:.*]] = f32[128,128]{1,0} dot(%[[convert1]], %[[convert2]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32 +// CHECK-NEXT: %[[convert:.*]] = bf16[128,128]{1,0} convert(%[[bitcastconvert1]]) +// CHECK-NEXT: %[[subtract1:.*]] = f32[128,128]{1,0} subtract(%[[p1]], %[[bitcastconvert3]]) +// CHECK-NEXT: %[[convert3:.*]] = bf16[128,128]{1,0} convert(%[[subtract1]]) +// CHECK-NEXT: %[[dot2:.*]] = f32[128,128]{1,0} dot(%[[convert]], %[[convert3]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32 +// CHECK-NEXT: %[[add:.*]] = f32[128,128]{1,0} add(%[[dot1]], %[[dot2]]) +// CHECK-NEXT: %[[abs:.*]] = f32[128,128]{1,0} abs(%[[add]]) +// CHECK-NEXT: %[[compare:.*]] = pred[128,128]{1,0} compare(%[[broadcast2]], %[[abs]]), direction=GE // CHECK-NEXT: %[[constant3:.*]] = f32[] constant(0) -// CHECK-NEXT: %[[broadcast3:.*]] = f32[128,128]{1,0} broadcast(f32[] %[[constant3]]), dimensions={} -// CHECK-NEXT: %[[select:.*]] = f32[128,128]{1,0} select(pred[128,128]{1,0} %[[compare]], f32[128,128]{1,0} %[[add]], f32[128,128]{1,0} %[[broadcast3]]) -// CHECK-NEXT: %[[dot3:.*]] = f32[128,128]{1,0} dot(bf16[128,128]{1,0} %[[convert]], bf16[128,128]{1,0} %[[convert2]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32 -// CHECK-NEXT: ROOT %[[add1:.*]] = f32[128,128]{1,0} add(f32[128,128]{1,0} %[[select]], f32[128,128]{1,0} %[[dot3]]) +// CHECK-NEXT: %[[broadcast3:.*]] = f32[128,128]{1,0} broadcast(%[[constant3]]), dimensions={} +// CHECK-NEXT: %[[select:.*]] = f32[128,128]{1,0} select(%[[compare]], %[[add]], %[[broadcast3]]) +// CHECK-NEXT: %[[dot3:.*]] = f32[128,128]{1,0} dot(%[[convert]], %[[convert2]]), lhs_contracting_dims={1}, rhs_contracting_dims={0}, algorithm=dot_bf16_bf16_f32 +// CHECK-NEXT: ROOT %[[add1:.*]] = f32[128,128]{1,0} add(%[[select]], %[[dot3]]) // CHECK-NEXT: } ENTRY e { p0 = f32[128,128] parameter(0) diff --git a/xla/tools/tests/hlo_expand_test.cc b/xla/tools/tests/hlo_expand_test.cc index 7a92bccc734036..33db7cf4bfd217 100644 --- a/xla/tools/tests/hlo_expand_test.cc +++ b/xla/tools/tests/hlo_expand_test.cc @@ -66,7 +66,7 @@ TEST_F(HloExpandTest, CholeskyHlo) { ENTRY %main.3 () -> f64[3,3] { %constant.1 = f64[3,3]{1,0} constant({ { 1, 2, 3 }, { 2, 20, 26 }, { 3, 26, 70 } }) - ROOT %cholesky.2 = f64[3,3]{1,0} cholesky(f64[3,3]{1,0} %constant.1), lower=true + ROOT %cholesky.2 = f64[3,3]{1,0} cholesky(%constant.1), lower=true })"; EXPECT_TRUE(exited_normally_); @@ -85,16 +85,16 @@ TEST_F(HloExpandTest, SpmdHlo) { ENTRY %entry_spmd (param: f32[24,64], param.1: f32[39296,64]) -> f32[24,19648] { %param = f32[24,64]{1,0} parameter(0), sharding={replicated} - %lhs.copy.1 = f32[24,64]{1,0} copy(f32[24,64]{1,0} %param) + %lhs.copy.1 = f32[24,64]{1,0} copy(%param) %param.1 = f32[39296,64]{1,0} parameter(1), sharding={replicated} %constant = s32[2]{0} constant({0, 19648}) %partition-id = u32[] partition-id() - %dynamic-slice = s32[1]{0} dynamic-slice(s32[2]{0} %constant, u32[] %partition-id), dynamic_slice_sizes={1} - %reshape = s32[] reshape(s32[1]{0} %dynamic-slice) + %dynamic-slice = s32[1]{0} dynamic-slice(%constant, %partition-id), dynamic_slice_sizes={1} + %reshape = s32[] reshape(%dynamic-slice) %constant.1 = s32[] constant(0) - %dynamic-slice.1 = f32[19648,64]{1,0} dynamic-slice(f32[39296,64]{1,0} %param.1, s32[] %reshape, s32[] %constant.1), dynamic_slice_sizes={19648,64} - %rhs.copy.1 = f32[19648,64]{1,0} copy(f32[19648,64]{1,0} %dynamic-slice.1) - ROOT %dot.1 = f32[24,19648]{1,0} dot(f32[24,64]{1,0} %lhs.copy.1, f32[19648,64]{1,0} %rhs.copy.1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + %dynamic-slice.1 = f32[19648,64]{1,0} dynamic-slice(%param.1, %reshape, %constant.1), dynamic_slice_sizes={19648,64} + %rhs.copy.1 = f32[19648,64]{1,0} copy(%dynamic-slice.1) + ROOT %dot.1 = f32[24,19648]{1,0} dot(%lhs.copy.1, %rhs.copy.1), lhs_contracting_dims={1}, rhs_contracting_dims={1} })"; EXPECT_TRUE(exited_normally_);