Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates MongoDBGraphStore.__init__ signature #83

Merged
merged 4 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain-mongodb/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

---

## Changes in version 0.5 (2025/02/11)
## Changes in version 0.5 (2025/02/25)

- Added GraphRAG support via `MongoDBGraphStore`

Expand Down
83 changes: 66 additions & 17 deletions libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,15 @@ class MongoDBGraphStore:
from completely different sources.
- "Jane Smith works with John Doe."
- "Jane Smith works at MongoDB."

"""

def __init__(
self,
collection: Collection,
*,
connection_string: Optional[str] = None,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
collection: Optional[Collection] = None,
entity_extraction_model: BaseChatModel,
entity_prompt: ChatPromptTemplate = None,
query_prompt: ChatPromptTemplate = None,
Expand All @@ -106,7 +109,11 @@ def __init__(
):
"""
Args:
collection: Collection representing an Entity Graph.
connection_string: A valid MongoDB connection URI.
database_name: The name of the database to connect to.
collection_name: The name of the collection to connect to.
collection: A Collection that will represent a Knowledge Graph.
** One may pass a Collection in lieu of connection_string, database_name, and collection_name.
entity_extraction_model: LLM for converting documents into Graph of Entities and Relationships.
entity_prompt: Prompt to fill graph store with entities following schema.
Defaults to .prompts.ENTITY_EXTRACTION_INSTRUCTIONS
Expand All @@ -122,6 +129,62 @@ def __init__(
- If "warn", the default, documents will be inserted but errors logged.
- If "error", an exception will be raised if any document does not match the schema.
"""
self._schema = deepcopy(entity_schema)
collection_existed = True
if connection_string and collection is not None:
raise ValueError(
"Pass one of: connection_string, database_name, and collection_name"
"OR a MongoDB Collection."
)
if collection is None: # collection is specified by uri and names
client: MongoClient = MongoClient(
connection_string,
driver=DriverInfo(
name="Langchain", version=version("langchain-mongodb")
),
)
db = client[database_name]
if collection_name not in db.list_collection_names():
validator = {"$jsonSchema": self._schema} if validate else None
collection = client[database_name].create_collection(
collection_name,
validator=validator,
validationAction=validation_action,
)
collection_existed = False
else:
collection = db[collection_name]
else:
if not isinstance(collection, Collection):
raise ValueError(
"collection must be a MongoDB Collection. "
"Consider using connection_string, database_name, and collection_name."
)

if validate and collection_existed:
# first check for existing validator
collection_info = collection.database.command(
"listCollections", filter={"name": collection.name}
)
collection_options = collection_info.get("cursor", {}).get("firstBatch", [])
validator = collection_options[0].get("options", {}).get("validator", None)
if not validator:
try:
collection.database.command(
"collMod",
collection.name,
validator={"$jsonSchema": self._schema},
validationAction=validation_action,
)
except OperationFailure:
logger.warning(
"Validation will NOT be performed. "
"User must be DB Admin to add validation **after** a Collection is created. \n"
"Please add validator when you create collection: "
"db.create_collection.(coll_name, validator={'$jsonSchema': schema.entity_schema})"
)
self.collection = collection

self.entity_extraction_model = entity_extraction_model
self.entity_prompt = (
prompts.entity_prompt if entity_prompt is None else entity_prompt
Expand All @@ -145,20 +208,6 @@ def __init__(
] = allowed_relationship_types
else:
self.allowed_relationship_types = []
if validate:
try:
collection.database.command(
"collMod",
collection.name,
validator={"$jsonSchema": self._schema},
validationAction=validation_action,
)
except OperationFailure:
logger.warning(
"Validation will NOT be performed. User must be DB Admin to add validation **after** a Collection is created. \n"
"Please add validator when you create collection: db.create_collection.(coll_name, validator={'$jsonSchema': self._schema})"
)
self.collection = collection

# Include examples
if entity_examples is None:
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-mongodb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dev = [
"langchain-openai>=0.2.14",
"langchain-community>=0.3.14",
"pypdf>=5.0.1",
"flaky>=3.8.1",
]

[tool.pytest.ini_options]
Expand Down
103 changes: 86 additions & 17 deletions libs/langchain-mongodb/tests/integration_tests/test_graphrag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

import pytest
from flaky import flaky
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
Expand Down Expand Up @@ -117,7 +118,10 @@ def entity_example():
@pytest.fixture(scope="module")
def graph_store(collection, entity_extraction_model, documents) -> MongoDBGraphStore:
store = MongoDBGraphStore(
collection, entity_extraction_model, entity_prompt, query_prompt
collection=collection,
entity_extraction_model=entity_extraction_model,
entity_prompt=entity_prompt,
query_prompt=query_prompt,
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
Expand All @@ -136,6 +140,7 @@ def test_add_docs_store(graph_store):
assert 4 <= len(extracted_entities) < 8


@flaky
def test_extract_entity_names(graph_store, query_connection):
query_entity_names = graph_store.extract_entity_names(query_connection)
assert set(query_entity_names) == {"John Doe", "Jane Smith"}
Expand All @@ -145,6 +150,7 @@ def test_extract_entity_names(graph_store, query_connection):
assert len(no_names) == 0


@flaky
def test_related_entities(graph_store):
entity_names = ["John Doe", "Jane Smith"]
related_entities = graph_store.related_entities(entity_names)
Expand All @@ -155,29 +161,39 @@ def test_related_entities(graph_store):
assert len(no_entities) == 0


@flaky
def test_additional_entity_examples(entity_extraction_model, entity_example, documents):
# Test additional examples
# First, create one client just to drop any existing collections
client = MongoClient(MONGODB_URI)
db = client[DB_NAME]
clxn_name = f"{COLLECTION_NAME}_addl_examples"
db[clxn_name].drop()
collection = db.create_collection(clxn_name)
client[DB_NAME][clxn_name].drop()
# Test additional examples
store_with_addl_examples = MongoDBGraphStore(
collection, entity_extraction_model, entity_examples=entity_example
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=clxn_name,
entity_extraction_model=entity_extraction_model,
entity_prompt=entity_prompt,
query_prompt=query_prompt,
entity_examples=entity_example,
)
store_with_addl_examples.collection.drop()

store_with_addl_examples.add_documents(documents)
entity_names = ["ACME Corporation", "GreenTech Ltd."]
new_entities = store_with_addl_examples.related_entities(entity_names)
assert len(new_entities) >= 2


@flaky
def test_chat_response(graph_store, query_connection):
"""Displays querying an existing Knowledge Graph Database"""
answer = graph_store.chat_response(query_connection)
assert isinstance(answer, AIMessage)
assert "acme corporation" in answer.content.lower()


@flaky
def test_similarity_search(graph_store, query_connection):
docs = graph_store.similarity_search(query_connection)
assert len(docs) >= 4
Expand All @@ -186,33 +202,78 @@ def test_similarity_search(graph_store, query_connection):
assert any("attributes" in d.keys() for d in docs)


@flaky
def test_validator(documents, entity_extraction_model):
# Case 1. No existing collection.
client = MongoClient(MONGODB_URI)
clxn_name = "langchain_test_graphrag_validation"
clxn_name = f"{COLLECTION_NAME}_validation"
client[DB_NAME][clxn_name].drop()
clxn = client[DB_NAME].create_collection(clxn_name)
# now we call with validation that can be added without db admin privileges
store = MongoDBGraphStore(
clxn, entity_extraction_model, validate=True, validation_action="error"
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=clxn_name,
entity_extraction_model=entity_extraction_model,
validate=True,
validation_action="error",
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
entities = store.collection.find({}).to_list()
# Using subset because SolarGrid Initiative is not always considered an entity
assert {"Person", "Organization"}.issubset(set(e["type"] for e in entities))
client.close()

# Case 2: Existing collection with a validator
client = MongoClient(MONGODB_URI)
clxn_name = f"{COLLECTION_NAME}_validation"
collection = client[DB_NAME][clxn_name]
collection.delete_many({})

store = MongoDBGraphStore(
collection=collection,
entity_extraction_model=entity_extraction_model,
validate=True,
validation_action="error",
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
collection.drop()
client.close()

# Case 3: Existing collection without a validator
client = MongoClient(MONGODB_URI)
clxn_name = f"{COLLECTION_NAME}_validation"
collection = client[DB_NAME].create_collection(clxn_name)
store = MongoDBGraphStore(
collection=collection,
entity_extraction_model=entity_extraction_model,
validate=True,
validation_action="error",
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
client.close()


@flaky
def test_allowed_entity_types(documents, entity_extraction_model):
"""Add allowed_entity_types. Use the validator to confirm behaviour."""
allowed_entity_types = ["Person"]
# drop collection
client = MongoClient(MONGODB_URI)
collection_name = f"{COLLECTION_NAME}_allowed_entity_types"
client[DB_NAME][collection_name].drop()
collection = client[DB_NAME].create_collection(collection_name)
# create knowledge graph with only allowed_entity_types
# this changes the schema at runtime
store = MongoDBGraphStore(
collection,
entity_extraction_model,
allowed_entity_types=allowed_entity_types,
validate=True,
validation_action="error",
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=collection_name,
entity_extraction_model=entity_extraction_model,
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
Expand All @@ -223,16 +284,24 @@ def test_allowed_entity_types(documents, entity_extraction_model):
all([len(e["relationships"].get("attributes", [])) == 0 for e in entities])


@flaky
def test_allowed_relationship_types(documents, entity_extraction_model):
# drop collection
client = MongoClient(MONGODB_URI)
collection_name = f"{COLLECTION_NAME}_allowed_relationship_types"
client[DB_NAME][collection_name].drop()
collection = client[DB_NAME].create_collection(collection_name)
clxn_name = f"{COLLECTION_NAME}_allowed_relationship_types"
client[DB_NAME][clxn_name].drop()
collection = client[DB_NAME].create_collection(clxn_name)
collection.drop()
# create knowledge graph with only allowed_relationship_types=["partner"]
# this changes the schema at runtime
store = MongoDBGraphStore(
collection,
entity_extraction_model,
allowed_relationship_types=["partner"],
validate=True,
validation_action="error",
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=clxn_name,
entity_extraction_model=entity_extraction_model,
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
Expand Down
11 changes: 11 additions & 0 deletions libs/langchain-mongodb/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading