-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate VectorStore from Elasticsearch client
- Loading branch information
Showing
8 changed files
with
439 additions
and
416 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,62 @@ | ||
"""Test finetuning engine.""" | ||
import pkgutil | ||
|
||
import pytest | ||
|
||
from llama_index.vector_stores.elasticsearch.base import ( | ||
_mode_must_match_retrieval_strategy, | ||
VectorStoreQueryMode, | ||
AsyncRetrievalStrategy, | ||
AsyncSparseVectorStrategy, | ||
AsyncBM25Strategy, | ||
AsyncDenseVectorStrategy, | ||
) | ||
|
||
|
||
def test_mode_must_match_retrieval_strategy() -> None: | ||
# DEFAULT mode should never raise any exception | ||
mode = VectorStoreQueryMode.DEFAULT | ||
retrieval_strategy = AsyncRetrievalStrategy() | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
# AsyncSparseVectorStrategy with mode SPARSE should not raise any exception | ||
mode = VectorStoreQueryMode.SPARSE | ||
retrieval_strategy = AsyncSparseVectorStrategy() | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
# AsyncBM25Strategy with TEXT_SEARCH should not raise any exception | ||
mode = VectorStoreQueryMode.TEXT_SEARCH | ||
retrieval_strategy = AsyncBM25Strategy() | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
# AsyncDenseVectorStrategy with mode HYBRID should not raise any exception | ||
mode = VectorStoreQueryMode.HYBRID | ||
retrieval_strategy = AsyncDenseVectorStrategy() | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
# UNKNOWN mode should raise NotImplementedError | ||
mode = VectorStoreQueryMode.UNKNOWN | ||
retrieval_strategy = AsyncRetrievalStrategy() | ||
with pytest.raises(NotImplementedError): | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
# if mode is SPARSE and strategy is not AsyncSparseVectorStrategy, should raise ValueError | ||
mode = VectorStoreQueryMode.SPARSE | ||
retrieval_strategy = AsyncRetrievalStrategy() | ||
with pytest.raises(ValueError): | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
# if mode is HYBRID and strategy is not AsyncDenseVectorStrategy, should raise ValueError | ||
mode = VectorStoreQueryMode.HYBRID | ||
retrieval_strategy = AsyncRetrievalStrategy() | ||
with pytest.raises(ValueError): | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
# if mode is HYBRID and strategy is not AsyncDenseVectorStrategy, should raise ValueError | ||
mode = VectorStoreQueryMode.HYBRID | ||
retrieval_strategy = AsyncRetrievalStrategy() | ||
with pytest.raises(ValueError): | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) | ||
|
||
def test_torch_imports() -> None: | ||
"""Test that torch is an optional dependency.""" | ||
# importing fine-tuning modules should be ok | ||
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine # noqa | ||
from llama_index.finetuning import OpenAIFinetuneEngine # noqa | ||
from llama_index.finetuning import SentenceTransformersFinetuneEngine # noqa | ||
|
||
# if torch isn't installed, then these should fail | ||
if pkgutil.find_loader("torch") is None: | ||
with pytest.raises(ModuleNotFoundError): | ||
from llama_index.embeddings.adapter.utils import LinearLayer | ||
else: | ||
# else, importing these should be ok | ||
from llama_index.embeddings.adapter.utils import LinearLayer # noqa | ||
# if mode is HYBRID and strategy is AsyncDenseVectorStrategy but hybrid is not enabled, should raise ValueError | ||
mode = VectorStoreQueryMode.HYBRID | ||
retrieval_strategy = AsyncDenseVectorStrategy(hybrid=False) | ||
with pytest.raises(ValueError): | ||
_mode_must_match_retrieval_strategy(mode, retrieval_strategy) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 14 additions & 1 deletion
15
...ama-index-vector-stores-elasticsearch/llama_index/vector_stores/elasticsearch/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,16 @@ | ||
from llama_index.vector_stores.elasticsearch.base import ElasticsearchStore | ||
|
||
__all__ = ["ElasticsearchStore"] | ||
from elasticsearch.helpers.vectorstore import ( | ||
AsyncBM25Strategy, | ||
AsyncSparseVectorStrategy, | ||
AsyncDenseVectorStrategy, | ||
AsyncRetrievalStrategy, | ||
) | ||
|
||
__all__ = [ | ||
"AsyncBM25Strategy", | ||
"AsyncDenseVectorStrategy", | ||
"AsyncRetrievalStrategy", | ||
"AsyncSparseVectorStrategy", | ||
"ElasticsearchStore", | ||
] |
Oops, something went wrong.