Skip to content
This repository has been archived by the owner on Jan 9, 2024. It is now read-only.

change insert_one() to insert_many() in mongo_db save function. #455

Merged
merged 15 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RAGchain/DB/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def create_or_load(self, *args, **kwargs):
pass

@abstractmethod
def save(self, passages: List[Passage]):
def save(self, passages: List[Passage], upsert: bool = False):
"""Abstract method for saving passages to the database."""
pass

Expand Down
37 changes: 25 additions & 12 deletions RAGchain/DB/mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from uuid import UUID

import pymongo
from pymongo import UpdateOne

from RAGchain import linker
from RAGchain.DB.base import BaseDB
Expand Down Expand Up @@ -54,19 +55,31 @@ def create_or_load(self):
else:
self.create()

def save(self, passages: List[Passage]):
def save(self, passages: List[Passage], upsert: bool = False):
"""Saves the passages to MongoDB collection."""
id_list = []
db_origin_list = []
for passage in passages:
# save to mongoDB
passage_to_dict = passage.to_dict()
self.collection.insert_one(passage_to_dict)
# save to redisDB
db_origin = self.get_db_origin()
db_origin_dict = db_origin.to_dict()
id_list.append(str(passage.id))
db_origin_list.append(db_origin_dict)
# Setting up files for saving to 'mongodb'
dict_passages = list(map(lambda x: x.to_dict(), passages))
# Setting up files for saving to 'linker'
id_list = list(map(lambda x: str(x.id), passages))
db_origin_list = [self.get_db_origin().to_dict() for _ in passages]

# save to 'mongodb'
if upsert:
db_id_list = [doc['_id'] for doc in self.collection.find({'_id': {'$in': id_list}}, {'_id': 1})]
# Create a dictionary of passages with id as key
dict_passages_dict = {_id: dict_passages[i] for i, _id in enumerate(id_list)}
if len(db_id_list) > 0:
requests = [UpdateOne({'_id': _id},
{'$set': dict_passages_dict[_id]}, upsert=True) for _id in db_id_list]
self.collection.bulk_write(requests)
not_duplicated_ids = [id for id in id_list if id not in db_id_list]
not_duplicated_passages = [dict_passages_dict[_id] for _id in not_duplicated_ids]
if len(not_duplicated_passages) > 0:
self.collection.insert_many(not_duplicated_passages)
else:
self.collection.insert_many(dict_passages)

# save to 'linker'
linker.put_json(id_list, db_origin_list)

def fetch(self, ids: List[UUID]) -> List[Passage]:
Expand Down
28 changes: 18 additions & 10 deletions RAGchain/DB/pickle_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,28 @@ def create_or_load(self):
else:
self.create()

def save(self, passages: List[Passage]):
def save(self, passages: List[Passage], upsert: bool = False):
"""Saves the given list of Passage objects to the pickle database. It also saves the data to the Linker."""
uuid_id_list = list(map(lambda x: x.id, passages))
str_id_list = [str(uuid_id) for uuid_id in uuid_id_list]
duplicate_ids = [doc.id for doc in self.fetch(uuid_id_list)]

id_to_passage = {str(passage.id): passage for passage in passages}

# save to pickleDB
if len(duplicate_ids) > 0:
if upsert:
for str_id in str_id_list:
if str_id in id_to_passage:
self.db.remove(id_to_passage[str_id])
else:
raise ValueError(f'{duplicate_ids} already exists')
self.db.extend(passages)
self._write_pickle()
# save to redisDB
db_origin = self.get_db_origin()
db_origin_dict = db_origin.to_dict()
id_list = []
db_origin_list = []
for passage in passages:
id_list.append(str(passage.id))
db_origin_list.append(db_origin_dict)
linker.put_json(id_list, db_origin_list)

# save to linker
db_origin_list = [self.get_db_origin().to_dict() for _ in passages]
linker.put_json(str_id_list, db_origin_list)

def fetch(self, ids: List[UUID]) -> List[Passage]:
"""Retrieves the Passage objects from the database based on the given list of passage IDs."""
Expand Down
24 changes: 23 additions & 1 deletion tests/RAGchain/DB/test_base_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime
from typing import List
from typing import List, Type

import pytest

from RAGchain.DB.base import BaseDB
from RAGchain.schema import Passage
Expand Down Expand Up @@ -41,6 +43,19 @@
importance=-1,
previous_passage_id='test_id_3',
next_passage_id=None,
metadata_etc={'test': 'test4'}
)
]

DUPLICATE_PASSAGE: List[Passage] = [
Passage(
id='test_id_3',
content='Duplicate test',
filepath='./test/duplicate_file.txt',
content_datetime=datetime(2022, 3, 6),
importance=-1,
previous_passage_id='test_id_2',
next_passage_id=None,
metadata_etc={'test': 'test3'}
)
]
Expand Down Expand Up @@ -107,3 +122,10 @@ def search_test_base(db: BaseDB):
assert len(test_result_11) == 2
assert 'test_id_2' in [passage.id for passage in test_result_11]
assert 'test_id_4' in [passage.id for passage in test_result_11]


def duplicate_id_test_base(db: BaseDB, error_type: Type[Exception]):
with pytest.raises(error_type):
db.save(DUPLICATE_PASSAGE)
db.save(DUPLICATE_PASSAGE, upsert=True)
assert db.fetch([DUPLICATE_PASSAGE[0].id]) == DUPLICATE_PASSAGE
13 changes: 9 additions & 4 deletions tests/RAGchain/DB/test_mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import pytest

import test_base_db
from pymongo.errors import BulkWriteError
from RAGchain.DB import MongoDB
from test_base_db import TEST_PASSAGES, fetch_test_base, search_test_base, duplicate_id_test_base


@pytest.fixture(scope='module')
Expand All @@ -14,7 +15,7 @@ def mongo_db():
db_name=os.getenv('MONGO_DB_NAME'),
collection_name=os.getenv('MONGO_COLLECTION_NAME'))
mongo_db.create_or_load()
mongo_db.save(test_base_db.TEST_PASSAGES)
mongo_db.save(TEST_PASSAGES)
yield mongo_db
mongo_db.collection.drop()
assert mongo_db.collection_name not in mongo_db.db.list_collection_names()
Expand All @@ -25,12 +26,16 @@ def test_create_or_load(mongo_db):


def test_fetch(mongo_db):
test_base_db.fetch_test_base(mongo_db)
fetch_test_base(mongo_db)


def test_db_type(mongo_db):
assert mongo_db.db_type == 'mongo_db'


def test_search(mongo_db):
test_base_db.search_test_base(mongo_db)
search_test_base(mongo_db)


def test_duplicate_id(mongo_db):
duplicate_id_test_base(mongo_db, BulkWriteError)
6 changes: 5 additions & 1 deletion tests/RAGchain/DB/test_pickle_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from RAGchain.DB import PickleDB
from test_base_db import fetch_test_base, TEST_PASSAGES, search_test_base
from test_base_db import fetch_test_base, TEST_PASSAGES, search_test_base, duplicate_id_test_base


@pytest.fixture(scope='module')
Expand Down Expand Up @@ -35,3 +35,7 @@ def test_db_type(pickle_db):

def test_search(pickle_db):
search_test_base(pickle_db)


def test_duplicate_id(pickle_db):
duplicate_id_test_base(pickle_db, ValueError)