From d17ad366ca37f0ab3a67476e33f6f5dde1234ef3 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Tue, 15 Oct 2024 10:41:42 -0500 Subject: [PATCH] ENH: fix issue with getting Kokkos::complex offsets when CUDA is enabled --- src/complex_dtypes.cpp | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/complex_dtypes.cpp b/src/complex_dtypes.cpp index eda9c97..c2ce103 100644 --- a/src/complex_dtypes.cpp +++ b/src/complex_dtypes.cpp @@ -54,40 +54,48 @@ // //----------------------------------------------------------------------------// +#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) +# define MANAGED __managed__ +#else +# define MANAGED +#endif + namespace Kokkos { namespace { -float re_float_offset; -float im_float_offset; -double re_double_offset; -double im_double_offset; +MANAGED float re_float_offset; +MANAGED float im_float_offset; +MANAGED double re_double_offset; +MANAGED double im_double_offset; } // namespace // Need to explicitly do both float and double since we cannot // partially specialize function templates template <> -constexpr const float&& get<2, float>(const complex&&) noexcept { +KOKKOS_FUNCTION const float&& get<2, float>(const complex&&) noexcept { static_assert(std::is_standard_layout_v>); re_float_offset = static_cast(offsetof(complex, re_)); return std::move(re_float_offset); } template <> -constexpr const float&& get<3, float>(const complex&&) noexcept { +KOKKOS_FUNCTION const float&& get<3, float>(const complex&&) noexcept { static_assert(std::is_standard_layout_v>); im_float_offset = static_cast(offsetof(complex, im_)); return std::move(im_float_offset); } template <> -constexpr const double&& get<2, double>(const complex&&) noexcept { +KOKKOS_FUNCTION const double&& get<2, double>( + const complex&&) noexcept { static_assert(std::is_standard_layout_v>); re_double_offset = static_cast(offsetof(complex, re_)); return std::move(re_double_offset); } template <> -constexpr const double&& get<3, double>(const complex&&) noexcept { +KOKKOS_FUNCTION const double&& get<3, double>( + const complex&&) noexcept { static_assert(std::is_standard_layout_v>); im_double_offset = static_cast(offsetof(complex, im_)); return std::move(im_double_offset);