Skip to content

Commit

Permalink
Make 'endpoint' parameter optional for DatabricksVectorSearch
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 committed Sep 20, 2024
1 parent a743461 commit 4380112
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
33 changes: 29 additions & 4 deletions libs/databricks/langchain_databricks/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ class DatabricksVectorSearch(VectorStore):
from langchain_databricks.vectorstores import DatabricksVectorSearch
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>"
)
Expand All @@ -102,12 +101,24 @@ class DatabricksVectorSearch(VectorStore):
from langchain_openai import OpenAIEmbeddings
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
embedding=OpenAIEmbeddings(),
text_column="document_content"
)
.. note::
If you are using `databricks-vectorsearch` version earlier than 0.35, you also need to
provide the `endpoint` parameter when initializing the vector store.
.. code-block:: python
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
...
)
Add Documents:
.. code-block:: python
from langchain_core.documents import Document
Expand Down Expand Up @@ -196,8 +207,8 @@ class DatabricksVectorSearch(VectorStore):

def __init__(
self,
endpoint: str,
index_name: str,
endpoint: Optional[str] = None,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
columns: Optional[List[str]] = None,
Expand All @@ -212,7 +223,21 @@ def __init__(
"Please install it with `pip install databricks-vectorsearch`."
) from e

self.index = VectorSearchClient().get_index(endpoint, index_name)
try:
self.index = VectorSearchClient().get_index(
endpoint_name=endpoint, index_name=index_name
)
except Exception as e:
if endpoint is None and "Wrong vector search endpoint" in str(e):
raise ValueError(
"The `endpoint` parameter is required for instantiating "
"DatabricksVectorSearch with the `databricks-vectorsearch` "
"version earlier than 0.35. Please provide the endpoint "
"name or upgrade to version 0.35 or later."
) from e
else:
raise

self._index_details = IndexDetails(self.index)

_validate_embedding(embedding, self._index_details)
Expand Down
22 changes: 12 additions & 10 deletions libs/databricks/tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ def embed_query(self, text: str) -> List[float]:

@pytest.fixture(autouse=True)
def mock_vs_client() -> Generator:
def _get_index(endpoint: str, index_name: str) -> MagicMock:
def _get_index(
endpoint_name: Optional[str] = None,
index_name: str = None, # type: ignore
) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex # type: ignore

if endpoint != ENDPOINT_NAME:
raise ValueError(f"Unknown endpoint: {endpoint}")

index = MagicMock(spec=VectorSearchIndex)
index.describe.return_value = INDEX_DETAILS[index_name]
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
Expand All @@ -157,7 +157,6 @@ def init_vector_search(
index_name: str, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
kwargs: Dict[str, Any] = {
"endpoint": ENDPOINT_NAME,
"index_name": index_name,
"columns": columns,
}
Expand All @@ -177,10 +176,17 @@ def test_init(index_name: str) -> None:
assert vectorsearch.index.describe() == INDEX_DETAILS[index_name]


def test_init_with_endpoint_name() -> None:
vectorsearch = DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
)
assert vectorsearch.index.describe() == INDEX_DETAILS[DELTA_SYNC_INDEX]


def test_init_fail_text_column_mismatch() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
text_column="some_other_column",
)
Expand All @@ -190,7 +196,6 @@ def test_init_fail_text_column_mismatch() -> None:
def test_init_fail_no_text_column(index_name: str) -> None:
with pytest.raises(ValueError, match="The `text_column` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
embedding=EMBEDDING_MODEL,
)
Expand All @@ -206,7 +211,6 @@ def test_init_fail_columns_not_in_schema() -> None:
def test_init_fail_no_embedding(index_name: str) -> None:
with pytest.raises(ValueError, match="The `embedding` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
)
Expand All @@ -215,7 +219,6 @@ def test_init_fail_no_embedding(index_name: str) -> None:
def test_init_fail_embedding_already_specified_in_source() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
embedding=EMBEDDING_MODEL,
)
Expand All @@ -227,7 +230,6 @@ def test_init_fail_embedding_dim_mismatch(index_name: str) -> None:
ValueError, match="embedding model's dimension '1000' does not match"
):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
embedding=FakeEmbeddings(1000),
Expand Down

0 comments on commit 4380112

Please sign in to comment.