diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 183af705..c1b29d08 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -53,7 +53,7 @@ def _lazy_init(self) -> None: self.store = self._get_store() self.index = DefaultIndex(self.embed, self.store) - if not self.store.has_group(LAZY_ROOT_NAME): + if not self.store.has_node(LAZY_ROOT_NAME): ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) self.store.update_nodes(root_nodes) @@ -176,9 +176,11 @@ def _add_files(self, input_files: List[str]): root_nodes = self._reader.load_data(input_files) temp_store = self._get_store() temp_store.update_nodes(root_nodes) - active_groups = self.store.active_groups() - LOG.info(f"add_files: Trying to merge store with {active_groups}") - for group in active_groups: + all_groups = self.store.all_groups() + LOG.info(f"add_files: Trying to merge store with {all_groups}") + for group in all_groups: + if not self.store.has_node(group): + continue # Duplicate group will be discarded automatically nodes = self._get_nodes(group, temp_store) self.store.update_nodes(nodes) @@ -212,7 +214,7 @@ def gather_children(node: DocNode): LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}") def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: - if store.has_group(group_name): + if store.has_node(group_name): return node_group = self.node_groups.get(group_name) if node_group is None: @@ -228,7 +230,7 @@ def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: def _get_nodes(self, group_name: str, store: Optional[BaseStore] = None) -> List[DocNode]: store = store or self.store self._dynamic_create_nodes(group_name, store) - return store.traverse_nodes(group_name) + return store.get_nodes(group_name) def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_off: Union[float, Dict[str, float]], index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 558f5bcc..239a0ba2 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -174,18 +174,20 @@ def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: raise NotImplementedError("not implemented yet.") @abstractmethod - def remove_nodes(self, nodes: List[DocNode]) -> None: + def get_nodes(self, group_name: str) -> List[DocNode]: raise NotImplementedError("not implemented yet.") @abstractmethod - def has_group(self, group_name: str) -> bool: + def remove_nodes(self, nodes: List[DocNode]) -> None: raise NotImplementedError("not implemented yet.") @abstractmethod - def traverse_group(self, group_name: str) -> List[DocNode]: + def has_node(self, group_name: str) -> bool: raise NotImplementedError("not implemented yet.") - # XXX NOTE the following APIs should be private. + @abstractmethod + def all_groups(self) -> List[str]: + raise NotImplementedError("not implemented yet.") @abstractmethod def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: @@ -219,24 +221,17 @@ def remove_nodes(self, nodes: List[DocNode]) -> None: self._group2docs[node.group].pop(node.uid, None) # override - def has_group(self, group_name: str) -> bool: - return group_name in self._group2docs + def has_node(self, group_name: str) -> bool: + docs = self._group2docs.get(group_name) + return True if docs else False # override - def traverse_group(self, group_name: str) -> List[DocNode]: + def get_nodes(self, group_name: str) -> List[DocNode]: return list(self._group2docs.get(group_name, {}).values()) - def get_group_docs(self) -> Dict[str, Dict[str, DocNode]]: - return self._group2docs - - def find_node_by_uid(self, uid: str) -> Optional[DocNode]: - for docs in self._group2docs.values(): - doc = docs.get(uid) - if doc: - return doc - return None - - # XXX NOTE the following APIs should be private. + # override + def all_groups(self) -> List[str]: + return [group for group, nodes in self._group2docs.items()] # override def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: @@ -246,6 +241,13 @@ def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: nodes.append(self._file_node_map[file]) return nodes + def find_node_by_uid(self, uid: str) -> Optional[DocNode]: + for docs in self._group2docs.values(): + doc = docs.get(uid) + if doc: + return doc + return None + # ---------------------------------------------------------------------------- # class ChromadbStore(BaseStore): @@ -276,12 +278,16 @@ def remove_nodes(self, nodes: List[DocNode]) -> None: return self._map_store.remove_nodes(nodes) # override - def has_group(self, group_name: str) -> bool: - return self._map_store.has_group(group_name) + def has_node(self, group_name: str) -> bool: + return self._map_store.has_node(group_name) # override - def traverse_group(self, group_name: str) -> List[DocNode]: - return self._map_store.traverse_group(group_name) + def get_nodes(self, group_name: str) -> List[DocNode]: + return self._map_store.get_nodes(group_name) + + # override + def all_groups(self) -> List[str]: + return self._map_store.all_groups() def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: @@ -295,9 +301,8 @@ def _load_store(self) -> None: self._map_store.update_nodes(nodes) # Rebuild relationships - group2docs = self._map_store.get_group_docs() - for group, nodes_dict in group2docs.items(): - for node in nodes_dict.values(): + for group_name in self._map_store.all_groups(): + for node in self._map_store.get_nodes(group_name): if node.parent: parent_uid = node.parent parent_node = self._map_store.find_node_by_uid(parent_uid) @@ -338,7 +343,6 @@ def _save_nodes(self, nodes: List[DocNode]) -> None: ) LOG.debug(f"Saved {group} nodes {ids} to chromadb.") - # XXX NOTE should be private. # override def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: return self._map_store.get_nodes_by_files(files) @@ -372,122 +376,73 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]: # ---------------------------------------------------------------------------- # +class MilvusEmbeddingIndexField: + def __init__(self, name: str = "", dim: int = 0, type: int = pymilvus.DataType.FLOAT_VECTOR): + self.name = name + self.dim: int = dim + self.type: int = type + class MilvusStore(BaseStore): - def __init__(self, node_groups: List[str], uri: str): + def __init__(self, node_groups: List[str], uri: str, + embedding_index_info: List[MilvusEmbeddingIndexField], # fields to be indexed by Milvus + full_data_store: BaseStore): + self._full_data_store = full_data_store + + self._primary_key = 'uid' self._client = pymilvus.MilvusClient(uri=uri) schema = self._client.create_schema(auto_id=False, enable_dynamic_field=True) schema.add_field( - field_name='uid', + field_name=self._primary_key, datatype=pymilvus.DataType.VARCHAR, - max_length=65535, + max_length=128, is_primary=True, ) - schema.add_field( - field_name='text', - datatype=pymilvus.DataType.VARCHAR, - max_length=65535, - ) - schema.add_field( - field_name='group', - datatype=pymilvus.DataType.VARCHAR, - max_length=1024, - ) - schema.add_field( - field_name='embedding_json', - datatype=pymilvus.DataType.VARCHAR, - max_length=65535, - ) - schema.add_field( - field_name='parent_uid', - datatype=pymilvus.DataType.VARCHAR, - max_length=1024, - ) - schema.add_field( - field_name='metadata_json', - datatype=pymilvus.DataType.VARCHAR, - max_length=65535, - ) + for field in embedding_index_info: + schema.add_field( + field_name=field.name, + datatype=field.type, + dim=field.dim) for group in node_groups: if group not in self._client.list_collections(): self._client.create_collection(collection_name=group, schema=schema) - self._map_store = MapStore(node_groups) - self._load_store() - # override def update_nodes(self, nodes: List[DocNode]) -> None: - self._map_store.update_nodes(nodes) self._save_nodes(nodes) + self._full_data_store.update_nodes(nodes) # override def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: - return self._map_store.get_node(group_name, node_id) + return self._full_data_store.get_node(group_name, node_id) # override def remove_nodes(self, nodes: List[DocNode]) -> None: for node in nodes: - self._client.delete(collection_name=node.group, filter=f'uid in ["{node.uid}"]') - self._map_store.remove_nodes(nodes) + self._client.delete(collection_name=node.group, + filter=f'{self._primary_key} in ["{node.uid}"]') + self._full_data_store.remove_nodes(nodes) # override - def has_group(self, group_name: str) -> bool: - return self._map_store.has_group(group_name) + def has_node(self, group_name: str) -> bool: + return self._full_data_store.has_node(group_name) # override - def traverse_group(self, group_name: str) -> List[DocNode]: - return self._map_store.traverse_group(group_name) + def get_nodes(self, group_name: str) -> List[DocNode]: + return self._full_data_store.get_nodes(group_name) - def _load_store(self) -> None: - groups = self._client.list_collections() - for group in groups: - results = self._client.query(collection_name=group, - output_fields=["count(*)"], - limit=1) - if len(results) != 1: - raise ValueError(f"query count(*) of collection [{group}] failed.") - - count = int(results[0]['count(*)']) - if count == 0: - continue + # override + def all_groups(self) -> List[str]: + return self._full_data_store.all_groups() - results = self._client.query(collection_name=group, - query_expression="uid in []", - output_fields=["*"], - limit=count) - for record in results: - doc_node = DocNode( - uid=record['uid'], - text=record['text'], - group=record['group'], - embedding=json.loads(record['embedding_json']), - parent=record['parent_uid'], # NOTE: will be updated later - metadata=json.loads(record['metadata_json']), - ) - self._map_store.update_nodes([doc_node]) - - # update doc's parent - group_docs = self._map_store.get_group_docs() - for group, docs in group_docs.items(): - for uid, doc in docs.items(): - # before find_node_by_uid() `doc.parent` is the parent's uid - doc.parent = self._map_store.find_node_by_uid(doc.parent) + # override + def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: + return self._full_data_store.get_nodes_by_files(files) def _save_nodes(self, nodes: List[DocNode]) -> None: for node in nodes: - data = { - 'uid': node.uid, - 'text': node.text, - 'group': node.group, - 'embedding_json': json.dumps(node.embedding), - 'parent_uid': node.parent.uid if node.parent else '', - 'metadata_json': json.dumps(node._metadata) - } - self._client.upsert(collection_name=node.group, data=data) - - # XXX NOTE the following APIs should be private. - - # override - def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: - return self._map_store.get_nodes_by_files(files) + if node.embedding: + data = node.embedding.copy() + data[self._primary_key] = node.uid + self._client.upsert(collection_name=node.group, data=data) diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index e62dc277..8ea52582 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -59,16 +59,16 @@ def test_retrieve(self): def test_add_files(self): assert self.doc_impl.store is None self.doc_impl._lazy_init() - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 1 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 1 new_doc = DocNode(text="new dummy text", group=LAZY_ROOT_NAME) new_doc.metadata = {"file_name": "new_file.txt"} self.mock_directory_reader.load_data.return_value = [new_doc] self.doc_impl._add_files(["new_file.txt"]) - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 2 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 2 def test_delete_files(self): self.doc_impl._delete_files(["dummy_file.txt"]) - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 0 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 0 class TestDocument(unittest.TestCase): diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index e07a6384..20d8a198 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -3,8 +3,15 @@ import unittest import lazyllm import tempfile +import pymilvus from unittest.mock import MagicMock -from lazyllm.tools.rag.store import DocNode, MapStore, ChromadbStore, MilvusStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.store import ( + DocNode, + MapStore, + ChromadbStore, + MilvusStore, MilvusEmbeddingIndexField, + LAZY_ROOT_NAME +) def clear_directory(directory_path): @@ -38,11 +45,11 @@ def tearDownClass(cls): def test_initialization(self): self.assertEqual(set(self.store._collections.keys()), set(self.node_groups)) - def test_add_and_traverse_group(self): + def test_add_and_get_nodes(self): node1 = DocNode(uid="1", text="text1", group="group1") node2 = DocNode(uid="2", text="text2", group="group2") self.store.update_nodes([node1, node2]) - nodes = self.store.traverse_group("group1") + nodes = self.store.get_nodes("group1") self.assertEqual(nodes, [node1]) def test_save_nodes(self): @@ -59,168 +66,152 @@ def test_load_store(self): self.store.update_nodes([node1, node2]) # Reset store and load from "persistent" storage - self.store._mapstore._group2docs = {group: {} for group in self.node_groups} + self.store._map_store._group2docs = {group: {} for group in self.node_groups} self.store._load_store() - nodes = self.store.traverse_group("group1") + nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) self.assertEqual(nodes[0].uid, "1") self.assertEqual(nodes[1].uid, "2") self.assertEqual(nodes[1].parent.uid, "1") -class TestMapStore(unittest.TestCase): - def test_update_nodes(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - store = MapStore(node_groups) + def test_all_groups(self): + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + def test_group_others(self): node1 = DocNode(uid="1", text="text1", group="group1", parent=None) node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) + self.store.update_nodes([node1, node2]) + self.assertEqual(self.store.has_node("group1"), True) + self.assertEqual(self.store.has_node("group2"), True) + +class TestMapStore(unittest.TestCase): + def setUp(self): + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + self.store = MapStore(self.node_groups) + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1) - nodes = store.traverse_group("group1") + def test_update_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) self.assertEqual(nodes[0].uid, "1") self.assertEqual(nodes[1].uid, "2") self.assertEqual(nodes[1].parent.uid, "1") def test_get_node(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - store = MapStore(node_groups) - - node1 = DocNode(uid="1", text="text1", group="group1", parent=None) - node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) - - n1 = store.get_node("group1", "1") - assert n1.text == node1.text - - n2 = store.get_node("group1", "2") - assert n2.text == node2.text + self.store.update_nodes([self.node1, self.node2]) + n1 = self.store.get_node("group1", "1") + assert n1.text == self.node1.text + n2 = self.store.get_node("group1", "2") + assert n2.text == self.node2.text def test_remove_nodes(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - store = MapStore(node_groups) - - node1 = DocNode(uid="1", text="text1", group="group1", parent=None) - node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) + self.store.update_nodes([self.node1, self.node2]) - n1 = store.get_node("group1", "1") - assert n1.text == node1.text - store.remove_nodes([n1]) - n1 = store.get_node("group1", "1") + n1 = self.store.get_node("group1", "1") + assert n1.text == self.node1.text + self.store.remove_nodes([n1]) + n1 = self.store.get_node("group1", "1") assert not n1 - n2 = store.get_node("group1", "2") - assert n2.text == node2.text - store.remove_nodes([n2]) - n2 = store.get_node("group1", "2") + n2 = self.store.get_node("group1", "2") + assert n2.text == self.node2.text + self.store.remove_nodes([n2]) + n2 = self.store.get_node("group1", "2") assert not n2 - def test_traverse_group(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - store = MapStore(node_groups) - - node1 = DocNode(uid="1", text="text1", group="group1", parent=None) - node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) - ids = set([node1.uid, node2.uid]) + def test_get_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + ids = set([self.node1.uid, self.node2.uid]) - docs = store.traverse_group("group1") + docs = self.store.get_nodes("group1") self.assertEqual(ids, set([doc.uid for doc in docs])) + def test_all_groups(self): + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + def test_group_others(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - store = MapStore(node_groups) - self.assertEqual(store.has_group("group1"), True) - self.assertEqual(store.has_group("group2"), True) + self.store.update_nodes([self.node1, self.node2]) + self.assertEqual(self.store.has_node("group1"), True) + self.assertEqual(self.store.has_node("group2"), False) class TestMilvusStore(unittest.TestCase): - def test_update_nodes(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - _, store_file = tempfile.mkstemp(suffix=".db") - store = MilvusStore(node_groups, store_file) - - node1 = DocNode(uid="1", text="text1", group="group1", parent=None) - node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) + def setUp(self): + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - # Reset store and load from "persistent" storage - store._map_store._group2docs = {group: {} for group in node_groups} - store._load_store() + self.map_store = MapStore(self.node_groups) - nodes = store.traverse_group("group1") - self.assertEqual(len(nodes), 2) - self.assertEqual(nodes[0].uid, "1") - self.assertEqual(nodes[1].uid, "2") - self.assertEqual(nodes[1].parent.uid, "1") + index_field_list = [ + MilvusEmbeddingIndexField("vec1", 3, pymilvus.DataType.FLOAT_VECTOR), + MilvusEmbeddingIndexField("vec2", 5, pymilvus.DataType.FLOAT_VECTOR), + ] + _, self.store_file = tempfile.mkstemp(suffix=".db") + self.store = MilvusStore(node_groups=self.node_groups, uri=self.store_file, + embedding_index_info=index_field_list, + full_data_store=self.map_store) - os.remove(store_file) + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, + embedding={"vec1": [1, 2, 3], "vec2": [4, 5, 6, 7, 8]}) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, + embedding={"vec1": [100, 200, 300], "vec2": [400, 500, 600, 700, 800]}) + self.nodes = [self.node1, self.node2] - def test_get_node(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - _, store_file = tempfile.mkstemp(suffix=".db") - store = MilvusStore(node_groups, store_file) + def tearDown(self): + os.remove(self.store_file) - node1 = DocNode(uid="1", text="text1", group="group1", parent=None) - node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) + def test_update_nodes(self): + self.store.update_nodes(self.nodes) - n1 = store.get_node("group1", "1") - assert n1.text == node1.text + nodes = self.store.get_nodes("group1") + self.assertEqual(len(nodes), len(self.nodes)) - n2 = store.get_node("group1", "2") - assert n2.text == node2.text + counter = 0 + for node in nodes: + for expected_node in self.nodes: + if node.uid == expected_node.uid: + self.assertEqual(node.text, expected_node.text) + counter += 1 + break + self.assertEqual(counter, len(self.nodes)) - os.remove(store_file) + def test_get_node(self): + self.store.update_nodes(self.nodes) + n1 = self.store.get_node("group1", "1") + self.assertEqual(n1.text, self.node1.text) + n2 = self.store.get_node("group1", "2") + self.assertEqual(n2.text, self.node2.text) def test_remove_nodes(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - _, store_file = tempfile.mkstemp(suffix=".db") - store = MilvusStore(node_groups, store_file) - - node1 = DocNode(uid="1", text="text1", group="group1", parent=None) - node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) - - n1 = store.get_node("group1", "1") - assert n1.text == node1.text - store.remove_nodes([n1]) - n1 = store.get_node("group1", "1") - assert not n1 + self.store.update_nodes(self.nodes) - n2 = store.get_node("group1", "2") - assert n2.text == node2.text - store.remove_nodes([n2]) - n2 = store.get_node("group1", "2") - assert not n2 - - os.remove(store_file) + n1 = self.store.get_node("group1", "1") + self.assertEqual(n1.text, self.node1.text) + self.store.remove_nodes([n1]) + n1 = self.store.get_node("group1", "1") + self.assertEqual(n1, None) - def test_traverse_group(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - _, store_file = tempfile.mkstemp(suffix=".db") - store = MilvusStore(node_groups, store_file) + n2 = self.store.get_node("group1", "2") + self.assertEqual(n2.text, self.node2.text) + self.store.remove_nodes([n2]) + n2 = self.store.get_node("group1", "2") + self.assertEqual(n2, None) - node1 = DocNode(uid="1", text="text1", group="group1", parent=None) - node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - store.update_nodes([node1, node2]) - ids = set([node1.uid, node2.uid]) + def test_get_nodes(self): + self.store.update_nodes(self.nodes) + ids = set([self.node1.uid, self.node2.uid]) - docs = store.traverse_group("group1") + docs = self.store.get_nodes("group1") self.assertEqual(ids, set([doc.uid for doc in docs])) - os.remove(store_file) + def test_all_groups(self): + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) def test_group_others(self): - node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - _, store_file = tempfile.mkstemp(suffix=".db") - store = MilvusStore(node_groups, store_file) - - self.assertEqual(store.has_group("group1"), True) - self.assertEqual(store.has_group("group2"), True) - - os.remove(store_file) + self.store.update_nodes([self.node1, self.node2]) + self.assertEqual(self.store.has_node("group1"), True) + self.assertEqual(self.store.has_node("group2"), False) if __name__ == "__main__": unittest.main() diff --git a/tests/requirements.txt b/tests/requirements.txt index fe37e5e2..79e0ccc9 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,3 +2,4 @@ wikipedia docx2txt olefile pytest-rerunfailures +pymilvus