Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates MongoDBGraphStore.__init__ signature #83

Merged
merged 4 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,15 @@ class MongoDBGraphStore:
from completely different sources.
- "Jane Smith works with John Doe."
- "Jane Smith works at MongoDB."

"""

def __init__(
self,
collection: Collection,
*,
connection_string: Optional[str] = None,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
collection: Optional[Collection] = None,
entity_extraction_model: BaseChatModel,
entity_prompt: ChatPromptTemplate = None,
query_prompt: ChatPromptTemplate = None,
Expand All @@ -106,7 +109,11 @@ def __init__(
):
"""
Args:
collection: Collection representing an Entity Graph.
connection_string: A valid MongoDB connection URI.
database_name: The name of the database to connect to.
collection_name: The name of the collection to connect to.
collection: A Collection that will represent a Knowledge Graph.
** One may pass a Collection in lieu of connection_string, database_name, and collection_name.
entity_extraction_model: LLM for converting documents into Graph of Entities and Relationships.
entity_prompt: Prompt to fill graph store with entities following schema.
Defaults to .prompts.ENTITY_EXTRACTION_INSTRUCTIONS
Expand All @@ -122,6 +129,48 @@ def __init__(
- If "warn", the default, documents will be inserted but errors logged.
- If "error", an exception will be raised if any document does not match the schema.
"""
self._schema = deepcopy(entity_schema)
if connection_string and collection is not None:
raise ValueError(
"Pass one of: connection_string, database_name, and collection_name"
"OR a MongoDB Collection."
)
if collection is not None:
if not isinstance(collection, Collection):
raise ValueError(
"collection must be a MongoDB Collection. "
"Consider using connection_string, database_name, and collection_name."
)
if validate:
try:
collection.database.command(
"collMod",
collection.name,
validator={"$jsonSchema": self._schema},
validationAction=validation_action,
)
except OperationFailure:
logger.warning(
"Validation will NOT be performed. "
"User must be DB Admin to add validation **after** a Collection is created. \n"
"Please add validator when you create collection: "
"db.create_collection.(coll_name, validator={'$jsonSchema': schema.entity_schema})"
)
else:
client: MongoClient = MongoClient(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't assume that the collection doesn't already exist just because we're creating the client.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@blink1073 This is inside a warning message.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is: if the collection already exists, what is the behavior here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I see what happens now. There would be no validation added. I think we need to always attempt to add validation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do attempt. It's in a try/except block. Let's discuss on slack if we want to delay release of a version for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not attempt for a collection that already exists, which may have been created by the user, no?

connection_string,
driver=DriverInfo(
name="Langchain", version=version("langchain-mongodb")
),
)
db = client[database_name]
if collection_name not in db.list_collection_names():
validator = {"$jsonSchema": self._schema} if validate else None
collection = client[database_name].create_collection(
collection_name, validator=validator
)
self.collection = collection

self.entity_extraction_model = entity_extraction_model
self.entity_prompt = (
prompts.entity_prompt if entity_prompt is None else entity_prompt
Expand All @@ -145,20 +194,6 @@ def __init__(
] = allowed_relationship_types
else:
self.allowed_relationship_types = []
if validate:
try:
collection.database.command(
"collMod",
collection.name,
validator={"$jsonSchema": self._schema},
validationAction=validation_action,
)
except OperationFailure:
logger.warning(
"Validation will NOT be performed. User must be DB Admin to add validation **after** a Collection is created. \n"
"Please add validator when you create collection: db.create_collection.(coll_name, validator={'$jsonSchema': self._schema})"
)
self.collection = collection

# Include examples
if entity_examples is None:
Expand Down
62 changes: 45 additions & 17 deletions libs/langchain-mongodb/tests/integration_tests/test_graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def entity_example():
@pytest.fixture(scope="module")
def graph_store(collection, entity_extraction_model, documents) -> MongoDBGraphStore:
store = MongoDBGraphStore(
collection, entity_extraction_model, entity_prompt, query_prompt
collection=collection,
entity_extraction_model=entity_extraction_model,
entity_prompt=entity_prompt,
query_prompt=query_prompt,
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
Expand Down Expand Up @@ -156,15 +159,22 @@ def test_related_entities(graph_store):


def test_additional_entity_examples(entity_extraction_model, entity_example, documents):
# Test additional examples
# First, create one client just to drop any existing collections
client = MongoClient(MONGODB_URI)
db = client[DB_NAME]
clxn_name = f"{COLLECTION_NAME}_addl_examples"
db[clxn_name].drop()
collection = db.create_collection(clxn_name)
client[DB_NAME][clxn_name].drop()
# Test additional examples
store_with_addl_examples = MongoDBGraphStore(
collection, entity_extraction_model, entity_examples=entity_example
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=clxn_name,
entity_extraction_model=entity_extraction_model,
entity_prompt=entity_prompt,
query_prompt=query_prompt,
entity_examples=entity_example,
)
store_with_addl_examples.collection.drop()

store_with_addl_examples.add_documents(documents)
entity_names = ["ACME Corporation", "GreenTech Ltd."]
new_entities = store_with_addl_examples.related_entities(entity_names)
Expand All @@ -187,12 +197,18 @@ def test_similarity_search(graph_store, query_connection):


def test_validator(documents, entity_extraction_model):
# First, create one client just to drop any existing collections
client = MongoClient(MONGODB_URI)
clxn_name = "langchain_test_graphrag_validation"
clxn_name = f"{COLLECTION_NAME}_validation"
client[DB_NAME][clxn_name].drop()
clxn = client[DB_NAME].create_collection(clxn_name)
# now we call with validation that can be added without db admin privileges
store = MongoDBGraphStore(
clxn, entity_extraction_model, validate=True, validation_action="error"
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=clxn_name,
entity_extraction_model=entity_extraction_model,
validate=True,
validation_action="error",
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
Expand All @@ -204,15 +220,20 @@ def test_validator(documents, entity_extraction_model):
def test_allowed_entity_types(documents, entity_extraction_model):
"""Add allowed_entity_types. Use the validator to confirm behaviour."""
allowed_entity_types = ["Person"]
# drop collection
client = MongoClient(MONGODB_URI)
collection_name = f"{COLLECTION_NAME}_allowed_entity_types"
client[DB_NAME][collection_name].drop()
collection = client[DB_NAME].create_collection(collection_name)
# create knowledge graph with only allowed_entity_types
# this changes the schema at runtime
store = MongoDBGraphStore(
collection,
entity_extraction_model,
allowed_entity_types=allowed_entity_types,
validate=True,
validation_action="error",
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=collection_name,
entity_extraction_model=entity_extraction_model,
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
Expand All @@ -224,15 +245,22 @@ def test_allowed_entity_types(documents, entity_extraction_model):


def test_allowed_relationship_types(documents, entity_extraction_model):
# drop collection
client = MongoClient(MONGODB_URI)
collection_name = f"{COLLECTION_NAME}_allowed_relationship_types"
client[DB_NAME][collection_name].drop()
collection = client[DB_NAME].create_collection(collection_name)
clxn_name = f"{COLLECTION_NAME}_allowed_relationship_types"
client[DB_NAME][clxn_name].drop()
collection = client[DB_NAME].create_collection(clxn_name)
collection.drop()
# create knowledge graph with only allowed_relationship_types=["partner"]
# this changes the schema at runtime
store = MongoDBGraphStore(
collection,
entity_extraction_model,
allowed_relationship_types=["partner"],
validate=True,
validation_action="error",
connection_string=MONGODB_URI,
database_name=DB_NAME,
collection_name=clxn_name,
entity_extraction_model=entity_extraction_model,
)
bulkwrite_results = store.add_documents(documents)
assert len(bulkwrite_results) == len(documents)
Expand Down
Loading