Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
oppenheimli committed Aug 20, 2024
1 parent f4234b4 commit 7aaabe7
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 49 deletions.
140 changes: 131 additions & 9 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -927,38 +927,160 @@ __global__ void dump_kernel_v2(const Table<K, V, S>* __restrict table,

for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) {
size_t bkt_idx = (ii + offset) / bucket_max_size;
int key_idx = (ii + offset) % bucket_max_size;
int leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE;
size_t key_idx = (ii + offset) % bucket_max_size;
size_t leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE;
Bucket<K, V, S>* bucket = &(buckets[bkt_idx]);

const K key =
(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed);
S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed);

bool match =
(!IS_RESERVED_KEY<K>(key)) && pred(key, score, pattern, threshold);
unsigned int vote = g.ballot(match);
int tile_cnt = __popc(vote);
int tile_offset = 0;
size_t tile_offset = 0;
if (g.thread_rank() == 0) {
tile_offset = static_cast<int>(
atomicAdd(d_dump_counter, static_cast<size_t>(tile_cnt)));
tile_offset = atomicAdd(d_dump_counter, static_cast<size_t>(tile_cnt));
}
tile_offset = g.shfl(tile_offset, 0);
int bias_g = tile_cnt - __popc(vote >> (key_idx % TILE_SIZE));

if (match) {
d_key[tile_offset + bias_g] = key;
if (d_score) {
d_score[tile_offset + bias_g] = score;
}
}

#pragma unroll
for (int r = 0; r < TILE_SIZE; r++) {
unsigned int biased_vote = vote >> r;
bool cur_match = biased_vote & 1;
if (cur_match) {
int bias = tile_cnt - __popc(biased_vote);
size_t cur_idx = leading_key_idx + r;

for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) {
d_val[(tile_offset + bias) * dim + j] =
bucket->vectors[cur_idx * dim + j];
}
}
}
}
}

template <class K, class V, class S,
template <typename, typename> class PredFunctor>
__global__ void size_if_kernel(const Table<K, V, S>* __restrict table,
Bucket<K, V, S>* buckets, const K pattern,
const S threshold, size_t* d_counter) {
extern __shared__ unsigned char s[];
KVM<K, V, S>* const block_tuples{reinterpret_cast<KVM<K, V, S>*>(s)};

const size_t bucket_max_size{table->bucket_max_size};

size_t local_acc = 0;
__shared__ size_t block_acc;
PredFunctor<K, S> pred;

const size_t tid{blockIdx.x * blockDim.x + threadIdx.x};

if (threadIdx.x == 0) {
block_acc = 0;
}
__syncthreads();

for (size_t i = tid; i < table->capacity; i += blockDim.x * gridDim.x) {
Bucket<K, V, S>* const bucket{&buckets[i / bucket_max_size]};

const int key_idx{static_cast<int>(i % bucket_max_size)};
const K key{(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed)};
S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed);

if ((!IS_RESERVED_KEY(key)) && pred(key, score, pattern, threshold)) {
++local_acc;
}
}
atomicAdd(&block_acc, local_acc);
__syncthreads();

if (threadIdx.x == 0) {
atomicAdd(d_counter, block_acc);
}
}

template <class K, class V, class S,
template <typename, typename> class PredFunctor, int TILE_SIZE>
__global__ void dump_kernel_v3(const Table<K, V, S>* __restrict table,
Bucket<K, V, S>* buckets, const K pattern,
const S threshold, K* d_key, V* __restrict d_val,
S* __restrict d_score, const size_t offset,
const size_t search_length,
size_t* d_dump_counter) {
const size_t bucket_max_size = table->bucket_max_size;
int dim = table->dim;
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());

PredFunctor<K, S> pred;

__shared__ int block_cnt;
__shared__ size_t block_offset;

size_t tid = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);

for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) {
size_t bkt_idx = (ii + offset) / bucket_max_size;
size_t key_idx = (ii + offset) % bucket_max_size;
size_t leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE;
Bucket<K, V, S>* bucket = &(buckets[bkt_idx]);

const K key =
(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed);
S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed);

if (threadIdx.x == 0) {
block_cnt = 0;
}
__syncthreads();

