Skip to content

Commit

Permalink
[Feat] Set bucket number to the min needed for saving mem if no rehash
Browse files Browse the repository at this point in the history
- Add bucket_count API
  • Loading branch information
rhdong committed Apr 1, 2023
1 parent 9260627 commit f8a15cc
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
24 changes: 16 additions & 8 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,23 @@ void create_table(Table<K, V, M>** table, const size_t dim,
CUDA_CHECK(cudaMemset(*table, 0, sizeof(Table<K, V, M>)));
(*table)->dim = dim;
(*table)->bucket_max_size = bucket_max_size;
(*table)->max_size = max_size;
(*table)->max_size = std::max(init_size, max_size);
(*table)->tile_size = tile_size;
(*table)->is_pure_hbm = true;
(*table)->bytes_per_slice = get_slice_size<K, V, M>(table);

(*table)->buckets_num = 1;
while ((*table)->buckets_num * (*table)->bucket_max_size < init_size) {
(*table)->buckets_num *= 2;
// The bucket number will be the minimum needed for saving memory if no
// rehash.
if ((init_size * 2) > (*table)->max_size) {
(*table)->buckets_num =
1 + (((*table)->max_size - 1) / (*table)->bucket_max_size);
} else {
(*table)->buckets_num = 1;
while ((*table)->buckets_num * (*table)->bucket_max_size < init_size) {
(*table)->buckets_num *= 2;
}
}

(*table)->capacity = (*table)->buckets_num * (*table)->bucket_max_size;
(*table)->max_hbm_for_vectors = max_hbm_for_vectors;
(*table)->remaining_hbm_for_vectors = max_hbm_for_vectors;
Expand Down Expand Up @@ -341,7 +349,7 @@ __forceinline__ __device__ void defragmentation_for_rehash(
break;
}
hashed_key = Murmur3HashDevice(find_key);
global_idx = hashed_key & (buckets_num * bucket_max_size - 1);
global_idx = hashed_key % (buckets_num * bucket_max_size);
start_idx = global_idx % bucket_max_size;

if ((start_idx <= empty_pos && empty_pos < key_idx) ||
Expand Down Expand Up @@ -526,7 +534,7 @@ __global__ void rehash_kernel_for_fast_mode(
if (target_key != static_cast<K>(EMPTY_KEY) &&
target_key != static_cast<K>(RECLAIM_KEY)) {
K hashed_key = Murmur3HashDevice(target_key);
global_idx = hashed_key & (buckets_num * bucket_max_size - 1);
global_idx = hashed_key % (buckets_num * bucket_max_size);
uint32_t new_bkt_idx = global_idx / bucket_max_size;
if (new_bkt_idx != bkt_idx) {
start_idx = global_idx % bucket_max_size;
Expand Down Expand Up @@ -768,7 +776,7 @@ __forceinline__ __device__ Bucket<K, V, M>* get_key_position(
Bucket<K, V, M>* __restrict buckets, const K key, size_t& bkt_idx,
size_t& start_idx, const size_t buckets_num, const size_t bucket_max_size) {
uint32_t hashed_key = Murmur3HashDevice(key);
size_t global_idx = hashed_key & (buckets_num * bucket_max_size - 1);
size_t global_idx = hashed_key % (buckets_num * bucket_max_size);
bkt_idx = global_idx / bucket_max_size;
start_idx = global_idx % bucket_max_size;
return buckets + bkt_idx;
Expand Down Expand Up @@ -1506,7 +1514,7 @@ __global__ void accum_kernel(
size_t key_idx = t / TILE_SIZE;
K insert_key = *(keys + key_idx);
K hashed_key = Murmur3HashDevice(insert_key);
size_t global_idx = hashed_key & (buckets_num * bucket_max_size - 1);
size_t global_idx = hashed_key % (buckets_num * bucket_max_size);
size_t bkt_idx = global_idx / bucket_max_size;
size_t start_idx = global_idx % bucket_max_size;

Expand Down
9 changes: 8 additions & 1 deletion include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class HashTable {
*
* @param ignore_evict_strategy A boolean option indicating whether if
* the insert_or_assign ignores the evict strategy of table with current
* metas anyway. If true, it does not check whether the metas confroms to
* metas anyway. If true, it does not check whether the metas conforms to
* the evict strategy. If false, it requires the metas follow the evict
* strategy of table.
*/
Expand Down Expand Up @@ -1280,6 +1280,13 @@ class HashTable {
return options_.max_bucket_size;
}

/**
* @brief Returns the number of buckets in the table.
*
* @return The number of buckets in the table.
*/
size_type bucket_count() const noexcept { return table_->buckets_num; }

/**
* @brief Save keys, vectors, metas in table to file or files.
*
Expand Down
15 changes: 11 additions & 4 deletions tests/merlin_hashtable_test.cc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ template <class K, class M>
__device__ Table::Pred ExportIfPred = export_if_pred<K, M>;

void test_basic(size_t max_hbm_for_vectors) {
constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL;
constexpr uint64_t BUCKET_MAX_SIZE = 128;
constexpr uint64_t INIT_CAPACITY = 64 * 1024 * 1024UL - (128 + 1);
constexpr uint64_t MAX_CAPACITY = INIT_CAPACITY;
constexpr uint64_t KEY_NUM = 1 * 1024 * 1024UL;
constexpr uint64_t TEST_TIMES = 1;
Expand All @@ -71,6 +72,7 @@ void test_basic(size_t max_hbm_for_vectors) {
options.init_capacity = INIT_CAPACITY;
options.max_capacity = MAX_CAPACITY;
options.dim = DIM;
options.max_bucket_size = BUCKET_MAX_SIZE;
options.max_hbm_for_vectors = nv::merlin::GB(max_hbm_for_vectors);
options.evict_strategy = nv::merlin::EvictStrategy::kCustomized;

Expand Down Expand Up @@ -113,6 +115,9 @@ void test_basic(size_t max_hbm_for_vectors) {
for (int i = 0; i < TEST_TIMES; i++) {
std::unique_ptr<Table> table = std::make_unique<Table>();
table->init(options);

ASSERT_EQ(table->bucket_count(),
524287); // 1 + (INIT_CAPACITY / options.bucket_max_size)
total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
ASSERT_EQ(total_size, 0);
Expand Down Expand Up @@ -748,7 +753,7 @@ void test_rehash_on_big_batch(size_t max_hbm_for_vectors) {

void test_dynamic_rehash_on_multi_threads(size_t max_hbm_for_vectors) {
constexpr uint64_t BUCKET_MAX_SIZE = 128ul;
constexpr uint64_t INIT_CAPACITY = 4 * 1024;
constexpr uint64_t INIT_CAPACITY = 4 * 1024 - BUCKET_MAX_SIZE - 1;
constexpr uint64_t MAX_CAPACITY = 16 * 1024 * INIT_CAPACITY;
constexpr uint64_t KEY_NUM = 256;
constexpr uint64_t THREAD_N = 8;
Expand All @@ -767,6 +772,7 @@ void test_dynamic_rehash_on_multi_threads(size_t max_hbm_for_vectors) {

std::shared_ptr<Table> table = std::make_shared<Table>();
table->init(options);
ASSERT_EQ(table->bucket_count(), 32);

auto worker_function = [&table, KEY_NUM, options](int task_n) {
K* h_keys;
Expand All @@ -790,7 +796,7 @@ void test_dynamic_rehash_on_multi_threads(size_t max_hbm_for_vectors) {
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));

while (table->capacity() < MAX_CAPACITY) {
while (table->capacity() * 2 < MAX_CAPACITY) {
test_util::create_random_keys<K, M, V, DIM>(h_keys, nullptr, h_vectors,
KEY_NUM);
CUDA_CHECK(cudaMemcpy(d_keys, h_keys, KEY_NUM * sizeof(K),
Expand Down Expand Up @@ -862,7 +868,7 @@ void test_dynamic_rehash_on_multi_threads(size_t max_hbm_for_vectors) {
for (auto& th : threads) {
th.join();
}
ASSERT_EQ(table->capacity(), MAX_CAPACITY);
ASSERT_GE(table->capacity() * 2, MAX_CAPACITY);
}

void test_export_batch_if(size_t max_hbm_for_vectors) {
Expand Down Expand Up @@ -1586,6 +1592,7 @@ void test_evict_strategy_customized_advanced(size_t max_hbm_for_vectors) {
for (int i = 0; i < TEST_TIMES; i++) {
std::unique_ptr<Table> table = std::make_unique<Table>();
table->init(options);
ASSERT_EQ(table->bucket_count(), BUCKET_NUM);

total_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand Down

0 comments on commit f8a15cc

Please sign in to comment.