Skip to content

Commit

Permalink
Include more information about the caller
Browse files Browse the repository at this point in the history
Currently, when making calls to AstraDB, we include a `User-Agent`
header identifying the API caller. This header is currently set to
"langchain" by default at the moment. This change threads the
`caller_name` parameter throughout the API and appends a suffix
identifying which class created the API client:

* langchain/cache
* langchain/chat_message_history
* langchain/document_loader
* langchain/graph_vectorstore
* langchain/semantic_cache
* langchain/vectorstore

This metadata will be useful for debugging and tracking utilization of
different library features.
  • Loading branch information
kerinin committed Sep 25, 2024
1 parent 80ea19c commit 27a9810
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 1 deletion.
4 changes: 4 additions & 0 deletions libs/astradb/langchain_astradb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
namespace: str | None = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
caller_name: str = "langchain/cache",
):
"""Cache that uses Astra DB as a backend.
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
caller_name=caller_name,
)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
Expand Down Expand Up @@ -326,6 +328,7 @@ def __init__(
embedding: Embeddings,
metric: str | None = None,
similarity_threshold: float = ASTRA_DB_SEMANTIC_CACHE_DEFAULT_THRESHOLD,
caller_name: str = "langchain/semantic_cache",
):
"""Astra DB semantic cache.
Expand Down Expand Up @@ -416,6 +419,7 @@ async def _acache_embedding(text: str) -> list[float]:
pre_delete_collection=pre_delete_collection,
embedding_dimension=embedding_dimension,
metric=metric,
caller_name=caller_name,
)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
Expand Down
2 changes: 2 additions & 0 deletions libs/astradb/langchain_astradb/chat_message_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
namespace: str | None = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
caller_name: str = "langchain/chat_message_history",
) -> None:
"""Chat message history that stores history in Astra DB.
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
caller_name=caller_name,
)

self.collection = self.astra_env.collection
Expand Down
2 changes: 2 additions & 0 deletions libs/astradb/langchain_astradb/document_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
nb_prefetched: int = _NOT_SET, # type: ignore[assignment]
page_content_mapper: Callable[[dict], str] = json.dumps,
metadata_mapper: Callable[[dict], dict[str, Any]] | None = None,
caller_name: str = "langchain/document_loader",
) -> None:
"""Load DataStax Astra DB documents.
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=SetupMode.OFF,
caller_name=caller_name,
)
self.astra_db_env = astra_db_env
self.filter = filter_criteria
Expand Down
2 changes: 2 additions & 0 deletions libs/astradb/langchain_astradb/graph_vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
metadata_indexing_include: Iterable[str] | None = None,
metadata_indexing_exclude: Iterable[str] | None = None,
collection_indexing_policy: dict[str, Any] | None = None,
caller_name: str = "langchain/graph_vectorstore",
**kwargs: Any,
):
"""Create a new Graph Vector Store backed by AstraDB."""
Expand All @@ -68,6 +69,7 @@ def __init__(
metadata_indexing_include=metadata_indexing_include,
metadata_indexing_exclude=metadata_indexing_exclude,
collection_indexing_policy=collection_indexing_policy,
caller_name=caller_name,
**kwargs,
)

Expand Down
4 changes: 3 additions & 1 deletion libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
caller_name: str = "langchain",
) -> None:
self.token: str | TokenProvider | None
self.api_endpoint: str | None
Expand Down Expand Up @@ -221,7 +222,6 @@ def __init__(
raise ValueError(msg)

# create the clients
caller_name = "langchain"
caller_version = getattr(langchain_core, "__version__", None)

self.data_api_client = DataAPIClient(
Expand Down Expand Up @@ -256,6 +256,7 @@ def __init__(
default_indexing_policy: dict[str, Any] | None = None,
collection_vector_service_options: CollectionVectorServiceOptions | None = None,
collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None,
caller_name: str = "langchain",
) -> None:
super().__init__(
token=token,
Expand All @@ -264,6 +265,7 @@ def __init__(
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
caller_name=caller_name,
)
self.collection_name = collection_name
self.collection = self.database.get_collection(
Expand Down
2 changes: 2 additions & 0 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __init__(
content_field: str | None = None,
ignore_invalid_documents: bool = False,
autodetect_collection: bool = False,
caller_name: str = "langchain/vectorstore",
) -> None:
"""Wrapper around DataStax Astra DB for vector-store workloads.
Expand Down Expand Up @@ -664,6 +665,7 @@ def __init__(
default_indexing_policy=DEFAULT_INDEXING_OPTIONS,
collection_vector_service_options=self.collection_vector_service_options,
collection_embedding_api_key=self.collection_embedding_api_key,
caller_name=caller_name,
)

def _get_safe_embedding(self) -> Embeddings:
Expand Down

0 comments on commit 27a9810

Please sign in to comment.