From 656d466da47de4d737c855fd7b46ead5ea6f03ef Mon Sep 17 00:00:00 2001 From: Evaldas Buinauskas Date: Fri, 5 Apr 2024 10:02:56 +0300 Subject: [PATCH] feat: add kNN search, query support --- src/search/knn/mod.rs | 283 ++++++++++++++++++++ src/search/mod.rs | 2 + src/search/queries/mod.rs | 1 + src/search/queries/specialized/knn_query.rs | 147 ++++++++++ src/search/queries/specialized/mod.rs | 2 + src/search/request.rs | 13 +- 6 files changed, 447 insertions(+), 1 deletion(-) create mode 100644 src/search/knn/mod.rs create mode 100644 src/search/queries/specialized/knn_query.rs diff --git a/src/search/knn/mod.rs b/src/search/knn/mod.rs new file mode 100644 index 0000000..e994153 --- /dev/null +++ b/src/search/knn/mod.rs @@ -0,0 +1,283 @@ +//! A k-nearest neighbor (kNN) search finds the k nearest vectors to a query vector, as measured by a similarity metric. +//! Common use cases for kNN include: +//! - Relevance ranking based on natural language processing (NLP) algorithms +//! - Product recommendations and recommendation engines +//! - Similarity search for images or videos +//! +//! + +use crate::search::*; +use crate::util::*; +use serde::Serialize; + +/// Performs a k-nearest neighbor (kNN) search and returns the matching documents. +/// +/// The kNN search API performs a k-nearest neighbor (kNN) search on a `dense_vector` field. Given a query vector, it +/// finds the _k_ closest vectors and returns those documents as search hits. +/// +/// Elasticsearch uses the HNSW algorithm to support efficient kNN search. Like most kNN algorithms, HNSW is an +/// approximate method that sacrifices result accuracy for improved search speed. This means the results returned are +/// not always the true _k_ closest neighbors. +/// +/// The kNN search API supports restricting the search using a filter. The search will return the top `k` documents +/// that also match the filter query. +/// +/// To create a knn search with a query vector or query vector builder: +/// ``` +/// # use elasticsearch_dsl::*; +/// # let search = +/// Search::new() +/// .knn(Knn::query_vector("test1", vec![1.0, 2.0, 3.0])) +/// .knn(Knn::query_vector_builder("test3", TextEmbedding::new("my-text-embedding-model", "The opposite of pink"))); +/// ``` +/// +#[derive(Debug, Clone, PartialEq, Serialize)] +pub struct Knn { + field: String, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + query_vector: Option>, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + query_vector_builder: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + k: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + num_candidates: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + filter: Option>, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + similarity: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + boost: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + _name: Option, +} + +impl Knn { + /// Creates an instance of [`Knn`] search with query vector + /// + /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled. + /// - `query_vector` - Query vector. Must have the same number of dimensions as the vector field you are searching + /// against. + pub fn query_vector(field: T, query_vector: Vec) -> Self + where + T: ToString, + { + Self { + field: field.to_string(), + query_vector: Some(query_vector), + query_vector_builder: None, + k: None, + num_candidates: None, + filter: None, + similarity: None, + boost: None, + _name: None, + } + } + /// Creates an instance of [`Knn`] search with query vector builder + /// + /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled. + /// - `query_vector_builder` - A configuration object indicating how to build a query_vector before executing the request. + pub fn query_vector_builder(field: T, query_vector_builder: U) -> Self + where + T: ToString, + U: Into, + { + Self { + field: field.to_string(), + query_vector: None, + query_vector_builder: Some(query_vector_builder.into()), + k: None, + num_candidates: None, + filter: None, + similarity: None, + boost: None, + _name: None, + } + } + + /// Number of nearest neighbors to return as top hits. This value must be less than `num_candidates`. + /// + /// Defaults to `size`. + pub fn k(mut self, k: u32) -> Self { + self.k = Some(k); + self + } + + /// The number of nearest neighbor candidates to consider per shard. Cannot exceed 10,000. Elasticsearch collects + /// `num_candidates` results from each shard, then merges them to find the top results. Increasing `num_candidates` + /// tends to improve the accuracy of the final results. Defaults to `Math.min(1.5 * size, 10_000)`. + pub fn num_candidates(mut self, num_candidates: u32) -> Self { + self.num_candidates = Some(num_candidates); + self + } + + /// Query to filter the documents that can match. The kNN search will return the top documents that also match + /// this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents + /// are allowed to match. + /// + /// The filter is a pre-filter, meaning that it is applied **during** the approximate kNN search to ensure that + /// `num_candidates` matching documents are returned. + pub fn filter(mut self, filter: T) -> Self + where + T: Into, + { + self.filter = Some(Box::new(filter.into())); + self + } + + /// The minimum similarity required for a document to be considered a match. The similarity value calculated + /// relates to the raw similarity used. Not the document score. The matched documents are then scored according + /// to similarity and the provided boost is applied. + pub fn similarity(mut self, similarity: f32) -> Self { + self.similarity = Some(similarity); + self + } + + add_boost_and_name!(); +} + +impl ShouldSkip for Knn {} + +/// A configuration object indicating how to build a query_vector before executing the request. +/// +/// Currently, the only supported builder is [`TextEmbedding`]. +/// +/// +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum QueryVectorBuilder { + /// The natural language processing task to perform. + TextEmbedding(TextEmbedding), +} + +/// The natural language processing task to perform. +#[derive(Debug, Clone, PartialEq, Serialize)] +pub struct TextEmbedding { + model_id: String, + model_text: String, +} + +impl From for QueryVectorBuilder { + fn from(embedding: TextEmbedding) -> Self { + Self::TextEmbedding(embedding) + } +} + +impl TextEmbedding { + /// Creates an instance of [`TextEmbedding`] + /// - `model_id` - The ID of the text embedding model to use to generate the dense vectors from the query string. + /// Use the same model that generated the embeddings from the input text in the index you search against. You can + /// use the value of the deployment_id instead in the model_id argument. + /// - `model_text` - The query string from which the model generates the dense vector representation. + pub fn new(model_id: T, model_text: U) -> Self + where + T: ToString, + U: ToString, + { + Self { + model_id: model_id.to_string(), + model_text: model_text.to_string(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialization() { + assert_serialize( + Search::new() + .knn(Knn::query_vector("test1", vec![1.0, 2.0, 3.0])) + .knn( + Knn::query_vector("test2", vec![4.0, 5.0, 6.0]) + .k(3) + .num_candidates(100) + .filter(Query::term("field", "value")) + .similarity(0.5) + .boost(2.0) + .name("test2"), + ) + .knn(Knn::query_vector_builder( + "test3", + TextEmbedding::new("my-text-embedding-model", "The opposite of pink"), + )) + .knn( + Knn::query_vector_builder( + "test4", + TextEmbedding::new("my-text-embedding-model", "The opposite of blue"), + ) + .k(5) + .num_candidates(200) + .filter(Query::term("field", "value")) + .similarity(0.7) + .boost(2.1) + .name("test4"), + ), + json!({ + "knn": [ + { + "field": "test1", + "query_vector": [1.0, 2.0, 3.0] + }, + { + "field": "test2", + "query_vector": [4.0, 5.0, 6.0], + "k": 3, + "num_candidates": 100, + "filter": { + "term": { + "field": { + "value": "value" + } + } + }, + "similarity": 0.5, + "boost": 2.0, + "_name": "test2" + }, + { + "field": "test3", + "query_vector_builder": { + "text_embedding": { + "model_id": "my-text-embedding-model", + "model_text": "The opposite of pink" + } + } + }, + { + "field": "test4", + "query_vector_builder": { + "text_embedding": { + "model_id": "my-text-embedding-model", + "model_text": "The opposite of blue" + } + }, + "k": 5, + "num_candidates": 200, + "filter": { + "term": { + "field": { + "value": "value" + } + } + }, + "similarity": 0.7, + "boost": 2.1, + "_name": "test4" + } + ] + }), + ); + } +} diff --git a/src/search/mod.rs b/src/search/mod.rs index 2c73cf8..4a53f07 100644 --- a/src/search/mod.rs +++ b/src/search/mod.rs @@ -15,6 +15,7 @@ mod response; // Public modules pub mod aggregations; pub mod highlight; +pub mod knn; pub mod params; pub mod queries; pub mod request; @@ -27,6 +28,7 @@ pub mod suggesters; // Public re-exports pub use self::aggregations::*; pub use self::highlight::*; +pub use self::knn::*; pub use self::params::*; pub use self::queries::params::*; pub use self::queries::*; diff --git a/src/search/queries/mod.rs b/src/search/queries/mod.rs index a46e226..c6db471 100644 --- a/src/search/queries/mod.rs +++ b/src/search/queries/mod.rs @@ -204,6 +204,7 @@ query!( SpanOr(SpanOrQuery), SpanTerm(SpanTermQuery), SpanWithin(SpanWithinQuery), + Knn(KnnQuery), ); #[cfg(test)] diff --git a/src/search/queries/specialized/knn_query.rs b/src/search/queries/specialized/knn_query.rs new file mode 100644 index 0000000..49a2ca6 --- /dev/null +++ b/src/search/queries/specialized/knn_query.rs @@ -0,0 +1,147 @@ +use crate::search::*; +use crate::util::*; +use serde::Serialize; + +/// Finds the _k_ nearest vectors to a query vector, as measured by a similarity metric. _knn_ query finds nearest +/// vectors through approximate search on indexed dense_vectors. The preferred way to do approximate kNN search is +/// through the +/// [top level knn section](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) of a +/// search request. _knn_ query is reserved for expert cases, where there is a need to combine this query with other queries. +/// +/// > `knn` query doesn’t have a separate `k` parameter. `k` is defined by `size` parameter of a search request +/// similar to other queries. `knn` query collects `num_candidates` results from each shard, then merges them to get +/// the top `size` results. +/// +/// To create a knn query: +/// ``` +/// # use elasticsearch_dsl::queries::*; +/// # let query = +/// Query::knn("test", vec![1.0, 2.0, 3.0]); +/// ``` +/// +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(remote = "Self")] +pub struct KnnQuery { + field: String, + + query_vector: Vec, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + num_candidates: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + filter: Option>, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + similarity: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + boost: Option, + + #[serde(skip_serializing_if = "ShouldSkip::should_skip")] + _name: Option, +} + +impl KnnQuery { + /// The number of nearest neighbor candidates to consider per shard. Cannot exceed 10,000. Elasticsearch collects + /// `num_candidates` results from each shard, then merges them to find the top results. Increasing `num_candidates` + /// tends to improve the accuracy of the final results. Defaults to `Math.min(1.5 * size, 10_000)`. + pub fn num_candidates(mut self, num_candidates: u32) -> Self { + self.num_candidates = Some(num_candidates); + self + } + + /// Query to filter the documents that can match. The kNN search will return the top documents that also match + /// this filter. The value can be a single query or a list of queries. If `filter` is not provided, all documents + /// are allowed to match. + /// + /// The filter is a pre-filter, meaning that it is applied **during** the approximate kNN search to ensure that + /// `num_candidates` matching documents are returned. + pub fn filter(mut self, filter: T) -> Self + where + T: Into, + { + self.filter = Some(Box::new(filter.into())); + self + } + + /// The minimum similarity required for a document to be considered a match. The similarity value calculated + /// relates to the raw similarity used. Not the document score. The matched documents are then scored according + /// to similarity and the provided boost is applied. + pub fn similarity(mut self, similarity: f32) -> Self { + self.similarity = Some(similarity); + self + } + + add_boost_and_name!(); +} + +impl ShouldSkip for KnnQuery {} + +serialize_with_root!("knn": KnnQuery); + +impl Query { + /// Creates an instance of [`KnnQuery`] + /// + /// - `field` - The name of the vector field to search against. Must be a dense_vector field with indexing enabled. + /// - `query_vector` - Query vector. Must have the same number of dimensions as the vector field you are searching + /// against. + pub fn knn(field: T, query_vector: Vec) -> KnnQuery + where + T: ToString, + { + KnnQuery { + field: field.to_string(), + query_vector, + num_candidates: None, + filter: None, + similarity: None, + boost: None, + _name: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialization() { + assert_serialize_query( + Query::knn("test", vec![1.0, 2.0, 3.0]), + json!({ + "knn": { + "field": "test", + "query_vector": [1.0, 2.0, 3.0] + } + }), + ); + + assert_serialize_query( + Query::knn("test", vec![1.0, 2.0, 3.0]) + .num_candidates(100) + .filter(Query::term("field", "value")) + .similarity(0.5) + .boost(2.0) + .name("test"), + json!({ + "knn": { + "field": "test", + "query_vector": [1.0, 2.0, 3.0], + "num_candidates": 100, + "filter": { + "term": { + "field": { + "value": "value" + } + } + }, + "similarity": 0.5, + "boost": 2.0, + "_name": "test" + } + }), + ); + } +} diff --git a/src/search/queries/specialized/mod.rs b/src/search/queries/specialized/mod.rs index 9677cf4..c1a7902 100644 --- a/src/search/queries/specialized/mod.rs +++ b/src/search/queries/specialized/mod.rs @@ -1,6 +1,7 @@ //! This group contains queries which do not fit into the other groups mod distance_feature_query; +mod knn_query; mod more_like_this_query; mod percolate_lookup_query; mod percolate_query; @@ -11,6 +12,7 @@ mod script_score_query; mod wrapper_query; pub use self::distance_feature_query::*; +pub use self::knn_query::*; pub use self::more_like_this_query::*; pub use self::percolate_lookup_query::*; pub use self::percolate_query::*; diff --git a/src/search/request.rs b/src/search/request.rs index 83e5ecf..d123f03 100644 --- a/src/search/request.rs +++ b/src/search/request.rs @@ -72,6 +72,9 @@ pub struct Search { #[serde(skip_serializing_if = "ShouldSkip::should_skip")] timeout: Option