Skip to content

Commit

Permalink
refactor index apis
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Oct 17, 2024
1 parent 5fefecc commit 323f0ba
Show file tree
Hide file tree
Showing 32 changed files with 533 additions and 402 deletions.
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .rerank import Reranker, register_reranker
from .transform import SentenceSplitter, LLMParser, NodeTransform, TransformArgs, AdaptiveTransform
from .index import register_similarity
from .store import DocNode
from .doc_node import DocNode
from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader,
MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader)
from .dataReader import SimpleDirectoryReader
Expand Down
16 changes: 16 additions & 0 deletions lazyllm/tools/rag/base_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from .doc_node import DocNode
from abc import ABC, abstractmethod
from typing import List

class BaseIndex(ABC):
@abstractmethod
def update(nodes: List[DocNode]) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def remove(uids: List[str]) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def query(self, *args, **kwargs) -> List[DocNode]:
raise NotImplementedError("not implemented yet.")
53 changes: 53 additions & 0 deletions lazyllm/tools/rag/base_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from typing import Optional, List, Dict
from .doc_node import DocNode
from .base_index import BaseIndex

class BaseStore(ABC):
@abstractmethod
def update_nodes(self, nodes: List[DocNode]) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def get_nodes(self, group_name: str) -> List[DocNode]:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def remove_nodes(self, uids: List[str]) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def has_nodes(self, group_name: str) -> bool:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def all_groups(self) -> List[str]:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def register_index(self, type: str, index: BaseIndex) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def remove_index(self, type: str) -> None:
raise NotImplementedError("not implemented yet.")

@abstractmethod
def get_index(self, type: str) -> Optional[BaseIndex]:
raise NotImplementedError("not implemented yet.")

# ----- helper functions ----- #

@staticmethod
def _update_indices(name2index: Dict[str, BaseIndex], nodes: List[DocNode]) -> None:
for _, index in name2index.items():
index.update(nodes)

@staticmethod
def _remove_from_indices(name2index: Dict[str, BaseIndex], uids: List[str]) -> None:
for _, index in name2index.items():
index.remove(uids)
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/component/bm25.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Tuple
from ..store import DocNode
from ..doc_node import DocNode
import bm25s
import Stemmer
from lazyllm.thirdparty import jieba
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/dataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pathlib import Path, PurePosixPath, PurePath
from fsspec import AbstractFileSystem
from lazyllm import ModuleBase, LOG
from .store import DocNode
from .doc_node import DocNode
from .readers import (ReaderBase, PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader,
EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader,
get_default_fs, is_default_fs)
Expand Down
3 changes: 2 additions & 1 deletion lazyllm/tools/rag/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional, Dict
from .store import DocNode, LAZY_ROOT_NAME
from .doc_node import DocNode
from .store import LAZY_ROOT_NAME
from lazyllm import LOG
from .dataReader import SimpleDirectoryReader

Expand Down
48 changes: 41 additions & 7 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,38 @@
from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, BaseStore
from .data_loaders import DirectoryReader
from .index import DefaultIndex
from .base_index import BaseIndex
from .utils import DocListManager
import threading
import time

_transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser)

class FileNodeIndex(BaseIndex):
def __init__(self):
self._file_node_map = {}

# override
def update(self, nodes: List[DocNode]) -> None:
for node in nodes:
if node.group != LAZY_ROOT_NAME:
continue
file_name = node.metadata.get("file_name")
if file_name:
self._file_node_map[file_name] = node

# override
def remove(self, uids: List[str]) -> None:
left = {k: v for k, v in self._file_node_map.items() if v.uid not in uids}
self._file_node_map = left

# override
def query(self, files: List[str]) -> List[DocNode]:
ret = []
for file in files:
ret.append(self._file_node_map.get(file))
return ret


def embed_wrapper(func):
if not func:
Expand Down Expand Up @@ -44,6 +70,12 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N
self.embed = {k: embed_wrapper(e) for k, e in embed.items()}
self.store = None

def _create_file_node_index(self, store) -> FileNodeIndex:
index = FileNodeIndex()
for group in store.all_groups():
index.update(store.get_nodes(group))
return index

@once_wrapper(reset_on_pickle=True)
def _lazy_init(self) -> None:
node_groups = DocImpl._builtin_node_groups.copy()
Expand All @@ -52,7 +84,9 @@ def _lazy_init(self) -> None:
self.node_groups = node_groups

self.store = self._get_store()
self.index = DefaultIndex(self.embed, self.store)
self.store.register_index(type='default', index=DefaultIndex(self.embed, self.store))
self.store.register_index(type='file_node_map', index=self._create_file_node_index(self.store))

if not self.store.has_nodes(LAZY_ROOT_NAME):
ids, pathes = self._list_files()
root_nodes = self._reader.load_data(pathes)
Expand Down Expand Up @@ -191,28 +225,28 @@ def _add_files(self, input_files: List[str]):

def _delete_files(self, input_files: List[str]) -> None:
self._lazy_init()
docs = self.store.get_nodes_by_files(input_files)
docs = self.store.get_index('file_node_map').query(input_files)
LOG.info(f"delete_files: removing documents {input_files} and nodes {docs}")
if len(docs) == 0:
return
self._delete_nodes_recursively(docs)

def _delete_nodes_recursively(self, root_nodes: List[DocNode]) -> None:
nodes_to_delete = defaultdict(list)
nodes_to_delete[LAZY_ROOT_NAME] = root_nodes
uids_to_delete = defaultdict(list)
uids_to_delete[LAZY_ROOT_NAME] = [node.uid for node in root_nodes]

