Skip to content

Commit

Permalink
INTPYTHON-504 Add DriverInfo to MongoClients (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
caseyclements authored Feb 5, 2025
1 parent ccdef94 commit 88a0858
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from importlib.metadata import version
from typing import Dict, List, Optional

from langchain_core.chat_history import BaseChatMessageHistory
Expand All @@ -9,6 +10,7 @@
messages_from_dict,
)
from pymongo import MongoClient, errors
from pymongo.driver_info import DriverInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,7 +114,12 @@ def __init__(
self.client = client
elif connection_string:
try:
self.client = MongoClient(connection_string)
self.client = MongoClient(
connection_string,
driver=DriverInfo(
name="Langchain", version=version("langchain-mongodb")
),
)
except errors.ConnectionFailure as error:
logger.error(error)
else:
Expand Down
7 changes: 6 additions & 1 deletion libs/langchain-mongodb/langchain_mongodb/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import functools
import warnings
from importlib.metadata import version
from typing import Any, Dict, List, Optional, Sequence

from langchain_core.indexing.base import RecordManager
from langchain_core.runnables.config import run_in_executor
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.errors import OperationFailure


Expand Down Expand Up @@ -47,7 +49,10 @@ def from_connection_string(
Returns:
A new MongoDBRecordManager instance.
"""
client: MongoClient = MongoClient(connection_string)
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
db_name, collection_name = namespace.split(".")
collection = client[db_name][collection_name]
return cls(collection=collection)
Expand Down
7 changes: 6 additions & 1 deletion libs/langchain-mongodb/langchain_mongodb/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from __future__ import annotations

import logging
from importlib.metadata import version
from typing import Dict, List, Optional, Sequence

from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
from langchain_core.runnables.config import run_in_executor
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,7 +82,10 @@ def from_connection_string(
include_db_collection_in_metadata (bool): Flag to include database and
collection names in metadata.
"""
client = MongoClient(connection_string)
client = MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
collection = client[db_name][collection_name]
return MongoDBLoader(
collection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def from_connection_string(
"""
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(name="langchain", version=version("langchain-mongodb")),
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
collection = client[database_name][collection_name]
vectorstore = MongoDBAtlasVectorSearch(
Expand Down
21 changes: 6 additions & 15 deletions libs/langchain-mongodb/langchain_mongodb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
.. code-block:: python
import getpass
MONGODB_ATLAS_CLUSTER_URI = getpass.getpass("MongoDB Atlas Cluster URI:")
MONGODB_ATLAS_CONNECTION_STRING = getpass.getpass("MongoDB Atlas Connection String:")
Key init args — indexing params:
embedding: Embeddings
Expand All @@ -99,20 +99,11 @@ class MongoDBAtlasVectorSearch(VectorStore):
from pymongo import MongoClient
from langchain_openai import OpenAIEmbeddings
# initialize MongoDB python client
client = MongoClient(MONGODB_ATLAS_CLUSTER_URI)
DB_NAME = "langchain_test_db"
COLLECTION_NAME = "langchain_test_vectorstores"
ATLAS_VECTOR_SEARCH_INDEX_NAME = "langchain-test-index-vectorstores"
MONGODB_COLLECTION = client[DB_NAME][COLLECTION_NAME]
vector_store = MongoDBAtlasVectorSearch(
collection=MONGODB_COLLECTION,
vector_store = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=os=MONGODB_ATLAS_CONNECTION_STRING,
namespace="db_name.collection_name",
embedding=OpenAIEmbeddings(),
index_name=ATLAS_VECTOR_SEARCH_INDEX_NAME,
relevance_score_fn="cosine",
index_name="vector_index",
)
Add Documents:
Expand Down Expand Up @@ -279,7 +270,7 @@ def from_connection_string(
"""
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(name="Langchain", version=version("langchain")),
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
db_name, collection_name = namespace.split(".")
collection = client[db_name][collection_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_1clxn_retriever(
# Setup
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(name="langchain", version=version("langchain-mongodb")),
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
db = client[DB_NAME]
combined_clxn = db[COLLECTION_NAME]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from importlib.metadata import version

import mongomock
import pytest
from langchain.memory import ConversationBufferMemory # type: ignore[import-not-found]
from langchain_core.messages import message_to_dict
from pymongo.driver_info import DriverInfo
from pytest_mock import MockerFixture

from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
Expand Down Expand Up @@ -59,7 +61,10 @@ def test_init_with_connection_string(mocker: MockerFixture) -> None:
collection_name="test-collection",
)

mock_mongo_client.assert_called_once_with("mongodb://localhost:27017/")
mock_mongo_client.assert_called_once_with(
"mongodb://localhost:27017/",
driver=DriverInfo(name="Langchain", version=version("langchain-mongodb")),
)
assert history.session_id == "test-session"
assert history.database_name == "test-database"
assert history.collection_name == "test-collection"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from importlib.metadata import version
from typing import (
Any,
Optional,
Expand All @@ -8,6 +9,7 @@
from langchain_core.runnables import RunnableConfig
from pymongo import MongoClient, UpdateOne
from pymongo.database import Database as MongoDatabase
from pymongo.driver_info import DriverInfo

from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
Expand Down Expand Up @@ -88,7 +90,12 @@ def from_conn_string(
"""
client: Optional[MongoClient] = None
try:
client = MongoClient(conn_string)
client = MongoClient(
conn_string,
driver=DriverInfo(
name="Langgraph", version=version("langgraph-checkpoint-mongodb")
),
)
yield MongoDBSaver(
client,
db_name,
Expand Down

0 comments on commit 88a0858

Please sign in to comment.