diff --git a/libs/elasticsearch/langchain_elasticsearch/vectorstores.py b/libs/elasticsearch/langchain_elasticsearch/vectorstores.py index b3c0b5b..80761b8 100644 --- a/libs/elasticsearch/langchain_elasticsearch/vectorstores.py +++ b/libs/elasticsearch/langchain_elasticsearch/vectorstores.py @@ -70,6 +70,7 @@ def index( self, dims_length: Union[int, None], vector_query_field: str, + text_field: str, similarity: Union[DistanceStrategy, None], ) -> Dict: """ @@ -80,6 +81,7 @@ def index( or None if not using vector-based query. vector_query_field: The field containing the vector representations in the index. + text_field: The field containing the text data in the index. similarity: The similarity strategy to use, or None if not using one. @@ -210,6 +212,7 @@ def index( self, dims_length: Union[int, None], vector_query_field: str, + text_field: str, similarity: Union[DistanceStrategy, None], ) -> Dict: """Create the mapping for the Elasticsearch index.""" @@ -289,6 +292,7 @@ def index( self, dims_length: Union[int, None], vector_query_field: str, + text_field: str, similarity: Union[DistanceStrategy, None], ) -> Dict: """Create the mapping for the Elasticsearch index.""" @@ -372,6 +376,7 @@ def index( self, dims_length: Union[int, None], vector_query_field: str, + text_field: str, similarity: Union[DistanceStrategy, None], ) -> Dict: return { @@ -389,6 +394,76 @@ def require_inference(self) -> bool: return False +class BM25RetrievalStrategy(BaseRetrievalStrategy): + """Retrieval strategy using the native BM25 algorithm of Elasticsearch.""" + + def __init__(self, k1: Union[float, None] = None, b: Union[float, None] = None): + self.k1 = k1 + self.b = b + + def query( + self, + query_vector: Union[List[float], None], + query: Union[str, None], + k: int, + fetch_k: int, + vector_query_field: str, + text_field: str, + filter: List[dict], + similarity: Union[DistanceStrategy, None], + ) -> Dict: + return { + "query": { + "bool": { + "must": [ + { + "match": { + text_field: { + "query": query, + } + }, + }, + ], + "filter": filter, + }, + }, + } + + def index( + self, + dims_length: Union[int, None], + vector_query_field: str, + text_field: str, + similarity: Union[DistanceStrategy, None], + ) -> Dict: + mappings: Dict = { + "properties": { + text_field: { + "type": "text", + "similarity": "custom_bm25", + }, + }, + } + settings: Dict = { + "similarity": { + "custom_bm25": { + "type": "BM25", + }, + }, + } + + if self.k1 is not None: + settings["similarity"]["custom_bm25"]["k1"] = self.k1 + + if self.b is not None: + settings["similarity"]["custom_bm25"]["b"] = self.b + + return {"mappings": mappings, "settings": settings} + + def require_inference(self) -> bool: + return False + + class ElasticsearchStore(VectorStore): """`Elasticsearch` vector store. @@ -905,6 +980,7 @@ def _create_index_if_not_exists( indexSettings = self.strategy.index( vector_query_field=self.vector_query_field, + text_field=self.query_field, dims_length=dims_length, similarity=self.distance_strategy, ) @@ -1284,3 +1360,17 @@ def SparseVectorRetrievalStrategy( deployed to Elasticsearch. """ return SparseRetrievalStrategy(model_id=model_id) + + @staticmethod + def BM25RetrievalStrategy( + k1: Union[float, None] = None, b: Union[float, None] = None + ) -> "BM25RetrievalStrategy": + """Used to apply BM25 without vector search. + + Args: + k1: Optional. This corresponds to the BM25 parameter, k1. Default is None, + which uses the default setting of Elasticsearch. + b: Optional. This corresponds to the BM25 parameter, b. Default is None, + which uses the default setting of Elasticsearch. + """ + return BM25RetrievalStrategy(k1=k1, b=b) diff --git a/libs/elasticsearch/tests/integration_tests/test_vectorstores.py b/libs/elasticsearch/tests/integration_tests/test_vectorstores.py index 3ac6d45..2d888fa 100644 --- a/libs/elasticsearch/tests/integration_tests/test_vectorstores.py +++ b/libs/elasticsearch/tests/integration_tests/test_vectorstores.py @@ -777,6 +777,67 @@ def test_elasticsearch_with_relevance_score( ) assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)] + def test_similarity_search_bm25_search( + self, elasticsearch_connection: dict, index_name: str + ) -> None: + """Test end to end using the BM25 retrieval strategy.""" + texts = ["foo", "bar", "baz"] + docsearch = ElasticsearchStore.from_texts( + texts, + None, + **elasticsearch_connection, + index_name=index_name, + strategy=ElasticsearchStore.BM25RetrievalStrategy(), + ) + + def assert_query(query_body: dict, query: str) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text": {"query": "foo"}}}], + "filter": [], + } + } + } + return query_body + + output = docsearch.similarity_search("foo", k=1, custom_query=assert_query) + assert output == [Document(page_content="foo")] + + def test_similarity_search_bm25_search_with_filter( + self, elasticsearch_connection: dict, index_name: str + ) -> None: + """Test end to using the BM25 retrieval strategy with metadata.""" + texts = ["foo", "foo", "foo"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = ElasticsearchStore.from_texts( + texts, + None, + **elasticsearch_connection, + index_name=index_name, + metadatas=metadatas, + strategy=ElasticsearchStore.BM25RetrievalStrategy(), + ) + + def assert_query(query_body: dict, query: str) -> dict: + assert query_body == { + "query": { + "bool": { + "must": [{"match": {"text": {"query": "foo"}}}], + "filter": [{"term": {"metadata.page": 1}}], + } + } + } + return query_body + + output = docsearch.similarity_search( + "foo", + k=3, + custom_query=assert_query, + filter=[{"term": {"metadata.page": 1}}], + ) + assert output == [Document(page_content="foo", metadata={"page": 1})] + def test_elasticsearch_with_relevance_threshold( self, elasticsearch_connection: dict, index_name: str ) -> None: