Skip to content

Commit

Permalink
Simplify dtype registration logic
Browse files Browse the repository at this point in the history
The checks for already registered types were added in the precursor to
ml_dtypes, when both JAX and tensorflow registered their own copies of
bfloat16. Since that is no longer a concern, we can remove this logic.
  • Loading branch information
jakevdp committed Dec 20, 2023
1 parent 29edcb3 commit 5680083
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 54 deletions.
19 changes: 3 additions & 16 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -861,22 +861,9 @@ bool RegisterFloatUFuncs(PyObject* numpy) {
return ok;
}

// TODO(jakevdp): simplify the following. The already_registered check is no
// longer necessary, and heap allocation is probably not important any longer.
//
// Returns true if the numpy type for T is successfully registered, including if
// it was already registered (e.g. by a different library). If
// `already_registered` is non-null, it's set to true if the type was already
// registered and false otherwise.
template <typename T>
bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {
if (already_registered != nullptr) {
*already_registered = false;
}
// It's important that we heap-allocate our type. This is because tp_name
// is not a fully-qualified name for a heap-allocated type.
// Existing implementations in JAX and TensorFlow look for "bfloat16",
// not "ml_dtypes.bfloat16" when searching for an implementation.
template <typename T>
bool RegisterFloatDtype(PyObject* numpy) {
// TODO(jakevdp): simplify this; we no longer need heap allocation.
Safe_PyObjectPtr name =
make_safe(PyUnicode_FromString(TypeDescriptor<T>::kTypeName));
Safe_PyObjectPtr qualname =
Expand Down
49 changes: 11 additions & 38 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,61 +226,34 @@ bool Initialize() {
if (!RegisterFloatDtype<bfloat16>(numpy.get())) {
return false;
}
bool float8_e4m3b11fnuz_already_registered;
if (!RegisterFloatDtype<float8_e4m3b11fnuz>(
numpy.get(), &float8_e4m3b11fnuz_already_registered)) {
if (!mRegisterFloatDtype<float8_e4m3b11fnuz>(numpy.get())) {
return false;
}
bool float8_e4m3fn_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e4m3fn>(
numpy.get(), &float8_e4m3fn_already_registered)) {
if (!RegisterFloatDtype<float8_e4m3fn>(numpy.get())) {
return false;
}
bool float8_e4m3fnuz_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e4m3fnuz>(
numpy.get(), &float8_e4m3fnuz_already_registered)) {
if (!RegisterFloatDtype<float8_e4m3fnuz>(numpy.get())) {
return false;
}
bool float8_e5m2_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e5m2>(
numpy.get(), &float8_e5m2_already_registered)) {
if (!RegisterFloatDtype<float8_e5m2>(numpy.get())) {
return false;
}
bool float8_e5m2fnuz_already_registered;
if (!ml_dtypes::RegisterFloatDtype<float8_e5m2fnuz>(
numpy.get(), &float8_e5m2fnuz_already_registered)) {
if (!RegisterFloatDtype<float8_e5m2fnuz>(numpy.get())) {
return false;
}

if (!ml_dtypes::RegisterInt4Dtype<int4>(numpy.get())) {
if (!RegisterInt4Dtype<int4>(numpy.get())) {
return false;
}

if (!ml_dtypes::RegisterInt4Dtype<uint4>(numpy.get())) {
if (!RegisterInt4Dtype<uint4>(numpy.get())) {
return false;
}

// Casts between bfloat16 and float8_e4m3b11fnuz. Only perform the cast if
// float8_e4m3b11fnuz hasn't been previously registered, presumably by a
// different library. In this case, we assume the cast has also already been
// registered, and registering it again can cause segfaults due to accessing
// an uninitialized type descriptor in this library.
if (!float8_e4m3b11fnuz_already_registered &&
!RegisterCustomFloatCast<float8_e4m3b11fnuz, bfloat16>()) {
return false;
}
if (!float8_e4m3fnuz_already_registered &&
!float8_e5m2fnuz_already_registered &&
!RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e5m2fnuz>()) {
return false;
}
if (!float8_e4m3fn_already_registered && !float8_e5m2_already_registered &&
!RegisterCustomFloatCast<float8_e4m3fn, float8_e5m2>()) {
return false;
}
// Register casts between pairs of custom float dtypes.
bool success = true;
// Continue trying to register casts, just in case some types are not
// registered (i.e. float8_e4m3b11fnuz)
success &= RegisterCustomFloatCast<float8_e4m3b11fnuz, bfloat16>();
success &= RegisterTwoWayCustomCast<float8_e4m3fnuz, float8_e5m2fnuz>();
success &= RegisterCustomFloatCast<float8_e4m3fn, float8_e5m2>();
success &= RegisterTwoWayCustomCast<float8_e4m3b11fnuz, float8_e4m3fn>();
success &= RegisterTwoWayCustomCast<float8_e4m3b11fnuz, float8_e5m2>();
success &= RegisterTwoWayCustomCast<bfloat16, float8_e4m3fn>();
Expand Down

0 comments on commit 5680083

Please sign in to comment.