# Gather all nodes to be deleted including their children
def gather_children(node: DocNode):
for children_group, children_list in node.children.items():
for child in children_list:
nodes_to_delete[children_group].append(child)
uids_to_delete[children_group].append(child.uid)
gather_children(child)

for node in root_nodes:
gather_children(node)

# Delete nodes in all groups
for group, node_uids in nodes_to_delete.items():
for group, node_uids in uids_to_delete.items():
self.store.remove_nodes(node_uids)
LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}")

Expand Down Expand Up @@ -241,7 +275,7 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_
if index:
assert index == "default", "we only support default index currently"
nodes = self._get_nodes(group_name)
return self.index.query(
return self.store.get_index('default').query(
query, nodes, similarity, similarity_cut_off, topk, embed_keys, **similarity_kws
)

Expand Down
151 changes: 151 additions & 0 deletions lazyllm/tools/rag/doc_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Optional, Dict, Any, Union, Callable, List
from enum import Enum, auto
from collections import defaultdict
from lazyllm import config
import uuid
import threading
import time

class MetadataMode(str, Enum):
ALL = auto()
EMBED = auto()
LLM = auto()
NONE = auto()


class DocNode:
def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None,
embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None,
metadata: Optional[Dict[str, Any]] = None, classfication: Optional[str] = None):
self.uid: str = uid if uid else str(uuid.uuid4())
self.text: Optional[str] = text
self.group: Optional[str] = group
self.embedding: Optional[Dict[str, List[float]]] = embedding or None
self._metadata: Dict[str, Any] = metadata or {}
# Metadata keys that are excluded from text for the embed model.
self._excluded_embed_metadata_keys: List[str] = []
# Metadata keys that are excluded from text for the LLM.
self._excluded_llm_metadata_keys: List[str] = []
self.parent: Optional["DocNode"] = parent
self.children: Dict[str, List["DocNode"]] = defaultdict(list)
self.is_saved: bool = False
self._docpath = None
self._lock = threading.Lock()
self._embedding_state = set()
# store will create index cache for classfication to speed up retrieve
self._classfication = classfication

@property
def root_node(self) -> Optional["DocNode"]:
root = self.parent
while root and root.parent:
root = root.parent
return root or self

@property
def metadata(self) -> Dict:
return self.root_node._metadata

@metadata.setter
def metadata(self, metadata: Dict) -> None:
self._metadata = metadata

@property
def excluded_embed_metadata_keys(self) -> List:
return self.root_node._excluded_embed_metadata_keys

@excluded_embed_metadata_keys.setter
def excluded_embed_metadata_keys(self, excluded_embed_metadata_keys: List) -> None:
self._excluded_embed_metadata_keys = excluded_embed_metadata_keys

@property
def excluded_llm_metadata_keys(self) -> List:
return self.root_node._excluded_llm_metadata_keys

@excluded_llm_metadata_keys.setter
def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None:
self._excluded_llm_metadata_keys = excluded_llm_metadata_keys

@property
def docpath(self) -> str:
return self.root_node._docpath or ''

@docpath.setter
def docpath(self, path):
assert not self.parent, 'Only root node can set docpath'
self._docpath = str(path)

def get_children_str(self) -> str:
return str(
{key: [node.uid for node in nodes] for key, nodes in self.children.items()}
)

def get_parent_id(self) -> str:
return self.parent.uid if self.parent else ""

def __str__(self) -> str:
return (
f"DocNode(id: {self.uid}, group: {self.group}, text: {self.get_text()}) parent: {self.get_parent_id()}, "
f"children: {self.get_children_str()}"
)

def __repr__(self) -> str:
return str(self) if config["debug"] else f'<Node id={self.uid}>'

def __eq__(self, other):
if isinstance(other, DocNode):
return self.uid == other.uid
return False

def __hash__(self):
return hash(self.uid)

def has_missing_embedding(self, embed_keys: Union[str, List[str]]) -> List[str]:
if isinstance(embed_keys, str): embed_keys = [embed_keys]
assert len(embed_keys) > 0, "The ebmed_keys to be checked must be passed in."
if self.embedding is None: return embed_keys
return [k for k in embed_keys if k not in self.embedding.keys() or self.embedding.get(k, [-1])[0] == -1]

def do_embedding(self, embed: Dict[str, Callable]) -> None:
generate_embed = {k: e(self.get_text(MetadataMode.EMBED)) for k, e in embed.items()}
with self._lock:
self.embedding = self.embedding or {}
self.embedding = {**self.embedding, **generate_embed}
self.is_saved = False

def check_embedding_state(self, embed_key: str) -> None:
while True:
with self._lock:
if not self.has_missing_embedding(embed_key):
self._embedding_state.discard(embed_key)
break
time.sleep(1)

def get_content(self) -> str:
return self.get_text(MetadataMode.LLM)

def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str:
"""Metadata info string."""
if mode == MetadataMode.NONE:
return ""

metadata_keys = set(self.metadata.keys())
if mode == MetadataMode.LLM:
for key in self.excluded_llm_metadata_keys:
if key in metadata_keys:
metadata_keys.remove(key)
elif mode == MetadataMode.EMBED:
for key in self.excluded_embed_metadata_keys:
if key in metadata_keys:
metadata_keys.remove(key)

return "\n".join([f"{key}: {self.metadata[key]}" for key in metadata_keys])

def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str:
metadata_str = self.get_metadata_str(metadata_mode).strip()
if not metadata_str:
return self.text if self.text else ""
return f"{metadata_str}\n\n{self.text}".strip()

def to_dict(self) -> Dict:
return dict(text=self.text, embedding=self.embedding, metadata=self.metadata)
Loading

0 comments on commit 323f0ba

Please sign in to comment.