Skip to content

Commit

Permalink
fix compilation with CUDA 12.4 (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
sekelle authored Mar 12, 2024
1 parent 5fe32ce commit 61a9674
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
15 changes: 4 additions & 11 deletions domain/include/cstone/cuda/cuda_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,15 @@ void memcpyD2D(const T* src, std::size_t n, T* dest);

void syncGpu();

#if not(defined(THRUST_MAJOR_VERSION) || defined(USE_CUDA) || defined(__CUDACC__) || defined(__HIPCC__))
// This must only be added when thrust headers are not available as device_vector is defined
// in an architecture dependent inline namespace that will clash with this forward declaration
namespace thrust
{

template<class T>
class device_allocator;

template<class T, class Alloc>
class device_vector;

} // namespace thrust
#endif

/*! @brief detection trait to determine whether a template parameter is an instance of thrust::device_vector
*
Expand All @@ -77,9 +76,3 @@ template<class Vector>
struct IsDeviceVector : public std::false_type
{
};

//! @brief detection of thrust device vectors
template<class T, class Alloc>
struct IsDeviceVector<thrust::device_vector<T, Alloc>> : public std::true_type
{
};
7 changes: 7 additions & 0 deletions domain/include/cstone/cuda/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
#include <thrust/device_vector.h>
#include <cuda_runtime.h>

#include "cuda_stubs.h"
#include "errorcheck.cuh"

//! @brief detection of thrust device vectors
template<class T, class Alloc>
struct IsDeviceVector<thrust::device_vector<T, Alloc>> : public std::true_type
{
};

template<class T, class Alloc>
T* rawPtr(thrust::device_vector<T, Alloc>& p)
{
Expand Down
3 changes: 2 additions & 1 deletion domain/include/cstone/cuda/cuda_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@

#if defined(USE_CUDA) || defined(__CUDACC__) || defined(__HIPCC__)
#include "cuda_utils.cuh"
#endif
#else
#include "cuda_stubs.h"
#endif
4 changes: 4 additions & 0 deletions domain/include/cstone/util/tuple.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ constexpr __host__ __device__ tuple<Ts&...> tie(Ts&... args) noexcept
namespace std
{

// Thrust tuples in CUDA are now cuda::std tuples for which structured bindings have been added in CUDA 12.4
#if (CUDART_VERSION < 12040) or defined(__HIPCC__)
template<size_t N, class... Ts>
struct tuple_element<N, thrust::tuple<Ts...>>
{
Expand All @@ -77,6 +79,8 @@ struct tuple_size<thrust::tuple<Ts...>>
static const int value = thrust::tuple_size<thrust::tuple<Ts...>>::value;
};

#endif

} // namespace std

#else
Expand Down

0 comments on commit 61a9674

Please sign in to comment.