diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index 75821cdca7244..eb179affe6d1c 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -97,6 +97,7 @@ cc_library( "//xla/tsl/platform:status", "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xla/hlo/ir/hlo_casting_utils.h b/xla/hlo/ir/hlo_casting_utils.h index a4021c400ab39..db655eaaf9ab4 100644 --- a/xla/hlo/ir/hlo_casting_utils.h +++ b/xla/hlo/ir/hlo_casting_utils.h @@ -21,17 +21,41 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_CASTING_UTILS_H_ #define XLA_HLO_IR_HLO_CASTING_UTILS_H_ +#include + +#include "absl/base/config.h" +#include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/tsl/platform/logging.h" namespace xla { +namespace cast_internal { + +template +inline const char* TypeName(T* input = nullptr) { +#ifdef ABSL_INTERNAL_HAS_RTTI + return (input != nullptr) ? typeid(*input).name() : typeid(T).name(); +#else + return "unknown (no RTTI)"; +#endif +} + +template +inline std::string WrongCastError(const HloInstruction* instr) { + return absl::StrFormat( + "HloInstruction '%s' is of type '%s' and cannot be downcasted to '%s.'", + instr->name(), TypeName(instr), TypeName()); +} +} // namespace cast_internal + // Downcasts a const HloInstruction pointer. Dies if argument is nullptr or // TargetClass::ClassOf() does not match. Similar to LLVM's cast. template const T* Cast(const HloInstruction* instr) { CHECK(instr != nullptr); - CHECK(T::ClassOf(instr)); + CHECK(T::ClassOf(instr)) << cast_internal::WrongCastError(instr); return tsl::down_cast(instr); } @@ -39,9 +63,7 @@ const T* Cast(const HloInstruction* instr) { // TargetClass::ClassOf() does not match. Similar to LLVM's cast. template T* Cast(HloInstruction* instr) { - CHECK(instr != nullptr); - CHECK(T::ClassOf(instr)); - return tsl::down_cast(instr); + return const_cast(Cast(const_cast(instr))); } // Downcasts a const HloInstruction pointer or returns nullptr if @@ -58,8 +80,7 @@ const T* DynCast(const HloInstruction* i) { // to LLVM's dyn_cast. template T* DynCast(HloInstruction* i) { - CHECK(i != nullptr); - return !T::ClassOf(i) ? nullptr : tsl::down_cast(i); + return const_cast(DynCast(const_cast(i))); } } // namespace xla diff --git a/xla/hlo/ir/hlo_casting_utils_test.cc b/xla/hlo/ir/hlo_casting_utils_test.cc index 3e7ba3544221a..3eba3d4bc6057 100644 --- a/xla/hlo/ir/hlo_casting_utils_test.cc +++ b/xla/hlo/ir/hlo_casting_utils_test.cc @@ -32,50 +32,50 @@ std::unique_ptr CreateCP() { Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); std::unique_ptr p0 = HloInstruction::CreateParameter(0, shape, "param"); - return HloInstruction::CreateCollectivePermute(shape, p0.get(), {{0, 1}}, 1); + std::unique_ptr cp = + HloInstruction::CreateCollectivePermute(shape, p0.get(), {{0, 1}}, 1); + cp->SetAndSanitizeName("test_cp"); + return cp; } +const char* kWrongCastError = + ".*ClassOf.*'test_cp'.*HloCollectivePermuteInstruction.*" + "HloAllReduceInstruction.*"; +const char* kNullptrError = ".*nullptr.*"; + TEST(HloCastingUtilsTest, Cast) { std::unique_ptr cp = CreateCP(); - HloCollectivePermuteInstruction* casted = - Cast(cp.get()); - EXPECT_NE(casted, nullptr); - - std::unique_ptr const_cp = CreateCP(); - const HloCollectivePermuteInstruction* const_casted = - Cast(const_cp.get()); - EXPECT_NE(const_casted, nullptr); + EXPECT_NE(Cast(cp.get()), nullptr); + EXPECT_DEATH(Cast(cp.get()), kWrongCastError); + cp.reset(); + EXPECT_DEATH(Cast(cp.get()), kNullptrError); } -TEST(HloCastingUtilsTest, CastDeath) { - std::unique_ptr cp = CreateCP(); - // wrong type - EXPECT_DEATH(Cast(cp.get()), ".*ClassOf.*"); - // nullptr +TEST(HloCastingUtilsTest, CastConst) { + std::unique_ptr cp = CreateCP(); + EXPECT_NE(Cast(cp.get()), nullptr); + EXPECT_DEATH(Cast(cp.get()), kWrongCastError); cp.reset(); - EXPECT_DEATH(Cast(cp.get()), ".*nullptr.*"); + EXPECT_DEATH(Cast(cp.get()), + kNullptrError); } TEST(HloCastingUtilsTest, DynCast) { std::unique_ptr cp = CreateCP(); - HloCollectivePermuteInstruction* casted = - DynCast(cp.get()); - EXPECT_NE(casted, nullptr); - - std::unique_ptr const_cp = CreateCP(); - const HloCollectivePermuteInstruction* const_casted = - DynCast(const_cp.get()); - EXPECT_NE(const_casted, nullptr); - - // wrong type - EXPECT_EQ(DynCast(CreateCP().get()), nullptr); + EXPECT_NE(DynCast(cp.get()), nullptr); + EXPECT_EQ(DynCast(cp.get()), nullptr); + cp.reset(); + EXPECT_DEATH(DynCast(cp.get()), + kNullptrError); } -TEST(HloCastingUtilsTest, DynCastDeath) { - std::unique_ptr cp = CreateCP(); +TEST(HloCastingUtilsTest, DynCastConst) { + std::unique_ptr cp = CreateCP(); + EXPECT_NE(DynCast(cp.get()), nullptr); + EXPECT_EQ(DynCast(cp.get()), nullptr); cp.reset(); - EXPECT_DEATH(DynCast(cp.get()), - ".*nullptr.*"); + EXPECT_DEATH(DynCast(cp.get()), + kNullptrError); } void BM_Cast(benchmark::State& state) {