Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic environment from data api endpoint (if not supplied) #88

Merged
merged 2 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import langchain_core
from astrapy import AsyncDatabase, DataAPIClient, Database
from astrapy.admin import parse_api_endpoint
from astrapy.constants import Environment
from astrapy.exceptions import DataAPIException

if TYPE_CHECKING:
Expand Down Expand Up @@ -60,7 +62,7 @@ def _survey_collection(
keyspace: str | None = None,
) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]:
"""Return the collection descriptor (if found) and a sample of documents."""
_environment = _AstraDBEnvironment(
_astra_db_env = _AstraDBEnvironment(
token=token,
api_endpoint=api_endpoint,
environment=environment,
Expand All @@ -70,21 +72,43 @@ def _survey_collection(
)
descriptors = [
coll_d
for coll_d in _environment.database.list_collections()
for coll_d in _astra_db_env.database.list_collections()
if coll_d.name == collection_name
]
if not descriptors:
return None, []
descriptor = descriptors[0]
# fetch some documents
document_ite = _environment.database.get_collection(collection_name).find(
document_ite = _astra_db_env.database.get_collection(collection_name).find(
filter={},
projection={"*": True},
limit=SURVEY_NUMBER_OF_DOCUMENTS,
)
return (descriptor, list(document_ite))


def _normalize_data_api_environment(
arg_environment: str | None,
api_endpoint: str,
) -> str:
_environment: str
if arg_environment is not None:
return arg_environment
parsed_endpoint = parse_api_endpoint(api_endpoint)
if parsed_endpoint is None:
logger.info(
"Detecting API environment '%s' from supplied endpoint",
Environment.OTHER,
)
return Environment.OTHER

logger.info(
"Detecting API environment '%s' from supplied endpoint",
parsed_endpoint.environment,
)
return parsed_endpoint.environment


class _AstraDBEnvironment:
def __init__(
self,
Expand Down Expand Up @@ -164,7 +188,7 @@ def __init__(
if len(_api_endpoints) != 1:
msg = (
"Conflicting API endpoints found in the sync and async "
"AstraDB constructor parameters. Please check the endpoints "
"AstraDB constructor parameters. Please check the tokens "
"and ensure they match."
)
raise ValueError(msg)
Expand Down Expand Up @@ -208,8 +232,6 @@ def __init__(
self.api_endpoint = _api_endpoint
self.keyspace = _keyspace

self.environment = environment

# init parameters are normalized to self.{token, api_endpoint, keyspace}.
# Proceed. Keyspace and token can be None (resp. on Astra DB and non-Astra)
if self.api_endpoint is None:
Expand All @@ -220,6 +242,11 @@ def __init__(
)
raise ValueError(msg)

self.environment = _normalize_data_api_environment(
environment,
self.api_endpoint,
)

# create the clients
caller_name = "langchain"
caller_version = getattr(langchain_core, "__version__", None)
Expand Down
72 changes: 72 additions & 0 deletions libs/astradb/tests/unit_tests/test_astra_db_environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
from astrapy.constants import Environment
from astrapy.db import AstraDB

from langchain_astradb.utils.astradb import (
Expand Down Expand Up @@ -251,3 +252,74 @@ def test_initialization(self) -> None:
del os.environ[KEYSPACE_ENV_VAR]
for env_var_name, env_var_value in env_vars_to_restore.items():
os.environ[env_var_name] = env_var_value

def test_env_autodetect(self) -> None:
a_e_string_prod = (
"https://01234567-89ab-cdef-0123-456789abcdef-us-east1"
".apps.astra.datastax.com"
)
a_e_string_dev = (
"https://01234567-89ab-cdef-0123-456789abcdef-us-east1"
".apps.astra-dev.datastax.com"
)
a_e_string_other = "http://localhost:1234"
mock_astra_db_prod = AstraDB(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
namespace="n",
)
mock_astra_db_dev = AstraDB(
token=FAKE_TOKEN,
api_endpoint=a_e_string_dev,
namespace="n",
)
mock_astra_db_other = AstraDB(
token=FAKE_TOKEN,
api_endpoint=a_e_string_other,
namespace="n",
)

a_env_prod = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
keyspace="n",
)
assert a_env_prod.environment == Environment.PROD
a_env_dev = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_dev,
keyspace="n",
)
assert a_env_dev.environment == Environment.DEV
a_env_other = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_other,
keyspace="n",
)
assert a_env_other.environment == Environment.OTHER

# a funny case
with pytest.raises(ValueError, match="mismatch"):
_AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
keyspace="n",
environment=Environment.DEV,
)

# initialization using the core clients
with pytest.warns(DeprecationWarning):
a_env_prod_core = _AstraDBEnvironment(
astra_db_client=mock_astra_db_prod,
)
assert a_env_prod_core.environment == Environment.PROD
with pytest.warns(DeprecationWarning):
a_env_dev_core = _AstraDBEnvironment(
astra_db_client=mock_astra_db_dev,
)
assert a_env_dev_core.environment == Environment.DEV
with pytest.warns(DeprecationWarning):
a_env_other_core = _AstraDBEnvironment(
astra_db_client=mock_astra_db_other,
)
assert a_env_other_core.environment == Environment.OTHER
Loading