Skip to content

Commit

Permalink
feat(save/load): Save and load with meta files.
Browse files Browse the repository at this point in the history
Fix following bugs, including:
  1. fix(save/load) Use host memory after free it.
  2. fix(test): missing thread offset in test case.
  3. fix: miss freeing counter in export_batch.
  4. fix: acquire lock when calling save/load.

feat: Allow insert_and_assign and accum_and_assign ignore constraint from strategy of table.
  • Loading branch information
Lifann authored and rhdong committed Sep 27, 2022
1 parent ea5e7a6 commit d6354d6
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 108 deletions.
42 changes: 37 additions & 5 deletions include/merlin/types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <stddef.h>
#include <cuda/std/semaphore>

namespace nv {
Expand Down Expand Up @@ -83,10 +84,12 @@ using EraseIfPredictInternal =
);

/**
* The abstract class of KV file.
* An abstract class provides interface between the nv::merlin::HashTable
* and a file, which enables the table to save to the file or load from
* the file, by overriding the `read` and `write` method.
*
* @tparam K The data type of the key.
* @tparam V The data type of the vector's item type.
* @tparam V The data type of the vector's elements.
* The item data type should be a basic data type of C++/CUDA.
* @tparam M The data type for `meta`.
* The currently supported data type is only `uint64_t`.
Expand All @@ -97,9 +100,38 @@ template <class K, class V, class M, size_t D>
class BaseKVFile {
public:
virtual ~BaseKVFile() {}
virtual ssize_t Read(size_t n, K* keys, V* vectors, M* metas) = 0;
virtual ssize_t Write(size_t n, const K* keys, const V* vectors,
const M* metas) = 0;

/**
* Read from file and fill into the keys, values, and metas buffer.
* When calling save/load method from table, it can assume that the
* received buffer of keys, vectors, and metas are automatically
* pre-allocated.
*
* @param n The number of KV pairs expect to read. `int64_t` was used
* here to adapt to various filesytem and formats.
* @param keys The pointer to received buffer for keys.
* @param vectors The pointer to received buffer for vectors.
* @param metas The pointer to received buffer for metas.
*
* @return Number of KV pairs have been successfully read.
*/
virtual size_t read(size_t n, K* keys, V* vectors, M* metas) = 0;

