Skip to content

Commit

Permalink
async embeddings integration and some async vectorstore methods imple…
Browse files Browse the repository at this point in the history
…mentation
  • Loading branch information
giacbrd committed Oct 3, 2024
1 parent 6488e28 commit 8e32233
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 4 deletions.
44 changes: 43 additions & 1 deletion libs/elasticsearch/langchain_elasticsearch/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, List, Optional

from elasticsearch import Elasticsearch
from elasticsearch.helpers.vectorstore import EmbeddingService
from elasticsearch.helpers.vectorstore import AsyncEmbeddingService, EmbeddingService
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_env

Expand Down Expand Up @@ -249,3 +249,45 @@ def embed_query(self, text: str) -> List[float]:
List[float]: The embedding for the input query text.
"""
return self._langchain_embeddings.embed_query(text)


class AsyncEmbeddingServiceAdapter(AsyncEmbeddingService):
"""
Adapter for LangChain Embeddings to support the AsyncEmbeddingService interface from
elasticsearch.helpers.vectorstore.
"""

def __init__(self, langchain_embeddings: Embeddings):
self._langchain_embeddings = langchain_embeddings

def __eq__(self, other): # type: ignore[no-untyped-def]
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
else:
return False

async def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of documents.
Args:
texts (List[str]): A list of document text strings to generate embeddings
for.
Returns:
List[List[float]]: A list of embeddings, one for each document in the input
list.
"""
return await self._langchain_embeddings.aembed_documents(texts)

async def embed_query(self, text: str) -> List[float]:
"""
Generate an embedding for a single query text.
Args:
text (str): The query text to generate an embedding for.
Returns:
List[float]: The embedding for the input query text.
"""
return await self._langchain_embeddings.aembed_query(text)
135 changes: 132 additions & 3 deletions libs/elasticsearch/langchain_elasticsearch/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Union,
)

from elasticsearch import AsyncElasticsearch, Elasticsearch
from elasticsearch import Elasticsearch
from elasticsearch.helpers.vectorstore import (
AsyncBM25Strategy,
AsyncDenseVectorScriptScoreStrategy,
Expand Down Expand Up @@ -43,7 +43,10 @@
create_elasticsearch_async_client,
create_elasticsearch_client,
)
from langchain_elasticsearch.embeddings import EmbeddingServiceAdapter
from langchain_elasticsearch.embeddings import (
AsyncEmbeddingServiceAdapter,
EmbeddingServiceAdapter,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -530,6 +533,7 @@ def _convert_retrieval_strategy(
)


# FIXME this maps must be kept updated with new strategy classes in Elasticsearch library

Check failure on line 536 in libs/elasticsearch/langchain_elasticsearch/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/elasticsearch / make lint #3.8

Ruff (E501)

langchain_elasticsearch/vectorstores.py:536:89: E501 Line too long (89 > 88)

Check failure on line 536 in libs/elasticsearch/langchain_elasticsearch/vectorstores.py

View workflow job for this annotation

GitHub Actions / cd libs/elasticsearch / make lint #3.11

Ruff (E501)

langchain_elasticsearch/vectorstores.py:536:89: E501 Line too long (89 > 88)
_sync_to_async_strategy_map: Dict[
Type[RetrievalStrategy], Type[AsyncRetrievalStrategy]
] = {
Expand Down Expand Up @@ -875,16 +879,21 @@ def __init__(
)

self._async_store = None
self._async_embedding_service = None
if es_async_connection is not None:
async_embedding_service = None
if embedding:
async_embedding_service = AsyncEmbeddingServiceAdapter(embedding)
self._async_store = AsyncVectorStore(
client=es_async_connection,
index=index_name,
retrieval_strategy=async_strategy,
embedding_service=embedding_service,
embedding_service=async_embedding_service,
text_field=query_field,
vector_field=vector_query_field,
user_agent=user_agent("langchain-py-vs"),
)
self._async_embedding_service = async_embedding_service

self.embedding = embedding
self.client = self._store.client
Expand Down Expand Up @@ -959,6 +968,56 @@ def similarity_search(
)
return [doc for doc, _score in docs]

async def asimilarity_search(
self,
query: str,
k: int = 4,
fetch_k: int = 50,
filter: Optional[List[dict]] = None,
*,
custom_query: Optional[
Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
] = None,
doc_builder: Optional[Callable[[Dict], Document]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return Elasticsearch documents most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k (int): Number of Documents to fetch to pass to knn num_candidates.
filter: Array of Elasticsearch filter clauses to apply to the query.
Returns:
List of Documents most similar to the query,
in descending order of similarity.
"""
if self._async_store is not None:
hits = await self._async_store.search(
query=query,
k=k,
num_candidates=fetch_k,
filter=filter,
custom_query=custom_query,
)
docs = _hits_to_docs_scores(
hits=hits,
content_field=self.query_field,
doc_builder=doc_builder,
)
return [doc for doc, _score in docs]
else:
return await super().asimilarity_search(
query=query,
k=k,
fetch_k=fetch_k,
filter=filter,
custom_query=custom_query,
doc_builder=doc_builder,
**kwargs,
)

def max_marginal_relevance_search(
self,
query: str,
Expand Down Expand Up @@ -1017,6 +1076,76 @@ def max_marginal_relevance_search(

return [doc for doc, _score in docs_scores]

async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
fields: Optional[List[str]] = None,
*,
custom_query: Optional[
Callable[[Dict[str, Any], Optional[str]], Dict[str, Any]]
] = None,
doc_builder: Optional[Callable[[Dict], Document]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query (str): Text to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
lambda_mult (float): Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
fields: Other fields to get from elasticsearch source. These fields
will be added to the document metadata.
Returns:
List[Document]: A list of Documents selected by maximal marginal relevance.
"""
if self._async_store is not None:
if self._async_embedding_service is None:
raise ValueError(
"maximal marginal relevance search requires an embedding service."
)

hits = await self._async_store.max_marginal_relevance_search(
embedding_service=self._async_embedding_service,
query=query,
vector_field=self.vector_query_field,
k=k,
num_candidates=fetch_k,
lambda_mult=lambda_mult,
fields=fields,
custom_query=custom_query,
)

docs_scores = _hits_to_docs_scores(
hits=hits,
content_field=self.query_field,
fields=fields,
doc_builder=doc_builder,
)

return [doc for doc, _score in docs_scores]
else:
return await super().amax_marginal_relevance_search(
query=query,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
fields=fields,
custom_query=custom_query,
doc_builder=doc_builder,
**kwargs,
)

@staticmethod
def _identity_fn(score: float) -> float:
return score
Expand Down

0 comments on commit 8e32233

Please sign in to comment.