Skip to content

Commit

Permalink
Bring back expanded error message.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729187353
  • Loading branch information
toli-y authored and Google-ML-Automation committed Feb 28, 2025
1 parent 517396f commit a58349b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 36 deletions.
1 change: 1 addition & 0 deletions xla/hlo/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ cc_library(
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
33 changes: 27 additions & 6 deletions xla/hlo/ir/hlo_casting_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,49 @@ limitations under the License.
#ifndef XLA_HLO_IR_HLO_CASTING_UTILS_H_
#define XLA_HLO_IR_HLO_CASTING_UTILS_H_

#include <string>

#include "absl/base/config.h"
#include "absl/log/check.h"
#include "absl/strings/str_format.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/tsl/platform/logging.h"

namespace xla {

namespace cast_internal {

template <typename T>
inline const char* TypeName(T* input = nullptr) {
#ifdef ABSL_INTERNAL_HAS_RTTI
return (input != nullptr) ? typeid(*input).name() : typeid(T).name();
#else
return "unknown (no RTTI)";
#endif
}

template <typename T>
inline std::string WrongCastError(const HloInstruction* instr) {
return absl::StrFormat(
"HloInstruction '%s' is of type '%s' and cannot be downcasted to '%s.'",
instr->name(), TypeName(instr), TypeName<T>());
}
} // namespace cast_internal

// Downcasts a const HloInstruction pointer. Dies if argument is nullptr or
// TargetClass::ClassOf() does not match. Similar to LLVM's cast.
template <typename T>
const T* Cast(const HloInstruction* instr) {
CHECK(instr != nullptr);
CHECK(T::ClassOf(instr));
CHECK(T::ClassOf(instr)) << cast_internal::WrongCastError<T>(instr);
return tsl::down_cast<const T*>(instr);
}

// Downcasts a non-const HloInstruction pointer. Dies if argument is nullptr or
// TargetClass::ClassOf() does not match. Similar to LLVM's cast.
template <typename T>
T* Cast(HloInstruction* instr) {
CHECK(instr != nullptr);
CHECK(T::ClassOf(instr));
return tsl::down_cast<T*>(instr);
return const_cast<T*>(Cast<T>(const_cast<const HloInstruction*>(instr)));
}

// Downcasts a const HloInstruction pointer or returns nullptr if
Expand All @@ -58,8 +80,7 @@ const T* DynCast(const HloInstruction* i) {
// to LLVM's dyn_cast.
template <typename T>
T* DynCast(HloInstruction* i) {
CHECK(i != nullptr);
return !T::ClassOf(i) ? nullptr : tsl::down_cast<T*>(i);
return const_cast<T*>(DynCast<T>(const_cast<const HloInstruction*>(i)));
}

} // namespace xla
Expand Down
60 changes: 30 additions & 30 deletions xla/hlo/ir/hlo_casting_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,50 @@ std::unique_ptr<HloInstruction> CreateCP() {
Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
std::unique_ptr<HloInstruction> p0 =
HloInstruction::CreateParameter(0, shape, "param");
return HloInstruction::CreateCollectivePermute(shape, p0.get(), {{0, 1}}, 1);
std::unique_ptr<HloInstruction> cp =
HloInstruction::CreateCollectivePermute(shape, p0.get(), {{0, 1}}, 1);
cp->SetAndSanitizeName("test_cp");
return cp;
}

const char* kWrongCastError =
".*ClassOf.*'test_cp'.*HloCollectivePermuteInstruction.*"
"HloAllReduceInstruction.*";
const char* kNullptrError = ".*nullptr.*";

TEST(HloCastingUtilsTest, Cast) {
std::unique_ptr<HloInstruction> cp = CreateCP();
HloCollectivePermuteInstruction* casted =
Cast<HloCollectivePermuteInstruction>(cp.get());
EXPECT_NE(casted, nullptr);

std::unique_ptr<const HloInstruction> const_cp = CreateCP();
const HloCollectivePermuteInstruction* const_casted =
Cast<const HloCollectivePermuteInstruction>(const_cp.get());
EXPECT_NE(const_casted, nullptr);
EXPECT_NE(Cast<HloCollectivePermuteInstruction>(cp.get()), nullptr);
EXPECT_DEATH(Cast<HloAllReduceInstruction>(cp.get()), kWrongCastError);
cp.reset();
EXPECT_DEATH(Cast<HloCollectivePermuteInstruction>(cp.get()), kNullptrError);
}

TEST(HloCastingUtilsTest, CastDeath) {
std::unique_ptr<HloInstruction> cp = CreateCP();
// wrong type
EXPECT_DEATH(Cast<HloAllReduceInstruction>(cp.get()), ".*ClassOf.*");
// nullptr
TEST(HloCastingUtilsTest, CastConst) {
std::unique_ptr<const HloInstruction> cp = CreateCP();
EXPECT_NE(Cast<const HloCollectivePermuteInstruction>(cp.get()), nullptr);
EXPECT_DEATH(Cast<const HloAllReduceInstruction>(cp.get()), kWrongCastError);
cp.reset();
EXPECT_DEATH(Cast<HloCollectivePermuteInstruction>(cp.get()), ".*nullptr.*");
EXPECT_DEATH(Cast<const HloCollectivePermuteInstruction>(cp.get()),
kNullptrError);
}

TEST(HloCastingUtilsTest, DynCast) {
std::unique_ptr<HloInstruction> cp = CreateCP();
HloCollectivePermuteInstruction* casted =
DynCast<HloCollectivePermuteInstruction>(cp.get());
EXPECT_NE(casted, nullptr);

std::unique_ptr<const HloInstruction> const_cp = CreateCP();
const HloCollectivePermuteInstruction* const_casted =
DynCast<const HloCollectivePermuteInstruction>(const_cp.get());
EXPECT_NE(const_casted, nullptr);

// wrong type
EXPECT_EQ(DynCast<HloAllReduceInstruction>(CreateCP().get()), nullptr);
EXPECT_NE(DynCast<HloCollectivePermuteInstruction>(cp.get()), nullptr);
EXPECT_EQ(DynCast<HloAllReduceInstruction>(cp.get()), nullptr);
cp.reset();
EXPECT_DEATH(DynCast<HloCollectivePermuteInstruction>(cp.get()),
kNullptrError);
}

TEST(HloCastingUtilsTest, DynCastDeath) {
std::unique_ptr<HloInstruction> cp = CreateCP();
TEST(HloCastingUtilsTest, DynCastConst) {
std::unique_ptr<const HloInstruction> cp = CreateCP();
EXPECT_NE(DynCast<const HloCollectivePermuteInstruction>(cp.get()), nullptr);
EXPECT_EQ(DynCast<const HloAllReduceInstruction>(cp.get()), nullptr);
cp.reset();
EXPECT_DEATH(DynCast<HloCollectivePermuteInstruction>(cp.get()),
".*nullptr.*");
EXPECT_DEATH(DynCast<const HloCollectivePermuteInstruction>(cp.get()),
kNullptrError);
}

void BM_Cast(benchmark::State& state) {
Expand Down

0 comments on commit a58349b

Please sign in to comment.