Skip to content

Commit

Permalink
store api breaking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Oct 17, 2024
1 parent e42c8b9 commit 5fefecc
Show file tree
Hide file tree
Showing 8 changed files with 349 additions and 92 deletions.
19 changes: 10 additions & 9 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _lazy_init(self) -> None:
if not self.store.has_nodes(LAZY_ROOT_NAME):
ids, pathes = self._list_files()
root_nodes = self._reader.load_data(pathes)
self.store.add_nodes(root_nodes)
self.store.update_nodes(root_nodes)
if self._dlm: self._dlm.update_kb_group_file_status(
ids, DocListManager.Status.success, group=self._kb_group_name)
LOG.debug(f"building {LAZY_ROOT_NAME} nodes: {root_nodes}")
Expand All @@ -72,7 +72,6 @@ def _get_store(self) -> BaseStore:
store = MapStore(node_groups=self.node_groups.keys())
elif rag_store_type == "chroma":
store = ChromadbStore(node_groups=self.node_groups.keys(), embed=self.embed)
store.try_load_store()
else:
raise NotImplementedError(
f"Not implemented store type for {rag_store_type}"
Expand Down Expand Up @@ -179,13 +178,15 @@ def _add_files(self, input_files: List[str]):
self._lazy_init()
root_nodes = self._reader.load_data(input_files)
temp_store = self._get_store()
temp_store.add_nodes(root_nodes)
active_groups = self.store.active_groups()
LOG.info(f"add_files: Trying to merge store with {active_groups}")
for group in active_groups:
temp_store.update_nodes(root_nodes)
all_groups = self.store.all_groups()
LOG.info(f"add_files: Trying to merge store with {all_groups}")
for group in all_groups:
if not self.store.has_nodes(group):
continue
# Duplicate group will be discarded automatically
nodes = self._get_nodes(group, temp_store)
self.store.add_nodes(nodes)
self.store.update_nodes(nodes)
LOG.debug(f"Merge {group} with {nodes}")

def _delete_files(self, input_files: List[str]) -> None:
Expand Down Expand Up @@ -226,13 +227,13 @@ def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None:
transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t)
parent_nodes = self._get_nodes(node_group["parent"], store)
nodes = transform.batch_forward(parent_nodes, group_name)
store.add_nodes(nodes)
store.update_nodes(nodes)
LOG.debug(f"building {group_name} nodes: {nodes}")

def _get_nodes(self, group_name: str, store: Optional[BaseStore] = None) -> List[DocNode]:
store = store or self.store
self._dynamic_create_nodes(group_name, store)
return store.traverse_nodes(group_name)
return store.get_nodes(group_name)

def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_off: Union[float, Dict[str, float]],
index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]:
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def query(
assert len(query) > 0, "Query should not be empty."
query_embedding = {k: self.embed[k](query) for k in (embed_keys or self.embed.keys())}
nodes = self._parallel_do_embedding(nodes)
self.store.try_save_nodes(nodes)
self.store.update_nodes(nodes)
similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs)
elif mode == "text":
similarities = similarity_func(query, nodes, topk=topk, **kwargs)
Expand Down
242 changes: 175 additions & 67 deletions lazyllm/tools/rag/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import chromadb
from lazyllm import LOG, config
from chromadb.api.models.Collection import Collection
import pymilvus
import threading
import json
import time
Expand Down Expand Up @@ -161,96 +162,134 @@ def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
def to_dict(self) -> Dict:
return dict(text=self.text, embedding=self.embedding, metadata=self.metadata)

# ---------------------------------------------------------------------------- #

class BaseStore(ABC):
def __init__(self, node_groups: List[str]) -> None:
self._store: Dict[str, Dict[str, DocNode]] = {
group: {} for group in node_groups
}
self._file_node_map = {}

def _add_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
if node.group == LAZY_ROOT_NAME and "file_name" in node.metadata:
self._file_node_map[node.metadata["file_name"]] = node
self._store[node.group][node.uid] = node

def add_nodes(self, nodes: List[DocNode]) -> None:
self._add_nodes(nodes)
self.try_save_nodes(nodes)
@abstractmethod
def update_nodes(self, nodes: List[DocNode]) -> None:
raise NotImplementedError("not implemented yet.")

def has_nodes(self, group: str) -> bool:
return len(self._store[group]) > 0
@abstractmethod
def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]:
raise NotImplementedError("not implemented yet.")