bool match =
(!IS_RESERVED_KEY<K>(key)) && pred(key, score, pattern, threshold);
unsigned int vote = g.ballot(match);
int tile_cnt = __popc(vote);

int in_block_tile_offset = 0;
if (g.thread_rank() == 0) {
in_block_tile_offset =
atomicAdd(reinterpret_cast<int*>(&block_cnt), tile_cnt);
}
in_block_tile_offset = g.shfl(in_block_tile_offset, 0);
__syncthreads();

if (threadIdx.x == 0) {
block_offset = atomicAdd(d_dump_counter, static_cast<size_t>(block_cnt));
}
__syncthreads();

int tile_offset = block_offset + in_block_tile_offset;
int bias_g = tile_cnt - __popc(vote >> (key_idx % TILE_SIZE));

if (match) {
d_key[tile_offset + key_idx] = key;
d_key[tile_offset + bias_g] = key;
if (d_score) {
d_score[tile_offset + key_idx] = score;
d_score[tile_offset + bias_g] = score;
}
}

#pragma unroll
for (int r = 0; r < TILE_SIZE; r++) {
bool cur_match = vote >> r & 1;
unsigned int biased_vote = vote >> r;
bool cur_match = biased_vote & 1;
if (cur_match) {
int bias = tile_cnt - __popc(biased_vote);
int cur_idx = leading_key_idx + r;
for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) {
d_val[(tile_offset + cur_idx) * dim + j] =
d_val[(tile_offset + bias) * dim + j] =
bucket->vectors[cur_idx * dim + j];
}
}
Expand Down
32 changes: 28 additions & 4 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ class HashTable : public HashTableBase<K, V, S> {
CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, options_.device_id));
shared_mem_size_ = deviceProp.sharedMemPerBlock;
sm_cnt_ = deviceProp.multiProcessorCount;
max_threads_per_block_ = deviceProp.maxThreadsPerBlock;
create_table<key_type, value_type, score_type>(
&table_, allocator_, options_.dim, options_.init_capacity,
options_.max_capacity, options_.max_hbm_for_vectors,
Expand Down Expand Up @@ -2621,10 +2622,10 @@ class HashTable : public HashTableBase<K, V, S> {
offset % TILE_SIZE == 0 && n % TILE_SIZE == 0;

if (match_fast_cond) {
int grid_size = std::min(sm_cnt_, static_cast<int>(SAFE_GET_GRID_SIZE(
n, options_.block_size)));
const int TILE_SIZE = 8;

int grid_size = std::min(
sm_cnt_ * max_threads_per_block_ / options_.block_size,
static_cast<int>(SAFE_GET_GRID_SIZE(n, options_.block_size)));
const int TILE_SIZE = 32;
dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values,
Expand Down Expand Up @@ -2687,6 +2688,28 @@ class HashTable : public HashTableBase<K, V, S> {
return h_size;
}

/**
* @brief Returns the number of keys if meet PredFunctor.
*
* @param stream The CUDA stream that is used to execute the operation.
* @return The table size match condiction of PredFunctor.
*/
template <template <typename, typename> class PredFunctor>
void size_if(const key_type& pattern, const score_type& threshold,
size_type* d_counter, cudaStream_t stream = 0) const {
read_shared_lock lock(mutex_, stream);
CUDA_CHECK(cudaMemsetAsync(d_counter, 0, sizeof(size_type), stream));

size_t grid_size = SAFE_GET_GRID_SIZE(capacity(), options_.block_size);
grid_size = std::min(grid_size,
static_cast<size_t>(sm_cnt_ * max_threads_per_block_ /
options_.block_size));
size_if_kernel<key_type, value_type, score_type, PredFunctor>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, d_counter);
CudaCheckError();
}

/**
* @brief Returns the hash table capacity.
*
Expand Down Expand Up @@ -3057,6 +3080,7 @@ class HashTable : public HashTableBase<K, V, S> {
TableCore* d_table_ = nullptr;
size_t shared_mem_size_ = 0;
int sm_cnt_ = 0;
int max_threads_per_block_ = 0;
std::atomic_bool reach_max_capacity_{false};
bool initialized_ = false;
mutable group_shared_mutex mutex_;
Expand Down
Loading

0 comments on commit 7aaabe7

Please sign in to comment.