Skip to content
This repository has been archived by the owner on Jan 9, 2024. It is now read-only.

change insert_one() to insert_many() in mongo_db save function. #455

Merged
merged 15 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions RAGchain/DB/mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from uuid import UUID

import pymongo
from pymongo import UpdateOne

from RAGchain import linker
from RAGchain.DB.base import BaseDB
Expand Down Expand Up @@ -54,19 +55,31 @@ def create_or_load(self):
else:
self.create()

def save(self, passages: List[Passage]):
def save(self, passages: List[Passage], upsert: bool = False):
"""Saves the passages to MongoDB collection."""
id_list = []
db_origin_list = []
for passage in passages:
# save to mongoDB
passage_to_dict = passage.to_dict()
self.collection.insert_one(passage_to_dict)
# save to redisDB
db_origin = self.get_db_origin()
db_origin_dict = db_origin.to_dict()
id_list.append(str(passage.id))
db_origin_list.append(db_origin_dict)
# Setting up files for saving to 'mongodb'
dict_passages = [passage.to_dict() for passage in passages]
# Setting up files for saving to 'linker'
id_list = [str(passage.id) for passage in passages]
db_origin_list = [self.get_db_origin().to_dict() for _ in passages]

# save to 'mongodb'
if upsert:
db_id_list = [doc['_id'] for doc in self.collection.find({'_id': {'$in': id_list}}, {'_id': 1})]
# Create a dictionary of passages with id as key
dict_passages_dict = {_id: dict_passages[i] for i, _id in enumerate(id_list)}
if len(db_id_list) > 0:
requests = [UpdateOne({'_id': _id},
{'$set': dict_passages_dict[_id]}, upsert=True) for _id in db_id_list]
self.collection.bulk_write(requests)
not_duplicated_ids = [id for id in id_list if id not in db_id_list]
not_duplicated_passages = [dict_passages_dict[_id] for _id in not_duplicated_ids]
if len(not_duplicated_passages) > 0:
self.collection.insert_many(not_duplicated_passages)
else:
self.collection.insert_many(dict_passages)

# save to 'linker'
linker.put_json(id_list, db_origin_list)

def fetch(self, ids: List[UUID]) -> List[Passage]:
Expand Down
2 changes: 1 addition & 1 deletion RAGchain/DB/pickle_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def save(self, passages: List[Passage]):
# save to pickleDB
self.db.extend(passages)
self._write_pickle()
# save to redisDB
# save to linker
db_origin = self.get_db_origin()
db_origin_dict = db_origin.to_dict()
id_list = []
Expand Down
13 changes: 13 additions & 0 deletions tests/RAGchain/DB/test_base_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,19 @@
importance=-1,
previous_passage_id='test_id_3',
next_passage_id=None,
metadata_etc={'test': 'test4'}
)
]

DUPLICATE_PASSAGE: List[Passage] = [
Passage(
id='test_id_3',
content='Duplicate test',
filepath='./test/duplicate_file.txt',
content_datetime=datetime(2022, 3, 6),
importance=-1,
previous_passage_id='test_id_2',
next_passage_id=None,
metadata_etc={'test': 'test3'}
)
]
Expand Down
8 changes: 8 additions & 0 deletions tests/RAGchain/DB/test_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import test_base_db
from pymongo.errors import BulkWriteError
from RAGchain.DB import MongoDB


Expand Down Expand Up @@ -34,3 +35,10 @@ def test_db_type(mongo_db):

def test_search(mongo_db):
test_base_db.search_test_base(mongo_db)


def test_duplicate_id(mongo_db):
with pytest.raises(BulkWriteError):
mongo_db.save(test_base_db.DUPLICATE_PASSAGE)
mongo_db.save(test_base_db.DUPLICATE_PASSAGE, upsert=True)
assert mongo_db.fetch([test_base_db.DUPLICATE_PASSAGE[0].id]) == test_base_db.DUPLICATE_PASSAGE