Skip to content

Commit

Permalink
Integrate VectorStore from Elasticsearch client
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjakob committed May 13, 2024
1 parent ce59634 commit 50c80e2
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 416 deletions.
76 changes: 59 additions & 17 deletions llama-index-finetuning/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,62 @@
"""Test finetuning engine."""
import pkgutil

import pytest

from llama_index.vector_stores.elasticsearch.base import (
_mode_must_match_retrieval_strategy,
VectorStoreQueryMode,
AsyncRetrievalStrategy,
AsyncSparseVectorStrategy,
AsyncBM25Strategy,
AsyncDenseVectorStrategy,
)


def test_mode_must_match_retrieval_strategy() -> None:
# DEFAULT mode should never raise any exception
mode = VectorStoreQueryMode.DEFAULT
retrieval_strategy = AsyncRetrievalStrategy()
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

# AsyncSparseVectorStrategy with mode SPARSE should not raise any exception
mode = VectorStoreQueryMode.SPARSE
retrieval_strategy = AsyncSparseVectorStrategy()
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

# AsyncBM25Strategy with TEXT_SEARCH should not raise any exception
mode = VectorStoreQueryMode.TEXT_SEARCH
retrieval_strategy = AsyncBM25Strategy()
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

# AsyncDenseVectorStrategy with mode HYBRID should not raise any exception
mode = VectorStoreQueryMode.HYBRID
retrieval_strategy = AsyncDenseVectorStrategy()
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

# UNKNOWN mode should raise NotImplementedError
mode = VectorStoreQueryMode.UNKNOWN
retrieval_strategy = AsyncRetrievalStrategy()
with pytest.raises(NotImplementedError):
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

# if mode is SPARSE and strategy is not AsyncSparseVectorStrategy, should raise ValueError
mode = VectorStoreQueryMode.SPARSE
retrieval_strategy = AsyncRetrievalStrategy()
with pytest.raises(ValueError):
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

# if mode is HYBRID and strategy is not AsyncDenseVectorStrategy, should raise ValueError
mode = VectorStoreQueryMode.HYBRID
retrieval_strategy = AsyncRetrievalStrategy()
with pytest.raises(ValueError):
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

# if mode is HYBRID and strategy is not AsyncDenseVectorStrategy, should raise ValueError
mode = VectorStoreQueryMode.HYBRID
retrieval_strategy = AsyncRetrievalStrategy()
with pytest.raises(ValueError):
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)

def test_torch_imports() -> None:
"""Test that torch is an optional dependency."""
# importing fine-tuning modules should be ok
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine # noqa
from llama_index.finetuning import OpenAIFinetuneEngine # noqa
from llama_index.finetuning import SentenceTransformersFinetuneEngine # noqa

# if torch isn't installed, then these should fail
if pkgutil.find_loader("torch") is None:
with pytest.raises(ModuleNotFoundError):
from llama_index.embeddings.adapter.utils import LinearLayer
else:
# else, importing these should be ok
from llama_index.embeddings.adapter.utils import LinearLayer # noqa
# if mode is HYBRID and strategy is AsyncDenseVectorStrategy but hybrid is not enabled, should raise ValueError
mode = VectorStoreQueryMode.HYBRID
retrieval_strategy = AsyncDenseVectorStrategy(hybrid=False)
with pytest.raises(ValueError):
_mode_must_match_retrieval_strategy(mode, retrieval_strategy)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests
poetry run pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
from llama_index.vector_stores.elasticsearch.base import ElasticsearchStore

__all__ = ["ElasticsearchStore"]
from elasticsearch.helpers.vectorstore import (
AsyncBM25Strategy,
AsyncSparseVectorStrategy,
AsyncDenseVectorStrategy,
AsyncRetrievalStrategy,
)

__all__ = [
"AsyncBM25Strategy",
"AsyncDenseVectorStrategy",
"AsyncRetrievalStrategy",
"AsyncSparseVectorStrategy",
"ElasticsearchStore",
]
Loading

0 comments on commit 50c80e2

Please sign in to comment.