From 4086f77a55f4ec6b7a7a5daacb43b6c6e25ede8d Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Tue, 1 Oct 2024 17:26:11 +0200 Subject: [PATCH] auto environment --- .../langchain_astradb/utils/astradb.py | 37 ++++++++-- .../unit_tests/test_astra_db_environment.py | 72 +++++++++++++++++++ 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/libs/astradb/langchain_astradb/utils/astradb.py b/libs/astradb/langchain_astradb/utils/astradb.py index 6c67d34..e72f88a 100644 --- a/libs/astradb/langchain_astradb/utils/astradb.py +++ b/libs/astradb/langchain_astradb/utils/astradb.py @@ -14,6 +14,8 @@ import langchain_core from astrapy import AsyncDatabase, DataAPIClient, Database +from astrapy.admin import parse_api_endpoint +from astrapy.constants import Environment from astrapy.exceptions import DataAPIException if TYPE_CHECKING: @@ -60,7 +62,7 @@ def _survey_collection( namespace: str | None = None, ) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]: """Return the collection descriptor (if found) and a sample of documents.""" - _environment = _AstraDBEnvironment( + _astra_db_env = _AstraDBEnvironment( token=token, api_endpoint=api_endpoint, environment=environment, @@ -70,14 +72,14 @@ def _survey_collection( ) descriptors = [ coll_d - for coll_d in _environment.database.list_collections() + for coll_d in _astra_db_env.database.list_collections() if coll_d.name == collection_name ] if not descriptors: return None, [] descriptor = descriptors[0] # fetch some documents - document_ite = _environment.database.get_collection(collection_name).find( + document_ite = _astra_db_env.database.get_collection(collection_name).find( filter={}, projection={"*": True}, limit=SURVEY_NUMBER_OF_DOCUMENTS, @@ -85,6 +87,28 @@ def _survey_collection( return (descriptor, list(document_ite)) +def _normalize_data_api_environment( + arg_environment: str | None, + api_endpoint: str, +) -> str: + _environment: str + if arg_environment is not None: + return arg_environment + parsed_endpoint = parse_api_endpoint(api_endpoint) + if parsed_endpoint is None: + logger.info( + "Detecting API environment '%s' from supplied endpoint", + Environment.OTHER, + ) + return Environment.OTHER + + logger.info( + "Detecting API environment '%s' from supplied endpoint", + parsed_endpoint.environment, + ) + return parsed_endpoint.environment + + class _AstraDBEnvironment: def __init__( self, @@ -208,8 +232,6 @@ def __init__( self.api_endpoint = _api_endpoint self.namespace = _namespace - self.environment = environment - # init parameters are normalized to self.{token, api_endpoint, namespace}. # Proceed. Namespace and token can be None (resp. on Astra DB and non-Astra) if self.api_endpoint is None: @@ -220,6 +242,11 @@ def __init__( ) raise ValueError(msg) + self.environment = _normalize_data_api_environment( + environment, + self.api_endpoint, + ) + # create the clients caller_name = "langchain" caller_version = getattr(langchain_core, "__version__", None) diff --git a/libs/astradb/tests/unit_tests/test_astra_db_environment.py b/libs/astradb/tests/unit_tests/test_astra_db_environment.py index 7d0137c..a105710 100644 --- a/libs/astradb/tests/unit_tests/test_astra_db_environment.py +++ b/libs/astradb/tests/unit_tests/test_astra_db_environment.py @@ -1,6 +1,7 @@ import os import pytest +from astrapy.constants import Environment from astrapy.db import AstraDB from langchain_astradb.utils.astradb import ( @@ -251,3 +252,74 @@ def test_initialization(self) -> None: del os.environ[NAMESPACE_ENV_VAR] for env_var_name, env_var_value in env_vars_to_restore.items(): os.environ[env_var_name] = env_var_value + + def test_env_autodetect(self) -> None: + a_e_string_prod = ( + "https://01234567-89ab-cdef-0123-456789abcdef-us-east1" + ".apps.astra.datastax.com" + ) + a_e_string_dev = ( + "https://01234567-89ab-cdef-0123-456789abcdef-us-east1" + ".apps.astra-dev.datastax.com" + ) + a_e_string_other = "http://localhost:1234" + mock_astra_db_prod = AstraDB( + token=FAKE_TOKEN, + api_endpoint=a_e_string_prod, + namespace="n", + ) + mock_astra_db_dev = AstraDB( + token=FAKE_TOKEN, + api_endpoint=a_e_string_dev, + namespace="n", + ) + mock_astra_db_other = AstraDB( + token=FAKE_TOKEN, + api_endpoint=a_e_string_other, + namespace="n", + ) + + a_env_prod = _AstraDBEnvironment( + token=FAKE_TOKEN, + api_endpoint=a_e_string_prod, + namespace="n", + ) + assert a_env_prod.environment == Environment.PROD + a_env_dev = _AstraDBEnvironment( + token=FAKE_TOKEN, + api_endpoint=a_e_string_dev, + namespace="n", + ) + assert a_env_dev.environment == Environment.DEV + a_env_other = _AstraDBEnvironment( + token=FAKE_TOKEN, + api_endpoint=a_e_string_other, + namespace="n", + ) + assert a_env_other.environment == Environment.OTHER + + # a funny case + with pytest.raises(ValueError, match="mismatch"): + _AstraDBEnvironment( + token=FAKE_TOKEN, + api_endpoint=a_e_string_prod, + namespace="n", + environment=Environment.DEV, + ) + + # initialization using the core clients + with pytest.warns(DeprecationWarning): + a_env_prod_core = _AstraDBEnvironment( + astra_db_client=mock_astra_db_prod, + ) + assert a_env_prod_core.environment == Environment.PROD + with pytest.warns(DeprecationWarning): + a_env_dev_core = _AstraDBEnvironment( + astra_db_client=mock_astra_db_dev, + ) + assert a_env_dev_core.environment == Environment.DEV + with pytest.warns(DeprecationWarning): + a_env_other_core = _AstraDBEnvironment( + astra_db_client=mock_astra_db_other, + ) + assert a_env_other_core.environment == Environment.OTHER