Skip to content

Commit

Permalink
auto environment (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus authored Oct 1, 2024
1 parent a64adcd commit 8227c0b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
39 changes: 33 additions & 6 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import langchain_core
from astrapy import AsyncDatabase, DataAPIClient, Database
from astrapy.admin import parse_api_endpoint
from astrapy.constants import Environment
from astrapy.exceptions import DataAPIException

if TYPE_CHECKING:
Expand Down Expand Up @@ -60,7 +62,7 @@ def _survey_collection(
keyspace: str | None = None,
) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]:
"""Return the collection descriptor (if found) and a sample of documents."""
_environment = _AstraDBEnvironment(
_astra_db_env = _AstraDBEnvironment(
token=token,
api_endpoint=api_endpoint,
environment=environment,
Expand All @@ -70,21 +72,43 @@ def _survey_collection(
)
descriptors = [
coll_d
for coll_d in _environment.database.list_collections()
for coll_d in _astra_db_env.database.list_collections()
if coll_d.name == collection_name
]
if not descriptors:
return None, []
descriptor = descriptors[0]
# fetch some documents
document_ite = _environment.database.get_collection(collection_name).find(
document_ite = _astra_db_env.database.get_collection(collection_name).find(
filter={},
projection={"*": True},
limit=SURVEY_NUMBER_OF_DOCUMENTS,
)
return (descriptor, list(document_ite))


def _normalize_data_api_environment(
arg_environment: str | None,
api_endpoint: str,
) -> str:
_environment: str
if arg_environment is not None:
return arg_environment
parsed_endpoint = parse_api_endpoint(api_endpoint)
if parsed_endpoint is None:
logger.info(
"Detecting API environment '%s' from supplied endpoint",
Environment.OTHER,
)
return Environment.OTHER

logger.info(
"Detecting API environment '%s' from supplied endpoint",
parsed_endpoint.environment,
)
return parsed_endpoint.environment


class _AstraDBEnvironment:
def __init__(
self,
Expand Down Expand Up @@ -164,7 +188,7 @@ def __init__(
if len(_api_endpoints) != 1:
msg = (
"Conflicting API endpoints found in the sync and async "
"AstraDB constructor parameters. Please check the endpoints "
"AstraDB constructor parameters. Please check the tokens "
"and ensure they match."
)
raise ValueError(msg)
Expand Down Expand Up @@ -208,8 +232,6 @@ def __init__(
self.api_endpoint = _api_endpoint
self.keyspace = _keyspace

self.environment = environment

# init parameters are normalized to self.{token, api_endpoint, keyspace}.
# Proceed. Keyspace and token can be None (resp. on Astra DB and non-Astra)
if self.api_endpoint is None:
Expand All @@ -220,6 +242,11 @@ def __init__(
)
raise ValueError(msg)

self.environment = _normalize_data_api_environment(
environment,
self.api_endpoint,
)

# create the clients
caller_name = "langchain"
caller_version = getattr(langchain_core, "__version__", None)
Expand Down
72 changes: 72 additions & 0 deletions libs/astradb/tests/unit_tests/test_astra_db_environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
from astrapy.constants import Environment
from astrapy.db import AstraDB

from langchain_astradb.utils.astradb import (
Expand Down Expand Up @@ -251,3 +252,74 @@ def test_initialization(self) -> None:
del os.environ[KEYSPACE_ENV_VAR]
for env_var_name, env_var_value in env_vars_to_restore.items():
os.environ[env_var_name] = env_var_value

def test_env_autodetect(self) -> None:
a_e_string_prod = (
"https://01234567-89ab-cdef-0123-456789abcdef-us-east1"
".apps.astra.datastax.com"
)
a_e_string_dev = (
"https://01234567-89ab-cdef-0123-456789abcdef-us-east1"
".apps.astra-dev.datastax.com"
)
a_e_string_other = "http://localhost:1234"
mock_astra_db_prod = AstraDB(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
namespace="n",
)
mock_astra_db_dev = AstraDB(
token=FAKE_TOKEN,
api_endpoint=a_e_string_dev,
namespace="n",
)
mock_astra_db_other = AstraDB(
token=FAKE_TOKEN,
api_endpoint=a_e_string_other,
namespace="n",
)

a_env_prod = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
keyspace="n",
)
assert a_env_prod.environment == Environment.PROD
a_env_dev = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_dev,
keyspace="n",
)
assert a_env_dev.environment == Environment.DEV
a_env_other = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_other,
keyspace="n",
)
assert a_env_other.environment == Environment.OTHER

# a funny case
with pytest.raises(ValueError, match="mismatch"):
_AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
keyspace="n",
environment=Environment.DEV,
)

# initialization using the core clients
with pytest.warns(DeprecationWarning):
a_env_prod_core = _AstraDBEnvironment(
astra_db_client=mock_astra_db_prod,
)
assert a_env_prod_core.environment == Environment.PROD
with pytest.warns(DeprecationWarning):
a_env_dev_core = _AstraDBEnvironment(
astra_db_client=mock_astra_db_dev,
)
assert a_env_dev_core.environment == Environment.DEV
with pytest.warns(DeprecationWarning):
a_env_other_core = _AstraDBEnvironment(
astra_db_client=mock_astra_db_other,
)
assert a_env_other_core.environment == Environment.OTHER

0 comments on commit 8227c0b

Please sign in to comment.