diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c2f574e7..530c581dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,8 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/") include(GNUInstallDirs) include(ExternalProject) include(cmake/utils/utils.cmake) +include(cmake/utils/compile_flags.cmake) +include(cmake/utils/platform_check.cmake) knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF) if (WITH_RAFT) @@ -88,8 +90,6 @@ list( APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}/) add_definitions(-DNOT_COMPILE_FOR_SWIG) -include(cmake/utils/compile_flags.cmake) -include(cmake/utils/platform_check.cmake) include(cmake/libs/libfaiss.cmake) include(cmake/libs/libhnsw.cmake) diff --git a/src/index/gpu_raft/gpu_raft_brute_force.cc b/src/index/gpu_raft/gpu_raft_brute_force.cc index e2b05408c..1cddd7580 100644 --- a/src/index/gpu_raft/gpu_raft_brute_force.cc +++ b/src/index/gpu_raft/gpu_raft_brute_force.cc @@ -21,8 +21,17 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, cuda_concurrent_size); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_cagra.cc b/src/index/gpu_raft/gpu_raft_cagra.cc index 350ee6dd9..baaa5947b 100644 --- a/src/index/gpu_raft/gpu_raft_cagra.cc +++ b/src/index/gpu_raft/gpu_raft_cagra.cc @@ -21,8 +21,17 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_CAGRA, GpuRaftCagraIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuRaftCagraIndexNode, fp32, cuda_concurrent_size); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_CAGRA, GpuRaftCagraIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuRaftCagraIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_ivf_flat.cc b/src/index/gpu_raft/gpu_raft_ivf_flat.cc index 0f986be2e..7631147f0 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_flat.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_flat.cc @@ -21,9 +21,16 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" - +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, cuda_concurrent_size); - +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_ivf_pq.cc b/src/index/gpu_raft/gpu_raft_ivf_pq.cc index 95083edac..43d87bd24 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_pq.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_pq.cc @@ -21,8 +21,23 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, cuda_concurrent_size); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, + []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; + }() + +); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, + []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; + }() + +); } // namespace knowhere