Skip to content

Commit

Permalink
[Feat] add export_if_batch API
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Oct 10, 2022
1 parent d6354d6 commit e718558
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 83 deletions.
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "tests/googletest"]
path = tests/googletest
url = https://github.com/google/googletest.git
ignore = dirty
63 changes: 63 additions & 0 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1342,5 +1342,68 @@ __global__ void dump_kernel(const Table<K, V, M, DIM>* __restrict table,
}
}

/* Dump with meta. */
template <class K, class V, class M, size_t DIM>
__global__ void dump_kernel(const Table<K, V, M, DIM>* __restrict table,
const EraseIfPredictInternal<K, M> pred,
const K pattern, const M threshold, K* d_key,
V* __restrict d_val, M* __restrict d_meta,
const size_t offset, const size_t search_length,
size_t* d_dump_counter) {
extern __shared__ unsigned char s[];
K* smem = (K*)s;
K* block_result_key = smem;
V* block_result_val = (V*)&(smem[blockDim.x]);
M* block_result_meta = (M*)&(block_result_val[blockDim.x]);
__shared__ size_t block_acc;
__shared__ size_t global_acc;
const size_t bucket_max_size = table->bucket_max_size;

const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;

if (threadIdx.x == 0) {
block_acc = 0;
}
__syncthreads();

if (tid < search_length) {
int bkt_idx = (tid + offset) / bucket_max_size;
int key_idx = (tid + offset) % bucket_max_size;
Bucket<K, V, M, DIM>* bucket = &(table->buckets[bkt_idx]);

K key = bucket->keys[key_idx];
M meta = bucket->metas[key_idx].val;

if (key != EMPTY_KEY && pred(key, meta, pattern, threshold)) {
size_t local_index = atomicAdd(&block_acc, 1);
block_result_key[local_index] = bucket->keys[key_idx];
for (int i = 0; i < DIM; i++) {
atomicExch(&(block_result_val[local_index].values[i]),
bucket->vectors[key_idx].values[i]);
}
if (d_meta != nullptr) {
block_result_meta[local_index] = bucket->metas[key_idx].val;
}
}
}
__syncthreads();

if (threadIdx.x == 0) {
global_acc = atomicAdd(d_dump_counter, block_acc);
}
__syncthreads();

if (threadIdx.x < block_acc) {
d_key[global_acc + threadIdx.x] = block_result_key[threadIdx.x];
for (int i = 0; i < DIM; i++) {
d_val[global_acc + threadIdx.x].values[i] =
block_result_val[threadIdx.x].values[i];
}
if (d_meta != nullptr) {
d_meta[global_acc + threadIdx.x] = block_result_meta[threadIdx.x];
}
}
}

} // namespace merlin
} // namespace nv
36 changes: 0 additions & 36 deletions include/merlin/managed.cuh

This file was deleted.

