From 7484865ef905d00121aff6a69bc3618a60aec5b7 Mon Sep 17 00:00:00 2001 From: wangzhihong Date: Mon, 14 Oct 2024 15:15:14 +0800 Subject: [PATCH] add test for doc_manager --- lazyllm/tools/rag/doc_impl.py | 2 +- lazyllm/tools/rag/utils.py | 26 +++++++++++++++++++------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index f8c35317..cfaf1e1b 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -62,7 +62,7 @@ def _lazy_init(self) -> None: if self._dlm: self._daemon = threading.Thread(target=self.worker) - self._daemon.setDaemon(True) + self._daemon.daemon = True self._daemon.start() def _get_store(self) -> BaseStore: diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index 4950ac2d..80591562 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -1,7 +1,7 @@ import os import shutil import hashlib -from typing import List, Callable, Generator, Dict, Any, Optional +from typing import List, Callable, Generator, Dict, Any, Optional, Union from abc import ABC, abstractmethod import pydantic @@ -9,6 +9,7 @@ from pydantic import BaseModel from fastapi import UploadFile from concurrent.futures import ThreadPoolExecutor, as_completed +import threading import lazyllm from lazyllm import config @@ -74,7 +75,7 @@ def get_file_status(self, fileid: str): pass @abstractmethod def update_file_status(self, fileid: str, status: str): pass @abstractmethod - def update_kb_group_file_status(self, group: str, file_ids: str, status: str): pass + def update_kb_group_file_status(self, group: str, file_ids: Union[str, List[str]], status: str): pass @abstractmethod def close(self): pass @@ -85,7 +86,12 @@ def __init__(self, path, name): root_dir = os.path.expanduser(os.path.join(config['home'], '.dbs')) os.system(f'mkdir -p {root_dir}') self._db_path = os.path.join(root_dir, f'.lazyllm_dlmanager.{self._id}.db') - self._conn = sqlite3.connect(self._db_path) + self._conns = threading.local() + + @property + def _conn(self): + if not hasattr(self._conns, 'impl'): self._conns.impl = sqlite3.connect(self._db_path) + return self._conns.impl def _init_tables(self): with self._conn: @@ -216,10 +222,16 @@ def update_file_status(self, fileid: str, status: str): with self._conn: self._conn.execute("UPDATE documents SET status = ? WHERE doc_id = ?", (status, fileid)) - def update_kb_group_file_status(self, group: str, file_ids: str, status: str): - with self._conn: - self._conn.execute("UPDATE kb_group_documents SET status = ? WHERE group_name = ? AND doc_id = ?", - (status, group, file_ids)) + def update_kb_group_file_status(self, group: str, file_ids: Union[str, List[str]], status: str): + if isinstance(file_ids, str): file_ids = [file_ids] + query = ('UPDATE kb_group_documents SET status = ? WHERE group_name = ? AND doc_id IN ' + f'({",".join("?" * len(file_ids))})') + try: + with self._conn: + self._conn.execute(query, [status, group] + file_ids) + except sqlite3.InterfaceError as e: + raise RuntimeError(f'{e}\n args are:\n {status}({type(status)}), {group}({type(group)}),' + f'{file_ids}({type(file_ids)})') def close(self): self._conn.close()