From 7f498f71c0712d5ab8e29b3e72bb105ac25f1da7 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Fri, 11 Oct 2024 11:00:45 -0500 Subject: [PATCH] ENH: register Kokkos::complex as a numpy dtype --- src/complex_dtypes.cpp | 103 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/src/complex_dtypes.cpp b/src/complex_dtypes.cpp index b8a8a4e..fb0051c 100644 --- a/src/complex_dtypes.cpp +++ b/src/complex_dtypes.cpp @@ -44,6 +44,7 @@ #include "common.hpp" +#include #include #include @@ -53,6 +54,105 @@ // //----------------------------------------------------------------------------// +namespace Kokkos { + +namespace { +float re_float_offset; +float im_float_offset; +double re_double_offset; +double im_double_offset; +} // namespace + +// Need to explicitly do both float and double since we cannot +// partially specialize function templates +template <> +constexpr const float&& get<2, float>(const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + re_float_offset = static_cast(offsetof(complex, re_)); + return std::move(re_float_offset); +} + +template <> +constexpr const float&& get<3, float>(const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + im_float_offset = static_cast(offsetof(complex, im_)); + return std::move(im_float_offset); +} + +template <> +constexpr const double&& get<2, double>(const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + re_double_offset = static_cast(offsetof(complex, re_)); + return std::move(re_double_offset); +} + +template <> +constexpr const double&& get<3, double>(const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + im_double_offset = static_cast(offsetof(complex, im_)); + return std::move(im_double_offset); +} +} // namespace Kokkos + +#define PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND(Name, Offset, Type) \ + ::pybind11::detail::field_descriptor { \ + Name, Offset, sizeof(Type), ::pybind11::format_descriptor::format(), \ + ::pybind11::detail::npy_format_descriptor::dtype() \ + } + +template +void register_complex_as_numpy_dtype() { + /* This function registers Kokkos::complex as a numpy datatype + * which is needed to cast Kokkos views of complex numbers to numpy + * arrays. Ideally we would just call this macro + * + * `PYBIND11_NUMPY_DTYPE(ComplexTp, re_, im_);` + * + * which builds a vector of field descriptors of the complex type. + * However this will not work because re_ and im_ are private member + * variables. The macro needs to extract their type and their offset + * within the class to work properly. + * + * Getting the type is easy since it can only be a float or double. + * Getting the offset requires calling + * `offsetof(Kokkos::complex, re_)`, which will not work since + * we cannot access private member variables. The solution is to + * create a context in which we can access them and return the + * offset from there. This is possible by specializing a templated + * member function or friend function to Kokkos::complex since they + * can access private variables (see + * http://www.gotw.ca/gotw/076.htm). + * + * Looking at Kokkos::complex, there is the get() template function + * which we can specialize. We select this overload + * + * ``` + * template + * friend constexpr const RT&& get(const complex&&) noexcept; + * ``` + * + * And specialize it for I == 2 for re_ and I == 3 for im_. Each + * specialization calls offsetof for the corresponding member + * variables and returns it. The original get function only works + * for I == 0 and I == 1, so these specializations will not + * interfere with it. Since the functions return rvalue references, + * we store the offsets in global variables and move them when + * returning. + */ + + using ComplexTp = Kokkos::complex; + + py::ssize_t re_offset = static_cast( + Kokkos::get<2, Tp>(static_cast(ComplexTp{0.0, 0.0}))); + py::ssize_t im_offset = static_cast( + Kokkos::get<3, Tp>(static_cast(ComplexTp{0.0, 0.0}))); + + ::pybind11::detail::npy_format_descriptor::register_dtype( + ::std::vector<::pybind11::detail::field_descriptor>{ + PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND("re_", re_offset, Tp), + PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND("im_", im_offset, Tp)}); +} + template void generate_complex_dtype(py::module& kokkos, const std::string& _name) { using ComplexTp = Kokkos::complex; @@ -79,4 +179,7 @@ void generate_complex_dtype(py::module& kokkos, const std::string& _name) { void generate_complex_dtypes(py::module& kokkos) { generate_complex_dtype(kokkos, "complex_float32"); generate_complex_dtype(kokkos, "complex_float64"); + + register_complex_as_numpy_dtype(); + register_complex_as_numpy_dtype(); } \ No newline at end of file