/**
* Write keys, values, metas from table to the file. It defines
* an abstract method to get batch of KV pairs and write them into
* file.
*
* @param n The number of KV pairs to be written. `int64_t` was used
* here to adapt to various filesytem and formats.
* @param keys The keys will be written to file.
* @param vectors The vectors of values will be written to file.
* @param metas The metas will be written to file.
*
* @return Number of KV pairs have been successfully written.
*/
virtual size_t write(size_t n, const K* keys, const V* vectors,
const M* metas) = 0;
};

} // namespace merlin
Expand Down
65 changes: 42 additions & 23 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,18 @@ class HashTable {
*
* @param stream The CUDA stream that is used to execute the operation.
*
* @param ignore_evict_strategy A boolean option indicating whether if
* the insert_or_assign ignores the evict strategy of table with current
* metas anyway. If true, it does not check whether the metas confroms to
* the evict strategy. If false, it requires the metas follow the evict
* strategy of table.
*/
void insert_or_assign(size_type n,
const key_type* keys, // (n)
const value_type* values, // (n, DIM)
const meta_type* metas = nullptr, // (n)
cudaStream_t stream = 0) {
cudaStream_t stream = 0,
bool ignore_evict_strategy = false) {
if (n == 0) {
return;
}
Expand All @@ -224,7 +230,9 @@ class HashTable {
reserve(capacity() * 2);
}

check_evict_strategy(metas);
if (!ignore_evict_strategy) {
check_evict_strategy(metas);
}

if (is_fast_mode()) {
const size_t block_size = 128;
Expand Down Expand Up @@ -351,13 +359,20 @@ class HashTable {
*
* @param stream The CUDA stream that is used to execute the operation.
*
* @param ignore_evict_strategy A boolean option indicating whether if
* the accum_or_assign ignores the evict strategy of table with current
* metas anyway. If true, it does not check whether the metas confroms to
* the evict strategy. If false, it requires the metas follow the evict
* strategy of table.
*
*/
void accum_or_assign(size_type n,
const key_type* keys, // (n)
const value_type* value_or_deltas, // (n, DIM)
const bool* accum_or_assigns, // (n)
const meta_type* metas = nullptr, // (n)
cudaStream_t stream = 0) {
cudaStream_t stream = 0,
bool ignore_evict_strategy = false) {
if (n == 0) {
return;
}
Expand All @@ -366,7 +381,9 @@ class HashTable {
reserve(capacity() * 2);
}

check_evict_strategy(metas);
if (!ignore_evict_strategy) {
check_evict_strategy(metas);
}

vector_type** dst;
int* src_offset;
Expand Down Expand Up @@ -723,6 +740,7 @@ class HashTable {
export_batch(n, offset, d_counter, keys, values, metas, stream);
CUDA_CHECK(cudaMemcpyAsync(&h_counter, d_counter, sizeof(size_type),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaFreeAsync(d_counter, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
return h_counter;
}
Expand Down Expand Up @@ -830,17 +848,18 @@ class HashTable {
}

/**
* @brief Save table to an abstract file.
* @brief Save keys, vectors, metas in table to file or files.
*
* @param file An BaseKVFile object defined the file format within filesystem.
* @param file A BaseKVFile object defined the file format on host filesystem.
* @param buffer_size The size of buffer used for saving in bytes.
* @param stream The CUDA stream used to execute the operation.
*
* @return Number of keys saved to file.
* @return Number of KV pairs saved to file.
*/
size_type save(BaseKVFile<K, V, M, DIM>* file,
size_type buffer_size = 1048576,
cudaStream_t stream = 0) const {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
K* d_keys = nullptr;
V* d_vectors = nullptr;
M* d_metas = nullptr;
Expand Down Expand Up @@ -883,15 +902,15 @@ class HashTable {
CUDA_CHECK(cudaMemcpyAsync(h_metas, d_metas, sizeof(M) * nkeys,
cudaMemcpyDeviceToHost, stream));

size_type counter = nkeys;
size_type counter = 0;
CUDA_CHECK(cudaStreamSynchronize(stream));

for (size_type offset = batch_pairs_num; offset < total_size;
offset += batch_pairs_num) {
CUDA_CHECK(cudaMemsetAsync(d_next_nkeys, 0, sizeof(size_type), stream));
export_batch(batch_pairs_num, offset, d_next_nkeys, d_keys, d_vectors,
d_metas, stream);
file->Write(nkeys, h_keys, h_vectors, h_metas);
counter += file->write(nkeys, h_keys, h_vectors, h_metas);
CUDA_CHECK(cudaMemcpyAsync(&nkeys, d_next_nkeys, sizeof(size_type),
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -901,7 +920,6 @@ class HashTable {
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaMemcpyAsync(h_metas, d_metas, sizeof(M) * nkeys,
cudaMemcpyDeviceToHost, stream));
counter += nkeys;
CUDA_CHECK(cudaStreamSynchronize(stream));
}

Expand All @@ -913,26 +931,27 @@ class HashTable {
CUDA_CHECK(cudaMemcpyAsync(h_metas, d_metas, sizeof(M) * nkeys,
cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
file->Write(nkeys, h_keys, h_vectors, h_metas);
counter += file->write(nkeys, h_keys, h_vectors, h_metas);
}

CUDA_FREE_POINTERS(stream, d_keys, d_vectors, d_metas, d_next_nkeys, h_keys,
h_vectors, h_metas);
CUDA_FREE_POINTERS(stream, d_keys, d_vectors, d_metas, d_next_nkeys);
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_FREE_POINTERS(stream, h_keys, h_vectors, h_metas);
return counter;
}

/**
* @brief Load file and restore table.
* @brief Load keys, vectors, metas from file to table.
*
* @param file An BaseKVFile object defined the file format within filesystem.
* @param file An BaseKVFile defined the file format within filesystem.
* @param buffer_size The size of buffer used for loading in bytes.
* @param stream The CUDA stream used to execute the operation.
*
* @return Number of keys loaded from file.
*/
size_type load(BaseKVFile<K, V, M, DIM>* file,
size_type buffer_size = 1048576, cudaStream_t stream = 0) {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
K* d_keys = nullptr;
V* d_vectors = nullptr;
M* d_metas = nullptr;
Expand All @@ -945,8 +964,8 @@ class HashTable {
CUDA_CHECK(cudaMallocHost(&h_keys, sizeof(K) * batch_pairs_num));
CUDA_CHECK(cudaMallocHost(&h_vectors, sizeof(V) * batch_pairs_num * DIM));
CUDA_CHECK(cudaMallocHost(&h_metas, sizeof(M) * batch_pairs_num));
size_type nkeys = file->Read(batch_pairs_num, h_keys, h_vectors, h_metas);
size_type counts = nkeys;
size_type nkeys = file->read(batch_pairs_num, h_keys, h_vectors, h_metas);
size_type counter = nkeys;
if (nkeys == 0) {
CUDA_FREE_POINTERS(stream, h_keys, h_vectors, h_metas);
CUDA_CHECK(cudaStreamSynchronize(stream));
Expand All @@ -964,16 +983,16 @@ class HashTable {
cudaMemcpyHostToDevice, stream));
CUDA_CHECK(cudaMemcpyAsync(d_metas, h_metas, sizeof(M) * nkeys,
cudaMemcpyHostToDevice, stream));
insert_or_assign(nkeys, d_keys, d_vectors, d_metas, stream);
nkeys = file->Read(batch_pairs_num, h_keys, h_vectors, h_metas);
counts += nkeys;
insert_or_assign(nkeys, d_keys, d_vectors, d_metas, stream, true);
nkeys = file->read(batch_pairs_num, h_keys, h_vectors, h_metas);
counter += nkeys;
CUDA_CHECK(cudaStreamSynchronize(stream));
} while (nkeys > 0);

CUDA_FREE_POINTERS(stream, d_keys, d_vectors, d_metas, h_keys, h_vectors,
h_metas);
CUDA_FREE_POINTERS(stream, d_keys, d_vectors, d_metas);
CUDA_CHECK(cudaStreamSynchronize(stream));
return counts;
CUDA_FREE_POINTERS(stream, h_keys, h_vectors, h_metas);
return counter;
}

private:
Expand Down
Loading

0 comments on commit d6354d6

Please sign in to comment.