Skip to content

Commit

Permalink
s
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Oct 15, 2024
1 parent 3ae16ed commit 303e0c7
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 239 deletions.
14 changes: 8 additions & 6 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _lazy_init(self) -> None:

self.store = self._get_store()
self.index = DefaultIndex(self.embed, self.store)
if not self.store.has_group(LAZY_ROOT_NAME):
if not self.store.has_node(LAZY_ROOT_NAME):
ids, pathes = self._list_files()
root_nodes = self._reader.load_data(pathes)
self.store.update_nodes(root_nodes)
Expand Down Expand Up @@ -176,9 +176,11 @@ def _add_files(self, input_files: List[str]):
root_nodes = self._reader.load_data(input_files)
temp_store = self._get_store()
temp_store.update_nodes(root_nodes)
active_groups = self.store.active_groups()
LOG.info(f"add_files: Trying to merge store with {active_groups}")
for group in active_groups:
all_groups = self.store.all_groups()
LOG.info(f"add_files: Trying to merge store with {all_groups}")
for group in all_groups:
if not self.store.has_node(group):
continue
# Duplicate group will be discarded automatically
nodes = self._get_nodes(group, temp_store)
self.store.update_nodes(nodes)
Expand Down Expand Up @@ -212,7 +214,7 @@ def gather_children(node: DocNode):
LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}")

def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None:
if store.has_group(group_name):
if store.has_node(group_name):
return
node_group = self.node_groups.get(group_name)
if node_group is None:
Expand All @@ -228,7 +230,7 @@ def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None:
def _get_nodes(self, group_name: str, store: Optional[BaseStore] = None) -> List[DocNode]:
store = store or self.store
self._dynamic_create_nodes(group_name, store)
return store.traverse_nodes(group_name)
return store.get_nodes(group_name)

def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_off: Union[float, Dict[str, float]],
index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]:
Expand Down
181 changes: 68 additions & 113 deletions lazyllm/tools/rag/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,20 @@ def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def remove_nodes(self, nodes: List[DocNode]) -> None:
def get_nodes(self, group_name: str) -> List[DocNode]:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def has_group(self, group_name: str) -> bool:
def remove_nodes(self, nodes: List[DocNode]) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def traverse_group(self, group_name: str) -> List[DocNode]:
def has_node(self, group_name: str) -> bool:
raise NotImplementedError("not implemented yet.")

# XXX NOTE the following APIs should be private.
@abstractmethod
def all_groups(self) -> List[str]:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
Expand Down Expand Up @@ -219,24 +221,17 @@ def remove_nodes(self, nodes: List[DocNode]) -> None:
self._group2docs[node.group].pop(node.uid, None)

# override
def has_group(self, group_name: str) -> bool:
return group_name in self._group2docs
def has_node(self, group_name: str) -> bool:
docs = self._group2docs.get(group_name)
return True if docs else False

# override
def traverse_group(self, group_name: str) -> List[DocNode]:
def get_nodes(self, group_name: str) -> List[DocNode]:
return list(self._group2docs.get(group_name, {}).values())

def get_group_docs(self) -> Dict[str, Dict[str, DocNode]]:
return self._group2docs

def find_node_by_uid(self, uid: str) -> Optional[DocNode]:
for docs in self._group2docs.values():
doc = docs.get(uid)
if doc:
return doc
return None

# XXX NOTE the following APIs should be private.
# override
def all_groups(self) -> List[str]:
return [group for group, nodes in self._group2docs.items()]

# override
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
Expand All @@ -246,6 +241,13 @@ def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
nodes.append(self._file_node_map[file])
return nodes

def find_node_by_uid(self, uid: str) -> Optional[DocNode]:
for docs in self._group2docs.values():
doc = docs.get(uid)
if doc:
return doc
return None

# ---------------------------------------------------------------------------- #

class ChromadbStore(BaseStore):
Expand Down Expand Up @@ -276,12 +278,16 @@ def remove_nodes(self, nodes: List[DocNode]) -> None:
return self._map_store.remove_nodes(nodes)

# override
def has_group(self, group_name: str) -> bool:
return self._map_store.has_group(group_name)
def has_node(self, group_name: str) -> bool:
return self._map_store.has_node(group_name)

# override
def traverse_group(self, group_name: str) -> List[DocNode]:
return self._map_store.traverse_group(group_name)
def get_nodes(self, group_name: str) -> List[DocNode]:
return self._map_store.get_nodes(group_name)

# override
def all_groups(self) -> List[str]:
return self._map_store.all_groups()

def _load_store(self) -> None:
if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]:
Expand All @@ -295,9 +301,8 @@ def _load_store(self) -> None:
self._map_store.update_nodes(nodes)

# Rebuild relationships
group2docs = self._map_store.get_group_docs()
for group, nodes_dict in group2docs.items():
for node in nodes_dict.values():
for group_name in self._map_store.all_groups():
for node in self._map_store.get_nodes(group_name):
if node.parent:
parent_uid = node.parent
parent_node = self._map_store.find_node_by_uid(parent_uid)
Expand Down Expand Up @@ -338,7 +343,6 @@ def _save_nodes(self, nodes: List[DocNode]) -> None:
)
LOG.debug(f"Saved {group} nodes {ids} to chromadb.")

# XXX NOTE should be private.
# override
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
return self._map_store.get_nodes_by_files(files)
Expand Down Expand Up @@ -372,122 +376,73 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]:

