From 7d122f64aac643f443c7b15fe0f6ffa0dff8c05e Mon Sep 17 00:00:00 2001 From: "yusheng.ma" Date: Sat, 24 Feb 2024 23:26:50 +0800 Subject: [PATCH] The thread pool size is related to the number of gpu. Signed-off-by: yusheng.ma --- CMakeLists.txt | 4 ++-- src/index/gpu_raft/gpu_raft_brute_force.cc | 13 +++++++++++-- src/index/gpu_raft/gpu_raft_cagra.cc | 13 +++++++++++-- src/index/gpu_raft/gpu_raft_ivf_flat.cc | 15 +++++++++++---- src/index/gpu_raft/gpu_raft_ivf_pq.cc | 19 +++++++++++++++++-- 5 files changed, 52 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c2f574e7..530c581dc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,8 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/") include(GNUInstallDirs) include(ExternalProject) include(cmake/utils/utils.cmake) +include(cmake/utils/compile_flags.cmake) +include(cmake/utils/platform_check.cmake) knowhere_option(WITH_RAFT "Build with RAFT indexes" OFF) if (WITH_RAFT) @@ -88,8 +90,6 @@ list( APPEND CMAKE_MODULE_PATH ${CMAKE_BINARY_DIR}/) add_definitions(-DNOT_COMPILE_FOR_SWIG) -include(cmake/utils/compile_flags.cmake) -include(cmake/utils/platform_check.cmake) include(cmake/libs/libfaiss.cmake) include(cmake/libs/libhnsw.cmake) diff --git a/src/index/gpu_raft/gpu_raft_brute_force.cc b/src/index/gpu_raft/gpu_raft_brute_force.cc index e2b05408c..1cddd7580 100644 --- a/src/index/gpu_raft/gpu_raft_brute_force.cc +++ b/src/index/gpu_raft/gpu_raft_brute_force.cc @@ -21,8 +21,17 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, cuda_concurrent_size); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_BRUTE_FORCE, GpuRaftBruteForceIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_cagra.cc b/src/index/gpu_raft/gpu_raft_cagra.cc index 350ee6dd9..baaa5947b 100644 --- a/src/index/gpu_raft/gpu_raft_cagra.cc +++ b/src/index/gpu_raft/gpu_raft_cagra.cc @@ -21,8 +21,17 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_CAGRA, GpuRaftCagraIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuRaftCagraIndexNode, fp32, cuda_concurrent_size); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_CAGRA, GpuRaftCagraIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_CAGRA, GpuRaftCagraIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_ivf_flat.cc b/src/index/gpu_raft/gpu_raft_ivf_flat.cc index 0f986be2e..7631147f0 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_flat.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_flat.cc @@ -21,9 +21,16 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" - +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, cuda_concurrent_size); - +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_FLAT, GpuRaftIvfFlatIndexNode, fp32, []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; +}()); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_ivf_pq.cc b/src/index/gpu_raft/gpu_raft_ivf_pq.cc index 95083edac..43d87bd24 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_pq.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_pq.cc @@ -21,8 +21,23 @@ #include "gpu_raft.h" #include "knowhere/factory.h" #include "knowhere/index_node_thread_pool_wrapper.h" +#include "raft/util/cuda_rt_essentials.hpp" namespace knowhere { -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, cuda_concurrent_size); -KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, cuda_concurrent_size); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_RAFT_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, + []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; + }() + +); +KNOWHERE_REGISTER_GLOBAL_WITH_THREAD_POOL(GPU_IVF_PQ, GpuRaftIvfPqIndexNode, fp32, + []() { + int count; + RAFT_CUDA_TRY(cudaGetDeviceCount(&count)); + return count * 8; + }() + +); } // namespace knowhere