Skip to content

Commit

Permalink
add test for doc_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 committed Oct 14, 2024
1 parent f6f75fc commit 7484865
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _lazy_init(self) -> None:

if self._dlm:
self._daemon = threading.Thread(target=self.worker)
self._daemon.setDaemon(True)
self._daemon.daemon = True
self._daemon.start()

def _get_store(self) -> BaseStore:
Expand Down
26 changes: 19 additions & 7 deletions lazyllm/tools/rag/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import shutil
import hashlib
from typing import List, Callable, Generator, Dict, Any, Optional
from typing import List, Callable, Generator, Dict, Any, Optional, Union
from abc import ABC, abstractmethod

import pydantic
import sqlite3
from pydantic import BaseModel
from fastapi import UploadFile
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

import lazyllm
from lazyllm import config
Expand Down Expand Up @@ -74,7 +75,7 @@ def get_file_status(self, fileid: str): pass
@abstractmethod
def update_file_status(self, fileid: str, status: str): pass
@abstractmethod
def update_kb_group_file_status(self, group: str, file_ids: str, status: str): pass
def update_kb_group_file_status(self, group: str, file_ids: Union[str, List[str]], status: str): pass
@abstractmethod
def close(self): pass

Expand All @@ -85,7 +86,12 @@ def __init__(self, path, name):
root_dir = os.path.expanduser(os.path.join(config['home'], '.dbs'))
os.system(f'mkdir -p {root_dir}')
self._db_path = os.path.join(root_dir, f'.lazyllm_dlmanager.{self._id}.db')
self._conn = sqlite3.connect(self._db_path)
self._conns = threading.local()

@property
def _conn(self):
if not hasattr(self._conns, 'impl'): self._conns.impl = sqlite3.connect(self._db_path)
return self._conns.impl

def _init_tables(self):
with self._conn:
Expand Down Expand Up @@ -216,10 +222,16 @@ def update_file_status(self, fileid: str, status: str):
with self._conn:
self._conn.execute("UPDATE documents SET status = ? WHERE doc_id = ?", (status, fileid))

def update_kb_group_file_status(self, group: str, file_ids: str, status: str):
with self._conn:
self._conn.execute("UPDATE kb_group_documents SET status = ? WHERE group_name = ? AND doc_id = ?",
(status, group, file_ids))
def update_kb_group_file_status(self, group: str, file_ids: Union[str, List[str]], status: str):
if isinstance(file_ids, str): file_ids = [file_ids]
query = ('UPDATE kb_group_documents SET status = ? WHERE group_name = ? AND doc_id IN '
f'({",".join("?" * len(file_ids))})')
try:
with self._conn:
self._conn.execute(query, [status, group] + file_ids)
except sqlite3.InterfaceError as e:
raise RuntimeError(f'{e}\n args are:\n {status}({type(status)}), {group}({type(group)}),'
f'{file_ids}({type(file_ids)})')

def close(self):
self._conn.close()
Expand Down

0 comments on commit 7484865

Please sign in to comment.