diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 7f98f4e3..0e5c0c6a 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -56,7 +56,7 @@ def _lazy_init(self) -> None: if not self.store.has_nodes(LAZY_ROOT_NAME): ids, pathes = self._list_files() root_nodes = self._reader.load_data(pathes) - self.store.add_nodes(root_nodes) + self.store.update_nodes(root_nodes) if self._dlm: self._dlm.update_kb_group_file_status( ids, DocListManager.Status.success, group=self._kb_group_name) LOG.debug(f"building {LAZY_ROOT_NAME} nodes: {root_nodes}") @@ -72,7 +72,6 @@ def _get_store(self) -> BaseStore: store = MapStore(node_groups=self.node_groups.keys()) elif rag_store_type == "chroma": store = ChromadbStore(node_groups=self.node_groups.keys(), embed=self.embed) - store.try_load_store() else: raise NotImplementedError( f"Not implemented store type for {rag_store_type}" @@ -179,13 +178,15 @@ def _add_files(self, input_files: List[str]): self._lazy_init() root_nodes = self._reader.load_data(input_files) temp_store = self._get_store() - temp_store.add_nodes(root_nodes) - active_groups = self.store.active_groups() - LOG.info(f"add_files: Trying to merge store with {active_groups}") - for group in active_groups: + temp_store.update_nodes(root_nodes) + all_groups = self.store.all_groups() + LOG.info(f"add_files: Trying to merge store with {all_groups}") + for group in all_groups: + if not self.store.has_nodes(group): + continue # Duplicate group will be discarded automatically nodes = self._get_nodes(group, temp_store) - self.store.add_nodes(nodes) + self.store.update_nodes(nodes) LOG.debug(f"Merge {group} with {nodes}") def _delete_files(self, input_files: List[str]) -> None: @@ -226,13 +227,13 @@ def _dynamic_create_nodes(self, group_name: str, store: BaseStore) -> None: transform = AdaptiveTransform(t) if isinstance(t, list) or t.pattern else make_transform(t) parent_nodes = self._get_nodes(node_group["parent"], store) nodes = transform.batch_forward(parent_nodes, group_name) - store.add_nodes(nodes) + store.update_nodes(nodes) LOG.debug(f"building {group_name} nodes: {nodes}") def _get_nodes(self, group_name: str, store: Optional[BaseStore] = None) -> List[DocNode]: store = store or self.store self._dynamic_create_nodes(group_name, store) - return store.traverse_nodes(group_name) + return store.get_nodes(group_name) def retrieve(self, query: str, group_name: str, similarity: str, similarity_cut_off: Union[float, Dict[str, float]], index: str, topk: int, similarity_kws: dict, embed_keys: Optional[List[str]] = None) -> List[DocNode]: diff --git a/lazyllm/tools/rag/index.py b/lazyllm/tools/rag/index.py index d9f6ad24..78a7a59e 100644 --- a/lazyllm/tools/rag/index.py +++ b/lazyllm/tools/rag/index.py @@ -96,7 +96,7 @@ def query( assert len(query) > 0, "Query should not be empty." query_embedding = {k: self.embed[k](query) for k in (embed_keys or self.embed.keys())} nodes = self._parallel_do_embedding(nodes) - self.store.try_save_nodes(nodes) + self.store.update_nodes(nodes) similarities = similarity_func(query_embedding, nodes, topk=topk, **kwargs) elif mode == "text": similarities = similarity_func(query, nodes, topk=topk, **kwargs) diff --git a/lazyllm/tools/rag/store.py b/lazyllm/tools/rag/store.py index 8e5c30c2..57cee198 100644 --- a/lazyllm/tools/rag/store.py +++ b/lazyllm/tools/rag/store.py @@ -6,6 +6,7 @@ import chromadb from lazyllm import LOG, config from chromadb.api.models.Collection import Collection +import pymilvus import threading import json import time @@ -161,60 +162,78 @@ def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: def to_dict(self) -> Dict: return dict(text=self.text, embedding=self.embedding, metadata=self.metadata) +# ---------------------------------------------------------------------------- # class BaseStore(ABC): - def __init__(self, node_groups: List[str]) -> None: - self._store: Dict[str, Dict[str, DocNode]] = { - group: {} for group in node_groups - } - self._file_node_map = {} - - def _add_nodes(self, nodes: List[DocNode]) -> None: - for node in nodes: - if node.group == LAZY_ROOT_NAME and "file_name" in node.metadata: - self._file_node_map[node.metadata["file_name"]] = node - self._store[node.group][node.uid] = node - - def add_nodes(self, nodes: List[DocNode]) -> None: - self._add_nodes(nodes) - self.try_save_nodes(nodes) + @abstractmethod + def update_nodes(self, nodes: List[DocNode]) -> None: + raise NotImplementedError("not implemented yet.") - def has_nodes(self, group: str) -> bool: - return len(self._store[group]) > 0 + @abstractmethod + def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: + raise NotImplementedError("not implemented yet.") - def get_node(self, group: str, node_id: str) -> Optional[DocNode]: - return self._store.get(group, {}).get(node_id) + @abstractmethod + def get_nodes(self, group_name: str) -> List[DocNode]: + raise NotImplementedError("not implemented yet.") - def traverse_nodes(self, group: str) -> List[DocNode]: - return list(self._store.get(group, {}).values()) + @abstractmethod + def remove_nodes(self, nodes: List[DocNode]) -> None: + raise NotImplementedError("not implemented yet.") @abstractmethod - def try_save_nodes(self, nodes: List[DocNode]) -> None: - # try save nodes to persistent source - raise NotImplementedError("Not implemented yet.") + def has_nodes(self, group_name: str) -> bool: + raise NotImplementedError("not implemented yet.") @abstractmethod - def try_load_store(self) -> None: - # try load nodes from persistent source - raise NotImplementedError("Not implemented yet.") + def all_groups(self) -> List[str]: + raise NotImplementedError("not implemented yet.") @abstractmethod - def try_remove_nodes(self, nodes: List[DocNode]) -> None: - # try remove nodes in persistent source - raise NotImplementedError("Not implemented yet.") + def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: + raise NotImplementedError("not implemented yet.") - def active_groups(self) -> List: - return [group for group, nodes in self._store.items() if nodes] +# ---------------------------------------------------------------------------- # - def _remove_nodes(self, nodes: List[DocNode]) -> None: +class MapStore(BaseStore): + def __init__(self, node_groups: List[str]): + # Dict[group_name, Dict[uuid, DocNode]] + self._group2docs: Dict[str, Dict[str, DocNode]] = { + group: {} for group in node_groups + } + self._file_node_map = {} + + # override + def update_nodes(self, nodes: List[DocNode]) -> None: for node in nodes: - assert node.group in self._store, f"Unexpected node group {node.group}" - self._store[node.group].pop(node.uid, None) + if node.group == LAZY_ROOT_NAME and "file_name" in node.metadata: + self._file_node_map[node.metadata["file_name"]] = node + self._group2docs[node.group][node.uid] = node + + # override + def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: + return self._group2docs.get(group_name, {}).get(node_id) + # override def remove_nodes(self, nodes: List[DocNode]) -> None: - self._remove_nodes(nodes) - self.try_remove_nodes(nodes) + for node in nodes: + assert node.group in self._group2docs, f"Unexpected node group {node.group}" + self._group2docs[node.group].pop(node.uid, None) + + # override + def has_nodes(self, group_name: str) -> bool: + docs = self._group2docs.get(group_name) + return True if docs else False + + # override + def get_nodes(self, group_name: str) -> List[DocNode]: + return list(self._group2docs.get(group_name, {}).values()) + + # override + def all_groups(self) -> List[str]: + return [group for group, nodes in self._group2docs.items()] + # override def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: nodes = [] for file in files: @@ -222,26 +241,20 @@ def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: nodes.append(self._file_node_map[file]) return nodes + def find_node_by_uid(self, uid: str) -> Optional[DocNode]: + for docs in self._group2docs.values(): + doc = docs.get(uid) + if doc: + return doc + return None -class MapStore(BaseStore): - def __init__(self, node_groups: List[str], *args, **kwargs): - super().__init__(node_groups, *args, **kwargs) - - def try_save_nodes(self, nodes: List[DocNode]) -> None: - pass - - def try_load_store(self) -> None: - pass - - def try_remove_nodes(self, nodes: List[DocNode]) -> None: - pass - +# ---------------------------------------------------------------------------- # class ChromadbStore(BaseStore): def __init__( - self, node_groups: List[str], embed: Dict[str, Callable], *args, **kwargs + self, node_groups: List[str], embed: Dict[str, Callable] ) -> None: - super().__init__(node_groups, *args, **kwargs) + self._map_store = MapStore(node_groups) self._db_client = chromadb.PersistentClient(path=config["rag_persistent_path"]) LOG.success(f"Initialzed chromadb in path: {config['rag_persistent_path']}") self._collections: Dict[str, Collection] = { @@ -249,8 +262,34 @@ def __init__( for group in node_groups } self._placeholder = {k: [-1] * len(e("a")) for k, e in embed.items()} if embed else {EMBED_DEFAULT_KEY: []} + self._load_store() + + # override + def update_nodes(self, nodes: List[DocNode]) -> None: + self._map_store.update_nodes(nodes) + self._save_nodes(nodes) + + # override + def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: + return self._map_store.get_node(group_name, node_id) - def try_load_store(self) -> None: + # override + def remove_nodes(self, nodes: List[DocNode]) -> None: + return self._map_store.remove_nodes(nodes) + + # override + def has_nodes(self, group_name: str) -> bool: + return self._map_store.has_nodes(group_name) + + # override + def get_nodes(self, group_name: str) -> List[DocNode]: + return self._map_store.get_nodes(group_name) + + # override + def all_groups(self) -> List[str]: + return self._map_store.all_groups() + + def _load_store(self) -> None: if not self._collections[LAZY_ROOT_NAME].peek(1)["ids"]: LOG.info("No persistent data found, skip the rebuilding phrase.") return @@ -259,20 +298,21 @@ def try_load_store(self) -> None: for group in self._collections.keys(): results = self._peek_all_documents(group) nodes = self._build_nodes_from_chroma(results) - self._add_nodes(nodes) + self._map_store.update_nodes(nodes) # Rebuild relationships - for group, nodes_dict in self._store.items(): - for node in nodes_dict.values(): + for group_name in self._map_store.all_groups(): + nodes = self._map_store.get_nodes(group_name) + for node in nodes: if node.parent: parent_uid = node.parent - parent_node = self._find_node_by_uid(parent_uid) + parent_node = self._map_store.find_node_by_uid(parent_uid) node.parent = parent_node parent_node.children[node.group].append(node) - LOG.debug(f"build {group} nodes from chromadb: {nodes_dict.values()}") + LOG.debug(f"build {group} nodes from chromadb: {nodes}") LOG.success("Successfully Built nodes from chromadb.") - def try_save_nodes(self, nodes: List[DocNode]) -> None: + def _save_nodes(self, nodes: List[DocNode]) -> None: if not nodes: return # Note: It's caller's duty to make sure this batch of nodes has the same group. @@ -304,14 +344,9 @@ def try_save_nodes(self, nodes: List[DocNode]) -> None: ) LOG.debug(f"Saved {group} nodes {ids} to chromadb.") - def try_remove_nodes(self, nodes: List[DocNode]) -> None: - pass - - def _find_node_by_uid(self, uid: str) -> Optional[DocNode]: - for nodes_by_category in self._store.values(): - if uid in nodes_by_category: - return nodes_by_category[uid] - raise ValueError(f"UID {uid} not found in store.") + # override + def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: + return self._map_store.get_nodes_by_files(files) def _build_nodes_from_chroma(self, results: Dict[str, List]) -> List[DocNode]: nodes: List[DocNode] = [] @@ -339,3 +374,76 @@ def _peek_all_documents(self, group: str) -> Dict[str, List]: assert group in self._collections, f"group {group} not found." collection = self._collections[group] return collection.peek(collection.count()) + +# ---------------------------------------------------------------------------- # + +class MilvusEmbeddingIndexField: + def __init__(self, name: str = "", dim: int = 0, type: int = pymilvus.DataType.FLOAT_VECTOR): + self.name = name + self.dim: int = dim + self.type: int = type + +class MilvusStore(BaseStore): + def __init__(self, node_groups: List[str], uri: str, + embedding_index_info: List[MilvusEmbeddingIndexField], # fields to be indexed by Milvus + full_data_store: BaseStore): + self._full_data_store = full_data_store + + self._primary_key = 'uid' + self._client = pymilvus.MilvusClient(uri=uri) + + schema = self._client.create_schema(auto_id=False, enable_dynamic_field=True) + schema.add_field( + field_name=self._primary_key, + datatype=pymilvus.DataType.VARCHAR, + max_length=128, + is_primary=True, + ) + for field in embedding_index_info: + schema.add_field( + field_name=field.name, + datatype=field.type, + dim=field.dim) + + for group in node_groups: + if group not in self._client.list_collections(): + self._client.create_collection(collection_name=group, schema=schema) + + # override + def update_nodes(self, nodes: List[DocNode]) -> None: + self._save_nodes(nodes) + self._full_data_store.update_nodes(nodes) + + # override + def get_node(self, group_name: str, node_id: str) -> Optional[DocNode]: + return self._full_data_store.get_node(group_name, node_id) + + # override + def remove_nodes(self, nodes: List[DocNode]) -> None: + for node in nodes: + self._client.delete(collection_name=node.group, + filter=f'{self._primary_key} in ["{node.uid}"]') + self._full_data_store.remove_nodes(nodes) + + # override + def has_nodes(self, group_name: str) -> bool: + return self._full_data_store.has_nodes(group_name) + + # override + def get_nodes(self, group_name: str) -> List[DocNode]: + return self._full_data_store.get_nodes(group_name) + + # override + def all_groups(self) -> List[str]: + return self._full_data_store.all_groups() + + # override + def get_nodes_by_files(self, files: List[str]) -> List[DocNode]: + return self._full_data_store.get_nodes_by_files(files) + + def _save_nodes(self, nodes: List[DocNode]) -> None: + for node in nodes: + if node.embedding: + data = node.embedding.copy() + data[self._primary_key] = node.uid + self._client.upsert(collection_name=node.group, data=data) diff --git a/requirements.full.txt b/requirements.full.txt index 8c8387f3..88d0c584 100644 --- a/requirements.full.txt +++ b/requirements.full.txt @@ -17,6 +17,7 @@ json5 tiktoken spacy<=3.7.5 chromadb +pymilvus bm25s pystemmer nltk diff --git a/requirements.txt b/requirements.txt index b8010c2b..0106213a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ json5 tiktoken spacy<=3.7.5 chromadb +pymilvus bm25s pystemmer nltk @@ -30,4 +31,4 @@ sqlalchemy psutil pypdf pytest -numpy==1.26.4 \ No newline at end of file +numpy==1.26.4 diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index e62dc277..8ea52582 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -59,16 +59,16 @@ def test_retrieve(self): def test_add_files(self): assert self.doc_impl.store is None self.doc_impl._lazy_init() - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 1 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 1 new_doc = DocNode(text="new dummy text", group=LAZY_ROOT_NAME) new_doc.metadata = {"file_name": "new_file.txt"} self.mock_directory_reader.load_data.return_value = [new_doc] self.doc_impl._add_files(["new_file.txt"]) - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 2 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 2 def test_delete_files(self): self.doc_impl._delete_files(["dummy_file.txt"]) - assert len(self.doc_impl.store.traverse_nodes(LAZY_ROOT_NAME)) == 0 + assert len(self.doc_impl.store.get_nodes(LAZY_ROOT_NAME)) == 0 class TestDocument(unittest.TestCase): diff --git a/tests/basic_tests/test_store.py b/tests/basic_tests/test_store.py index 1a6edfc9..facaa39b 100644 --- a/tests/basic_tests/test_store.py +++ b/tests/basic_tests/test_store.py @@ -2,8 +2,16 @@ import shutil import unittest import lazyllm +import tempfile +import pymilvus from unittest.mock import MagicMock -from lazyllm.tools.rag.store import DocNode, ChromadbStore, LAZY_ROOT_NAME +from lazyllm.tools.rag.store import ( + DocNode, + MapStore, + ChromadbStore, + MilvusStore, MilvusEmbeddingIndexField, + LAZY_ROOT_NAME +) def clear_directory(directory_path): @@ -26,7 +34,7 @@ def setUp(self): self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] self.embed = {"default": MagicMock(side_effect=lambda text: [0.1, 0.2, 0.3])} self.store = ChromadbStore(self.node_groups, self.embed) - self.store.add_nodes( + self.store.update_nodes( [DocNode(uid="1", text="text1", group=LAZY_ROOT_NAME, parent=None)], ) @@ -37,36 +45,173 @@ def tearDownClass(cls): def test_initialization(self): self.assertEqual(set(self.store._collections.keys()), set(self.node_groups)) - def test_add_and_traverse_nodes(self): + def test_add_and_get_nodes(self): node1 = DocNode(uid="1", text="text1", group="group1") node2 = DocNode(uid="2", text="text2", group="group2") - self.store.add_nodes([node1, node2]) - nodes = self.store.traverse_nodes("group1") + self.store.update_nodes([node1, node2]) + nodes = self.store.get_nodes("group1") self.assertEqual(nodes, [node1]) def test_save_nodes(self): node1 = DocNode(uid="1", text="text1", group="group1") node2 = DocNode(uid="2", text="text2", group="group2") - self.store.add_nodes([node1, node2]) + self.store.update_nodes([node1, node2]) collection = self.store._collections["group1"] self.assertEqual(collection.peek(collection.count())["ids"], ["1", "2"]) - def test_try_load_store(self): + def test_load_store(self): # Set up initial data to be loaded node1 = DocNode(uid="1", text="text1", group="group1", parent=None) node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) - self.store.add_nodes([node1, node2]) + self.store.update_nodes([node1, node2]) # Reset store and load from "persistent" storage - self.store._store = {group: {} for group in self.node_groups} - self.store.try_load_store() + self.store._map_store._group2docs = {group: {} for group in self.node_groups} + self.store._load_store() - nodes = self.store.traverse_nodes("group1") + nodes = self.store.get_nodes("group1") self.assertEqual(len(nodes), 2) self.assertEqual(nodes[0].uid, "1") self.assertEqual(nodes[1].uid, "2") self.assertEqual(nodes[1].parent.uid, "1") + def test_all_groups(self): + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + + def test_group_others(self): + node1 = DocNode(uid="1", text="text1", group="group1", parent=None) + node2 = DocNode(uid="2", text="text2", group="group1", parent=node1) + self.store.update_nodes([node1, node2]) + self.assertEqual(self.store.has_nodes("group1"), True) + self.assertEqual(self.store.has_nodes("group2"), True) + +class TestMapStore(unittest.TestCase): + def setUp(self): + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + self.store = MapStore(self.node_groups) + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1) + + def test_update_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + nodes = self.store.get_nodes("group1") + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0].uid, "1") + self.assertEqual(nodes[1].uid, "2") + self.assertEqual(nodes[1].parent.uid, "1") + + def test_get_node(self): + self.store.update_nodes([self.node1, self.node2]) + n1 = self.store.get_node("group1", "1") + assert n1.text == self.node1.text + n2 = self.store.get_node("group1", "2") + assert n2.text == self.node2.text + + def test_remove_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + + n1 = self.store.get_node("group1", "1") + assert n1.text == self.node1.text + self.store.remove_nodes([n1]) + n1 = self.store.get_node("group1", "1") + assert not n1 + + n2 = self.store.get_node("group1", "2") + assert n2.text == self.node2.text + self.store.remove_nodes([n2]) + n2 = self.store.get_node("group1", "2") + assert not n2 + + def test_get_nodes(self): + self.store.update_nodes([self.node1, self.node2]) + ids = set([self.node1.uid, self.node2.uid]) + + docs = self.store.get_nodes("group1") + self.assertEqual(ids, set([doc.uid for doc in docs])) + + def test_all_groups(self): + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + + def test_group_others(self): + self.store.update_nodes([self.node1, self.node2]) + self.assertEqual(self.store.has_nodes("group1"), True) + self.assertEqual(self.store.has_nodes("group2"), False) + +class TestMilvusStore(unittest.TestCase): + def setUp(self): + self.node_groups = [LAZY_ROOT_NAME, "group1", "group2"] + + self.map_store = MapStore(self.node_groups) + + index_field_list = [ + MilvusEmbeddingIndexField("vec1", 3, pymilvus.DataType.FLOAT_VECTOR), + MilvusEmbeddingIndexField("vec2", 5, pymilvus.DataType.FLOAT_VECTOR), + ] + _, self.store_file = tempfile.mkstemp(suffix=".db") + self.store = MilvusStore(node_groups=self.node_groups, uri=self.store_file, + embedding_index_info=index_field_list, + full_data_store=self.map_store) + + self.node1 = DocNode(uid="1", text="text1", group="group1", parent=None, + embedding={"vec1": [1, 2, 3], "vec2": [4, 5, 6, 7, 8]}) + self.node2 = DocNode(uid="2", text="text2", group="group1", parent=self.node1, + embedding={"vec1": [100, 200, 300], "vec2": [400, 500, 600, 700, 800]}) + self.nodes = [self.node1, self.node2] + + def tearDown(self): + os.remove(self.store_file) + + def test_update_nodes(self): + self.store.update_nodes(self.nodes) + + nodes = self.store.get_nodes("group1") + self.assertEqual(len(nodes), len(self.nodes)) + + counter = 0 + for node in nodes: + for expected_node in self.nodes: + if node.uid == expected_node.uid: + self.assertEqual(node.text, expected_node.text) + counter += 1 + break + self.assertEqual(counter, len(self.nodes)) + + def test_get_node(self): + self.store.update_nodes(self.nodes) + n1 = self.store.get_node("group1", "1") + self.assertEqual(n1.text, self.node1.text) + n2 = self.store.get_node("group1", "2") + self.assertEqual(n2.text, self.node2.text) + + def test_remove_nodes(self): + self.store.update_nodes(self.nodes) + + n1 = self.store.get_node("group1", "1") + self.assertEqual(n1.text, self.node1.text) + self.store.remove_nodes([n1]) + n1 = self.store.get_node("group1", "1") + self.assertEqual(n1, None) + + n2 = self.store.get_node("group1", "2") + self.assertEqual(n2.text, self.node2.text) + self.store.remove_nodes([n2]) + n2 = self.store.get_node("group1", "2") + self.assertEqual(n2, None) + + def test_get_nodes(self): + self.store.update_nodes(self.nodes) + ids = set([self.node1.uid, self.node2.uid]) + + docs = self.store.get_nodes("group1") + self.assertEqual(ids, set([doc.uid for doc in docs])) + + def test_all_groups(self): + self.assertEqual(set(self.store.all_groups()), set(self.node_groups)) + + def test_group_others(self): + self.store.update_nodes([self.node1, self.node2]) + self.assertEqual(self.store.has_nodes("group1"), True) + self.assertEqual(self.store.has_nodes("group2"), False) if __name__ == "__main__": unittest.main() diff --git a/tests/requirements.txt b/tests/requirements.txt index 90ba7fd6..5a33e8db 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,3 +3,4 @@ docx2txt olefile pytest-rerunfailures pytest-order +pymilvus