diff --git a/CMakeLists.txt b/CMakeLists.txt index 80482b7..36ef673 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,7 +44,8 @@ SET(libpykokkos_SOURCES ${CMAKE_CURRENT_LIST_DIR}/src/available.cpp ${CMAKE_CURRENT_LIST_DIR}/src/common.cpp ${CMAKE_CURRENT_LIST_DIR}/src/tools.cpp - ${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp) + ${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/complex_dtypes.cpp) SET(libpykokkos_HEADERS ${CMAKE_CURRENT_LIST_DIR}/include/libpykokkos.hpp diff --git a/include/fwd.hpp b/include/fwd.hpp index ee6bb31..6f602d8 100644 --- a/include/fwd.hpp +++ b/include/fwd.hpp @@ -134,6 +134,8 @@ enum KokkosViewDataType { Uint64, Float32, Float64, + ComplexFloat32, + ComplexFloat64, ViewDataTypesEnd }; diff --git a/include/libpykokkos.hpp b/include/libpykokkos.hpp index 18aed0e..271bc8a 100644 --- a/include/libpykokkos.hpp +++ b/include/libpykokkos.hpp @@ -58,4 +58,5 @@ void generate_atomic_variants(py::module& kokkos); void generate_backend_versions(py::module& kokkos); void generate_pool_variants(py::module& kokkos); void generate_execution_spaces(py::module& kokkos); +void generate_complex_dtypes(py::module& kokkos); void destroy_callbacks(); diff --git a/include/traits.hpp b/include/traits.hpp index 83f9d1b..86f6fae 100644 --- a/include/traits.hpp +++ b/include/traits.hpp @@ -85,6 +85,10 @@ VIEW_DATA_TYPE(uint32_t, Uint32, "uint32", "unsigned", "unsigned_int") VIEW_DATA_TYPE(uint64_t, Uint64, "uint64", "unsigned_long") VIEW_DATA_TYPE(float, Float32, "float32", "float") VIEW_DATA_TYPE(double, Float64, "float64", "double") +VIEW_DATA_TYPE(Kokkos::complex, ComplexFloat32, "complex_float32_dtype", + "complex_float_dtype") +VIEW_DATA_TYPE(Kokkos::complex, ComplexFloat64, "complex_float64_dtype", + "complex_double_dtype") //----------------------------------------------------------------------------// // diff --git a/include/views.hpp b/include/views.hpp index 309d287..0731908 100644 --- a/include/views.hpp +++ b/include/views.hpp @@ -51,6 +51,7 @@ #include "fwd.hpp" #include "traits.hpp" +#include #include #include #include @@ -66,6 +67,21 @@ RetT get_extents(Tp &m, std::index_sequence) { template constexpr auto get_stride(Tp &m); +template +inline std::string get_format() { + return py::format_descriptor::format(); +} + +template <> +inline std::string get_format>() { + return py::format_descriptor>::format(); +} + +template <> +inline std::string get_format>() { + return py::format_descriptor>::format(); +} + template > RetT get_strides(Tp &m, std::index_sequence) { @@ -321,12 +337,13 @@ void generate_view(py::module &_mod, const std::string &_name, _view.def_buffer([_ndim](ViewT &m) -> py::buffer_info { auto _extents = get_extents(m, std::make_index_sequence{}); auto _strides = get_stride(m, std::make_index_sequence{}); + auto _format = get_format(); return py::buffer_info(m.data(), // Pointer to buffer sizeof(Tp), // Size of one scalar - py::format_descriptor::format(), // Descriptor - _ndim, // Number of dimensions - _extents, // Buffer dimensions - _strides // Strides (in bytes) for each index + _format, // Descriptor + _ndim, // Number of dimensions + _extents, // Buffer dimensions + _strides // Strides (in bytes) for each index ); }); diff --git a/kokkos/utility.py b/kokkos/utility.py index 8bb2ee5..56046ae 100644 --- a/kokkos/utility.py +++ b/kokkos/utility.py @@ -101,6 +101,10 @@ def read_dtype(_dtype): return lib.float32 elif _dtype == np.float64: return lib.float64 + elif _dtype == np.complex64: + return lib.complex_float32_dtype + elif _dtype == np.complex128: + return lib.complex_float64_dtype except ImportError: pass diff --git a/src/complex_dtypes.cpp b/src/complex_dtypes.cpp new file mode 100644 index 0000000..c2ce103 --- /dev/null +++ b/src/complex_dtypes.cpp @@ -0,0 +1,201 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Christian R. Trott (crtrott@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#include "common.hpp" + +#include +#include +#include + +//----------------------------------------------------------------------------// +// +// The Kokkos::complex dtypes +// +//----------------------------------------------------------------------------// + +#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) +# define MANAGED __managed__ +#else +# define MANAGED +#endif + +namespace Kokkos { + +namespace { +MANAGED float re_float_offset; +MANAGED float im_float_offset; +MANAGED double re_double_offset; +MANAGED double im_double_offset; +} // namespace + +// Need to explicitly do both float and double since we cannot +// partially specialize function templates +template <> +KOKKOS_FUNCTION const float&& get<2, float>(const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + re_float_offset = static_cast(offsetof(complex, re_)); + return std::move(re_float_offset); +} + +template <> +KOKKOS_FUNCTION const float&& get<3, float>(const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + im_float_offset = static_cast(offsetof(complex, im_)); + return std::move(im_float_offset); +} + +template <> +KOKKOS_FUNCTION const double&& get<2, double>( + const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + re_double_offset = static_cast(offsetof(complex, re_)); + return std::move(re_double_offset); +} + +template <> +KOKKOS_FUNCTION const double&& get<3, double>( + const complex&&) noexcept { + static_assert(std::is_standard_layout_v>); + im_double_offset = static_cast(offsetof(complex, im_)); + return std::move(im_double_offset); +} +} // namespace Kokkos + +#define PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND(Name, Offset, Type) \ + ::pybind11::detail::field_descriptor { \ + Name, Offset, sizeof(Type), ::pybind11::format_descriptor::format(), \ + ::pybind11::detail::npy_format_descriptor::dtype() \ + } + +template +void register_complex_as_numpy_dtype() { + /* This function registers Kokkos::complex as a numpy datatype + * which is needed to cast Kokkos views of complex numbers to numpy + * arrays. Ideally we would just call this macro + * + * `PYBIND11_NUMPY_DTYPE(ComplexTp, re_, im_);` + * + * which builds a vector of field descriptors of the complex type. + * However this will not work because re_ and im_ are private member + * variables. The macro needs to extract their type and their offset + * within the class to work properly. + * + * Getting the type is easy since it can only be a float or double. + * Getting the offset requires calling + * `offsetof(Kokkos::complex, re_)`, which will not work since + * we cannot access private member variables. The solution is to + * create a context in which we can access them and return the + * offset from there. This is possible by specializing a templated + * member function or friend function to Kokkos::complex since they + * can access private variables (see + * http://www.gotw.ca/gotw/076.htm). + * + * Looking at Kokkos::complex, there is the get() template function + * which we can specialize. We select this overload + * + * ``` + * template + * friend constexpr const RT&& get(const complex&&) noexcept; + * ``` + * + * And specialize it for I == 2 for re_ and I == 3 for im_. Each + * specialization calls offsetof for the corresponding member + * variables and returns it. The original get function only works + * for I == 0 and I == 1, so these specializations will not + * interfere with it. Since the functions return rvalue references, + * we store the offsets in global variables and move them when + * returning. + */ + + using ComplexTp = Kokkos::complex; + + py::ssize_t re_offset = static_cast( + Kokkos::get<2, Tp>(static_cast(ComplexTp{0.0, 0.0}))); + py::ssize_t im_offset = static_cast( + Kokkos::get<3, Tp>(static_cast(ComplexTp{0.0, 0.0}))); + + ::pybind11::detail::npy_format_descriptor::register_dtype( + ::std::vector<::pybind11::detail::field_descriptor>{ + PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND("re_", re_offset, Tp), + PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND("im_", im_offset, Tp)}); +} + +template +void generate_complex_dtype(py::module& kokkos, const std::string& _name) { + using ComplexTp = Kokkos::complex; + + py::class_(kokkos, _name.c_str()) + .def(py::init()) // Constructor for real part only + .def(py::init()) // Constructor for real and imaginary parts + .def("imag_mutable", py::overload_cast<>(&ComplexTp::imag)) + .def("imag_const", py::overload_cast<>(&ComplexTp::imag, py::const_)) + .def("imag_set", py::overload_cast(&ComplexTp::imag)) + .def("real_mutable", py::overload_cast<>(&ComplexTp::real)) + .def("real_const", py::overload_cast<>(&ComplexTp::real, py::const_)) + .def("real_set", py::overload_cast(&ComplexTp::real)) + .def(py::self + py::self) + .def(py::self + Tp()) + .def(py::self += py::self) + .def(py::self += Tp()) + .def(py::self - py::self) + .def(py::self - Tp()) + .def(py::self -= py::self) + .def(py::self -= Tp()) + .def(py::self * py::self) + .def(py::self * Tp()) + .def(py::self *= py::self) + .def(py::self *= Tp()) + .def(py::self / py::self) + .def(py::self / Tp()) + .def(py::self /= py::self) + .def(py::self /= Tp()); +} + +void generate_complex_dtypes(py::module& kokkos) { + generate_complex_dtype(kokkos, "complex_float32"); + generate_complex_dtype(kokkos, "complex_float64"); + + register_complex_as_numpy_dtype(); + register_complex_as_numpy_dtype(); +} \ No newline at end of file diff --git a/src/libpykokkos.cpp b/src/libpykokkos.cpp index fd8ff91..84a009c 100644 --- a/src/libpykokkos.cpp +++ b/src/libpykokkos.cpp @@ -116,4 +116,5 @@ PYBIND11_MODULE(libpykokkos, kokkos) { generate_backend_versions(kokkos); generate_pool_variants(kokkos); generate_execution_spaces(kokkos); + generate_complex_dtypes(kokkos); } diff --git a/src/variants/CMakeLists.txt b/src/variants/CMakeLists.txt index 955d462..9515aa4 100644 --- a/src/variants/CMakeLists.txt +++ b/src/variants/CMakeLists.txt @@ -26,7 +26,7 @@ TARGET_LINK_LIBRARIES(libpykokkos-variants PUBLIC SET(_types concrete dynamic) SET(_variants layout memory_trait) -SET(_data_types Int8 Int16 Int32 Int64 Uint8 Uint16 Uint32 Uint64 Float32 Float64) +SET(_data_types Int8 Int16 Int32 Int64 Uint8 Uint16 Uint32 Uint64 Float32 Float64 ComplexFloat32 ComplexFloat64) SET(layout_enums Right) SET(memory_trait_enums Managed)