def get_node(self, group: str, node_id: str) -> Optional[DocNode]:
return self._store.get(group, {}).get(node_id)
@abstractmethod
def get_nodes(self, group_name: str) -> List[DocNode]:
raise NotImplementedError("not implemented yet.")

def traverse_nodes(self, group: str) -> List[DocNode]:
return list(self._store.get(group, {}).values())
@abstractmethod
def remove_nodes(self, nodes: List[DocNode]) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def try_save_nodes(self, nodes: List[DocNode]) -> None:
# try save nodes to persistent source
raise NotImplementedError("Not implemented yet.")
def has_nodes(self, group_name: str) -> bool:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def try_load_store(self) -> None:
# try load nodes from persistent source
raise NotImplementedError("Not implemented yet.")
def all_groups(self) -> List[str]:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def try_remove_nodes(self, nodes: List[DocNode]) -> None:
# try remove nodes in persistent source
raise NotImplementedError("Not implemented yet.")
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
raise NotImplementedError("not implemented yet.")

def active_groups(self) -> List:
return [group for group, nodes in self._store.items() if nodes]
# ---------------------------------------------------------------------------- #

def _remove_nodes(self, nodes: List[DocNode]) -> None:
class MapStore(BaseStore):
def __init__(self, node_groups: List[str]):
# Dict[group_name, Dict[uuid, DocNode]]
self._group2docs: Dict[str, Dict[str, DocNode]] = {
group: {} for group in node_groups
}
self._file_node_map = {}

# override
def update_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
assert node.group in self._store, f"Unexpected node group {node.group}"
self._store[node.group].pop(node.uid, None)
if node.group == LAZY_ROOT_NAME and "file_name" in node.metadata:
self._file_node_map[node.metadata["file_name"]] = node
self._group2docs[node.group][node.uid] = node

# override
def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]:
return self._group2docs.get(group_name, {}).get(node_id)

# override
def remove_nodes(self, nodes: List[DocNode]) -> None:
self._remove_nodes(nodes)
self.try_remove_nodes(nodes)
for node in nodes:
assert node.group in self._group2docs, f"Unexpected node group {node.group}"
self._group2docs[node.group].pop(node.uid, None)

# override
def has_nodes(self, group_name: str) -> bool:
docs = self._group2docs.get(group_name)
return True if docs else False

# override
def get_nodes(self, group_name: str) -> List[DocNode]:
return list(self._group2docs.get(group_name, {}).values())

# override
def all_groups(self) -> List[str]:
return [group for group, nodes in self._group2docs.items()]

# override
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
nodes = []
for file in files:
if file in self._file_node_map:
nodes.append(self._file_node_map[file])
return nodes

def find_node_by_uid(self, uid: str) -> Optional[DocNode]:
for docs in self._group2docs.values():
doc = docs.get(uid)
if doc:
return doc
return None

class MapStore(BaseStore):
def __init__(self, node_groups: List[str], *args, **kwargs):
super().__init__(node_groups, *args, **kwargs)

def try_save_nodes(self, nodes: List[DocNode]) -> None:
pass

def try_load_store(self) -> None:
pass

def try_remove_nodes(self, nodes: List[DocNode]) -> None:
pass

# ---------------------------------------------------------------------------- #

class ChromadbStore(BaseStore):
def __init__(
self, node_groups: List[str], embed: Dict[str, Callable], *args, **kwargs
self, node_groups: List[str], embed: Dict[str, Callable]
) -> None:
super().__init__(node_groups, *args, **kwargs)
self._map_store = MapStore(node_groups)
self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"])
LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}")
self._collections: Dict[str, Collection] = {
group: self._db_client.get_or_create_collection(group)
for group in node_groups
}
self._placeholder = {k: [-1] * len(e("a")) for k, e in embed.items()} if embed else {EMBED_DEFAULT_KEY: []}
self._load_store()

# override
def update_nodes(self, nodes: List[DocNode]) -> None:
self._map_store.update_nodes(nodes)
self._save_nodes(nodes)

# override
def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]:
return self._map_store.get_node(group_name, node_id)

def try_load_store(self) -> None:
# override
def remove_nodes(self, nodes: List[DocNode]) -> None:
return self._map_store.remove_nodes(nodes)

# override
def has_nodes(self, group_name: str) -> bool:
return self._map_store.has_nodes(group_name)

# override
def get_nodes(self, group_name: str) -> List[DocNode]:
return self._map_store.get_nodes(group_name)

# override
def all_groups(self) -> List[str]:
return self._map_store.all_groups()

