Skip to content

Commit

Permalink
enable reserve key
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian Qian authored and rhdong committed May 17, 2024
1 parent d87034d commit e4aba46
Show file tree
Hide file tree
Showing 26 changed files with 481 additions and 236 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
.idea
.vscode
build

.clwb
cmake-build-debug/
docs/build
docs/source/README.md
docs/source/CONTRIBUTING.md
Expand Down
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,8 @@ add_executable(find_with_missed_keys_test tests/find_with_missed_keys_test.cc.cu
target_compile_features(find_with_missed_keys_test PUBLIC cxx_std_14)
set_target_properties(find_with_missed_keys_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(find_with_missed_keys_test gtest_main)

add_executable(reserved_keys_test tests/reserved_keys_test.cc.cu)
target_compile_features(reserved_keys_test PUBLIC cxx_std_14)
set_target_properties(reserved_keys_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main)
45 changes: 29 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,30 @@ The `score_type` must be `uint64_t`. For more detail, please refer to [`class Ev

It's recommended to keep the default configuration for the options ending with `*`.

| Name | Type | Default | Description |
|:------------------------|:----------------|:--------|:------------------------------------------------------|
| __init_capacity__ | size_t | 0 | The initial capacity of the hash table. |
| __max_capacity__ | size_t | 0 | The maximum capacity of the hash table. |
| __max_hbm_for_vectors__ | size_t | 0 | The maximum HBM for vectors, in bytes. |
| __dim__ | size_t | 64 | The dimension of the value vectors. |
| __max_bucket_size*__ | size_t | 128 | The length of each bucket. |
| __max_load_factor*__ | float | 0.5f | The max load factor before rehashing. |
| __block_size*__ | int | 128 | The default block size for CUDA kernels. |
| __io_block_size*__ | int | 1024 | The block size for IO CUDA kernels. |
| __device_id*__ | int | -1 | The ID of device. Managed internally when set to `-1` |
| __io_by_cpu*__ | bool | false | The flag indicating if the CPU handles IO. |

| Name | Type | Default | Description |
|:---------------------------|:-------|:--------|:------------------------------------------------------|
| __init_capacity__ | size_t | 0 | The initial capacity of the hash table. |
| __max_capacity__ | size_t | 0 | The maximum capacity of the hash table. |
| __max_hbm_for_vectors__ | size_t | 0 | The maximum HBM for vectors, in bytes. |
| __dim__ | size_t | 64 | The dimension of the value vectors. |
| __max_bucket_size*__ | size_t | 128 | The length of each bucket. |
| __max_load_factor*__ | float | 0.5f | The max load factor before rehashing. |
| __block_size*__ | int | 128 | The default block size for CUDA kernels. |
| __io_block_size*__ | int | 1024 | The block size for IO CUDA kernels. |
| __device_id*__ | int | -1 | The ID of device. Managed internally when set to `-1` |
| __io_by_cpu*__ | bool | false | The flag indicating if the CPU handles IO. |
| __reserved_key_start_bit__ | int | 0 | The start bit offset of reserved key in the 64 bit |

#### Reserved Keys
- The keys of `0xFFFFFFFFFFFFFFFD`, `0xFFFFFFFFFFFFFFFE`, and `0xFFFFFFFFFFFFFFFF` are reserved for internal using.
- Call set options.reserved_key_start_bit to change the reserved keys if the default one conflicted with your keys.
The valid range of reserved_key_start_bit is [0, 62] and the default value is 0, meaning the default reserved keys.
reserved_key_start_bit = 1 means using the insignificant bits 1 and 2 as the keys as the reserved keys,
in binary format, it looks like 111~11xx0, and the index 0 bit is 0 and all the other bits are positive, in this case the new reserved keys are
`FFFFFFFFFFFFFFFE`, `0xFFFFFFFFFFFFFFFC`, `0xFFFFFFFFFFFFFFF8`, and `0xFFFFFFFFFFFFFFFA`
reserved_key_start_bit = 2, in binary format, it looks like 111~11xx10, bit offset 0 are always 0 for any reserved_key_start_bit != 0,
- if you change the reserved_key_start_bit, you should use same value for save/load
For more detail, please refer to [`init_reserved_keys`](https://github.com/search?q=repo%3ANVIDIA-Merlin%2FHierarchicalKV%20init_reserved_keys&type=code).
For more detail, please refer to [`struct HashTableOptions`](https://github.com/NVIDIA-Merlin/HierarchicalKV/blob/master/include/merlin_hashtable.cuh#L60).

### How to use:
Expand Down Expand Up @@ -142,13 +153,10 @@ int main(int argc, char *argv[])
}

```
### Usage restrictions
- The `key_type` must be `int64_t` or `uint64_t`.
- The `score_type` must be `uint64_t`.
- The keys of `0xFFFFFFFFFFFFFFFC`, `0xFFFFFFFFFFFFFFFD`, `0xFFFFFFFFFFFFFFFE`, and `0xFFFFFFFFFFFFFFFF` are reserved for internal using.
## Contributors
HierarchicalKV is co-maintianed by [NVIDIA Merlin Team](https://github.com/NVIDIA-Merlin) and NVIDIA product end-users,
Expand All @@ -172,6 +180,11 @@ cd HierarchicalKV && mkdir -p build && cd build
cmake -DCMAKE_BUILD_TYPE=Release -Dsm=80 .. && make -j
```

For Debug:
```shell
cmake -DCMAKE_BUILD_TYPE=Debug -Dsm=80 .. && make -j
```

For Benchmark:
```shell
./merlin_hashtable_benchmark
Expand Down
8 changes: 4 additions & 4 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ __global__ void remove_kernel(const Table<K, V, S>* __restrict table,
t += blockDim.x * gridDim.x) {
int key_idx = t / TILE_SIZE;
K find_key = keys[key_idx];
if (IS_RESERVED_KEY(find_key)) continue;
if (IS_RESERVED_KEY<K>(find_key)) continue;

int key_pos = -1;

Expand Down Expand Up @@ -719,7 +719,7 @@ __global__ void remove_kernel(const Table<K, V, S>* __restrict table,
bucket->keys(key_offset)->load(cuda::std::memory_order_relaxed);
current_score =
bucket->scores(key_offset)->load(cuda::std::memory_order_relaxed);
if (!IS_RESERVED_KEY(current_key)) {
if (!IS_RESERVED_KEY<K>(current_key)) {
if (pred(current_key, current_score, pattern, threshold)) {
atomicAdd(count, 1);
key_pos = key_offset;
Expand Down Expand Up @@ -782,7 +782,7 @@ __global__ void dump_kernel(const Table<K, V, S>* __restrict table,
const int key_idx{static_cast<int>((tid + offset) % bucket_max_size)};
const K key{(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed)};

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
size_t local_index{atomicAdd(&block_acc, 1)};
block_tuples[local_index] = {
key, &bucket->vectors[key_idx * dim],
Expand Down Expand Up @@ -846,7 +846,7 @@ __global__ void dump_kernel(const Table<K, V, S>* __restrict table,
(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed);
S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed);

if (!IS_RESERVED_KEY(key) && pred(key, score, pattern, threshold)) {
if (!IS_RESERVED_KEY<K>(key) && pred(key, score, pattern, threshold)) {
size_t local_index = atomicAdd(&block_acc, 1);
block_result_key[local_index] = key;
for (int i = 0; i < dim; i++) {
Expand Down
4 changes: 2 additions & 2 deletions include/merlin/core_kernels/accum_or_assign.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ __global__ void accum_or_assign_kernel_with_io(

const K insert_key = keys[key_idx];

if (IS_RESERVED_KEY(insert_key)) continue;
if (IS_RESERVED_KEY<K>(insert_key)) continue;

const S insert_score =
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);
Expand Down Expand Up @@ -222,7 +222,7 @@ __global__ void accum_or_assign_kernel(

const K insert_key = keys[key_idx];

if (IS_RESERVED_KEY(insert_key)) continue;
if (IS_RESERVED_KEY<K>(insert_key)) continue;

const S insert_score =
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);
Expand Down
2 changes: 1 addition & 1 deletion include/merlin/core_kernels/contains.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ __global__ void contains_kernel(const Table<K, V, S>* __restrict table,
int key_idx = t / TILE_SIZE;

const K find_key = keys[key_idx];
if (IS_RESERVED_KEY(find_key)) continue;
if (IS_RESERVED_KEY<K>(find_key)) continue;

int key_pos = -1;
int src_lane = -1;
Expand Down
12 changes: 6 additions & 6 deletions include/merlin/core_kernels/find_or_insert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ __global__ void tlp_v1_find_or_insert_kernel_with_io(
key = keys[kv_idx];
score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -272,7 +272,7 @@ __global__ void tlp_v2_find_or_insert_kernel_with_io(
key = keys[kv_idx];
score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -621,7 +621,7 @@ __global__ void pipeline_find_or_insert_kernel_with_io(
S* sm_param_scores = SMM::param_scores(smem);
__pipeline_memcpy_async(sm_param_scores + tx, scores + kv_idx, sizeof(S));
}
if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -1326,7 +1326,7 @@ __global__ void find_or_insert_kernel_with_io(

const K find_or_insert_key = keys[key_idx];

if (IS_RESERVED_KEY(find_or_insert_key)) continue;
if (IS_RESERVED_KEY<K>(find_or_insert_key)) continue;

const S find_or_insert_score =
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);
Expand Down Expand Up @@ -1463,7 +1463,7 @@ __global__ void find_or_insert_kernel_lock_key_hybrid(

score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -1713,7 +1713,7 @@ __global__ void find_or_insert_kernel(

const K find_or_insert_key = keys[key_idx];

if (IS_RESERVED_KEY(find_or_insert_key)) continue;
if (IS_RESERVED_KEY<K>(find_or_insert_key)) continue;

const S find_or_insert_score =
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);
Expand Down
4 changes: 2 additions & 2 deletions include/merlin/core_kernels/find_ptr_or_insert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ __global__ void find_or_insert_ptr_kernel_lock_key(
key = keys[kv_idx];
score = ScoreFunctor::desired_when_missed(scores, kv_idx, global_epoch);

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -283,7 +283,7 @@ __global__ void find_ptr_or_insert_kernel(

const K find_or_insert_key = keys[key_idx];

if (IS_RESERVED_KEY(find_or_insert_key)) continue;
if (IS_RESERVED_KEY<K>(find_or_insert_key)) continue;

const S find_or_insert_score =
ScoreFunctor::desired_when_missed(scores, key_idx, global_epoch);
Expand Down
6 changes: 3 additions & 3 deletions include/merlin/core_kernels/lookup.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ __global__ void lookup_kernel_with_io(
int key_idx = t / TILE_SIZE;

const K find_key = keys[key_idx];
if (IS_RESERVED_KEY(find_key)) continue;
if (IS_RESERVED_KEY<K>(find_key)) continue;

V* find_value = values + key_idx * dim;

Expand Down Expand Up @@ -1015,7 +1015,7 @@ __device__ void tlp_lookup_kernel_hybrid_impl(
if (kv_idx < n) {
key = keys[kv_idx];
if (dst_offset) dst_offset[kv_idx] = kv_idx;
if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -1140,7 +1140,7 @@ __device__ void lookup_kernel_impl(
int key_idx = t / TILE_SIZE;

const K find_key = keys[key_idx];
if (IS_RESERVED_KEY(find_key)) continue;
if (IS_RESERVED_KEY<K>(find_key)) continue;

int key_pos = -1;
int src_lane = -1;
Expand Down
4 changes: 2 additions & 2 deletions include/merlin/core_kernels/lookup_ptr.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ __global__ void tlp_lookup_ptr_kernel_with_filter(
uint32_t key_pos = {0};
if (kv_idx < n) {
key = keys[kv_idx];
if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -140,7 +140,7 @@ __global__ void lookup_ptr_kernel(const Table<K, V, S>* __restrict table,
int key_idx = t / TILE_SIZE;

const K find_key = keys[key_idx];
if (IS_RESERVED_KEY(find_key)) continue;
if (IS_RESERVED_KEY<K>(find_key)) continue;

int key_pos = -1;
int src_lane = -1;
Expand Down
8 changes: 4 additions & 4 deletions include/merlin/core_kernels/update.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ __global__ void tlp_update_kernel_with_io(
if (kv_idx < n) {
key = keys[kv_idx];

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -680,7 +680,7 @@ __global__ void update_kernel_with_io(

const K update_key = keys[key_idx];

if (IS_RESERVED_KEY(update_key)) continue;
if (IS_RESERVED_KEY<K>(update_key)) continue;

const V* update_value = values + key_idx * dim;

Expand Down Expand Up @@ -773,7 +773,7 @@ __global__ void tlp_update_kernel_hybrid(
if (kv_idx < n) {
key = keys[kv_idx];
if (src_offset) src_offset[kv_idx] = kv_idx;
if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -878,7 +878,7 @@ __global__ void update_kernel(const Table<K, V, S>* __restrict table,

const K update_key = keys[key_idx];

if (IS_RESERVED_KEY(update_key)) continue;
if (IS_RESERVED_KEY<K>(update_key)) continue;

size_t bkt_idx = 0;
size_t start_idx = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/merlin/core_kernels/update_score.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ __global__ void tlp_update_score_kernel(Bucket<K, V, S>* __restrict__ buckets,
if (kv_idx < n) {
key = keys[kv_idx];

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -549,7 +549,7 @@ __global__ void update_score_kernel(const Table<K, V, S>* __restrict table,

const K update_key = keys[key_idx];

if (IS_RESERVED_KEY(update_key)) continue;
if (IS_RESERVED_KEY<K>(update_key)) continue;

size_t bkt_idx = 0;
size_t start_idx = 0;
Expand Down
8 changes: 4 additions & 4 deletions include/merlin/core_kernels/update_values.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ __global__ void tlp_update_values_kernel_with_io(
if (kv_idx < n) {
key = keys[kv_idx];

if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -638,7 +638,7 @@ __global__ void update_values_kernel_with_io(

const K update_key = keys[key_idx];

if (IS_RESERVED_KEY(update_key)) continue;
if (IS_RESERVED_KEY<K>(update_key)) continue;

const V* update_value = values + key_idx * dim;

Expand Down Expand Up @@ -724,7 +724,7 @@ __global__ void tlp_update_values_kernel_hybrid(
if (kv_idx < n) {
key = keys[kv_idx];
if (src_offset) src_offset[kv_idx] = kv_idx;
if (!IS_RESERVED_KEY(key)) {
if (!IS_RESERVED_KEY<K>(key)) {
const K hashed_key = Murmur3HashDevice(key);
target_digests = digests_from_hashed<K>(hashed_key);
uint64_t global_idx =
Expand Down Expand Up @@ -823,7 +823,7 @@ __global__ void update_values_kernel(const Table<K, V, S>* __restrict table,

const K update_key = keys[key_idx];

if (IS_RESERVED_KEY(update_key)) continue;
if (IS_RESERVED_KEY<K>(update_key)) continue;

size_t bkt_idx = 0;
size_t start_idx = 0;
Expand Down
Loading

0 comments on commit e4aba46

Please sign in to comment.