Skip to content

Commit

Permalink
Avoid call to PyArray_TypeNumFromName, which is removed in NumPy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 20, 2023
1 parent c83659f commit e5da001
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions ml_dtypes/_src/custom_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -876,17 +876,13 @@ bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {
// assumes that the other module has a sufficiently complete bfloat16
// implementation. The only known NumPy bfloat16 extension at the time of
// writing is this one (distributed in TF and JAX).
// TODO(phawkins): distribute the bfloat16 extension as its own pip package,
// so we can unambiguously refer to a single canonical definition of bfloat16.
int typenum =
PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
if (typenum != NPY_NOTYPE) {
PyArray_Descr* descr = PyArray_DescrFromType(typenum);
PyArray_Descr *descr = NULL;
if (PyArray_DescrConverter(PyUnicode_FromString(TypeDescriptor<T>::kTypeName), &descr) != NPY_FAIL) {
// The test for an argmax function here is to verify that the
// bfloat16 implementation is sufficiently new, and, say, not from
// an older version of TF or JAX.
if (descr && descr->f && descr->f->argmax) {
TypeDescriptor<T>::npy_type = typenum;
TypeDescriptor<T>::npy_type = descr->type_num;
TypeDescriptor<T>::type_ptr = reinterpret_cast<PyObject*>(descr->typeobj);
if (already_registered != nullptr) {
*already_registered = true;
Expand All @@ -897,7 +893,7 @@ bool RegisterFloatDtype(PyObject* numpy, bool* already_registered = nullptr) {

// It's important that we heap-allocate our type. This is because tp_name
// is not a fully-qualified name for a heap-allocated type, and
// PyArray_TypeNumFromName() (above) looks at the tp_name field to find
// PyArray_DescrConverter() (above) looks at the tp_name field to find
// types. Existing implementations in JAX and TensorFlow look for "bfloat16",
// not "ml_dtypes.bfloat16" when searching for an implementation.
Safe_PyObjectPtr name =
Expand Down

0 comments on commit e5da001

Please sign in to comment.