# ---------------------------------------------------------------------------- #

class MilvusEmbeddingIndexField:
def __init__(self, name: str = "", dim: int = 0, type: int = pymilvus.DataType.FLOAT_VECTOR):
self.name = name
self.dim: int = dim
self.type: int = type

class MilvusStore(BaseStore):
def __init__(self, node_groups: List[str], uri: str):
def __init__(self, node_groups: List[str], uri: str,
embedding_index_info: List[MilvusEmbeddingIndexField], # fields to be indexed by Milvus
full_data_store: BaseStore):
self._full_data_store = full_data_store

self._primary_key = 'uid'
self._client = pymilvus.MilvusClient(uri=uri)

schema = self._client.create_schema(auto_id=False, enable_dynamic_field=True)
schema.add_field(
field_name='uid',
field_name=self._primary_key,
datatype=pymilvus.DataType.VARCHAR,
max_length=65535,
max_length=128,
is_primary=True,
)
schema.add_field(
field_name='text',
datatype=pymilvus.DataType.VARCHAR,
max_length=65535,
)
schema.add_field(
field_name='group',
datatype=pymilvus.DataType.VARCHAR,
max_length=1024,
)
schema.add_field(
field_name='embedding_json',
datatype=pymilvus.DataType.VARCHAR,
max_length=65535,
)
schema.add_field(
field_name='parent_uid',
datatype=pymilvus.DataType.VARCHAR,
max_length=1024,
)
schema.add_field(
field_name='metadata_json',
datatype=pymilvus.DataType.VARCHAR,
max_length=65535,
)
for field in embedding_index_info:
schema.add_field(
field_name=field.name,
datatype=field.type,
dim=field.dim)

for group in node_groups:
if group not in self._client.list_collections():
self._client.create_collection(collection_name=group, schema=schema)

self._map_store = MapStore(node_groups)
self._load_store()

# override
def update_nodes(self, nodes: List[DocNode]) -> None:
self._map_store.update_nodes(nodes)
self._save_nodes(nodes)
self._full_data_store.update_nodes(nodes)

# override
def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]:
return self._map_store.get_node(group_name, node_id)
return self._full_data_store.get_node(group_name, node_id)

# override
def remove_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
self._client.delete(collection_name=node.group, filter=f'uid in ["{node.uid}"]')
self._map_store.remove_nodes(nodes)
self._client.delete(collection_name=node.group,
filter=f'{self._primary_key} in ["{node.uid}"]')
self._full_data_store.remove_nodes(nodes)

# override
def has_group(self, group_name: str) -> bool:
return self._map_store.has_group(group_name)
def has_node(self, group_name: str) -> bool:
return self._full_data_store.has_node(group_name)

# override
def traverse_group(self, group_name: str) -> List[DocNode]:
return self._map_store.traverse_group(group_name)
def get_nodes(self, group_name: str) -> List[DocNode]:
return self._full_data_store.get_nodes(group_name)

def _load_store(self) -> None:
groups = self._client.list_collections()
for group in groups:
results = self._client.query(collection_name=group,
output_fields=["count(*)"],
limit=1)
if len(results) != 1:
raise ValueError(f"query count(*) of collection [{group}] failed.")

count = int(results[0]['count(*)'])
if count == 0:
continue
# override
def all_groups(self) -> List[str]:
return self._full_data_store.all_groups()

results = self._client.query(collection_name=group,
query_expression="uid in []",
output_fields=["*"],
limit=count)
for record in results:
doc_node = DocNode(
uid=record['uid'],
text=record['text'],
group=record['group'],
embedding=json.loads(record['embedding_json']),
parent=record['parent_uid'], # NOTE: will be updated later
metadata=json.loads(record['metadata_json']),
)
self._map_store.update_nodes([doc_node])

# update doc's parent
group_docs = self._map_store.get_group_docs()
for group, docs in group_docs.items():
for uid, doc in docs.items():
# before find_node_by_uid() `doc.parent` is the parent's uid
doc.parent = self._map_store.find_node_by_uid(doc.parent)
# override
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
return self._full_data_store.get_nodes_by_files(files)

def _save_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
data = {
'uid': node.uid,
'text': node.text,
'group': node.group,
'embedding_json': json.dumps(node.embedding),
'parent_uid': node.parent.uid if node.parent else '',
'metadata_json': json.dumps(node._metadata)
}
self._client.upsert(collection_name=node.group, data=data)

# XXX NOTE the following APIs should be private.

# override
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
return self._map_store.get_nodes_by_files(files)
if node.embedding:
data = node.embedding.copy()
data[self._primary_key] = node.uid
self._client.upsert(collection_name=node.group, data=data)
6 changes: 3 additions & 3 deletions tests/basic_tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ def test_retrieve(self):
def test_add_files(self):
assert self.doc_impl.store is None
self.doc_impl._lazy_init()
assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 1
assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 1
new_doc = DocNode(text="new dummy text", group=LAZY_ROOT_NAME)
new_doc.metadata = {"file_name": "new_file.txt"}
self.mock_directory_reader.load_data.return_value = [new_doc]
self.doc_impl._add_files(["new_file.txt"])
assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 2
assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 2

def test_delete_files(self):
self.doc_impl._delete_files(["dummy_file.txt"])
assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 0
assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 0


class TestDocument(unittest.TestCase):
Expand Down
Loading

0 comments on commit 303e0c7

Please sign in to comment.