Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add suppport for Kokkos::complex Views #61

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ SET(libpykokkos_SOURCES
${CMAKE_CURRENT_LIST_DIR}/src/available.cpp
${CMAKE_CURRENT_LIST_DIR}/src/common.cpp
${CMAKE_CURRENT_LIST_DIR}/src/tools.cpp
${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp)
${CMAKE_CURRENT_LIST_DIR}/src/execution_spaces.cpp
${CMAKE_CURRENT_LIST_DIR}/src/complex_dtypes.cpp)

SET(libpykokkos_HEADERS
${CMAKE_CURRENT_LIST_DIR}/include/libpykokkos.hpp
Expand Down
2 changes: 2 additions & 0 deletions include/fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ enum KokkosViewDataType {
Uint64,
Float32,
Float64,
ComplexFloat32,
ComplexFloat64,
ViewDataTypesEnd
};

Expand Down
1 change: 1 addition & 0 deletions include/libpykokkos.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@ void generate_atomic_variants(py::module& kokkos);
void generate_backend_versions(py::module& kokkos);
void generate_pool_variants(py::module& kokkos);
void generate_execution_spaces(py::module& kokkos);
void generate_complex_dtypes(py::module& kokkos);
void destroy_callbacks();
4 changes: 4 additions & 0 deletions include/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ VIEW_DATA_TYPE(uint32_t, Uint32, "uint32", "unsigned", "unsigned_int")
VIEW_DATA_TYPE(uint64_t, Uint64, "uint64", "unsigned_long")
VIEW_DATA_TYPE(float, Float32, "float32", "float")
VIEW_DATA_TYPE(double, Float64, "float64", "double")
VIEW_DATA_TYPE(Kokkos::complex<float>, ComplexFloat32, "complex_float32_dtype",
"complex_float_dtype")
VIEW_DATA_TYPE(Kokkos::complex<double>, ComplexFloat64, "complex_float64_dtype",
"complex_double_dtype")

//----------------------------------------------------------------------------//
// <data-type> <enum> <string identifiers>
Expand Down
25 changes: 21 additions & 4 deletions include/views.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
#include "fwd.hpp"
#include "traits.hpp"

#include <pybind11/numpy.h>
#include <Kokkos_Core.hpp>
#include <Kokkos_DynRankView.hpp>
#include <iostream>
Expand All @@ -66,6 +67,21 @@ RetT get_extents(Tp &m, std::index_sequence<Idx...>) {
template <typename Up, size_t Idx, typename Tp>
constexpr auto get_stride(Tp &m);

template <typename Tp>
inline std::string get_format() {
return py::format_descriptor<Tp>::format();
}

template <>
inline std::string get_format<Kokkos::complex<float>>() {
return py::format_descriptor<std::complex<float>>::format();
}

template <>
inline std::string get_format<Kokkos::complex<double>>() {
return py::format_descriptor<std::complex<double>>::format();
}

template <typename Up, typename Tp, size_t... Idx,
typename RetT = std::array<size_t, sizeof...(Idx)>>
RetT get_strides(Tp &m, std::index_sequence<Idx...>) {
Expand Down Expand Up @@ -321,12 +337,13 @@ void generate_view(py::module &_mod, const std::string &_name,
_view.def_buffer([_ndim](ViewT &m) -> py::buffer_info {
auto _extents = get_extents(m, std::make_index_sequence<DimIdx + 1>{});
auto _strides = get_stride<Tp>(m, std::make_index_sequence<DimIdx + 1>{});
auto _format = get_format<Tp>();
return py::buffer_info(m.data(), // Pointer to buffer
sizeof(Tp), // Size of one scalar
py::format_descriptor<Tp>::format(), // Descriptor
_ndim, // Number of dimensions
_extents, // Buffer dimensions
_strides // Strides (in bytes) for each index
_format, // Descriptor
_ndim, // Number of dimensions
_extents, // Buffer dimensions
_strides // Strides (in bytes) for each index
);
});

Expand Down
4 changes: 4 additions & 0 deletions kokkos/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def read_dtype(_dtype):
return lib.float32
elif _dtype == np.float64:
return lib.float64
elif _dtype == np.complex64:
return lib.complex_float32_dtype
elif _dtype == np.complex128:
return lib.complex_float64_dtype
except ImportError:
pass

Expand Down
201 changes: 201 additions & 0 deletions src/complex_dtypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
//@HEADER
// ************************************************************************
//
// Kokkos v. 3.0
// Copyright (2020) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the Corporation nor the names of the
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact Christian R. Trott ([email protected])
//
// ************************************************************************
//@HEADER
*/

#include "common.hpp"

#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <Kokkos_Core.hpp>

//----------------------------------------------------------------------------//
//
// The Kokkos::complex dtypes
//
//----------------------------------------------------------------------------//

#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP)
# define MANAGED __managed__
#else
# define MANAGED
#endif

namespace Kokkos {

namespace {
MANAGED float re_float_offset;
MANAGED float im_float_offset;
MANAGED double re_double_offset;
MANAGED double im_double_offset;
} // namespace

// Need to explicitly do both float and double since we cannot
// partially specialize function templates
template <>
KOKKOS_FUNCTION const float&& get<2, float>(const complex<float>&&) noexcept {
static_assert(std::is_standard_layout_v<complex<float>>);
re_float_offset = static_cast<float>(offsetof(complex<float>, re_));
return std::move(re_float_offset);
}

template <>
KOKKOS_FUNCTION const float&& get<3, float>(const complex<float>&&) noexcept {
static_assert(std::is_standard_layout_v<complex<float>>);
im_float_offset = static_cast<float>(offsetof(complex<float>, im_));
return std::move(im_float_offset);
}

template <>
KOKKOS_FUNCTION const double&& get<2, double>(
const complex<double>&&) noexcept {
static_assert(std::is_standard_layout_v<complex<double>>);
re_double_offset = static_cast<double>(offsetof(complex<double>, re_));
return std::move(re_double_offset);
}

template <>
KOKKOS_FUNCTION const double&& get<3, double>(
const complex<double>&&) noexcept {
static_assert(std::is_standard_layout_v<complex<double>>);
im_double_offset = static_cast<double>(offsetof(complex<double>, im_));
return std::move(im_double_offset);
}
} // namespace Kokkos

#define PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND(Name, Offset, Type) \
::pybind11::detail::field_descriptor { \
Name, Offset, sizeof(Type), ::pybind11::format_descriptor<Type>::format(), \
::pybind11::detail::npy_format_descriptor<Type>::dtype() \
}

template <typename Tp>
void register_complex_as_numpy_dtype() {
/* This function registers Kokkos::complex<Tp> as a numpy datatype
* which is needed to cast Kokkos views of complex numbers to numpy
* arrays. Ideally we would just call this macro
*
* `PYBIND11_NUMPY_DTYPE(ComplexTp, re_, im_);`
*
* which builds a vector of field descriptors of the complex type.
* However this will not work because re_ and im_ are private member
* variables. The macro needs to extract their type and their offset
* within the class to work properly.
*
* Getting the type is easy since it can only be a float or double.
* Getting the offset requires calling
* `offsetof(Kokkos::complex<Tp>, re_)`, which will not work since
* we cannot access private member variables. The solution is to
* create a context in which we can access them and return the
* offset from there. This is possible by specializing a templated
* member function or friend function to Kokkos::complex since they
* can access private variables (see
* http://www.gotw.ca/gotw/076.htm).
*
* Looking at Kokkos::complex, there is the get() template function
* which we can specialize. We select this overload
*
* ```
* template <size_t I, typename RT>
* friend constexpr const RT&& get(const complex<RT>&&) noexcept;
* ```
*
* And specialize it for I == 2 for re_ and I == 3 for im_. Each
* specialization calls offsetof for the corresponding member
* variables and returns it. The original get function only works
* for I == 0 and I == 1, so these specializations will not
* interfere with it. Since the functions return rvalue references,
* we store the offsets in global variables and move them when
* returning.
*/

using ComplexTp = Kokkos::complex<Tp>;

py::ssize_t re_offset = static_cast<py::ssize_t>(
Kokkos::get<2, Tp>(static_cast<const ComplexTp&&>(ComplexTp{0.0, 0.0})));
py::ssize_t im_offset = static_cast<py::ssize_t>(
Kokkos::get<3, Tp>(static_cast<const ComplexTp&&>(ComplexTp{0.0, 0.0})));

::pybind11::detail::npy_format_descriptor<ComplexTp>::register_dtype(
::std::vector<::pybind11::detail::field_descriptor>{
PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND("re_", re_offset, Tp),
PYBIND11_FIELD_DESCRIPTOR_EX_WORKAROUND("im_", im_offset, Tp)});
}

template <typename Tp>
void generate_complex_dtype(py::module& kokkos, const std::string& _name) {
using ComplexTp = Kokkos::complex<Tp>;

py::class_<ComplexTp>(kokkos, _name.c_str())
.def(py::init<Tp>()) // Constructor for real part only
.def(py::init<Tp, Tp>()) // Constructor for real and imaginary parts
.def("imag_mutable", py::overload_cast<>(&ComplexTp::imag))
.def("imag_const", py::overload_cast<>(&ComplexTp::imag, py::const_))
.def("imag_set", py::overload_cast<Tp>(&ComplexTp::imag))
.def("real_mutable", py::overload_cast<>(&ComplexTp::real))
.def("real_const", py::overload_cast<>(&ComplexTp::real, py::const_))
.def("real_set", py::overload_cast<Tp>(&ComplexTp::real))
.def(py::self + py::self)
.def(py::self + Tp())
.def(py::self += py::self)
.def(py::self += Tp())
.def(py::self - py::self)
.def(py::self - Tp())
.def(py::self -= py::self)
.def(py::self -= Tp())
.def(py::self * py::self)
.def(py::self * Tp())
.def(py::self *= py::self)
.def(py::self *= Tp())
.def(py::self / py::self)
.def(py::self / Tp())
.def(py::self /= py::self)
.def(py::self /= Tp());
}

void generate_complex_dtypes(py::module& kokkos) {
generate_complex_dtype<float>(kokkos, "complex_float32");
generate_complex_dtype<double>(kokkos, "complex_float64");

register_complex_as_numpy_dtype<float>();
register_complex_as_numpy_dtype<double>();
}
1 change: 1 addition & 0 deletions src/libpykokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,5 @@ PYBIND11_MODULE(libpykokkos, kokkos) {
generate_backend_versions(kokkos);
generate_pool_variants(kokkos);
generate_execution_spaces(kokkos);
generate_complex_dtypes(kokkos);
}
2 changes: 1 addition & 1 deletion src/variants/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TARGET_LINK_LIBRARIES(libpykokkos-variants PUBLIC

SET(_types concrete dynamic)
SET(_variants layout memory_trait)
SET(_data_types Int8 Int16 Int32 Int64 Uint8 Uint16 Uint32 Uint64 Float32 Float64)
SET(_data_types Int8 Int16 Int32 Int64 Uint8 Uint16 Uint32 Uint64 Float32 Float64 ComplexFloat32 ComplexFloat64)

SET(layout_enums Right)
SET(memory_trait_enums Managed)
Expand Down
Loading