From 39f3738b31325b10722dcbae5912cf02ed85e5be Mon Sep 17 00:00:00 2001 From: ouguoyu Date: Tue, 15 Oct 2024 15:34:58 +0800 Subject: [PATCH] s --- lazyllm/tools/rag/doc_impl.py | 6 +++--- lazyllm/tools/rag/store.py | 17 +++++++++-------- tests/basic_tests/test_store.py | 12 ++++++------ 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index c1b29d08..1de968c5 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -53,7 +53,7 @@ def _lazy_init(self) -> None: self.store = self._get_store() self.index = DefaultIndex(self.embed, self.store) - if not self.store.has_node(LAZY_ROOT_NAME): + if not self.store.has_nodes(LAZY_ROOT_NAME): ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) self.store.update_nodes(root_nodes) @@ -179,7 +179,7 @@ def _add_files(self, input_files: List[str]): all_groups = self.store.all_groups() LOG.info(f"add_files: Trying to merge store with {all_groups}") for group in all_groups: - if not self.store.has_node(group): + if not self.store.has_nodes(group): continue # Duplicate group will be discarded automatically nodes = self._get_nodes(group, temp_store) @@ -214,7 +214,7 @@ def gather_children(node: DocNode): LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}") def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: - if store.has_node(group_name): + if store.has_nodes(group_name): return node_group = self.node_groups.get(group_name) if node_group is None: diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 239a0ba2..57cee198 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -182,7 +182,7 @@ def remove_nodes(self, nodes: List[DocNode]) -> None: raise NotImplementedError("not implemented yet.") @abstractmethod - def has_node(self, group_name: str) -> bool: + def has_nodes(self, group_name: str) -> bool: raise NotImplementedError("not implemented yet.") @abstractmethod @@ -221,7 +221,7 @@ def remove_nodes(self, nodes: List[DocNode]) -> None: self._group2docs[node.group].pop(node.uid, None) # override - def has_node(self, group_name: str) -> bool: + def has_nodes(self, group_name: str) -> bool: docs = self._group2docs.get(group_name) return True if docs else False @@ -278,8 +278,8 @@ def remove_nodes(self, nodes: List[DocNode]) -> None: return self._map_store.remove_nodes(nodes) # override - def has_node(self, group_name: str) -> bool: - return self._map_store.has_node(group_name) + def has_nodes(self, group_name: str) -> bool: + return self._map_store.has_nodes(group_name) # override def get_nodes(self, group_name: str) -> List[DocNode]: @@ -302,13 +302,14 @@ def _load_store(self) -> None: # Rebuild relationships for group_name in self._map_store.all_groups(): - for node in self._map_store.get_nodes(group_name): + nodes = self._map_store.get_nodes(group_name) + for node in nodes: if node.parent: parent_uid = node.parent parent_node = self._map_store.find_node_by_uid(parent_uid) node.parent = parent_node parent_node.children[node.group].append(node) - LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}") + LOG.debug(f"build {group} nodes from chromadb: {nodes}") LOG.success("Successfully Built nodes from chromadb.") def _save_nodes(self, nodes: List[DocNode]) -> None: @@ -425,8 +426,8 @@ def remove_nodes(self, nodes: List[DocNode]) -> None: self._full_data_store.remove_nodes(nodes) # override - def has_node(self, group_name: str) -> bool: - return self._full_data_store.has_node(group_name) + def has_nodes(self, group_name: str) -> bool: + return self._full_data_store.has_nodes(group_name) # override def get_nodes(self, group_name: str) -> List[DocNode]: diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 20d8a198..facaa39b 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -82,8 +82,8 @@ def test_group_others(self): node1 = DocNode(uid="1", text="text1", group="group1", parent=None) node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) self.store.update_nodes([node1, node2]) - self.assertEqual(self.store.has_node("group1"), True) - self.assertEqual(self.store.has_node("group2"), True) + self.assertEqual(self.store.has_nodes("group1"), True) + self.assertEqual(self.store.has_nodes("group2"), True) class TestMapStore(unittest.TestCase): def setUp(self): @@ -134,8 +134,8 @@ def test_all_groups(self): def test_group_others(self): self.store.update_nodes([self.node1, self.node2]) - self.assertEqual(self.store.has_node("group1"), True) - self.assertEqual(self.store.has_node("group2"), False) + self.assertEqual(self.store.has_nodes("group1"), True) + self.assertEqual(self.store.has_nodes("group2"), False) class TestMilvusStore(unittest.TestCase): def setUp(self): @@ -210,8 +210,8 @@ def test_all_groups(self): def test_group_others(self): self.store.update_nodes([self.node1, self.node2]) - self.assertEqual(self.store.has_node("group1"), True) - self.assertEqual(self.store.has_node("group2"), False) + self.assertEqual(self.store.has_nodes("group1"), True) + self.assertEqual(self.store.has_nodes("group2"), False) if __name__ == "__main__": unittest.main()