-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve type_caster for floating-point types. #829
Changes from 1 commit
1e008be
b025460
2057e51
59e449e
44a3dca
b26f83b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,10 +126,22 @@ template <typename T> | |
struct type_caster<T, enable_if_t<std::is_arithmetic_v<T> && !is_std_char_v<T>>> { | ||
NB_INLINE bool from_python(handle src, uint8_t flags, cleanup_list *) noexcept { | ||
if constexpr (std::is_floating_point_v<T>) { | ||
if constexpr (sizeof(T) == 8) | ||
return detail::load_f64(src.ptr(), flags, &value); | ||
else | ||
return detail::load_f32(src.ptr(), flags, &value); | ||
if constexpr (sizeof(T) == 8) { | ||
// Assume T, double, and Python float are all IEEE 754 binary64 | ||
return detail::load_f64(src.ptr(), flags, (double *) &value); | ||
} else { | ||
double d; | ||
if (!detail::load_f64(src.ptr(), flags, &d)) | ||
return false; | ||
T result = static_cast<T>(d); | ||
if ((flags & (uint8_t) cast_flags::convert) | ||
|| static_cast<double>(result) == d | ||
|| (result != result && d != d)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I intend for the caster to work for any floating-point type. The type Maybe we could just check |
||
value = result; | ||
return true; | ||
} | ||
return false; | ||
} | ||
} else { | ||
if constexpr (std::is_signed_v<T>) { | ||
if constexpr (sizeof(T) == 8) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -904,18 +904,16 @@ bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept { | |
|
||
#if !defined(Py_LIMITED_API) | ||
if (NB_LIKELY(is_float)) { | ||
*out = (double) PyFloat_AS_DOUBLE(o); | ||
*out = PyFloat_AS_DOUBLE(o); | ||
return true; | ||
} | ||
|
||
is_float = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you re-enable this assigment? I am not sure that all compilers will understand that is_float can only be false following this conditional. Having the assignment gurantees that constant propagation will remove the check below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I had assumed this was an old work-around for a specific compiler issue and was no longer needed. Honestly, I think it's better not to have this since it only applies in a not NB_LIKELY code path. Feel free to change your mind; I'm happy to revert this latest commit. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to haveit, I don't think it can do any harm in release mode, and debug mode performance is in any case meaningless. |
||
#endif | ||
|
||
if (is_float || (flags & (uint8_t) cast_flags::convert)) { | ||
double result = PyFloat_AsDouble(o); | ||
|
||
if (result != -1.0 || !PyErr_Occurred()) { | ||
*out = (double) result; | ||
*out = result; | ||
return true; | ||
} else { | ||
PyErr_Clear(); | ||
|
@@ -925,31 +923,6 @@ bool load_f64(PyObject *o, uint8_t flags, double *out) noexcept { | |
return false; | ||
} | ||
|
||
bool load_f32(PyObject *o, uint8_t flags, float *out) noexcept { | ||
bool is_float = PyFloat_CheckExact(o); | ||
|
||
#if !defined(Py_LIMITED_API) | ||
if (NB_LIKELY(is_float)) { | ||
*out = (float) PyFloat_AS_DOUBLE(o); | ||
return true; | ||
} | ||
|
||
is_float = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you re-enable this assigment? I am not sure that all compilers will understand that |
||
#endif | ||
|
||
if (is_float || (flags & (uint8_t) cast_flags::convert)) { | ||
double result = PyFloat_AsDouble(o); | ||
|
||
if (result != -1.0 || !PyErr_Occurred()) { | ||
*out = (float) result; | ||
return true; | ||
} else { | ||
PyErr_Clear(); | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
|
||
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION) && PY_VERSION_HEX < 0x030c0000 | ||
// Direct access for compact integers. These functions are | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that this would be better to still keep in a dedicated
load_f32
routine with the double precision bits inlined. The goal is to keep binding code small that callsload_f32
thousands of times.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I restored
load_f32
. Note that I statically assert that bothdouble
andfloat
adhere to ISO/IEC 60559 as documented here, so I only checkd != d
since I know that ifd
isNaN
, then the conversion tofloat
will giveNaN
. Hopefully, these assertions are true everywhere, or else I have some thinking to do....In the case of
double
, the caster only checkssizeof(T) == sizeof(double)
. The assumption (as documented in the comment) is that this is ISO/IEC 60559 (i.e., IEEE 754) binary64. Hopefully, this is always true for systems of interest. The good news is this branch will be taken forstd::float64_t
as well as fordouble
. If you like, I'm happy to usestd::numeric_limits
in the test, but I hesitated to include<limits>
since it's 1900 lines.I used
std::is_same_v<T, float>
in the test forfloat
since TensorFloat-32 is the same size asfloat
but is a different representation. So,std::float32_t
will not take this branch. (Of course, it will still be correct, but it will use the last branch. (Without this PR, it doesn't work at all.))I did include
<limits>
incommon.cpp
since it's only one file and it's already included transitively bynb_internals.h
, which includestsl/robin_map.h
, which includestsl/robin_hash.h
, which includes<limits>
.