diff --git a/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py b/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py index 24bb3852..1e4a6e40 100644 --- a/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py +++ b/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py @@ -1,5 +1,6 @@ import json import logging +from importlib.metadata import version from typing import Dict, List, Optional from langchain_core.chat_history import BaseChatMessageHistory @@ -9,6 +10,7 @@ messages_from_dict, ) from pymongo import MongoClient, errors +from pymongo.driver_info import DriverInfo logger = logging.getLogger(__name__) @@ -112,7 +114,12 @@ def __init__( self.client = client elif connection_string: try: - self.client = MongoClient(connection_string) + self.client = MongoClient( + connection_string, + driver=DriverInfo( + name="Langchain", version=version("langchain-mongodb") + ), + ) except errors.ConnectionFailure as error: logger.error(error) else: diff --git a/libs/langchain-mongodb/langchain_mongodb/indexes.py b/libs/langchain-mongodb/langchain_mongodb/indexes.py index 64cf63e1..d422c937 100644 --- a/libs/langchain-mongodb/langchain_mongodb/indexes.py +++ b/libs/langchain-mongodb/langchain_mongodb/indexes.py @@ -3,12 +3,14 @@ import functools import warnings +from importlib.metadata import version from typing import Any, Dict, List, Optional, Sequence from langchain_core.indexing.base import RecordManager from langchain_core.runnables.config import run_in_executor from pymongo import MongoClient from pymongo.collection import Collection +from pymongo.driver_info import DriverInfo from pymongo.errors import OperationFailure @@ -47,7 +49,10 @@ def from_connection_string( Returns: A new MongoDBRecordManager instance. """ - client: MongoClient = MongoClient(connection_string) + client: MongoClient = MongoClient( + connection_string, + driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), + ) db_name, collection_name = namespace.split(".") collection = client[db_name][collection_name] return cls(collection=collection) diff --git a/libs/langchain-mongodb/langchain_mongodb/loaders.py b/libs/langchain-mongodb/langchain_mongodb/loaders.py index 3c4e5ac8..0112dacc 100644 --- a/libs/langchain-mongodb/langchain_mongodb/loaders.py +++ b/libs/langchain-mongodb/langchain_mongodb/loaders.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging +from importlib.metadata import version from typing import Dict, List, Optional, Sequence from langchain_community.document_loaders.base import BaseLoader @@ -9,6 +10,7 @@ from langchain_core.runnables.config import run_in_executor from pymongo import MongoClient from pymongo.collection import Collection +from pymongo.driver_info import DriverInfo logger = logging.getLogger(__name__) @@ -80,7 +82,10 @@ def from_connection_string( include_db_collection_in_metadata (bool): Flag to include database and collection names in metadata. """ - client = MongoClient(connection_string) + client = MongoClient( + connection_string, + driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), + ) collection = client[db_name][collection_name] return MongoDBLoader( collection, diff --git a/libs/langchain-mongodb/langchain_mongodb/retrievers/parent_document.py b/libs/langchain-mongodb/langchain_mongodb/retrievers/parent_document.py index 0f475c48..5b5ffe94 100644 --- a/libs/langchain-mongodb/langchain_mongodb/retrievers/parent_document.py +++ b/libs/langchain-mongodb/langchain_mongodb/retrievers/parent_document.py @@ -168,7 +168,7 @@ def from_connection_string( """ client: MongoClient = MongoClient( connection_string, - driver=DriverInfo(name="langchain", version=version("langchain-mongodb")), + driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), ) collection = client[database_name][collection_name] vectorstore = MongoDBAtlasVectorSearch( diff --git a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py index 58271130..6d816b1b 100644 --- a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py +++ b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py @@ -79,7 +79,7 @@ class MongoDBAtlasVectorSearch(VectorStore): .. code-block:: python import getpass - MONGODB_ATLAS_CLUSTER_URI = getpass.getpass("MongoDB Atlas Cluster URI:") + MONGODB_ATLAS_CONNECTION_STRING = getpass.getpass("MongoDB Atlas Connection String:") Key init args — indexing params: embedding: Embeddings @@ -99,20 +99,11 @@ class MongoDBAtlasVectorSearch(VectorStore): from pymongo import MongoClient from langchain_openai import OpenAIEmbeddings - # initialize MongoDB python client - client = MongoClient(MONGODB_ATLAS_CLUSTER_URI) - - DB_NAME = "langchain_test_db" - COLLECTION_NAME = "langchain_test_vectorstores" - ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain-test-index-vectorstores" - - MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME] - - vector_store = MongoDBAtlasVectorSearch( - collection=MONGODB_COLLECTION, + vector_store = MongoDBAtlasVectorSearch.from_connection_string( + connection_string=os=MONGODB_ATLAS_CONNECTION_STRING, + namespace="db_name.collection_name", embedding=OpenAIEmbeddings(), - index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME, - relevance_score_fn="cosine", + index_name="vector_index", ) Add Documents: @@ -279,7 +270,7 @@ def from_connection_string( """ client: MongoClient = MongoClient( connection_string, - driver=DriverInfo(name="Langchain", version=version("langchain")), + driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), ) db_name, collection_name = namespace.split(".") collection = client[db_name][collection_name] diff --git a/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py b/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py index 57a8a537..5307a8e4 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_parent_document.py @@ -33,7 +33,7 @@ def test_1clxn_retriever( # Setup client: MongoClient = MongoClient( connection_string, - driver=DriverInfo(name="langchain", version=version("langchain-mongodb")), + driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), ) db = client[DB_NAME] combined_clxn = db[COLLECTION_NAME] diff --git a/libs/langchain-mongodb/tests/unit_tests/test_chat_message_histories.py b/libs/langchain-mongodb/tests/unit_tests/test_chat_message_histories.py index fd4e2f8b..355ad83a 100644 --- a/libs/langchain-mongodb/tests/unit_tests/test_chat_message_histories.py +++ b/libs/langchain-mongodb/tests/unit_tests/test_chat_message_histories.py @@ -1,9 +1,11 @@ import json +from importlib.metadata import version import mongomock import pytest from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found] from langchain_core.messages import message_to_dict +from pymongo.driver_info import DriverInfo from pytest_mock import MockerFixture from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory @@ -59,7 +61,10 @@ def test_init_with_connection_string(mocker: MockerFixture) -> None: collection_name="test-collection", ) - mock_mongo_client.assert_called_once_with("mongodb://localhost:27017/") + mock_mongo_client.assert_called_once_with( + "mongodb://localhost:27017/", + driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")), + ) assert history.session_id == "test-session" assert history.database_name == "test-database" assert history.collection_name == "test-collection" diff --git a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py index 98039430..d8b2c327 100644 --- a/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py +++ b/libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py @@ -1,5 +1,6 @@ from collections.abc import Iterator, Sequence from contextlib import contextmanager +from importlib.metadata import version from typing import ( Any, Optional, @@ -8,6 +9,7 @@ from langchain_core.runnables import RunnableConfig from pymongo import MongoClient, UpdateOne from pymongo.database import Database as MongoDatabase +from pymongo.driver_info import DriverInfo from langgraph.checkpoint.base import ( WRITES_IDX_MAP, @@ -88,7 +90,12 @@ def from_conn_string( """ client: Optional[MongoClient] = None try: - client = MongoClient(conn_string) + client = MongoClient( + conn_string, + driver=DriverInfo( + name="Langgraph", version=version("langgraph-checkpoint-mongodb") + ), + ) yield MongoDBSaver( client, db_name,