Skip to content

Commit

Permalink
s
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Oct 15, 2024
1 parent 303e0c7 commit 39f3738
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
6 changes: 3 additions & 3 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _lazy_init(self) -> None:

self.store = self._get_store()
self.index = DefaultIndex(self.embed, self.store)
if not self.store.has_node(LAZY_ROOT_NAME):
if not self.store.has_nodes(LAZY_ROOT_NAME):
ids, pathes = self._list_files()
root_nodes = self._reader.load_data(pathes)
self.store.update_nodes(root_nodes)
Expand Down Expand Up @@ -179,7 +179,7 @@ def _add_files(self, input_files: List[str]):
all_groups = self.store.all_groups()
LOG.info(f"add_files: Trying to merge store with {all_groups}")
for group in all_groups:
if not self.store.has_node(group):
if not self.store.has_nodes(group):
continue
# Duplicate group will be discarded automatically
nodes = self._get_nodes(group, temp_store)
Expand Down Expand Up @@ -214,7 +214,7 @@ def gather_children(node: DocNode):
LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}")

def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None:
if store.has_node(group_name):
if store.has_nodes(group_name):
return
node_group = self.node_groups.get(group_name)
if node_group is None:
Expand Down
17 changes: 9 additions & 8 deletions lazyllm/tools/rag/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def remove_nodes(self, nodes: List[DocNode]) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def has_node(self, group_name: str) -> bool:
def has_nodes(self, group_name: str) -> bool:
raise NotImplementedError("not implemented yet.")

@abstractmethod
Expand Down Expand Up @@ -221,7 +221,7 @@ def remove_nodes(self, nodes: List[DocNode]) -> None:
self._group2docs[node.group].pop(node.uid, None)

# override
def has_node(self, group_name: str) -> bool:
def has_nodes(self, group_name: str) -> bool:
docs = self._group2docs.get(group_name)
return True if docs else False

Expand Down Expand Up @@ -278,8 +278,8 @@ def remove_nodes(self, nodes: List[DocNode]) -> None:
return self._map_store.remove_nodes(nodes)

# override
def has_node(self, group_name: str) -> bool:
return self._map_store.has_node(group_name)
def has_nodes(self, group_name: str) -> bool:
return self._map_store.has_nodes(group_name)

# override
def get_nodes(self, group_name: str) -> List[DocNode]:
Expand All @@ -302,13 +302,14 @@ def _load_store(self) -> None:

# Rebuild relationships
for group_name in self._map_store.all_groups():
for node in self._map_store.get_nodes(group_name):
nodes = self._map_store.get_nodes(group_name)
for node in nodes:
if node.parent:
parent_uid = node.parent
parent_node = self._map_store.find_node_by_uid(parent_uid)
node.parent = parent_node
parent_node.children[node.group].append(node)
LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}")
LOG.debug(f"build {group} nodes from chromadb: {nodes}")
LOG.success("Successfully Built nodes from chromadb.")

def _save_nodes(self, nodes: List[DocNode]) -> None:
Expand Down Expand Up @@ -425,8 +426,8 @@ def remove_nodes(self, nodes: List[DocNode]) -> None:
self._full_data_store.remove_nodes(nodes)

# override
def has_node(self, group_name: str) -> bool:
return self._full_data_store.has_node(group_name)
def has_nodes(self, group_name: str) -> bool:
return self._full_data_store.has_nodes(group_name)

# override
def get_nodes(self, group_name: str) -> List[DocNode]:
Expand Down
12 changes: 6 additions & 6 deletions tests/basic_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_group_others(self):
node1 = DocNode(uid="1", text="text1", group="group1", parent=None)
node2 = DocNode(uid="2", text="text2", group="group1", parent=node1)
self.store.update_nodes([node1, node2])
self.assertEqual(self.store.has_node("group1"), True)
self.assertEqual(self.store.has_node("group2"), True)
self.assertEqual(self.store.has_nodes("group1"), True)
self.assertEqual(self.store.has_nodes("group2"), True)

class TestMapStore(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -134,8 +134,8 @@ def test_all_groups(self):

def test_group_others(self):
self.store.update_nodes([self.node1, self.node2])
self.assertEqual(self.store.has_node("group1"), True)
self.assertEqual(self.store.has_node("group2"), False)
self.assertEqual(self.store.has_nodes("group1"), True)
self.assertEqual(self.store.has_nodes("group2"), False)

class TestMilvusStore(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -210,8 +210,8 @@ def test_all_groups(self):

def test_group_others(self):
self.store.update_nodes([self.node1, self.node2])
self.assertEqual(self.store.has_node("group1"), True)
self.assertEqual(self.store.has_node("group2"), False)
self.assertEqual(self.store.has_nodes("group1"), True)
self.assertEqual(self.store.has_nodes("group2"), False)

if __name__ == "__main__":
unittest.main()

0 comments on commit 39f3738

Please sign in to comment.