diff --git a/lazyllm/tools/rag/chroma_store.py b/lazyllm/tools/rag/chroma_store.py index 8a7969b5..1c8bdf88 100644 --- a/lazyllm/tools/rag/chroma_store.py +++ b/lazyllm/tools/rag/chroma_store.py @@ -43,6 +43,10 @@ def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> Non self._db_client.delete_collection(name=group_name) return self._map_store.remove_nodes(group_name, uids) + @override + def update_doc_meta(self, filepath: str, metadata: dict) -> None: + self._map_store.update_doc_meta(filepath, metadata) + @override def get_nodes(self, group_name: str, uids: List[str] = None) -> List[DocNode]: return self._map_store.get_nodes(group_name, uids) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index ef3b6c2b..4a3a7e09 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -223,6 +223,13 @@ def add_reader(self, pattern: str, func: Optional[Callable] = None): def worker(self): while True: + # Apply meta changes + rows = self._dlm.fetch_docs_changed_meta(self._kb_group_name) + if rows: + for row in rows: + new_meta_dict = json.loads(row[1]) if row[1] else {} + self.store.update_doc_meta(row[0], new_meta_dict) + docs = self._dlm.get_docs_need_reparse(group=self._kb_group_name) if docs: filepaths = [doc.path for doc in docs] diff --git a/lazyllm/tools/rag/doc_manager.py b/lazyllm/tools/rag/doc_manager.py index 47334628..10603086 100644 --- a/lazyllm/tools/rag/doc_manager.py +++ b/lazyllm/tools/rag/doc_manager.py @@ -1,6 +1,6 @@ import os import json -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Union from pydantic import BaseModel, Field from starlette.responses import RedirectResponse @@ -213,5 +213,146 @@ def delete_files_from_group(self, request: FileGroupRequest): except Exception as e: return BaseResponse(code=500, msg=str(e), data=None) + class AddMetadataRequest(BaseModel): + doc_ids: List[str] + kv_pair: Dict[str, Union[bool, int, float, str, list]] + + @app.post("/add_metadata") + def add_metadata(self, add_metadata_request: AddMetadataRequest): + doc_ids = add_metadata_request.doc_ids + kv_pair = add_metadata_request.kv_pair + try: + docs = self._manager.get_docs(doc_ids) + if not docs: + return BaseResponse(code=400, msg="Failed, no doc found") + doc_meta = {} + for doc in docs: + meta_dict = json.loads(doc.meta) if doc.meta else {} + for k, v in kv_pair.items(): + if k not in meta_dict or not meta_dict[k]: + meta_dict[k] = v + elif isinstance(meta_dict[k], list): + meta_dict[k].extend(v) if isinstance(v, list) else meta_dict[k].append(v) + else: + meta_dict[k] = ([meta_dict[k]] + v) if isinstance(v, list) else [meta_dict[k], v] + doc_meta[doc.doc_id] = meta_dict + self._manager.set_docs_new_meta(doc_meta) + return BaseResponse(data=None) + except Exception as e: + return BaseResponse(code=500, msg=str(e), data=None) + + class DeleteMetadataRequest(BaseModel): + doc_ids: List[str] + keys: Optional[List[str]] = Field(None) + kv_pair: Optional[Dict[str, Union[bool, int, float, str, list]]] = Field(None) + + def _inplace_del_meta(self, meta_dict, kv_pair: Dict[str, Union[None, bool, int, float, str, list]]): + # alert: meta_dict is not a deepcopy + for k, v in kv_pair.items(): + if k not in meta_dict: + continue + if v is None: + meta_dict.pop(k, None) + elif isinstance(meta_dict[k], list): + if isinstance(v, (bool, int, float, str)): + v = [v] + # delete v exists in meta_dict[k] + meta_dict[k] = list(set(meta_dict[k]) - set(v)) + else: + # old meta[k] not a list, use v as condition to delete the key + if meta_dict[k] == v: + meta_dict.pop(k, None) + + @app.post("/delete_metadata_item") + def delete_metadata_item(self, del_metadata_request: DeleteMetadataRequest): + doc_ids = del_metadata_request.doc_ids + kv_pair = del_metadata_request.kv_pair + keys = del_metadata_request.keys + try: + if keys is not None: + # convert keys to kv_pair + if kv_pair: + kv_pair.update({k: None for k in keys}) + else: + kv_pair = {k: None for k in keys} + if not kv_pair: + # clear metadata + self._manager.set_docs_new_meta({doc_id: {} for doc_id in doc_ids}) + else: + docs = self._manager.get_docs(doc_ids) + if not docs: + return BaseResponse(code=400, msg="Failed, no doc found") + doc_meta = {} + for doc in docs: + meta_dict = json.loads(doc.meta) if doc.meta else {} + self._inplace_del_meta(meta_dict, kv_pair) + doc_meta[doc.doc_id] = meta_dict + self._manager.set_docs_new_meta(doc_meta) + return BaseResponse(data=None) + except Exception as e: + return BaseResponse(code=500, msg=str(e), data=None) + + class UpdateMetadataRequest(BaseModel): + doc_ids: List[str] + kv_pair: Dict[str, Union[bool, int, float, str, list]] + + @app.post("/update_or_create_metadata_keys") + def update_or_create_metadata_keys(self, update_metadata_request: UpdateMetadataRequest): + doc_ids = update_metadata_request.doc_ids + kv_pair = update_metadata_request.kv_pair + try: + docs = self._manager.get_docs(doc_ids) + if not docs: + return BaseResponse(code=400, msg="Failed, no doc found") + for doc in docs: + doc_meta = {} + meta_dict = json.loads(doc.meta) if doc.meta else {} + for k, v in kv_pair.items(): + meta_dict[k] = v + doc_meta[doc.doc_id] = meta_dict + self._manager.set_docs_new_meta(doc_meta) + return BaseResponse(data=None) + except Exception as e: + return BaseResponse(code=500, msg=str(e), data=None) + + class ResetMetadataRequest(BaseModel): + doc_ids: List[str] + new_meta: Dict[str, Union[bool, int, float, str, list]] + + @app.post("/reset_metadata") + def reset_metadata(self, reset_metadata_request: ResetMetadataRequest): + doc_ids = reset_metadata_request.doc_ids + new_meta = reset_metadata_request.new_meta + try: + docs = self._manager.get_docs(doc_ids) + if not docs: + return BaseResponse(code=400, msg="Failed, no doc found") + self._manager.set_docs_new_meta({doc.doc_id: new_meta for doc in docs}) + return BaseResponse(data=None) + except Exception as e: + return BaseResponse(code=500, msg=str(e), data=None) + + class QueryMetadataRequest(BaseModel): + doc_id: str + key: Optional[str] = None + + @app.post("/query_metadata") + def query_metadata(self, query_metadata_request: QueryMetadataRequest): + doc_id = query_metadata_request.doc_id + key = query_metadata_request.key + try: + docs = self._manager.get_docs(doc_id) + if not docs: + return BaseResponse(data=None) + doc = docs[0] + meta_dict = json.loads(doc.meta) if doc.meta else {} + if not key: + return BaseResponse(data=meta_dict) + if key not in meta_dict: + return BaseResponse(code=400, msg=f"Failed, key {key} does not exist") + return BaseResponse(data=meta_dict[key]) + except Exception as e: + return BaseResponse(code=500, msg=str(e), data=None) + def __repr__(self): return lazyllm.make_repr("Module", "DocManager") diff --git a/lazyllm/tools/rag/global_metadata.py b/lazyllm/tools/rag/global_metadata.py index 653305b5..d7bd3ae2 100644 --- a/lazyllm/tools/rag/global_metadata.py +++ b/lazyllm/tools/rag/global_metadata.py @@ -19,3 +19,6 @@ def __init__(self, data_type: int, element_type: Optional[int] = None, RAG_DOC_CREATION_DATE = 'creation_date' RAG_DOC_LAST_MODIFIED_DATE = 'last_modified_date' RAG_DOC_LAST_ACCESSED_DATE = 'last_accessed_date' + +RAG_SYSTEM_META_KEYS = set([RAG_DOC_ID, RAG_DOC_PATH, RAG_DOC_FILE_NAME, RAG_DOC_FILE_TYPE, RAG_DOC_FILE_SIZE, + RAG_DOC_CREATION_DATE, RAG_DOC_LAST_MODIFIED_DATE, RAG_DOC_LAST_ACCESSED_DATE]) diff --git a/lazyllm/tools/rag/map_store.py b/lazyllm/tools/rag/map_store.py index 84d3e850..234fe4ff 100644 --- a/lazyllm/tools/rag/map_store.py +++ b/lazyllm/tools/rag/map_store.py @@ -5,6 +5,7 @@ from .utils import _FileNodeIndex from .default_index import DefaultIndex from lazyllm.common import override +from .global_metadata import RAG_SYSTEM_META_KEYS def _update_indices(name2index: Dict[str, IndexBase], nodes: List[DocNode]) -> None: for index in name2index.values(): @@ -33,6 +34,20 @@ def update_nodes(self, nodes: List[DocNode]) -> None: self._group2docs[node._group][node._uid] = node _update_indices(self._name2index, nodes) + @override + def update_doc_meta(self, filepath: str, metadata: dict) -> None: + doc_nodes: List[DocNode] = self._name2index['file_node_map'].query([filepath]) + if not doc_nodes: + return + root_node = doc_nodes[0].root_node + keys_to_delete = [] + for k in root_node.global_metadata: + if not (k in RAG_SYSTEM_META_KEYS or k in metadata): + keys_to_delete.append(k) + for k in keys_to_delete: + root_node.global_metadata.pop(k) + root_node.global_metadata.update(metadata) + @override def remove_nodes(self, group_name: str, uids: List[str] = None) -> None: if uids: diff --git a/lazyllm/tools/rag/milvus_store.py b/lazyllm/tools/rag/milvus_store.py index 618d2da6..e2691d41 100644 --- a/lazyllm/tools/rag/milvus_store.py +++ b/lazyllm/tools/rag/milvus_store.py @@ -165,6 +165,10 @@ def update_nodes(self, nodes: List[DocNode]) -> None: self._map_store.update_nodes(nodes) + @override + def update_doc_meta(self, filepath: str, metadata: dict) -> None: + self._map_store.update_doc_meta(filepath, metadata) + @override def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: if uids: diff --git a/lazyllm/tools/rag/store_base.py b/lazyllm/tools/rag/store_base.py index 08992ad9..ba59d747 100644 --- a/lazyllm/tools/rag/store_base.py +++ b/lazyllm/tools/rag/store_base.py @@ -15,6 +15,10 @@ class StoreBase(ABC): def update_nodes(self, nodes: List[DocNode]) -> None: pass + @abstractmethod + def update_doc_meta(self, filepath: str, metadata: dict) -> None: + pass + @abstractmethod def remove_nodes(self, group_name: str, uids: Optional[List[str]] = None) -> None: pass diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index c9ba1825..072a8680 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -11,7 +11,7 @@ from lazyllm.common.queue import sqlite3_check_threadsafety import sqlalchemy from sqlalchemy.orm import DeclarativeBase, sessionmaker -from sqlalchemy import Column, select, insert, update, Row +from sqlalchemy import Column, select, insert, update, Row, bindparam from sqlalchemy.exc import NoResultFound import pydantic @@ -69,7 +69,9 @@ class KBGroup(KBDataBase): group_id = Column(sqlalchemy.Integer, primary_key=True, autoincrement=True) group_name = Column(sqlalchemy.String, nullable=False, unique=True) +DocMetaChangedRow = Row GroupDocPartRow = Row + class KBGroupDocuments(KBDataBase): __tablename__ = "kb_group_documents" @@ -79,6 +81,7 @@ class KBGroupDocuments(KBDataBase): status = Column(sqlalchemy.Text, nullable=True) log = Column(sqlalchemy.Text, nullable=True) need_reparse = Column(sqlalchemy.Boolean, default=False, nullable=False) + new_meta = Column(sqlalchemy.Text, nullable=True) # unique constraint __table_args__ = (sqlalchemy.UniqueConstraint('doc_id', 'group_name', name='uq_doc_to_group'),) @@ -148,6 +151,15 @@ def list_files(self, limit: Optional[int] = None, details: bool = False, status: Union[str, List[str]] = Status.all, exclude_status: Optional[Union[str, List[str]]] = None): pass + @abstractmethod + def get_docs(self, doc_ids: List[str]) -> List[KBDocument]: pass + + @abstractmethod + def set_docs_new_meta(self, doc_meta: Dict[str, dict]): pass + + @abstractmethod + def fetch_docs_changed_meta(self, group: str) -> List[DocMetaChangedRow]: pass + @abstractmethod def list_all_kb_group(self): pass @@ -312,6 +324,41 @@ def list_files(self, limit: Optional[int] = None, details: bool = False, cursor = conn.execute(query, params) return cursor.fetchall() if details else [row[0] for row in cursor] + def get_docs(self, doc_ids: List[str]) -> List[KBDocument]: + with self._db_lock, self._Session() as session: + docs = session.query(KBDocument).filter(KBDocument.doc_id.in_(doc_ids)).all() + return docs + return [] + + def set_docs_new_meta(self, doc_meta: Dict[str, dict]): + data_to_update = [{"_doc_id": k, "_meta": json.dumps(v)} for k, v in doc_meta.items()] + with self._db_lock, self._Session() as session: + # Use sqlalchemy core bulk update + stmt = KBDocument.__table__.update().where( + KBDocument.doc_id == bindparam("_doc_id")).values(meta=bindparam("_meta")) + session.execute(stmt, data_to_update) + session.commit() + + stmt = KBGroupDocuments.__table__.update().where( + KBGroupDocuments.doc_id == bindparam("_doc_id"), + KBGroupDocuments.status != DocListManager.Status.waiting).values(new_meta=bindparam("_meta")) + session.execute(stmt, data_to_update) + session.commit() + + def fetch_docs_changed_meta(self, group: str) -> List[DocMetaChangedRow]: + rows = [] + conds = [KBGroupDocuments.group_name == group, KBGroupDocuments.new_meta.isnot(None)] + with self._db_lock, self._Session() as session: + rows = ( + session.query(KBDocument.path, KBGroupDocuments.new_meta) + .join(KBGroupDocuments, KBDocument.doc_id == KBGroupDocuments.doc_id) + .filter(*conds).all() + ) + stmt = update(KBGroupDocuments).where(sqlalchemy.and_(*conds)).values(new_meta=None) + session.execute(stmt) + session.commit() + return rows + def list_all_kb_group(self): with self._db_lock, sqlite3.connect(self._db_path, check_same_thread=self._check_same_thread) as conn: cursor = conn.execute("SELECT group_name FROM document_groups") diff --git a/tests/basic_tests/test_document.py b/tests/basic_tests/test_document.py index 66f142eb..deb706ec 100644 --- a/tests/basic_tests/test_document.py +++ b/tests/basic_tests/test_document.py @@ -228,6 +228,45 @@ def test_delete_files_in_store(self): nodes = self.doc_impl.store.get_nodes(LAZY_ROOT_NAME) assert len(nodes) == 1 assert nodes[0].global_metadata[RAG_DOC_ID] == test2_docid + cur_meta_dict = nodes[0].global_metadata + + url = f'{self.doc_server_addr}/add_metadata' + response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={"title": "title2"})) + assert response.status_code == 200 and response.json().get('code') == 200 + time.sleep(20) + assert cur_meta_dict["title"] == "title2" + + response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={"title": "TITLE2"})) + assert response.status_code == 200 and response.json().get('code') == 200 + time.sleep(20) + assert cur_meta_dict["title"] == ["title2", "TITLE2"] + + url = f'{self.doc_server_addr}/delete_metadata_item' + response = httpx.post(url, json=dict(doc_ids=[test2_docid], keys=["signature"])) + assert response.status_code == 200 and response.json().get('code') == 200 + time.sleep(20) + assert "signature" not in cur_meta_dict + + response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={"title": "TITLE2"})) + assert response.status_code == 200 and response.json().get('code') == 200 + time.sleep(20) + assert cur_meta_dict["title"] == ["title2"] + + url = f'{self.doc_server_addr}/update_or_create_metadata_keys' + response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={"signature": "signature2"})) + assert response.status_code == 200 and response.json().get('code') == 200 + time.sleep(20) + assert cur_meta_dict["signature"] == "signature2" + + url = f'{self.doc_server_addr}/reset_metadata' + response = httpx.post(url, json=dict(doc_ids=[test2_docid], + new_meta={"author": "author2", "signature": "signature_new"})) + assert response.status_code == 200 and response.json().get('code') == 200 + time.sleep(20) + assert cur_meta_dict["signature"] == "signature_new" and cur_meta_dict["author"] == "author2" + + url = f'{self.doc_server_addr}/query_metadata' + response = httpx.post(url, json=dict(doc_id=test2_docid)) # make sure that only one file is left response = httpx.get(f'{self.doc_server_addr}/list_files')