Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use astrapy 1.5+ naming conventions #87

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/astradb/langchain_astradb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
Expand Down Expand Up @@ -411,7 +411,7 @@ async def _acache_embedding(text: str) -> list[float]:
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
Expand Down
2 changes: 1 addition & 1 deletion libs/astradb/langchain_astradb/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
Expand Down
4 changes: 2 additions & 2 deletions libs/astradb/langchain_astradb/document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=namespace,
setup_mode=SetupMode.OFF,
)
self.astra_db_env = astra_db_env
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
self.page_content_mapper = page_content_mapper
self.metadata_mapper = metadata_mapper or (
lambda _: {
"namespace": self.astra_db_env.database.namespace,
"namespace": self.astra_db_env.database.keyspace,
"api_endpoint": self.astra_db_env.database.api_endpoint,
"collection": collection_name,
}
Expand Down
4 changes: 4 additions & 0 deletions libs/astradb/langchain_astradb/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
raise ValueError(msg)
kwargs["requested_indexing_policy"] = {"allow": ["_id"]}
kwargs["default_indexing_policy"] = {"allow": ["_id"]}

if "namespace" in kwargs:
kwargs["keyspace"] = kwargs.pop("namespace")

self.astra_env = _AstraDBCollectionEnvironment(
*args,
**kwargs,
Expand Down
42 changes: 21 additions & 21 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

TOKEN_ENV_VAR = "ASTRA_DB_APPLICATION_TOKEN" # noqa: S105
API_ENDPOINT_ENV_VAR = "ASTRA_DB_API_ENDPOINT"
NAMESPACE_ENV_VAR = "ASTRA_DB_KEYSPACE"
KEYSPACE_ENV_VAR = "ASTRA_DB_KEYSPACE"

# Default settings for API data operations (concurrency & similar):
# Chunk size for many-document insertions (None meaning defer to astrapy):
Expand Down Expand Up @@ -57,7 +57,7 @@ def _survey_collection(
environment: str | None = None,
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
keyspace: str | None = None,
) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]:
"""Return the collection descriptor (if found) and a sample of documents."""
_environment = _AstraDBEnvironment(
Expand All @@ -66,7 +66,7 @@ def _survey_collection(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=keyspace,
)
descriptors = [
coll_d
Expand All @@ -93,11 +93,11 @@ def __init__(
environment: str | None = None,
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
keyspace: str | None = None,
) -> None:
self.token: str | TokenProvider | None
self.api_endpoint: str | None
self.namespace: str | None
self.keyspace: str | None
self.environment: str | None

self.data_api_client: DataAPIClient
Expand Down Expand Up @@ -147,7 +147,7 @@ def __init__(
if klient is not None
}
)
_namespaces = list(
_keyspaces = list(
{
klient.namespace
for klient in [astra_db_client, async_astra_db_client]
Expand All @@ -164,21 +164,21 @@ def __init__(
if len(_api_endpoints) != 1:
msg = (
"Conflicting API endpoints found in the sync and async "
"AstraDB constructor parameters. Please check the tokens "
"AstraDB constructor parameters. Please check the endpoints "
"and ensure they match."
)
raise ValueError(msg)
if len(_namespaces) != 1:
if len(_keyspaces) != 1:
msg = (
"Conflicting namespaces found in the sync and async "
"AstraDB constructor parameters. Please check the tokens "
"and ensure they match."
"Conflicting keyspaces found in the sync and async "
"AstraDB constructor parameters' 'namespace' attributes. "
"Please check the keyspaces and ensure they match."
)
raise ValueError(msg)
# all good: these are 1-element lists here
self.token = _tokens[0]
self.api_endpoint = _api_endpoints[0]
self.namespace = _namespaces[0]
self.keyspace = _keyspaces[0]
else:
_token: str | TokenProvider | None
# secrets-based initialization
Expand All @@ -199,19 +199,19 @@ def __init__(
_api_endpoint = os.environ.get(API_ENDPOINT_ENV_VAR)
else:
_api_endpoint = api_endpoint
if namespace is None:
_namespace = os.environ.get(NAMESPACE_ENV_VAR)
if keyspace is None:
_keyspace = os.environ.get(KEYSPACE_ENV_VAR)
else:
_namespace = namespace
_keyspace = keyspace

self.token = _token
self.api_endpoint = _api_endpoint
self.namespace = _namespace
self.keyspace = _keyspace

self.environment = environment

# init parameters are normalized to self.{token, api_endpoint, namespace}.
# Proceed. Namespace and token can be None (resp. on Astra DB and non-Astra)
# init parameters are normalized to self.{token, api_endpoint, keyspace}.
# Proceed. Keyspace and token can be None (resp. on Astra DB and non-Astra)
if self.api_endpoint is None:
msg = (
"API endpoint for Data API not provided. "
Expand All @@ -232,7 +232,7 @@ def __init__(
self.database = self.data_api_client.get_database(
api_endpoint=self.api_endpoint,
token=self.token,
keyspace=self.namespace,
keyspace=self.keyspace,
)
self.async_database = self.database.to_async()

Expand All @@ -247,7 +247,7 @@ def __init__(
environment: str | None = None,
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
keyspace: str | None = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: int | Awaitable[int] | None = None,
Expand All @@ -263,7 +263,7 @@ def __init__(
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
keyspace=keyspace,
)
self.collection_name = collection_name
self.collection = self.database.get_collection(
Expand Down
4 changes: 2 additions & 2 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def __init__(
environment=self.environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=self.namespace,
keyspace=self.namespace,
)
if c_descriptor is None:
msg = f"Collection '{self.collection_name}' not found."
Expand Down Expand Up @@ -653,7 +653,7 @@ def __init__(
environment=self.environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=self.namespace,
keyspace=self.namespace,
setup_mode=_setup_mode,
pre_delete_collection=pre_delete_collection,
embedding_dimension=_embedding_dimension,
Expand Down
2 changes: 1 addition & 1 deletion libs/astradb/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def database(
db = client.get_database(
astra_db_credentials["api_endpoint"],
token=StaticTokenProvider(astra_db_credentials["token"]),
namespace=astra_db_credentials["namespace"],
keyspace=astra_db_credentials["namespace"],
)
if not is_astra_db:
if astra_db_credentials["namespace"] is None:
Expand Down
6 changes: 3 additions & 3 deletions libs/astradb/tests/integration_tests/test_document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_astradb_loader_base_sync(
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": database.namespace,
"namespace": database.keyspace,
"api_endpoint": astra_db_credentials["api_endpoint"],
"collection": document_loader_collection.name,
}
Expand Down Expand Up @@ -189,7 +189,7 @@ async def test_astradb_loader_prefetched_async(
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": database.namespace,
"namespace": database.keyspace,
"api_endpoint": astra_db_credentials["api_endpoint"],
"collection": async_document_loader_collection.name,
}
Expand Down Expand Up @@ -234,7 +234,7 @@ async def test_astradb_loader_base_async(
assert content["_id"] not in ids
ids.add(content["_id"])
assert doc.metadata == {
"namespace": database.namespace,
"namespace": database.keyspace,
"api_endpoint": astra_db_credentials["api_endpoint"],
"collection": async_document_loader_collection.name,
}
Expand Down
28 changes: 14 additions & 14 deletions libs/astradb/tests/unit_tests/test_astra_db_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from langchain_astradb.utils.astradb import (
API_ENDPOINT_ENV_VAR,
NAMESPACE_ENV_VAR,
KEYSPACE_ENV_VAR,
TOKEN_ENV_VAR,
_AstraDBEnvironment,
)
Expand Down Expand Up @@ -47,15 +47,15 @@ def test_initialization(self) -> None:
API_ENDPOINT_ENV_VAR
]
del os.environ[API_ENDPOINT_ENV_VAR]
if NAMESPACE_ENV_VAR in os.environ:
env_vars_to_restore[NAMESPACE_ENV_VAR] = os.environ[NAMESPACE_ENV_VAR]
del os.environ[NAMESPACE_ENV_VAR]
if KEYSPACE_ENV_VAR in os.environ:
env_vars_to_restore[KEYSPACE_ENV_VAR] = os.environ[KEYSPACE_ENV_VAR]
del os.environ[KEYSPACE_ENV_VAR]

# token+endpoint
env1 = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string,
namespace="n",
keyspace="n",
)

# through a core AstraDB instance
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_initialization(self) -> None:
)
with pytest.raises(
ValueError,
match="Conflicting namespaces found in the sync and async AstraDB "
match="Conflicting keyspaces found in the sync and async AstraDB "
"constructor parameters.",
), pytest.warns(DeprecationWarning):
_AstraDBEnvironment(
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_initialization(self) -> None:
os.environ[TOKEN_ENV_VAR] = "t"
env4 = _AstraDBEnvironment(
api_endpoint=a_e_string,
namespace="n",
keyspace="n",
)
del os.environ[TOKEN_ENV_VAR]
assert env1.data_api_client == env4.data_api_client
Expand All @@ -185,7 +185,7 @@ def test_initialization(self) -> None:
os.environ[API_ENDPOINT_ENV_VAR] = a_e_string
env5 = _AstraDBEnvironment(
token=FAKE_TOKEN,
namespace="n",
keyspace="n",
)
del os.environ[API_ENDPOINT_ENV_VAR]
assert env1.data_api_client == env5.data_api_client
Expand All @@ -195,19 +195,19 @@ def test_initialization(self) -> None:
# both and also namespace via env vars
os.environ[TOKEN_ENV_VAR] = FAKE_TOKEN
os.environ[API_ENDPOINT_ENV_VAR] = a_e_string
os.environ[NAMESPACE_ENV_VAR] = "n"
os.environ[KEYSPACE_ENV_VAR] = "n"
env6 = _AstraDBEnvironment()
assert env1.data_api_client == env6.data_api_client
assert env1.database == env6.database
assert env1.async_database == env6.async_database
del os.environ[TOKEN_ENV_VAR]
del os.environ[API_ENDPOINT_ENV_VAR]
del os.environ[NAMESPACE_ENV_VAR]
del os.environ[KEYSPACE_ENV_VAR]

# env vars do not interfere if client(s) passed
os.environ[TOKEN_ENV_VAR] = "NO!"
os.environ[API_ENDPOINT_ENV_VAR] = "NO!"
os.environ[NAMESPACE_ENV_VAR] = "NO!"
os.environ[KEYSPACE_ENV_VAR] = "NO!"
with pytest.warns(DeprecationWarning):
env7a = _AstraDBEnvironment(
async_astra_db_client=mock_astra_db.to_async(),
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_initialization(self) -> None:
env8 = _AstraDBEnvironment(
token=FAKE_TOKEN,
api_endpoint=a_e_string,
namespace="n",
keyspace="n",
)
assert env1.data_api_client == env8.data_api_client
assert env1.database == env8.database
Expand All @@ -247,7 +247,7 @@ def test_initialization(self) -> None:
del os.environ[TOKEN_ENV_VAR]
if API_ENDPOINT_ENV_VAR in os.environ:
del os.environ[API_ENDPOINT_ENV_VAR]
if NAMESPACE_ENV_VAR in os.environ:
del os.environ[NAMESPACE_ENV_VAR]
if KEYSPACE_ENV_VAR in os.environ:
del os.environ[KEYSPACE_ENV_VAR]
for env_var_name, env_var_value in env_vars_to_restore.items():
os.environ[env_var_name] = env_var_value
Loading