diff --git a/.gitmodules b/.gitmodules index d41dd11d1..44d6d26f4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "tests/googletest"] path = tests/googletest url = https://github.com/google/googletest.git + ignore = dirty diff --git a/include/merlin/core_kernels.cuh b/include/merlin/core_kernels.cuh index 29c6afdcb..89834475b 100644 --- a/include/merlin/core_kernels.cuh +++ b/include/merlin/core_kernels.cuh @@ -1342,5 +1342,68 @@ __global__ void dump_kernel(const Table* __restrict table, } } +/* Dump with meta. */ +template +__global__ void dump_kernel(const Table* __restrict table, + const EraseIfPredictInternal pred, + const K pattern, const M threshold, K* d_key, + V* __restrict d_val, M* __restrict d_meta, + const size_t offset, const size_t search_length, + size_t* d_dump_counter) { + extern __shared__ unsigned char s[]; + K* smem = (K*)s; + K* block_result_key = smem; + V* block_result_val = (V*)&(smem[blockDim.x]); + M* block_result_meta = (M*)&(block_result_val[blockDim.x]); + __shared__ size_t block_acc; + __shared__ size_t global_acc; + const size_t bucket_max_size = table->bucket_max_size; + + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (threadIdx.x == 0) { + block_acc = 0; + } + __syncthreads(); + + if (tid < search_length) { + int bkt_idx = (tid + offset) / bucket_max_size; + int key_idx = (tid + offset) % bucket_max_size; + Bucket* bucket = &(table->buckets[bkt_idx]); + + K key = bucket->keys[key_idx]; + M meta = bucket->metas[key_idx].val; + + if (key != EMPTY_KEY && pred(key, meta, pattern, threshold)) { + size_t local_index = atomicAdd(&block_acc, 1); + block_result_key[local_index] = bucket->keys[key_idx]; + for (int i = 0; i < DIM; i++) { + atomicExch(&(block_result_val[local_index].values[i]), + bucket->vectors[key_idx].values[i]); + } + if (d_meta != nullptr) { + block_result_meta[local_index] = bucket->metas[key_idx].val; + } + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + global_acc = atomicAdd(d_dump_counter, block_acc); + } + __syncthreads(); + + if (threadIdx.x < block_acc) { + d_key[global_acc + threadIdx.x] = block_result_key[threadIdx.x]; + for (int i = 0; i < DIM; i++) { + d_val[global_acc + threadIdx.x].values[i] = + block_result_val[threadIdx.x].values[i]; + } + if (d_meta != nullptr) { + d_meta[global_acc + threadIdx.x] = block_result_meta[threadIdx.x]; + } + } +} + } // namespace merlin } // namespace nv diff --git a/include/merlin/managed.cuh b/include/merlin/managed.cuh deleted file mode 100644 index 40ea7239b..000000000 --- a/include/merlin/managed.cuh +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace nv { -namespace merlin { - -struct managed { - static void* operator new(size_t n) { - void* ptr = 0; - cudaError_t result = cudaMallocManaged(&ptr, n); - if (cudaSuccess != result || 0 == ptr) throw std::bad_alloc(); - return ptr; - } - - static void operator delete(void* ptr) noexcept { cudaFree(ptr); } -}; - -} // namespace merlin -} // namespace nv \ No newline at end of file diff --git a/include/merlin/types.cuh b/include/merlin/types.cuh index e0e4ead10..69d544e5f 100644 --- a/include/merlin/types.cuh +++ b/include/merlin/types.cuh @@ -78,7 +78,7 @@ struct Table { template using EraseIfPredictInternal = bool (*)(const K& key, ///< iterated key in table - const M& meta, ///< iterated meta in table + M& meta, ///< iterated meta in table const K& pattern, ///< input key from caller const M& threshold ///< input meta from caller ); diff --git a/include/merlin/utils.cuh b/include/merlin/utils.cuh index 38caeb176..aea8df874 100644 --- a/include/merlin/utils.cuh +++ b/include/merlin/utils.cuh @@ -324,40 +324,6 @@ void realloc_managed(P* ptr, size_t old_size, size_t new_size) { return; } -template -__forceinline__ __device__ constexpr bool key_compare(const K* k1, - const K* k2) { - auto __lhs_c = reinterpret_cast(k1); - auto __rhs_c = reinterpret_cast(k2); - -#pragma unroll - for (int i = 0; i < sizeof(K); i++) { - auto const __lhs_v = *__lhs_c++; - auto const __rhs_v = *__rhs_c++; - if (__lhs_v != __rhs_v) { - return false; - } - } - return true; -} - -template -__forceinline__ __device__ constexpr bool key_empty(const K* k) { - constexpr K empty_key = EMPTY_KEY; - auto __lhs_c = reinterpret_cast(k); - auto __rhs_c = reinterpret_cast(&empty_key); - -#pragma unroll - for (int i = 0; i < sizeof(K); i++) { - auto const __lhs_v = *__lhs_c++; - auto const __rhs_v = *__rhs_c++; - if (__lhs_v != __rhs_v) { - return false; - } - } - return true; -} - template __forceinline__ __device__ void lock( const cg::thread_block_tile& tile, mutex& set_mutex) { diff --git a/include/merlin_hashtable.cuh b/include/merlin_hashtable.cuh index b9da7842e..d8220c732 100644 --- a/include/merlin_hashtable.cuh +++ b/include/merlin_hashtable.cuh @@ -65,26 +65,37 @@ struct HashTableOptions { * @brief A customizable template function indicates which keys should be * erased from the hash table by returning `true`. * - * @note The `erase_if` API traverses all of the items by this function and the - * items that return `true` are removed. + * @note The `erase_if` or `export_batch_if` API traverses all of the items by + * this function and the items that return `true` are removed or exported. * - * Example: + * Example for erase_if: * * ``` * template * __forceinline__ __device__ bool erase_if_pred(const K& key, - * const M& meta, + * M& meta, * const K& pattern, * const M& threshold) { * return ((key & 0xFFFF000000000000 == pattern) && * (meta < threshold)); * } * ``` + * + * Example for export_batch_if: + * ``` + * template + * __forceinline__ __device__ bool export_if_pred(const K& key, + * M& meta, + * const K& pattern, + * const M& threshold) { + * return meta >= threshold; + * } + * ``` */ template using EraseIfPredict = bool (*)( const K& key, ///< The traversed key in a hash table. - const M& meta, ///< The traversed meta in a hash table. + M& meta, ///< The traversed meta in a hash table. const K& pattern, ///< The key pattern to compare with the `key` argument. const M& threshold ///< The threshold to compare with the `meta` argument. ); @@ -676,6 +687,7 @@ class HashTable { * * @param n The maximum number of exported pairs. * @param offset The position of the key to remove. + * @param counter The position of the key to remove. * @param keys The keys to dump from GPU-accessible memory with shape (n). * @param values The values to dump from GPU-accessible memory with shape * (n, DIM). @@ -692,13 +704,13 @@ class HashTable { * memory. Reducing the value for @p n is currently required if this exception * occurs. */ - void export_batch(size_type n, size_type offset, size_type* d_counter, + void export_batch(size_type n, size_type offset, size_type* counter, key_type* keys, // (n) value_type* values, // (n, DIM) meta_type* metas = nullptr, // (n) cudaStream_t stream = 0) const { if (offset >= table_->capacity) { - CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream)); + CUDA_CHECK(cudaMemsetAsync(counter, 0, sizeof(size_type), stream)); return; } n = std::min(table_->capacity - offset, n); @@ -724,7 +736,7 @@ class HashTable { dump_kernel <<>>( table_, keys, reinterpret_cast(values), metas, offset, - n, d_counter); + n, counter); CudaCheckError(); } @@ -745,6 +757,92 @@ class HashTable { return h_counter; } + /** + * @brief Exports a certain number of the key-value-meta tuples which match + * specified condition from the hash table. + * + * @param n The maximum number of exported pairs. + * The value for @p pred should be a function with type `Pred` defined like + * the following example: + * + * ``` + * template + * __forceinline__ __device__ bool export_if_pred(const K& key, + * M& meta, + * const K& pattern, + * const M& threshold) { + * + * return meta > threshold; + * } + * ``` + * + * @param pred The predicate function with type Pred that returns `true` if + * the element should be exported. + * @param pattern The third user-defined argument to @p pred with key_type + * type. + * @param threshold The fourth user-defined argument to @p pred with meta_type + * type. + * @param offset The position of the key to remove. + * @param keys The keys to dump from GPU-accessible memory with shape (n). + * @param values The values to dump from GPU-accessible memory with shape + * (n, DIM). + * @param metas The metas to search on GPU-accessible memory with shape (n). + * @parblock + * If @p metas is `nullptr`, the meta for each key will not be returned. + * @endparblock + * + * @param stream The CUDA stream that is used to execute the operation. + * + * @return The number of elements dumped. + * + * @throw CudaException If the key-value size is too large for GPU shared + * memory. Reducing the value for @p n is currently required if this exception + * occurs. + */ + void export_batch_if(Pred& pred, const key_type& pattern, + const meta_type& threshold, size_type n, + size_type offset, size_type* d_counter, + key_type* keys, // (n) + value_type* values, // (n, DIM) + meta_type* metas = nullptr, // (n) + cudaStream_t stream = 0) const { + if (offset >= table_->capacity) { + CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream)); + return; + } + + Pred h_pred; + + n = std::min(table_->capacity - offset, n); + size_type meta_size = (metas == nullptr ? 0 : sizeof(meta_type)); + + std::shared_lock lock(mutex_, std::defer_lock); + if (!reach_max_capacity_) { + lock.lock(); + } + const size_t block_size = + std::min(shared_mem_size_ / 2 / + (sizeof(key_type) + sizeof(vector_type) + meta_size), + 1024UL); + + MERLIN_CHECK( + (block_size > 0), + "[merlin-kv] block_size <= 0, the K-V-M size may be too large!"); + const size_t shared_size = + (sizeof(key_type) + sizeof(vector_type) + meta_size) * block_size; + const int grid_size = (n - 1) / (block_size) + 1; + + CUDA_CHECK(cudaMemcpyFromSymbolAsync(&h_pred, pred, sizeof(Pred), 0, + cudaMemcpyDeviceToHost, stream)); + + dump_kernel + <<>>( + table_, h_pred, pattern, threshold, keys, + reinterpret_cast(values), metas, offset, n, + d_counter); + CudaCheckError(); + } + public: /** * @brief Indicates if the hash table has no elements. diff --git a/tests/merlin_hashtable_test.cc.cu b/tests/merlin_hashtable_test.cc.cu index eda9516ad..ed5c959cf 100644 --- a/tests/merlin_hashtable_test.cc.cu +++ b/tests/merlin_hashtable_test.cc.cu @@ -124,15 +124,24 @@ using Table = nv::merlin::HashTable; using TableOptions = nv::merlin::HashTableOptions; template -__forceinline__ __device__ bool erase_if_pred(const K& key, const M& meta, +__forceinline__ __device__ bool erase_if_pred(const K& key, M& meta, const K& pattern, const M& threshold) { return ((key & 0x7f > pattern) && (meta > threshold)); } -/* A demo of Pred for erase_if */ template -__device__ Table::Pred pred = erase_if_pred; +__device__ Table::Pred EraseIfPred = erase_if_pred; + +template +__forceinline__ __device__ bool export_if_pred(const K& key, M& meta, + const K& pattern, + const M& threshold) { + return meta > threshold; +} + +template +__device__ Table::Pred ExportIfPred = export_if_pred; void test_basic() { constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL; @@ -340,7 +349,8 @@ void test_erase_if_pred() { K pattern = 100; M threshold = 0; - size_t erase_num = table->erase_if(pred, pattern, threshold, stream); + size_t erase_num = + table->erase_if(EraseIfPred, pattern, threshold, stream); total_size = table->size(stream); CUDA_CHECK(cudaStreamSynchronize(stream)); ASSERT_TRUE((erase_num + total_size) == BUCKET_MAX_SIZE); @@ -637,9 +647,127 @@ void test_dynamic_rehash_on_multi_threads() { ASSERT_TRUE(table->capacity() == MAX_CAPACITY); } +void test_export_batch_if() { + constexpr uint64_t INIT_CAPACITY = 256UL; + constexpr uint64_t MAX_CAPACITY = INIT_CAPACITY; + constexpr uint64_t KEY_NUM = 128UL; + constexpr uint64_t TEST_TIMES = 1; + + K* h_keys; + M* h_metas; + Vector* h_vectors; + size_t h_dump_counter = 0; + + TableOptions options; + + options.init_capacity = INIT_CAPACITY; + options.max_capacity = MAX_CAPACITY; + options.max_hbm_for_vectors = nv::merlin::GB(16); + options.evict_strategy = nv::merlin::EvictStrategy::kCustomized; + + std::unique_ptr table = std::make_unique
(); + table->init(options); + + CUDA_CHECK(cudaMallocHost(&h_keys, KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMallocHost(&h_metas, KEY_NUM * sizeof(M))); + CUDA_CHECK(cudaMallocHost(&h_vectors, KEY_NUM * sizeof(Vector))); + + K* d_keys; + M* d_metas = nullptr; + Vector* d_vectors; + bool* d_found; + size_t* d_dump_counter; + + CUDA_CHECK(cudaMalloc(&d_keys, KEY_NUM * sizeof(K))); + CUDA_CHECK(cudaMalloc(&d_metas, KEY_NUM * sizeof(M))); + CUDA_CHECK(cudaMalloc(&d_vectors, KEY_NUM * sizeof(Vector))); + CUDA_CHECK(cudaMalloc(&d_found, KEY_NUM * sizeof(bool))); + CUDA_CHECK(cudaMalloc(&d_dump_counter, sizeof(size_t))); + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + uint64_t total_size = 0; + for (int i = 0; i < TEST_TIMES; i++) { + create_random_keys( + h_keys, h_metas, reinterpret_cast(h_vectors), KEY_NUM); + + CUDA_CHECK(cudaMemcpy(d_keys, h_keys, KEY_NUM * sizeof(K), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_metas, h_metas, KEY_NUM * sizeof(M), + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_vectors, h_vectors, KEY_NUM * sizeof(Vector), + cudaMemcpyHostToDevice)); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_TRUE(total_size == 0); + + table->insert_or_assign( + KEY_NUM, d_keys, reinterpret_cast(d_vectors), d_metas, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_TRUE(total_size == KEY_NUM); + + K pattern = 100; + M threshold = h_metas[size_t(KEY_NUM / 2)]; + + table->export_batch_if(ExportIfPred, pattern, threshold, + table->capacity(), 0, d_dump_counter, d_keys, + reinterpret_cast(d_vectors), d_metas, + stream); + + CUDA_CHECK(cudaMemcpy(&h_dump_counter, d_dump_counter, sizeof(size_t), + cudaMemcpyDeviceToHost)); + + size_t expected_export_count = 0; + for (int i = 0; i < KEY_NUM; i++) { + if (h_metas[i] > threshold) expected_export_count++; + } + ASSERT_TRUE(expected_export_count == h_dump_counter); + + CUDA_CHECK(cudaMemset(h_metas, 0, KEY_NUM * sizeof(M))); + CUDA_CHECK(cudaMemset(h_vectors, 0, KEY_NUM * sizeof(Vector))); + + CUDA_CHECK(cudaMemcpy(h_metas, d_metas, KEY_NUM * sizeof(M), + cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors, KEY_NUM * sizeof(Vector), + cudaMemcpyDeviceToHost)); + + for (int i = 0; i < h_dump_counter; i++) { + ASSERT_TRUE(h_metas[i] > threshold); + } + + table->clear(stream); + total_size = table->size(stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + ASSERT_TRUE(total_size == 0); + } + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaStreamDestroy(stream)); + + CUDA_CHECK(cudaMemcpy(h_vectors, d_vectors, KEY_NUM * sizeof(Vector), + cudaMemcpyDeviceToHost)); + + CUDA_CHECK(cudaFreeHost(h_keys)); + CUDA_CHECK(cudaFreeHost(h_metas)); + + CUDA_CHECK(cudaFree(d_keys)); + CUDA_CHECK(cudaFree(d_metas)) + CUDA_CHECK(cudaFree(d_vectors)); + CUDA_CHECK(cudaFree(d_found)); + CUDA_CHECK(cudaFree(d_dump_counter)); + CUDA_CHECK(cudaDeviceSynchronize()); + + CudaCheckError(); +} + TEST(MerlinHashTableTest, test_basic) { test_basic(); } TEST(MerlinHashTableTest, test_erase_if_pred) { test_erase_if_pred(); } TEST(MerlinHashTableTest, test_rehash) { test_rehash(); } TEST(MerlinHashTableTest, test_dynamic_rehash_on_multi_threads) { test_dynamic_rehash_on_multi_threads(); } +TEST(MerlinHashTableTest, test_export_batch_if) { test_export_batch_if(); }