-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* ElasticsearchStore * Update elasticsearch/store/_utilities.py Co-authored-by: Quentin Pradet <[email protected]> * rename; depend on client; async only * generate _sync files * add cleanup step for _sync generation * fix formatting * more linting fixes * batch embedding call; infer num_dimensions * revert accidental changes * keep field names only in store; apply metadata mappings in store * fix typos in file names * use `elasticsearch_url` fixture; create conftest.py * export relevant classes * remove Semantic strategy wait for `semantic_text` to land * es_query is sync * async strategies * cleanup old file * add docker-compose service with model deployment * optional dependencies for MMR * only test sync parts * cleanup unasync script * nox: install optional deps * fix tests with requests remembering Transport * fix numpy typing * add user agent default argument * move to `elasticsearch.helpers.vectorstore` * use Protocol over ABC * revert Protocol change because Python 3.7 * address PR feedback: - Strategy suffix - Sphinx docstrings - add user agent to EmbeddingService - raise ConflictError - various cleanup * improve docstring * fix metadata mappings issue * address PR feedback * add error tests for strategies * canonical names, keyword args only * fix sparse vector strategy bug (duplicate `size`) * all wildcard deletes in compose ES --------- Co-authored-by: Quentin Pradet <[email protected]> (cherry picked from commit c2b0ca3) Co-authored-by: Max Jakob <[email protected]>
- Loading branch information
1 parent
90aa5cb
commit 590ee66
Showing
24 changed files
with
3,543 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Licensed to Elasticsearch B.V. under one or more contributor | ||
# license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright | ||
# ownership. Elasticsearch B.V. licenses this file to you under | ||
# the Apache License, Version 2.0 (the "License"); you may | ||
# not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from elasticsearch.helpers.vectorstore._async.embedding_service import ( | ||
AsyncElasticsearchEmbeddings, | ||
AsyncEmbeddingService, | ||
) | ||
from elasticsearch.helpers.vectorstore._async.strategies import ( | ||
AsyncBM25Strategy, | ||
AsyncDenseVectorScriptScoreStrategy, | ||
AsyncDenseVectorStrategy, | ||
AsyncRetrievalStrategy, | ||
AsyncSparseVectorStrategy, | ||
) | ||
from elasticsearch.helpers.vectorstore._async.vectorstore import AsyncVectorStore | ||
from elasticsearch.helpers.vectorstore._sync.embedding_service import ( | ||
ElasticsearchEmbeddings, | ||
EmbeddingService, | ||
) | ||
from elasticsearch.helpers.vectorstore._sync.strategies import ( | ||
BM25Strategy, | ||
DenseVectorScriptScoreStrategy, | ||
DenseVectorStrategy, | ||
RetrievalStrategy, | ||
SparseVectorStrategy, | ||
) | ||
from elasticsearch.helpers.vectorstore._sync.vectorstore import VectorStore | ||
from elasticsearch.helpers.vectorstore._utils import DistanceMetric | ||
|
||
__all__ = [ | ||
"AsyncBM25Strategy", | ||
"AsyncDenseVectorScriptScoreStrategy", | ||
"AsyncDenseVectorStrategy", | ||
"AsyncElasticsearchEmbeddings", | ||
"AsyncEmbeddingService", | ||
"AsyncRetrievalStrategy", | ||
"AsyncSparseVectorStrategy", | ||
"AsyncVectorStore", | ||
"BM25Strategy", | ||
"DenseVectorScriptScoreStrategy", | ||
"DenseVectorStrategy", | ||
"DistanceMetric", | ||
"ElasticsearchEmbeddings", | ||
"EmbeddingService", | ||
"RetrievalStrategy", | ||
"SparseVectorStrategy", | ||
"VectorStore", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Licensed to Elasticsearch B.V. under one or more contributor | ||
# license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright | ||
# ownership. Elasticsearch B.V. licenses this file to you under | ||
# the Apache License, Version 2.0 (the "License"); you may | ||
# not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Licensed to Elasticsearch B.V. under one or more contributor | ||
# license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright | ||
# ownership. Elasticsearch B.V. licenses this file to you under | ||
# the Apache License, Version 2.0 (the "License"); you may | ||
# not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from elasticsearch import AsyncElasticsearch, BadRequestError, NotFoundError | ||
|
||
|
||
async def model_must_be_deployed(client: AsyncElasticsearch, model_id: str) -> None: | ||
""" | ||
:raises [NotFoundError]: if the model is neither downloaded nor deployed. | ||
:raises [ConflictError]: if the model is downloaded but not yet deployed. | ||
""" | ||
doc = {"text_field": f"test if the model '{model_id}' is deployed"} | ||
try: | ||
await client.ml.infer_trained_model(model_id=model_id, docs=[doc]) | ||
except BadRequestError: | ||
# The model is deployed but expects a different input field name. | ||
pass | ||
|
||
|
||
async def model_is_deployed(client: AsyncElasticsearch, model_id: str) -> bool: | ||
try: | ||
await model_must_be_deployed(client, model_id) | ||
return True | ||
except NotFoundError: | ||
return False |
89 changes: 89 additions & 0 deletions
89
elasticsearch/helpers/vectorstore/_async/embedding_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Licensed to Elasticsearch B.V. under one or more contributor | ||
# license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright | ||
# ownership. Elasticsearch B.V. licenses this file to you under | ||
# the Apache License, Version 2.0 (the "License"); you may | ||
# not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
from elasticsearch import AsyncElasticsearch | ||
from elasticsearch._version import __versionstr__ as lib_version | ||
|
||
|
||
class AsyncEmbeddingService(ABC): | ||
@abstractmethod | ||
async def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
"""Generate embeddings for a list of documents. | ||
:param texts: A list of document strings to generate embeddings for. | ||
:return: A list of embeddings, one for each document in the input. | ||
""" | ||
|
||
@abstractmethod | ||
async def embed_query(self, query: str) -> List[float]: | ||
"""Generate an embedding for a single query text. | ||
:param text: The query text to generate an embedding for. | ||
:return: The embedding for the input query text. | ||
""" | ||
|
||
|
||
class AsyncElasticsearchEmbeddings(AsyncEmbeddingService): | ||
"""Elasticsearch as a service for embedding model inference. | ||
You need to have an embedding model downloaded and deployed in Elasticsearch: | ||
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html | ||
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html | ||
""" # noqa: E501 | ||
|
||
def __init__( | ||
self, | ||
*, | ||
client: AsyncElasticsearch, | ||
model_id: str, | ||
input_field: str = "text_field", | ||
user_agent: str = f"elasticsearch-py-es/{lib_version}", | ||
): | ||
""" | ||
:param agent_header: user agent header specific to the 3rd party integration. | ||
Used for usage tracking in Elastic Cloud. | ||
:param model_id: The model_id of the model deployed in the Elasticsearch cluster. | ||
:param input_field: The name of the key for the input text field in the | ||
document. Defaults to 'text_field'. | ||
:param client: Elasticsearch client connection. Alternatively specify the | ||
Elasticsearch connection with the other es_* parameters. | ||
""" | ||
# Add integration-specific usage header for tracking usage in Elastic Cloud. | ||
# client.options preserves existing (non-user-agent) headers. | ||
client = client.options(headers={"User-Agent": user_agent}) | ||
|
||
self.client = client | ||
self.model_id = model_id | ||
self.input_field = input_field | ||
|
||
async def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
return await self._embedding_func(texts) | ||
|
||
async def embed_query(self, text: str) -> List[float]: | ||
result = await self._embedding_func([text]) | ||
return result[0] | ||
|
||
async def _embedding_func(self, texts: List[str]) -> List[List[float]]: | ||
response = await self.client.ml.infer_trained_model( | ||
model_id=self.model_id, docs=[{self.input_field: text} for text in texts] | ||
) | ||
return [doc["predicted_value"] for doc in response["inference_results"]] |
Oops, something went wrong.