diff --git a/crates/base/src/index.rs b/crates/base/src/index.rs index 66219c0e..344d384c 100644 --- a/crates/base/src/index.rs +++ b/crates/base/src/index.rs @@ -476,7 +476,7 @@ impl Default for ScalarQuantizationOptions { #[validate(schema(function = "Self::validate_self"))] pub struct ProductQuantizationOptions { #[serde(default = "ProductQuantizationOptions::default_ratio")] - #[validate(range(min = 1, max = 1024))] + #[validate(range(min = 1, max = 8))] pub ratio: u32, #[serde(default = "ProductQuantizationOptions::default_bits")] pub bits: u32, diff --git a/crates/quantization/src/product.rs b/crates/quantization/src/product.rs index ad50de99..00f6b37c 100644 --- a/crates/quantization/src/product.rs +++ b/crates/quantization/src/product.rs @@ -34,6 +34,7 @@ pub struct ProductQuantizer { ratio: u32, bits: u32, centroids: Vec2, + tcentroids: Vec2, } impl Quantizer for ProductQuantizer { @@ -76,11 +77,18 @@ impl Quantizer for ProductQuantizer { .copy_from_slice(&points[i as usize][(j,)]); } } + let mut tcentroids = Vec2::zeros((dims as usize, 1 << bits)); + for i in 0..dims as usize { + for j in 0_usize..(1 << bits) { + tcentroids[(i, j)] = centroids[(j, i)]; + } + } Self { dims, ratio, bits, centroids, + tcentroids, } } @@ -140,13 +148,7 @@ impl Quantizer for ProductQuantizer { type Lut = Vec; fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { - O::preprocess( - self.dims, - self.ratio, - self.bits, - self.centroids.as_slice(), - vector, - ) + O::preprocess(self.dims, self.ratio, self.bits, &self.tcentroids, vector) } fn process(&self, lut: &Self::Lut, code: &[u8], _: Borrowed<'_, O>) -> Distance { @@ -161,13 +163,7 @@ impl Quantizer for ProductQuantizer { ); fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut { - O::fscan_preprocess( - self.dims, - self.ratio, - self.bits, - self.centroids.as_slice(), - vector, - ) + O::fscan_preprocess(self.dims, self.ratio, self.bits, &self.tcentroids, vector) } fn fscan_process(&self, flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { @@ -301,7 +297,7 @@ pub trait OperatorProductQuantization: Operator { dims: u32, ratio: u32, bits: u32, - centroids: &[Self::Scalar], + tcentroids: &Vec2, vector: Borrowed<'_, Self>, ) -> Vec; fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance; @@ -309,7 +305,7 @@ pub trait OperatorProductQuantization: Operator { dims: u32, ratio: u32, bits: u32, - centroids: &[Self::Scalar], + tcentroids: &Vec2, vector: Borrowed<'_, Self>, ) -> (u32, f32, f32, Vec); fn fscan_process(flut: &(u32, f32, f32, Vec), code: &[u8]) -> [Distance; 32]; @@ -324,7 +320,7 @@ impl OperatorProductQuantization for VectDot { dims: u32, ratio: u32, bits: u32, - centroids: &Vec2, + centroids: &Vec2, vector: Borrowed<'_, Self>, ) -> Vec { let mut code = Vec::with_capacity(dims.div_ceil(ratio) as _); @@ -350,23 +346,121 @@ impl OperatorProductQuantization for VectDot { dims: u32, ratio: u32, bits: u32, - centroids: &[Self::Scalar], + tcentroids: &Vec2, vector: Borrowed<'_, Self>, ) -> Vec { - let mut xy = Vec::with_capacity((dims.div_ceil(ratio) as usize) * (1 << bits)); - for i in 0..dims.div_ceil(ratio) { - let subdims = std::cmp::min(ratio, dims - ratio * i); - xy.extend((0_usize..1 << bits).map(|k| { - let mut xy = 0.0f32; - for i in ratio * i..ratio * i + subdims { - let x = vector.slice()[i as usize].to_f32(); - let y = centroids[(k as u32 * dims + i) as usize].to_f32(); - xy += x * y; + #[inline(never)] + fn internal( + dims: usize, + tcentroids: &[S], + vector: &[S], + ) -> Vec { + // code below needs special care, any minor changes would result in huge performance degradation + // For example: + // * calling `Vec::with_capacity` with parameter `dims.div_ceil(RATIO) * (1 << BITS)` + // * move `assert!(dims <= 65535)` after allocation + // * change pointer arithmetic to `get_unchecked` or `std::hint::assert_unchecked` + // * change parameters from slices to pointers + assert!(dims <= 65535); + assert!(tcentroids.len() == dims * (1 << BITS)); + assert!(vector.len() == dims); + let mut table = Vec::::with_capacity((dims / RATIO) * (1 << BITS) + (1 << BITS)); + if dims >= 32 { + // fast path + for i in 0..dims / RATIO { + for j in 0..1 << BITS { + let mut value = 0.0f32; + for k in 0..RATIO { + let idx_x = RATIO * i + k; + let idx_y = (RATIO * i + k) * (1 << BITS) + j; + let x = unsafe { vector.as_ptr().add(idx_x).read() }; + let y = unsafe { tcentroids.as_ptr().add(idx_y).read() }; + let xy = x.to_f32() * y.to_f32(); + value += xy; + } + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } } - xy - })); + if dims % RATIO != 0 { + let i = dims / RATIO; + for j in 0..1 << BITS { + let mut value = 0.0f32; + for k in 0..dims % RATIO { + let idx_x = RATIO * i + k; + let idx_y = (RATIO * i + k) * (1 << BITS) + j; + let x = unsafe { vector.as_ptr().add(idx_x).read() }; + let y = unsafe { tcentroids.as_ptr().add(idx_y).read() }; + let xy = x.to_f32() * y.to_f32(); + value += xy; + } + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } + } + } else { + // slow path + for i in 0..dims.div_ceil(RATIO) { + for j in 0..1 << BITS { + let mut value = 0.0f32; + for k in 0..std::cmp::min(RATIO, dims - j * RATIO) { + let idx_x = RATIO * i + k; + let idx_y = (RATIO * i + k) * (1 << BITS) + j; + let x = unsafe { vector.as_ptr().add(idx_x).read() }; + let y = unsafe { tcentroids.as_ptr().add(idx_y).read() }; + let xy = x.to_f32() * y.to_f32(); + value += xy; + } + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } + } + } + unsafe { + table.set_len(dims.div_ceil(RATIO) * (1 << BITS)); + } + table + } + assert!((1..=8).contains(&ratio) && (bits == 1 || bits == 2 || bits == 4 || bits == 8)); + let no = (ratio - 1) * 4 + bits.ilog2(); + match no { + 0 => internal::<1, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 1 => internal::<1, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 2 => internal::<1, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 3 => internal::<1, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 4 => internal::<2, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 5 => internal::<2, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 6 => internal::<2, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 7 => internal::<2, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 8 => internal::<3, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 9 => internal::<3, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 10 => internal::<3, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 11 => internal::<3, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 12 => internal::<4, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 13 => internal::<4, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 14 => internal::<4, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 15 => internal::<4, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 16 => internal::<5, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 17 => internal::<5, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 18 => internal::<5, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 19 => internal::<5, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 20 => internal::<6, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 21 => internal::<6, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 22 => internal::<6, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 23 => internal::<6, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 24 => internal::<7, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 25 => internal::<7, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 26 => internal::<7, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 27 => internal::<7, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 28 => internal::<8, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 29 => internal::<8, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 30 => internal::<8, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 31 => internal::<8, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 32.. => unreachable!(), } - xy } fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { fn internal(dims: u32, ratio: u32, t: &[f32], f: F) -> Distance @@ -398,7 +492,7 @@ impl OperatorProductQuantization for VectDot { dims: u32, ratio: u32, bits: u32, - centroids: &[Self::Scalar], + centroids: &Vec2, vector: Borrowed<'_, Self>, ) -> (u32, f32, f32, Vec) { let (k, b, t) = quantize::<255>(&Self::preprocess(dims, ratio, bits, centroids, vector)); @@ -420,7 +514,7 @@ impl OperatorProductQuantization for VectL2 { dims: u32, ratio: u32, bits: u32, - centroids: &Vec2, + centroids: &Vec2, vector: Borrowed<'_, Self>, ) -> Vec { let mut code = Vec::with_capacity(dims.div_ceil(ratio) as _); @@ -446,24 +540,121 @@ impl OperatorProductQuantization for VectL2 { dims: u32, ratio: u32, bits: u32, - centroids: &[Self::Scalar], + tcentroids: &Vec2, vector: Borrowed<'_, Self>, ) -> Vec { - let mut d2 = Vec::with_capacity((dims.div_ceil(ratio) as usize) * (1 << bits)); - for i in 0..dims.div_ceil(ratio) { - let subdims = std::cmp::min(ratio, dims - ratio * i); - d2.extend((0_usize..1 << bits).map(|k| { - let mut d2 = 0.0f32; - for i in ratio * i..ratio * i + subdims { - let x = vector.slice()[i as usize].to_f32(); - let y = centroids[(k as u32 * dims + i) as usize].to_f32(); - let d = x - y; - d2 += d * d; + #[inline(never)] + fn internal( + dims: usize, + tcentroids: &[S], + vector: &[S], + ) -> Vec { + // code below needs special care, any minor changes would result in huge performance degradation + // For example: + // * calling `Vec::with_capacity` with parameter `dims.div_ceil(RATIO) * (1 << BITS)` + // * move `assert!(dims <= 65535)` after allocation + // * change pointer arithmetic to `get_unchecked` or `std::hint::assert_unchecked` + // * change parameters from slices to pointers + assert!(dims <= 65535); + assert!(tcentroids.len() == dims * (1 << BITS)); + assert!(vector.len() == dims); + let mut table = Vec::::with_capacity((dims / RATIO) * (1 << BITS) + (1 << BITS)); + if dims >= 32 { + // fast path + for i in 0..dims / RATIO { + for j in 0..1 << BITS { + let mut value = 0.0f32; + for k in 0..RATIO { + let idx_x = RATIO * i + k; + let idx_y = (RATIO * i + k) * (1 << BITS) + j; + let x = unsafe { vector.as_ptr().add(idx_x).read() }; + let y = unsafe { tcentroids.as_ptr().add(idx_y).read() }; + let d = x.to_f32() - y.to_f32(); + value += d * d; + } + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } + } + if dims % RATIO != 0 { + let i = dims / RATIO; + for j in 0..1 << BITS { + let mut value = 0.0f32; + for k in 0..dims % RATIO { + let idx_x = RATIO * i + k; + let idx_y = (RATIO * i + k) * (1 << BITS) + j; + let x = unsafe { vector.as_ptr().add(idx_x).read() }; + let y = unsafe { tcentroids.as_ptr().add(idx_y).read() }; + let d = x.to_f32() - y.to_f32(); + value += d * d; + } + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } + } + } else { + // slow path + for i in 0..dims.div_ceil(RATIO) { + for j in 0..1 << BITS { + let mut value = 0.0f32; + for k in 0..std::cmp::min(RATIO, dims - j * RATIO) { + let idx_x = RATIO * i + k; + let idx_y = (RATIO * i + k) * (1 << BITS) + j; + let x = unsafe { vector.as_ptr().add(idx_x).read() }; + let y = unsafe { tcentroids.as_ptr().add(idx_y).read() }; + let d = x.to_f32() - y.to_f32(); + value += d * d; + } + unsafe { + table.as_mut_ptr().add(i * (1 << BITS) + j).write(value); + } + } } - d2 - })); + } + unsafe { + table.set_len(dims.div_ceil(RATIO) * (1 << BITS)); + } + table + } + assert!((1..=8).contains(&ratio) && (bits == 1 || bits == 2 || bits == 4 || bits == 8)); + let no = (ratio - 1) * 4 + bits.ilog2(); + match no { + 0 => internal::<1, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 1 => internal::<1, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 2 => internal::<1, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 3 => internal::<1, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 4 => internal::<2, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 5 => internal::<2, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 6 => internal::<2, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 7 => internal::<2, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 8 => internal::<3, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 9 => internal::<3, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 10 => internal::<3, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 11 => internal::<3, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 12 => internal::<4, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 13 => internal::<4, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 14 => internal::<4, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 15 => internal::<4, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 16 => internal::<5, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 17 => internal::<5, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 18 => internal::<5, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 19 => internal::<5, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 20 => internal::<6, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 21 => internal::<6, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 22 => internal::<6, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 23 => internal::<6, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 24 => internal::<7, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 25 => internal::<7, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 26 => internal::<7, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 27 => internal::<7, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 28 => internal::<8, 1, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 29 => internal::<8, 2, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 30 => internal::<8, 4, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 31 => internal::<8, 8, S>(dims as _, tcentroids.as_slice(), vector.slice()), + 32.. => unreachable!(), } - d2 } fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { fn internal(dims: u32, ratio: u32, t: &[f32], f: F) -> Distance @@ -495,7 +686,7 @@ impl OperatorProductQuantization for VectL2 { dims: u32, ratio: u32, bits: u32, - centroids: &[Self::Scalar], + centroids: &Vec2, vector: Borrowed<'_, Self>, ) -> (u32, f32, f32, Vec) { let (k, b, t) = quantize::<255>(&Self::preprocess(dims, ratio, bits, centroids, vector)); @@ -529,7 +720,7 @@ macro_rules! unimpl_operator_product_quantization { _: u32, _: u32, _: u32, - _: &[Self::Scalar], + _: &Vec2, _: Borrowed<'_, Self>, ) -> Vec { unimplemented!() @@ -542,7 +733,7 @@ macro_rules! unimpl_operator_product_quantization { _: u32, _: u32, _: u32, - _: &[Self::Scalar], + _: &Vec2, _: Borrowed<'_, Self>, ) -> (u32, f32, f32, Vec) { unimplemented!()