diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 783e4744..9e587073 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -3,7 +3,7 @@ from .rerank import Reranker, register_reranker from .transform import SentenceSplitter, LLMParser, NodeTransform, TransformArgs, AdaptiveTransform from .index import register_similarity -from .store import DocNode +from .doc_node import DocNode from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader) from .dataReader import SimpleDirectoryReader diff --git a/lazyllm/tools/rag/base_index.py b/lazyllm/tools/rag/base_index.py new file mode 100644 index 00000000..7630ec6d --- /dev/null +++ b/lazyllm/tools/rag/base_index.py @@ -0,0 +1,16 @@ +from .doc_node import DocNode +from abc import ABC, abstractmethod +from typing import List + +class BaseIndex(ABC): + @abstractmethod + def update(nodes: List[DocNode]) -> None: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def remove(uids: List[str]) -> None: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def query(self, *args, **kwargs) -> List[DocNode]: + raise NotImplementedError("not implemented yet.") diff --git a/lazyllm/tools/rag/base_store.py b/lazyllm/tools/rag/base_store.py new file mode 100644 index 00000000..fd7808cc --- /dev/null +++ b/lazyllm/tools/rag/base_store.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Dict +from .doc_node import DocNode +from .base_index import BaseIndex + +class BaseStore(ABC): + @abstractmethod + def update_nodes(self, nodes: List[DocNode]) -> None: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def get_nodes(self, group_name: str) -> List[DocNode]: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def remove_nodes(self, uids: List[str]) -> None: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def has_nodes(self, group_name: str) -> bool: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def all_groups(self) -> List[str]: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def register_index(self, type: str, index: BaseIndex) -> None: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def remove_index(self, type: str) -> None: + raise NotImplementedError("not implemented yet.") + + @abstractmethod + def get_index(self, type: str) -> Optional[BaseIndex]: + raise NotImplementedError("not implemented yet.") + + # ----- helper functions ----- # + + @staticmethod + def _update_indices(name2index: Dict[str, BaseIndex], nodes: List[DocNode]) -> None: + for _, index in name2index.items(): + index.update(nodes) + + @staticmethod + def _remove_from_indices(name2index: Dict[str, BaseIndex], uids: List[str]) -> None: + for _, index in name2index.items(): + index.remove(uids) diff --git a/lazyllm/tools/rag/component/bm25.py b/lazyllm/tools/rag/component/bm25.py index 171c5d97..56881869 100644 --- a/lazyllm/tools/rag/component/bm25.py +++ b/lazyllm/tools/rag/component/bm25.py @@ -1,5 +1,5 @@ from typing import List, Tuple -from ..store import DocNode +from ..doc_node import DocNode import bm25s import Stemmer from lazyllm.thirdparty import jieba diff --git a/lazyllm/tools/rag/dataReader.py b/lazyllm/tools/rag/dataReader.py index 116f3559..319525e0 100644 --- a/lazyllm/tools/rag/dataReader.py +++ b/lazyllm/tools/rag/dataReader.py @@ -14,7 +14,7 @@ from pathlib import Path, PurePosixPath, PurePath from fsspec import AbstractFileSystem from lazyllm import ModuleBase, LOG -from .store import DocNode +from .doc_node import DocNode from .readers import (ReaderBase, PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader, get_default_fs, is_default_fs) diff --git a/lazyllm/tools/rag/data_loaders.py b/lazyllm/tools/rag/data_loaders.py index 02a11c9a..0212fc17 100644 --- a/lazyllm/tools/rag/data_loaders.py +++ b/lazyllm/tools/rag/data_loaders.py @@ -1,5 +1,6 @@ from typing import List, Optional, Dict -from .store import DocNode, LAZY_ROOT_NAME +from .doc_node import DocNode +from .store import LAZY_ROOT_NAME from lazyllm import LOG from .dataReader import SimpleDirectoryReader diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 0e5c0c6a..cf14753e 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -8,12 +8,38 @@ from .store import MapStore, DocNode, ChromadbStore, LAZY_ROOT_NAME, BaseStore from .data_loaders import DirectoryReader from .index import DefaultIndex +from .base_index import BaseIndex from .utils import DocListManager import threading import time _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) +class FileNodeIndex(BaseIndex): + def __init__(self): + self._file_node_map = {} + + # override + def update(self, nodes: List[DocNode]) -> None: + for node in nodes: + if node.group != LAZY_ROOT_NAME: + continue + file_name = node.metadata.get("file_name") + if file_name: + self._file_node_map[file_name] = node + + # override + def remove(self, uids: List[str]) -> None: + left = {k: v for k, v in self._file_node_map.items() if v.uid not in uids} + self._file_node_map = left + + # override + def query(self, files: List[str]) -> List[DocNode]: + ret = [] + for file in files: + ret.append(self._file_node_map.get(file)) + return ret + def embed_wrapper(func): if not func: @@ -44,6 +70,12 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N self.embed = {k: embed_wrapper(e) for k, e in embed.items()} self.store = None + def _create_file_node_index(self, store) -> FileNodeIndex: + index = FileNodeIndex() + for group in store.all_groups(): + index.update(store.get_nodes(group)) + return index + @once_wrapper(reset_on_pickle=True) def _lazy_init(self) -> None: node_groups = DocImpl._builtin_node_groups.copy() @@ -52,7 +84,9 @@ def _lazy_init(self) -> None: self.node_groups = node_groups self.store = self._get_store() - self.index = DefaultIndex(self.embed, self.store) + self.store.register_index(type='default', index=DefaultIndex(self.embed, self.store)) + self.store.register_index(type='file_node_map', index=self._create_file_node_index(self.store)) + if not self.store.has_nodes(LAZY_ROOT_NAME): ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) @@ -191,28 +225,28 @@ def _add_files(self, input_files: List[str]): def _delete_files(self, input_files: List[str]) -> None: self._lazy_init() - docs = self.store.get_nodes_by_files(input_files) + docs = self.store.get_index('file_node_map').query(input_files) LOG.info(f"delete_files: removing documents {input_files} and nodes {docs}") if len(docs) == 0: return self._delete_nodes_recursively(docs) def _delete_nodes_recursively(self, root_nodes: List[DocNode]) -> None: - nodes_to_delete = defaultdict(list) - nodes_to_delete[LAZY_ROOT_NAME] = root_nodes + uids_to_delete = defaultdict(list) + uids_to_delete[LAZY_ROOT_NAME] = [node.uid for node in root_nodes] # Gather all nodes to be deleted including their children def gather_children(node: DocNode): for children_group, children_list in node.children.items(): for child in children_list: - nodes_to_delete[children_group].append(child) + uids_to_delete[children_group].append(child.uid) gather_children(child) for node in root_nodes: gather_children(node) # Delete nodes in all groups - for group, node_uids in nodes_to_delete.items(): + for group, node_uids in uids_to_delete.items(): self.store.remove_nodes(node_uids) LOG.debug(f"Removed nodes from group {group} for node IDs: {node_uids}") @@ -241,7 +275,7 @@ def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_ if index: assert index == "default", "we only support default index currently" nodes = self._get_nodes(group_name) - return self.index.query( + return self.store.get_index('default').query( query, nodes, similarity, similarity_cut_off, topk, embed_keys, **similarity_kws ) diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py new file mode 100644 index 00000000..21f3e1a6 --- /dev/null +++ b/lazyllm/tools/rag/doc_node.py @@ -0,0 +1,151 @@ +from typing import Optional, Dict, Any, Union, Callable, List +from enum import Enum, auto +from collections import defaultdict +from lazyllm import config +import uuid +import threading +import time + +class MetadataMode(str, Enum): + ALL = auto() + EMBED = auto() + LLM = auto() + NONE = auto() + + +class DocNode: + def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None, + embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None, + metadata: Optional[Dict[str, Any]] = None, classfication: Optional[str] = None): + self.uid: str = uid if uid else str(uuid.uuid4()) + self.text: Optional[str] = text + self.group: Optional[str] = group + self.embedding: Optional[Dict[str, List[float]]] = embedding or None + self._metadata: Dict[str, Any] = metadata or {} + # Metadata keys that are excluded from text for the embed model. + self._excluded_embed_metadata_keys: List[str] = [] + # Metadata keys that are excluded from text for the LLM. + self._excluded_llm_metadata_keys: List[str] = [] + self.parent: Optional["DocNode"] = parent + self.children: Dict[str, List["DocNode"]] = defaultdict(list) + self.is_saved: bool = False + self._docpath = None + self._lock = threading.Lock() + self._embedding_state = set() + # store will create index cache for classfication to speed up retrieve + self._classfication = classfication + + @property + def root_node(self) -> Optional["DocNode"]: + root = self.parent + while root and root.parent: + root = root.parent + return root or self + + @property + def metadata(self) -> Dict: + return self.root_node._metadata + + @metadata.setter + def metadata(self, metadata: Dict) -> None: + self._metadata = metadata + + @property + def excluded_embed_metadata_keys(self) -> List: + return self.root_node._excluded_embed_metadata_keys + + @excluded_embed_metadata_keys.setter + def excluded_embed_metadata_keys(self, excluded_embed_metadata_keys: List) -> None: + self._excluded_embed_metadata_keys = excluded_embed_metadata_keys + + @property + def excluded_llm_metadata_keys(self) -> List: + return self.root_node._excluded_llm_metadata_keys + + @excluded_llm_metadata_keys.setter + def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None: + self._excluded_llm_metadata_keys = excluded_llm_metadata_keys + + @property + def docpath(self) -> str: + return self.root_node._docpath or '' + + @docpath.setter + def docpath(self, path): + assert not self.parent, 'Only root node can set docpath' + self._docpath = str(path) + + def get_children_str(self) -> str: + return str( + {key: [node.uid for node in nodes] for key, nodes in self.children.items()} + ) + + def get_parent_id(self) -> str: + return self.parent.uid if self.parent else "" + + def __str__(self) -> str: + return ( + f"DocNode(id: {self.uid}, group: {self.group}, text: {self.get_text()}) parent: {self.get_parent_id()}, " + f"children: {self.get_children_str()}" + ) + + def __repr__(self) -> str: + return str(self) if config["debug"] else f'' + + def __eq__(self, other): + if isinstance(other, DocNode): + return self.uid == other.uid + return False + + def __hash__(self): + return hash(self.uid) + + def has_missing_embedding(self, embed_keys: Union[str, List[str]]) -> List[str]: + if isinstance(embed_keys, str): embed_keys = [embed_keys] + assert len(embed_keys) > 0, "The ebmed_keys to be checked must be passed in." + if self.embedding is None: return embed_keys + return [k for k in embed_keys if k not in self.embedding.keys() or self.embedding.get(k, [-1])[0] == -1] + + def do_embedding(self, embed: Dict[str, Callable]) -> None: + generate_embed = {k: e(self.get_text(MetadataMode.EMBED)) for k, e in embed.items()} + with self._lock: + self.embedding = self.embedding or {} + self.embedding = {**self.embedding, **generate_embed} + self.is_saved = False + + def check_embedding_state(self, embed_key: str) -> None: + while True: + with self._lock: + if not self.has_missing_embedding(embed_key): + self._embedding_state.discard(embed_key) + break + time.sleep(1) + + def get_content(self) -> str: + return self.get_text(MetadataMode.LLM) + + def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: + """Metadata info string.""" + if mode == MetadataMode.NONE: + return "" + + metadata_keys = set(self.metadata.keys()) + if mode == MetadataMode.LLM: + for key in self.excluded_llm_metadata_keys: + if key in metadata_keys: + metadata_keys.remove(key) + elif mode == MetadataMode.EMBED: + for key in self.excluded_embed_metadata_keys: + if key in metadata_keys: + metadata_keys.remove(key) + + return "\n".join([f"{key}: {self.metadata[key]}" for key in metadata_keys]) + + def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: + metadata_str = self.get_metadata_str(metadata_mode).strip() + if not metadata_str: + return self.text if self.text else "" + return f"{metadata_str}\n\n{self.text}".strip() + + def to_dict(self) -> Dict: + return dict(text=self.text, embedding=self.embedding, metadata=self.metadata) diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index 78a7a59e..36d48613 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -1,10 +1,15 @@ import concurrent import os from typing import List, Callable, Optional, Dict, Union, Tuple -from .store import DocNode, BaseStore +from .doc_node import DocNode +from .base_store import BaseStore +from .base_index import BaseIndex import numpy as np from .component.bm25 import BM25 from lazyllm import LOG, config, ThreadPoolExecutor +import pymilvus + +# ---------------------------------------------------------------------------- # # min(32, (os.cpu_count() or 1) + 4) is the default number of workers for ThreadPoolExecutor config.add( @@ -14,8 +19,9 @@ "MAX_EMBEDDING_WORKERS", ) +# ---------------------------------------------------------------------------- # -class DefaultIndex: +class DefaultIndex(BaseIndex): """Default Index, registered for similarity functions""" registered_similarity = dict() @@ -55,6 +61,7 @@ def wrapper(query, nodes, **kwargs): return decorator(func) if func else decorator + # TODO XXX returns modified nodes def _parallel_do_embedding(self, nodes: List[DocNode]) -> List[DocNode]: with ThreadPoolExecutor(config["max_embedding_workers"]) as executor: futures = [] @@ -74,6 +81,15 @@ def _parallel_do_embedding(self, nodes: List[DocNode]) -> List[DocNode]: future.result() return nodes + # override + def update(self, nodes: List[DocNode]) -> None: + pass + + # override + def remove(self, uids: List[str]) -> None: + pass + + # override def query( self, query: str, @@ -150,3 +166,94 @@ def register_similarity( batch: bool = False, ) -> Callable: return DefaultIndex.register_similarity(func, mode, descend, batch) + +# ---------------------------------------------------------------------------- # + +class MilvusEmbeddingField: + def __init__(self, name: str, dim: int, data_type: int, index_type: str, + metric_type: str, index_params={}): + self.name = name + self.dim = dim + self.data_type = data_type + self.index_type = index_type + self.metric_type = metric_type + self.index_params = index_params + +class MilvusIndex(BaseIndex): + def __init__(self, group_embedding_fields: Dict[str, List[MilvusEmbeddingField]], + uri: str, full_data_store: BaseStore): + self._full_data_store = full_data_store + + self._primary_key = 'uid' + self._group_key = 'group_name' + self._client = pymilvus.MilvusClient(uri=uri) + + for group_name, embedding_fields in group_embedding_fields.items(): + schema = self._client.create_schema(auto_id=False, enable_dynamic_field=False) + schema.add_field( + field_name=self._primary_key, + datatype=pymilvus.DataType.VARCHAR, + max_length=128, + is_primary=True, + ) + schema.add_field( + field_name=self._group_key, + datatype=pymilvus.DataType.VARCHAR, + max_length=128, + ) + for field in embedding_fields: + schema.add_field( + field_name=field.name, + datatype=field.data_type, + dim=field.dim) + + index_params = self._client.prepare_index_params() + for field in embedding_fields: + index_params.add_index(field_name=field.name, index_type=field.index_type, + metric_type=field.metric_type, params=field.index_params) + + self._client.create_collection(collection_name=group_name, schema=schema, + index_params=index_params) + + # override + def update(self, nodes: List[DocNode]) -> None: + for node in nodes: + if node.embedding: + data = node.embedding.copy() + data[self._primary_key] = node.uid + data[self._group_key] = node.group + self._client.upsert(collection_name=node.group, data=data) + + # override + def remove(self, uids: List[str]) -> None: + for group in self._client.list_collections(): + self._client.delete(collection_name=group, + filter=f'{self._primary_key} in {uids}') + + # override + def query(self, + group_name: str, + data: List[float], + filter: str = "", + limit: int = 10, + search_params: Optional[dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None) -> List[DocNode]: + results = self._client.search(collection_name=group_name, data=[data], + filter=filter, limit=limit, + output_fields=[self._group_key], + search_params=search_params, + timeout=timeout, partition_name=partition_names, + anns_field=anns_field) + if len(results) == 0: + return [] + + docs = [] + for result in results[0]: + uid = result['id'] + group_name = result['entity'][self._group_key] + doc = self._full_data_store.get_node(group_name, uid) + if doc: + docs.append(doc) + return docs diff --git a/lazyllm/tools/rag/readers/docxReader.py b/lazyllm/tools/rag/readers/docxReader.py index ff472013..7e0ab3b6 100644 --- a/lazyllm/tools/rag/readers/docxReader.py +++ b/lazyllm/tools/rag/readers/docxReader.py @@ -3,7 +3,7 @@ from typing import Dict, Optional, List from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class DocxReader(LazyLLMReaderBase): def _load_data(self, file: Path, extra_info: Optional[Dict] = None, diff --git a/lazyllm/tools/rag/readers/epubReader.py b/lazyllm/tools/rag/readers/epubReader.py index 0e208dbf..747c3402 100644 --- a/lazyllm/tools/rag/readers/epubReader.py +++ b/lazyllm/tools/rag/readers/epubReader.py @@ -3,7 +3,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode from lazyllm import LOG class EpubReader(LazyLLMReaderBase): diff --git a/lazyllm/tools/rag/readers/hwpReader.py b/lazyllm/tools/rag/readers/hwpReader.py index 9678336e..35f33b9c 100644 --- a/lazyllm/tools/rag/readers/hwpReader.py +++ b/lazyllm/tools/rag/readers/hwpReader.py @@ -5,7 +5,7 @@ import zlib from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode from lazyllm import LOG class HWPReader(LazyLLMReaderBase): diff --git a/lazyllm/tools/rag/readers/imageReader.py b/lazyllm/tools/rag/readers/imageReader.py index ee610bbc..fe05f57f 100644 --- a/lazyllm/tools/rag/readers/imageReader.py +++ b/lazyllm/tools/rag/readers/imageReader.py @@ -7,7 +7,7 @@ from PIL import Image from .readerBase import LazyLLMReaderBase, infer_torch_device -from ..store import DocNode +from ..doc_node import DocNode def img_2_b64(image: Image, format: str = "JPEG") -> str: buff = BytesIO() diff --git a/lazyllm/tools/rag/readers/ipynbReader.py b/lazyllm/tools/rag/readers/ipynbReader.py index 66c0e192..90e0cc5c 100644 --- a/lazyllm/tools/rag/readers/ipynbReader.py +++ b/lazyllm/tools/rag/readers/ipynbReader.py @@ -4,7 +4,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class IPYNBReader(LazyLLMReaderBase): def __init__(self, parser_config: Optional[Dict] = None, concatenate: bool = False, return_trace: bool = True): diff --git a/lazyllm/tools/rag/readers/markdownReader.py b/lazyllm/tools/rag/readers/markdownReader.py index c1748f55..0184576b 100644 --- a/lazyllm/tools/rag/readers/markdownReader.py +++ b/lazyllm/tools/rag/readers/markdownReader.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Tuple from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class MarkdownReader(LazyLLMReaderBase): def __init__(self, remove_hyperlinks: bool = True, remove_images: bool = True, return_trace: bool = True) -> None: diff --git a/lazyllm/tools/rag/readers/mboxreader.py b/lazyllm/tools/rag/readers/mboxreader.py index 567854c8..3fab832c 100644 --- a/lazyllm/tools/rag/readers/mboxreader.py +++ b/lazyllm/tools/rag/readers/mboxreader.py @@ -3,7 +3,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode from lazyllm import LOG class MboxReader(LazyLLMReaderBase): diff --git a/lazyllm/tools/rag/readers/pandasReader.py b/lazyllm/tools/rag/readers/pandasReader.py index bbe2cb60..e3ad327a 100644 --- a/lazyllm/tools/rag/readers/pandasReader.py +++ b/lazyllm/tools/rag/readers/pandasReader.py @@ -5,7 +5,7 @@ import pandas as pd from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class PandasCSVReader(LazyLLMReaderBase): def __init__(self, concat_rows: bool = True, col_joiner: str = ", ", row_joiner: str = "\n", diff --git a/lazyllm/tools/rag/readers/pdfReader.py b/lazyllm/tools/rag/readers/pdfReader.py index 8982c424..a0a4043f 100644 --- a/lazyllm/tools/rag/readers/pdfReader.py +++ b/lazyllm/tools/rag/readers/pdfReader.py @@ -5,7 +5,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase, get_default_fs, is_default_fs -from ..store import DocNode +from ..doc_node import DocNode RETRY_TIMES = 3 diff --git a/lazyllm/tools/rag/readers/pptxReader.py b/lazyllm/tools/rag/readers/pptxReader.py index 8085844d..3ae216e8 100644 --- a/lazyllm/tools/rag/readers/pptxReader.py +++ b/lazyllm/tools/rag/readers/pptxReader.py @@ -5,7 +5,7 @@ from typing import Optional, Dict, List from .readerBase import LazyLLMReaderBase, infer_torch_device -from ..store import DocNode +from ..doc_node import DocNode class PPTXReader(LazyLLMReaderBase): def __init__(self, return_trace: bool = True) -> None: diff --git a/lazyllm/tools/rag/readers/readerBase.py b/lazyllm/tools/rag/readers/readerBase.py index 70515e52..3e9d0355 100644 --- a/lazyllm/tools/rag/readers/readerBase.py +++ b/lazyllm/tools/rag/readers/readerBase.py @@ -3,7 +3,7 @@ from typing import Iterable, List from ....common import LazyLLMRegisterMetaClass -from ..store import DocNode +from ..doc_node import DocNode from lazyllm.module import ModuleBase class LazyLLMReaderBase(ModuleBase, metaclass=LazyLLMRegisterMetaClass): diff --git a/lazyllm/tools/rag/readers/videoAudioReader.py b/lazyllm/tools/rag/readers/videoAudioReader.py index 02236e75..bdd41e1d 100644 --- a/lazyllm/tools/rag/readers/videoAudioReader.py +++ b/lazyllm/tools/rag/readers/videoAudioReader.py @@ -3,7 +3,7 @@ from fsspec import AbstractFileSystem from .readerBase import LazyLLMReaderBase -from ..store import DocNode +from ..doc_node import DocNode class VideoAudioReader(LazyLLMReaderBase): def __init__(self, model_version: str = "base", return_trace: bool = True) -> None: diff --git a/lazyllm/tools/rag/rerank.py b/lazyllm/tools/rag/rerank.py index 8253caf7..bf0eadc1 100644 --- a/lazyllm/tools/rag/rerank.py +++ b/lazyllm/tools/rag/rerank.py @@ -1,8 +1,8 @@ from functools import lru_cache from typing import Callable, List, Optional, Union from lazyllm import ModuleBase, config, LOG -from lazyllm.tools.rag.store import DocNode, MetadataMode from lazyllm.components.utils.downloader import ModelManager +from .doc_node import DocNode, MetadataMode from .retriever import _PostProcess import numpy as np diff --git a/lazyllm/tools/rag/retriever.py b/lazyllm/tools/rag/retriever.py index b826bdbe..5443f2ee 100644 --- a/lazyllm/tools/rag/retriever.py +++ b/lazyllm/tools/rag/retriever.py @@ -1,5 +1,5 @@ from lazyllm import ModuleBase, pipeline, once_wrapper -from .store import DocNode +from .doc_node import DocNode from .document import Document, DocImpl from typing import List, Optional, Union, Dict diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 57cee198..3821557c 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -1,198 +1,19 @@ -from abc import ABC, abstractmethod -from collections import defaultdict -from enum import Enum, auto -import uuid -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import chromadb from lazyllm import LOG, config from chromadb.api.models.Collection import Collection -import pymilvus -import threading +from .base_store import BaseStore +from .base_index import BaseIndex +from .doc_node import DocNode import json -import time +# ---------------------------------------------------------------------------- # LAZY_ROOT_NAME = "lazyllm_root" EMBED_DEFAULT_KEY = '__default__' config.add("rag_store_type", str, "map", "RAG_STORE_TYPE") # "map", "chroma" config.add("rag_persistent_path", str, "./lazyllm_chroma", "RAG_PERSISTENT_PATH") - -class MetadataMode(str, Enum): - ALL = auto() - EMBED = auto() - LLM = auto() - NONE = auto() - - -class DocNode: - def __init__(self, uid: Optional[str] = None, text: Optional[str] = None, group: Optional[str] = None, - embedding: Optional[Dict[str, List[float]]] = None, parent: Optional["DocNode"] = None, - metadata: Optional[Dict[str, Any]] = None, classfication: Optional[str] = None): - self.uid: str = uid if uid else str(uuid.uuid4()) - self.text: Optional[str] = text - self.group: Optional[str] = group - self.embedding: Optional[Dict[str, List[float]]] = embedding or None - self._metadata: Dict[str, Any] = metadata or {} - # Metadata keys that are excluded from text for the embed model. - self._excluded_embed_metadata_keys: List[str] = [] - # Metadata keys that are excluded from text for the LLM. - self._excluded_llm_metadata_keys: List[str] = [] - self.parent: Optional["DocNode"] = parent - self.children: Dict[str, List["DocNode"]] = defaultdict(list) - self.is_saved: bool = False - self._docpath = None - self._lock = threading.Lock() - self._embedding_state = set() - # store will create index cache for classfication to speed up retrieve - self._classfication = classfication - - @property - def root_node(self) -> Optional["DocNode"]: - root = self.parent - while root and root.parent: - root = root.parent - return root or self - - @property - def metadata(self) -> Dict: - return self.root_node._metadata - - @metadata.setter - def metadata(self, metadata: Dict) -> None: - self._metadata = metadata - - @property - def excluded_embed_metadata_keys(self) -> List: - return self.root_node._excluded_embed_metadata_keys - - @excluded_embed_metadata_keys.setter - def excluded_embed_metadata_keys(self, excluded_embed_metadata_keys: List) -> None: - self._excluded_embed_metadata_keys = excluded_embed_metadata_keys - - @property - def excluded_llm_metadata_keys(self) -> List: - return self.root_node._excluded_llm_metadata_keys - - @excluded_llm_metadata_keys.setter - def excluded_llm_metadata_keys(self, excluded_llm_metadata_keys: List) -> None: - self._excluded_llm_metadata_keys = excluded_llm_metadata_keys - - @property - def docpath(self) -> str: - return self.root_node._docpath or '' - - @docpath.setter - def docpath(self, path): - assert not self.parent, 'Only root node can set docpath' - self._docpath = str(path) - - def get_children_str(self) -> str: - return str( - {key: [node.uid for node in nodes] for key, nodes in self.children.items()} - ) - - def get_parent_id(self) -> str: - return self.parent.uid if self.parent else "" - - def __str__(self) -> str: - return ( - f"DocNode(id: {self.uid}, group: {self.group}, text: {self.get_text()}) parent: {self.get_parent_id()}, " - f"children: {self.get_children_str()}" - ) - - def __repr__(self) -> str: - return str(self) if config["debug"] else f'' - - def __eq__(self, other): - if isinstance(other, DocNode): - return self.uid == other.uid - return False - - def __hash__(self): - return hash(self.uid) - - def has_missing_embedding(self, embed_keys: Union[str, List[str]]) -> List[str]: - if isinstance(embed_keys, str): embed_keys = [embed_keys] - assert len(embed_keys) > 0, "The ebmed_keys to be checked must be passed in." - if self.embedding is None: return embed_keys - return [k for k in embed_keys if k not in self.embedding.keys() or self.embedding.get(k, [-1])[0] == -1] - - def do_embedding(self, embed: Dict[str, Callable]) -> None: - generate_embed = {k: e(self.get_text(MetadataMode.EMBED)) for k, e in embed.items()} - with self._lock: - self.embedding = self.embedding or {} - self.embedding = {**self.embedding, **generate_embed} - self.is_saved = False - - def check_embedding_state(self, embed_key: str) -> None: - while True: - with self._lock: - if not self.has_missing_embedding(embed_key): - self._embedding_state.discard(embed_key) - break - time.sleep(1) - - def get_content(self) -> str: - return self.get_text(MetadataMode.LLM) - - def get_metadata_str(self, mode: MetadataMode = MetadataMode.ALL) -> str: - """Metadata info string.""" - if mode == MetadataMode.NONE: - return "" - - metadata_keys = set(self.metadata.keys()) - if mode == MetadataMode.LLM: - for key in self.excluded_llm_metadata_keys: - if key in metadata_keys: - metadata_keys.remove(key) - elif mode == MetadataMode.EMBED: - for key in self.excluded_embed_metadata_keys: - if key in metadata_keys: - metadata_keys.remove(key) - - return "\n".join([f"{key}: {self.metadata[key]}" for key in metadata_keys]) - - def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: - metadata_str = self.get_metadata_str(metadata_mode).strip() - if not metadata_str: - return self.text if self.text else "" - return f"{metadata_str}\n\n{self.text}".strip() - - def to_dict(self) -> Dict: - return dict(text=self.text, embedding=self.embedding, metadata=self.metadata) - -# ---------------------------------------------------------------------------- # - -class BaseStore(ABC): - @abstractmethod - def update_nodes(self, nodes: List[DocNode]) -> None: - raise NotImplementedError("not implemented yet.") - - @abstractmethod - def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: - raise NotImplementedError("not implemented yet.") - - @abstractmethod - def get_nodes(self, group_name: str) -> List[DocNode]: - raise NotImplementedError("not implemented yet.") - - @abstractmethod - def remove_nodes(self, nodes: List[DocNode]) -> None: - raise NotImplementedError("not implemented yet.") - - @abstractmethod - def has_nodes(self, group_name: str) -> bool: - raise NotImplementedError("not implemented yet.") - - @abstractmethod - def all_groups(self) -> List[str]: - raise NotImplementedError("not implemented yet.") - - @abstractmethod - def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: - raise NotImplementedError("not implemented yet.") - # ---------------------------------------------------------------------------- # class MapStore(BaseStore): @@ -201,24 +22,26 @@ def __init__(self, node_groups: List[str]): self._group2docs: Dict[str, Dict[str, DocNode]] = { group: {} for group in node_groups } - self._file_node_map = {} + self._name2index = {} # override def update_nodes(self, nodes: List[DocNode]) -> None: for node in nodes: - if node.group == LAZY_ROOT_NAME and "file_name" in node.metadata: - self._file_node_map[node.metadata["file_name"]] = node self._group2docs[node.group][node.uid] = node + self._update_indices(self._name2index, nodes) + # override def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: return self._group2docs.get(group_name, {}).get(node_id) # override - def remove_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - assert node.group in self._group2docs, f"Unexpected node group {node.group}" - self._group2docs[node.group].pop(node.uid, None) + def remove_nodes(self, uids: List[str]) -> None: + for _, docs in self._group2docs.items(): + for uid in uids: + docs.pop(uid, None) + + self._remove_from_indices(self._name2index, uids) # override def has_nodes(self, group_name: str) -> bool: @@ -234,12 +57,16 @@ def all_groups(self) -> List[str]: return [group for group, nodes in self._group2docs.items()] # override - def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: - nodes = [] - for file in files: - if file in self._file_node_map: - nodes.append(self._file_node_map[file]) - return nodes + def register_index(self, type: str, index: BaseIndex) -> None: + self._name2index[type] = index + + # override + def remove_index(self, type: str) -> None: + self._name2index.pop(type, None) + + # override + def get_index(self, type: str) -> Optional[BaseIndex]: + return self._name2index.get(type) def find_node_by_uid(self, uid: str) -> Optional[DocNode]: for docs in self._group2docs.values(): @@ -274,8 +101,8 @@ def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: return self._map_store.get_node(group_name, node_id) # override - def remove_nodes(self, nodes: List[DocNode]) -> None: - return self._map_store.remove_nodes(nodes) + def remove_nodes(self, uids: List[str]) -> None: + return self._map_store.remove_nodes(uids) # override def has_nodes(self, group_name: str) -> bool: @@ -289,6 +116,18 @@ def get_nodes(self, group_name: str) -> List[DocNode]: def all_groups(self) -> List[str]: return self._map_store.all_groups() + # override + def register_index(self, type: str, index: BaseIndex) -> None: + self._map_store.register_index(type, index) + + # override + def remove_index(self, type: str) -> Optional[BaseIndex]: + return self._map_store.remove_index(type) + + # override + def get_index(self, type: str) -> Optional[BaseIndex]: + return self._map_store.get_index(type) + def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: LOG.info("No persistent data found, skip the rebuilding phrase.") @@ -344,10 +183,6 @@ def _save_nodes(self, nodes: List[DocNode]) -> None: ) LOG.debug(f"Saved {group} nodes {ids} to chromadb.") - # override - def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: - return self._map_store.get_nodes_by_files(files) - def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: nodes: List[DocNode] = [] for i, uid in enumerate(results['ids']): @@ -374,76 +209,3 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]: assert group in self._collections, f"group {group} not found." collection = self._collections[group] return collection.peek(collection.count()) - -# ---------------------------------------------------------------------------- # - -class MilvusEmbeddingIndexField: - def __init__(self, name: str = "", dim: int = 0, type: int = pymilvus.DataType.FLOAT_VECTOR): - self.name = name - self.dim: int = dim - self.type: int = type - -class MilvusStore(BaseStore): - def __init__(self, node_groups: List[str], uri: str, - embedding_index_info: List[MilvusEmbeddingIndexField], # fields to be indexed by Milvus - full_data_store: BaseStore): - self._full_data_store = full_data_store - - self._primary_key = 'uid' - self._client = pymilvus.MilvusClient(uri=uri) - - schema = self._client.create_schema(auto_id=False, enable_dynamic_field=True) - schema.add_field( - field_name=self._primary_key, - datatype=pymilvus.DataType.VARCHAR, - max_length=128, - is_primary=True, - ) - for field in embedding_index_info: - schema.add_field( - field_name=field.name, - datatype=field.type, - dim=field.dim) - - for group in node_groups: - if group not in self._client.list_collections(): - self._client.create_collection(collection_name=group, schema=schema) - - # override - def update_nodes(self, nodes: List[DocNode]) -> None: - self._save_nodes(nodes) - self._full_data_store.update_nodes(nodes) - - # override - def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: - return self._full_data_store.get_node(group_name, node_id) - - # override - def remove_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - self._client.delete(collection_name=node.group, - filter=f'{self._primary_key} in ["{node.uid}"]') - self._full_data_store.remove_nodes(nodes) - - # override - def has_nodes(self, group_name: str) -> bool: - return self._full_data_store.has_nodes(group_name) - - # override - def get_nodes(self, group_name: str) -> List[DocNode]: - return self._full_data_store.get_nodes(group_name) - - # override - def all_groups(self) -> List[str]: - return self._full_data_store.all_groups() - - # override - def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: - return self._full_data_store.get_nodes_by_files(files) - - def _save_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - if node.embedding: - data = node.embedding.copy() - data[self._primary_key] = node.uid - self._client.upsert(collection_name=node.group, data=data) diff --git a/lazyllm/tools/rag/transform.py b/lazyllm/tools/rag/transform.py index 38a97bd2..62b516a8 100644 --- a/lazyllm/tools/rag/transform.py +++ b/lazyllm/tools/rag/transform.py @@ -12,7 +12,7 @@ import nltk import tiktoken -from .store import DocNode, MetadataMode +from .doc_node import DocNode, MetadataMode from lazyllm import LOG, TrainableModule, ThreadPoolExecutor diff --git a/tests/advanced_tests/standard_test/test_reranker.py b/tests/advanced_tests/standard_test/test_reranker.py index ba89b667..3628a291 100644 --- a/tests/advanced_tests/standard_test/test_reranker.py +++ b/tests/advanced_tests/standard_test/test_reranker.py @@ -1,6 +1,6 @@ import unittest from unittest.mock import patch, MagicMock -from lazyllm.tools.rag.store import DocNode +from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.rerank import Reranker, register_reranker diff --git a/tests/basic_tests/test_bm25.py b/tests/basic_tests/test_bm25.py index 1fc9303f..0172e73a 100644 --- a/tests/basic_tests/test_bm25.py +++ b/tests/basic_tests/test_bm25.py @@ -1,6 +1,6 @@ import unittest from lazyllm.tools.rag.component.bm25 import BM25 -from lazyllm.tools.rag.store import DocNode +from lazyllm.tools.rag.doc_node import DocNode import numpy as np diff --git a/tests/basic_tests/test_doc_node.py b/tests/basic_tests/test_doc_node.py index 5c92018c..e49aff18 100644 --- a/tests/basic_tests/test_doc_node.py +++ b/tests/basic_tests/test_doc_node.py @@ -1,5 +1,5 @@ from unittest.mock import MagicMock -from lazyllm.tools.rag.store import DocNode, MetadataMode +from lazyllm.tools.rag.doc_node import DocNode, MetadataMode class TestDocNode: diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 8ea52582..dbd67954 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -1,7 +1,8 @@ import lazyllm -from lazyllm.tools.rag.doc_impl import DocImpl +from lazyllm.tools.rag.doc_impl import DocImpl, FileNodeIndex from lazyllm.tools.rag.transform import SentenceSplitter -from lazyllm.tools.rag.store import DocNode, LAZY_ROOT_NAME +from lazyllm.tools.rag.store import LAZY_ROOT_NAME +from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag import Document, Retriever, TransformArgs, AdaptiveTransform from lazyllm.launcher import cleanup from unittest.mock import MagicMock @@ -151,5 +152,38 @@ def test_multi_embedding_with_document(self): assert len(nodes3) == 3 +class TestFileNodeIndex(unittest.TestCase): + def setUp(self): + self.index = FileNodeIndex() + self.node1 = DocNode(uid='1', group=LAZY_ROOT_NAME, metadata={"file_name": "d1"}) + self.node2 = DocNode(uid='2', group=LAZY_ROOT_NAME, metadata={"file_name": "d2"}) + self.files = [self.node1.metadata['file_name'], self.node1.metadata['file_name']] + + def test_update(self): + self.index.update([self.node1, self.node2]) + + nodes = self.index.query(self.files) + assert len(nodes) == len(self.files) + + ret = [node.metadata['file_name'] for node in nodes] + assert set(ret) == set(self.files) + + def test_remove(self): + self.index.update([self.node1, self.node2]) + + self.index.remove([self.node2.uid]) + ret = self.index.query([self.node2.metadata['file_name']]) + assert len(ret) == 1 + assert ret[0] is None + + def test_query(self): + self.index.update([self.node1, self.node2]) + ret = self.index.query([self.node2.metadata['file_name']]) + assert len(ret) == 1 + assert ret[0] is self.node2 + ret = self.index.query([self.node1.metadata['file_name']]) + assert len(ret) == 1 + assert ret[0] is self.node1 + if __name__ == "__main__": unittest.main() diff --git a/tests/basic_tests/test_index.py b/tests/basic_tests/test_index.py index 8a5e6178..c3439bc9 100644 --- a/tests/basic_tests/test_index.py +++ b/tests/basic_tests/test_index.py @@ -1,8 +1,12 @@ +import os import time import unittest +import tempfile +import pymilvus from unittest.mock import MagicMock -from lazyllm.tools.rag.store import DocNode, MapStore -from lazyllm.tools.rag.index import DefaultIndex, register_similarity +from lazyllm.tools.rag.store import MapStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.index import DefaultIndex, register_similarity, MilvusIndex, MilvusEmbeddingField class TestDefaultIndex(unittest.TestCase): @@ -98,5 +102,60 @@ def test_query_multi_embed_one_thresholds(self): self.assertEqual(len(results), 1) self.assertIn(self.doc_node_2, results) +class TestMilvusIndex(unittest.TestCase): + def setUp(self): + embedding_fields = [ + MilvusEmbeddingField(name="vec1", dim=3, data_type=pymilvus.DataType.FLOAT_VECTOR, + index_type="HNSW", metric_type="IP"), + MilvusEmbeddingField(name="vec2", dim=5, data_type=pymilvus.DataType.FLOAT_VECTOR, + index_type="HNSW", metric_type="IP"), + ] + group_embedding_fields = { + "group1": embedding_fields, + "group2": embedding_fields, + } + + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + _, self.store_file = tempfile.mkstemp(suffix=".db") + + self.map_store = MapStore(self.node_groups) + self.index = MilvusIndex(group_embedding_fields=group_embedding_fields, + uri=self.store_file, full_data_store=self.map_store) + self.map_store.register_index(type='milvus', index=self.index) + + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, + embedding={"vec1": [1.0, 2.0, 3.0], "vec2": [4.0, 5.0, 6.0, 7.0, 8.0]}) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, + embedding={"vec1": [100.0, 200.0, 300.0], "vec2": [400.0, 500.0, 600.0, 700.0, 800.0]}) + + def tearDown(self): + os.remove(self.store_file) + + def test_update_and_query(self): + self.map_store.update_nodes([self.node1]) + ret = self.index.query(group_name='group1', data=[100.0, 200.0, 300.0], limit=1, + anns_field='vec1') + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) + + self.map_store.update_nodes([self.node2]) + ret = self.index.query(group_name='group1', data=[100.0, 200.0, 300.0], limit=1, + anns_field='vec1') + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + def test_remove_and_query(self): + self.map_store.update_nodes([self.node1, self.node2]) + ret = self.index.query(group_name='group1', data=[100.0, 200.0, 300.0], limit=1, + anns_field='vec1') + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node2.uid) + + self.map_store.remove_nodes([self.node2.uid]) + ret = self.index.query(group_name='group1', data=[100.0, 200.0, 300.0], limit=1, + anns_field='vec1') + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0].uid, self.node1.uid) + if __name__ == "__main__": unittest.main() diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index facaa39b..11aa9529 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -2,16 +2,9 @@ import shutil import unittest import lazyllm -import tempfile -import pymilvus from unittest.mock import MagicMock -from lazyllm.tools.rag.store import ( - DocNode, - MapStore, - ChromadbStore, - MilvusStore, MilvusEmbeddingIndexField, - LAZY_ROOT_NAME -) +from lazyllm.tools.rag.store import MapStore, ChromadbStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.doc_node import DocNode def clear_directory(directory_path): @@ -112,13 +105,13 @@ def test_remove_nodes(self): n1 = self.store.get_node("group1", "1") assert n1.text == self.node1.text - self.store.remove_nodes([n1]) + self.store.remove_nodes(["1"]) n1 = self.store.get_node("group1", "1") assert not n1 n2 = self.store.get_node("group1", "2") assert n2.text == self.node2.text - self.store.remove_nodes([n2]) + self.store.remove_nodes(["2"]) n2 = self.store.get_node("group1", "2") assert not n2 @@ -136,82 +129,3 @@ def test_group_others(self): self.store.update_nodes([self.node1, self.node2]) self.assertEqual(self.store.has_nodes("group1"), True) self.assertEqual(self.store.has_nodes("group2"), False) - -class TestMilvusStore(unittest.TestCase): - def setUp(self): - self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] - - self.map_store = MapStore(self.node_groups) - - index_field_list = [ - MilvusEmbeddingIndexField("vec1", 3, pymilvus.DataType.FLOAT_VECTOR), - MilvusEmbeddingIndexField("vec2", 5, pymilvus.DataType.FLOAT_VECTOR), - ] - _, self.store_file = tempfile.mkstemp(suffix=".db") - self.store = MilvusStore(node_groups=self.node_groups, uri=self.store_file, - embedding_index_info=index_field_list, - full_data_store=self.map_store) - - self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, - embedding={"vec1": [1, 2, 3], "vec2": [4, 5, 6, 7, 8]}) - self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, - embedding={"vec1": [100, 200, 300], "vec2": [400, 500, 600, 700, 800]}) - self.nodes = [self.node1, self.node2] - - def tearDown(self): - os.remove(self.store_file) - - def test_update_nodes(self): - self.store.update_nodes(self.nodes) - - nodes = self.store.get_nodes("group1") - self.assertEqual(len(nodes), len(self.nodes)) - - counter = 0 - for node in nodes: - for expected_node in self.nodes: - if node.uid == expected_node.uid: - self.assertEqual(node.text, expected_node.text) - counter += 1 - break - self.assertEqual(counter, len(self.nodes)) - - def test_get_node(self): - self.store.update_nodes(self.nodes) - n1 = self.store.get_node("group1", "1") - self.assertEqual(n1.text, self.node1.text) - n2 = self.store.get_node("group1", "2") - self.assertEqual(n2.text, self.node2.text) - - def test_remove_nodes(self): - self.store.update_nodes(self.nodes) - - n1 = self.store.get_node("group1", "1") - self.assertEqual(n1.text, self.node1.text) - self.store.remove_nodes([n1]) - n1 = self.store.get_node("group1", "1") - self.assertEqual(n1, None) - - n2 = self.store.get_node("group1", "2") - self.assertEqual(n2.text, self.node2.text) - self.store.remove_nodes([n2]) - n2 = self.store.get_node("group1", "2") - self.assertEqual(n2, None) - - def test_get_nodes(self): - self.store.update_nodes(self.nodes) - ids = set([self.node1.uid, self.node2.uid]) - - docs = self.store.get_nodes("group1") - self.assertEqual(ids, set([doc.uid for doc in docs])) - - def test_all_groups(self): - self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) - - def test_group_others(self): - self.store.update_nodes([self.node1, self.node2]) - self.assertEqual(self.store.has_nodes("group1"), True) - self.assertEqual(self.store.has_nodes("group2"), False) - -if __name__ == "__main__": - unittest.main() diff --git a/tests/basic_tests/test_transform.py b/tests/basic_tests/test_transform.py index 5b6a8f29..47f60d60 100644 --- a/tests/basic_tests/test_transform.py +++ b/tests/basic_tests/test_transform.py @@ -1,6 +1,6 @@ import lazyllm from lazyllm.tools.rag.transform import SentenceSplitter -from lazyllm.tools.rag.store import DocNode +from lazyllm.tools.rag.doc_node import DocNode class TestSentenceSplitter: