diff --git a/libs/databricks/langchain_databricks/vectorstores.py b/libs/databricks/langchain_databricks/vectorstores.py index 7359dcf..c2c01a1 100644 --- a/libs/databricks/langchain_databricks/vectorstores.py +++ b/libs/databricks/langchain_databricks/vectorstores.py @@ -89,7 +89,6 @@ class DatabricksVectorSearch(VectorStore): from langchain_databricks.vectorstores import DatabricksVectorSearch vector_store = DatabricksVectorSearch( - endpoint="", index_name="" ) @@ -102,12 +101,24 @@ class DatabricksVectorSearch(VectorStore): from langchain_openai import OpenAIEmbeddings vector_store = DatabricksVectorSearch( - endpoint="", index_name="", embedding=OpenAIEmbeddings(), text_column="document_content" ) + .. note:: + + If you are using `databricks-vectorsearch` version earlier than 0.35, you also need to + provide the `endpoint` parameter when initializing the vector store. + + .. code-block:: python + + vector_store = DatabricksVectorSearch( + endpoint="", + index_name="", + ... + ) + Add Documents: .. code-block:: python from langchain_core.documents import Document @@ -196,8 +207,8 @@ class DatabricksVectorSearch(VectorStore): def __init__( self, - endpoint: str, index_name: str, + endpoint: Optional[str] = None, embedding: Optional[Embeddings] = None, text_column: Optional[str] = None, columns: Optional[List[str]] = None, @@ -212,7 +223,21 @@ def __init__( "Please install it with `pip install databricks-vectorsearch`." ) from e - self.index = VectorSearchClient().get_index(endpoint, index_name) + try: + self.index = VectorSearchClient().get_index( + endpoint_name=endpoint, index_name=index_name + ) + except Exception as e: + if endpoint is None and "Wrong vector search endpoint" in str(e): + raise ValueError( + "The `endpoint` parameter is required for instantiating " + "DatabricksVectorSearch with the `databricks-vectorsearch` " + "version earlier than 0.35. Please provide the endpoint " + "name or upgrade to version 0.35 or later." + ) from e + else: + raise + self._index_details = IndexDetails(self.index) _validate_embedding(embedding, self._index_details) diff --git a/libs/databricks/tests/unit_tests/test_vectorstore.py b/libs/databricks/tests/unit_tests/test_vectorstore.py index ed8654e..164cd5a 100644 --- a/libs/databricks/tests/unit_tests/test_vectorstore.py +++ b/libs/databricks/tests/unit_tests/test_vectorstore.py @@ -133,12 +133,12 @@ def embed_query(self, text: str) -> List[float]: @pytest.fixture(autouse=True) def mock_vs_client() -> Generator: - def _get_index(endpoint: str, index_name: str) -> MagicMock: + def _get_index( + endpoint_name: Optional[str] = None, + index_name: str = None, # type: ignore + ) -> MagicMock: from databricks.vector_search.client import VectorSearchIndex # type: ignore - if endpoint != ENDPOINT_NAME: - raise ValueError(f"Unknown endpoint: {endpoint}") - index = MagicMock(spec=VectorSearchIndex) index.describe.return_value = INDEX_DETAILS[index_name] index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE @@ -157,7 +157,6 @@ def init_vector_search( index_name: str, columns: Optional[List[str]] = None ) -> DatabricksVectorSearch: kwargs: Dict[str, Any] = { - "endpoint": ENDPOINT_NAME, "index_name": index_name, "columns": columns, } @@ -177,10 +176,17 @@ def test_init(index_name: str) -> None: assert vectorsearch.index.describe() == INDEX_DETAILS[index_name] +def test_init_with_endpoint_name() -> None: + vectorsearch = DatabricksVectorSearch( + endpoint=ENDPOINT_NAME, + index_name=DELTA_SYNC_INDEX, + ) + assert vectorsearch.index.describe() == INDEX_DETAILS[DELTA_SYNC_INDEX] + + def test_init_fail_text_column_mismatch() -> None: with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=DELTA_SYNC_INDEX, text_column="some_other_column", ) @@ -190,7 +196,6 @@ def test_init_fail_text_column_mismatch() -> None: def test_init_fail_no_text_column(index_name: str) -> None: with pytest.raises(ValueError, match="The `text_column` parameter is required"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=index_name, embedding=EMBEDDING_MODEL, ) @@ -206,7 +211,6 @@ def test_init_fail_columns_not_in_schema() -> None: def test_init_fail_no_embedding(index_name: str) -> None: with pytest.raises(ValueError, match="The `embedding` parameter is required"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=index_name, text_column="text", ) @@ -215,7 +219,6 @@ def test_init_fail_no_embedding(index_name: str) -> None: def test_init_fail_embedding_already_specified_in_source() -> None: with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=DELTA_SYNC_INDEX, embedding=EMBEDDING_MODEL, ) @@ -227,7 +230,6 @@ def test_init_fail_embedding_dim_mismatch(index_name: str) -> None: ValueError, match="embedding model's dimension '1000' does not match" ): DatabricksVectorSearch( - endpoint=ENDPOINT_NAME, index_name=index_name, text_column="text", embedding=FakeEmbeddings(1000),