-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
167 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#include <math.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <algorithm> | ||
#include <array> | ||
#include <iostream> | ||
#include <thread> | ||
#include <vector> | ||
#include <map> | ||
#include <unordered_map> | ||
#include "merlin_hashtable.cuh" | ||
#include "merlin/types.cuh" | ||
#include "test_util.cuh" | ||
|
||
using K = uint64_t; | ||
using V = float; | ||
using S = uint64_t; | ||
using i64 = int64_t; | ||
using u64 = uint64_t; | ||
using f32 = float; | ||
using EvictStrategy = nv::merlin::EvictStrategy; | ||
using TableOptions = nv::merlin::HashTableOptions; | ||
|
||
template <class K, class S> | ||
struct ExportIfPredFunctor { | ||
__forceinline__ __device__ bool operator()(const K& key, S& score, | ||
const K& pattern, | ||
const S& threshold) { | ||
return score < threshold; | ||
} | ||
}; | ||
|
||
void test_export_batch_if() { | ||
constexpr uint64_t CAP = 1024ul; | ||
size_t n = 256; | ||
size_t n0 = 127; | ||
size_t n1 = 128; | ||
size_t n2 = 163; | ||
size_t dim = 32; | ||
size_t table_size = 0; | ||
i64 pattern = 0; | ||
u64 threshold = 40; | ||
|
||
cudaStream_t stream; | ||
CUDA_CHECK(cudaStreamCreate(&stream)); | ||
|
||
TableOptions options; | ||
options.init_capacity = CAP; | ||
options.max_capacity = CAP; | ||
options.dim = dim; | ||
options.max_hbm_for_vectors = nv::merlin::GB(100); | ||
using Table = nv::merlin::HashTable<i64, f32, u64, EvictStrategy::kCustomized>; | ||
|
||
std::unique_ptr<Table> table = std::make_unique<Table>(); | ||
table->init(options); | ||
|
||
test_util::KVMSBuffer<i64, f32, u64> buffer0; | ||
buffer0.Reserve(n0, dim, stream); | ||
buffer0.ToRange(0, 1, stream); | ||
buffer0.Setscore((u64)15, stream); | ||
table->insert_or_assign(n0, buffer0.keys_ptr(), buffer0.values_ptr(), buffer0.scores_ptr(), stream, true, false); | ||
table_size = table->size(stream); | ||
CUDA_CHECK(cudaStreamSynchronize(stream)); | ||
MERLIN_EXPECT_TRUE(table_size == n0, "Invalid table size."); | ||
|
||
test_util::KVMSBuffer<i64, f32, u64> buffer1; | ||
buffer1.Reserve(n1, dim, stream); | ||
buffer1.ToRange(n0, 1, stream); | ||
buffer1.Setscore((u64)30, stream); | ||
table->insert_or_assign(n1, buffer1.keys_ptr(), buffer1.values_ptr(), buffer1.scores_ptr(), stream, true, false); | ||
table_size = table->size(stream); | ||
CUDA_CHECK(cudaStreamSynchronize(stream)); | ||
MERLIN_EXPECT_TRUE(table_size == n0 + n1, "Invalid table size."); | ||
|
||
test_util::KVMSBuffer<i64, f32, u64> buffer2; | ||
buffer2.Reserve(n2, dim, stream); | ||
buffer2.ToRange(n0 + n1, 1, stream); | ||
buffer2.Setscore((u64)45, stream); | ||
table->insert_or_assign(n2, buffer2.keys_ptr(), buffer2.values_ptr(), buffer2.scores_ptr(), stream, true, false); | ||
table_size = table->size(stream); | ||
CUDA_CHECK(cudaStreamSynchronize(stream)); | ||
MERLIN_EXPECT_TRUE(table_size == n0 + n1 + n2, "Invalid table size."); | ||
|
||
test_util::KVMSBuffer<i64, f32, u64> buffer_out; | ||
buffer_out.Reserve(CAP, dim, stream); | ||
buffer_out.ToZeros(stream); | ||
|
||
size_t* d_cnt = nullptr; | ||
size_t h_cnt = 0; | ||
CUDA_CHECK(cudaMallocAsync(&d_cnt, sizeof(size_t), stream)); | ||
CUDA_CHECK(cudaMemsetAsync(d_cnt, 0, sizeof(size_t), stream)); | ||
CUDA_CHECK(cudaStreamSynchronize(stream)); | ||
table->export_batch_if<ExportIfPredFunctor>(pattern, threshold, | ||
static_cast<size_t>(CAP), 0, | ||
d_cnt, buffer_out.keys_ptr(), | ||
buffer_out.values_ptr(), | ||
buffer_out.scores_ptr(), | ||
stream); | ||
CUDA_CHECK(cudaMemcpyAsync(&h_cnt, d_cnt, sizeof(size_t), cudaMemcpyDeviceToHost, stream)); | ||
CUDA_CHECK(cudaStreamSynchronize(stream)); | ||
MERLIN_EXPECT_TRUE(h_cnt == n0 + n1, "export_batch_if get invalid cnt."); | ||
|
||
buffer_out.SyncData(false, stream); | ||
CUDA_CHECK(cudaStreamSynchronize(stream)); | ||
|
||
std::unordered_map<i64, u64> record; | ||
for (size_t i = 0; i < h_cnt; i++) { | ||
i64 key = buffer_out.keys_ptr(false)[i]; | ||
u64 score = buffer_out.scores_ptr(false)[i]; | ||
MERLIN_EXPECT_TRUE(key == static_cast<i64>(score), ""); | ||
record[key] = score; | ||
for (int j = 0; j < dim; j++) { | ||
f32 value = buffer_out.values_ptr(false)[i * dim + j]; | ||
MERLIN_EXPECT_TRUE(key == static_cast<i64>(value), ""); | ||
} | ||
} | ||
MERLIN_EXPECT_TRUE(record.size() == n0 + n1 + n2, ""); | ||
printf("done\n"); | ||
} | ||
|
||
int main() { | ||
test_export_batch_if(); | ||
return 0; | ||
} |