2 changes: 1 addition & 1 deletion include/merlin/types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct Table {
template <class K, class M>
using EraseIfPredictInternal =
bool (*)(const K& key, ///< iterated key in table
const M& meta, ///< iterated meta in table
M& meta, ///< iterated meta in table
const K& pattern, ///< input key from caller
const M& threshold ///< input meta from caller
);
Expand Down
34 changes: 0 additions & 34 deletions include/merlin/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -324,40 +324,6 @@ void realloc_managed(P* ptr, size_t old_size, size_t new_size) {
return;
}

template <class K>
__forceinline__ __device__ constexpr bool key_compare(const K* k1,
const K* k2) {
auto __lhs_c = reinterpret_cast<unsigned char const*>(k1);
auto __rhs_c = reinterpret_cast<unsigned char const*>(k2);

#pragma unroll
for (int i = 0; i < sizeof(K); i++) {
auto const __lhs_v = *__lhs_c++;
auto const __rhs_v = *__rhs_c++;
if (__lhs_v != __rhs_v) {
return false;
}
}
return true;
}

template <class K>
__forceinline__ __device__ constexpr bool key_empty(const K* k) {
constexpr K empty_key = EMPTY_KEY;
auto __lhs_c = reinterpret_cast<unsigned char const*>(k);
auto __rhs_c = reinterpret_cast<unsigned char const*>(&empty_key);

#pragma unroll
for (int i = 0; i < sizeof(K); i++) {
auto const __lhs_v = *__lhs_c++;
auto const __rhs_v = *__rhs_c++;
if (__lhs_v != __rhs_v) {
return false;
}
}
return true;
}

template <typename mutex, uint32_t TILE_SIZE, bool THREAD_SAFE = true>
__forceinline__ __device__ void lock(
const cg::thread_block_tile<TILE_SIZE>& tile, mutex& set_mutex) {
Expand Down
114 changes: 106 additions & 8 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,37 @@ struct HashTableOptions {
* @brief A customizable template function indicates which keys should be
* erased from the hash table by returning `true`.
*
* @note The `erase_if` API traverses all of the items by this function and the
* items that return `true` are removed.
* @note The `erase_if` or `export_batch_if` API traverses all of the items by
* this function and the items that return `true` are removed or exported.
*
* Example:
* Example for erase_if:
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool erase_if_pred(const K& key,
* const M& meta,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return ((key & 0xFFFF000000000000 == pattern) &&
* (meta < threshold));
* }
* ```
*
* Example for export_batch_if:
* ```
* template <class K, class M>
* __forceinline__ __device__ bool export_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
* return meta >= threshold;
* }
* ```
*/
template <class K, class M>
using EraseIfPredict = bool (*)(
const K& key, ///< The traversed key in a hash table.
const M& meta, ///< The traversed meta in a hash table.
M& meta, ///< The traversed meta in a hash table.
const K& pattern, ///< The key pattern to compare with the `key` argument.
const M& threshold ///< The threshold to compare with the `meta` argument.
);
Expand Down Expand Up @@ -676,6 +687,7 @@ class HashTable {
*
* @param n The maximum number of exported pairs.
* @param offset The position of the key to remove.
* @param counter The position of the key to remove.
* @param keys The keys to dump from GPU-accessible memory with shape (n).
* @param values The values to dump from GPU-accessible memory with shape
* (n, DIM).
Expand All @@ -692,13 +704,13 @@ class HashTable {
* memory. Reducing the value for @p n is currently required if this exception
* occurs.
*/
void export_batch(size_type n, size_type offset, size_type* d_counter,
void export_batch(size_type n, size_type offset, size_type* counter,
key_type* keys, // (n)
value_type* values, // (n, DIM)
meta_type* metas = nullptr, // (n)
cudaStream_t stream = 0) const {
if (offset >= table_->capacity) {
CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream));
CUDA_CHECK(cudaMemsetAsync(counter, 0, sizeof(size_type), stream));
return;
}
n = std::min(table_->capacity - offset, n);
Expand All @@ -724,7 +736,7 @@ class HashTable {
dump_kernel<key_type, vector_type, meta_type, DIM>
<<<grid_size, block_size, shared_size, stream>>>(
table_, keys, reinterpret_cast<vector_type*>(values), metas, offset,
n, d_counter);
n, counter);
CudaCheckError();
}

Expand All @@ -745,6 +757,92 @@ class HashTable {
return h_counter;
}

/**
* @brief Exports a certain number of the key-value-meta tuples which match
* specified condition from the hash table.
*
* @param n The maximum number of exported pairs.
* The value for @p pred should be a function with type `Pred` defined like
* the following example:
*
* ```
* template <class K, class M>
* __forceinline__ __device__ bool export_if_pred(const K& key,
* M& meta,
* const K& pattern,
* const M& threshold) {
*
* return meta > threshold;
* }
* ```
*
* @param pred The predicate function with type Pred that returns `true` if
* the element should be exported.
* @param pattern The third user-defined argument to @p pred with key_type
* type.
* @param threshold The fourth user-defined argument to @p pred with meta_type
* type.
* @param offset The position of the key to remove.
* @param keys The keys to dump from GPU-accessible memory with shape (n).
* @param values The values to dump from GPU-accessible memory with shape
* (n, DIM).
* @param metas The metas to search on GPU-accessible memory with shape (n).
* @parblock
* If @p metas is `nullptr`, the meta for each key will not be returned.
* @endparblock
*
* @param stream The CUDA stream that is used to execute the operation.
*
* @return The number of elements dumped.
*
* @throw CudaException If the key-value size is too large for GPU shared
* memory. Reducing the value for @p n is currently required if this exception
* occurs.
*/
void export_batch_if(Pred& pred, const key_type& pattern,
const meta_type& threshold, size_type n,
size_type offset, size_type* d_counter,
key_type* keys, // (n)
value_type* values, // (n, DIM)
meta_type* metas = nullptr, // (n)
cudaStream_t stream = 0) const {
if (offset >= table_->capacity) {
CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream));
return;
}

Pred h_pred;

n = std::min(table_->capacity - offset, n);
size_type meta_size = (metas == nullptr ? 0 : sizeof(meta_type));

std::shared_lock<std::shared_timed_mutex> lock(mutex_, std::defer_lock);
if (!reach_max_capacity_) {
lock.lock();
}
const size_t block_size =
std::min(shared_mem_size_ / 2 /
(sizeof(key_type) + sizeof(vector_type) + meta_size),
1024UL);

MERLIN_CHECK(
(block_size > 0),
"[merlin-kv] block_size <= 0, the K-V-M size may be too large!");
const size_t shared_size =
(sizeof(key_type) + sizeof(vector_type) + meta_size) * block_size;
const int grid_size = (n - 1) / (block_size) + 1;

CUDA_CHECK(cudaMemcpyFromSymbolAsync(&h_pred, pred, sizeof(Pred), 0,
cudaMemcpyDeviceToHost, stream));

dump_kernel<key_type, vector_type, meta_type, DIM>
<<<grid_size, block_size, shared_size, stream>>>(
table_, h_pred, pattern, threshold, keys,
reinterpret_cast<vector_type*>(values), metas, offset, n,
d_counter);
CudaCheckError();
}

public:
/**
* @brief Indicates if the hash table has no elements.
Expand Down
Loading

0 comments on commit e718558

Please sign in to comment.