Skip to content

Commit

Permalink
bring latest main
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus committed Oct 1, 2024
2 parents 4086f77 + a64adcd commit 850aa26
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 49 deletions.
4 changes: 2 additions & 2 deletions libs/astradb/langchain_astradb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
Expand Down Expand Up @@ -411,7 +411,7 @@ async def _acache_embedding(text: str) -> list[float]:
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
Expand Down
2 changes: 1 addition & 1 deletion libs/astradb/langchain_astradb/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
Expand Down
4 changes: 2 additions & 2 deletions libs/astradb/langchain_astradb/document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=SetupMode.OFF,
)
self.astra_db_env = astra_db_env
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
self.page_content_mapper = page_content_mapper
self.metadata_mapper = metadata_mapper or (
lambda _: {
"namespace": self.astra_db_env.database.namespace,
"namespace": self.astra_db_env.database.keyspace,
"api_endpoint": self.astra_db_env.database.api_endpoint,
"collection": collection_name,
}
Expand Down
4 changes: 4 additions & 0 deletions libs/astradb/langchain_astradb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ValueError(msg)
kwargs["requested_indexing_policy"] = {"allow": ["_id"]}
kwargs["default_indexing_policy"] = {"allow": ["_id"]}

if "namespace" in kwargs:
kwargs["keyspace"] = kwargs.pop("namespace")

self.astra_env = _AstraDBCollectionEnvironment(
*args,
**kwargs,
Expand Down
40 changes: 20 additions & 20 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

TOKEN_ENV_VAR = "ASTRA_DB_APPLICATION_TOKEN" # noqa: S105
API_ENDPOINT_ENV_VAR = "ASTRA_DB_API_ENDPOINT"
NAMESPACE_ENV_VAR = "ASTRA_DB_KEYSPACE"
KEYSPACE_ENV_VAR = "ASTRA_DB_KEYSPACE"

# Default settings for API data operations (concurrency & similar):
# Chunk size for many-document insertions (None meaning defer to astrapy):
Expand Down Expand Up @@ -59,7 +59,7 @@ def _survey_collection(
environment: str | None = None,
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
keyspace: str | None = None,
) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]:
"""Return the collection descriptor (if found) and a sample of documents."""
_astra_db_env = _AstraDBEnvironment(
Expand All @@ -68,7 +68,7 @@ def _survey_collection(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=keyspace,
)
descriptors = [
coll_d
Expand Down Expand Up @@ -117,11 +117,11 @@ def __init__(
environment: str | None = None,
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
keyspace: str | None = None,
) -> None:
self.token: str | TokenProvider | None
self.api_endpoint: str | None
self.namespace: str | None
self.keyspace: str | None
self.environment: str | None

self.data_api_client: DataAPIClient
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(
if klient is not None
}
)
_namespaces = list(
_keyspaces = list(
{
klient.namespace
for klient in [astra_db_client, async_astra_db_client]
Expand All @@ -192,17 +192,17 @@ def __init__(
"and ensure they match."
)
raise ValueError(msg)
if len(_namespaces) != 1:
if len(_keyspaces) != 1:
msg = (
"Conflicting namespaces found in the sync and async "
"AstraDB constructor parameters. Please check the tokens "
"and ensure they match."
"Conflicting keyspaces found in the sync and async "
"AstraDB constructor parameters' 'namespace' attributes. "
"Please check the keyspaces and ensure they match."
)
raise ValueError(msg)
# all good: these are 1-element lists here
self.token = _tokens[0]
self.api_endpoint = _api_endpoints[0]
self.namespace = _namespaces[0]
self.keyspace = _keyspaces[0]
else:
_token: str | TokenProvider | None
# secrets-based initialization
Expand All @@ -223,17 +223,17 @@ def __init__(
_api_endpoint = os.environ.get(API_ENDPOINT_ENV_VAR)
else:
_api_endpoint = api_endpoint
if namespace is None:
_namespace = os.environ.get(NAMESPACE_ENV_VAR)
if keyspace is None:
_keyspace = os.environ.get(KEYSPACE_ENV_VAR)
else:
_namespace = namespace
_keyspace = keyspace

self.token = _token
self.api_endpoint = _api_endpoint
self.namespace = _namespace
self.keyspace = _keyspace

# init parameters are normalized to self.{token, api_endpoint, namespace}.
# Proceed. Namespace and token can be None (resp. on Astra DB and non-Astra)
# init parameters are normalized to self.{token, api_endpoint, keyspace}.
# Proceed. Keyspace and token can be None (resp. on Astra DB and non-Astra)
if self.api_endpoint is None:
msg = (
"API endpoint for Data API not provided. "
Expand All @@ -259,7 +259,7 @@ def __init__(
self.database = self.data_api_client.get_database(
api_endpoint=self.api_endpoint,
token=self.token,
keyspace=self.namespace,
keyspace=self.keyspace,
)
self.async_database = self.database.to_async()

Expand All @@ -274,7 +274,7 @@ def __init__(
environment: str | None = None,
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
keyspace: str | None = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: int | Awaitable[int] | None = None,
Expand All @@ -290,7 +290,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=keyspace,
)
self.collection_name = collection_name
self.collection = self.database.get_collection(
Expand Down
4 changes: 2 additions & 2 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def __init__(
environment=self.environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=self.namespace,
keyspace=self.namespace,
)
if c_descriptor is None:
msg = f"Collection '{self.collection_name}' not found."
Expand Down Expand Up @@ -653,7 +653,7 @@ def __init__(
environment=self.environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=self.namespace,
keyspace=self.namespace,
setup_mode=_setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=_embedding_dimension,
Expand Down
2 changes: 1 addition & 1 deletion libs/astradb/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def database(
db = client.get_database(
astra_db_credentials["api_endpoint"],
token=StaticTokenProvider(astra_db_credentials["token"]),
namespace=astra_db_credentials["namespace"],
keyspace=astra_db_credentials["namespace"],
)
if not is_astra_db:
if astra_db_credentials["namespace"] is None:
Expand Down
6 changes: 3 additions & 3 deletions libs/astradb/tests/integration_tests/test_document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_astradb_loader_base_sync(
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": database.namespace,
"namespace": database.keyspace,
"api_endpoint": astra_db_credentials["api_endpoint"],
"collection": document_loader_collection.name,
}
Expand Down Expand Up @@ -189,7 +189,7 @@ async def test_astradb_loader_prefetched_async(
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": database.namespace,
"namespace": database.keyspace,
"api_endpoint": astra_db_credentials["api_endpoint"],
"collection": async_document_loader_collection.name,
}
Expand Down Expand Up @@ -234,7 +234,7 @@ async def test_astradb_loader_base_async(
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": database.namespace,
"namespace": database.keyspace,
"api_endpoint": astra_db_credentials["api_endpoint"],
"collection": async_document_loader_collection.name,
}
Expand Down
36 changes: 18 additions & 18 deletions libs/astradb/tests/unit_tests/test_astra_db_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from langchain_astradb.utils.astradb import (
API_ENDPOINT_ENV_VAR,
NAMESPACE_ENV_VAR,
KEYSPACE_ENV_VAR,
TOKEN_ENV_VAR,
_AstraDBEnvironment,
)
Expand Down Expand Up @@ -48,15 +48,15 @@ def test_initialization(self) -> None:
API_ENDPOINT_ENV_VAR
]
del os.environ[API_ENDPOINT_ENV_VAR]
if NAMESPACE_ENV_VAR in os.environ:
env_vars_to_restore[NAMESPACE_ENV_VAR] = os.environ[NAMESPACE_ENV_VAR]
del os.environ[NAMESPACE_ENV_VAR]
if KEYSPACE_ENV_VAR in os.environ:
env_vars_to_restore[KEYSPACE_ENV_VAR] = os.environ[KEYSPACE_ENV_VAR]
del os.environ[KEYSPACE_ENV_VAR]

# token+endpoint
env1 = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string,
namespace="n",
keyspace="n",
)

# through a core AstraDB instance
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_initialization(self) -> None:
)
with pytest.raises(
ValueError,
match="Conflicting namespaces found in the sync and async AstraDB "
match="Conflicting keyspaces found in the sync and async AstraDB "
"constructor parameters.",
), pytest.warns(DeprecationWarning):
_AstraDBEnvironment(
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_initialization(self) -> None:
os.environ[TOKEN_ENV_VAR] = "t"
env4 = _AstraDBEnvironment(
api_endpoint=a_e_string,
namespace="n",
keyspace="n",
)
del os.environ[TOKEN_ENV_VAR]
assert env1.data_api_client == env4.data_api_client
Expand All @@ -186,7 +186,7 @@ def test_initialization(self) -> None:
os.environ[API_ENDPOINT_ENV_VAR] = a_e_string
env5 = _AstraDBEnvironment(
token=FAKE_TOKEN,
namespace="n",
keyspace="n",
)
del os.environ[API_ENDPOINT_ENV_VAR]
assert env1.data_api_client == env5.data_api_client
Expand All @@ -196,19 +196,19 @@ def test_initialization(self) -> None:
# both and also namespace via env vars
os.environ[TOKEN_ENV_VAR] = FAKE_TOKEN
os.environ[API_ENDPOINT_ENV_VAR] = a_e_string
os.environ[NAMESPACE_ENV_VAR] = "n"
os.environ[KEYSPACE_ENV_VAR] = "n"
env6 = _AstraDBEnvironment()
assert env1.data_api_client == env6.data_api_client
assert env1.database == env6.database
assert env1.async_database == env6.async_database
del os.environ[TOKEN_ENV_VAR]
del os.environ[API_ENDPOINT_ENV_VAR]
del os.environ[NAMESPACE_ENV_VAR]
del os.environ[KEYSPACE_ENV_VAR]

# env vars do not interfere if client(s) passed
os.environ[TOKEN_ENV_VAR] = "NO!"
os.environ[API_ENDPOINT_ENV_VAR] = "NO!"
os.environ[NAMESPACE_ENV_VAR] = "NO!"
os.environ[KEYSPACE_ENV_VAR] = "NO!"
with pytest.warns(DeprecationWarning):
env7a = _AstraDBEnvironment(
async_astra_db_client=mock_astra_db.to_async(),
Expand Down Expand Up @@ -236,7 +236,7 @@ def test_initialization(self) -> None:
env8 = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string,
namespace="n",
keyspace="n",
)
assert env1.data_api_client == env8.data_api_client
assert env1.database == env8.database
Expand All @@ -248,8 +248,8 @@ def test_initialization(self) -> None:
del os.environ[TOKEN_ENV_VAR]
if API_ENDPOINT_ENV_VAR in os.environ:
del os.environ[API_ENDPOINT_ENV_VAR]
if NAMESPACE_ENV_VAR in os.environ:
del os.environ[NAMESPACE_ENV_VAR]
if KEYSPACE_ENV_VAR in os.environ:
del os.environ[KEYSPACE_ENV_VAR]
for env_var_name, env_var_value in env_vars_to_restore.items():
os.environ[env_var_name] = env_var_value

Expand Down Expand Up @@ -282,19 +282,19 @@ def test_env_autodetect(self) -> None:
a_env_prod = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
namespace="n",
keyspace="n",
)
assert a_env_prod.environment == Environment.PROD
a_env_dev = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_dev,
namespace="n",
keyspace="n",
)
assert a_env_dev.environment == Environment.DEV
a_env_other = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_other,
namespace="n",
keyspace="n",
)
assert a_env_other.environment == Environment.OTHER

Expand All @@ -303,7 +303,7 @@ def test_env_autodetect(self) -> None:
_AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string_prod,
namespace="n",
keyspace="n",
environment=Environment.DEV,
)

Expand Down

0 comments on commit 850aa26

Please sign in to comment.