Skip to content

Commit

Permalink
diskann support new data type(fp16/bf16)
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Feb 21, 2024
1 parent b4149ed commit 4c951e1
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 82 deletions.
2 changes: 2 additions & 0 deletions include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ template <typename InType, typename... Types>
using TypeMatch = std::bool_constant<(... | std::is_same_v<InType, Types>)>;
template <typename InType>
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16>;
template <typename InType>
using KnowhereFloatTypeCheck = TypeMatch<InType, fp16, fp32, bf16>;

template <typename T>
struct MockData {
Expand Down
5 changes: 4 additions & 1 deletion src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
namespace knowhere {
template <typename DataType>
class DiskANNIndexNode : public IndexNode {
static_assert(std::is_same_v<DataType, fp32>, "DiskANN only support float");
static_assert(KnowhereFloatTypeCheck<DataType>::value,
"DiskANN only support floating point data type(float32, float16, bfloat16)");

public:
using DistType = float;
Expand Down Expand Up @@ -697,4 +698,6 @@ DiskANNIndexNode<DataType>::GetCachedNodeNum(const float cache_dram_budget, cons
}

KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp32);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp16);
KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, bf16);
} // namespace knowhere
3 changes: 2 additions & 1 deletion thirdparty/DiskANN/include/diskann/distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#include <cstdint>
#include "simd/hook.h"
#include "diskann/utils.h"
#include "knowhere/operands.h"
namespace diskann {

template<typename T>
using DISTFUN = T (*)(const T *, const T *, size_t);
using DISTFUN = std::function<float(const T *, const T *, size_t)>;

template<typename T>
DISTFUN<T> get_distance_function(Metric m);
Expand Down
2 changes: 1 addition & 1 deletion thirdparty/DiskANN/include/diskann/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ namespace diskann {
size_t _num_frozen_pts = 0;
bool _has_built = false;
DISTFUN<T> _func = nullptr;
std::function<T(const T *, const T *, size_t)> _distance;
DISTFUN<T> _distance;
unsigned _width = 0;
unsigned _ep = 0;
size_t _max_range_of_loaded_graph = 0;
Expand Down
2 changes: 1 addition & 1 deletion thirdparty/DiskANN/include/diskann/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ namespace diskann {
DISTFUN<T> dist_cmp;
DISTFUN<float> dist_cmp_float;

T dist_cmp_wrap(const T *x, const T *y, size_t d, int32_t u) {
float dist_cmp_wrap(const T *x, const T *y, size_t d, int32_t u) {
if (metric == Metric::COSINE) {
return dist_cmp(x, y, d) / base_norms[u];
} else {
Expand Down
55 changes: 37 additions & 18 deletions thirdparty/DiskANN/include/diskann/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ typedef int FileHandle;
#include "ann_exception.h"
#include "common_includes.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/operands.h"

// taken from
// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
Expand All @@ -61,7 +62,7 @@ typedef int FileHandle;
#define COMPLETION_PERCENT 10

inline bool file_exists(const std::string& name, bool dirCheck = false) {
int val;
int val;
struct stat buffer;
val = stat(name.c_str(), &buffer);

Expand Down Expand Up @@ -537,7 +538,8 @@ namespace diskann {

for (size_t i = 0; i < npts; i++) {
reader.read((char*) (data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
memset((void*) (data + i * rounded_dim + dim), 0,
(rounded_dim - dim) * sizeof(T));
}
stream << " done." << std::endl;
LOG_KNOWHERE_DEBUG_ << stream.str();
Expand Down Expand Up @@ -583,6 +585,13 @@ namespace diskann {
template<typename T>
float prepare_base_for_inner_products(const std::string in_file,
const std::string out_file) {
if (!knowhere::KnowhereFloatTypeCheck<T>::value) {
std::stringstream stream;
stream << "DiskANN currently only supports floating point(float32, "
"float16, bfloat16) for IP."
<< std::endl;
throw diskann::ANNException(stream.str(), -1);
}
LOG_KNOWHERE_DEBUG_
<< "Pre-processing base file by adding extra coordinate";
std::ifstream in_reader(in_file.c_str(), std::ios::binary);
Expand All @@ -606,10 +615,11 @@ namespace diskann {
size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE;
std::unique_ptr<T[]> in_block_data =
std::make_unique<T[]>(block_size * in_dims);
std::unique_ptr<float[]> out_block_data =
std::make_unique<float[]>(block_size * out_dims);
std::unique_ptr<T[]> out_block_data =
std::make_unique<T[]>(block_size * out_dims);

std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims);
std::memset((void*) out_block_data.get(), 0,
sizeof(T) * block_size * out_dims);
_u64 num_blocks = DIV_ROUND_UP(npts, block_size);

std::vector<float> norms(npts, 0);
Expand Down Expand Up @@ -642,24 +652,30 @@ namespace diskann {
for (_u64 p = 0; p < block_pts; p++) {
for (_u64 j = 0; j < in_dims; j++) {
out_block_data[p * out_dims + j] =
in_block_data[p * in_dims + j] / max_norm;
(T) (((float) in_block_data[p * in_dims + j]) / max_norm);
}
float res = 1 - (norms[start_id + p] / (max_norm * max_norm));
res = res <= 0 ? 0 : std::sqrt(res);
out_block_data[p * out_dims + out_dims - 1] = res;
out_block_data[p * out_dims + out_dims - 1] = (T) res;
}
out_writer.write((char*) out_block_data.get(),
block_pts * out_dims * sizeof(float));
block_pts * out_dims * sizeof(T));
}
out_writer.close();
return max_norm;
}

template<typename T>
std::vector<float> prepare_base_for_cosine(const std::string in_file,
const std::string out_file) {
LOG_KNOWHERE_DEBUG_
<< "Pre-processing base file by normalizing";
const std::string out_file) {
if (!knowhere::KnowhereFloatTypeCheck<T>::value) {
std::stringstream stream;
stream << "DiskANN currently only supports floating point(float32, "
"float16, bfloat16) for Cosine."
<< std::endl;
throw diskann::ANNException(stream.str(), -1);
}
LOG_KNOWHERE_DEBUG_ << "Pre-processing base file by normalizing";
std::ifstream in_reader(in_file.c_str(), std::ios::binary);
std::ofstream out_writer(out_file.c_str(), std::ios::binary);
_u64 npts, in_dims, out_dims;
Expand All @@ -680,10 +696,11 @@ namespace diskann {
size_t block_size = npts <= BLOCK_SIZE ? npts : BLOCK_SIZE;
std::unique_ptr<T[]> in_block_data =
std::make_unique<T[]>(block_size * in_dims);
std::unique_ptr<float[]> out_block_data =
std::make_unique<float[]>(block_size * out_dims);
std::unique_ptr<T[]> out_block_data =
std::make_unique<T[]>(block_size * out_dims);

std::memset(out_block_data.get(), 0, sizeof(float) * block_size * out_dims);
std::memset((void*) out_block_data.get(), 0,
sizeof(T) * block_size * out_dims);
_u64 num_blocks = DIV_ROUND_UP(npts, block_size);

std::vector<float> norms(npts, 0);
Expand Down Expand Up @@ -716,11 +733,12 @@ namespace diskann {
for (_u64 p = 0; p < block_pts; p++) {
for (_u64 j = 0; j < in_dims; j++) {
out_block_data[p * out_dims + j] =
in_block_data[p * in_dims + j] / norms[start_id + p];
(T) (((float) in_block_data[p * in_dims + j]) /
norms[start_id + p]);
}
}
out_writer.write((char*) out_block_data.get(),
block_pts * out_dims * sizeof(float));
block_pts * out_dims * sizeof(T));
}
out_writer.close();

Expand Down Expand Up @@ -805,7 +823,8 @@ namespace diskann {

for (size_t i = 0; i < npts; i++) {
reader.read((char*) (data + i * rounded_dim), dim * sizeof(T));
memset(data + i * rounded_dim + dim, 0, (rounded_dim - dim) * sizeof(T));
memset((void*) (data + i * rounded_dim + dim), 0,
(rounded_dim - dim) * sizeof(T));
}
}

Expand Down Expand Up @@ -837,7 +856,7 @@ namespace diskann {
float* read_buf, _u64 npts, _u64 ndims);

void normalize_data_file(const std::string& inFileName,
const std::string& outFileName);
const std::string& outFileName);

inline std::string get_pq_pivots_filename(const std::string& prefix) {
return prefix + "_pq_pivots.bin";
Expand Down
89 changes: 65 additions & 24 deletions thirdparty/DiskANN/src/aux_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ namespace diskann {
diskann::alloc_aligned(((void **) &warmup),
warmup_num * warmup_aligned_dim * sizeof(T),
8 * sizeof(T));
std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
std::memset((void*) warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(-128, 127);
Expand Down Expand Up @@ -760,7 +760,6 @@ namespace diskann {
}

save_bin<uint32_t>(cache_file, node_list.data(), num_nodes_to_cache, 1);

}

// General purpose support for DiskANN interface
Expand Down Expand Up @@ -1064,7 +1063,7 @@ namespace diskann {

template<typename T>
int build_disk_index(const BuildConfig &config) {
if (!std::is_same<T, float>::value &&
if (!knowhere::KnowhereFloatTypeCheck<T>::value &&
(config.compare_metric == diskann::Metric::INNER_PRODUCT ||
config.compare_metric == diskann::Metric::COSINE)) {
std::stringstream stream;
Expand Down Expand Up @@ -1305,6 +1304,12 @@ namespace diskann {
const std::string mem_index_file,
const std::string output_file,
const std::string reorder_data_file);
template void create_disk_layout<knowhere::fp16>(
const std::string base_file, const std::string mem_index_file,
const std::string output_file, const std::string reorder_data_file);
template void create_disk_layout<knowhere::bf16>(
const std::string base_file, const std::string mem_index_file,
const std::string output_file, const std::string reorder_data_file);

template int8_t *load_warmup<int8_t>(const std::string &cache_warmup_file,
uint64_t &warmup_num,
Expand All @@ -1317,6 +1322,12 @@ namespace diskann {
template float *load_warmup<float>(const std::string &cache_warmup_file,
uint64_t &warmup_num, uint64_t warmup_dim,
uint64_t warmup_aligned_dim);
template knowhere::fp16 *load_warmup<knowhere::fp16>(
const std::string &cache_warmup_file, uint64_t &warmup_num,
uint64_t warmup_dim, uint64_t warmup_aligned_dim);
template knowhere::bf16 *load_warmup<knowhere::bf16>(
const std::string &cache_warmup_file, uint64_t &warmup_num,
uint64_t warmup_dim, uint64_t warmup_aligned_dim);

template uint32_t optimize_beamwidth<int8_t>(
std::unique_ptr<diskann::PQFlashIndex<int8_t>> &pFlashIndex,
Expand All @@ -1333,35 +1344,53 @@ namespace diskann {
float *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template uint32_t optimize_beamwidth<knowhere::fp16>(
std::unique_ptr<diskann::PQFlashIndex<knowhere::fp16>> &pFlashIndex,
knowhere::fp16 *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template uint32_t optimize_beamwidth<knowhere::bf16>(
std::unique_ptr<diskann::PQFlashIndex<knowhere::bf16>> &pFlashIndex,
knowhere::bf16 *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);

template int build_disk_index<int8_t>(const BuildConfig &config);
template int build_disk_index<uint8_t>(const BuildConfig &config);
template int build_disk_index<float>(const BuildConfig &config);
template int build_disk_index<knowhere::fp16>(const BuildConfig &config);
template int build_disk_index<knowhere::bf16>(const BuildConfig &config);

template std::unique_ptr<diskann::Index<int8_t>>
build_merged_vamana_index<int8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
build_merged_vamana_index<int8_t>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<float>>
build_merged_vamana_index<float>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
build_merged_vamana_index<float>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<uint8_t>>
build_merged_vamana_index<uint8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
build_merged_vamana_index<uint8_t>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<knowhere::fp16>>
build_merged_vamana_index<knowhere::fp16>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);
template std::unique_ptr<diskann::Index<knowhere::bf16>>
build_merged_vamana_index<knowhere::bf16>(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget, std::string mem_index_path,
std::string medoids_path, std::string centroids_file);

template void generate_cache_list_from_graph_with_pq<int8_t>(
_u64 num_nodes_to_cache, unsigned R, const diskann::Metric compare_metric,
Expand All @@ -1381,4 +1410,16 @@ namespace diskann {
const std::string &pq_compressed_code_path, const unsigned entry_point,
const std::vector<std::vector<unsigned>> &graph,
const std::string &cache_file);
template void generate_cache_list_from_graph_with_pq<knowhere::fp16>(
_u64 num_nodes_to_cache, unsigned R, const diskann::Metric compare_metric,
const std::string &sample_file, const std::string &pq_pivots_path,
const std::string &pq_compressed_code_path, const unsigned entry_point,
const std::vector<std::vector<unsigned>> &graph,
const std::string &cache_file);
template void generate_cache_list_from_graph_with_pq<knowhere::bf16>(
_u64 num_nodes_to_cache, unsigned R, const diskann::Metric compare_metric,
const std::string &sample_file, const std::string &pq_pivots_path,
const std::string &pq_compressed_code_path, const unsigned entry_point,
const std::vector<std::vector<unsigned>> &graph,
const std::string &cache_file);
}; // namespace diskann
17 changes: 11 additions & 6 deletions thirdparty/DiskANN/src/distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ namespace diskann {
template<typename T>
DISTFUN<T> get_distance_function(diskann::Metric m) {
if (m == diskann::Metric::L2) {
return [](const T* x, const T* y, size_t size) -> T {
return [](const T* x, const T* y, size_t size) -> float {
float res = 0;
for (size_t i = 0; i < size; i++) {
res += ((float) x[i] - (float) y[i]) * ((float) x[i] - (float) y[i]);
}
return res;
};
} else if (m == diskann::Metric::INNER_PRODUCT) {
return [](const T* x, const T* y, size_t size) -> T {
} else if (m == diskann::Metric::INNER_PRODUCT ||
m == diskann::Metric::COSINE) {
return [](const T* x, const T* y, size_t size) -> float {
float res = 0;
for (size_t i = 0; i < size; i++) {
res += (float) x[i] * (float) y[i];
Expand Down Expand Up @@ -65,11 +66,15 @@ namespace diskann {
}
}

template DISTFUN<float> get_distance_function(diskann::Metric m);
template DISTFUN<uint8_t> get_distance_function(diskann::Metric m);
template DISTFUN<int8_t> get_distance_function(diskann::Metric m);
template DISTFUN<float> get_distance_function(diskann::Metric m);
template DISTFUN<uint8_t> get_distance_function(diskann::Metric m);
template DISTFUN<int8_t> get_distance_function(diskann::Metric m);
template DISTFUN<knowhere::fp16> get_distance_function(diskann::Metric m);
template DISTFUN<knowhere::bf16> get_distance_function(diskann::Metric m);

template float norm_l2sqr(const float*, size_t);
template float norm_l2sqr(const uint8_t*, size_t);
template float norm_l2sqr(const int8_t*, size_t);
template float norm_l2sqr(const knowhere::fp16*, size_t);
template float norm_l2sqr(const knowhere::bf16*, size_t);
} // namespace diskann
Loading

0 comments on commit 4c951e1

Please sign in to comment.