Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

diskann support new data type(fp16/bf16) #393

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ template <typename InType, typename... Types>
using TypeMatch = std::bool_constant<(... | std::is_same_v<InType, Types>)>;
template <typename InType>
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16>;
template <typename InType>
using KnowhereFloatTypeCheck = TypeMatch<InType, fp16, fp32, bf16>;

template <typename T>
struct MockData {
Expand Down
5 changes: 4 additions & 1 deletion src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
namespace knowhere {
template <typename DataType>
class DiskANNIndexNode : public IndexNode {
static_assert(std::is_same_v<DataType, fp32>, "DiskANN only support float");
static_assert(KnowhereFloatTypeCheck<DataType>::value,
"DiskANN only support floating point data type(float32, float16, bfloat16)");

public:
using DistType = float;
Expand Down Expand Up @@ -694,4 +695,6 @@ DiskANNIndexNode<DataType>::GetCachedNodeNum(const float cache_dram_budget, cons
}

KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp32);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, bf16);
} // namespace knowhere
3 changes: 2 additions & 1 deletion thirdparty/DiskANN/include/diskann/distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#include <cstdint>
#include "simd/hook.h"
#include "diskann/utils.h"
#include "knowhere/operands.h"
namespace diskann {

template<typename T>
using DISTFUN = T (*)(const T *, const T *, size_t);
using DISTFUN = std::function<float(const T *, const T *, size_t)>;

template<typename T>
DISTFUN<T> get_distance_function(Metric m);
Expand Down
2 changes: 1 addition & 1 deletion thirdparty/DiskANN/include/diskann/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ namespace diskann {
size_t _num_frozen_pts = 0;
bool _has_built = false;
DISTFUN<T> _func = nullptr;
std::function<T(const T *, const T *, size_t)> _distance;
DISTFUN<T> _distance;
unsigned _width = 0;
unsigned _ep = 0;
size_t _max_range_of_loaded_graph = 0;
Expand Down
2 changes: 1 addition & 1 deletion thirdparty/DiskANN/include/diskann/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ namespace diskann {
DISTFUN<T> dist_cmp;
DISTFUN<float> dist_cmp_float;

T dist_cmp_wrap(const T *x, const T *y, size_t d, int32_t u) {
float dist_cmp_wrap(const T *x, const T *y, size_t d, int32_t u) {
if (metric == Metric::COSINE) {
return dist_cmp(x, y, d) / base_norms[u];
} else {
Expand Down
46 changes: 28 additions & 18 deletions thirdparty/DiskANN/include/diskann/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ typedef int FileHandle;
#include "ann_exception.h"
#include "common_includes.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/operands.h"

// taken from
// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
Expand All @@ -61,7 +62,7 @@ typedef int FileHandle;
#define COMPLETION_PERCENT 10

inline bool file_exists(const std::string& name, bool dirCheck = false) {
int val;
int val;
struct stat buffer;
val = stat(name.c_str(), &buffer);

Expand Down Expand Up @@ -528,7 +529,8 @@ namespace diskann {

for (size_t i = 0; i < npts; i++) {
reader.read((char*) (data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
memset((void*) (data + i * rounded_dim + dim), 0,
(rounded_dim - dim) * sizeof(T));
}
stream << " done." << std::endl;
LOG_KNOWHERE_DEBUG_ << stream.str();
Expand Down Expand Up @@ -574,6 +576,9 @@ namespace diskann {
template<typename T>
float prepare_base_for_inner_products(const std::string in_file,
const std::string out_file) {
static_assert(
knowhere::KnowhereFloatTypeCheck<T>::value,
"prepare_base_for_inner_products only support fp16, bf16, fp32.");
LOG_KNOWHERE_DEBUG_
<< "Pre-processing base file by adding extra coordinate";
std::ifstream in_reader(in_file.c_str(), std::ios::binary);
Expand All @@ -597,10 +602,11 @@ namespace diskann {
size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE;
std::unique_ptr<T[]> in_block_data =
std::make_unique<T[]>(block_size * in_dims);
std::unique_ptr<float[]> out_block_data =
std::make_unique<float[]>(block_size * out_dims);
std::unique_ptr<T[]> out_block_data =
std::make_unique<T[]>(block_size * out_dims);

std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims);
std::memset((void*) out_block_data.get(), 0,
sizeof(T) * block_size * out_dims);
_u64 num_blocks = DIV_ROUND_UP(npts, block_size);

std::vector<float> norms(npts, 0);
Expand Down Expand Up @@ -633,24 +639,25 @@ namespace diskann {
for (_u64 p = 0; p < block_pts; p++) {
for (_u64 j = 0; j < in_dims; j++) {
out_block_data[p * out_dims + j] =
in_block_data[p * in_dims + j] / max_norm;
(T) (((float) in_block_data[p * in_dims + j]) / max_norm);
}
float res = 1 - (norms[start_id + p] / (max_norm * max_norm));
res = res <= 0 ? 0 : std::sqrt(res);
out_block_data[p * out_dims + out_dims - 1] = res;
out_block_data[p * out_dims + out_dims - 1] = (T) res;
}
out_writer.write((char*) out_block_data.get(),
block_pts * out_dims * sizeof(float));
block_pts * out_dims * sizeof(T));
}
out_writer.close();
return max_norm;
}

template<typename T>
std::vector<float> prepare_base_for_cosine(const std::string in_file,
const std::string out_file) {
LOG_KNOWHERE_DEBUG_
<< "Pre-processing base file by normalizing";
const std::string out_file) {
static_assert(knowhere::KnowhereFloatTypeCheck<T>::value,
"prepare_base_for_cosine only support fp16, bf16, fp32.");
LOG_KNOWHERE_DEBUG_ << "Pre-processing base file by normalizing";
std::ifstream in_reader(in_file.c_str(), std::ios::binary);
std::ofstream out_writer(out_file.c_str(), std::ios::binary);
_u64 npts, in_dims, out_dims;
Expand All @@ -671,10 +678,11 @@ namespace diskann {
size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE;
std::unique_ptr<T[]> in_block_data =
std::make_unique<T[]>(block_size * in_dims);
std::unique_ptr<float[]> out_block_data =
std::make_unique<float[]>(block_size * out_dims);
std::unique_ptr<T[]> out_block_data =
std::make_unique<T[]>(block_size * out_dims);

std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims);
std::memset((void*) out_block_data.get(), 0,
sizeof(T) * block_size * out_dims);
_u64 num_blocks = DIV_ROUND_UP(npts, block_size);

std::vector<float> norms(npts, 0);
Expand Down Expand Up @@ -707,11 +715,12 @@ namespace diskann {
for (_u64 p = 0; p < block_pts; p++) {
for (_u64 j = 0; j < in_dims; j++) {
out_block_data[p * out_dims + j] =
in_block_data[p * in_dims + j] / norms[start_id + p];
(T) (((float) in_block_data[p * in_dims + j]) /
norms[start_id + p]);
}
}
out_writer.write((char*) out_block_data.get(),
block_pts * out_dims * sizeof(float));
block_pts * out_dims * sizeof(T));
}
out_writer.close();

Expand Down Expand Up @@ -796,7 +805,8 @@ namespace diskann {

for (size_t i = 0; i < npts; i++) {
reader.read((char*) (data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
memset((void*) (data + i * rounded_dim + dim), 0,
(rounded_dim - dim) * sizeof(T));
}
}

Expand Down Expand Up @@ -828,7 +838,7 @@ namespace diskann {
float* read_buf, _u64 npts, _u64 ndims);

void normalize_data_file(const std::string& inFileName,
const std::string& outFileName);
const std::string& outFileName);

inline std::string get_pq_pivots_filename(const std::string& prefix) {
return prefix + "_pq_pivots.bin";
Expand Down
116 changes: 80 additions & 36 deletions thirdparty/DiskANN/src/aux_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ namespace diskann {
diskann::alloc_aligned(((void **) &warmup),
warmup_num * warmup_aligned_dim * sizeof(T),
8 * sizeof(T));
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
std::memset((void *) warmup, 0,
warmup_num * warmup_aligned_dim * sizeof(T));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-128, 127);
Expand Down Expand Up @@ -760,7 +761,6 @@ namespace diskann {
}

save_bin<uint32_t>(cache_file, node_list.data(), num_nodes_to_cache, 1);

}

// General purpose support for DiskANN interface
Expand Down Expand Up @@ -1063,7 +1063,7 @@ namespace diskann {

template<typename T>
int build_disk_index(const BuildConfig &config) {
if (!std::is_same<T, float>::value &&
if (!knowhere::KnowhereFloatTypeCheck<T>::value &&
(config.compare_metric == diskann::Metric::INNER_PRODUCT ||
config.compare_metric == diskann::Metric::COSINE)) {
std::stringstream stream;
Expand Down Expand Up @@ -1301,6 +1301,12 @@ namespace diskann {
const std::string mem_index_file,
const std::string output_file,
const std::string reorder_data_file);
template void create_disk_layout<knowhere::fp16>(
const std::string base_file, const std::string mem_index_file,
const std::string output_file, const std::string reorder_data_file);
template void create_disk_layout<knowhere::bf16>(
const std::string base_file, const std::string mem_index_file,
const std::string output_file, const std::string reorder_data_file);

template int8_t *load_warmup<int8_t>(const std::string &cache_warmup_file,
uint64_t &warmup_num,
Expand All @@ -1313,51 +1319,77 @@ namespace diskann {
template float *load_warmup<float>(const std::string &cache_warmup_file,
uint64_t &warmup_num, uint64_t warmup_dim,
uint64_t warmup_aligned_dim);

template uint32_t optimize_beamwidth<int8_t>(
std::unique_ptr<diskann::PQFlashIndex<int8_t>> &pFlashIndex,
int8_t *tuning_sample, _u64 tuning_sample_num,
template knowhere::fp16 *load_warmup<knowhere::fp16>(
const std::string &cache_warmup_file, uint64_t &warmup_num,
uint64_t warmup_dim, uint64_t warmup_aligned_dim);
template knowhere::bf16 *load_warmup<knowhere::bf16>(
const std::string &cache_warmup_file, uint64_t &warmup_num,
uint64_t warmup_dim, uint64_t warmup_aligned_dim);

// knowhere not support uint8/int8 diskann
// template uint32_t optimize_beamwidth<int8_t>(
// std::unique_ptr<diskann::PQFlashIndex<int8_t>> &pFlashIndex,
// int8_t *tuning_sample, _u64 tuning_sample_num,
// _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
// uint32_t start_bw);
// template uint32_t optimize_beamwidth<uint8_t>(
// std::unique_ptr<diskann::PQFlashIndex<uint8_t>> &pFlashIndex,
// uint8_t *tuning_sample, _u64 tuning_sample_num,
// _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
// uint32_t start_bw);
template uint32_t optimize_beamwidth<float>(
std::unique_ptr<diskann::PQFlashIndex<float>> &pFlashIndex,
float *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template uint32_t optimize_beamwidth<uint8_t>(
std::unique_ptr<diskann::PQFlashIndex<uint8_t>> &pFlashIndex,
uint8_t *tuning_sample, _u64 tuning_sample_num,
template uint32_t optimize_beamwidth<knowhere::fp16>(
std::unique_ptr<diskann::PQFlashIndex<knowhere::fp16>> &pFlashIndex,
knowhere::fp16 *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template uint32_t optimize_beamwidth<float>(
std::unique_ptr<diskann::PQFlashIndex<float>> &pFlashIndex,
float *tuning_sample, _u64 tuning_sample_num,
template uint32_t optimize_beamwidth<knowhere::bf16>(
std::unique_ptr<diskann::PQFlashIndex<knowhere::bf16>> &pFlashIndex,
knowhere::bf16 *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);

template int build_disk_index<int8_t>(const BuildConfig &config);
template int build_disk_index<uint8_t>(const BuildConfig &config);
// not support build uint8/int8 diskindex in knowhere
// template int build_disk_index<int8_t>(const BuildConfig &config);
// template int build_disk_index<uint8_t>(const BuildConfig &config);
template int build_disk_index<float>(const BuildConfig &config);
template int build_disk_index<knowhere::fp16>(const BuildConfig &config);
template int build_disk_index<knowhere::bf16>(const BuildConfig &config);

template std::unique_ptr<diskann::Index<int8_t>>
build_merged_vamana_index<int8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
build_merged_vamana_index<int8_t>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<float>>
build_merged_vamana_index<float>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
build_merged_vamana_index<float>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<uint8_t>>
build_merged_vamana_index<uint8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
build_merged_vamana_index<uint8_t>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<knowhere::fp16>>
build_merged_vamana_index<knowhere::fp16>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<knowhere::bf16>>
build_merged_vamana_index<knowhere::bf16>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);

template void generate_cache_list_from_graph_with_pq<int8_t>(
_u64 num_nodes_to_cache, unsigned R, const diskann::Metric compare_metric,
Expand All @@ -1377,4 +1409,16 @@ namespace diskann {
const std::string &pq_compressed_code_path, const unsigned entry_point,
const std::vector<std::vector<unsigned>> &graph,
const std::string &cache_file);
template void generate_cache_list_from_graph_with_pq<knowhere::fp16>(
_u64 num_nodes_to_cache, unsigned R, const diskann::Metric compare_metric,
const std::string &sample_file, const std::string &pq_pivots_path,
const std::string &pq_compressed_code_path, const unsigned entry_point,
const std::vector<std::vector<unsigned>> &graph,
const std::string &cache_file);
template void generate_cache_list_from_graph_with_pq<knowhere::bf16>(
_u64 num_nodes_to_cache, unsigned R, const diskann::Metric compare_metric,
const std::string &sample_file, const std::string &pq_pivots_path,
const std::string &pq_compressed_code_path, const unsigned entry_point,
const std::vector<std::vector<unsigned>> &graph,
const std::string &cache_file);
}; // namespace diskann
17 changes: 11 additions & 6 deletions thirdparty/DiskANN/src/distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ namespace diskann {
template<typename T>
DISTFUN<T> get_distance_function(diskann::Metric m) {
if (m == diskann::Metric::L2) {
return [](const T* x, const T* y, size_t size) -> T {
return [](const T* x, const T* y, size_t size) -> float {
float res = 0;
for (size_t i = 0; i < size; i++) {
res += ((float) x[i] - (float) y[i]) * ((float) x[i] - (float) y[i]);
}
return res;
};
} else if (m == diskann::Metric::INNER_PRODUCT) {
return [](const T* x, const T* y, size_t size) -> T {
} else if (m == diskann::Metric::INNER_PRODUCT ||
m == diskann::Metric::COSINE) {
return [](const T* x, const T* y, size_t size) -> float {
float res = 0;
for (size_t i = 0; i < size; i++) {
res += (float) x[i] * (float) y[i];
Expand Down Expand Up @@ -65,11 +66,15 @@ namespace diskann {
}
}

template DISTFUN<float> get_distance_function(diskann::Metric m);
template DISTFUN<uint8_t> get_distance_function(diskann::Metric m);
template DISTFUN<int8_t> get_distance_function(diskann::Metric m);
template DISTFUN<float> get_distance_function(diskann::Metric m);
template DISTFUN<uint8_t> get_distance_function(diskann::Metric m);
template DISTFUN<int8_t> get_distance_function(diskann::Metric m);
template DISTFUN<knowhere::fp16> get_distance_function(diskann::Metric m);
template DISTFUN<knowhere::bf16> get_distance_function(diskann::Metric m);

template float norm_l2sqr(const float*, size_t);
template float norm_l2sqr(const uint8_t*, size_t);
template float norm_l2sqr(const int8_t*, size_t);
template float norm_l2sqr(const knowhere::fp16*, size_t);
template float norm_l2sqr(const knowhere::bf16*, size_t);
} // namespace diskann
Loading
Loading