def _load_store(self) -> None:
if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]:
LOG.info("No persistent data found, skip the rebuilding phrase.")
return
Expand All @@ -259,20 +298,21 @@ def try_load_store(self) -> None:
for group in self._collections.keys():
results = self._peek_all_documents(group)
nodes = self._build_nodes_from_chroma(results)
self._add_nodes(nodes)
self._map_store.update_nodes(nodes)

# Rebuild relationships
for group, nodes_dict in self._store.items():
for node in nodes_dict.values():
for group_name in self._map_store.all_groups():
nodes = self._map_store.get_nodes(group_name)
for node in nodes:
if node.parent:
parent_uid = node.parent
parent_node = self._find_node_by_uid(parent_uid)
parent_node = self._map_store.find_node_by_uid(parent_uid)
node.parent = parent_node
parent_node.children[node.group].append(node)
LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}")
LOG.debug(f"build {group} nodes from chromadb: {nodes}")
LOG.success("Successfully Built nodes from chromadb.")

def try_save_nodes(self, nodes: List[DocNode]) -> None:
def _save_nodes(self, nodes: List[DocNode]) -> None:
if not nodes:
return
# Note: It's caller's duty to make sure this batch of nodes has the same group.
Expand Down Expand Up @@ -304,14 +344,9 @@ def try_save_nodes(self, nodes: List[DocNode]) -> None:
)
LOG.debug(f"Saved {group} nodes {ids} to chromadb.")

def try_remove_nodes(self, nodes: List[DocNode]) -> None:
pass

def _find_node_by_uid(self, uid: str) -> Optional[DocNode]:
for nodes_by_category in self._store.values():
if uid in nodes_by_category:
return nodes_by_category[uid]
raise ValueError(f"UID {uid} not found in store.")
# override
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
return self._map_store.get_nodes_by_files(files)

def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]:
nodes: List[DocNode] = []
Expand Down Expand Up @@ -339,3 +374,76 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]:
assert group in self._collections, f"group {group} not found."
collection = self._collections[group]
return collection.peek(collection.count())

# ---------------------------------------------------------------------------- #

class MilvusEmbeddingIndexField:
def __init__(self, name: str = "", dim: int = 0, type: int = pymilvus.DataType.FLOAT_VECTOR):
self.name = name
self.dim: int = dim
self.type: int = type

class MilvusStore(BaseStore):
def __init__(self, node_groups: List[str], uri: str,
embedding_index_info: List[MilvusEmbeddingIndexField], # fields to be indexed by Milvus
full_data_store: BaseStore):
self._full_data_store = full_data_store

self._primary_key = 'uid'
self._client = pymilvus.MilvusClient(uri=uri)

schema = self._client.create_schema(auto_id=False, enable_dynamic_field=True)
schema.add_field(
field_name=self._primary_key,
datatype=pymilvus.DataType.VARCHAR,
max_length=128,
is_primary=True,
)
for field in embedding_index_info:
schema.add_field(
field_name=field.name,
datatype=field.type,
dim=field.dim)

for group in node_groups:
if group not in self._client.list_collections():
self._client.create_collection(collection_name=group, schema=schema)

# override
def update_nodes(self, nodes: List[DocNode]) -> None:
self._save_nodes(nodes)
self._full_data_store.update_nodes(nodes)

# override
def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]:
return self._full_data_store.get_node(group_name, node_id)

# override
def remove_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
self._client.delete(collection_name=node.group,
filter=f'{self._primary_key} in ["{node.uid}"]')
self._full_data_store.remove_nodes(nodes)

# override
def has_nodes(self, group_name: str) -> bool:
return self._full_data_store.has_nodes(group_name)

# override
def get_nodes(self, group_name: str) -> List[DocNode]:
return self._full_data_store.get_nodes(group_name)

# override
def all_groups(self) -> List[str]:
return self._full_data_store.all_groups()

# override
def get_nodes_by_files(self, files: List[str]) -> List[DocNode]:
return self._full_data_store.get_nodes_by_files(files)

def _save_nodes(self, nodes: List[DocNode]) -> None:
for node in nodes:
if node.embedding:
data = node.embedding.copy()
data[self._primary_key] = node.uid
self._client.upsert(collection_name=node.group, data=data)
1 change: 1 addition & 0 deletions requirements.full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ json5
tiktoken
spacy<=3.7.5
chromadb
pymilvus
bm25s
pystemmer
nltk
Expand Down
Loading

0 comments on commit 5fefecc

Please sign in to comment.