diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index d3209e872..6d7ff6d40 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -25,9 +25,20 @@ env: jobs: lint: + strategy: + matrix: + version: [15] runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Prepare + run: | + sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' + wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + sudo apt-get update + sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }} + cargo install cargo-pgrx --git https://github.com/tensorchord/pgrx.git --rev $(cat Cargo.toml | grep "pgrx =" | awk -F'rev = "' '{print $2}' | cut -d'"' -f1) + cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config - name: Format check run: cargo fmt --check - name: Semantic check diff --git a/Cargo.toml b/Cargo.toml index 251bbc620..c5728ce45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vectors" -version = "0.0.0" +version = "0.1.1" edition = "2021" [lib] @@ -18,7 +18,7 @@ pg_test = [] [dependencies] pgrx = { git = "https://github.com/tensorchord/pgrx.git", rev = "c0d11a8b78b0d707a5e9106bc4d5f66395ca9a2e" } -openai_api_rust = "0.1.8" +openai_api_rust = { git = "https://github.com/tensorchord/openai-api.git", rev = "228d54b6002e98257b3c81501a054942342f585f" } static_assertions = "1.1.0" libc = "~0.2" serde = "1.0.163" @@ -33,15 +33,11 @@ dashmap = "5.4.0" parking_lot = "0.12.1" memoffset = "0.9.0" serde_json = "1" -tokio = { version = "1", features = ["full"] } thiserror = "1.0.40" -anyhow = { version = "1.0.71", features = ["backtrace"] } -async-channel = "1.8.0" tempfile = "3.6.0" cstr = "0.2.11" arrayvec = { version = "0.7.3", features = ["serde"] } memmap2 = "0.7.0" -tokio-stream = { version = "0.1.14", features = ["fs"] } validator = { version = "0.16.1", features = ["derive"] } toml = "0.7.6" diff --git a/README.md b/README.md index 338fb94a1..fbb37fe11 100644 --- a/README.md +++ b/README.md @@ -104,13 +104,18 @@ You can then populate the table with vector data as follows. INSERT INTO items (embedding) VALUES ('[1,2,3]'), ('[4,5,6]'); + +-- or insert values using a casting from array to vector + +INSERT INTO items (embedding) +VALUES (ARRAY[1, 2, 3]::real[]), (ARRAY[4, 5, 6]::real[]); ``` We support three operators to calculate the distance between two vectors. - `<->`: squared Euclidean distance, defined as $\Sigma (x_i - y_i) ^ 2$. - `<#>`: negative dot product distance, defined as $- \Sigma x_iy_i$. -- `<=>`: negative squared cosine distance, defined as $- \frac{(\Sigma x_iy_i)^2}{\Sigma x_i^2 \Sigma y_i^2}$. +- `<=>`: negative cosine distance, defined as $- \frac{\Sigma x_iy_i}{\sqrt{\Sigma x_i^2 \Sigma y_i^2}}$. ```sql -- call the distance function through operators @@ -142,12 +147,10 @@ You can create an index, using squared Euclidean distance with the following SQL CREATE INDEX ON items USING vectors (embedding l2_ops) WITH (options = $$ capacity = 2097152 -size_ram = 4294967296 -storage_vectors = "ram" +[vectors] +memmap = "ram" [algorithm.hnsw] -storage = "ram" -m = 32 -ef = 256 +memmap = "ram" $$); --- Or using IVFFlat algorithm. @@ -155,10 +158,10 @@ $$); CREATE INDEX ON items USING vectors (embedding l2_ops) WITH (options = $$ capacity = 2097152 -size_ram = 2147483648 -storage_vectors = "ram" +[vectors] +memmap = "ram" [algorithm.ivf] -storage = "ram" +memmap = "ram" nlist = 1000 nprobe = 10 $$); @@ -203,15 +206,14 @@ We utilize TOML syntax to express the index's configuration. Here's what each ke | Key | Type | Description | | ---------------------- | ------- | --------------------------------------------------------------------------------------------------------------------- | | capacity | integer | The index's capacity. The value should be greater than the number of rows in your table. | -| size_ram | integer | (Optional) The maximum amount of memory the persisent part of index can occupy. | -| size_disk | integer | (Optional) The maximum amount of disk-backed memory-mapped file size the persisent part of index can occupy. | -| storage_vectors | string | `ram` ensures that the vectors always stays in memory while `disk` suggests otherwise. | +| vectors | table | Configuration of background process vector storage. | +| vectors.memmap | string | (Optional) `ram` ensures that the vectors always stays in memory while `disk` suggests otherwise. | | algorithm.ivf | table | If this table is set, the IVF algorithm will be used for the index. | -| algorithm.ivf.storage | string | (Optional) `ram` ensures that the persisent part of algorithm always stays in memory while `disk` suggests otherwise. | -| algorithm.ivf.nlist | integer | (Optional) Number of cluster units. | -| algorithm.ivf.nprobe | integer | (Optional) Number of units to query. | +| algorithm.ivf.memmap | string | (Optional) `ram` ensures that the persisent part of algorithm always stays in memory while `disk` suggests otherwise. | +| algorithm.ivf.nlist | integer | Number of cluster units. | +| algorithm.ivf.nprobe | integer | Number of units to query. | | algorithm.hnsw | table | If this table is set, the HNSW algorithm will be used for the index. | -| algorithm.hnsw.storage | string | (Optional) `ram` ensures that the persisent part of algorithm always stays in memory while `disk` suggests otherwise. | +| algorithm.hnsw.memmap | string | (Optional) `ram` ensures that the persisent part of algorithm always stays in memory while `disk` suggests otherwise. | | algorithm.hnsw.m | integer | (Optional) Maximum degree of the node. | | algorithm.hnsw.ef | integer | (Optional) Search scope in building. | @@ -229,10 +231,10 @@ UPDATE documents SET embedding = ai_embedding_vector(content) WHERE length(embed CREATE INDEX ON documents USING vectors (embedding l2_ops) WITH (options = $$ capacity = 2097152 -size_ram = 4294967296 -storage_vectors = "ram" +[vectors] +memmap = "ram" [algorithm.hnsw] -storage = "ram" +memmap = "ram" m = 32 ef = 256 $$); diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 3bd252ccd..bd7fb5050 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,4 +1,4 @@ [toolchain] -channel = "nightly-2023-05-09" -components = ["rustfmt", "clippy"] +channel = "nightly-2023-08-03" +components = ["rustfmt", "clippy", "miri"] targets = ["x86_64-unknown-linux-gnu"] diff --git a/src/algorithms/flat.rs b/src/algorithms/flat.rs index dbe395d11..04b403832 100644 --- a/src/algorithms/flat.rs +++ b/src/algorithms/flat.rs @@ -1,49 +1,81 @@ -use crate::algorithms::Vectors; -use crate::memory::Address; +use super::utils::filtered_fixed_heap::FilteredFixedHeap; +use super::Algo; +use crate::bgworker::index::IndexOptions; +use crate::bgworker::storage::Storage; +use crate::bgworker::storage::StoragePreallocator; +use crate::bgworker::vectors::Vectors; use crate::prelude::*; -use crate::utils::fixed_heap::FixedHeap; use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; use std::sync::Arc; +use thiserror::Error; + +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum FlatError { + // +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FlatOptions {} -pub struct Flat { - distance: Distance, +pub struct Flat { vectors: Arc, + _maker: PhantomData, } -impl Algorithm for Flat { - type Options = FlatOptions; +impl Algo for Flat { + type Error = FlatError; - fn build(options: Options, vectors: Arc, _: usize) -> anyhow::Result { + type Save = (); + + fn prebuild(_: &mut StoragePreallocator, _: IndexOptions) -> Result<(), Self::Error> { + Ok(()) + } + + fn build( + _: &mut Storage, + _: IndexOptions, + vectors: Arc, + _: usize, + ) -> Result { Ok(Self { - distance: options.distance, vectors, + _maker: PhantomData, }) } - fn address(&self) -> Address { - Address::DANGLING - } + fn save(&self) {} - fn load(options: Options, vectors: Arc, _: Address) -> anyhow::Result { + fn load( + _: &mut Storage, + _: IndexOptions, + vectors: Arc, + _: (), + ) -> Result { Ok(Self { - distance: options.distance, vectors, + _maker: PhantomData, }) } - fn insert(&self, _: usize) -> anyhow::Result<()> { + fn insert(&self, _: usize) -> Result<(), FlatError> { Ok(()) } - fn search(&self, (vector, k): (Box<[Scalar]>, usize)) -> anyhow::Result> { - let mut result = FixedHeap::<(Scalar, u64)>::new(k); + fn search( + &self, + target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, FlatError> + where + F: FnMut(u64) -> bool, + { + let mut result = FilteredFixedHeap::new(k, filter); for i in 0..self.vectors.len() { let this_vector = self.vectors.get_vector(i); let this_data = self.vectors.get_data(i); - let dis = self.distance.distance(&vector, this_vector); + let dis = D::distance(&target, this_vector); result.push((dis, this_data)); } Ok(result.into_sorted_vec()) diff --git a/src/algorithms/flat_q.rs b/src/algorithms/flat_q.rs new file mode 100644 index 000000000..ac5227aca --- /dev/null +++ b/src/algorithms/flat_q.rs @@ -0,0 +1,126 @@ +use super::impls::quantization::*; +use super::utils::filtered_fixed_heap::FilteredFixedHeap; +use super::Algo; +use crate::bgworker::index::IndexOptions; +use crate::bgworker::storage::Storage; +use crate::bgworker::storage::StoragePreallocator; +use crate::bgworker::vectors::Vectors; +use crate::prelude::*; +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; +use std::sync::Arc; +use thiserror::Error; + +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum FlatQError { + #[error("Quantization {0}")] + Quantization(#[from] QuantizationError), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FlatQOptions { + pub memmap: Memmap, + pub sample_size: usize, +} + +pub struct FlatQ { + vectors: Arc, + implementation: QuantizationImpl, + _maker: PhantomData, +} + +impl Algo for FlatQ { + type Error = FlatQError; + + type Save = Q; + + fn prebuild( + storage: &mut StoragePreallocator, + options: IndexOptions, + ) -> Result<(), Self::Error> { + let flat_q_options = options.algorithm.clone().unwrap_flat_q(); + QuantizationImpl::::prebuild( + storage, + options.dims, + options.capacity, + flat_q_options.memmap, + )?; + Ok(()) + } + + fn build( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + n: usize, + ) -> Result { + let flat_q_options = options.algorithm.clone().unwrap_flat_q(); + let implementation = QuantizationImpl::new( + storage, + vectors.clone(), + options.dims, + n, + flat_q_options.sample_size, + options.capacity, + flat_q_options.memmap, + )?; + Ok(Self { + vectors, + implementation, + _maker: PhantomData, + }) + } + + fn save(&self) -> Q { + self.implementation.save() + } + + fn load( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + save: Q, + ) -> Result { + let flat_q_options = options.algorithm.clone().unwrap_flat_q(); + Ok(Self { + vectors: vectors.clone(), + implementation: QuantizationImpl::load( + storage, + vectors, + save, + options.capacity, + flat_q_options.memmap, + )?, + _maker: PhantomData, + }) + } + + fn insert(&self, x: usize) -> Result<(), FlatQError> { + self.implementation.insert(x)?; + Ok(()) + } + + fn search( + &self, + target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, FlatQError> + where + F: FnMut(u64) -> bool, + { + let mut result = FilteredFixedHeap::new(k, filter); + let vector = self.implementation.process(&target); + for i in 0..self.vectors.len() { + let this_vector = self.implementation.get_vector(i); + let this_data = self.vectors.get_data(i); + let dis = self.implementation.distance(&vector, this_vector); + result.push((dis, this_data)); + } + let mut output = Vec::new(); + for (i, j) in result.into_sorted_vec().into_iter() { + output.push((i, j)); + } + Ok(output) + } +} diff --git a/src/algorithms/hnsw.rs b/src/algorithms/hnsw.rs index 1ff93e36f..d4acf7482 100644 --- a/src/algorithms/hnsw.rs +++ b/src/algorithms/hnsw.rs @@ -1,16 +1,25 @@ use super::impls::hnsw::HnswImpl; -use crate::algorithms::Vectors; -use crate::memory::using; -use crate::memory::Address; +use super::Algo; +use crate::bgworker::index::IndexOptions; +use crate::bgworker::storage::Storage; +use crate::bgworker::storage::StoragePreallocator; +use crate::bgworker::vectors::Vectors; use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::Arc; +use thiserror::Error; + +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum HnswError { + // +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HnswOptions { - pub storage: Storage, + #[serde(default = "HnswOptions::default_memmap")] + pub memmap: Memmap, #[serde(default = "HnswOptions::default_build_threads")] pub build_threads: usize, #[serde(default = "HnswOptions::default_max_threads")] @@ -22,6 +31,9 @@ pub struct HnswOptions { } impl HnswOptions { + fn default_memmap() -> Memmap { + Memmap::Ram + } fn default_build_threads() -> usize { std::thread::available_parallelism().unwrap().get() } @@ -36,35 +48,51 @@ impl HnswOptions { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HnswPersistent { - address: Address, +pub struct Hnsw { + implementation: HnswImpl, } -pub struct Hnsw { - implementation: HnswImpl, -} +impl Algo for Hnsw { + type Error = HnswError; + + type Save = (); -impl Algorithm for Hnsw { - type Options = HnswOptions; + fn prebuild( + storage: &mut StoragePreallocator, + options: IndexOptions, + ) -> Result<(), Self::Error> { + let hnsw_options = options.algorithm.clone().unwrap_hnsw(); + HnswImpl::::prebuild( + storage, + options.capacity, + hnsw_options.m, + hnsw_options.memmap, + )?; + Ok(()) + } - fn build(options: Options, vectors: Arc, n: usize) -> anyhow::Result { + fn build( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + n: usize, + ) -> Result { let hnsw_options = options.algorithm.clone().unwrap_hnsw(); let implementation = HnswImpl::new( + storage, vectors, options.dims, - options.distance, options.capacity, hnsw_options.max_threads, hnsw_options.m, hnsw_options.ef_construction, - hnsw_options.storage, + hnsw_options.memmap, )?; let i = AtomicUsize::new(0); - using().scope(|scope| -> anyhow::Result<()> { + std::thread::scope(|scope| -> Result<(), HnswError> { let mut handles = Vec::new(); for _ in 0..hnsw_options.build_threads { - handles.push(scope.spawn(|| -> anyhow::Result<()> { + handles.push(scope.spawn(|| -> Result<(), HnswError> { loop { let i = i.fetch_add(1, Ordering::Relaxed); if i >= n { @@ -72,40 +100,50 @@ impl Algorithm for Hnsw { } implementation.insert(i)?; } - anyhow::Result::Ok(()) + Result::Ok(()) })); } for handle in handles.into_iter() { handle.join().unwrap()?; } - anyhow::Result::Ok(()) + Result::Ok(()) })?; Ok(Self { implementation }) } - fn address(&self) -> Address { - self.implementation.address - } + fn save(&self) {} - fn load(options: Options, vectors: Arc, address: Address) -> anyhow::Result { + fn load( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + (): (), + ) -> Result { let hnsw_options = options.algorithm.unwrap_hnsw(); let implementation = HnswImpl::load( + storage, vectors, - options.distance, options.dims, options.capacity, hnsw_options.max_threads, hnsw_options.m, hnsw_options.ef_construction, - address, - hnsw_options.storage, + hnsw_options.memmap, )?; Ok(Self { implementation }) } - fn insert(&self, insert: usize) -> anyhow::Result<()> { + fn insert(&self, insert: usize) -> Result<(), HnswError> { self.implementation.insert(insert) } - fn search(&self, search: (Box<[Scalar]>, usize)) -> anyhow::Result> { - self.implementation.search(search) + fn search( + &self, + target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, HnswError> + where + F: FnMut(u64) -> bool, + { + self.implementation.search(target, k, filter) } } diff --git a/src/algorithms/impls/kmeans.rs b/src/algorithms/impls/elkan_k_means.rs similarity index 88% rename from src/algorithms/impls/kmeans.rs rename to src/algorithms/impls/elkan_k_means.rs index 9a1074d8c..508b4238b 100644 --- a/src/algorithms/impls/kmeans.rs +++ b/src/algorithms/impls/elkan_k_means.rs @@ -1,26 +1,29 @@ use crate::prelude::*; -use crate::utils::vec2::Vec2; + +use crate::algorithms::utils::vec2::Vec2; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; +use std::marker::PhantomData; use std::ops::{Index, IndexMut}; -pub struct Kmeans { - distance: Distance, +pub struct ElkanKMeans { dims: u16, c: usize, - centroids: Vec2, + pub centroids: Vec2, lowerbound: Square, upperbound: Vec, assign: Vec, rand: StdRng, samples: Vec2, + _maker: PhantomData, } const DELTA: f32 = 1.0 / 1024.0; -impl Kmeans { - pub fn new(distance: Distance, dims: u16, c: usize, samples: Vec2) -> Self { +impl ElkanKMeans { + pub fn new(c: usize, samples: Vec2) -> Self { let n = samples.len(); + let dims = samples.dims(); let mut rand = StdRng::from_entropy(); let mut centroids = Vec2::new(dims, c); @@ -34,7 +37,7 @@ impl Kmeans { for i in 0..c { let mut sum = Scalar::Z; for j in 0..n { - let dis = distance.kmeans_distance(&samples[j], ¢roids[i]); + let dis = D::elkan_k_means_distance(&samples[j], ¢roids[i]); lowerbound[(j, i)] = dis; if dis * dis < weight[j] { weight[j] = dis * dis; @@ -72,7 +75,6 @@ impl Kmeans { } Self { - distance, dims, c, centroids, @@ -81,12 +83,13 @@ impl Kmeans { assign, rand, samples, + _maker: PhantomData, } } pub fn iterate(&mut self) -> bool { let c = self.c; - let f = |lhs: &[Scalar], rhs: &[Scalar]| self.distance.kmeans_distance(lhs, rhs); + let f = |lhs: &[Scalar], rhs: &[Scalar]| D::elkan_k_means_distance(lhs, rhs); let dims = self.dims; let samples = &self.samples; let rand = &mut self.rand; @@ -154,7 +157,7 @@ impl Kmeans { } // Step 4, 7 - let old = std::mem::replace(centroids, Vec2::new(dims, n)); + let old = std::mem::replace(centroids, Vec2::new(dims, c)); let mut count = vec![Scalar::Z; c]; centroids.fill(Scalar::Z); for i in 0..n { @@ -171,18 +174,6 @@ impl Kmeans { centroids[i][dim] /= count[i]; } } - /* - let mut sum = 0.0; - for i in 0..c { - sum += count[i].0 as f32; - } - let average = sum / c as f32; - let mut cov = 0.0; - for i in 0..c { - cov += (count[i].0 as f32 - average) * (count[i].0 as f32 - average); - } - log::error!("COV = {}", cov); - */ for i in 0..c { if count[i] != Scalar::Z { continue; @@ -210,7 +201,7 @@ impl Kmeans { count[o] = count[o] - count[i]; } for i in 0..c { - self.distance.kmeans_normalize(&mut centroids[i]); + D::elkan_k_means_normalize(&mut centroids[i]); } // Step 5, 6 diff --git a/src/algorithms/impls/hnsw.rs b/src/algorithms/impls/hnsw.rs index 2da142539..91fa26dd6 100644 --- a/src/algorithms/impls/hnsw.rs +++ b/src/algorithms/impls/hnsw.rs @@ -1,68 +1,130 @@ -use crate::algorithms::Vectors; -use crate::memory::Address; -use crate::memory::PBox; -use crate::memory::Persistent; -use crate::memory::Ptr; +use crate::algorithms::hnsw::HnswError; +use crate::bgworker::storage::Storage; +use crate::bgworker::storage::StoragePreallocator; +use crate::bgworker::storage_mmap::MmapBox; +use crate::bgworker::vectors::Vectors; use crate::prelude::*; -use crate::utils::parray::PArray; -use crate::utils::semaphore::Semaphore; -use crate::utils::unsafe_once::UnsafeOnce; + +use crate::algorithms::utils::filtered_fixed_heap::FilteredFixedHeap; +use crate::algorithms::utils::semaphore::Semaphore; use parking_lot::RwLock; +use parking_lot::RwLockReadGuard; +use parking_lot::RwLockWriteGuard; use rand::Rng; +use std::cell::UnsafeCell; use std::cmp::Reverse; use std::collections::BinaryHeap; +use std::marker::PhantomData; use std::ops::RangeInclusive; use std::sync::Arc; -type Vertex = PBox<[RwLock>]>; - -pub struct Root { - vertexs: PBox<[UnsafeOnce]>, - entry: RwLock>, +#[derive(Debug, Clone, Copy)] +struct VertexIndexer { + offset: usize, + capacity: usize, } -static_assertions::assert_impl_all!(Root: Persistent); +#[derive(Debug, Clone, Copy)] +struct EdgesIndexer { + offset: usize, + capacity: usize, + len: usize, +} -pub struct HnswImpl { - pub address: Address, - root: &'static Root, +pub struct HnswImpl { + indexers: MmapBox<[VertexIndexer]>, + vertexs: MmapBox<[RwLock]>, + edges: MmapBox<[UnsafeCell<(Scalar, usize)>]>, + entry: MmapBox>>, vectors: Arc, - distance: Distance, dims: u16, m: usize, ef_construction: usize, visited: Semaphore, - storage: Storage, + _maker: PhantomData, } -unsafe impl Send for HnswImpl {} -unsafe impl Sync for HnswImpl {} +unsafe impl Send for HnswImpl {} +unsafe impl Sync for HnswImpl {} -impl HnswImpl { +impl HnswImpl { + pub fn prebuild( + storage: &mut StoragePreallocator, + capacity: usize, + m: usize, + memmap: Memmap, + ) -> Result<(), HnswError> { + let len_indexers = capacity; + let len_vertexs = capacity * 2; + let len_edges = capacity * 2 * (2 * m); + storage.palloc_mmap_slice::(memmap, len_indexers); + storage.palloc_mmap_slice::>(memmap, len_vertexs); + storage.palloc_mmap_slice::>(memmap, len_edges); + storage.palloc_mmap::>>(memmap); + Ok(()) + } pub fn new( + storage: &mut Storage, vectors: Arc, dims: u16, - distance: Distance, capacity: usize, max_threads: usize, m: usize, ef_construction: usize, - storage: Storage, - ) -> anyhow::Result { - let ptr = PBox::new( - Root { - vertexs: unsafe { PBox::new_zeroed_slice(capacity, storage)?.assume_init() }, - entry: RwLock::new(None), - }, - storage, - )? - .into_raw(); + memmap: Memmap, + ) -> Result { + let len_indexers = capacity; + let len_vertexs = capacity * 2; + let len_edges = capacity * 2 * (2 * m); + let mut indexers = unsafe { + storage + .alloc_mmap_slice::(memmap, len_indexers) + .assume_init() + }; + let mut vertexs = unsafe { + storage + .alloc_mmap_slice::>(memmap, len_vertexs) + .assume_init() + }; + let edges = unsafe { + storage + .alloc_mmap_slice::>(memmap, len_edges) + .assume_init() + }; + let entry = unsafe { + let mut entry = storage.alloc_mmap::>>(memmap); + entry.write(RwLock::new(None)); + entry.assume_init() + }; + { + let mut offset_vertexs = 0usize; + let mut offset_edges = 0usize; + for i in 0..capacity { + let levels = generate_random_levels(m, 63); + let capacity_vertexs = levels as usize + 1; + for j in 0..=levels { + let capacity_edges = size_of_a_layer(m, j); + vertexs[offset_vertexs + j as usize] = RwLock::new(EdgesIndexer { + offset: offset_edges, + capacity: capacity_edges, + len: 0, + }); + offset_edges += capacity_edges; + } + indexers[i] = VertexIndexer { + offset: offset_vertexs, + capacity: capacity_vertexs, + }; + offset_vertexs += capacity_vertexs; + } + } Ok(Self { - address: ptr.address(), - root: unsafe { ptr.as_ref() }, + indexers, + vertexs, + edges, + entry, vectors, dims, - distance, visited: { let semaphore = Semaphore::::new(); for _ in 0..max_threads { @@ -72,25 +134,28 @@ impl HnswImpl { }, m, ef_construction, - storage, + _maker: PhantomData, }) } pub fn load( + storage: &mut Storage, vectors: Arc, - distance: Distance, dims: u16, capacity: usize, max_threads: usize, m: usize, ef_construction: usize, - address: Address, - storage: Storage, - ) -> anyhow::Result { + memmap: Memmap, + ) -> Result { + let len_indexers = capacity; + let len_vertexs = capacity * 2; + let len_edges = capacity * 2 * (2 * m); Ok(Self { - address, - root: unsafe { Ptr::new(address, ()).as_ref() }, + indexers: unsafe { storage.alloc_mmap_slice(memmap, len_indexers).assume_init() }, + vertexs: unsafe { storage.alloc_mmap_slice(memmap, len_vertexs).assume_init() }, + edges: unsafe { storage.alloc_mmap_slice(memmap, len_edges).assume_init() }, + entry: unsafe { storage.alloc_mmap(memmap).assume_init() }, vectors, - distance, dims, m, ef_construction, @@ -101,48 +166,74 @@ impl HnswImpl { } semaphore }, - storage, + _maker: PhantomData, }) } - pub fn search( + pub fn search( &self, - (x_vector, k): (Box<[Scalar]>, usize), - ) -> anyhow::Result> { - anyhow::ensure!(x_vector.len() == self.dims as usize); - let entry = *self.root.entry.read(); - let Some(u) = entry else { return Ok(Vec::new()) }; + target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, HnswError> + where + F: FnMut(u64) -> bool, + { + assert!(target.len() == self.dims as usize); + let entry = *self.entry.read(); + let Some(u) = entry else { + return Ok(Vec::new()); + }; let top = self._levels(u); - let u = self._go(1..=top, u, &x_vector); + let u = self._go(1..=top, u, &target); let mut visited = self.visited.acquire(); - let mut result = self._search(&mut visited, &x_vector, u, k, 0); - result.sort(); - Ok(result - .iter() - .map(|&(score, u)| (score, self.vectors.get_data(u))) - .collect::>()) + let result = self._filtered_search(&mut visited, &target, u, k, 0, filter); + Ok(result) } - pub fn insert(&self, x: usize) -> anyhow::Result<()> { + pub fn insert(&self, x: usize) -> Result<(), HnswError> { let mut visited = self.visited.acquire(); self._insert(&mut visited, x) } - fn _go(&self, levels: RangeInclusive, u: usize, target: &[Scalar]) -> usize { - let mut u = u; + fn _vertex(&self, i: usize) -> &[RwLock] { + let VertexIndexer { offset, capacity } = self.indexers[i]; + &self.vertexs[offset..][..capacity] + } + fn _edges<'a>(&self, guard: &'a RwLockReadGuard) -> &'a [(Scalar, usize)] { + unsafe { + let raw = self.edges[guard.offset..][..guard.len].as_ptr(); + std::slice::from_raw_parts(raw.cast(), guard.len) + } + } + #[allow(clippy::needless_pass_by_ref_mut)] + fn _edges_mut<'a>( + &self, + guard: &'a mut RwLockWriteGuard, + ) -> &'a mut [(Scalar, usize)] { + unsafe { + let raw = self.edges[guard.offset..][..guard.len].as_ptr(); + std::slice::from_raw_parts_mut(raw.cast_mut().cast(), guard.len) + } + } + fn _edges_clear(&self, guard: &mut RwLockWriteGuard) { + guard.len = 0; + } + fn _edges_append(&self, guard: &mut RwLockWriteGuard, data: (Scalar, usize)) { + if guard.capacity == guard.len { + panic!("Array is full. The capacity is {}.", guard.capacity); + } unsafe { - std::intrinsics::prefetch_read_data(target.as_ptr(), 3); + self.edges[guard.offset + guard.len].get().write(data); } + guard.len += 1; + } + fn _go(&self, levels: RangeInclusive, u: usize, target: &[Scalar]) -> usize { + let mut u = u; let mut u_dis = self._dist0(u, target); for i in levels.rev() { let mut changed = true; while changed { changed = false; - unsafe { - std::intrinsics::prefetch_read_data( - self.vectors.get_vector(u).as_ref().as_ptr(), - 3, - ); - } - let guard = self.root.vertexs[u][i as usize].read(); - for (_, v) in guard.iter().copied() { + let guard = self._vertex(u)[i as usize].read(); + for &(_, v) in self._edges(&guard).iter() { let v_dis = self._dist0(v, target); if v_dis < u_dis { u = v; @@ -154,10 +245,9 @@ impl HnswImpl { } u } - fn _insert(&self, visited: &mut Visited, insert: usize) -> anyhow::Result<()> { - let vertexs = self.root.vertexs.as_ref(); - let vector = self.vectors.get_vector(insert); - let levels = generate_random_levels(self.m, 63); + fn _insert(&self, visited: &mut Visited, id: usize) -> Result<(), HnswError> { + let target = self.vectors.get_vector(id); + let levels = self._levels(id); let entry; let lock = { let cond = move |global: Option| { @@ -167,10 +257,10 @@ impl HnswImpl { true } }; - let lock = self.root.entry.read(); + let lock = self.entry.read(); if cond(*lock) { drop(lock); - let lock = self.root.entry.write(); + let lock = self.entry.write(); entry = *lock; if cond(*lock) { Some(lock) @@ -183,115 +273,76 @@ impl HnswImpl { } }; let Some(mut u) = entry else { - let vertex = { - let mut vertex = PBox::new_uninit_slice(1 + levels as usize, self.storage)?; - for i in 0..=levels { - let array = PArray::new(1 + size_of_a_layer(self.m, i), self.storage)?; - vertex[i as usize].write(RwLock::new(array)); - } - unsafe { vertex.assume_init() } - }; - vertexs[insert].set(vertex); - *lock.unwrap() = Some(insert); + if let Some(mut lock) = lock { + *lock = Some(id); + } return Ok(()); }; let top = self._levels(u); if top > levels { - u = self._go(levels + 1..=top, u, vector); + u = self._go(levels + 1..=top, u, target); } let mut layers = Vec::with_capacity(1 + levels as usize); for i in (0..=std::cmp::min(levels, top)).rev() { - let mut layer = self._search(visited, vector, u, self.ef_construction, i); - layer.sort(); - self._select0(&mut layer, size_of_a_layer(self.m, i))?; - u = layer.first().unwrap().1; - layers.push(layer); + let mut edges = self._search(visited, target, u, self.ef_construction, i); + edges.sort(); + edges = self._select(edges, size_of_a_layer(self.m, i))?; + u = edges.first().unwrap().1; + layers.push(edges); } layers.reverse(); layers.resize_with(1 + levels as usize, Vec::new); let backup = layers.iter().map(|x| x.to_vec()).collect::>(); - let vertex = { - let mut vertex = PBox::new_uninit_slice(1 + levels as usize, self.storage)?; - for i in 0..=levels { - let mut array = PArray::new(1 + size_of_a_layer(self.m, i), self.storage)?; - for &x in layers[i as usize].iter() { - array.push(x)?; - } - vertex[i as usize].write(RwLock::new(array)); + for i in 0..=levels { + let mut guard = self._vertex(id)[i as usize].write(); + let edges = layers[i as usize].as_slice(); + self._edges_clear(&mut guard); + for &edge in edges { + self._edges_append(&mut guard, edge); } - unsafe { vertex.assume_init() } - }; - vertexs[insert].set(vertex); + } for (i, layer) in backup.into_iter().enumerate() { let i = i as u8; for (n_dis, n) in layer.iter().copied() { - let mut guard = vertexs[n][i as usize].write(); - orderedly_insert(&mut guard, (n_dis, insert))?; - self._select1(&mut guard, size_of_a_layer(self.m, i))?; + let mut guard = self._vertex(n)[i as usize].write(); + let element = (n_dis, id); + let mut edges = self._edges_mut(&mut guard).to_vec(); + let (Ok(index) | Err(index)) = edges.binary_search(&element); + edges.insert(index, element); + edges = self._select(edges, size_of_a_layer(self.m, i))?; + self._edges_clear(&mut guard); + for &edge in edges.iter() { + self._edges_append(&mut guard, edge); + } } } if let Some(mut lock) = lock { - *lock = Some(insert); - } - Ok(()) - } - fn _select0(&self, v: &mut Vec<(Scalar, usize)>, size: usize) -> anyhow::Result<()> { - unsafe { - std::intrinsics::prefetch_read_data(v.as_ptr(), 3); - } - if v.len() <= size { - return Ok(()); - } - let cloned = v.to_vec(); - v.clear(); - for (u_dis, u) in cloned.iter().copied() { - if v.len() == size { - break; - } - unsafe { - std::intrinsics::prefetch_read_data( - self.vectors.get_vector(u).as_ref().as_ptr(), - 3, - ); - } - let check = v - .iter() - .map(|&(_, v)| self._dist1(u, v)) - .all(|dist| dist > u_dis); - if check { - v.push((u_dis, u)); - } + *lock = Some(id); } Ok(()) } - fn _select1(&self, v: &mut PArray<(Scalar, usize)>, size: usize) -> anyhow::Result<()> { - unsafe { - std::intrinsics::prefetch_read_data(v.as_ptr(), 3); - } - if v.len() <= size { - return Ok(()); + fn _select( + &self, + input: Vec<(Scalar, usize)>, + size: usize, + ) -> Result, HnswError> { + if input.len() <= size { + return Ok(input); } - let cloned = v.to_vec(); - v.clear(); - for (u_dis, u) in cloned.iter().copied() { - if v.len() == size { + let mut res = Vec::new(); + for (u_dis, u) in input.iter().copied() { + if res.len() == size { break; } - unsafe { - std::intrinsics::prefetch_read_data( - self.vectors.get_vector(u).as_ref().as_ptr(), - 3, - ); - } - let check = v + let check = res .iter() .map(|&(_, v)| self._dist1(u, v)) .all(|dist| dist > u_dis); if check { - v.push((u_dis, u)).unwrap(); + res.push((u_dis, u)); } } - Ok(()) + Ok(res) } fn _search( &self, @@ -302,9 +353,6 @@ impl HnswImpl { i: u8, ) -> Vec<(Scalar, usize)> { assert!(k > 0); - unsafe { - std::intrinsics::prefetch_read_data(target.as_ptr(), 3); - } let mut bound = Scalar::INFINITY; let mut visited = visited.new_version(); let mut candidates = BinaryHeap::>::new(); @@ -323,18 +371,12 @@ impl HnswImpl { if u_dis > bound { break; } - let guard = self.root.vertexs[u][i as usize].read(); - for (_, v) in guard.iter().copied() { + let guard = self._vertex(u)[i as usize].read(); + for &(_, v) in self._edges(&guard).iter() { if visited.test(v) { continue; } visited.set(v); - unsafe { - std::intrinsics::prefetch_read_data( - self.vectors.get_vector(v).as_ref().as_ptr(), - 3, - ); - } let v_dis = self._dist0(v, target); if v_dis > bound { continue; @@ -351,17 +393,57 @@ impl HnswImpl { } results.into_vec() } + fn _filtered_search( + &self, + visited: &mut Visited, + target: &[Scalar], + s: usize, + k: usize, + i: u8, + filter: F, + ) -> Vec<(Scalar, u64)> + where + F: FnMut(u64) -> bool, + { + assert!(k > 0); + let mut visited = visited.new_version(); + let mut candidates = BinaryHeap::>::new(); + let mut results = FilteredFixedHeap::new(k, filter); + let s_dis = self._dist0(s, target); + visited.set(s); + candidates.push(Reverse((s_dis, s))); + results.push((s_dis, self.vectors.get_data(s))); + while let Some(Reverse((u_dis, u))) = candidates.pop() { + if u_dis > results.bound() { + break; + } + let guard = self._vertex(u)[i as usize].read(); + for &(_, v) in self._edges(&guard).iter() { + if visited.test(v) { + continue; + } + visited.set(v); + let v_dis = self._dist0(v, target); + if v_dis > results.bound() { + continue; + } + candidates.push(Reverse((v_dis, v))); + results.push((v_dis, self.vectors.get_data(v))); + } + } + results.into_sorted_vec() + } fn _dist0(&self, u: usize, target: &[Scalar]) -> Scalar { let u = self.vectors.get_vector(u); - self.distance.distance(u, target) + D::distance(u, target) } fn _dist1(&self, u: usize, v: usize) -> Scalar { let u = self.vectors.get_vector(u); let v = self.vectors.get_vector(v); - self.distance.distance(u, v) + D::distance(u, v) } fn _levels(&self, u: usize) -> u8 { - self.root.vertexs[u].len() as u8 - 1 + self._vertex(u).len() as u8 - 1 } } @@ -380,12 +462,6 @@ fn size_of_a_layer(m: usize, i: u8) -> usize { } } -pub fn orderedly_insert(a: &mut PArray, element: T) -> anyhow::Result { - let (Ok(index) | Err(index)) = a.binary_search(&element); - a.insert(index, element)?; - Ok(index) -} - pub struct Visited { version: usize, data: Box<[usize]>, diff --git a/src/algorithms/impls/ivf.rs b/src/algorithms/impls/ivf.rs index a147ea927..465d34e4d 100644 --- a/src/algorithms/impls/ivf.rs +++ b/src/algorithms/impls/ivf.rs @@ -1,45 +1,53 @@ -use super::kmeans::Kmeans; -use crate::algorithms::Vectors; -use crate::memory::Address; -use crate::memory::PBox; -use crate::memory::Persistent; -use crate::memory::Ptr; +use super::elkan_k_means::ElkanKMeans; +use crate::algorithms::ivf::IvfError; +use crate::algorithms::utils::filtered_fixed_heap::FilteredFixedHeap; +use crate::algorithms::utils::fixed_heap::FixedHeap; +use crate::algorithms::utils::mmap_vec2::MmapVec2; +use crate::algorithms::utils::vec2::Vec2; +use crate::bgworker::storage::{Storage, StoragePreallocator}; +use crate::bgworker::storage_mmap::MmapBox; +use crate::bgworker::vectors::Vectors; use crate::prelude::*; -use crate::utils::fixed_heap::FixedHeap; -use crate::utils::unsafe_once::UnsafeOnce; -use crate::utils::vec2::Vec2; use crossbeam::atomic::AtomicCell; use rand::seq::index::sample; use rand::thread_rng; +use std::cell::UnsafeCell; +use std::marker::PhantomData; use std::sync::Arc; -struct List { - centroid: PBox<[Scalar]>, - head: AtomicCell>, -} - type Vertex = Option; -pub struct Root { - lists: PBox<[List]>, - vertexs: PBox<[UnsafeOnce]>, -} - -static_assertions::assert_impl_all!(Root: Persistent); - -pub struct IvfImpl { - pub address: Address, - root: &'static Root, +pub struct IvfImpl { + centroids: MmapVec2, + heads: MmapBox<[AtomicCell>]>, + vertexs: MmapBox<[UnsafeCell]>, + // vectors: Arc, - distance: Distance, nprobe: usize, + nlist: usize, + _maker: PhantomData, } -impl IvfImpl { +unsafe impl Send for IvfImpl {} +unsafe impl Sync for IvfImpl {} + +impl IvfImpl { + pub fn prebuild( + storage: &mut StoragePreallocator, + dims: u16, + nlist: usize, + capacity: usize, + memmap: Memmap, + ) -> Result<(), IvfError> { + MmapVec2::prebuild(storage, dims, nlist); + storage.palloc_mmap_slice::>>(memmap, nlist); + storage.palloc_mmap_slice::>(memmap, capacity); + Ok(()) + } pub fn new( + storage: &mut Storage, vectors: Arc, dims: u16, - distance: Distance, n: usize, nlist: usize, nsample: usize, @@ -47,115 +55,127 @@ impl IvfImpl { least_iterations: usize, iterations: usize, capacity: usize, - storage: Storage, - ) -> anyhow::Result { + memmap: Memmap, + ) -> Result { let m = std::cmp::min(nsample, n); let f = sample(&mut thread_rng(), n, m).into_vec(); let mut samples = Vec2::new(dims, m); for i in 0..m { samples[i].copy_from_slice(vectors.get_vector(f[i])); - distance.kmeans_normalize(&mut samples[i]); + D::elkan_k_means_normalize(&mut samples[i]); } - let mut kmeans = Kmeans::new(distance, dims, nlist, samples); + let mut k_means = ElkanKMeans::::new(nlist, samples); for _ in 0..least_iterations { - kmeans.iterate(); + k_means.iterate(); } for _ in least_iterations..iterations { - if kmeans.iterate() { + if k_means.iterate() { break; } } - let centroids = kmeans.finish(); - let ptr = PBox::new( - Root { - lists: { - let mut lists = PBox::new_zeroed_slice(nlist, storage)?; - for i in 0..nlist { - lists[i].write(List { - centroid: { - let mut centroid = unsafe { - PBox::new_zeroed_slice(dims as _, storage)?.assume_init() - }; - centroid.copy_from_slice(¢roids[i]); - centroid - }, - head: AtomicCell::new(None), - }); - } - unsafe { lists.assume_init() } - }, - vertexs: { - let vertexs = PBox::new_zeroed_slice(capacity, storage)?; - unsafe { vertexs.assume_init() } - }, - }, - storage, - )? - .into_raw(); + let k_means = k_means.finish(); + let centroids = { + let mut centroids = MmapVec2::build(storage, dims, nlist); + for i in 0..nlist { + centroids[i].copy_from_slice(&k_means[i]); + } + centroids + }; + let heads = { + let mut heads = storage.alloc_mmap_slice(memmap, nlist); + for i in 0..nlist { + heads[i].write(AtomicCell::new(None)); + } + unsafe { heads.assume_init() } + }; + let vertexs = { + let mut vertexs = storage.alloc_mmap_slice(memmap, capacity); + for i in 0..capacity { + vertexs[i].write(UnsafeCell::new(None)); + } + unsafe { vertexs.assume_init() } + }; Ok(Self { - address: ptr.address(), - root: unsafe { ptr.as_ref() }, + centroids, + heads, + vertexs, + // vectors, - distance, nprobe, + nlist, + _maker: PhantomData, }) } pub fn load( + storage: &mut Storage, + dims: u16, vectors: Arc, - distance: Distance, - address: Address, + nlist: usize, nprobe: usize, - ) -> anyhow::Result { + capacity: usize, + memmap: Memmap, + ) -> Result { Ok(Self { - address, - root: unsafe { Ptr::new(address, ()).as_ref() }, + centroids: MmapVec2::load(storage, dims, nlist), + heads: unsafe { storage.alloc_mmap_slice(memmap, nlist).assume_init() }, + vertexs: unsafe { storage.alloc_mmap_slice(memmap, capacity).assume_init() }, vectors, - distance, nprobe, + nlist, + _maker: PhantomData, }) } - pub fn search( + pub fn search( &self, - (mut x_vector, k): (Box<[Scalar]>, usize), - ) -> anyhow::Result> { + mut target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, IvfError> + where + F: FnMut(u64) -> bool, + { let vectors = self.vectors.as_ref(); - self.distance.kmeans_normalize(&mut x_vector); + D::elkan_k_means_normalize(&mut target); let mut lists = FixedHeap::new(self.nprobe); - for (i, list) in self.root.lists.iter().enumerate() { - let dis = self.distance.kmeans_distance(&x_vector, &list.centroid); + for i in 0..self.nlist { + let centroid = &self.centroids[i]; + let dis = D::elkan_k_means_distance(&target, centroid); lists.push((dis, i)); } - let mut result = FixedHeap::new(k); + let mut result = FilteredFixedHeap::new(k, filter); for (_, i) in lists.into_vec().into_iter() { - let mut cursor = self.root.lists[i].head.load(); + let mut cursor = self.heads[i].load(); while let Some(u) = cursor { let u_vector = vectors.get_vector(u); let u_data = vectors.get_data(u); - let u_dis = self.distance.distance(&x_vector, u_vector); + let u_dis = D::distance(&target, u_vector); result.push((u_dis, u_data)); - cursor = *self.root.vertexs[u]; + cursor = unsafe { *self.vertexs[u].get() }; } } Ok(result.into_sorted_vec()) } - pub fn insert(&self, x: usize) -> anyhow::Result<()> { + pub fn insert(&self, x: usize) -> Result<(), IvfError> { self._insert(x)?; Ok(()) } - pub fn _insert(&self, x: usize) -> anyhow::Result<()> { - let vertexs = self.root.vertexs.as_ref(); - let mut x_vector = self.vectors.get_vector(x).to_vec(); - self.distance.kmeans_normalize(&mut x_vector); + pub fn _insert(&self, x: usize) -> Result<(), IvfError> { + let vertexs = self.vertexs.as_ref(); + let mut target = self.vectors.get_vector(x).to_vec(); + D::elkan_k_means_normalize(&mut target); let mut result = (Scalar::INFINITY, 0); - for (i, list) in self.root.lists.iter().enumerate() { - let dis = self.distance.kmeans_distance(&x_vector, &list.centroid); + for i in 0..self.nlist { + let centroid = &self.centroids[i]; + let dis = D::elkan_k_means_distance(&target, centroid); result = std::cmp::min(result, (dis, i)); } loop { - let next = self.root.lists[result.1].head.load(); - vertexs[x].set(next); - let list = &self.root.lists[result.1]; - if list.head.compare_exchange(next, Some(x)).is_ok() { + let next = self.heads[result.1].load(); + unsafe { + vertexs[x].get().write(next); + } + let head = &self.heads[result.1]; + if head.compare_exchange(next, Some(x)).is_ok() { break; } } diff --git a/src/algorithms/impls/mod.rs b/src/algorithms/impls/mod.rs index 63ad1fcaf..af253944b 100644 --- a/src/algorithms/impls/mod.rs +++ b/src/algorithms/impls/mod.rs @@ -1,3 +1,4 @@ +pub mod elkan_k_means; pub mod hnsw; pub mod ivf; -pub mod kmeans; +pub mod quantization; diff --git a/src/algorithms/impls/quantization.rs b/src/algorithms/impls/quantization.rs new file mode 100644 index 000000000..f25b42ea3 --- /dev/null +++ b/src/algorithms/impls/quantization.rs @@ -0,0 +1,320 @@ +use super::elkan_k_means::ElkanKMeans; +use crate::algorithms::utils::vec2::Vec2; +use crate::bgworker::storage::{Storage, StoragePreallocator}; +use crate::bgworker::storage_mmap::MmapBox; +use crate::bgworker::vectors::Vectors; +use crate::prelude::*; +use rand::seq::index::sample; +use rand::thread_rng; +use serde::{Deserialize, Serialize}; +use std::cell::UnsafeCell; +use std::marker::PhantomData; +use std::mem::MaybeUninit; +use std::ops::{Index, IndexMut}; +use std::sync::Arc; +use thiserror::Error; + +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum QuantizationError { + // +} + +pub struct QuantizationImpl { + data: MmapBox<[UnsafeCell>]>, + vectors: Arc, + width: usize, + quantization: Q, +} + +impl QuantizationImpl { + pub fn prebuild( + storage: &mut StoragePreallocator, + dims: u16, + capacity: usize, + memmap: Memmap, + ) -> Result<(), QuantizationError> { + let width = Q::width_dims(dims); + storage.palloc_mmap_slice::>>(memmap, width * capacity); + Ok(()) + } + pub fn new( + storage: &mut Storage, + vectors: Arc, + dims: u16, + n: usize, + nsample: usize, + capacity: usize, + memmap: Memmap, + ) -> Result { + let m = std::cmp::min(n, nsample); + let f = sample(&mut thread_rng(), n, m).into_vec(); + let mut samples = Vec2::new(dims, m); + for i in 0..m { + samples[i].copy_from_slice(vectors.get_vector(f[i])); + } + let quantization = Q::build(samples); + let width = quantization.width(); + let data = unsafe { + storage + .alloc_mmap_slice(memmap, width * capacity) + .assume_init() + }; + for i in 0..n { + let p = quantization.process(vectors.get_vector(i)); + unsafe { + std::ptr::copy_nonoverlapping( + p.as_ptr(), + data[i * width..][..width].as_ptr() as *mut u8, + width, + ); + } + } + Ok(Self { + data, + width, + vectors, + quantization, + }) + } + pub fn save(&self) -> Q { + self.quantization.clone() + } + pub fn load( + storage: &mut Storage, + vectors: Arc, + quantization: Q, + capacity: usize, + memmap: Memmap, + ) -> Result { + let width = quantization.width(); + Ok(Self { + data: unsafe { + storage + .alloc_mmap_slice(memmap, width * capacity) + .assume_init() + }, + vectors, + width: quantization.width(), + quantization, + }) + } + pub fn insert(&self, x: usize) -> Result<(), QuantizationError> { + let p = self.quantization.process(self.vectors.get_vector(x)); + unsafe { + std::ptr::copy_nonoverlapping( + p.as_ptr(), + self.data[x * self.width..][..self.width].as_ptr() as *mut u8, + self.width, + ); + } + Ok(()) + } + pub fn process(&self, vector: &[Scalar]) -> Vec { + self.quantization.process(vector) + } + pub fn distance(&self, lhs: &[u8], rhs: &[u8]) -> Scalar { + self.quantization.distance(lhs, rhs) + } + pub fn get_vector(&self, i: usize) -> &[u8] { + unsafe { assume_immutable_init(&self.data[i * self.width..][..self.width]) } + } +} + +pub trait Quantization: Clone + serde::Serialize + for<'a> serde::Deserialize<'a> { + fn build(samples: Vec2) -> Self + where + Self: Sized; + fn process(&self, point: &[Scalar]) -> Vec; + fn distance(&self, lhs: &[u8], rhs: &[u8]) -> Scalar; + fn width_dims(dims: u16) -> usize; + fn width(&self) -> usize; +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ScalarQuantization { + dims: u16, + max: Vec, + min: Vec, + _maker: PhantomData, +} + +impl Quantization for ScalarQuantization { + fn build(samples: Vec2) -> Self { + let dims = samples.dims(); + let mut max = vec![Scalar::NEG_INFINITY; dims as _]; + let mut min = vec![Scalar::INFINITY; dims as _]; + for i in 0..samples.len() { + for j in 0..dims as usize { + max[j] = std::cmp::max(max[j], samples[i][j]); + min[j] = std::cmp::max(min[j], samples[i][j]); + } + } + Self { + dims, + max, + min, + _maker: PhantomData, + } + } + + fn process(&self, vector: &[Scalar]) -> Vec { + let dims = self.dims; + assert!(dims as usize == vector.len()); + let mut result = vec![0u8; dims as usize]; + for i in 0..dims as usize { + let w = ((vector[i] - self.min[i]) / (self.max[i] - self.min[i]) * 256.0).0 as u32; + result[i] = w.clamp(0, 255) as u8; + } + result + } + + fn distance(&self, lhs: &[u8], rhs: &[u8]) -> Scalar { + let dims = self.dims; + assert!(dims as usize == lhs.len()); + assert!(dims as usize == rhs.len()); + let mut result = D::QUANTIZATION_INITIAL_STATE; + for i in 0..dims as usize { + let lhs = Scalar(lhs[i] as Float) * (self.max[i] - self.min[i]) + self.min[i]; + let rhs = Scalar(rhs[i] as Float) * (self.max[i] - self.min[i]) + self.min[i]; + result = D::quantization_merge(result, D::quantization_new(&[lhs], &[rhs])); + } + D::quantization_finish(result) + } + + fn width_dims(dims: u16) -> usize { + dims as usize + } + + fn width(&self) -> usize { + self.dims as usize + } +} + +const DIV: u16 = 2; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ProductQuantization { + dims: u16, + centroids: Vec, + matrixs: Vec>, +} + +impl Quantization for ProductQuantization { + fn build(samples: Vec2) -> Self { + let n = samples.len(); + let dims = samples.dims(); + let width = dims.div_ceil(DIV); + let mut centroids = Vec::with_capacity(width as usize); + let mut matrixs = Vec::with_capacity(width as usize); + for i in 0..width { + let subdims = std::cmp::min(DIV, dims - DIV * i); + let mut subsamples = Vec2::new(subdims, n); + for j in 0..n { + let src = &samples[j][(i * DIV) as usize..][..subdims as usize]; + subsamples[j].copy_from_slice(src); + } + let mut k_means = ElkanKMeans::::new(256, subsamples); + for _ in 0..200 { + if k_means.iterate() { + break; + } + } + let centroid = k_means.finish(); + let mut matrix = ProductQuantizationMatrix::new(D::QUANTIZATION_INITIAL_STATE); + for i in 0u8..=255 { + for j in i..=255 { + let state = D::quantization_new(¢roid[i as usize], ¢roid[j as usize]); + matrix[(i, j)] = state; + matrix[(j, i)] = state; + } + } + centroids.push(centroid); + matrixs.push(matrix); + } + Self { + dims, + centroids, + matrixs, + } + } + + fn process(&self, vector: &[Scalar]) -> Vec { + let dims = self.dims; + assert!(dims as usize == vector.len()); + let width = dims.div_ceil(DIV); + let mut result = Vec::with_capacity(width as usize); + for i in 0..width { + let subdims = std::cmp::min(DIV, dims - DIV * i); + let mut minimal = Scalar::INFINITY; + let mut target = 0u8; + for j in 0u8..=255 { + let left = &vector[(i * DIV) as usize..][..subdims as usize]; + let right = &self.centroids[i as usize][j as usize]; + let dis = L2::distance(left, right); + if dis < minimal { + minimal = dis; + target = j; + } + } + result.push(target); + } + result + } + + fn distance(&self, lhs: &[u8], rhs: &[u8]) -> Scalar { + let dims = self.dims; + let width = dims.div_ceil(DIV); + assert!(lhs.len() == width as usize); + assert!(rhs.len() == width as usize); + let mut result = D::QUANTIZATION_INITIAL_STATE; + for i in 0..width { + let delta = self.matrixs[i as usize][(lhs[i as usize], rhs[i as usize])]; + result = D::quantization_merge(result, delta); + } + D::quantization_finish(result) + } + + fn width_dims(dims: u16) -> usize { + dims.div_ceil(DIV) as usize + } + + fn width(&self) -> usize { + self.dims.div_ceil(DIV) as usize + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct ProductQuantizationMatrix(Box<[T]>); + +impl ProductQuantizationMatrix { + pub fn new(initial: T) -> Self { + let mut inner = Box::<[T]>::new_uninit_slice(65536); + unsafe { + let ptr = inner.as_mut_ptr() as *mut T; + for i in 0..65536 { + ptr.add(i).write(initial); + } + } + let inner = unsafe { inner.assume_init() }; + Self(inner) + } +} + +impl Index<(u8, u8)> for ProductQuantizationMatrix { + type Output = T; + + fn index(&self, (x, y): (u8, u8)) -> &Self::Output { + &self.0[x as usize * 256 + y as usize] + } +} + +impl IndexMut<(u8, u8)> for ProductQuantizationMatrix { + fn index_mut(&mut self, (x, y): (u8, u8)) -> &mut Self::Output { + &mut self.0[x as usize * 256 + y as usize] + } +} + +unsafe fn assume_immutable_init(slice: &[UnsafeCell>]) -> &[T] { + let p = slice.as_ptr().cast::>() as *const T; + std::slice::from_raw_parts(p, slice.len()) +} diff --git a/src/algorithms/ivf.rs b/src/algorithms/ivf.rs index f1e66a4d1..00f59fb90 100644 --- a/src/algorithms/ivf.rs +++ b/src/algorithms/ivf.rs @@ -1,27 +1,39 @@ use super::impls::ivf::IvfImpl; -use crate::algorithms::Vectors; -use crate::memory::using; -use crate::memory::Address; +use super::Algo; +use crate::bgworker::index::IndexOptions; +use crate::bgworker::storage::Storage; +use crate::bgworker::storage::StoragePreallocator; +use crate::bgworker::vectors::Vectors; use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::Arc; +use thiserror::Error; + +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum IvfError { + // +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct IvfOptions { - pub storage: Storage, + #[serde(default = "IvfOptions::default_memmap")] + pub memmap: Memmap, #[serde(default = "IvfOptions::default_build_threads")] pub build_threads: usize, - pub nlist: usize, - pub nprobe: usize, #[serde(default = "IvfOptions::default_least_iterations")] pub least_iterations: usize, #[serde(default = "IvfOptions::default_iterations")] pub iterations: usize, + pub nlist: usize, + pub nprobe: usize, } impl IvfOptions { + fn default_memmap() -> Memmap { + Memmap::Ram + } fn default_build_threads() -> usize { std::thread::available_parallelism().unwrap().get() } @@ -33,19 +45,41 @@ impl IvfOptions { } } -pub struct Ivf { - implementation: IvfImpl, +pub struct Ivf { + implementation: IvfImpl, } -impl Algorithm for Ivf { - type Options = IvfOptions; +impl Algo for Ivf { + type Error = IvfError; + + type Save = (); - fn build(options: Options, vectors: Arc, n: usize) -> anyhow::Result { + fn prebuild( + storage: &mut StoragePreallocator, + options: IndexOptions, + ) -> Result<(), Self::Error> { + let ivf_options = options.algorithm.clone().unwrap_ivf(); + IvfImpl::::prebuild( + storage, + options.dims, + ivf_options.nlist, + options.capacity, + ivf_options.memmap, + )?; + Ok(()) + } + + fn build( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + n: usize, + ) -> Result { let ivf_options = options.algorithm.clone().unwrap_ivf(); let implementation = IvfImpl::new( + storage, vectors.clone(), options.dims, - options.distance, n, ivf_options.nlist, ivf_options.nlist * 50, @@ -53,13 +87,13 @@ impl Algorithm for Ivf { ivf_options.least_iterations, ivf_options.iterations, options.capacity, - ivf_options.storage, + ivf_options.memmap, )?; let i = AtomicUsize::new(0); - using().scope(|scope| -> anyhow::Result<()> { + std::thread::scope(|scope| -> Result<(), IvfError> { let mut handles = Vec::new(); for _ in 0..ivf_options.build_threads { - handles.push(scope.spawn(|| -> anyhow::Result<()> { + handles.push(scope.spawn(|| -> Result<(), IvfError> { loop { let i = i.fetch_add(1, Ordering::Relaxed); if i >= n { @@ -67,28 +101,47 @@ impl Algorithm for Ivf { } implementation.insert(i)?; } - anyhow::Result::Ok(()) + Result::Ok(()) })); } for handle in handles.into_iter() { handle.join().unwrap()?; } - anyhow::Result::Ok(()) + Result::Ok(()) })?; Ok(Self { implementation }) } - fn address(&self) -> Address { - self.implementation.address - } - fn load(options: Options, vectors: Arc, address: Address) -> anyhow::Result { + fn save(&self) {} + fn load( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + (): (), + ) -> Result { let ivf_options = options.algorithm.clone().unwrap_ivf(); - let implementation = IvfImpl::load(vectors, options.distance, address, ivf_options.nprobe)?; + let implementation = IvfImpl::load( + storage, + options.dims, + vectors, + ivf_options.nlist, + ivf_options.nprobe, + options.capacity, + ivf_options.memmap, + )?; Ok(Self { implementation }) } - fn insert(&self, insert: usize) -> anyhow::Result<()> { + fn insert(&self, insert: usize) -> Result<(), IvfError> { self.implementation.insert(insert) } - fn search(&self, search: (Box<[Scalar]>, usize)) -> anyhow::Result> { - self.implementation.search(search) + fn search( + &self, + target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, IvfError> + where + F: FnMut(u64) -> bool, + { + self.implementation.search(target, k, filter) } } diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index cc24baf58..fddc91c14 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -1,63 +1,315 @@ mod flat; +mod flat_q; mod hnsw; mod impls; mod ivf; -mod vectors; +mod utils; -pub use flat::Flat; -pub use hnsw::Hnsw; -pub use ivf::Ivf; -pub use vectors::Vectors; +pub use flat::{Flat, FlatOptions}; +pub use flat_q::{FlatQ, FlatQOptions}; +pub use hnsw::{Hnsw, HnswOptions}; +pub use ivf::{Ivf, IvfOptions}; -use crate::memory::Address; +use self::flat::FlatError; +use self::flat_q::FlatQError; +use self::hnsw::HnswError; +use self::impls::quantization::ProductQuantization; +use self::ivf::IvfError; +use crate::bgworker::index::IndexOptions; +use crate::bgworker::storage::{Storage, StoragePreallocator}; +use crate::bgworker::vectors::Vectors; use crate::prelude::*; +use serde::{Deserialize, Serialize}; use std::sync::Arc; +use thiserror::Error; -pub enum DynAlgorithm { - Hnsw(Hnsw), - Flat(Flat), - Ivf(Ivf), +#[derive(Debug, Clone, Serialize, Deserialize, Error)] +pub enum AlgorithmError { + #[error("HNSW {0}")] + Hnsw(#[from] HnswError), + #[error("Flat {0}")] + Flat(#[from] FlatError), + #[error("FlatQ {0}")] + FlatQ(#[from] FlatQError), + #[error("Ivf {0}")] + Ivf(#[from] IvfError), } -impl DynAlgorithm { - pub fn build(options: Options, vectors: Arc, n: usize) -> anyhow::Result { +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AlgorithmOptions { + Hnsw(HnswOptions), + Ivf(IvfOptions), + Flat(FlatOptions), + FlatQ(FlatQOptions), +} + +impl AlgorithmOptions { + pub fn unwrap_hnsw(self) -> HnswOptions { + use AlgorithmOptions::*; + match self { + Hnsw(x) => x, + _ => unreachable!(), + } + } + #[allow(dead_code)] + pub fn unwrap_flat(self) -> FlatOptions { + use AlgorithmOptions::*; + match self { + Flat(x) => x, + _ => unreachable!(), + } + } + pub fn unwrap_flat_q(self) -> FlatQOptions { + use AlgorithmOptions::*; + match self { + FlatQ(x) => x, + _ => unreachable!(), + } + } + pub fn unwrap_ivf(self) -> IvfOptions { + use AlgorithmOptions::*; + match self { + Ivf(x) => x, + _ => unreachable!(), + } + } +} + +pub enum Algorithm { + HnswL2(Hnsw), + HnswCosine(Hnsw), + HnswDot(Hnsw), + FlatL2(Flat), + FlatCosine(Flat), + FlatDot(Flat), + FlatPqL2(FlatQ>), + FlatPqCosine(FlatQ>), + FlatPqDot(FlatQ>), + IvfL2(Ivf), + IvfCosine(Ivf), + IvfDot(Ivf), +} + +impl Algorithm { + pub fn prebuild( + storage: &mut StoragePreallocator, + options: IndexOptions, + ) -> Result<(), AlgorithmError> { use AlgorithmOptions as O; - match options.algorithm { - O::Hnsw(_) => Hnsw::build(options, vectors, n).map(Self::Hnsw), - O::Flat(_) => Flat::build(options, vectors, n).map(Self::Flat), - O::Ivf(_) => Ivf::build(options, vectors, n).map(Self::Ivf), + match (options.algorithm.clone(), options.distance) { + (O::Hnsw(_), Distance::L2) => Ok(Hnsw::::prebuild(storage, options)?), + (O::Hnsw(_), Distance::Cosine) => Ok(Hnsw::::prebuild(storage, options)?), + (O::Hnsw(_), Distance::Dot) => Ok(Hnsw::::prebuild(storage, options)?), + (O::Flat(_), Distance::L2) => Ok(Flat::::prebuild(storage, options)?), + (O::Flat(_), Distance::Cosine) => Ok(Flat::::prebuild(storage, options)?), + (O::Flat(_), Distance::Dot) => Ok(Flat::::prebuild(storage, options)?), + (O::FlatQ(_), Distance::L2) => Ok(FlatQ::>::prebuild( + storage, options, + )?), + (O::FlatQ(_), Distance::Cosine) => Ok( + FlatQ::>::prebuild(storage, options)?, + ), + (O::FlatQ(_), Distance::Dot) => Ok(FlatQ::>::prebuild( + storage, options, + )?), + (O::Ivf(_), Distance::L2) => Ok(Ivf::::prebuild(storage, options)?), + (O::Ivf(_), Distance::Cosine) => Ok(Ivf::::prebuild(storage, options)?), + (O::Ivf(_), Distance::Dot) => Ok(Ivf::::prebuild(storage, options)?), } } - pub fn address(&self) -> Address { - use DynAlgorithm::*; + pub fn build( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + n: usize, + ) -> Result { + use AlgorithmOptions as O; + match (options.algorithm.clone(), options.distance) { + (O::Hnsw(_), Distance::L2) => { + Ok(Hnsw::build(storage, options, vectors, n).map(Self::HnswL2)?) + } + (O::Hnsw(_), Distance::Cosine) => { + Ok(Hnsw::build(storage, options, vectors, n).map(Self::HnswCosine)?) + } + (O::Hnsw(_), Distance::Dot) => { + Ok(Hnsw::build(storage, options, vectors, n).map(Self::HnswDot)?) + } + (O::Flat(_), Distance::L2) => { + Ok(Flat::build(storage, options, vectors, n).map(Self::FlatL2)?) + } + (O::Flat(_), Distance::Cosine) => { + Ok(Flat::build(storage, options, vectors, n).map(Self::FlatCosine)?) + } + (O::Flat(_), Distance::Dot) => { + Ok(FlatQ::build(storage, options, vectors, n).map(Self::FlatPqDot)?) + } + (O::FlatQ(_), Distance::L2) => { + Ok(FlatQ::build(storage, options, vectors, n).map(Self::FlatPqL2)?) + } + (O::FlatQ(_), Distance::Cosine) => { + Ok(FlatQ::build(storage, options, vectors, n).map(Self::FlatPqCosine)?) + } + (O::FlatQ(_), Distance::Dot) => { + Ok(FlatQ::build(storage, options, vectors, n).map(Self::FlatPqDot)?) + } + (O::Ivf(_), Distance::L2) => { + Ok(Ivf::build(storage, options, vectors, n).map(Self::IvfL2)?) + } + (O::Ivf(_), Distance::Cosine) => { + Ok(Ivf::build(storage, options, vectors, n).map(Self::IvfCosine)?) + } + (O::Ivf(_), Distance::Dot) => { + Ok(Ivf::build(storage, options, vectors, n).map(Self::IvfDot)?) + } + } + } + pub fn save(&self) -> Vec { + use Algorithm::*; match self { - Hnsw(sel) => sel.address(), - Flat(sel) => sel.address(), - Ivf(sel) => sel.address(), + HnswL2(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + HnswCosine(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + HnswDot(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + FlatL2(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + FlatCosine(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + FlatDot(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + FlatPqL2(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + FlatPqCosine(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + FlatPqDot(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + IvfL2(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + IvfCosine(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), + IvfDot(sel) => bincode::serialize(&sel.save()).expect("Failed to serialize."), } } - pub fn load(options: Options, vectors: Arc, address: Address) -> anyhow::Result { + pub fn load( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + save: Vec, + ) -> Result { use AlgorithmOptions as O; - match options.algorithm { - O::Hnsw(_) => Ok(Self::Hnsw(Hnsw::load(options, vectors, address)?)), - O::Flat(_) => Ok(Self::Flat(Flat::load(options, vectors, address)?)), - O::Ivf(_) => Ok(Self::Ivf(Ivf::load(options, vectors, address)?)), + match (options.algorithm.clone(), options.distance) { + (O::Hnsw(_), Distance::L2) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Hnsw::load(storage, options, vectors, save).map(Self::HnswL2)?) + } + (O::Hnsw(_), Distance::Cosine) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Hnsw::load(storage, options, vectors, save).map(Self::HnswCosine)?) + } + (O::Hnsw(_), Distance::Dot) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Hnsw::load(storage, options, vectors, save).map(Self::HnswDot)?) + } + (O::Flat(_), Distance::L2) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Flat::load(storage, options, vectors, save).map(Self::FlatL2)?) + } + (O::Flat(_), Distance::Cosine) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Flat::load(storage, options, vectors, save).map(Self::FlatCosine)?) + } + (O::Flat(_), Distance::Dot) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Flat::load(storage, options, vectors, save).map(Self::FlatDot)?) + } + (O::FlatQ(_), Distance::L2) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(FlatQ::load(storage, options, vectors, save).map(Self::FlatPqL2)?) + } + (O::FlatQ(_), Distance::Cosine) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(FlatQ::load(storage, options, vectors, save).map(Self::FlatPqCosine)?) + } + (O::FlatQ(_), Distance::Dot) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(FlatQ::load(storage, options, vectors, save).map(Self::FlatPqDot)?) + } + (O::Ivf(_), Distance::L2) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Ivf::load(storage, options, vectors, save).map(Self::IvfL2)?) + } + (O::Ivf(_), Distance::Cosine) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Ivf::load(storage, options, vectors, save).map(Self::IvfCosine)?) + } + (O::Ivf(_), Distance::Dot) => { + let save = bincode::deserialize(&save).expect("Failed to deserialize."); + Ok(Ivf::load(storage, options, vectors, save).map(Self::IvfDot)?) + } } } - pub fn insert(&self, insert: usize) -> anyhow::Result<()> { - use DynAlgorithm::*; + pub fn insert(&self, insert: usize) -> Result<(), AlgorithmError> { + use Algorithm::*; match self { - Hnsw(sel) => sel.insert(insert), - Flat(sel) => sel.insert(insert), - Ivf(sel) => sel.insert(insert), + HnswL2(sel) => Ok(sel.insert(insert)?), + HnswCosine(sel) => Ok(sel.insert(insert)?), + HnswDot(sel) => Ok(sel.insert(insert)?), + FlatL2(sel) => Ok(sel.insert(insert)?), + FlatCosine(sel) => Ok(sel.insert(insert)?), + FlatDot(sel) => Ok(sel.insert(insert)?), + FlatPqL2(sel) => Ok(sel.insert(insert)?), + FlatPqCosine(sel) => Ok(sel.insert(insert)?), + FlatPqDot(sel) => Ok(sel.insert(insert)?), + IvfL2(sel) => Ok(sel.insert(insert)?), + IvfCosine(sel) => Ok(sel.insert(insert)?), + IvfDot(sel) => Ok(sel.insert(insert)?), } } - pub fn search(&self, search: (Box<[Scalar]>, usize)) -> anyhow::Result> { - use DynAlgorithm::*; + pub fn search( + &self, + target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, AlgorithmError> + where + F: FnMut(u64) -> bool, + { + use Algorithm::*; match self { - Hnsw(sel) => sel.search(search), - Flat(sel) => sel.search(search), - Ivf(sel) => sel.search(search), + HnswL2(sel) => Ok(sel.search(target, k, filter)?), + HnswCosine(sel) => Ok(sel.search(target, k, filter)?), + HnswDot(sel) => Ok(sel.search(target, k, filter)?), + FlatL2(sel) => Ok(sel.search(target, k, filter)?), + FlatCosine(sel) => Ok(sel.search(target, k, filter)?), + FlatDot(sel) => Ok(sel.search(target, k, filter)?), + FlatPqL2(sel) => Ok(sel.search(target, k, filter)?), + FlatPqCosine(sel) => Ok(sel.search(target, k, filter)?), + FlatPqDot(sel) => Ok(sel.search(target, k, filter)?), + IvfL2(sel) => Ok(sel.search(target, k, filter)?), + IvfCosine(sel) => Ok(sel.search(target, k, filter)?), + IvfDot(sel) => Ok(sel.search(target, k, filter)?), } } } + +pub trait Algo: Sized { + type Error: std::error::Error + serde::Serialize + for<'a> serde::Deserialize<'a>; + type Save: serde::Serialize + for<'a> serde::Deserialize<'a>; + fn prebuild( + storage: &mut StoragePreallocator, + options: IndexOptions, + ) -> Result<(), Self::Error>; + fn build( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + n: usize, + ) -> Result; + fn save(&self) -> Self::Save; + fn load( + storage: &mut Storage, + options: IndexOptions, + vectors: Arc, + save: Self::Save, + ) -> Result; + fn insert(&self, i: usize) -> Result<(), Self::Error>; + fn search( + &self, + target: Box<[Scalar]>, + k: usize, + filter: F, + ) -> Result, Self::Error> + where + F: FnMut(u64) -> bool; +} diff --git a/src/algorithms/utils/filtered_fixed_heap.rs b/src/algorithms/utils/filtered_fixed_heap.rs new file mode 100644 index 000000000..f10e58cd8 --- /dev/null +++ b/src/algorithms/utils/filtered_fixed_heap.rs @@ -0,0 +1,48 @@ +use crate::prelude::*; +use std::collections::BinaryHeap; + +type T = (Scalar, u64); + +#[derive(Debug, Clone)] +pub struct FilteredFixedHeap { + size: usize, + heap: BinaryHeap, + f: F, +} + +impl FilteredFixedHeap +where + F: FnMut(u64) -> bool, +{ + pub fn new(size: usize, f: F) -> Self { + Self { + size, + heap: BinaryHeap::::with_capacity(size), + f, + } + } + pub fn push(&mut self, item: T) { + if self.heap.len() < self.size { + if (self.f)(item.1) { + self.heap.push(item); + } + } else if self.heap.peek().unwrap() > &item { + if (self.f)(item.1) { + self.heap.pop(); + self.heap.push(item); + } + } + } + pub fn bound(&mut self) -> Scalar { + if self.heap.len() < self.size { + Scalar::INFINITY + } else { + self.heap.peek().unwrap().0 + } + } + pub fn into_sorted_vec(self) -> Vec { + let mut vec = self.heap.into_vec(); + vec.sort(); + vec + } +} diff --git a/src/utils/fixed_heap.rs b/src/algorithms/utils/fixed_heap.rs similarity index 80% rename from src/utils/fixed_heap.rs rename to src/algorithms/utils/fixed_heap.rs index 449af79d1..485108223 100644 --- a/src/utils/fixed_heap.rs +++ b/src/algorithms/utils/fixed_heap.rs @@ -1,5 +1,6 @@ use std::collections::BinaryHeap; +#[derive(Debug, Clone)] pub struct FixedHeap { size: usize, heap: BinaryHeap, @@ -21,9 +22,4 @@ impl FixedHeap { pub fn into_vec(self) -> Vec { self.heap.into_vec() } - pub fn into_sorted_vec(self) -> Vec { - let mut vec = self.heap.into_vec(); - vec.sort(); - vec - } } diff --git a/src/algorithms/utils/mmap_vec2.rs b/src/algorithms/utils/mmap_vec2.rs new file mode 100644 index 000000000..995783bb8 --- /dev/null +++ b/src/algorithms/utils/mmap_vec2.rs @@ -0,0 +1,60 @@ +use crate::bgworker::storage::{Storage, StoragePreallocator}; +use crate::bgworker::storage_mmap::MmapBox; +use crate::prelude::*; +use std::ops::{Deref, DerefMut, Index, IndexMut}; + +#[derive(Debug)] +pub struct MmapVec2 { + dims: u16, + v: MmapBox<[Scalar]>, +} + +impl MmapVec2 { + pub fn prebuild(storage: &mut StoragePreallocator, dims: u16, n: usize) { + storage.palloc_mmap_slice::(Memmap::Ram, dims as usize * n); + } + pub fn build(storage: &mut Storage, dims: u16, n: usize) -> Self { + let v = unsafe { + storage + .alloc_mmap_slice(Memmap::Ram, dims as usize * n) + .assume_init() + }; + Self { dims, v } + } + pub fn load(storage: &mut Storage, dims: u16, n: usize) -> Self { + let v = unsafe { + storage + .alloc_mmap_slice(Memmap::Ram, dims as usize * n) + .assume_init() + }; + Self { dims, v } + } +} + +impl Index for MmapVec2 { + type Output = [Scalar]; + + fn index(&self, index: usize) -> &Self::Output { + &self.v[self.dims as usize * index..][..self.dims as usize] + } +} + +impl IndexMut for MmapVec2 { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.v[self.dims as usize * index..][..self.dims as usize] + } +} + +impl Deref for MmapVec2 { + type Target = [Scalar]; + + fn deref(&self) -> &Self::Target { + self.v.deref() + } +} + +impl DerefMut for MmapVec2 { + fn deref_mut(&mut self) -> &mut Self::Target { + self.v.deref_mut() + } +} diff --git a/src/algorithms/utils/mod.rs b/src/algorithms/utils/mod.rs new file mode 100644 index 000000000..6805353ff --- /dev/null +++ b/src/algorithms/utils/mod.rs @@ -0,0 +1,5 @@ +pub mod filtered_fixed_heap; +pub mod fixed_heap; +pub mod mmap_vec2; +pub mod semaphore; +pub mod vec2; diff --git a/src/utils/semaphore.rs b/src/algorithms/utils/semaphore.rs similarity index 100% rename from src/utils/semaphore.rs rename to src/algorithms/utils/semaphore.rs diff --git a/src/utils/vec2.rs b/src/algorithms/utils/vec2.rs similarity index 87% rename from src/utils/vec2.rs rename to src/algorithms/utils/vec2.rs index dabcd1496..d2bdd828c 100644 --- a/src/utils/vec2.rs +++ b/src/algorithms/utils/vec2.rs @@ -1,7 +1,7 @@ use crate::prelude::*; use std::ops::{Deref, DerefMut, Index, IndexMut}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Vec2 { dims: u16, v: Box<[Scalar]>, @@ -14,7 +14,6 @@ impl Vec2 { v: unsafe { Box::new_zeroed_slice(dims as usize * n).assume_init() }, } } - #[allow(dead_code)] pub fn dims(&self) -> u16 { self.dims } @@ -31,11 +30,6 @@ impl Vec2 { } } } - #[allow(dead_code)] - pub fn iter(&self) -> impl Iterator { - let n = self.len(); - (0..n).map(|i| self.index(i)) - } } impl Index for Vec2 { diff --git a/src/algorithms/vectors.rs b/src/algorithms/vectors.rs deleted file mode 100644 index 2ab83f3e5..000000000 --- a/src/algorithms/vectors.rs +++ /dev/null @@ -1,116 +0,0 @@ -use crate::memory::Address; -use crate::memory::PBox; -use crate::memory::Persistent; -use crate::memory::Ptr; -use crate::prelude::*; -use std::cell::UnsafeCell; -use std::mem::MaybeUninit; -use std::sync::atomic::AtomicUsize; -use std::sync::atomic::Ordering; - -type Boxed = PBox<[UnsafeCell>]>; - -pub struct Root { - len: AtomicUsize, - inflight: AtomicUsize, - // ---------------------- - data: Boxed, - vector: Boxed, -} - -static_assertions::assert_impl_all!(Root: Persistent); - -pub struct Vectors { - address: Address, - root: &'static Root, - dims: u16, - capacity: usize, -} - -impl Vectors { - pub fn build(options: Options) -> anyhow::Result { - let storage = options.storage_vectors; - let ptr = PBox::new( - unsafe { - let len_data = options.capacity; - let len_vector = options.capacity * options.dims as usize; - Root { - len: AtomicUsize::new(0), - inflight: AtomicUsize::new(0), - data: PBox::new_uninit_slice(len_data, storage)?.assume_init(), - vector: PBox::new_uninit_slice(len_vector, storage)?.assume_init(), - } - }, - storage, - )? - .into_raw(); - let root = unsafe { ptr.as_ref() }; - let address = ptr.address(); - Ok(Self { - root, - dims: options.dims, - capacity: options.capacity, - address, - }) - } - pub fn address(&self) -> Address { - self.address - } - pub fn load(options: Options, address: Address) -> anyhow::Result { - Ok(Self { - root: unsafe { Ptr::new(address, ()).as_ref() }, - capacity: options.capacity, - dims: options.dims, - address, - }) - } - pub fn put(&self, data: u64, vector: &[Scalar]) -> anyhow::Result { - // If the index is approaching to `usize::MAX`, it will break. But it will not likely happen. - let i = self.root.inflight.fetch_add(1, Ordering::AcqRel); - if i >= self.capacity { - self.root.inflight.store(self.capacity, Ordering::Release); - anyhow::bail!("Full."); - } - unsafe { - let uninit_data = &mut *self.root.data[i].get(); - MaybeUninit::write(uninit_data, data); - let uninit_vector = - assume_mutable(&self.root.vector[i * self.dims as usize..][..self.dims as usize]); - uninit_vector.copy_from_slice(std::slice::from_raw_parts( - vector.as_ptr() as *const MaybeUninit, - vector.len(), - )); - } - while self - .root - .len - .compare_exchange_weak(i, i + 1, Ordering::AcqRel, Ordering::Relaxed) - .is_err() - { - std::hint::spin_loop(); - } - Ok(i) - } - pub fn len(&self) -> usize { - self.root.len.load(Ordering::Acquire) - } - pub fn get_data(&self, i: usize) -> u64 { - unsafe { (*self.root.data[i].get()).assume_init_read() } - } - pub fn get_vector(&self, i: usize) -> &[Scalar] { - unsafe { - assume_immutable_init(&self.root.vector[i * self.dims as usize..][..self.dims as usize]) - } - } -} - -#[allow(clippy::mut_from_ref)] -unsafe fn assume_mutable(slice: &[UnsafeCell]) -> &mut [T] { - let p = slice.as_ptr().cast::>() as *mut T; - std::slice::from_raw_parts_mut(p, slice.len()) -} - -unsafe fn assume_immutable_init(slice: &[UnsafeCell>]) -> &[T] { - let p = slice.as_ptr().cast::>() as *const T; - std::slice::from_raw_parts(p, slice.len()) -} diff --git a/src/bgworker/filter_delete.rs b/src/bgworker/filter_delete.rs new file mode 100644 index 000000000..58a87c005 --- /dev/null +++ b/src/bgworker/filter_delete.rs @@ -0,0 +1,68 @@ +use crate::prelude::*; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; + +pub struct FilterDelete { + data: DashMap, +} + +impl FilterDelete { + pub fn new() -> Self { + Self { + data: DashMap::new(), + } + } + pub fn filter(&self, x: u64) -> Option { + let p = Pointer::from_u48(x >> 16); + let version = x as u16; + if let Some(cell) = self.data.get(&p) { + let (current_version, current_existence) = cell.value(); + if version < *current_version { + None + } else { + debug_assert!(version == *current_version); + debug_assert!(*current_existence); + Some(p) + } + } else { + debug_assert!(version == 0); + Some(p) + } + } + pub fn on_deleting(&self, p: Pointer) -> bool { + match self.data.entry(p) { + Entry::Occupied(mut entry) => { + let (current_version, current_existence) = entry.get_mut(); + if *current_existence { + *current_version = *current_version + 1; + *current_existence = false; + true + } else { + false + } + } + Entry::Vacant(entry) => { + let current_version = 1u16; + let current_existence = false; + entry.insert((current_version, current_existence)); + true + } + } + } + pub fn on_inserting(&self, p: Pointer) -> u64 { + match self.data.entry(p) { + Entry::Occupied(mut entry) => { + let (current_version, current_existence) = entry.get_mut(); + debug_assert!(*current_existence == false); + *current_existence = true; + p.as_u48() << 16 | *current_version as u64 + } + Entry::Vacant(entry) => { + let current_version = 0u16; + let current_existence = true; + entry.insert((current_version, current_existence)); + p.as_u48() << 16 | current_version as u64 + } + } + } +} diff --git a/src/bgworker/index.rs b/src/bgworker/index.rs index 4d5f566f7..2de808f26 100644 --- a/src/bgworker/index.rs +++ b/src/bgworker/index.rs @@ -1,292 +1,242 @@ -use super::wal::WalSync; +use super::filter_delete::FilterDelete; +use super::storage::Storage; +use super::storage::StoragePreallocator; +use super::vectors::Vectors; +use super::wal::Wal; use super::wal::WalWriter; -use crate::algorithms::DynAlgorithm; -use crate::algorithms::Vectors; -use crate::memory::given; -use crate::memory::Address; -use crate::memory::Context; -use crate::memory::ContextOptions; +use crate::algorithms::Algorithm; +use crate::algorithms::AlgorithmError; +use crate::algorithms::AlgorithmOptions; +use crate::bgworker::vectors::VectorsOptions; +use crate::ipc::server::Build; +use crate::ipc::server::Search; +use crate::ipc::ServerIpcError; use crate::prelude::*; -use dashmap::DashMap; use serde::{Deserialize, Serialize}; +use std::io::ErrorKind; use std::path::Path; -use std::ptr::NonNull; use std::sync::Arc; -use tokio::io::ErrorKind; -use tokio_stream::StreamExt; +use thiserror::Error; +use validator::Validate; + +#[derive(Debug, Clone, Serialize, Deserialize, Error)] +pub enum IndexError { + #[error("Algorithm {0}")] + Algorithm(#[from] AlgorithmError), + #[error("Ipc {0}")] + Ipc(#[from] ServerIpcError), +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +pub struct IndexOptions { + #[validate(range(min = 1))] + pub dims: u16, + pub distance: Distance, + #[validate(range(min = 1))] + pub capacity: usize, + pub vectors: VectorsOptions, + pub algorithm: AlgorithmOptions, +} pub struct Index { #[allow(dead_code)] id: Id, #[allow(dead_code)] - options: Options, + options: IndexOptions, vectors: Arc, - algo: DynAlgorithm, - version: IndexVersion, + algo: Algorithm, + filter_delete: FilterDelete, wal: WalWriter, #[allow(dead_code)] - context: Arc, + storage: Storage, } impl Index { - pub async fn drop(id: Id) -> anyhow::Result<()> { - use tokio_stream::wrappers::ReadDirStream; - let mut stream = ReadDirStream::new(tokio::fs::read_dir(".").await?); - while let Some(f) = stream.next().await { - let filename = f? - .file_name() - .into_string() - .map_err(|_| anyhow::anyhow!("Bad filename."))?; - if filename.starts_with(&format!("{}_", id.as_u32())) { - remove_file_if_exists(filename).await?; + pub fn clean(id: Id) { + for f in std::fs::read_dir(".").expect("Failed to clean.") { + let f = f.unwrap(); + if let Some(filename) = f.file_name().to_str() { + if filename.starts_with(&format!("{}_", id.as_u32())) { + remove_file_if_exists(filename).expect("Failed to delete."); + } } } - Ok(()) } - pub async fn build( + pub fn prebuild(options: IndexOptions) -> Result { + let mut storage = StoragePreallocator::new(); + Vectors::prebuild(&mut storage, options.clone()); + Algorithm::prebuild(&mut storage, options.clone())?; + Ok(storage) + } + pub fn build( id: Id, - options: Options, - data: async_channel::Receiver<(Box<[Scalar]>, Pointer)>, - ) -> anyhow::Result { - Self::drop(id).await?; - tokio::task::block_in_place(|| -> anyhow::Result<_> { - let context = Context::build(ContextOptions { - block_ram: (options.size_ram, format!("{}_data_ram", id.as_u32())), - block_disk: (options.size_disk, format!("{}_data_disk", id.as_u32())), - })?; - let _given = unsafe { given(NonNull::new_unchecked(Arc::as_ptr(&context).cast_mut())) }; - let vectors = Arc::new(Vectors::build(options.clone())?); - while let Ok((vector, p)) = data.recv_blocking() { - let data = p.as_u48() << 16; - vectors.put(data, &vector)?; - } - let algo = DynAlgorithm::build(options.clone(), vectors.clone(), vectors.len())?; - context.persist()?; - let version = IndexVersion::new(); - let wal = { - let path_wal = format!("{}_wal", id.as_u32()); - let mut wal = WalSync::create(path_wal)?; - let log = LogMeta { - options: options.clone(), - address_algorithm: algo.address(), - address_vectors: vectors.address(), - }; - wal.write(&log.bincode()?)?; - WalWriter::spawn(wal.into_async())? + options: IndexOptions, + server_build: &mut Build, + ) -> Result { + Self::clean(id); + let storage_preallocator = Self::prebuild(options.clone())?; + let mut storage = Storage::build(id, storage_preallocator); + let vectors = Arc::new(Vectors::build(&mut storage, options.clone())); + while let Some((vector, p)) = server_build.next().expect("IPC error.") { + let data = p.as_u48() << 16; + vectors.put(data, &vector); + } + let algo = Algorithm::build( + &mut storage, + options.clone(), + vectors.clone(), + vectors.len(), + )?; + storage.persist(); + let filter_delete = FilterDelete::new(); + let wal = { + let path_wal = format!("{}_wal", id.as_u32()); + let mut wal = Wal::create(path_wal); + let log = LogFirst { + options: options.clone(), + save_algorithm: algo.save(), }; - Ok(Self { - id, - options, - vectors, - algo, - version, - wal, - context, - }) + wal.write(&log.bincode()); + wal + }; + Ok(Self { + id, + options, + vectors, + algo, + filter_delete, + wal: WalWriter::spawn(wal), + storage, }) } - pub async fn load(id: Id) -> anyhow::Result { - tokio::task::block_in_place(|| { - let mut wal = WalSync::open(format!("{}_wal", id.as_u32()))?; - let LogMeta { - options, - address_vectors, - address_algorithm, - } = wal - .read()? - .ok_or(anyhow::anyhow!("The index is broken."))? - .deserialize::()?; - let context = Context::load(ContextOptions { - block_ram: (options.size_ram, format!("{}_data_ram", id.as_u32())), - block_disk: (options.size_disk, format!("{}_data_disk", id.as_u32())), - })?; - let _given = unsafe { given(NonNull::new_unchecked(Arc::as_ptr(&context).cast_mut())) }; - let vectors = Arc::new(Vectors::load(options.clone(), address_vectors)?); - let algo = DynAlgorithm::load(options.clone(), vectors.clone(), address_algorithm)?; - let version = IndexVersion::new(); - loop { - let Some(replay) = wal.read()? else { break }; - match replay.deserialize::()? { - LogReplay::Insert { vector, p } => { - let data = version.insert(p); - let index = vectors.put(data, &vector)?; - algo.insert(index)?; - } - LogReplay::Delete { p } => { - version.remove(p); - } + pub fn load(id: Id) -> Self { + let mut storage = Storage::load(id); + let mut wal = Wal::open(format!("{}_wal", id.as_u32())); + let LogFirst { + options, + save_algorithm, + } = wal + .read() + .expect("The index is broken.") + .deserialize::(); + let vectors = Arc::new(Vectors::load(&mut storage, options.clone())); + let algo = Algorithm::load( + &mut storage, + options.clone(), + vectors.clone(), + save_algorithm, + ) + .expect("Failed to load the algorithm."); + let filter_delete = FilterDelete::new(); + loop { + let Some(replay) = wal.read() else { break }; + match replay.deserialize::() { + LogFollowing::Insert { vector, p } => { + let data = filter_delete.on_inserting(p); + let index = vectors.put(data, &vector); + algo.insert(index).expect("Failed to reinsert."); + } + LogFollowing::Delete { p } => { + filter_delete.on_deleting(p); } } - wal.truncate()?; - wal.flush()?; - let wal = WalWriter::spawn(wal.into_async())?; - Ok(Self { - id, - options, - algo, - version, - wal, - vectors, - context, - }) - }) + } + wal.truncate(); + wal.flush(); + Self { + id, + options, + algo, + filter_delete, + wal: WalWriter::spawn(wal), + vectors, + storage, + } } - pub async fn insert(&self, (vector, p): (Box<[Scalar]>, Pointer)) -> anyhow::Result<()> { - tokio::task::block_in_place(|| -> anyhow::Result<()> { - let _given = unsafe { - given(NonNull::new_unchecked( - Arc::as_ptr(&self.context).cast_mut(), - )) - }; - let data = self.version.insert(p); - let index = self.vectors.put(data, &vector)?; - self.algo.insert(index)?; - anyhow::Result::Ok(()) - })?; - let bytes = LogReplay::Insert { vector, p }.bincode()?; - self.wal.write(bytes).await?; + pub fn insert(&self, (vector, p): (Box<[Scalar]>, Pointer)) -> Result<(), IndexError> { + let data = self.filter_delete.on_inserting(p); + let index = self.vectors.put(data, &vector); + self.algo.insert(index)?; + let bytes = LogFollowing::Insert { vector, p }.bincode(); + self.wal.write(bytes); Ok(()) } - pub async fn delete(&self, delete: Pointer) -> anyhow::Result<()> { - self.version.remove(delete); - let bytes = LogReplay::Delete { p: delete }.bincode()?; - self.wal.write(bytes).await?; + pub fn delete(&self, delete: Pointer) -> Result<(), IndexError> { + self.filter_delete.on_deleting(delete); + let bytes = LogFollowing::Delete { p: delete }.bincode(); + self.wal.write(bytes); Ok(()) } - pub async fn search(&self, search: (Box<[Scalar]>, usize)) -> anyhow::Result> { - let result = tokio::task::block_in_place(|| -> anyhow::Result<_> { - let _given = unsafe { - given(NonNull::new_unchecked( - Arc::as_ptr(&self.context).cast_mut(), - )) - }; - let result = self.algo.search(search)?; - let result = result - .into_iter() - .filter_map(|(_, x)| self.version.filter(x)) - .collect(); - Ok(result) - })?; + pub fn search( + &self, + target: Box<[Scalar]>, + k: usize, + server_search: &mut Search, + ) -> Result, IndexError> { + let filter = |p| { + if let Some(p) = self.filter_delete.filter(p) { + server_search.check(p).expect("IPC error.") + } else { + false + } + }; + let result = self.algo.search(target, k, filter)?; + let result = result + .into_iter() + .filter_map(|(_, x)| self.filter_delete.filter(x)) + .collect(); Ok(result) } - pub async fn flush(&self) -> anyhow::Result<()> { - self.wal.flush().await?; - Ok(()) + pub fn flush(&self) { + self.wal.flush(); } - pub async fn shutdown(self) -> anyhow::Result<()> { - self.wal.shutdown().await?; - Ok(()) - } -} - -struct IndexVersion { - data: DashMap, -} - -impl IndexVersion { - pub fn new() -> Self { - Self { - data: DashMap::new(), - } - } - pub fn filter(&self, x: u64) -> Option { - let p = Pointer::from_u48(x >> 16); - let v = x as u16; - if let Some(guard) = self.data.get(&p) { - let (cv, cve) = guard.value(); - debug_assert!(v < *cv || (v == *cv && *cve)); - if v == *cv { - Some(p) - } else { - None - } - } else { - debug_assert!(v == 0); - Some(p) - } - } - pub fn insert(&self, p: Pointer) -> u64 { - if let Some(mut guard) = self.data.get_mut(&p) { - let (cv, cve) = guard.value_mut(); - debug_assert!(*cve == false); - *cve = true; - p.as_u48() << 16 | *cv as u64 - } else { - self.data.insert(p, (0, true)); - p.as_u48() << 16 | 0 - } - } - pub fn remove(&self, p: Pointer) { - if let Some(mut guard) = self.data.get_mut(&p) { - let (cv, cve) = guard.value_mut(); - if *cve == true { - *cv = *cv + 1; - *cve = false; - } - } else { - self.data.insert(p, (1, false)); - } + pub fn shutdown(&mut self) { + self.wal.shutdown(); } } #[derive(Serialize, Deserialize, Debug, Clone)] -struct LogMeta { - options: Options, - address_vectors: Address, - address_algorithm: Address, +struct LogFirst { + options: IndexOptions, + save_algorithm: Vec, } #[derive(Serialize, Deserialize, Debug, Clone)] -enum LogReplay { +enum LogFollowing { Insert { vector: Box<[Scalar]>, p: Pointer }, Delete { p: Pointer }, } -pub struct Load { - inner: Option, +fn remove_file_if_exists(path: impl AsRef) -> std::io::Result<()> { + match std::fs::remove_file(path) { + Ok(()) => Ok(()), + Err(e) if e.kind() == ErrorKind::NotFound => Ok(()), + Err(e) => Err(e), + } } -impl Load { - pub fn new() -> Self { - Self { inner: None } - } - pub fn get(&self) -> anyhow::Result<&T> { - self.inner - .as_ref() - .ok_or(anyhow::anyhow!("The index is not loaded.")) - } - #[allow(dead_code)] - pub fn get_mut(&mut self) -> anyhow::Result<&mut T> { - self.inner - .as_mut() - .ok_or(anyhow::anyhow!("The index is not loaded.")) - } - pub fn load(&mut self, x: T) { - assert!(self.inner.is_none()); - self.inner = Some(x); - } - pub fn unload(&mut self) -> T { - assert!(self.inner.is_some()); - self.inner.take().unwrap() - } - pub fn is_loaded(&self) -> bool { - self.inner.is_some() - } - pub fn is_unloaded(&self) -> bool { - self.inner.is_none() +trait BincodeDeserialize { + fn deserialize<'a, T: Deserialize<'a>>(&'a self) -> T; +} + +impl BincodeDeserialize for [u8] { + fn deserialize<'a, T: Deserialize<'a>>(&'a self) -> T { + bincode::deserialize::(self).expect("Failed to deserialize.") } } -async fn remove_file_if_exists(path: impl AsRef) -> std::io::Result<()> { - match tokio::fs::remove_file(path).await { - Ok(()) => Ok(()), - Err(e) if e.kind() == ErrorKind::NotFound => Ok(()), - Err(e) => Err(e), +trait Bincode: Sized { + fn bincode(&self) -> Vec; +} + +impl Bincode for T { + fn bincode(&self) -> Vec { + bincode::serialize(self).expect("Failed to serialize.") } } diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index cf53677a7..99cb5b5df 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -1,85 +1,23 @@ -mod index; -mod session; -mod wal; +pub mod filter_delete; +pub mod index; +pub mod storage; +pub mod storage_mmap; +pub mod vectors; +pub mod wal; -pub use session::Client; -pub use session::ClientBuild; - -use self::index::Load; -use crate::postgres::PORT; +use self::index::IndexError; +use crate::ipc::server::RpcHandler; +use crate::ipc::ServerIpcError; use crate::prelude::*; use dashmap::DashMap; use index::Index; use std::fs::OpenOptions; use std::mem::MaybeUninit; -use tokio::sync::RwLock; - -struct Global { - indexes: DashMap>>, -} - -static mut GLOBAL: MaybeUninit = MaybeUninit::uninit(); +use thiserror::Error; #[no_mangle] extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) -> ! { - match std::panic::catch_unwind(|| { - std::fs::create_dir_all("pg_vectors").expect("Failed to create the directory."); - std::env::set_current_dir("pg_vectors").expect("Failed to set the current variable."); - unsafe { - let global = Global { - indexes: DashMap::new(), - }; - (GLOBAL.as_ptr() as *mut Global).write(global); - } - let logging = OpenOptions::new() - .create(true) - .append(true) - .open("_log") - .expect("The logging file is failed to open."); - env_logger::builder() - .target(env_logger::Target::Pipe(Box::new(logging))) - .init(); - std::panic::set_hook(Box::new(|info| { - let backtrace = std::backtrace::Backtrace::capture(); - log::error!("Process panickied. {:?}. Backtrace. {}.", info, backtrace); - })); - let runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .expect("The tokio runtime is failed to build."); - let listener = runtime - .block_on(async { tokio::net::TcpListener::bind(("0.0.0.0", PORT.get() as u16)).await }) - .expect("The listening port is failed to bind."); - runtime.spawn(async move { - while let Ok((stream, _)) = listener.accept().await { - tokio::task::spawn(async move { - if let Err(e) = session::server_main(stream).await { - log::error!("Session panickied. {}. {}", e, e.backtrace()); - } - }); - } - }); - loop { - let mut sig: i32 = 0; - unsafe { - let mut set: libc::sigset_t = std::mem::zeroed(); - libc::sigemptyset(&mut set); - libc::sigaddset(&mut set, libc::SIGHUP); - libc::sigaddset(&mut set, libc::SIGTERM); - libc::sigwait(&set, &mut sig); - } - match sig { - libc::SIGHUP => { - std::process::exit(0); - } - libc::SIGTERM => { - std::process::exit(0); - } - _ => (), - } - std::thread::yield_now(); - } - }) { + match std::panic::catch_unwind(thread_main) { Ok(never) => never, Err(_) => { log::error!("The background process crashed."); @@ -88,18 +26,154 @@ extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) -> ! { } } -fn global() -> &'static Global { - unsafe { GLOBAL.assume_init_ref() } +fn thread_main() -> ! { + std::fs::create_dir_all("pg_vectors").expect("Failed to create the directory."); + std::env::set_current_dir("pg_vectors").expect("Failed to set the current variable."); + unsafe { + INDEXES.as_mut_ptr().write(DashMap::new()); + } + let logging = OpenOptions::new() + .create(true) + .append(true) + .open("_log") + .expect("The logging file is failed to open."); + env_logger::builder() + .target(env_logger::Target::Pipe(Box::new(logging))) + .init(); + std::panic::set_hook(Box::new(|info| { + let backtrace = std::backtrace::Backtrace::capture(); + log::error!("Process panickied. {:?}. Backtrace. {}.", info, backtrace); + })); + std::thread::spawn(|| thread_listening()); + loop { + let mut sig: i32 = 0; + unsafe { + let mut set: libc::sigset_t = std::mem::zeroed(); + libc::sigemptyset(&mut set); + libc::sigaddset(&mut set, libc::SIGHUP); + libc::sigaddset(&mut set, libc::SIGTERM); + libc::sigwait(&set, &mut sig); + } + match sig { + libc::SIGHUP => { + std::process::exit(0); + } + libc::SIGTERM => { + std::process::exit(0); + } + _ => (), + } + std::thread::yield_now(); + } +} + +static mut INDEXES: MaybeUninit> = MaybeUninit::uninit(); + +fn thread_listening() { + let listener = crate::ipc::listen(); + for rpc_handler in listener { + std::thread::spawn(move || { + if let Err(e) = thread_session(rpc_handler) { + log::error!("Session exited. {}.", e); + } + }); + } +} + +#[derive(Debug, Clone, Error)] +pub enum SessionError { + #[error("Ipc")] + Ipc(#[from] ServerIpcError), + #[error("Index")] + Index(#[from] IndexError), } -async fn find_index(id: Id) -> anyhow::Result<&'static RwLock>> { - use dashmap::mapref::entry::Entry; - match global().indexes.try_entry(id).unwrap() { - Entry::Occupied(x) => Ok(x.get()), - Entry::Vacant(x) => { - let reference = Box::leak(Box::new(RwLock::new(Load::new()))); - x.insert(reference); - Ok(reference) +fn thread_session(mut rpc_handler: RpcHandler) -> Result<(), SessionError> { + use crate::ipc::server::RpcHandle; + loop { + match rpc_handler.handle()? { + RpcHandle::Build { id, options, mut x } => { + use dashmap::mapref::entry::Entry; + let indexes = unsafe { INDEXES.assume_init_ref() }; + match indexes.entry(id) { + Entry::Occupied(entry) => entry.into_ref(), + Entry::Vacant(entry) => { + let index = Index::build(id, options, &mut x)?; + entry.insert(index) + } + }; + rpc_handler = x.leave()?; + } + RpcHandle::Insert { id, insert, x } => { + let indexes = unsafe { INDEXES.assume_init_ref() }; + let index = indexes.get(&id).expect("Not load."); + index.insert(insert)?; + rpc_handler = x.leave()?; + } + RpcHandle::Delete { id, delete, x } => { + let indexes = unsafe { INDEXES.assume_init_ref() }; + let index = indexes.get(&id).expect("Not load."); + index.delete(delete)?; + rpc_handler = x.leave()?; + } + RpcHandle::Search { + id, + target, + k, + mut x, + } => { + let indexes = unsafe { INDEXES.assume_init_ref() }; + let index = indexes.get(&id).expect("Not load."); + let result = index.search(target, k, &mut x)?; + rpc_handler = x.leave(result)?; + } + RpcHandle::Load { id, x } => { + use dashmap::mapref::entry::Entry; + let indexes: &DashMap = unsafe { INDEXES.assume_init_ref() }; + match indexes.entry(id) { + Entry::Occupied(entry) => entry.into_ref(), + Entry::Vacant(entry) => { + let index = Index::load(id); + entry.insert(index) + } + }; + rpc_handler = x.leave()?; + } + RpcHandle::Unload { id, x } => { + use dashmap::mapref::entry::Entry; + let indexes: &DashMap = unsafe { INDEXES.assume_init_ref() }; + match indexes.entry(id) { + Entry::Occupied(mut entry) => { + entry.get_mut().shutdown(); + entry.remove(); + } + Entry::Vacant(_) => (), + }; + rpc_handler = x.leave()?; + } + RpcHandle::Flush { id, x } => { + let indexes = unsafe { INDEXES.assume_init_ref() }; + let index = indexes.get(&id).expect("Not load."); + index.flush(); + rpc_handler = x.leave()?; + } + RpcHandle::Clean { id, x } => { + use dashmap::mapref::entry::Entry; + let indexes: &DashMap = unsafe { INDEXES.assume_init_ref() }; + match indexes.entry(id) { + Entry::Occupied(mut entry) => { + entry.get_mut().shutdown(); + entry.remove(); + Index::clean(id); + } + Entry::Vacant(_entry) => { + Index::clean(id); + } + }; + rpc_handler = x.leave()?; + } + RpcHandle::Leave {} => break, } } + Ok(()) } diff --git a/src/bgworker/session.rs b/src/bgworker/session.rs deleted file mode 100644 index 812fcb7a3..000000000 --- a/src/bgworker/session.rs +++ /dev/null @@ -1,470 +0,0 @@ -use super::find_index; -use super::index::Index; -use crate::prelude::*; -use serde::{Deserialize, Serialize}; -use std::{ - io::ErrorKind, - time::{Duration, Instant}, -}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum ClientPacket { - // requests - Build0 { - id: Id, - options: Options, - }, - Build1((Box<[Scalar]>, Pointer)), - Build2, - Load { - id: Id, - }, - Unload { - id: Id, - }, - Insert { - id: Id, - insert: (Box<[Scalar]>, Pointer), - }, - Delete { - id: Id, - delete: Pointer, - }, - Search { - id: Id, - search: (Box<[Scalar]>, usize), - }, - Flush { - id: Id, - }, - Drop { - id: Id, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum ServerPacket { - Reset(String), - // responses - Build {}, - Load {}, - Unload {}, - Insert {}, - Delete {}, - Search { result: Vec }, - Flush {}, - Drop {}, -} - -struct Server { - read: tokio::io::BufReader, - write: tokio::io::BufWriter, -} - -impl Server { - fn new(stream: tokio::net::TcpStream) -> Self { - let (read, write) = stream.into_split(); - let read = tokio::io::BufReader::new(read); - let write = tokio::io::BufWriter::new(write); - Self { read, write } - } - async fn recv(&mut self) -> anyhow::Result { - use tokio::io::AsyncReadExt; - let packet_size = self.read.read_u16().await?; - let mut buffer = vec![0u8; packet_size as usize]; - self.read.read_exact(&mut buffer).await?; - buffer.deserialize() - } - async fn send(&mut self, maybe: anyhow::Result) -> anyhow::Result<()> { - use tokio::io::AsyncWriteExt; - let packet = match maybe { - Ok(packet) => packet, - Err(e) => ServerPacket::Reset(e.to_string()), - }; - let packet = packet.bincode()?; - anyhow::ensure!(packet.len() <= u16::MAX as usize); - let packet_size = packet.len() as u16; - self.write.write_u16(packet_size).await?; - self.write.write_all(&packet).await?; - Ok(()) - } - async fn flush(&mut self) -> anyhow::Result<()> { - use tokio::io::AsyncWriteExt; - self.write.flush().await?; - Ok(()) - } -} - -pub async fn server_main(stream: tokio::net::TcpStream) -> anyhow::Result<()> { - let mut server = Server::new(stream); - loop { - let packet = server.recv().await?; - match packet { - ClientPacket::Build0 { id, options } => { - let (tx, rx) = async_channel::bounded(65536); - let maybe = { - let data = tokio::spawn(async move { - loop { - let packet = server.recv().await?; - match packet { - ClientPacket::Build1(data) => { - tx.send(data).await?; - } - ClientPacket::Build2 => { - drop(tx); - return anyhow::Result::Ok(server); - } - _ => anyhow::bail!("Bad state."), - } - } - }); - handler_build(id, options, rx).await?; - server = data.await??; - Ok(ServerPacket::Build {}) - }; - server.send(maybe).await?; - server.flush().await?; - } - ClientPacket::Load { id } => { - let maybe = async { - handler_load(id).await?; - Ok(ServerPacket::Load {}) - } - .await; - server.send(maybe).await?; - server.flush().await?; - } - ClientPacket::Unload { id } => { - let maybe = async { - handler_unload(id).await?; - Ok(ServerPacket::Unload {}) - } - .await; - server.send(maybe).await?; - server.flush().await?; - } - ClientPacket::Insert { id, insert } => { - let maybe = async { - handler_insert(id, insert).await?; - Ok(ServerPacket::Insert {}) - } - .await; - server.send(maybe).await?; - server.flush().await?; - } - ClientPacket::Delete { id, delete } => { - let maybe = async { - handler_delete(id, delete).await?; - Ok(ServerPacket::Delete {}) - } - .await; - server.send(maybe).await?; - server.flush().await?; - } - ClientPacket::Search { id, search } => { - let maybe = async { - let data = handler_search(id, search).await?; - Ok(ServerPacket::Search { result: data }) - } - .await; - server.send(maybe).await?; - server.flush().await?; - } - ClientPacket::Flush { id } => { - let maybe = async { - handler_flush(id).await?; - Ok(ServerPacket::Flush {}) - } - .await; - server.send(maybe).await?; - server.flush().await?; - } - ClientPacket::Drop { id } => { - let maybe = async { - handler_drop(id).await?; - Ok(ServerPacket::Drop {}) - } - .await; - server.send(maybe).await?; - server.flush().await?; - } - _ => anyhow::bail!("Bad state."), - } - } -} - -async fn handler_build( - id: Id, - options: Options, - data: async_channel::Receiver<(Box<[Scalar]>, Pointer)>, -) -> anyhow::Result<()> { - let index = find_index(id).await?; - let mut guard = index.write().await; - if guard.is_unloaded() { - guard.load(Index::build(id, options, data).await?); - } - Ok(()) -} - -async fn handler_load(id: Id) -> anyhow::Result<()> { - let index = find_index(id).await?; - let mut guard = index.write().await; - if guard.is_unloaded() { - guard.load(Index::load(id).await?); - } - Ok(()) -} - -async fn handler_unload(id: Id) -> anyhow::Result<()> { - let index = find_index(id).await?; - let mut guard = index.write().await; - if guard.is_loaded() { - guard.unload().shutdown().await?; - } - Ok(()) -} - -async fn handler_insert(id: Id, insert: (Box<[Scalar]>, Pointer)) -> anyhow::Result<()> { - let index = find_index(id).await?; - let index = index.read().await; - let index = index.get()?; - index.insert(insert).await?; - Ok(()) -} - -async fn handler_delete(id: Id, delete: Pointer) -> anyhow::Result<()> { - let index = find_index(id).await?; - let index = index.read().await; - let index = index.get()?; - index.delete(delete).await?; - Ok(()) -} - -async fn handler_search( - id: Id, - (x_vector, k): (Box<[Scalar]>, usize), -) -> anyhow::Result> { - let index = find_index(id).await?; - let index = index.read().await; - let index = index.get()?; - let data = index.search((x_vector, k)).await?; - Ok(data) -} - -async fn handler_flush(id: Id) -> anyhow::Result<()> { - let index = find_index(id).await?; - let index = index.read().await; - let index = index.get()?; - index.flush().await?; - Ok(()) -} - -async fn handler_drop(id: Id) -> anyhow::Result<()> { - let index = find_index(id).await?; - let mut guard = index.write().await; - if guard.is_loaded() { - let x = guard.unload(); - x.shutdown().await?; - Index::drop(id).await?; - } - Ok(()) -} - -pub struct Client { - read: std::io::BufReader, - write: std::io::BufWriter, -} - -impl Client { - pub fn new(tcp: std::net::TcpStream) -> anyhow::Result { - let read = std::io::BufReader::new(tcp.try_clone()?); - let write = std::io::BufWriter::new(tcp.try_clone()?); - Ok(Self { read, write }) - } - - fn _recv(&mut self) -> anyhow::Result { - use byteorder::BigEndian as E; - use byteorder::ReadBytesExt; - use std::io::Read; - let packet_size = self.read.read_u16::()?; - let mut buffer = vec![0u8; packet_size as usize]; - self.read.read_exact(&mut buffer)?; - buffer.deserialize() - } - - fn _send(&mut self, packet: ClientPacket) -> anyhow::Result<()> { - use byteorder::BigEndian as E; - use byteorder::WriteBytesExt; - use std::io::Write; - let packet = packet.bincode()?; - anyhow::ensure!(packet.len() <= u16::MAX as usize); - let packet_size = packet.len() as u16; - self.write.write_u16::(packet_size)?; - self.write.write_all(&packet)?; - Ok(()) - } - - fn _test(&mut self) -> anyhow::Result { - if !self.read.buffer().is_empty() { - return Ok(true); - } - unsafe { - use std::os::fd::AsRawFd; - let mut buf = [0u8]; - let result = libc::recv( - self.read.get_mut().as_raw_fd(), - buf.as_mut_ptr() as _, - 1, - libc::MSG_PEEK | libc::MSG_DONTWAIT, - ); - match result { - -1 => { - let err = std::io::Error::last_os_error(); - if err.kind() == ErrorKind::WouldBlock { - Ok(false) - } else { - Err(err.into()) - } - } - 0 => { - // TCP stream is closed. - Ok(false) - } - 1 => Ok(true), - _ => unreachable!(), - } - } - } - - fn _flush(&mut self) -> anyhow::Result<()> { - use std::io::Write; - self.write.flush()?; - Ok(()) - } - - pub fn build(&mut self, id: Id, options: Options) -> anyhow::Result { - self._send(ClientPacket::Build0 { id, options })?; - Ok(ClientBuild { - last: Instant::now(), - client: self, - }) - } - - pub fn insert(&mut self, id: Id, insert: (Box<[Scalar]>, Pointer)) -> anyhow::Result<()> { - self._send(ClientPacket::Insert { id, insert })?; - self._flush()?; - let packet = self._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Insert {} => Ok(()), - _ => anyhow::bail!("Bad state."), - } - } - - pub fn delete(&mut self, id: Id, delete: Pointer) -> anyhow::Result<()> { - self._send(ClientPacket::Delete { id, delete })?; - self._flush()?; - let packet = self._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Delete {} => Ok(()), - _ => anyhow::bail!("Bad state."), - } - } - - pub fn search( - &mut self, - id: Id, - search: (Box<[Scalar]>, usize), - ) -> anyhow::Result> { - self._send(ClientPacket::Search { id, search })?; - self._flush()?; - let packet = self._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Search { result } => Ok(result), - _ => anyhow::bail!("Bad state."), - } - } - - pub fn load(&mut self, id: Id) -> anyhow::Result<()> { - self._send(ClientPacket::Load { id })?; - self._flush()?; - let packet = self._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Load {} => Ok(()), - _ => anyhow::bail!("Bad state."), - } - } - - pub fn unload(&mut self, id: Id) -> anyhow::Result<()> { - self._send(ClientPacket::Unload { id })?; - self._flush()?; - let packet = self._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Unload {} => Ok(()), - _ => anyhow::bail!("Bad state."), - } - } - - pub fn flush(&mut self, id: Id) -> anyhow::Result<()> { - self._send(ClientPacket::Flush { id })?; - self._flush()?; - let packet = self._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Flush {} => Ok(()), - _ => anyhow::bail!("Bad state."), - } - } - - pub fn drop(&mut self, id: Id) -> anyhow::Result<()> { - self._send(ClientPacket::Drop { id })?; - self._flush()?; - let packet = self._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Drop {} => Ok(()), - _ => anyhow::bail!("Bad state."), - } - } -} - -pub struct ClientBuild<'a> { - last: Instant, - client: &'a mut Client, -} - -impl<'a> ClientBuild<'a> { - #[allow(clippy::never_loop)] - fn _process(&mut self) -> anyhow::Result<()> { - if self.last.elapsed() > Duration::from_millis(200) { - while self.client._test()? { - let packet = self.client._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - _ => anyhow::bail!("Bad state."), - } - } - self.last = Instant::now(); - } - Ok(()) - } - pub fn next(&mut self, data: (Box<[Scalar]>, Pointer)) -> anyhow::Result<()> { - self._process()?; - self.client._send(ClientPacket::Build1(data))?; - Ok(()) - } - pub fn finish(self) -> anyhow::Result<()> { - self.client._send(ClientPacket::Build2)?; - self.client._flush()?; - let packet = self.client._recv()?; - match packet { - ServerPacket::Reset(e) => anyhow::bail!(e), - ServerPacket::Build {} => Ok(()), - _ => anyhow::bail!("Bad state."), - } - } -} diff --git a/src/bgworker/storage.rs b/src/bgworker/storage.rs new file mode 100644 index 000000000..4413115a8 --- /dev/null +++ b/src/bgworker/storage.rs @@ -0,0 +1,68 @@ +use super::storage_mmap::{MmapBox, StorageMmap}; +use crate::bgworker::storage_mmap; +use crate::prelude::*; +use std::mem::MaybeUninit; + +#[derive(Debug, Clone)] +pub enum StoragePreallocatorElement { + Mmap(storage_mmap::StorageMmapPreallocatorElement), +} + +pub struct StoragePreallocator { + sequence: Vec, +} + +impl StoragePreallocator { + pub fn new() -> Self { + Self { + sequence: Vec::new(), + } + } + pub fn palloc_mmap(&mut self, memmap: Memmap) { + use StoragePreallocatorElement::Mmap; + self.sequence + .push(Mmap(storage_mmap::prealloc::(memmap))); + } + pub fn palloc_mmap_slice(&mut self, memmap: Memmap, len: usize) { + use StoragePreallocatorElement::Mmap; + self.sequence + .push(Mmap(storage_mmap::prealloc_slice::(memmap, len))); + } +} + +pub struct Storage { + storage_mmap: StorageMmap, +} + +impl Storage { + pub fn build(id: Id, preallocator: StoragePreallocator) -> Self { + let mmap_iter = preallocator + .sequence + .iter() + .filter_map(|x| { + use StoragePreallocatorElement::Mmap; + #[allow(unreachable_patterns)] + match x.clone() { + Mmap(x) => Some(x), + _ => None, + } + }) + .collect::>() + .into_iter(); + let storage_mmap = StorageMmap::build(id, mmap_iter); + Self { storage_mmap } + } + pub fn load(id: Id) -> Self { + let storage_mmap = StorageMmap::load(id); + Self { storage_mmap } + } + pub fn alloc_mmap(&mut self, memmap: Memmap) -> MmapBox> { + self.storage_mmap.alloc_mmap(memmap) + } + pub fn alloc_mmap_slice(&mut self, memmap: Memmap, len: usize) -> MmapBox<[MaybeUninit]> { + self.storage_mmap.alloc_mmap_slice(memmap, len) + } + pub fn persist(&self) { + self.storage_mmap.persist(); + } +} diff --git a/src/bgworker/storage_mmap.rs b/src/bgworker/storage_mmap.rs new file mode 100644 index 000000000..bf7e525b2 --- /dev/null +++ b/src/bgworker/storage_mmap.rs @@ -0,0 +1,231 @@ +use crate::prelude::{Id, Memmap}; +use cstr::cstr; +use memmap2::MmapMut; +use std::alloc::Layout; +use std::borrow::{Borrow, BorrowMut}; +use std::fmt::Debug; +use std::fs::{File, OpenOptions}; +use std::mem::MaybeUninit; +use std::ops::{Deref, DerefMut}; +use std::os::fd::FromRawFd; + +pub unsafe auto trait Pointerless {} + +impl !Pointerless for *const T {} +impl !Pointerless for *mut T {} +impl !Pointerless for &'_ T {} +impl !Pointerless for &'_ mut T {} + +pub type StorageMmapPreallocatorElement = (Memmap, Layout); + +pub fn prealloc(memmap: Memmap) -> StorageMmapPreallocatorElement { + (memmap, std::alloc::Layout::new::()) +} + +pub fn prealloc_slice(memmap: Memmap, len: usize) -> StorageMmapPreallocatorElement { + (memmap, std::alloc::Layout::array::(len).unwrap()) +} + +pub struct StorageMmap { + block_ram: Block, + block_disk: Block, +} + +impl StorageMmap { + pub fn build(id: Id, iter: impl Iterator) -> Self { + let mut size_ram = 0usize; + let mut size_disk = 0usize; + for (memmap, layout) in iter { + match memmap { + Memmap::Ram => { + size_ram = size_ram.next_multiple_of(layout.align()); + size_ram += layout.size(); + } + Memmap::Disk => { + size_disk = size_disk.next_multiple_of(layout.align()); + size_disk += layout.size(); + } + } + } + let size_ram = size_ram.next_multiple_of(4096); + let size_disk = size_disk.next_multiple_of(4096); + let block_ram = Block::build(size_ram, format!("{}_ram", id), Memmap::Ram); + let block_disk = Block::build(size_disk, format!("{}_disk", id), Memmap::Disk); + Self { + block_ram, + block_disk, + } + } + pub fn load(id: Id) -> Self { + let block_ram = Block::load(format!("{}_ram", id), Memmap::Ram); + let block_disk = Block::load(format!("{}_disk", id), Memmap::Disk); + Self { + block_ram, + block_disk, + } + } + pub fn alloc_mmap(&mut self, memmap: Memmap) -> MmapBox> { + let ptr = match memmap { + Memmap::Ram => self.block_ram.allocate(std::alloc::Layout::new::()), + Memmap::Disk => self.block_disk.allocate(std::alloc::Layout::new::()), + }; + MmapBox(ptr.cast()) + } + pub fn alloc_mmap_slice(&mut self, memmap: Memmap, len: usize) -> MmapBox<[MaybeUninit]> { + let ptr = match memmap { + Memmap::Ram => self + .block_ram + .allocate(std::alloc::Layout::array::(len).unwrap()), + Memmap::Disk => self + .block_disk + .allocate(std::alloc::Layout::array::(len).unwrap()), + }; + MmapBox(unsafe { std::slice::from_raw_parts_mut(ptr.cast(), len) }) + } + pub fn persist(&self) { + self.block_ram.persist(); + self.block_disk.persist(); + } +} + +struct Block { + path: String, + mmap: MmapMut, + cursor: usize, +} + +impl Block { + fn build(size: usize, path: String, memmap: Memmap) -> Self { + assert!(size % 4096 == 0); + let file = tempfile(memmap).expect("Failed to create temp file."); + file.set_len(size as u64) + .expect("Failed to resize the file."); + let mmap = unsafe { MmapMut::map_mut(&file).expect("Failed to create mmap.") }; + let _ = mmap.advise(memmap2::Advice::WillNeed); + Self { + path, + mmap, + cursor: 0, + } + } + + fn load(path: String, memmap: Memmap) -> Self { + let mut file = tempfile(memmap).expect("Failed to create temp file."); + let mut persistent_file = std::fs::OpenOptions::new() + .read(true) + .open(&path) + .expect("Failed to read index."); + std::io::copy(&mut persistent_file, &mut file).expect("Failed to write temp file."); + let mmap = unsafe { MmapMut::map_mut(&file).expect("Failed to create mmap.") }; + Self { + path, + mmap, + cursor: 0, + } + } + + fn allocate(&mut self, layout: Layout) -> *mut () { + self.cursor = self.cursor.next_multiple_of(layout.align()); + let offset = self.cursor; + self.cursor += layout.size(); + assert!(self.cursor <= self.mmap.len()); + unsafe { self.mmap.as_ptr().add(offset).cast_mut().cast() } + } + + fn persist(&self) { + use std::io::Write; + let mut persistent_file = OpenOptions::new() + .create(true) + .read(true) + .write(true) + .truncate(true) + .open(&self.path) + .expect("Failed to open the persistent file."); + persistent_file + .write_all(self.mmap.as_ref()) + .expect("Failed to write the persisent file."); + persistent_file + .sync_all() + .expect("Failed to write the persisent file."); + } +} + +pub struct MmapBox(*mut T); + +impl MmapBox> { + pub unsafe fn assume_init(self) -> MmapBox { + MmapBox(self.0.cast()) + } +} + +impl MmapBox<[MaybeUninit]> { + pub unsafe fn assume_init(self) -> MmapBox<[T]> { + MmapBox(std::ptr::from_raw_parts_mut( + self.0.cast(), + std::ptr::metadata(self.0), + )) + } +} + +unsafe impl Send for MmapBox {} +unsafe impl Sync for MmapBox {} + +impl Deref for MmapBox { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.0 } + } +} + +impl DerefMut for MmapBox { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.0 } + } +} + +impl AsRef for MmapBox { + fn as_ref(&self) -> &T { + unsafe { &*self.0 } + } +} + +impl AsMut for MmapBox { + fn as_mut(&mut self) -> &mut T { + unsafe { &mut *self.0 } + } +} + +impl Borrow for MmapBox { + fn borrow(&self) -> &T { + unsafe { &*self.0 } + } +} + +impl BorrowMut for MmapBox { + fn borrow_mut(&mut self) -> &mut T { + unsafe { &mut *self.0 } + } +} + +impl Debug for MmapBox { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Debug::fmt(self.deref(), f) + } +} + +fn tempfile(memmap: Memmap) -> std::io::Result { + use Memmap::*; + let file = match memmap { + Disk => tempfile::tempfile()?, + Ram => unsafe { + let fd = libc::memfd_create(cstr!("file").as_ptr(), 0); + if fd != -1 { + File::from_raw_fd(fd) + } else { + return Err(std::io::Error::last_os_error()); + } + }, + }; + Ok(file) +} diff --git a/src/bgworker/vectors.rs b/src/bgworker/vectors.rs new file mode 100644 index 000000000..11a76649c --- /dev/null +++ b/src/bgworker/vectors.rs @@ -0,0 +1,131 @@ +use super::storage::Storage; +use super::storage::StoragePreallocator; +use super::storage_mmap::MmapBox; +use crate::bgworker::index::IndexOptions; +use crate::prelude::*; +use serde::{Deserialize, Serialize}; +use std::cell::UnsafeCell; +use std::mem::MaybeUninit; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VectorsOptions { + #[serde(default = "VectorsOptions::default_memmap")] + pub memmap: Memmap, +} + +impl VectorsOptions { + fn default_memmap() -> Memmap { + Memmap::Ram + } +} + +type Boxed = MmapBox<[UnsafeCell>]>; + +pub struct Vectors { + dims: u16, + capacity: usize, + // + len: MmapBox, + inflight: MmapBox, + data: Boxed, + vector: Boxed, +} + +unsafe impl Send for Vectors {} +unsafe impl Sync for Vectors {} + +impl Vectors { + pub fn prebuild(storage: &mut StoragePreallocator, options: IndexOptions) { + let memmap = options.vectors.memmap; + let len_data = options.capacity; + let len_vector = options.capacity * options.dims as usize; + storage.palloc_mmap::(memmap); + storage.palloc_mmap::(memmap); + storage.palloc_mmap_slice::>>(memmap, len_data); + storage.palloc_mmap_slice::>>(memmap, len_vector); + } + pub fn build(storage: &mut Storage, options: IndexOptions) -> Self { + let memmap = options.vectors.memmap; + let len_data = options.capacity; + let len_vector = options.capacity * options.dims as usize; + Self { + dims: options.dims, + capacity: options.capacity, + len: unsafe { + let mut len = storage.alloc_mmap(memmap); + len.write(AtomicUsize::new(0)); + len.assume_init() + }, + inflight: unsafe { + let mut inflight = storage.alloc_mmap(memmap); + inflight.write(AtomicUsize::new(0)); + inflight.assume_init() + }, + data: unsafe { storage.alloc_mmap_slice(memmap, len_data).assume_init() }, + vector: unsafe { storage.alloc_mmap_slice(memmap, len_vector).assume_init() }, + } + } + pub fn load(storage: &mut Storage, options: IndexOptions) -> Self { + let memmap = options.vectors.memmap; + let len_data = options.capacity; + let len_vector = options.capacity * options.dims as usize; + Self { + capacity: options.capacity, + dims: options.dims, + len: unsafe { storage.alloc_mmap(memmap).assume_init() }, + inflight: unsafe { storage.alloc_mmap(memmap).assume_init() }, + data: unsafe { storage.alloc_mmap_slice(memmap, len_data).assume_init() }, + vector: unsafe { storage.alloc_mmap_slice(memmap, len_vector).assume_init() }, + } + } + pub fn put(&self, data: u64, vector: &[Scalar]) -> usize { + // If the index is approaching to `usize::MAX`, it will break. But it will not likely happen. + let i = self.inflight.fetch_add(1, Ordering::AcqRel); + if i >= self.capacity { + self.inflight.store(self.capacity, Ordering::Release); + panic!("The capacity is used up."); + } + unsafe { + let uninit_data = &mut *self.data[i].get(); + MaybeUninit::write(uninit_data, data); + let slice_vector = &self.vector[i * self.dims as usize..][..self.dims as usize]; + let uninit_vector = assume_mutable(slice_vector); + uninit_vector.copy_from_slice(std::slice::from_raw_parts( + vector.as_ptr() as *const MaybeUninit, + vector.len(), + )); + } + while self + .len + .compare_exchange_weak(i, i + 1, Ordering::AcqRel, Ordering::Relaxed) + .is_err() + { + std::hint::spin_loop(); + } + i + } + pub fn len(&self) -> usize { + self.len.load(Ordering::Acquire) + } + pub fn get_data(&self, i: usize) -> u64 { + unsafe { (*self.data[i].get()).assume_init_read() } + } + pub fn get_vector(&self, i: usize) -> &[Scalar] { + unsafe { + assume_immutable_init(&self.vector[i * self.dims as usize..][..self.dims as usize]) + } + } +} + +#[allow(clippy::mut_from_ref)] +unsafe fn assume_mutable(slice: &[UnsafeCell]) -> &mut [T] { + let p = slice.as_ptr().cast::>() as *mut T; + std::slice::from_raw_parts_mut(p, slice.len()) +} + +unsafe fn assume_immutable_init(slice: &[UnsafeCell>]) -> &[T] { + let p = slice.as_ptr().cast::>() as *const T; + std::slice::from_raw_parts(p, slice.len()) +} diff --git a/src/bgworker/wal.rs b/src/bgworker/wal.rs index 574bd59f8..d985f030e 100644 --- a/src/bgworker/wal.rs +++ b/src/bgworker/wal.rs @@ -1,10 +1,11 @@ use byteorder::NativeEndian as N; use crc32fast::hash as crc32; use std::path::Path; +use std::thread::JoinHandle; /* +----------+-----------+---------+ -| CRC (4B) | Size (2B) | Payload | +| CRC (4B) | Size (4B) | Payload | +----------+-----------+---------+ */ @@ -16,192 +17,168 @@ pub enum WalStatus { Flush, } -pub struct WalSync { +pub struct Wal { file: std::fs::File, offset: usize, status: WalStatus, } -impl WalSync { - pub fn open(path: impl AsRef) -> anyhow::Result { +impl Wal { + pub fn open(path: impl AsRef) -> Self { use WalStatus::*; let file = std::fs::OpenOptions::new() .create(true) .write(true) .read(true) - .open(path)?; - Ok(Self { + .open(path) + .expect("Failed to open wal."); + Self { file, offset: 0, status: Read, - }) + } } - pub fn create(path: impl AsRef) -> anyhow::Result { + pub fn create(path: impl AsRef) -> Self { use WalStatus::*; let file = std::fs::OpenOptions::new() .create(true) .write(true) .read(true) .truncate(true) - .open(path)?; - Ok(Self { + .open(path) + .expect("Failed to create wal."); + Self { file, offset: 0, status: Write, - }) + } } - pub fn read(&mut self) -> anyhow::Result>> { + pub fn read(&mut self) -> Option> { use byteorder::ReadBytesExt; use std::io::Read; - use std::io::{Error, ErrorKind}; use WalStatus::*; - let Read = self.status else { panic!("Operation not permitted.") }; - let maybe_error = (|| -> std::io::Result> { - let crc = self.file.read_u32::()?; - let len = self.file.read_u16::()?; - let mut data = vec![0u8; len as usize]; - self.file.read_exact(&mut data)?; - if crc32(&data) == crc { - self.offset += 4 + 2 + data.len(); - Ok(data) - } else { - Err(Error::new(ErrorKind::UnexpectedEof, "Bad crc."))? - } - })(); - match maybe_error { - Ok(data) => Ok(Some(data)), - Err(error) if error.kind() == ErrorKind::UnexpectedEof => { - self.status = WalStatus::Truncate; - Ok(None) - } - Err(error) => anyhow::bail!(error), + let Read = self.status else { + panic!("Operation not permitted.") + }; + macro_rules! resolve_eof { + ($t: expr) => { + match $t { + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + self.status = WalStatus::Truncate; + return None; + } + Err(e) => panic!("{}", e), + Ok(e) => e, + } + }; + } + let crc = resolve_eof!(self.file.read_u32::()); + let len = resolve_eof!(self.file.read_u32::()); + let mut data = vec![0u8; len as usize]; + resolve_eof!(self.file.read_exact(&mut data)); + if crc32(&data) != crc { + self.status = WalStatus::Truncate; + return None; } + self.offset += 4 + 4 + data.len(); + Some(data) } - pub fn truncate(&mut self) -> anyhow::Result<()> { + pub fn truncate(&mut self) { use WalStatus::*; - let Truncate = self.status else { panic!("Operation not permitted.") }; - self.file.set_len(self.offset as _)?; - self.file.sync_all()?; + let Truncate = self.status else { + panic!("Operation not permitted.") + }; + self.file + .set_len(self.offset as _) + .expect("Failed to truncate wal."); + self.file.sync_all().expect("Failed to flush wal."); self.status = WalStatus::Flush; - Ok(()) } - pub fn write(&mut self, bytes: &[u8]) -> anyhow::Result<()> { + pub fn write(&mut self, bytes: &[u8]) { use byteorder::WriteBytesExt; use std::io::Write; use WalStatus::*; - let (Write | Flush) = self.status else { panic!("Operation not permitted.") }; - self.file.write_u32::(crc32(bytes))?; - self.file.write_u16::(bytes.len() as _)?; - self.file.write_all(bytes)?; - self.offset += 4 + 2 + bytes.len(); + let (Write | Flush) = self.status else { + panic!("Operation not permitted.") + }; + self.file + .write_u32::(crc32(bytes)) + .expect("Failed to write wal."); + self.file + .write_u32::(bytes.len() as _) + .expect("Failed to write wal."); + self.file.write_all(bytes).expect("Failed to write wal."); + self.offset += 4 + 4 + bytes.len(); self.status = WalStatus::Write; - Ok(()) } - pub fn flush(&mut self) -> anyhow::Result<()> { + pub fn flush(&mut self) { use WalStatus::*; - let (Write | Flush) = self.status else { panic!("Operation not permitted.") }; - self.file.sync_all()?; + let (Write | Flush) = self.status else { + panic!("Operation not permitted.") + }; + self.file.sync_all().expect("Failed to flush wal."); self.status = WalStatus::Flush; - Ok(()) - } - pub fn into_async(self) -> WalAsync { - WalAsync { - file: tokio::fs::File::from_std(self.file), - offset: self.offset, - status: self.status, - } } } -pub struct WalAsync { - file: tokio::fs::File, - offset: usize, - status: WalStatus, +pub struct WalWriter { + #[allow(dead_code)] + handle: JoinHandle, + tx: crossbeam::channel::Sender, } -impl WalAsync { - pub async fn write(&mut self, bytes: &[u8]) -> anyhow::Result<()> { - use tokio::io::AsyncWriteExt; +impl WalWriter { + pub fn spawn(wal: Wal) -> WalWriter { use WalStatus::*; - let (Write | Flush) = self.status else { panic!("Operation not permitted.") }; - self.file.write_u32(crc32(bytes)).await?; - self.file.write_u16(bytes.len() as _).await?; - self.file.write_all(bytes).await?; - self.offset += 4 + 2 + bytes.len(); - self.status = WalStatus::Write; - Ok(()) + let (Write | Flush) = wal.status else { + panic!("Operation not permitted.") + }; + let (tx, rx) = crossbeam::channel::bounded(256); + let handle = std::thread::spawn(move || thread_wal(wal, rx)); + WalWriter { handle, tx } } - pub async fn flush(&mut self) -> anyhow::Result<()> { - use WalStatus::*; - let (Write | Flush) = self.status else { panic!("Operation not permitted.") }; - self.file.sync_all().await?; - self.status = WalStatus::Flush; - Ok(()) + pub fn write(&self, data: Vec) { + self.tx + .send(WalWriterMessage::Write(data)) + .expect("Wal thread exited."); + } + pub fn flush(&self) { + let (tx, rx) = crossbeam::channel::bounded::(0); + self.tx + .send(WalWriterMessage::Flush(tx)) + .expect("Wal thread exited."); + let _ = rx.recv(); + } + pub fn shutdown(&mut self) { + let (tx, rx) = crossbeam::channel::bounded::(0); + self.tx + .send(WalWriterMessage::Shutdown(tx)) + .expect("Wal thread exited."); + let _ = rx.recv(); } } enum WalWriterMessage { Write(Vec), - Flush(tokio::sync::oneshot::Sender<()>), -} - -pub struct WalWriter { - tx: Option>, - handle: tokio::task::JoinHandle>, + Flush(crossbeam::channel::Sender), + Shutdown(crossbeam::channel::Sender), } -impl WalWriter { - pub fn spawn(mut wal: WalAsync) -> anyhow::Result { - use WalStatus::*; - anyhow::ensure!(matches!(wal.status, Write | Flush)); - let (tx, mut rx) = tokio::sync::mpsc::channel(4096); - let handle = tokio::task::spawn(async move { - while let Some(r) = rx.recv().await { - use WalWriterMessage::*; - match r { - Write(bytes) => { - wal.write(&bytes).await?; - } - Flush(callback) => { - wal.flush().await?; - let _ = callback.send(()); - } - } +fn thread_wal(mut wal: Wal, rx: crossbeam::channel::Receiver) -> Wal { + while let Ok(message) = rx.recv() { + match message { + WalWriterMessage::Write(data) => { + wal.write(&data); } - Ok(()) - }); - Ok(Self { - tx: Some(tx), - handle, - }) - } - pub async fn write(&self, bytes: Vec) -> anyhow::Result<()> { - use WalWriterMessage::*; - self.tx - .as_ref() - .unwrap() - .send(Write(bytes)) - .await - .ok() - .ok_or(anyhow::anyhow!("The WAL thread exited."))?; - Ok(()) - } - pub async fn flush(&self) -> anyhow::Result<()> { - use WalWriterMessage::*; - let (tx, rx) = tokio::sync::oneshot::channel(); - self.tx - .as_ref() - .unwrap() - .send(Flush(tx)) - .await - .ok() - .ok_or(anyhow::anyhow!("The WAL thread exited."))?; - rx.await?; - Ok(()) - } - pub async fn shutdown(mut self) -> anyhow::Result<()> { - self.tx.take(); - self.handle.await??; - Ok(()) + WalWriterMessage::Flush(_callback) => { + wal.flush(); + } + WalWriterMessage::Shutdown(_callback) => { + wal.flush(); + return wal; + } + } } + wal.flush(); + wal } diff --git a/src/embedding/mod.rs b/src/embedding/mod.rs new file mode 100644 index 000000000..dfcc87ee7 --- /dev/null +++ b/src/embedding/mod.rs @@ -0,0 +1,5 @@ +pub(crate) type Embedding = Vec; +pub(crate) type Embeddings = Vec; + +mod openai; +mod udf; diff --git a/src/embedding.rs b/src/embedding/openai.rs similarity index 93% rename from src/embedding.rs rename to src/embedding/openai.rs index 9dfabf739..b25c94a90 100644 --- a/src/embedding.rs +++ b/src/embedding/openai.rs @@ -3,11 +3,11 @@ use openai_api_rust::{ Auth, OpenAI, }; -pub(crate) type Embedding = Vec; -pub(crate) type Embeddings = Vec; +use super::Embeddings; #[cfg(test)] use mockall::automock; + #[cfg_attr(test, automock)] pub(crate) trait EmbeddingCreator { fn create_embeddings(&self, input: Vec) -> Result; @@ -77,13 +77,11 @@ impl EmbeddingCreator for OpenAIEmbedding { } #[cfg(test)] mod tests { + use crate::embedding::openai::{EmbeddingCreator, EmbeddingModel, OpenAIEmbedding}; + use crate::embedding::Embedding; use httpmock::MockServer; use serde_json::json; - use crate::embedding::{Embedding, EmbeddingCreator}; - - use super::OpenAIEmbedding; - #[test] fn test_create_embeddings() { let input: String = "hello".to_string(); @@ -97,7 +95,7 @@ mod tests { let client = OpenAIEmbedding::new( "".to_string(), - crate::embedding::EmbeddingModel::Ada002, + EmbeddingModel::Ada002, server.base_url() + "/", ); diff --git a/src/udf.rs b/src/embedding/udf.rs similarity index 79% rename from src/udf.rs rename to src/embedding/udf.rs index 1812d9265..cc6bfb14c 100644 --- a/src/udf.rs +++ b/src/embedding/udf.rs @@ -1,12 +1,14 @@ +use super::openai::{EmbeddingCreator, OpenAIEmbedding}; +use super::Embedding; +use crate::postgres::datatype::Vector; +use crate::postgres::datatype::VectorOutput; +use crate::postgres::gucs::OPENAI_API_KEY_GUC; +use crate::prelude::Float; +use crate::prelude::Scalar; use pgrx::prelude::*; -use crate::{ - embedding::{Embedding, EmbeddingCreator, OpenAIEmbedding}, - postgres::OPENAI_API_KEY_GUC, -}; - #[pg_extern] -fn ai_embedding_vector(input: String) -> Embedding { +fn ai_embedding_vector(input: String) -> VectorOutput { let api_key = match OPENAI_API_KEY_GUC.get() { Some(key) => key .to_str() @@ -21,7 +23,13 @@ fn ai_embedding_vector(input: String) -> Embedding { let openai_embedding = OpenAIEmbedding::new_ada002(api_key); match ai_embedding_vector_inner(input, openai_embedding) { - Ok(embedding) => embedding, + Ok(embedding) => { + let embedding = embedding + .into_iter() + .map(|x| Scalar(x as Float)) + .collect::>(); + Vector::new_in_postgres(&embedding) + } Err(e) => { error!("{}", e) } @@ -43,8 +51,8 @@ fn ai_embedding_vector_inner( #[cfg(test)] mod tests { - use crate::embedding::MockEmbeddingCreator; - use crate::udf::ai_embedding_vector_inner; + use crate::embedding::openai::MockEmbeddingCreator; + use crate::embedding::udf::ai_embedding_vector_inner; use mockall::predicate::eq; // We need to mock embedding since it requires an API key. diff --git a/src/ipc/client.rs b/src/ipc/client.rs new file mode 100644 index 000000000..936b6ba68 --- /dev/null +++ b/src/ipc/client.rs @@ -0,0 +1,167 @@ +use crate::bgworker::index::IndexOptions; +use crate::ipc::packet::*; +use crate::ipc::transport::Socket; +use crate::ipc::ClientIpcError; +use crate::prelude::*; + +pub struct Rpc { + socket: Socket, +} + +impl Rpc { + pub(super) fn new(socket: Socket) -> Self { + Self { socket } + } + pub fn build(mut self, id: Id, options: IndexOptions) -> Result { + let packet = RpcPacket::Build { id, options }; + self.socket.client_send(packet)?; + Ok(BuildHandler { + socket: self.socket, + reach: false, + }) + } + pub fn search( + mut self, + id: Id, + target: Box<[Scalar]>, + k: usize, + ) -> Result { + let packet = RpcPacket::Search { id, target, k }; + self.socket.client_send(packet)?; + Ok(SearchHandler { + socket: self.socket, + }) + } + pub fn insert( + &mut self, + id: Id, + insert: (Box<[Scalar]>, Pointer), + ) -> Result<(), ClientIpcError> { + let packet = RpcPacket::Insert { id, insert }; + self.socket.client_send(packet)?; + let InsertPacket::Leave {} = self.socket.client_recv::()?; + Ok(()) + } + pub fn delete(&mut self, id: Id, delete: Pointer) -> Result<(), ClientIpcError> { + let packet = RpcPacket::Delete { id, delete }; + self.socket.client_send(packet)?; + let DeletePacket::Leave {} = self.socket.client_recv::()?; + Ok(()) + } + pub fn load(&mut self, id: Id) -> Result<(), ClientIpcError> { + let packet = RpcPacket::Load { id }; + self.socket.client_send(packet)?; + let LoadPacket::Leave {} = self.socket.client_recv::()?; + Ok(()) + } + pub fn unload(&mut self, id: Id) -> Result<(), ClientIpcError> { + let packet = RpcPacket::Unload { id }; + self.socket.client_send(packet)?; + let UnloadPacket::Leave {} = self.socket.client_recv::()?; + Ok(()) + } + pub fn flush(&mut self, id: Id) -> Result<(), ClientIpcError> { + let packet = RpcPacket::Flush { id }; + self.socket.client_send(packet)?; + let FlushPacket::Leave {} = self.socket.client_recv::()?; + Ok(()) + } + pub fn clean(&mut self, id: Id) -> Result<(), ClientIpcError> { + let packet = RpcPacket::Clean { id }; + self.socket.client_send(packet)?; + let CleanPacket::Leave {} = self.socket.client_recv::()?; + Ok(()) + } +} + +pub struct BuildHandler { + reach: bool, + socket: Socket, +} + +impl BuildHandler { + pub fn handle(mut self) -> Result { + if !self.reach { + Ok(BuildHandle::Next { + x: Next { + socket: self.socket, + }, + }) + } else { + Ok(match self.socket.client_recv::()? { + BuildPacket::Leave {} => BuildHandle::Leave { + x: Rpc { + socket: self.socket, + }, + }, + _ => unreachable!(), + }) + } + } +} + +pub enum BuildHandle { + Next { x: Next }, + Leave { x: Rpc }, +} + +pub struct Next { + socket: Socket, +} + +impl Next { + pub fn leave( + mut self, + data: Option<(Box<[Scalar]>, Pointer)>, + ) -> Result { + let end = data.is_none(); + let packet = NextPacket::Leave { data }; + self.socket.client_send(packet)?; + Ok(BuildHandler { + socket: self.socket, + reach: end, + }) + } +} + +pub enum SearchHandle { + Check { p: Pointer, x: Check }, + Leave { result: Vec, x: Rpc }, +} + +pub struct SearchHandler { + socket: Socket, +} + +impl SearchHandler { + pub fn handle(mut self) -> Result { + Ok(match self.socket.client_recv::()? { + SearchPacket::Check { p } => SearchHandle::Check { + p, + x: Check { + socket: self.socket, + }, + }, + SearchPacket::Leave { result } => SearchHandle::Leave { + result, + x: Rpc { + socket: self.socket, + }, + }, + }) + } +} + +pub struct Check { + socket: Socket, +} + +impl Check { + pub fn leave(mut self, result: bool) -> Result { + let packet = CheckPacket::Leave { result }; + self.socket.client_send(packet)?; + Ok(SearchHandler { + socket: self.socket, + }) + } +} diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs new file mode 100644 index 000000000..4d251b44e --- /dev/null +++ b/src/ipc/mod.rs @@ -0,0 +1,38 @@ +pub mod client; +mod packet; +pub mod server; +mod transport; + +use self::client::Rpc; +use self::server::RpcHandler; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum ServerIpcError { + #[error("The connection is closed.")] + Closed, + #[error("Server encounters an error.")] + Server, +} + +#[derive(Debug, Error, Serialize, Deserialize)] +pub enum ClientIpcError { + #[error("The connection is closed.")] + Closed, + #[error("Server encounters an error.")] + Server, +} + +pub fn listen() -> impl Iterator { + let mut listener = self::transport::Listener::new(); + std::iter::from_fn(move || { + let socket = listener.accept(); + Some(self::server::RpcHandler::new(socket)) + }) +} + +pub fn connect() -> Rpc { + let socket = self::transport::Socket::new(); + self::client::Rpc::new(socket) +} diff --git a/src/ipc/packet.rs b/src/ipc/packet.rs new file mode 100644 index 000000000..c38dfff9e --- /dev/null +++ b/src/ipc/packet.rs @@ -0,0 +1,91 @@ +use crate::bgworker::index::IndexOptions; +use crate::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub enum RpcPacket { + Build { + id: Id, + options: IndexOptions, + }, + Insert { + id: Id, + insert: (Box<[Scalar]>, Pointer), + }, + Delete { + id: Id, + delete: Pointer, + }, + Search { + id: Id, + target: Box<[Scalar]>, + k: usize, + }, + Load { + id: Id, + }, + Unload { + id: Id, + }, + Flush { + id: Id, + }, + Clean { + id: Id, + }, + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum BuildPacket { + Next {}, + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum NextPacket { + Leave { + data: Option<(Box<[Scalar]>, Pointer)>, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum SearchPacket { + Check { p: Pointer }, + Leave { result: Vec }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum CheckPacket { + Leave { result: bool }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum InsertPacket { + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum DeletePacket { + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum LoadPacket { + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum UnloadPacket { + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum FlushPacket { + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum CleanPacket { + Leave {}, +} diff --git a/src/ipc/server.rs b/src/ipc/server.rs new file mode 100644 index 000000000..a9d2488cd --- /dev/null +++ b/src/ipc/server.rs @@ -0,0 +1,248 @@ +use crate::bgworker::index::IndexOptions; +use crate::ipc::packet::*; +use crate::ipc::transport::Socket; +use crate::ipc::ServerIpcError; +use crate::prelude::*; + +pub struct RpcHandler { + socket: Socket, +} + +impl RpcHandler { + pub(super) fn new(socket: Socket) -> Self { + Self { socket } + } + pub fn handle(mut self) -> Result { + Ok(match self.socket.server_recv::()? { + RpcPacket::Build { id, options } => RpcHandle::Build { + id, + options, + x: Build { + socket: self.socket, + reach: false, + }, + }, + RpcPacket::Insert { id, insert } => RpcHandle::Insert { + id, + insert, + x: Insert { + socket: self.socket, + }, + }, + RpcPacket::Delete { id, delete } => RpcHandle::Delete { + id, + delete, + x: Delete { + socket: self.socket, + }, + }, + RpcPacket::Search { id, target, k } => RpcHandle::Search { + id, + target, + k, + x: Search { + socket: self.socket, + }, + }, + RpcPacket::Load { id } => RpcHandle::Load { + id, + x: Load { + socket: self.socket, + }, + }, + RpcPacket::Unload { id } => RpcHandle::Unload { + id, + x: Unload { + socket: self.socket, + }, + }, + RpcPacket::Flush { id } => RpcHandle::Flush { + id, + x: Flush { + socket: self.socket, + }, + }, + RpcPacket::Clean { id } => RpcHandle::Clean { + id, + x: Clean { + socket: self.socket, + }, + }, + RpcPacket::Leave {} => RpcHandle::Leave {}, + }) + } +} + +pub enum RpcHandle { + Build { + id: Id, + options: IndexOptions, + x: Build, + }, + Search { + id: Id, + target: Box<[Scalar]>, + k: usize, + x: Search, + }, + Insert { + id: Id, + insert: (Box<[Scalar]>, Pointer), + x: Insert, + }, + Delete { + id: Id, + delete: Pointer, + x: Delete, + }, + Load { + id: Id, + x: Load, + }, + Unload { + id: Id, + x: Unload, + }, + Flush { + id: Id, + x: Flush, + }, + Clean { + id: Id, + x: Clean, + }, + Leave {}, +} + +pub struct Build { + socket: Socket, + reach: bool, +} + +impl Build { + pub fn next(&mut self) -> Result, Pointer)>, ServerIpcError> { + if !self.reach { + let packet = self.socket.server_recv::()?; + match packet { + NextPacket::Leave { data: Some(data) } => Ok(Some(data)), + NextPacket::Leave { data: None } => { + self.reach = true; + Ok(None) + } + } + } else { + Ok(None) + } + } + pub fn leave(mut self) -> Result { + let packet = BuildPacket::Leave {}; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} + +pub struct Insert { + socket: Socket, +} + +impl Insert { + pub fn leave(mut self) -> Result { + let packet = InsertPacket::Leave {}; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} + +pub struct Delete { + socket: Socket, +} + +impl Delete { + pub fn leave(mut self) -> Result { + let packet = DeletePacket::Leave {}; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} + +pub struct Search { + socket: Socket, +} + +impl Search { + pub fn check(&mut self, p: Pointer) -> Result { + let packet = SearchPacket::Check { p }; + self.socket.server_send(packet)?; + let CheckPacket::Leave { result } = self.socket.server_recv::()?; + Ok(result) + } + pub fn leave(mut self, result: Vec) -> Result { + let packet = SearchPacket::Leave { result }; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} + +pub struct Load { + socket: Socket, +} + +impl Load { + pub fn leave(mut self) -> Result { + let packet = LoadPacket::Leave {}; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} + +pub struct Unload { + socket: Socket, +} + +impl Unload { + pub fn leave(mut self) -> Result { + let packet = UnloadPacket::Leave {}; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} + +pub struct Flush { + socket: Socket, +} + +impl Flush { + pub fn leave(mut self) -> Result { + let packet = FlushPacket::Leave {}; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} + +pub struct Clean { + socket: Socket, +} + +impl Clean { + pub fn leave(mut self) -> Result { + let packet = CleanPacket::Leave {}; + self.socket.server_send(packet)?; + Ok(RpcHandler { + socket: self.socket, + }) + } +} diff --git a/src/ipc/transport.rs b/src/ipc/transport.rs new file mode 100644 index 000000000..fc18f6dcb --- /dev/null +++ b/src/ipc/transport.rs @@ -0,0 +1,121 @@ +use crate::ipc::ClientIpcError; +use crate::ipc::ServerIpcError; +use byteorder::{ReadBytesExt, WriteBytesExt}; +use serde::{Deserialize, Serialize}; +use std::io::ErrorKind; +use std::io::{Read, Write}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::path::Path; + +macro_rules! resolve_server_closed { + ($t: expr) => { + match $t { + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Err(ServerIpcError::Closed) + } + Err(e) => panic!("{}", e), + Ok(e) => e, + } + }; +} + +macro_rules! resolve_client_closed { + ($t: expr) => { + match $t { + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Err(ClientIpcError::Closed) + } + Err(e) => panic!("{}", e), + Ok(e) => e, + } + }; +} + +pub(super) struct Listener { + listener: UnixListener, +} + +impl Listener { + pub fn new() -> Self { + let path = "./_socket"; + remove_file_if_exists(&path).expect("Failed to bind."); + let listener = UnixListener::bind(&path).expect("Failed to bind."); + Self { listener } + } + pub fn accept(&mut self) -> Socket { + let (stream, _) = self.listener.accept().expect("Failed to listen."); + Socket { + stream: Some(stream), + } + } +} + +pub(super) struct Socket { + stream: Option, +} + +impl Socket { + pub fn new() -> Self { + let path = "./pg_vectors/_socket"; + let stream = UnixStream::connect(path).expect("Failed to bind."); + Socket { + stream: Some(stream), + } + } + pub fn server_send(&mut self, packet: T) -> Result<(), ServerIpcError> + where + T: Serialize, + { + use byteorder::NativeEndian as N; + let stream = self.stream.as_mut().ok_or(ServerIpcError::Closed)?; + let buffer = bincode::serialize(&packet).expect("Failed to serialize"); + let len = u32::try_from(buffer.len()).expect("Packet is too large."); + resolve_server_closed!(stream.write_u32::(len)); + resolve_server_closed!(stream.write_all(&buffer)); + Ok(()) + } + pub fn client_recv(&mut self) -> Result + where + T: for<'a> Deserialize<'a>, + { + use byteorder::NativeEndian as N; + let stream = self.stream.as_mut().ok_or(ClientIpcError::Closed)?; + let len = resolve_client_closed!(stream.read_u32::()); + let mut buffer = vec![0u8; len as usize]; + resolve_client_closed!(stream.read_exact(&mut buffer)); + let packet = bincode::deserialize(&buffer).expect("Failed to deserialize."); + Ok(packet) + } + pub fn client_send(&mut self, packet: T) -> Result<(), ClientIpcError> + where + T: Serialize, + { + use byteorder::NativeEndian as N; + let stream = self.stream.as_mut().ok_or(ClientIpcError::Closed)?; + let buffer = bincode::serialize(&packet).expect("Failed to serialize"); + let len = u32::try_from(buffer.len()).expect("Packet is too large."); + resolve_client_closed!(stream.write_u32::(len)); + resolve_client_closed!(stream.write_all(&buffer)); + Ok(()) + } + pub fn server_recv(&mut self) -> Result + where + T: for<'a> Deserialize<'a>, + { + use byteorder::NativeEndian as N; + let stream = self.stream.as_mut().ok_or(ServerIpcError::Closed)?; + let len = resolve_server_closed!(stream.read_u32::()); + let mut buffer = vec![0u8; len as usize]; + resolve_server_closed!(stream.read_exact(&mut buffer)); + let packet = bincode::deserialize(&buffer).expect("Failed to deserialize."); + Ok(packet) + } +} + +fn remove_file_if_exists(path: impl AsRef) -> std::io::Result<()> { + match std::fs::remove_file(path) { + Ok(()) => Ok(()), + Err(e) if e.kind() == ErrorKind::NotFound => Ok(()), + Err(e) => Err(e), + } +} diff --git a/src/lib.rs b/src/lib.rs index 9d8b9788f..f23192146 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,16 +8,17 @@ #![feature(negative_impls)] #![feature(ptr_metadata)] #![feature(new_uninit)] -#![feature(maybe_uninit_slice)] +#![feature(int_roundings)] +#![feature(never_type)] +#![allow(clippy::complexity)] +#![allow(clippy::style)] mod algorithms; mod bgworker; mod embedding; -mod memory; +mod ipc; mod postgres; mod prelude; -mod udf; -mod utils; pgrx::pg_module_magic!(); @@ -38,9 +39,8 @@ pub unsafe extern "C" fn _PG_init() { .set_function("vectors_main") .set_library("vectors") .set_argument(None) - .set_start_time(BgWorkerStartTime::ConsistentState) - .enable_spi_access() .enable_shmem_access(None) + .set_start_time(BgWorkerStartTime::PostmasterStart) .load(); self::postgres::init(); } diff --git a/src/memory/block.rs b/src/memory/block.rs deleted file mode 100644 index da40ebec6..000000000 --- a/src/memory/block.rs +++ /dev/null @@ -1,152 +0,0 @@ -use crate::prelude::*; -use cstr::cstr; -use memmap2::MmapMut; -use std::alloc::{AllocError, Layout}; -use std::fs::{File, OpenOptions}; -use std::os::fd::FromRawFd; -use std::sync::atomic::{AtomicUsize, Ordering}; - -pub struct Block { - #[allow(dead_code)] - size: usize, - path: String, - mmap: MmapMut, - bump: Bump, -} - -impl Block { - pub fn build(size: usize, path: String, storage: Storage) -> anyhow::Result { - anyhow::ensure!(size % 4096 == 0); - let file = tempfile(storage)?; - file.set_len(size as u64)?; - let mut mmap = unsafe { MmapMut::map_mut(&file) }?; - mmap.advise(memmap2::Advice::WillNeed)?; - let bump = unsafe { Bump::build(size, mmap.as_mut_ptr()) }; - Ok(Self { - size, - path, - mmap, - bump, - }) - } - - pub fn load(size: usize, path: String, storage: Storage) -> anyhow::Result { - anyhow::ensure!(size % 4096 == 0); - let mut file = tempfile(storage)?; - let mut persistent_file = std::fs::OpenOptions::new().read(true).open(&path)?; - std::io::copy(&mut persistent_file, &mut file)?; - let mut mmap = unsafe { MmapMut::map_mut(&file) }?; - let bump = unsafe { Bump::load(size, mmap.as_mut_ptr()) }; - Ok(Self { - size, - path, - mmap, - bump, - }) - } - - pub fn persist(&self) -> anyhow::Result<()> { - use std::io::Write; - let mut persistent_file = OpenOptions::new() - .create(true) - .read(true) - .write(true) - .truncate(true) - .open(&self.path)?; - persistent_file.write_all(self.mmap.as_ref())?; - persistent_file.sync_all()?; - Ok(()) - } - - pub fn address(&self) -> usize { - self.mmap.as_ptr() as usize - } - - pub fn allocate(&self, layout: Layout) -> Result { - self.bump.allocate(layout) - } - - pub fn allocate_zeroed(&self, layout: Layout) -> Result { - self.bump.allocate_zeroed(layout) - } -} - -pub struct Bump { - size: usize, - space: *mut Header, -} - -impl Bump { - pub unsafe fn build(size: usize, addr: *mut u8) -> Self { - assert!(size >= 4096); - let space = addr.cast::
(); - space.write(Header { - cursor: AtomicUsize::new(4096), - objects: AtomicUsize::new(0), - }); - Self { size, space } - } - pub unsafe fn load(size: usize, addr: *mut u8) -> Self { - assert!(size >= 4096); - let space = addr.cast::
(); - Self { size, space } - } - pub fn allocate(&self, layout: Layout) -> Result { - if layout.size() == 0 { - return Ok(0); - } - if layout.align() > 128 { - return Err(AllocError); - } - let mut old = unsafe { (*self.space).cursor.load(Ordering::Relaxed) }; - let offset = loop { - let offset = (old + layout.align() - 1) & !(layout.align() - 1); - let new = offset + layout.size(); - if new > self.size { - return Err(AllocError); - } - let exchange = unsafe { - (*self.space).cursor.compare_exchange_weak( - old, - new, - Ordering::Relaxed, - Ordering::Relaxed, - ) - }; - let Err(_old) = exchange else { break offset }; - old = _old; - }; - unsafe { - (*self.space).objects.fetch_add(1, Ordering::Relaxed); - } - Ok(offset) - } - pub fn allocate_zeroed(&self, layout: Layout) -> Result { - self.allocate(layout) - } -} - -unsafe impl Send for Bump {} -unsafe impl Sync for Bump {} - -#[repr(C)] -struct Header { - cursor: AtomicUsize, - objects: AtomicUsize, -} - -fn tempfile(storage: Storage) -> anyhow::Result { - use Storage::*; - let file = match storage { - Disk => tempfile::tempfile()?, - Ram => unsafe { - let fd = libc::memfd_create(cstr!("file").as_ptr(), 0); - if fd != -1 { - File::from_raw_fd(fd) - } else { - anyhow::bail!(std::io::Error::last_os_error()); - } - }, - }; - Ok(file) -} diff --git a/src/memory/mod.rs b/src/memory/mod.rs deleted file mode 100644 index 98cdc46b5..000000000 --- a/src/memory/mod.rs +++ /dev/null @@ -1,231 +0,0 @@ -mod block; -mod pbox; - -pub use pbox::PBox; - -use self::block::Block; -use crate::prelude::*; -use serde::{Deserialize, Serialize}; -use std::alloc::{AllocError, Layout}; -use std::cell::Cell; -use std::fmt::Debug; -use std::ptr::{NonNull, Pointee}; -use std::sync::Arc; -use std::thread::{Scope, ScopedJoinHandle}; - -pub unsafe auto trait Persistent {} - -impl !Persistent for *const T {} -impl !Persistent for *mut T {} -impl !Persistent for &'_ T {} -impl !Persistent for &'_ mut T {} - -#[repr(transparent)] -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] -pub struct Address(usize); - -impl Debug for Address { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "({:?}, {:#x})", self.storage(), self.offset()) - } -} - -impl Address { - pub fn storage(self) -> Storage { - use Storage::*; - if Ram as usize == (self.0 >> 63) { - Storage::Ram - } else { - Storage::Disk - } - } - pub fn offset(self) -> usize { - self.0 & ((1usize << 63) - 1) - } - pub const fn new(storage: Storage, offset: usize) -> Self { - debug_assert!(offset < (1 << 63)); - Self((storage as usize) << 63 | offset << 0) - } - pub const DANGLING: Self = Address(usize::MAX); -} - -#[repr(C)] -#[derive(PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Ptr { - address: Address, - metadata: ::Metadata, -} - -impl Debug for Ptr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if std::mem::size_of::<::Metadata>() == 0 { - write!(f, "({:?})", self.address()) - } else if std::mem::size_of::<::Metadata>() == std::mem::size_of::() { - let metadata = unsafe { std::mem::transmute_copy::<_, usize>(&self.metadata()) }; - write!(f, "({:?}, {:#x})", self.address(), metadata) - } else { - write!(f, "({:?}, ?)", self.address()) - } - } -} - -impl Clone for Ptr { - fn clone(&self) -> Self { - *self - } -} - -impl Copy for Ptr {} - -impl Ptr { - pub fn storage(self) -> Storage { - self.address.storage() - } - pub fn offset(self) -> usize { - self.address.offset() - } - pub fn address(self) -> Address { - self.address - } - pub fn metadata(self) -> ::Metadata { - self.metadata - } - pub fn new(address: Address, metadata: ::Metadata) -> Self { - Self { address, metadata } - } - pub fn cast(self) -> Ptr { - Ptr::new(self.address, ()) - } - pub fn from_raw_parts(data_address: Ptr<()>, metadata: ::Metadata) -> Self { - Ptr::new(data_address.address(), metadata) - } - pub fn as_ptr(self) -> *const T { - let data_address = (OFFSETS[self.storage() as usize].get() + self.offset()) as _; - let metadata = self.metadata(); - std::ptr::from_raw_parts(data_address, metadata) - } - pub fn as_mut_ptr(self) -> *mut T { - let data_address = (OFFSETS[self.storage() as usize].get() + self.offset()) as _; - let metadata = self.metadata(); - std::ptr::from_raw_parts_mut(data_address, metadata) - } - pub unsafe fn as_ref<'a>(self) -> &'a T { - &*self.as_ptr() - } - pub unsafe fn as_mut<'a>(self) -> &'a mut T { - &mut *self.as_mut_ptr() - } -} - -#[thread_local] -static CONTEXT: Cell>> = Cell::new(None); - -#[thread_local] -static OFFSETS: [Cell; 2] = [Cell::new(0), Cell::new(0)]; - -pub unsafe fn given(p: NonNull) -> impl Drop { - pub struct Given; - impl Drop for Given { - fn drop(&mut self) { - CONTEXT.take(); - } - } - let given = Given; - CONTEXT.set(Some(p)); - OFFSETS[0].set(p.as_ref().block_ram.address()); - OFFSETS[1].set(p.as_ref().block_disk.address()); - given -} - -pub fn using<'a>() -> &'a Context { - let ctx = CONTEXT.get().expect("Never given a context to use."); - unsafe { ctx.as_ref() } -} - -pub struct Context { - block_ram: Block, - block_disk: Block, -} - -impl Context { - pub fn build(options: ContextOptions) -> anyhow::Result> { - let block_ram = Block::build(options.block_ram.0, options.block_ram.1, Storage::Ram)?; - let block_disk = Block::build(options.block_disk.0, options.block_disk.1, Storage::Disk)?; - Ok(Arc::new(Self { - block_ram, - block_disk, - })) - } - pub fn load(options: ContextOptions) -> anyhow::Result> { - let block_ram = Block::load(options.block_ram.0, options.block_ram.1, Storage::Ram)?; - let block_disk = Block::load(options.block_disk.0, options.block_disk.1, Storage::Disk)?; - Ok(Arc::new(Self { - block_ram, - block_disk, - })) - } - pub fn persist(&self) -> anyhow::Result<()> { - self.block_ram.persist()?; - self.block_disk.persist()?; - Ok(()) - } - pub fn allocate(&self, storage: Storage, layout: Layout) -> Result, AllocError> { - use Storage::*; - let offset = match storage { - Ram => self.block_ram.allocate(layout), - Disk => self.block_disk.allocate(layout), - }?; - let address = Address::new(storage, offset); - let ptr = Ptr::new(address, ()); - Ok(ptr) - } - pub fn allocate_zeroed(&self, storage: Storage, layout: Layout) -> Result, AllocError> { - use Storage::*; - let offset = match storage { - Ram => self.block_ram.allocate_zeroed(layout), - Disk => self.block_disk.allocate_zeroed(layout), - }?; - let address = Address::new(storage, offset); - let ptr = Ptr::new(address, ()); - Ok(ptr) - } - pub fn scope<'env, F, T>(&self, f: F) -> T - where - F: for<'scope> FnOnce(&'scope ContextScope<'scope, 'env>) -> T, - { - std::thread::scope(|scope| { - f(unsafe { std::mem::transmute::<&Scope, &ContextScope>(scope) }) - }) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ContextOptions { - pub block_ram: (usize, String), - pub block_disk: (usize, String), -} - -#[repr(transparent)] -pub struct ContextScope<'scope, 'env: 'scope>(Scope<'scope, 'env>); - -impl<'scope, 'env: 'scope> ContextScope<'scope, 'env> { - pub fn spawn(&'scope self, f: F) -> ScopedJoinHandle<'scope, T> - where - F: FnOnce() -> T + Send + 'scope, - T: Send + 'scope, - { - struct AssertSend(T); - impl AssertSend { - fn cosume(self) -> T { - self.0 - } - } - unsafe impl Send for AssertSend {} - let wrapped = AssertSend(CONTEXT.get().unwrap()); - self.0.spawn(move || { - let context = wrapped.cosume(); - let _given = unsafe { given(context) }; - f() - }) - } -} diff --git a/src/memory/pbox.rs b/src/memory/pbox.rs deleted file mode 100644 index 003154aec..000000000 --- a/src/memory/pbox.rs +++ /dev/null @@ -1,115 +0,0 @@ -use crate::memory::{using, Ptr}; -use crate::prelude::Storage; -use std::alloc::Layout; -use std::borrow::{Borrow, BorrowMut}; -use std::fmt::Debug; -use std::mem::MaybeUninit; -use std::ops::{Deref, DerefMut}; - -pub struct PBox(Ptr); - -impl PBox { - pub fn new(t: T, storage: Storage) -> anyhow::Result { - let ptr = using() - .allocate(storage, std::alloc::Layout::new::())? - .cast::(); - unsafe { - ptr.as_mut_ptr().write(t); - } - Ok(Self(ptr)) - } -} - -impl PBox { - pub fn into_raw(self) -> Ptr { - let raw = self.0; - std::mem::forget(self); - raw - } - #[allow(dead_code)] - pub fn from_raw(raw: Ptr) -> Self { - Self(raw) - } -} - -impl Deref for PBox { - type Target = T; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl DerefMut for PBox { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { self.0.as_mut() } - } -} - -impl AsRef for PBox { - fn as_ref(&self) -> &T { - unsafe { self.0.as_ref() } - } -} - -impl AsMut for PBox { - fn as_mut(&mut self) -> &mut T { - unsafe { self.0.as_mut() } - } -} - -impl Borrow for PBox { - fn borrow(&self) -> &T { - unsafe { self.0.as_ref() } - } -} - -impl BorrowMut for PBox { - fn borrow_mut(&mut self) -> &mut T { - unsafe { self.0.as_mut() } - } -} - -impl Debug for PBox { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Debug::fmt(self.deref(), f) - } -} - -impl PBox> { - #[allow(dead_code)] - pub fn new_uninit(storage: Storage) -> anyhow::Result>> { - let ptr = using() - .allocate(storage, std::alloc::Layout::new::())? - .cast::>(); - Ok(Self(ptr)) - } - #[allow(dead_code)] - pub unsafe fn assume_init(self) -> PBox { - let ptr = PBox::into_raw(self); - PBox(Ptr::new(ptr.address(), ())) - } -} - -impl PBox<[MaybeUninit]> { - pub fn new_uninit_slice( - len: usize, - storage: Storage, - ) -> anyhow::Result]>> { - let ptr = using().allocate(storage, Layout::array::(len)?)?; - let ptr = Ptr::from_raw_parts(ptr, len); - Ok(PBox(ptr)) - } - pub fn new_zeroed_slice( - len: usize, - storage: Storage, - ) -> anyhow::Result]>> { - let ptr = using().allocate_zeroed(storage, Layout::array::(len)?)?; - let ptr = Ptr::from_raw_parts(ptr, len); - Ok(PBox(ptr)) - } - pub unsafe fn assume_init(self) -> PBox<[T]> { - let ptr = PBox::into_raw(self); - PBox(Ptr::new(ptr.address(), ptr.metadata())) - } -} diff --git a/src/postgres/casts.rs b/src/postgres/casts.rs new file mode 100644 index 000000000..f83686103 --- /dev/null +++ b/src/postgres/casts.rs @@ -0,0 +1,19 @@ +use super::datatype::{Vector, VectorInput, VectorOutput, VectorTypmod}; +use crate::prelude::Scalar; + +#[pgrx::pg_extern] +fn cast_array_to_vector(array: pgrx::Array, typmod: i32, _explicit: bool) -> VectorOutput { + assert!(!array.is_empty()); + assert!(!array.contains_nulls()); + let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); + let mut data = Vec::with_capacity(typmod.dims().unwrap_or_default() as usize); + for x in array.iter_deny_null() { + data.push(x); + } + Vector::new_in_postgres(&data) +} + +#[pgrx::pg_extern] +fn cast_vector_to_array<'a>(vector: VectorInput<'a>, _typmod: i32, _explicit: bool) -> Vec { + vector.data().to_vec() +} diff --git a/src/postgres/datatype.rs b/src/postgres/datatype.rs index 479bb067f..6a8caa149 100644 --- a/src/postgres/datatype.rs +++ b/src/postgres/datatype.rs @@ -428,118 +428,3 @@ fn vector_typmod_out(typmod: i32) -> CString { None => CString::new("()").unwrap(), } } - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(+)] -#[pgrx::commutator(+)] -fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - let n = lhs.len(); - let mut v = Vector::new_zeroed(n); - for i in 0..n { - v[i] = lhs[i] + rhs[i]; - } - v.copy_into_postgres() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(-)] -fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - let n = lhs.len(); - let mut v = Vector::new_zeroed(n); - for i in 0..n { - v[i] = lhs[i] - rhs[i]; - } - v.copy_into_postgres() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(<)] -#[pgrx::negator(>=)] -#[pgrx::commutator(>)] -#[pgrx::restrict(scalarltsel)] -#[pgrx::join(scalarltjoinsel)] -fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - lhs.deref() < rhs.deref() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(<=)] -#[pgrx::negator(>)] -#[pgrx::commutator(>=)] -#[pgrx::restrict(scalarltsel)] -#[pgrx::join(scalarltjoinsel)] -fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - lhs.deref() <= rhs.deref() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(>)] -#[pgrx::negator(<=)] -#[pgrx::commutator(<)] -#[pgrx::restrict(scalargtsel)] -#[pgrx::join(scalargtjoinsel)] -fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - lhs.deref() > rhs.deref() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(>=)] -#[pgrx::negator(<)] -#[pgrx::commutator(<=)] -#[pgrx::restrict(scalargtsel)] -#[pgrx::join(scalargtjoinsel)] -fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - lhs.deref() >= rhs.deref() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(=)] -#[pgrx::negator(<>)] -#[pgrx::commutator(=)] -#[pgrx::restrict(eqsel)] -#[pgrx::join(eqjoinsel)] -fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - lhs.deref() == rhs.deref() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(<>)] -#[pgrx::negator(=)] -#[pgrx::commutator(<>)] -#[pgrx::restrict(eqsel)] -#[pgrx::join(eqjoinsel)] -fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - lhs.deref() != rhs.deref() -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(<=>)] -#[pgrx::commutator(<=>)] -fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - Distance::Cosine.distance(&lhs, &rhs) -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(<#>)] -#[pgrx::commutator(<#>)] -fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - Distance::Dot.distance(&lhs, &rhs) -} - -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] -#[pgrx::opname(<->)] -#[pgrx::commutator(<->)] -fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { - assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); - Distance::L2.distance(&lhs, &rhs) -} diff --git a/src/postgres/functions.rs b/src/postgres/functions.rs new file mode 100644 index 000000000..327ae2bbb --- /dev/null +++ b/src/postgres/functions.rs @@ -0,0 +1,20 @@ +use crate::postgres::hook_transaction::client; +use crate::prelude::*; + +#[pgrx::pg_extern(strict)] +unsafe fn vectors_load(oid: pgrx::pg_sys::Oid) { + let id = Id::from_sys(oid); + client(|mut rpc| { + rpc.load(id).unwrap(); + rpc + }) +} + +#[pgrx::pg_extern(strict)] +unsafe fn vectors_unload(oid: pgrx::pg_sys::Oid) { + let id = Id::from_sys(oid); + client(|mut rpc| { + rpc.unload(id).unwrap(); + rpc + }) +} diff --git a/src/postgres/gucs.rs b/src/postgres/gucs.rs index 11fe2ca86..a65fd2fdf 100644 --- a/src/postgres/gucs.rs +++ b/src/postgres/gucs.rs @@ -4,8 +4,6 @@ use std::ffi::CStr; pub static OPENAI_API_KEY_GUC: GucSetting> = GucSetting::>::new(None); -pub static PORT: GucSetting = GucSetting::::new(33509); - pub static K: GucSetting = GucSetting::::new(64); pub unsafe fn init() { @@ -17,16 +15,6 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); - GucRegistry::define_int_guc( - "vectors.port", - "The port for the background worker to listen.", - "If the system runs two or more Postgres clusters, ports should be set with different values.", - &PORT, - 1, - u16::MAX as _, - GucContext::Postmaster, - GucFlags::default(), - ); GucRegistry::define_int_guc( "vectors.k", "The number of nearest neighbors to return for searching.", diff --git a/src/postgres/hook_executor.rs b/src/postgres/hook_executor.rs new file mode 100644 index 000000000..0e45d19d8 --- /dev/null +++ b/src/postgres/hook_executor.rs @@ -0,0 +1,119 @@ +use super::hook_transaction::drop_if_commit; +use crate::postgres::index_scan::Scanner; +use crate::prelude::*; +use std::ptr::null_mut; + +type PlanstateTreeWalker = + unsafe extern "C" fn(*mut pgrx::pg_sys::PlanState, *mut libc::c_void) -> bool; + +pub unsafe fn post_executor_start(query_desc: *mut pgrx::pg_sys::QueryDesc) { + // Before Postgres 16, type defination of `PlanstateTreeWalker` in the source code is incorrect. + let planstate = (*query_desc).planstate; + let context = null_mut(); + rewrite_plan_state(planstate, context); +} + +pub unsafe fn pre_process_utility(pstmt: *mut pgrx::pg_sys::PlannedStmt) { + unsafe { + let utility_statement = pgrx::PgBox::from_pg((*pstmt).utilityStmt); + + let is_drop = pgrx::is_a(utility_statement.as_ptr(), pgrx::pg_sys::NodeTag_T_DropStmt); + + if is_drop { + let stat_drop = + pgrx::PgBox::from_pg(utility_statement.as_ptr() as *mut pgrx::pg_sys::DropStmt); + + match stat_drop.removeType { + pgrx::pg_sys::ObjectType_OBJECT_TABLE | pgrx::pg_sys::ObjectType_OBJECT_INDEX => { + let objects = pgrx::PgList::::from_pg(stat_drop.objects); + for object in objects.iter_ptr() { + let mut rel = std::ptr::null_mut(); + let address = pgrx::pg_sys::get_object_address( + stat_drop.removeType, + object, + &mut rel, + pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE, + stat_drop.missing_ok, + ); + + if address.objectId == pgrx::pg_sys::InvalidOid { + continue; + } + + match stat_drop.removeType { + pgrx::pg_sys::ObjectType_OBJECT_TABLE => { + // Memory leak here? + let list = pgrx::pg_sys::RelationGetIndexList(rel); + let list = pgrx::PgList::::from_pg(list); + for index in list.iter_oid() { + drop_if_commit(Id::from_sys(index)); + } + pgrx::pg_sys::relation_close( + rel, + pgrx::pg_sys::AccessExclusiveLock as _, + ); + } + pgrx::pg_sys::ObjectType_OBJECT_INDEX => { + drop_if_commit(Id::from_sys((*rel).rd_id)); + pgrx::pg_sys::relation_close( + rel, + pgrx::pg_sys::AccessExclusiveLock as _, + ); + } + _ => unreachable!(), + } + } + } + + _ => {} + } + } + } +} + +#[pgrx::pg_guard] +unsafe extern "C" fn rewrite_plan_state( + node: *mut pgrx::pg_sys::PlanState, + context: *mut libc::c_void, +) -> bool { + match (*node).type_ { + pgrx::pg_sys::NodeTag_T_IndexScanState => { + let node = node as *mut pgrx::pg_sys::IndexScanState; + let index_relation = (*node).iss_RelationDesc; + // Check the pointer of `amvalidate`. + if (*(*index_relation).rd_indam).amvalidate == Some(super::index::amvalidate) { + // The logic is copied from Postgres source code. + if (*node).iss_ScanDesc.is_null() { + (*node).iss_ScanDesc = pgrx::pg_sys::index_beginscan( + (*node).ss.ss_currentRelation, + (*node).iss_RelationDesc, + (*(*node).ss.ps.state).es_snapshot, + (*node).iss_NumScanKeys, + (*node).iss_NumOrderByKeys, + ); + if (*node).iss_NumRuntimeKeys == 0 || (*node).iss_RuntimeKeysReady { + pgrx::pg_sys::index_rescan( + (*node).iss_ScanDesc, + (*node).iss_ScanKeys, + (*node).iss_NumScanKeys, + (*node).iss_OrderByKeys, + (*node).iss_NumOrderByKeys, + ); + } + // inject + let scanner = &mut *((*(*node).iss_ScanDesc).opaque as *mut Scanner); + let Scanner::Initial { + index_scan_state, .. + } = scanner + else { + unreachable!() + }; + *index_scan_state = Some(node); + } + } + } + _ => (), + } + let walker = std::mem::transmute::(rewrite_plan_state); + pgrx::pg_sys::planstate_tree_walker(node, Some(walker), context) +} diff --git a/src/postgres/hook_transaction.rs b/src/postgres/hook_transaction.rs new file mode 100644 index 000000000..f94a9222c --- /dev/null +++ b/src/postgres/hook_transaction.rs @@ -0,0 +1,53 @@ +use crate::ipc::client::Rpc; +use crate::ipc::connect; +use crate::prelude::*; +use std::cell::RefCell; +use std::collections::BTreeSet; + +#[thread_local] +static FLUSH_IF_COMMIT: RefCell> = RefCell::new(BTreeSet::new()); + +#[thread_local] +static DROP_IF_COMMIT: RefCell> = RefCell::new(BTreeSet::new()); + +#[thread_local] +static CLIENT: RefCell> = RefCell::new(None); + +pub fn aborting() { + *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); + *DROP_IF_COMMIT.borrow_mut() = BTreeSet::new(); +} + +pub fn committing() { + client(|mut rpc| { + for id in FLUSH_IF_COMMIT.borrow().iter().copied() { + rpc.flush(id).unwrap(); + } + + for id in DROP_IF_COMMIT.borrow().iter().copied() { + rpc.clean(id).unwrap(); + } + + rpc + }); + *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); + *DROP_IF_COMMIT.borrow_mut() = BTreeSet::new(); +} + +pub fn drop_if_commit(id: Id) { + DROP_IF_COMMIT.borrow_mut().insert(id); +} + +pub fn flush_if_commit(id: Id) { + FLUSH_IF_COMMIT.borrow_mut().insert(id); +} + +pub fn client(f: F) +where + F: FnOnce(Rpc) -> Rpc, +{ + let mut guard = CLIENT.borrow_mut(); + let client = guard.take().unwrap_or_else(|| connect()); + let client = f(client); + *guard = Some(client); +} diff --git a/src/postgres/hooks.rs b/src/postgres/hooks.rs index d5852d34e..a2a7d2797 100644 --- a/src/postgres/hooks.rs +++ b/src/postgres/hooks.rs @@ -1,17 +1,27 @@ -use crate::bgworker::Client; -use crate::postgres::gucs::PORT; -use crate::prelude::*; -use parking_lot::{Mutex, MutexGuard}; -use pgrx::once_cell::sync::Lazy; use pgrx::PgHooks; -use std::cell::RefCell; -use std::collections::BTreeSet; struct Hooks; static mut HOOKS: Hooks = Hooks; impl PgHooks for Hooks { + fn executor_start( + &mut self, + query_desc: pgrx::PgBox, + eflags: i32, + prev_hook: fn( + query_desc: pgrx::PgBox, + eflags: i32, + ) -> pgrx::HookResult<()>, + ) -> pgrx::HookResult<()> { + let pointer = query_desc.as_ptr(); + let result = prev_hook(query_desc, eflags); + unsafe { + super::hook_executor::post_executor_start(pointer); + } + result + } + fn process_utility_hook( &mut self, pstmt: pgrx::PgBox, @@ -34,133 +44,29 @@ impl PgHooks for Hooks { ) -> pgrx::HookResult<()>, ) -> pgrx::HookResult<()> { unsafe { - let utility_statement = pgrx::PgBox::from_pg(pstmt.utilityStmt); - - let is_drop = pgrx::is_a(utility_statement.as_ptr(), pgrx::pg_sys::NodeTag_T_DropStmt); - - if is_drop { - let stat_drop = - pgrx::PgBox::from_pg(utility_statement.as_ptr() as *mut pgrx::pg_sys::DropStmt); - - match stat_drop.removeType { - pgrx::pg_sys::ObjectType_OBJECT_TABLE - | pgrx::pg_sys::ObjectType_OBJECT_INDEX => { - let objects = - pgrx::PgList::::from_pg(stat_drop.objects); - for object in objects.iter_ptr() { - let mut rel = std::ptr::null_mut(); - let address = pgrx::pg_sys::get_object_address( - stat_drop.removeType, - object, - &mut rel, - pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE, - stat_drop.missing_ok, - ); - - if address.objectId == pgrx::pg_sys::InvalidOid { - continue; - } - - match stat_drop.removeType { - pgrx::pg_sys::ObjectType_OBJECT_TABLE => { - // Memory leak here? - let list = pgrx::pg_sys::RelationGetIndexList(rel); - let list = pgrx::PgList::::from_pg(list); - for index in list.iter_oid() { - drop_if_commit(Id::from_sys(index)); - } - pgrx::pg_sys::relation_close( - rel, - pgrx::pg_sys::AccessExclusiveLock as _, - ); - } - pgrx::pg_sys::ObjectType_OBJECT_INDEX => { - drop_if_commit(Id::from_sys((*rel).rd_id)); - pgrx::pg_sys::relation_close( - rel, - pgrx::pg_sys::AccessExclusiveLock as _, - ); - } - _ => unreachable!(), - } - } - } - - _ => {} - } - prev_hook( - pstmt, - query_string, - read_only_tree, - context, - params, - query_env, - dest, - completion_tag, - ) - } else { - prev_hook( - pstmt, - query_string, - read_only_tree, - context, - params, - query_env, - dest, - completion_tag, - ) - } + super::hook_executor::pre_process_utility(pstmt.as_ptr()); + prev_hook( + pstmt, + query_string, + read_only_tree, + context, + params, + query_env, + dest, + completion_tag, + ) } } } -pub fn drop_if_commit(id: Id) { - DROP_IF_COMMIT.borrow_mut().insert(id); -} - -pub fn flush_if_commit(id: Id) { - FLUSH_IF_COMMIT.borrow_mut().insert(id); -} - -pub fn client() -> MutexGuard<'static, Lazy> { - CLIENT.lock() -} - -#[thread_local] -static FLUSH_IF_COMMIT: RefCell> = RefCell::new(BTreeSet::new()); - -#[thread_local] -static DROP_IF_COMMIT: RefCell> = RefCell::new(BTreeSet::new()); - -static CLIENT: Mutex> = Mutex::new(Lazy::new(lazy_client)); - -fn lazy_client() -> Client { - let stream = std::net::TcpStream::connect(("0.0.0.0", PORT.get() as u16)).unwrap(); - Client::new(stream).unwrap() -} - #[pgrx::pg_guard] unsafe extern "C" fn xact_callback(event: pgrx::pg_sys::XactEvent, _data: pgrx::void_mut_ptr) { match event { pgrx::pg_sys::XactEvent_XACT_EVENT_ABORT => { - *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); - *DROP_IF_COMMIT.borrow_mut() = BTreeSet::new(); - *CLIENT.lock() = Lazy::new(lazy_client); + super::hook_transaction::aborting(); } pgrx::pg_sys::XactEvent_XACT_EVENT_PRE_COMMIT => { - let mut client = CLIENT.lock(); - let client = &mut *client; - - for id in FLUSH_IF_COMMIT.borrow().iter().copied() { - client.flush(id).unwrap(); - } - - for id in DROP_IF_COMMIT.borrow().iter().copied() { - client.drop(id).unwrap(); - } - - *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); - *DROP_IF_COMMIT.borrow_mut() = BTreeSet::new(); + super::hook_transaction::committing(); } _ => {} } diff --git a/src/postgres/index.rs b/src/postgres/index.rs index 8b81aeed5..25e8661ee 100644 --- a/src/postgres/index.rs +++ b/src/postgres/index.rs @@ -1,71 +1,18 @@ -use crate::bgworker::ClientBuild; +use super::index_build; +use super::index_scan; +use super::index_setup; +use super::index_update; use crate::postgres::datatype::VectorInput; -use crate::postgres::datatype::VectorTypmod; -use crate::postgres::gucs::K; -use crate::postgres::hooks::client; -use crate::postgres::hooks::flush_if_commit; use crate::prelude::*; -use pg_sys::Datum; -use pgrx::prelude::*; -use serde::{Deserialize, Serialize}; use std::cell::Cell; -use std::ffi::CStr; -use validator::Validate; #[thread_local] -static RELOPT_KIND: Cell = Cell::new(0); - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct PartialOptions { - capacity: usize, - #[serde(default = "PartialOptions::default_size_ram")] - size_ram: usize, - #[serde(default = "PartialOptions::default_size_disk")] - size_disk: usize, - storage_vectors: Storage, - algorithm: AlgorithmOptions, -} - -impl PartialOptions { - fn default_size_ram() -> usize { - 16384 - } - fn default_size_disk() -> usize { - 16384 - } -} - -#[derive(Copy, Clone, Debug)] -#[repr(C)] -struct PartialOptionsHelper { - vl_len_: i32, - offset: i32, -} - -impl PartialOptionsHelper { - unsafe fn get(this: *const Self) -> PartialOptions { - if (*this).offset == 0 { - panic!("`options` cannot be null.") - } else { - let ptr = (this as *const std::os::raw::c_char).offset((*this).offset as isize); - toml::from_str::(CStr::from_ptr(ptr).to_str().unwrap()).unwrap() - } - } -} - -struct BuildState<'a> { - build: ClientBuild<'a>, - ntuples: f64, -} - -struct ScanState { - data: Option>, -} +static RELOPT_KIND: Cell = Cell::new(0); pub unsafe fn init() { - use pg_sys::AsPgCStr; - RELOPT_KIND.set(pg_sys::add_reloption_kind()); - pg_sys::add_string_reloption( + use pgrx::pg_sys::AsPgCStr; + RELOPT_KIND.set(pgrx::pg_sys::add_reloption_kind()); + pgrx::pg_sys::add_string_reloption( RELOPT_KIND.get(), "options".as_pg_cstr(), "".as_pg_cstr(), @@ -73,20 +20,25 @@ pub unsafe fn init() { None, #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] { - pg_sys::AccessExclusiveLock as pg_sys::LOCKMODE + pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE }, ); } -#[pg_extern(sql = " +#[pgrx::pg_extern(sql = " CREATE OR REPLACE FUNCTION vectors_amhandler(internal) RETURNS index_am_handler PARALLEL SAFE IMMUTABLE STRICT LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@'; CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method'; ", requires = ["vector"])] -fn vectors_amhandler(_fcinfo: pg_sys::FunctionCallInfo) -> PgBox { - let mut am_routine = - unsafe { PgBox::::alloc_node(pg_sys::NodeTag_T_IndexAmRoutine) }; +fn vectors_amhandler( + _fcinfo: pgrx::pg_sys::FunctionCallInfo, +) -> pgrx::PgBox { + let mut am_routine = unsafe { + pgrx::PgBox::::alloc_node( + pgrx::pg_sys::NodeTag_T_IndexAmRoutine, + ) + }; am_routine.amstrategies = 1; am_routine.amsupport = 0; @@ -126,20 +78,23 @@ fn vectors_amhandler(_fcinfo: pg_sys::FunctionCallInfo) -> PgBox bool { - validate_opclass(opclass_oid); +#[pgrx::pg_guard] +pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool { + index_setup::convert_opclass_to_distance(opclass_oid); true } #[cfg(any(feature = "pg11", feature = "pg12"))] -#[pg_guard] -unsafe extern "C" fn amoptions(reloptions: pg_sys::Datum, validate: bool) -> *mut pg_sys::bytea { +#[pgrx::pg_guard] +pub unsafe extern "C" fn amoptions( + reloptions: pg_sys::Datum, + validate: bool, +) -> *mut pg_sys::bytea { use pg_sys::AsPgCStr; let tab: &[pg_sys::relopt_parse_elt] = &[pg_sys::relopt_parse_elt { optname: "options".as_pg_cstr(), opttype: pg_sys::relopt_type_RELOPT_TYPE_STRING, - offset: memoffset::offset_of!(PartialOptionsHelper, offset) as i32, + offset: index_setup::helper_offset() as i32, }]; let mut noptions = 0; let options = pg_sys::parseRelOptions(reloptions, validate, RELOPT_KIND.get(), &mut noptions); @@ -149,14 +104,10 @@ unsafe extern "C" fn amoptions(reloptions: pg_sys::Datum, validate: bool) -> *mu for relopt in std::slice::from_raw_parts_mut(options, noptions as usize) { relopt.gen.as_mut().unwrap().lockmode = pg_sys::AccessExclusiveLock as pg_sys::LOCKMODE; } - let rdopts = pg_sys::allocateReloptStruct( - std::mem::size_of::(), - options, - noptions, - ); + let rdopts = pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions); pg_sys::fillRelOptions( rdopts, - std::mem::size_of::(), + index_setup::helper_size(), options, noptions, validate, @@ -168,33 +119,37 @@ unsafe extern "C" fn amoptions(reloptions: pg_sys::Datum, validate: bool) -> *mu } #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] -#[pg_guard] -unsafe extern "C" fn amoptions(reloptions: pg_sys::Datum, validate: bool) -> *mut pg_sys::bytea { - use pg_sys::AsPgCStr; - let tab: &[pg_sys::relopt_parse_elt] = &[pg_sys::relopt_parse_elt { +#[pgrx::pg_guard] +pub unsafe extern "C" fn amoptions( + reloptions: pgrx::pg_sys::Datum, + validate: bool, +) -> *mut pgrx::pg_sys::bytea { + use pgrx::pg_sys::AsPgCStr; + + let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { optname: "options".as_pg_cstr(), - opttype: pg_sys::relopt_type_RELOPT_TYPE_STRING, - offset: memoffset::offset_of!(PartialOptionsHelper, offset) as i32, + opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, + offset: index_setup::helper_offset() as i32, }]; - let rdopts = pg_sys::build_reloptions( + let rdopts = pgrx::pg_sys::build_reloptions( reloptions, validate, RELOPT_KIND.get(), - std::mem::size_of::(), + index_setup::helper_size(), tab.as_ptr(), tab.len() as _, ); - rdopts as *mut pg_sys::bytea + rdopts as *mut pgrx::pg_sys::bytea } -#[pg_guard] -unsafe extern "C" fn amcostestimate( - _root: *mut pg_sys::PlannerInfo, - path: *mut pg_sys::IndexPath, +#[pgrx::pg_guard] +pub unsafe extern "C" fn amcostestimate( + _root: *mut pgrx::pg_sys::PlannerInfo, + path: *mut pgrx::pg_sys::IndexPath, _loop_count: f64, - index_startup_cost: *mut pg_sys::Cost, - index_total_cost: *mut pg_sys::Cost, - index_selectivity: *mut pg_sys::Selectivity, + index_startup_cost: *mut pgrx::pg_sys::Cost, + index_total_cost: *mut pgrx::pg_sys::Cost, + index_selectivity: *mut pgrx::pg_sys::Selectivity, index_correlation: *mut f64, index_pages: *mut f64, ) { @@ -213,88 +168,27 @@ unsafe extern "C" fn amcostestimate( *index_pages = 0.0; } -#[pg_guard] -unsafe extern "C" fn ambuild( - heap_relation: pg_sys::Relation, - index_relation: pg_sys::Relation, - _index_info: *mut pg_sys::IndexInfo, -) -> *mut pg_sys::IndexBuildResult { - let oid = (*index_relation).rd_id; - let id = Id::from_sys(oid); - flush_if_commit(id); - let options = options(index_relation); - let mut client = client(); - let mut state = BuildState { - build: client.build(id, options).unwrap(), - ntuples: 0.0, - }; - #[cfg(any(feature = "pg11", feature = "pg12"))] - #[pg_guard] - unsafe extern "C" fn callback( - _index_relation: pg_sys::Relation, - htup: pg_sys::HeapTuple, - values: *mut pg_sys::Datum, - is_null: *mut bool, - _tuple_is_alive: bool, - state: *mut std::os::raw::c_void, - ) { - let ctid = &(*htup).t_self; - let state = &mut *(state as *mut BuildState); - let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = ( - pgvector.to_vec().into_boxed_slice(), - Pointer::from_sys(*ctid), - ); - state.build.next(data).unwrap(); - state.ntuples += 1.0; - } - #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] - #[pg_guard] - unsafe extern "C" fn callback( - _index_relation: pg_sys::Relation, - ctid: pg_sys::ItemPointer, - values: *mut pg_sys::Datum, - is_null: *mut bool, - _tuple_is_alive: bool, - state: *mut std::os::raw::c_void, - ) { - let state = &mut *(state as *mut BuildState); - let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = ( - pgvector.to_vec().into_boxed_slice(), - Pointer::from_sys(*ctid), - ); - state.build.next(data).unwrap(); - } - let index_info = pg_sys::BuildIndexInfo(index_relation); - pg_sys::IndexBuildHeapScan( - heap_relation, - index_relation, - index_info, - Some(callback), - &mut state, - ); - state.build.finish().unwrap(); - let mut result = PgBox::::alloc0(); - result.heap_tuples = state.ntuples; +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambuild( + heap_relation: pgrx::pg_sys::Relation, + index_relation: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, +) -> *mut pgrx::pg_sys::IndexBuildResult { + index_build::build(index_relation, Some((heap_relation, index_info))); + let mut result = pgrx::PgBox::::alloc0(); + result.heap_tuples = 0.0; result.index_tuples = 0.0; result.into_pg() } -#[pg_guard] -unsafe extern "C" fn ambuildempty(index_relation: pg_sys::Relation) { - let oid = (*index_relation).rd_id; - let id = Id::from_sys(oid); - flush_if_commit(id); - let options = options(index_relation); - let mut client = client(); - let build = client.build(id, options).unwrap(); - build.finish().unwrap(); +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) { + index_build::build(index_relation, None); } #[cfg(any(feature = "pg11", feature = "pg12", feature = "pg13"))] #[pg_guard] -unsafe extern "C" fn aminsert( +pub unsafe extern "C" fn aminsert( index_relation: pg_sys::Relation, values: *mut pg_sys::Datum, is_null: *mut bool, @@ -303,137 +197,71 @@ unsafe extern "C" fn aminsert( _check_unique: pg_sys::IndexUniqueCheck, _index_info: *mut pg_sys::IndexInfo, ) -> bool { - _aminsert(index_relation, values, is_null, heap_tid) + use pgrx::FromDatum; + let oid = (*index_relation).rd_id; + let id = Id::from_sys(oid); + let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let vector = vector.data().to_vec().into_boxed_slice(); + index_update::update_insert(id, vector, heap_tid); + true } #[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16"))] -#[pg_guard] -unsafe extern "C" fn aminsert( - index_relation: pg_sys::Relation, - values: *mut pg_sys::Datum, +#[pgrx::pg_guard] +pub unsafe extern "C" fn aminsert( + index_relation: pgrx::pg_sys::Relation, + values: *mut pgrx::pg_sys::Datum, is_null: *mut bool, - heap_tid: pg_sys::ItemPointer, - _heap_relation: pg_sys::Relation, - _check_unique: pg_sys::IndexUniqueCheck, + heap_tid: pgrx::pg_sys::ItemPointer, + _heap_relation: pgrx::pg_sys::Relation, + _check_unique: pgrx::pg_sys::IndexUniqueCheck, _index_unchanged: bool, - _index_info: *mut pg_sys::IndexInfo, -) -> bool { - _aminsert(index_relation, values, is_null, heap_tid) -} - -#[pg_guard] -unsafe extern "C" fn _aminsert( - index_relation: pg_sys::Relation, - values: *mut pg_sys::Datum, - is_null: *mut bool, - heap_tid: pg_sys::ItemPointer, + _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { + use pgrx::FromDatum; let oid = (*index_relation).rd_id; let id = Id::from_sys(oid); - flush_if_commit(id); - let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let vector = pgvector.data().to_vec().into_boxed_slice(); - let p = Pointer::from_sys(*heap_tid); - client().insert(id, (vector, p)).unwrap(); + let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let vector = vector.data().to_vec().into_boxed_slice(); + index_update::update_insert(id, vector, heap_tid); true } -#[pg_guard] -unsafe extern "C" fn ambeginscan( - index_relation: pg_sys::Relation, +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambeginscan( + index_relation: pgrx::pg_sys::Relation, n_keys: std::os::raw::c_int, n_order_bys: std::os::raw::c_int, -) -> pg_sys::IndexScanDesc { - let mut scan = PgBox::from_pg(pg_sys::RelationGetIndexScan( - index_relation, - n_keys, - n_order_bys, - )); - - let state = ScanState { data: None }; - - scan.opaque = pgrx::PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(state) - as pgrx::void_mut_ptr; - - scan.into_pg() +) -> pgrx::pg_sys::IndexScanDesc { + index_scan::make_scan(index_relation, n_keys, n_order_bys) } -#[pg_guard] -unsafe extern "C" fn amrescan( - scan: pg_sys::IndexScanDesc, - keys: pg_sys::ScanKey, +#[pgrx::pg_guard] +pub unsafe extern "C" fn amrescan( + scan: pgrx::pg_sys::IndexScanDesc, + keys: pgrx::pg_sys::ScanKey, n_keys: std::os::raw::c_int, - orderbys: pg_sys::ScanKey, + orderbys: pgrx::pg_sys::ScanKey, n_orderbys: std::os::raw::c_int, ) { - let oid = (*(*scan).indexRelation).rd_id; - let id = Id::from_sys(oid); - if n_orderbys > 0 { - let orderbys = std::slice::from_raw_parts_mut(orderbys, n_orderbys as usize); - std::ptr::copy(orderbys.as_ptr(), (*scan).orderByData, orderbys.len()); - } - if n_keys > 0 { - let keys = std::slice::from_raw_parts_mut(keys, n_keys as usize); - std::ptr::copy(keys.as_ptr(), (*scan).keyData, keys.len()); - } - if (*scan).numberOfOrderBys > 0 { - use pg_sys::{palloc, palloc0}; - let size_datum = std::mem::size_of::(); - let size_bool = std::mem::size_of::(); - let orderbyvals = palloc0(size_datum * (*scan).numberOfOrderBys as usize) as *mut Datum; - let orderbynulls = palloc(size_bool * (*scan).numberOfOrderBys as usize) as *mut bool; - orderbynulls.write_bytes(1, (*scan).numberOfOrderBys as usize); - (*scan).xs_orderbyvals = orderbyvals; - (*scan).xs_orderbynulls = orderbynulls; - } - assert!(n_orderbys == 1, "Not supported."); - let state = &mut *((*scan).opaque as *mut ScanState); - let scan_vector = (*orderbys.add(0)).sk_argument; - let dt_vector = VectorInput::from_datum(scan_vector, false).unwrap(); - let vector = dt_vector.data(); - state.data = { - let k = K.get() as _; - let mut data = client() - .search(id, (vector.to_vec().into_boxed_slice(), k)) - .unwrap(); - data.reverse(); - Some(data) - }; + index_scan::start_scan(scan, keys, n_keys, orderbys, n_orderbys); } -#[pg_guard] -unsafe extern "C" fn amgettuple( - scan: pg_sys::IndexScanDesc, - _direction: pg_sys::ScanDirection, +#[pgrx::pg_guard] +pub unsafe extern "C" fn amgettuple( + scan: pgrx::pg_sys::IndexScanDesc, + direction: pgrx::pg_sys::ScanDirection, ) -> bool { - (*scan).xs_recheck = false; - (*scan).xs_recheckorderby = false; - let state = &mut *((*scan).opaque as *mut ScanState); - if let Some(data) = state.data.as_mut() { - if let Some(p) = data.pop() { - #[cfg(any(feature = "pg11"))] - { - (*scan).xs_ctup.t_self = p.into_sys(); - } - #[cfg(not(feature = "pg11"))] - { - (*scan).xs_heaptid = p.into_sys(); - } - true - } else { - false - } - } else { - unreachable!() - } + assert!(direction == pgrx::pg_sys::ScanDirection_ForwardScanDirection); + index_scan::next_scan(scan) } -#[pg_guard] -extern "C" fn amendscan(_scan: pg_sys::IndexScanDesc) {} +#[pgrx::pg_guard] +pub extern "C" fn amendscan(_scan: pgrx::pg_sys::IndexScanDesc) {} #[cfg(any(feature = "pg11", feature = "pg12"))] #[pg_guard] -unsafe extern "C" fn ambulkdelete( +pub unsafe extern "C" fn ambulkdelete( info: *mut pg_sys::IndexVacuumInfo, _stats: *mut pg_sys::IndexBulkDeleteResult, _callback: pg_sys::IndexBulkDeleteCallback, @@ -464,7 +292,6 @@ unsafe extern "C" fn ambulkdelete( } let oid = (*(*info).index).rd_id; let id = Id::from_sys(oid); - flush_if_commit(id); let items = callback_state as *mut LVRelStats; let deletes = std::slice::from_raw_parts((*items).dead_tuples, (*items).num_dead_tuples as usize) @@ -472,16 +299,14 @@ unsafe extern "C" fn ambulkdelete( .copied() .map(Pointer::from_sys) .collect::>(); - for message in deletes { - client().delete(id, message).unwrap(); - } + update_delete(id, deletes); let result = PgBox::::alloc0(); result.into_pg() } #[cfg(any(feature = "pg13", feature = "pg14"))] #[pg_guard] -unsafe extern "C" fn ambulkdelete( +pub unsafe extern "C" fn ambulkdelete( info: *mut pg_sys::IndexVacuumInfo, _stats: *mut pg_sys::IndexBulkDeleteResult, _callback: pg_sys::IndexBulkDeleteCallback, @@ -496,7 +321,6 @@ unsafe extern "C" fn ambulkdelete( } let oid = (*(*info).index).rd_id; let id = Id::from_sys(oid); - flush_if_commit(id); let items = callback_state as *mut LVDeadTuples; let deletes = (*items) .itemptrs @@ -505,25 +329,22 @@ unsafe extern "C" fn ambulkdelete( .copied() .map(Pointer::from_sys) .collect::>(); - for message in deletes { - client().delete(id, message).unwrap(); - } + update_delete(id, deletes); let result = PgBox::::alloc0(); result.into_pg() } #[cfg(any(feature = "pg15", feature = "pg16"))] -#[pg_guard] -unsafe extern "C" fn ambulkdelete( - info: *mut pg_sys::IndexVacuumInfo, - _stats: *mut pg_sys::IndexBulkDeleteResult, - _callback: pg_sys::IndexBulkDeleteCallback, +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambulkdelete( + info: *mut pgrx::pg_sys::IndexVacuumInfo, + _stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, + _callback: pgrx::pg_sys::IndexBulkDeleteCallback, callback_state: *mut std::os::raw::c_void, -) -> *mut pg_sys::IndexBulkDeleteResult { +) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { let oid = (*(*info).index).rd_id; let id = Id::from_sys(oid); - flush_if_commit(id); - let items = callback_state as *mut pg_sys::VacDeadItems; + let items = callback_state as *mut pgrx::pg_sys::VacDeadItems; let deletes = (*items) .items .as_slice((*items).num_items as usize) @@ -531,112 +352,16 @@ unsafe extern "C" fn ambulkdelete( .copied() .map(Pointer::from_sys) .collect::>(); - for message in deletes { - client().delete(id, message).unwrap(); - } - let result = PgBox::::alloc0(); + index_update::update_delete(id, deletes); + let result = pgrx::PgBox::::alloc0(); result.into_pg() } -#[pg_guard] -unsafe extern "C" fn amvacuumcleanup( - _info: *mut pg_sys::IndexVacuumInfo, - _stats: *mut pg_sys::IndexBulkDeleteResult, -) -> *mut pg_sys::IndexBulkDeleteResult { - let result = PgBox::::alloc0(); +#[pgrx::pg_guard] +pub unsafe extern "C" fn amvacuumcleanup( + _info: *mut pgrx::pg_sys::IndexVacuumInfo, + _stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, +) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { + let result = pgrx::PgBox::::alloc0(); result.into_pg() } - -unsafe fn options(index_relation: pg_sys::Relation) -> Options { - let nkeys = (*(*index_relation).rd_index).indnkeyatts; - let opfamily = (*index_relation).rd_opfamily.read(); - let typmod = (*(*(*index_relation).rd_att).attrs.as_ptr().add(0)).type_mod(); - let options = (*index_relation).rd_options as *mut PartialOptionsHelper; - if nkeys != 1 { - panic!("Only supports exactly one key column."); - } - if options.is_null() { - panic!("The options is null."); - } - let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); - let options = PartialOptionsHelper::get(options); - let options = Options { - dims: typmod.dims().expect("Column does not have dimensions."), - distance: validate_opfamily(opfamily), - capacity: options.capacity, - size_disk: options.size_disk, - size_ram: options.size_ram, - storage_vectors: options.storage_vectors, - algorithm: options.algorithm, - }; - options.validate().expect("The options is invalid."); - options -} - -fn regoperatorin(name: &str) -> pg_sys::Oid { - let cstr = std::ffi::CString::new(name).expect("specified name has embedded NULL byte"); - unsafe { - pgrx::direct_function_call::( - pg_sys::regoperatorin, - &[cstr.as_c_str().into_datum()], - ) - .expect("operator lookup returned NULL") - } -} - -unsafe fn validate_opclass(opclass: pg_sys::Oid) -> Distance { - let tup = pg_sys::SearchSysCache1(pg_sys::SysCacheIdentifier_CLAOID as _, opclass.into()); - if tup.is_null() { - panic!("cache lookup failed for operator class {opclass}"); - } - let classform = pg_sys::GETSTRUCT(tup).cast::(); - let opfamily = (*classform).opcfamily; - let distance = validate_opfamily(opfamily); - pg_sys::ReleaseSysCache(tup); - distance -} - -unsafe fn validate_opfamily(opfamily: pg_sys::Oid) -> Distance { - let tup = pg_sys::SearchSysCache1(pg_sys::SysCacheIdentifier_OPFAMILYOID as _, opfamily.into()); - if tup.is_null() { - panic!("cache lookup failed for operator family {opfamily}"); - } - let oprlist = pg_sys::SearchSysCacheList( - pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _, - 1, - opfamily.into(), - 0.into(), - 0.into(), - ); - assert!((*oprlist).n_members == 1); - let member = (*oprlist).members.as_slice(1)[0]; - let oprtup = &mut (*member).tuple; - let oprform = pg_sys::GETSTRUCT(oprtup).cast::(); - assert!((*oprform).amopstrategy == 1); - assert!((*oprform).amoppurpose == pg_sys::AMOP_ORDER as i8); - let opropr = (*oprform).amopopr; - let distance = if opropr == regoperatorin("<->(vector,vector)") { - Distance::L2 - } else if opropr == regoperatorin("<#>(vector,vector)") { - Distance::Dot - } else if opropr == regoperatorin("<=>(vector,vector)") { - Distance::Cosine - } else { - panic!("Unsupported operator.") - }; - pg_sys::ReleaseCatCacheList(oprlist); - pg_sys::ReleaseSysCache(tup); - distance -} - -#[pg_extern(strict)] -unsafe fn vectors_load(oid: pg_sys::Oid) { - let id = Id::from_sys(oid); - client().load(id).unwrap(); -} - -#[pg_extern(strict)] -unsafe fn vectors_unload(oid: pg_sys::Oid) { - let id = Id::from_sys(oid); - client().unload(id).unwrap(); -} diff --git a/src/postgres/index_build.rs b/src/postgres/index_build.rs new file mode 100644 index 000000000..16ebdcded --- /dev/null +++ b/src/postgres/index_build.rs @@ -0,0 +1,88 @@ +use super::hook_transaction::{client, flush_if_commit}; +use crate::ipc::client::{BuildHandle, BuildHandler}; +use crate::postgres::index_setup::options; +use crate::prelude::*; + +pub struct Builder { + pub build_handler: Option, + pub ntuples: f64, +} + +pub unsafe fn build( + index: pgrx::pg_sys::Relation, + data: Option<(pgrx::pg_sys::Relation, *mut pgrx::pg_sys::IndexInfo)>, +) { + let oid = (*index).rd_id; + let id = Id::from_sys(oid); + flush_if_commit(id); + let options = options(index); + client(|rpc| { + let build_handler = rpc.build(id, options).unwrap(); + let mut builder = Builder { + build_handler: Some(build_handler), + ntuples: 0.0, + }; + if let Some((heap, index_info)) = data { + pgrx::pg_sys::IndexBuildHeapScan(heap, index, index_info, Some(callback), &mut builder); + } + let build_handler = builder.build_handler.take().unwrap(); + let BuildHandle::Next { x } = build_handler.handle().unwrap() else { + panic!("Invaild state.") + }; + let build_handler = x.leave(None).unwrap(); + let BuildHandle::Leave { x } = build_handler.handle().unwrap() else { + panic!("Invaild state.") + }; + x + }); +} + +#[cfg(any(feature = "pg11", feature = "pg12"))] +#[pg_guard] +unsafe extern "C" fn callback( + _index_relation: pg_sys::Relation, + htup: pg_sys::HeapTuple, + values: *mut pg_sys::Datum, + is_null: *mut bool, + _tuple_is_alive: bool, + state: *mut std::os::raw::c_void, +) { + use super::datatype::VectorInput; + use pgrx::FromDatum; + + let ctid = &(*htup).t_self; + let state = &mut *(state as *mut Builder); + let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let data = ( + pgvector.to_vec().into_boxed_slice(), + Pointer::from_sys(*ctid), + ); + (*state.build).build.next(data).unwrap(); + state.ntuples += 1.0; +} + +#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] +#[pgrx::pg_guard] +unsafe extern "C" fn callback( + _index_relation: pgrx::pg_sys::Relation, + ctid: pgrx::pg_sys::ItemPointer, + values: *mut pgrx::pg_sys::Datum, + is_null: *mut bool, + _tuple_is_alive: bool, + state: *mut std::os::raw::c_void, +) { + use super::datatype::VectorInput; + use pgrx::FromDatum; + + let state = &mut *(state as *mut Builder); + let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let data = ( + pgvector.to_vec().into_boxed_slice(), + Pointer::from_sys(*ctid), + ); + let BuildHandle::Next { x } = state.build_handler.take().unwrap().handle().unwrap() else { + panic!("Invaild state.") + }; + state.build_handler = Some(x.leave(Some(data)).unwrap()); + state.ntuples += 1.0; +} diff --git a/src/postgres/index_scan.rs b/src/postgres/index_scan.rs new file mode 100644 index 000000000..c66f506b3 --- /dev/null +++ b/src/postgres/index_scan.rs @@ -0,0 +1,294 @@ +use crate::postgres::datatype::VectorInput; +use crate::postgres::gucs::K; +use crate::prelude::*; +use pgrx::FromDatum; + +use super::hook_transaction::client; + +#[derive(Debug, Clone)] +pub enum Scanner { + Initial { + // fields to be filled by amhandler and hook + vector: Option>, + index_scan_state: Option<*mut pgrx::pg_sys::IndexScanState>, + }, + Type0 { + data: Vec, + }, + Type1 { + index_scan_state: *mut pgrx::pg_sys::IndexScanState, + data: Vec, + }, +} + +pub unsafe fn make_scan( + index_relation: pgrx::pg_sys::Relation, + n_keys: std::os::raw::c_int, + n_orderbys: std::os::raw::c_int, +) -> pgrx::pg_sys::IndexScanDesc { + use pgrx::PgMemoryContexts; + + assert!(n_keys == 0); + assert!(n_orderbys == 1); + + let scan = pgrx::pg_sys::RelationGetIndexScan(index_relation, n_keys, n_orderbys); + + (*scan).xs_recheck = false; + (*scan).xs_recheckorderby = false; + + let scanner = Scanner::Initial { + vector: None, + index_scan_state: None, + }; + + (*scan).opaque = PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(scanner) as _; + + scan +} + +pub unsafe fn start_scan( + scan: pgrx::pg_sys::IndexScanDesc, + keys: pgrx::pg_sys::ScanKey, + n_keys: std::os::raw::c_int, + orderbys: pgrx::pg_sys::ScanKey, + n_orderbys: std::os::raw::c_int, +) { + use Scanner::*; + + assert!((*scan).numberOfKeys == n_keys); + assert!((*scan).numberOfOrderBys == n_orderbys); + assert!(n_keys == 0); + assert!(n_orderbys == 1); + + if n_keys > 0 { + std::ptr::copy(keys, (*scan).keyData, n_keys as usize); + } + if n_orderbys > 0 { + std::ptr::copy(orderbys, (*scan).orderByData, n_orderbys as usize); + } + if n_orderbys > 0 { + let size = std::mem::size_of::(); + let size = size * (*scan).numberOfOrderBys as usize; + let data = pgrx::pg_sys::palloc0(size) as *mut _; + (*scan).xs_orderbyvals = data; + } + if n_orderbys > 0 { + let size = std::mem::size_of::(); + let size = size * (*scan).numberOfOrderBys as usize; + let data = pgrx::pg_sys::palloc(size) as *mut bool; + data.write_bytes(1, (*scan).numberOfOrderBys as usize); + (*scan).xs_orderbynulls = data; + } + let orderby = orderbys.add(0); + let argument = (*orderby).sk_argument; + let vector = VectorInput::from_datum(argument, false).unwrap(); + let vector = vector.to_vec().into_boxed_slice(); + + let last = (*((*scan).opaque as *mut Scanner)).clone(); + let scanner = (*scan).opaque as *mut Scanner; + + match last { + Initial { + index_scan_state, .. + } => { + *scanner = Initial { + vector: Some(vector), + index_scan_state, + }; + } + Type0 { data: _ } => { + *scanner = Initial { + vector: Some(vector), + index_scan_state: None, + }; + } + Type1 { + index_scan_state, + data: _, + } => { + *scanner = Initial { + vector: Some(vector), + index_scan_state: Some(index_scan_state), + }; + } + } +} + +pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { + let scanner = &mut *((*scan).opaque as *mut Scanner); + if matches!(scanner, Scanner::Initial { .. }) { + let Scanner::Initial { + vector, + index_scan_state, + } = std::mem::replace( + scanner, + Scanner::Initial { + vector: None, + index_scan_state: None, + }, + ) + else { + unreachable!() + }; + let oid = (*(*scan).indexRelation).rd_id; + let id = Id::from_sys(oid); + let vector = vector.expect("`rescan` is never called."); + if let Some(index_scan_state) = index_scan_state { + client(|rpc| { + let k = K.get() as _; + let mut handler = rpc.search(id, vector, k).unwrap(); + let mut res; + let rpc = loop { + use crate::ipc::client::SearchHandle::*; + match handler.handle().unwrap() { + Check { p, x } => { + let result = check(index_scan_state, p); + handler = x.leave(result).unwrap(); + } + Leave { result, x } => { + res = result; + break x; + } + } + }; + res.reverse(); + *scanner = Scanner::Type1 { + index_scan_state, + data: res, + }; + rpc + }); + } else { + pgrx::warning!("Fallback to post-filter."); + client(|rpc| { + let k = K.get() as _; + let mut handler = rpc.search(id, vector, k).unwrap(); + let mut res; + let rpc = loop { + use crate::ipc::client::SearchHandle::*; + match handler.handle().unwrap() { + Check { p: _, x } => { + handler = x.leave(true).unwrap(); + } + Leave { result, x } => { + res = result; + break x; + } + } + }; + res.reverse(); + *scanner = Scanner::Type0 { data: res }; + rpc + }); + } + } + match scanner { + Scanner::Initial { .. } => unreachable!(), + Scanner::Type0 { data } => { + if let Some(p) = data.pop() { + #[cfg(feature = "pg11")] + { + (*scan).xs_ctup.t_self = p.into_sys(); + } + #[cfg(not(feature = "pg11"))] + { + (*scan).xs_heaptid = p.into_sys(); + } + true + } else { + false + } + } + Scanner::Type1 { data, .. } => { + if let Some(p) = data.pop() { + #[cfg(feature = "pg11")] + { + (*scan).xs_ctup.t_self = p.into_sys(); + } + #[cfg(not(feature = "pg11"))] + { + (*scan).xs_heaptid = p.into_sys(); + } + true + } else { + false + } + } + } +} + +unsafe fn execute_boolean_qual( + state: *mut pgrx::pg_sys::ExprState, + econtext: *mut pgrx::pg_sys::ExprContext, +) -> bool { + use pgrx::PgMemoryContexts; + if state.is_null() { + return true; + } + assert!((*state).flags & pgrx::pg_sys::EEO_FLAG_IS_QUAL as u8 != 0); + let mut is_null = true; + pgrx::pg_sys::MemoryContextReset((*econtext).ecxt_per_tuple_memory); + let ret = PgMemoryContexts::For((*econtext).ecxt_per_tuple_memory) + .switch_to(|_| (*state).evalfunc.unwrap()(state, econtext, &mut is_null)); + assert!(!is_null); + bool::from_datum(ret, is_null).unwrap() +} + +unsafe fn check_quals(node: *mut pgrx::pg_sys::IndexScanState) -> bool { + let slot = (*node).ss.ss_ScanTupleSlot; + let econtext = (*node).ss.ps.ps_ExprContext; + (*econtext).ecxt_scantuple = slot; + if (*node).ss.ps.qual.is_null() { + return true; + } + let state = (*node).ss.ps.qual; + let econtext = (*node).ss.ps.ps_ExprContext; + execute_boolean_qual(state, econtext) +} + +unsafe fn check_mvcc(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool { + let scan_desc = (*node).iss_ScanDesc; + let heap_fetch = (*scan_desc).xs_heapfetch; + let index_relation = (*heap_fetch).rel; + let rd_tableam = (*index_relation).rd_tableam; + let snapshot = (*scan_desc).xs_snapshot; + let index_fetch_tuple = (*rd_tableam).index_fetch_tuple.unwrap(); + let mut all_dead = false; + let slot = (*node).ss.ss_ScanTupleSlot; + let mut heap_continue = false; + let found = index_fetch_tuple( + heap_fetch, + &mut p.into_sys(), + snapshot, + slot, + &mut heap_continue, + &mut all_dead, + ); + if found { + return true; + } + while heap_continue { + let found = index_fetch_tuple( + heap_fetch, + &mut p.into_sys(), + snapshot, + slot, + &mut heap_continue, + &mut all_dead, + ); + if found { + return true; + } + } + false +} + +unsafe fn check(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool { + if !check_mvcc(node, p) { + return false; + } + if !check_quals(node) { + return false; + } + true +} diff --git a/src/postgres/index_setup.rs b/src/postgres/index_setup.rs new file mode 100644 index 000000000..28f857e01 --- /dev/null +++ b/src/postgres/index_setup.rs @@ -0,0 +1,125 @@ +use crate::algorithms::AlgorithmOptions; +use crate::bgworker::index::IndexOptions; +use crate::bgworker::vectors::VectorsOptions; +use crate::postgres::datatype::VectorTypmod; +use crate::prelude::*; +use serde::{Deserialize, Serialize}; +use std::ffi::CStr; +use validator::Validate; + +pub fn helper_offset() -> usize { + memoffset::offset_of!(Helper, offset) +} + +pub fn helper_size() -> usize { + std::mem::size_of::() +} + +pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distance { + let opclass_cache_id = pgrx::pg_sys::SysCacheIdentifier_CLAOID as _; + let tuple = pgrx::pg_sys::SearchSysCache1(opclass_cache_id, opclass.into()); + assert!( + !tuple.is_null(), + "cache lookup failed for operator class {opclass}" + ); + let classform = pgrx::pg_sys::GETSTRUCT(tuple).cast::(); + let opfamily = (*classform).opcfamily; + let distance = convert_opfamily_to_distance(opfamily); + pgrx::pg_sys::ReleaseSysCache(tuple); + distance +} + +pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Distance { + let opfamily_cache_id = pgrx::pg_sys::SysCacheIdentifier_OPFAMILYOID as _; + let opstrategy_cache_id = pgrx::pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _; + let tuple = pgrx::pg_sys::SearchSysCache1(opfamily_cache_id, opfamily.into()); + assert!( + !tuple.is_null(), + "cache lookup failed for operator family {opfamily}" + ); + let list = pgrx::pg_sys::SearchSysCacheList( + opstrategy_cache_id, + 1, + opfamily.into(), + 0.into(), + 0.into(), + ); + assert!((*list).n_members == 1); + let member = (*list).members.as_slice(1)[0]; + let member_tuple = &mut (*member).tuple; + let amop = pgrx::pg_sys::GETSTRUCT(member_tuple).cast::(); + assert!((*amop).amopstrategy == 1); + assert!((*amop).amoppurpose == pgrx::pg_sys::AMOP_ORDER as i8); + let operator = (*amop).amopopr; + let distance; + if operator == regoperatorin("<->(vector,vector)") { + distance = Distance::L2; + } else if operator == regoperatorin("<#>(vector,vector)") { + distance = Distance::Dot; + } else if operator == regoperatorin("<=>(vector,vector)") { + distance = Distance::Cosine; + } else { + panic!("Unsupported operator.") + }; + pgrx::pg_sys::ReleaseCatCacheList(list); + pgrx::pg_sys::ReleaseSysCache(tuple); + distance +} + +pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { + let nkeysatts = (*(*index_relation).rd_index).indnkeyatts; + assert!(nkeysatts == 1, "Can not be built on multicolumns."); + // get distance + let opfamily = (*index_relation).rd_opfamily.read(); + let distance = convert_opfamily_to_distance(opfamily); + // get dims + let attrs = (*(*index_relation).rd_att).attrs.as_slice(1); + let attr = &attrs[0]; + let typmod = VectorTypmod::parse_from_i32(attr.type_mod()).unwrap(); + let dims = typmod.dims().expect("Column does not have dimensions."); + // get other options + let parsed = get_parsed_from_varlena((*index_relation).rd_options); + let options = IndexOptions { + dims, + distance, + capacity: parsed.capacity, + vectors: parsed.vectors, + algorithm: parsed.algorithm, + }; + options.validate().expect("The options is invalid."); + options +} + +#[derive(Copy, Clone, Debug)] +#[repr(C)] +struct Helper { + pub vl_len_: i32, + pub offset: i32, +} + +unsafe fn get_parsed_from_varlena(helper: *const pgrx::pg_sys::varlena) -> Parsed { + let helper = helper as *const Helper; + assert!((*helper).offset != 0, "`options` cannot be null."); + let ptr = (helper as *const libc::c_char).offset((*helper).offset as isize); + let cstr = CStr::from_ptr(ptr); + toml::from_str::(cstr.to_str().unwrap()).unwrap() +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct Parsed { + capacity: usize, + vectors: VectorsOptions, + algorithm: AlgorithmOptions, +} + +fn regoperatorin(name: &str) -> pgrx::pg_sys::Oid { + use pgrx::IntoDatum; + let cstr = std::ffi::CString::new(name).expect("specified name has embedded NULL byte"); + unsafe { + pgrx::direct_function_call::( + pgrx::pg_sys::regoperatorin, + &[cstr.as_c_str().into_datum()], + ) + .expect("operator lookup returned NULL") + } +} diff --git a/src/postgres/index_update.rs b/src/postgres/index_update.rs new file mode 100644 index 000000000..d12508b23 --- /dev/null +++ b/src/postgres/index_update.rs @@ -0,0 +1,21 @@ +use crate::postgres::hook_transaction::{client, flush_if_commit}; +use crate::prelude::*; + +pub unsafe fn update_insert(id: Id, vector: Box<[Scalar]>, tid: pgrx::pg_sys::ItemPointer) { + flush_if_commit(id); + let p = Pointer::from_sys(*tid); + client(|mut rpc| { + rpc.insert(id, (vector, p)).unwrap(); + rpc + }) +} + +pub fn update_delete(id: Id, deletes: Vec) { + flush_if_commit(id); + client(|mut rpc| { + for message in deletes { + rpc.delete(id, message).unwrap(); + } + rpc + }) +} diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index 5049643a8..43443b0bf 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -1,11 +1,16 @@ -mod datatype; -mod gucs; +mod casts; +pub mod datatype; +mod functions; +pub mod gucs; +mod hook_executor; +mod hook_transaction; mod hooks; mod index; - -pub use gucs::K; -pub use gucs::OPENAI_API_KEY_GUC; -pub use gucs::PORT; +mod index_build; +mod index_scan; +mod index_setup; +mod index_update; +mod operators; pub unsafe fn init() { self::gucs::init(); diff --git a/src/postgres/operators.rs b/src/postgres/operators.rs new file mode 100644 index 000000000..94957952e --- /dev/null +++ b/src/postgres/operators.rs @@ -0,0 +1,118 @@ +use crate::postgres::datatype::{Vector, VectorInput, VectorOutput}; +use crate::prelude::*; +use std::ops::Deref; + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(+)] +#[pgrx::commutator(+)] +fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + let n = lhs.len(); + let mut v = Vector::new_zeroed(n); + for i in 0..n { + v[i] = lhs[i] + rhs[i]; + } + v.copy_into_postgres() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(-)] +fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + let n = lhs.len(); + let mut v = Vector::new_zeroed(n); + for i in 0..n { + v[i] = lhs[i] - rhs[i]; + } + v.copy_into_postgres() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(<)] +#[pgrx::negator(>=)] +#[pgrx::commutator(>)] +#[pgrx::restrict(scalarltsel)] +#[pgrx::join(scalarltjoinsel)] +fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + lhs.deref() < rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(<=)] +#[pgrx::negator(>)] +#[pgrx::commutator(>=)] +#[pgrx::restrict(scalarltsel)] +#[pgrx::join(scalarltjoinsel)] +fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + lhs.deref() <= rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(>)] +#[pgrx::negator(<=)] +#[pgrx::commutator(<)] +#[pgrx::restrict(scalargtsel)] +#[pgrx::join(scalargtjoinsel)] +fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + lhs.deref() > rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(>=)] +#[pgrx::negator(<)] +#[pgrx::commutator(<=)] +#[pgrx::restrict(scalargtsel)] +#[pgrx::join(scalargtjoinsel)] +fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + lhs.deref() >= rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(=)] +#[pgrx::negator(<>)] +#[pgrx::commutator(=)] +#[pgrx::restrict(eqsel)] +#[pgrx::join(eqjoinsel)] +fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + lhs.deref() == rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(<>)] +#[pgrx::negator(=)] +#[pgrx::commutator(<>)] +#[pgrx::restrict(eqsel)] +#[pgrx::join(eqjoinsel)] +fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + lhs.deref() != rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(<=>)] +#[pgrx::commutator(<=>)] +fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + Cosine::distance(&lhs, &rhs) +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(<#>)] +#[pgrx::commutator(<#>)] +fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + Dot::distance(&lhs, &rhs) +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::opname(<->)] +#[pgrx::commutator(<->)] +fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { + assert_eq!(lhs.len(), rhs.len(), "Invaild operation."); + L2::distance(&lhs, &rhs) +} diff --git a/src/prelude.rs b/src/prelude.rs deleted file mode 100644 index b38a026bc..000000000 --- a/src/prelude.rs +++ /dev/null @@ -1,126 +0,0 @@ -use crate::algorithms::Flat; -use crate::algorithms::Hnsw; -use crate::algorithms::Ivf; -use crate::algorithms::Vectors; -use crate::memory::Address; -use serde::{Deserialize, Serialize}; -use std::sync::Arc; -use validator::Validate; - -pub use crate::utils::scalar::Float; -pub use crate::utils::scalar::Scalar; - -pub use crate::utils::bincode::Bincode; -pub use crate::utils::bincode::BincodeDeserialize; - -pub use crate::utils::distance::Distance; - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub struct Id { - newtype: u32, -} - -impl Id { - pub fn from_sys(sys: pgrx::pg_sys::Oid) -> Self { - Self { - newtype: sys.as_u32(), - } - } - pub fn as_u32(self) -> u32 { - self.newtype - } -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct Pointer { - newtype: u64, -} - -impl Pointer { - pub fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self { - let mut newtype = 0; - newtype |= (sys.ip_blkid.bi_hi as u64) << 32; - newtype |= (sys.ip_blkid.bi_lo as u64) << 16; - newtype |= (sys.ip_posid as u64) << 0; - Self { newtype } - } - pub fn into_sys(self) -> pgrx::pg_sys::ItemPointerData { - pgrx::pg_sys::ItemPointerData { - ip_blkid: pgrx::pg_sys::BlockIdData { - bi_hi: ((self.newtype >> 32) & 0xffff) as u16, - bi_lo: ((self.newtype >> 16) & 0xffff) as u16, - }, - ip_posid: ((self.newtype >> 0) & 0xffff) as u16, - } - } - pub fn from_u48(value: u64) -> Self { - assert!(value < (1u64 << 48)); - Self { newtype: value } - } - pub fn as_u48(self) -> u64 { - self.newtype - } -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -#[repr(u8)] -#[serde(rename_all = "snake_case")] -pub enum Storage { - Ram = 0, - Disk = 1, -} - -pub trait Algorithm: Sized { - type Options: Clone + Serialize + for<'a> Deserialize<'a>; - fn build(options: Options, vectors: Arc, n: usize) -> anyhow::Result; - fn address(&self) -> Address; - fn load(options: Options, vectors: Arc, address: Address) -> anyhow::Result; - fn insert(&self, i: usize) -> anyhow::Result<()>; - fn search(&self, search: (Box<[Scalar]>, usize)) -> anyhow::Result>; -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -pub struct Options { - pub dims: u16, - pub distance: Distance, - pub capacity: usize, - #[validate(range(min = 16384))] - pub size_ram: usize, - #[validate(range(min = 16384))] - pub size_disk: usize, - pub storage_vectors: Storage, - pub algorithm: AlgorithmOptions, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum AlgorithmOptions { - Hnsw(::Options), - Flat(::Options), - Ivf(::Options), -} - -impl AlgorithmOptions { - pub fn unwrap_hnsw(self) -> ::Options { - use AlgorithmOptions::*; - match self { - Hnsw(x) => x, - _ => unreachable!(), - } - } - #[allow(dead_code)] - pub fn unwrap_flat(self) -> ::Options { - use AlgorithmOptions::*; - match self { - Flat(x) => x, - _ => unreachable!(), - } - } - pub fn unwrap_ivf(self) -> ::Options { - use AlgorithmOptions::*; - match self { - Ivf(x) => x, - _ => unreachable!(), - } - } -} diff --git a/src/prelude/distance.rs b/src/prelude/distance.rs new file mode 100644 index 000000000..061e02633 --- /dev/null +++ b/src/prelude/distance.rs @@ -0,0 +1,288 @@ +use crate::prelude::*; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +mod sealed { + pub trait Sealed {} +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Distance { + L2, + Cosine, + Dot, +} + +pub trait DistanceFamily: sealed::Sealed + Copy + Default + Send + Sync + Unpin + 'static { + fn distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar; + // elkan k means + fn elkan_k_means_normalize(vector: &mut [Scalar]); + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar; + type QuantizationState: Debug + + Copy + + Send + + Sync + + serde::Serialize + + for<'a> serde::Deserialize<'a>; + // quantization + const QUANTIZATION_INITIAL_STATE: Self::QuantizationState; + fn quantization_new(lhs: &[Scalar], rhs: &[Scalar]) -> Self::QuantizationState; + fn quantization_merge( + lhs: Self::QuantizationState, + rhs: Self::QuantizationState, + ) -> Self::QuantizationState; + fn quantization_append( + state: Self::QuantizationState, + lhs: Scalar, + rhs: Scalar, + ) -> Self::QuantizationState; + fn quantization_finish(state: Self::QuantizationState) -> Scalar; +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +pub struct L2; + +impl sealed::Sealed for L2 {} + +impl DistanceFamily for L2 { + #[inline(always)] + fn distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + distance_squared_l2(lhs, rhs) + } + + #[inline(always)] + fn elkan_k_means_normalize(_: &mut [Scalar]) {} + + #[inline(always)] + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + distance_squared_l2(lhs, rhs) + } + + type QuantizationState = Scalar; + + const QUANTIZATION_INITIAL_STATE: Scalar = Scalar::Z; + + #[inline(always)] + fn quantization_new(lhs: &[Scalar], rhs: &[Scalar]) -> Self::QuantizationState { + distance_squared_l2(lhs, rhs) + } + + #[inline(always)] + fn quantization_merge(lhs: Scalar, rhs: Scalar) -> Scalar { + lhs + rhs + } + + #[inline(always)] + fn quantization_finish(state: Scalar) -> Scalar { + state + } + + #[inline(always)] + fn quantization_append( + result: Self::QuantizationState, + lhs: Scalar, + rhs: Scalar, + ) -> Self::QuantizationState { + result + (lhs - rhs) * (lhs - rhs) + } +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +pub struct Cosine; + +impl sealed::Sealed for Cosine {} + +impl DistanceFamily for Cosine { + #[inline(always)] + fn distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + distance_cosine(lhs, rhs) * (-1.0) + } + + #[inline(always)] + fn elkan_k_means_normalize(vector: &mut [Scalar]) { + l2_normalize(vector) + } + + #[inline(always)] + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + distance_dot(lhs, rhs).acos() + } + + type QuantizationState = (Scalar, Scalar, Scalar); + + const QUANTIZATION_INITIAL_STATE: (Scalar, Scalar, Scalar) = (Scalar::Z, Scalar::Z, Scalar::Z); + + #[inline(always)] + fn quantization_new(lhs: &[Scalar], rhs: &[Scalar]) -> (Scalar, Scalar, Scalar) { + xy_x2_y2(lhs, rhs) + } + + #[inline(always)] + fn quantization_merge( + (l_xy, l_x2, l_y2): (Scalar, Scalar, Scalar), + (r_xy, r_x2, r_y2): (Scalar, Scalar, Scalar), + ) -> (Scalar, Scalar, Scalar) { + (l_xy + r_xy, l_x2 + r_x2, l_y2 + r_y2) + } + + #[inline(always)] + fn quantization_finish((xy, x2, y2): (Scalar, Scalar, Scalar)) -> Scalar { + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[inline(always)] + fn quantization_append( + (xy, x2, y2): Self::QuantizationState, + x: Scalar, + y: Scalar, + ) -> Self::QuantizationState { + (xy + x * y, x2 + x * x, y2 + y * y) + } +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)] +pub struct Dot; + +impl sealed::Sealed for Dot {} + +impl DistanceFamily for Dot { + #[inline(always)] + fn distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + distance_dot(lhs, rhs) * (-1.0) + } + + #[inline(always)] + fn elkan_k_means_normalize(vector: &mut [Scalar]) { + l2_normalize(vector) + } + + #[inline(always)] + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + distance_dot(lhs, rhs).acos() + } + + type QuantizationState = Scalar; + + const QUANTIZATION_INITIAL_STATE: Scalar = Scalar::Z; + + #[inline(always)] + fn quantization_new(lhs: &[Scalar], rhs: &[Scalar]) -> Self::QuantizationState { + distance_dot(lhs, rhs) + } + + #[inline(always)] + fn quantization_merge(lhs: Scalar, rhs: Scalar) -> Scalar { + lhs + rhs + } + + #[inline(always)] + fn quantization_finish(state: Scalar) -> Scalar { + state * (-1.0) + } + + #[inline(always)] + fn quantization_append( + result: Self::QuantizationState, + x: Scalar, + y: Scalar, + ) -> Self::QuantizationState { + result + x * y + } +} + +#[inline(always)] +fn distance_squared_l2(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + if lhs.len() != rhs.len() { + panic!( + "different vector dimensions {} and {}.", + lhs.len(), + rhs.len() + ); + } + let n = lhs.len(); + let mut d2 = Scalar::Z; + for i in 0..n { + let d = lhs[i] - rhs[i]; + d2 += d * d; + } + d2 +} + +#[inline(always)] +fn distance_cosine(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + if lhs.len() != rhs.len() { + panic!( + "different vector dimensions {} and {}.", + lhs.len(), + rhs.len() + ); + } + let n = lhs.len(); + let mut xy = Scalar::Z; + let mut x2 = Scalar::Z; + let mut y2 = Scalar::Z; + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +fn distance_dot(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { + if lhs.len() != rhs.len() { + panic!( + "different vector dimensions {} and {}.", + lhs.len(), + rhs.len() + ); + } + let n = lhs.len(); + let mut xy = Scalar::Z; + for i in 0..n { + xy += lhs[i] * rhs[i]; + } + xy +} + +#[inline(always)] +fn xy_x2_y2(lhs: &[Scalar], rhs: &[Scalar]) -> (Scalar, Scalar, Scalar) { + if lhs.len() != rhs.len() { + panic!( + "different vector dimensions {} and {}.", + lhs.len(), + rhs.len() + ); + } + let n = lhs.len(); + let mut xy = Scalar::Z; + let mut x2 = Scalar::Z; + let mut y2 = Scalar::Z; + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + (xy, x2, y2) +} + +#[inline(always)] +fn length(vector: &[Scalar]) -> Scalar { + let n = vector.len(); + let mut dot = Scalar::Z; + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +fn l2_normalize(vector: &mut [Scalar]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} diff --git a/src/prelude/mod.rs b/src/prelude/mod.rs new file mode 100644 index 000000000..3f48addae --- /dev/null +++ b/src/prelude/mod.rs @@ -0,0 +1,17 @@ +mod distance; +mod scalar; +mod sys; + +pub use self::distance::{Cosine, Distance, DistanceFamily, Dot, L2}; +pub use self::scalar::{Float, Scalar}; +pub use self::sys::{Id, Pointer}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[repr(u8)] +#[serde(rename_all = "snake_case")] +pub enum Memmap { + Ram = 0, + Disk = 1, +} diff --git a/src/utils/scalar.rs b/src/prelude/scalar.rs similarity index 98% rename from src/utils/scalar.rs rename to src/prelude/scalar.rs index 4124f789c..ee062d3d1 100644 --- a/src/utils/scalar.rs +++ b/src/prelude/scalar.rs @@ -20,6 +20,7 @@ pub struct Scalar(pub Float); impl Scalar { pub const INFINITY: Self = Self(Float::INFINITY); + pub const NEG_INFINITY: Self = Self(Float::NEG_INFINITY); pub const NAN: Self = Self(Float::NAN); pub const Z: Self = Self(0.0); @@ -69,7 +70,7 @@ impl Eq for Scalar {} impl PartialOrd for Scalar { #[inline(always)] fn partial_cmp(&self, other: &Self) -> Option { - Some(self.0.total_cmp(&other.0)) + Some(Ord::cmp(self, other)) } } diff --git a/src/prelude/sys.rs b/src/prelude/sys.rs new file mode 100644 index 000000000..7535a6f2b --- /dev/null +++ b/src/prelude/sys.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::Display; + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Id { + newtype: u32, +} + +impl Id { + pub fn from_sys(sys: pgrx::pg_sys::Oid) -> Self { + Self { + newtype: sys.as_u32(), + } + } + pub fn as_u32(self) -> u32 { + self.newtype + } +} + +impl Display for Id { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "{}", self.as_u32()) + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct Pointer { + newtype: u64, +} + +impl Pointer { + pub fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self { + let mut newtype = 0; + newtype |= (sys.ip_blkid.bi_hi as u64) << 32; + newtype |= (sys.ip_blkid.bi_lo as u64) << 16; + newtype |= (sys.ip_posid as u64) << 0; + Self { newtype } + } + pub fn into_sys(self) -> pgrx::pg_sys::ItemPointerData { + pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { + bi_hi: ((self.newtype >> 32) & 0xffff) as u16, + bi_lo: ((self.newtype >> 16) & 0xffff) as u16, + }, + ip_posid: ((self.newtype >> 0) & 0xffff) as u16, + } + } + pub fn from_u48(value: u64) -> Self { + assert!(value < (1u64 << 48)); + Self { newtype: value } + } + pub fn as_u48(self) -> u64 { + self.newtype + } +} diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index 5cdc0b26b..ffc5c222e 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -1,3 +1,9 @@ +CREATE CAST (real[] AS vector) + WITH FUNCTION cast_array_to_vector(real[], integer, boolean) AS IMPLICIT; + +CREATE CAST (vector AS real[]) + WITH FUNCTION cast_vector_to_array(vector, integer, boolean) AS IMPLICIT; + CREATE OPERATOR CLASS l2_ops FOR TYPE vector USING vectors AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops; diff --git a/src/utils/bincode.rs b/src/utils/bincode.rs deleted file mode 100644 index 4a7c72462..000000000 --- a/src/utils/bincode.rs +++ /dev/null @@ -1,23 +0,0 @@ -use serde::{Deserialize, Serialize}; - -pub trait BincodeDeserialize { - fn deserialize<'a, T: Deserialize<'a>>(&'a self) -> anyhow::Result; -} - -impl BincodeDeserialize for [u8] { - fn deserialize<'a, T: Deserialize<'a>>(&'a self) -> anyhow::Result { - let t = bincode::deserialize::(self)?; - Ok(t) - } -} - -pub trait Bincode: Sized { - fn bincode(&self) -> anyhow::Result>; -} - -impl Bincode for T { - fn bincode(&self) -> anyhow::Result> { - let bytes = bincode::serialize(self)?; - Ok(bytes) - } -} diff --git a/src/utils/distance.rs b/src/utils/distance.rs deleted file mode 100644 index 021785aea..000000000 --- a/src/utils/distance.rs +++ /dev/null @@ -1,97 +0,0 @@ -use crate::prelude::*; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub enum Distance { - L2, - Cosine, - Dot, -} - -impl Distance { - #[inline(always)] - pub fn distance(self, lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - match self { - Distance::L2 => distance_squared_l2(lhs, rhs), - Distance::Cosine => distance_squared_cosine(lhs, rhs) * (-1.0), - Distance::Dot => distance_dot(lhs, rhs) * (-1.0), - } - } - #[inline(always)] - pub fn kmeans_normalize(self, vector: &mut [Scalar]) { - match self { - Distance::L2 => (), - Distance::Cosine | Distance::Dot => l2_normalize(vector), - } - } - #[inline(always)] - pub fn kmeans_distance(self, lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - match self { - Distance::L2 => distance_squared_l2(lhs, rhs), - Distance::Cosine | Distance::Dot => distance_dot(lhs, rhs).acos(), - } - } -} - -#[inline(always)] -fn distance_squared_l2(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - return Scalar::NAN; - } - let n = lhs.len(); - let mut result = Scalar::Z; - for i in 0..n { - let diff = lhs[i] - rhs[i]; - result += diff * diff; - } - result -} - -#[inline(always)] -fn distance_squared_cosine(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - return Scalar::NAN; - } - let n = lhs.len(); - let mut dot = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..n { - dot += lhs[i] * rhs[i]; - x2 += lhs[i] * lhs[i]; - y2 += rhs[i] * rhs[i]; - } - (dot * dot) / (x2 * y2) -} - -#[inline(always)] -fn distance_dot(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - return Scalar::NAN; - } - let n = lhs.len(); - let mut dot = Scalar::Z; - for i in 0..n { - dot += lhs[i] * rhs[i]; - } - dot -} - -#[inline(always)] -fn length(vector: &[Scalar]) -> Scalar { - let n = vector.len(); - let mut dot = Scalar::Z; - for i in 0..n { - dot += vector[i] * vector[i]; - } - dot.sqrt() -} - -#[inline(always)] -fn l2_normalize(vector: &mut [Scalar]) { - let n = vector.len(); - let l = length(vector); - for i in 0..n { - vector[i] /= l; - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs deleted file mode 100644 index 28f290661..000000000 --- a/src/utils/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod bincode; -pub mod distance; -pub mod scalar; - -pub mod fixed_heap; -pub mod parray; -pub mod semaphore; -pub mod unsafe_once; -pub mod vec2; diff --git a/src/utils/parray.rs b/src/utils/parray.rs deleted file mode 100644 index 2e5d63e76..000000000 --- a/src/utils/parray.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::memory::PBox; -use crate::prelude::*; -use std::fmt::Debug; -use std::mem::MaybeUninit; -use std::ops::{Deref, DerefMut}; - -pub struct PArray { - data: PBox<[MaybeUninit]>, - len: usize, -} - -impl PArray { - pub fn new(capacity: usize, storage: Storage) -> anyhow::Result { - Ok(Self { - data: PBox::new_uninit_slice(capacity, storage)?, - len: 0, - }) - } - pub fn clear(&mut self) { - self.len = 0; - } - pub fn capacity(&self) -> usize { - self.data.len() - } - pub fn len(&self) -> usize { - self.len - } - pub fn insert(&mut self, index: usize, element: T) -> anyhow::Result<()> { - assert!(index <= self.len); - if self.len == self.capacity() { - anyhow::bail!("The vector is full."); - } - unsafe { - if index < self.len { - let p = self.data.as_ptr().add(index).cast_mut(); - std::ptr::copy(p, p.add(1), self.len - index); - } - self.data[index].write(element); - self.len += 1; - } - Ok(()) - } - pub fn push(&mut self, element: T) -> anyhow::Result<()> { - if self.capacity() == self.len { - anyhow::bail!("The vector is full."); - } - let index = self.len; - self.data[index].write(element); - self.len += 1; - Ok(()) - } - #[allow(dead_code)] - pub fn pop(&mut self) -> Option { - if self.len == 0 { - return None; - } - let value; - unsafe { - self.len -= 1; - value = self.data[self.len].assume_init_read(); - } - Some(value) - } -} - -impl Deref for PArray { - type Target = [T]; - - fn deref(&self) -> &Self::Target { - unsafe { MaybeUninit::slice_assume_init_ref(&self.data[..self.len]) } - } -} - -impl DerefMut for PArray { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { MaybeUninit::slice_assume_init_mut(&mut self.data[..self.len]) } - } -} - -impl Debug for PArray { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut list = f.debug_list(); - list.entries(self.deref()); - list.finish() - } -} diff --git a/src/utils/unsafe_once.rs b/src/utils/unsafe_once.rs deleted file mode 100644 index 3bce856c8..000000000 --- a/src/utils/unsafe_once.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::cell::UnsafeCell; -use std::mem::MaybeUninit; -use std::ops::Deref; - -#[repr(C)] -pub struct UnsafeOnce { - inner: UnsafeCell>, -} - -impl UnsafeOnce { - #[allow(unused)] - pub const unsafe fn new() -> Self { - Self { - inner: UnsafeCell::new(MaybeUninit::uninit()), - } - } - pub fn set(&self, data: T) { - unsafe { - (*self.inner.get()).write(data); - } - } -} - -impl Deref for UnsafeOnce { - type Target = T; - - fn deref(&self) -> &T { - unsafe { (*self.inner.get()).assume_init_ref() } - } -}