Skip to content

Commit

Permalink
address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjakob committed Apr 29, 2024
1 parent 6f81af9 commit 881d56c
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 151 deletions.
18 changes: 9 additions & 9 deletions elasticsearch/helpers/vectorstore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@
from elasticsearch.helpers.vectorstore._utils import DistanceMetric

__all__ = [
"BM25Strategy",
"DenseVectorStrategy",
"DenseVectorScriptScoreStrategy",
"ElasticsearchEmbeddings",
"EmbeddingService",
"RetrievalStrategy",
"SparseVectorStrategy",
"VectorStore",
"AsyncBM25Strategy",
"AsyncDenseVectorStrategy",
"AsyncDenseVectorScriptScoreStrategy",
"AsyncDenseVectorStrategy",
"AsyncElasticsearchEmbeddings",
"AsyncEmbeddingService",
"AsyncRetrievalStrategy",
"AsyncSparseVectorStrategy",
"AsyncVectorStore",
"BM25Strategy",
"DenseVectorScriptScoreStrategy",
"DenseVectorStrategy",
"DistanceMetric",
"ElasticsearchEmbeddings",
"EmbeddingService",
"RetrievalStrategy",
"SparseVectorStrategy",
"VectorStore",
]
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def __init__(
self.input_field = input_field

async def embed_documents(self, texts: List[str]) -> List[List[float]]:
result = await self._embedding_func(texts)
return result
return await self._embedding_func(texts)

async def embed_query(self, text: str) -> List[float]:
result = await self._embedding_func([text])
Expand Down
26 changes: 13 additions & 13 deletions elasticsearch/helpers/vectorstore/_async/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ def es_mappings_settings(
num_dimensions: Optional[int],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if self.distance is DistanceMetric.COSINE:
similarityAlgo = "cosine"
similarity = "cosine"
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
similarityAlgo = "l2_norm"
similarity = "l2_norm"
elif self.distance is DistanceMetric.DOT_PRODUCT:
similarityAlgo = "dot_product"
similarity = "dot_product"
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
similarityAlgo = "max_inner_product"
similarity = "max_inner_product"
else:
raise ValueError(f"Similarity {self.distance} not supported.")

Expand All @@ -257,7 +257,7 @@ def es_mappings_settings(
"type": "dense_vector",
"dims": num_dimensions,
"index": True,
"similarity": similarityAlgo,
"similarity": similarity,
},
}
}
Expand Down Expand Up @@ -326,18 +326,18 @@ def es_query(
raise ValueError("specify a query_vector")

if self.distance is DistanceMetric.COSINE:
similarityAlgo = (
similarity_algo = (
f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0"
)
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
similarity_algo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
elif self.distance is DistanceMetric.DOT_PRODUCT:
similarityAlgo = f"""
similarity_algo = f"""
double value = dotProduct(params.query_vector, '{vector_field}');
return sigmoid(1, Math.E, -value);
"""
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
similarityAlgo = f"""
similarity_algo = f"""
double value = dotProduct(params.query_vector, '{vector_field}');
if (dotProduct < 0) {{
return 1 / (1 + -1 * dotProduct);
Expand All @@ -347,16 +347,16 @@ def es_query(
else:
raise ValueError(f"Similarity {self.distance} not supported.")

queryBool: Dict[str, Any] = {"match_all": {}}
query_bool: Dict[str, Any] = {"match_all": {}}
if filter:
queryBool = {"bool": {"filter": filter}}
query_bool = {"bool": {"filter": filter}}

return {
"query": {
"script_score": {
"query": queryBool,
"query": query_bool,
"script": {
"source": similarityAlgo,
"source": similarity_algo,
"params": {"query_vector": query_vector},
},
},
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch/helpers/vectorstore/_async/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from elasticsearch import AsyncElasticsearch
from elasticsearch._version import __versionstr__ as lib_version
from elasticsearch.helpers import BulkIndexError, async_bulk
from elasticsearch.helpers.vectorstore._async.embedding_service import (
from elasticsearch.helpers.vectorstore import (
AsyncEmbeddingService,
AsyncRetrievalStrategy,
)
from elasticsearch.helpers.vectorstore._async.strategies import AsyncRetrievalStrategy
from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions elasticsearch/helpers/vectorstore/_sync/embedding_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def __init__(
self.input_field = input_field

def embed_documents(self, texts: List[str]) -> List[List[float]]:
result = self._embedding_func(texts)
return result
return self._embedding_func(texts)

def embed_query(self, text: str) -> List[float]:
result = self._embedding_func([text])
Expand Down
26 changes: 13 additions & 13 deletions elasticsearch/helpers/vectorstore/_sync/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ def es_mappings_settings(
num_dimensions: Optional[int],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if self.distance is DistanceMetric.COSINE:
similarityAlgo = "cosine"
similarity = "cosine"
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
similarityAlgo = "l2_norm"
similarity = "l2_norm"
elif self.distance is DistanceMetric.DOT_PRODUCT:
similarityAlgo = "dot_product"
similarity = "dot_product"
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
similarityAlgo = "max_inner_product"
similarity = "max_inner_product"
else:
raise ValueError(f"Similarity {self.distance} not supported.")

Expand All @@ -257,7 +257,7 @@ def es_mappings_settings(
"type": "dense_vector",
"dims": num_dimensions,
"index": True,
"similarity": similarityAlgo,
"similarity": similarity,
},
}
}
Expand Down Expand Up @@ -326,18 +326,18 @@ def es_query(
raise ValueError("specify a query_vector")

if self.distance is DistanceMetric.COSINE:
similarityAlgo = (
similarity_algo = (
f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0"
)
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
similarity_algo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
elif self.distance is DistanceMetric.DOT_PRODUCT:
similarityAlgo = f"""
similarity_algo = f"""
double value = dotProduct(params.query_vector, '{vector_field}');
return sigmoid(1, Math.E, -value);
"""
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
similarityAlgo = f"""
similarity_algo = f"""
double value = dotProduct(params.query_vector, '{vector_field}');
if (dotProduct < 0) {{
return 1 / (1 + -1 * dotProduct);
Expand All @@ -347,16 +347,16 @@ def es_query(
else:
raise ValueError(f"Similarity {self.distance} not supported.")

queryBool: Dict[str, Any] = {"match_all": {}}
query_bool: Dict[str, Any] = {"match_all": {}}
if filter:
queryBool = {"bool": {"filter": filter}}
query_bool = {"bool": {"filter": filter}}

return {
"query": {
"script_score": {
"query": queryBool,
"query": query_bool,
"script": {
"source": similarityAlgo,
"source": similarity_algo,
"params": {"query_vector": query_vector},
},
},
Expand Down
3 changes: 1 addition & 2 deletions elasticsearch/helpers/vectorstore/_sync/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from elasticsearch import Elasticsearch
from elasticsearch._version import __versionstr__ as lib_version
from elasticsearch.helpers import BulkIndexError, bulk
from elasticsearch.helpers.vectorstore._sync.embedding_service import EmbeddingService
from elasticsearch.helpers.vectorstore._sync.strategies import RetrievalStrategy
from elasticsearch.helpers.vectorstore import EmbeddingService, RetrievalStrategy
from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion elasticsearch/helpers/vectorstore/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,5 @@ def _raise_missing_mmr_deps_error(parent_error: ModuleNotFoundError) -> None:
raise ModuleNotFoundError(
f"Failed to compute maximal marginal relevance because the required "
f"module '{parent_error.name}' is missing. You can install it by running: "
f"'{sys.executable} -m pip install elasticsearch[mmr]'"
f"'{sys.executable} -m pip install elasticsearch[vectorstore_mmr]'"
) from parent_error
6 changes: 4 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def pytest_argv():

@nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"])
def test(session):
session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV, silent=False)
session.install(
".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV, silent=False
)
session.install("-r", "dev-requirements.txt", silent=False)

session.run(*pytest_argv())
Expand Down Expand Up @@ -95,7 +97,7 @@ def lint(session):
session.run("flake8", *SOURCE_FILES)
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)

session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV)
session.install(".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV)

# Run mypy on the package and then the type examples separately for
# the two different mypy use-cases, ourselves and our users.
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
"requests": ["requests>=2.4.0, <3.0.0"],
"async": ["aiohttp>=3,<4"],
"orjson": ["orjson>=3"],
"mmr": ["numpy>=1", "simsimd>=3"],
# Maximal Marginal Relevance (MMR) for search results
"vectorstore_mmr": ["numpy>=1", "simsimd>=3"],
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,68 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import List

from elastic_transport import Transport

from elasticsearch.helpers.vectorstore import EmbeddingService


class RequestSavingTransport(Transport):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.requests: list = []

def perform_request(self, *args, **kwargs):
self.requests.append(kwargs)
return super().perform_request(*args, **kwargs)


class FakeEmbeddings(EmbeddingService):
"""Fake embeddings functionality for testing."""

def __init__(self, dimensionality: int = 10) -> None:
self.dimensionality = dimensionality

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings. Embeddings encode each text as its index."""
return [
[float(1.0)] * (self.dimensionality - 1) + [float(i)]
for i in range(len(texts))
]

def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents.
"""
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]


class ConsistentFakeEmbeddings(FakeEmbeddings):
"""Fake embeddings which remember all the texts seen so far to return consistent
vectors for the same texts."""

def __init__(self, dimensionality: int = 10) -> None:
self.known_texts: List[str] = []
self.dimensionality = dimensionality

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return consistent embeddings for each text seen so far."""
out_vectors = []
for text in texts:
if text not in self.known_texts:
self.known_texts.append(text)
vector = [float(1.0)] * (self.dimensionality - 1) + [
float(self.known_texts.index(text))
]
out_vectors.append(vector)
return out_vectors

def embed_query(self, text: str) -> List[float]:
"""Return consistent embeddings for the text, if seen before, or a constant
one if the text is unknown."""
result = self.embed_documents([text])
return result[0]
Loading

0 comments on commit 881d56c

Please sign in to comment.