From fec142d6e7e23d239686bdc726aa8278a6bdf875 Mon Sep 17 00:00:00 2001 From: Nikolas Ramstedt Date: Tue, 3 Dec 2024 16:28:19 +0100 Subject: [PATCH] Feat/rag references (#192) * fix(FileUploadService): simplify get_file_infos............ * chore(rag_scoring): format * fix(FileUploadService): simplify * feat(VectorService): add optional document metadata arg * feat(upload_and_vectorize_handler): extract metadata per chunk and save it to vector * feat(RagScoringPipeline): propagate and expose document metadata to LLM context (references etc) * fix(upload_and_vectorize_handler): clean up * chore(file_parser.py): refactor * feat(file_parser.py): allow parsing remote files (by url) * feat(CollectionMetadataModel): add urls * feat(upload_and_vectorize_handler): aggregate urls & files * feat(upload_view): add textarea for urls * fix(upload_and_vectorize_handler): include deps * fix(upload_and_vectorize_handler): make files optional * add(HTMLParser): extract and include page title as metadata * feat(upload_view): add no-wrap css and increase textarea rows * add(file_parser): error handler & null parser * fix(upload_view): move up urls * chore(upload_and_vectorize_handler): simplify * fix(HTMLParser): rm self from args * fix(upload_and_vectorize_handler): only create collection when we have chunks * fix(after rebase): save meta data * feat(SvelteMarkdownWrapper): add custom link renderer (in order to target blank) * output context as json * fix(upload_and_vectorize_handler): allow urls or files (or a combination of both) --- .../assistant/pipeline/rag_scoring.py | 41 +++--- .../fai_backend/collection/models.py | 1 + .../fai_backend/collection/service.py | 2 + .../fai_backend/documents/routes.py | 90 ++++++++---- .../fai_backend/files/file_parser.py | 136 ++++++++++++++---- .../fai-backend/fai_backend/files/service.py | 12 +- .../fai-backend/fai_backend/vector/service.py | 14 +- .../components/SvelteMarkdownWrapper.svelte | 13 +- .../src/lib/components/markdown/Link.svelte | 17 +++ 9 files changed, 239 insertions(+), 87 deletions(-) create mode 100644 fai-rag-app/fai-frontend/src/lib/components/markdown/Link.svelte diff --git a/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py b/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py index cf9bc729..2b8c0c93 100644 --- a/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py +++ b/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py @@ -1,11 +1,11 @@ import json -from typing import Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any -from langstream import Stream +from langstream import Stream, as_async_generator -from fai_backend.assistant.protocol import IAssistantPipelineStrategy, IAssistantContextStore +from fai_backend.assistant.protocol import IAssistantContextStore, IAssistantPipelineStrategy from fai_backend.collection.dependencies import get_collection_service -from fai_backend.llm.service import query_vector from fai_backend.projects.dependencies import get_project_service from fai_backend.vector.factory import vector_db from fai_backend.vector.service import VectorService @@ -19,25 +19,22 @@ async def create_pipeline( async def run_rag_stream(query: list[str]): collection_id = context_store.get_mutable().files_collection_id vector_service = VectorService(vector_db=vector_db, collection_meta_service=get_collection_service()) - vector_db_query_result = await query_vector( - vector_service=vector_service, + + result = await vector_service.query_from_collection( collection_name=collection_id, - query=query[0], + query_texts=[query[0]], + n_results=10, ) - documents: [str] = [] - - def store_and_return_document(document: str): - documents.append(document) - return document + documents, documents_metadata = result['documents'][0], result['metadatas'][0] def append_score_to_documents(scores): - z = zip(documents, [s[0] for s in scores]) + z = zip(documents, documents_metadata, [s[0] for s in scores]) return z def sort_and_slice_documents(scored_documents, slice_size: int): first_element = list(scored_documents)[0] - sorted_scores = sorted(first_element, key=lambda x: x[1], reverse=True) + sorted_scores = sorted(first_element, key=lambda x: x[2], reverse=True) return sorted_scores[:slice_size] projects = await get_project_service().read_projects() @@ -51,7 +48,7 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]: stream = await assistant.create_stream() scoring_context_store.get_mutable().rag_document = document - full = "" + full = '' async for o in stream(query[0]): if o.final: full += o.data @@ -60,13 +57,15 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]: yield score full_stream = ( - vector_db_query_result - .map(store_and_return_document) + Stream[None, str]( + 'QueryVectorStream', + lambda _: as_async_generator(*documents) + ) .map(scoring_stream) .gather() .and_then(append_score_to_documents) .and_then(lambda scored_documents: sort_and_slice_documents(scored_documents, 6)) - .and_then(lambda results: {"query": query, "results": results[0]}) + .and_then(lambda results: {'query': query, 'results': results[0]}) ) async for r in full_stream(query[0]): @@ -74,7 +73,11 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]: def rag_postprocess(in_data: Any): results: list[str] = in_data[0]['results'] - concatenated = "\n".join([s for (s, _) in results]) + concatenated = '\n\n'.join([json.dumps({ + **{'text': s}, + **m + }) for (s, m, _) in results]) + context_store.get_mutable().rag_output = concatenated return concatenated diff --git a/fai-rag-app/fai-backend/fai_backend/collection/models.py b/fai-rag-app/fai-backend/fai_backend/collection/models.py index 488c8c83..3c86cfed 100644 --- a/fai-rag-app/fai-backend/fai_backend/collection/models.py +++ b/fai-rag-app/fai-backend/fai_backend/collection/models.py @@ -6,6 +6,7 @@ class CollectionMetadataModel(Document): label: str = '' description: str = '' embedding_model: str | None = '' + urls: list[str] | None = None class Settings: name = 'collections' diff --git a/fai-rag-app/fai-backend/fai_backend/collection/service.py b/fai-rag-app/fai-backend/fai_backend/collection/service.py index 9f728647..b4c868a0 100644 --- a/fai-rag-app/fai-backend/fai_backend/collection/service.py +++ b/fai-rag-app/fai-backend/fai_backend/collection/service.py @@ -15,6 +15,7 @@ async def create_collection_metadata( label: str, description: str = '', embedding_model: str | None = None, + urls: list[str] | None = None ): collection_metadata = CollectionMetadataModel( @@ -22,6 +23,7 @@ async def create_collection_metadata( label=label, description=description, embedding_model=embedding_model, + urls=urls ) return await self.repo.create(collection_metadata) diff --git a/fai-rag-app/fai-backend/fai_backend/documents/routes.py b/fai-rag-app/fai-backend/fai_backend/documents/routes.py index e4c249b9..d9d8049b 100644 --- a/fai-rag-app/fai-backend/fai_backend/documents/routes.py +++ b/fai-rag-app/fai-backend/fai_backend/documents/routes.py @@ -1,29 +1,27 @@ from tempfile import NamedTemporaryFile +from typing import Annotated -from fastapi import APIRouter, Depends, Form, Security, UploadFile +from fastapi import APIRouter, Depends, File, Form, Security, UploadFile from fai_backend.collection.dependencies import get_collection_service from fai_backend.collection.service import CollectionService from fai_backend.config import settings from fai_backend.dependencies import get_authenticated_user, get_page_template_for_logged_in_users, get_project_user from fai_backend.files.dependecies import get_file_upload_service -from fai_backend.files.file_parser import ParserFactory +from fai_backend.files.file_parser import ParserFactory, is_url from fai_backend.files.service import FileUploadService from fai_backend.framework import components as c from fai_backend.framework import events as e -from fai_backend.logger.route_class import APIRouter as LoggingAPIRouter from fai_backend.phrase import phrase as _ -from fai_backend.projects.dependencies import get_project_service, list_projects_request -from fai_backend.projects.schema import ProjectResponse -from fai_backend.projects.service import ProjectService from fai_backend.schema import ProjectUser from fai_backend.vector.dependencies import get_vector_service from fai_backend.vector.service import VectorService +router_base = APIRouter() + router = APIRouter( prefix='/api', tags=['Documents'], - route_class=LoggingAPIRouter, ) @@ -88,17 +86,23 @@ def upload_view(view=Depends(get_page_template_for_logged_in_users)) -> list: components=[ c.InputField( name='collection_label', - title=_('input_fileupload_collection_label', - 'Collection label (optional)'), - placeholder=_('input_fileupload_collection_placeholder', - 'Collection label (optional)'), + label=_('Collection label (optional)'), + placeholder=_('Collection label (optional)'), required=False, html_type='text', ), + c.Textarea( + name='urls', + placeholder=_('urls', 'URLs'), + label=_('urls', 'URLs'), + required=False, + class_name='whitespace-nowrap', + rows=6 + ), c.FileInput( name='files', - title=_('file', 'File'), - required=True, + label=_('file', 'File'), + required=False, multiple=True, file_size_limit=settings.FILE_SIZE_LIMIT, ), @@ -108,42 +112,70 @@ def upload_view(view=Depends(get_page_template_for_logged_in_users)) -> list: class_name='btn btn-primary', ), ], - class_name='card bg-base-100 w-full max-w-6xl', + class_name='card-body', ), - ]), + ], class_name='card bg-base-100 w-full max-w-xl'), ])], _('upload_documents', 'Upload documents')) @router.post('/documents/upload_and_vectorize', response_model=list, response_model_exclude_none=True) async def upload_and_vectorize_handler( - files: list[UploadFile], + files: Annotated[ + list[UploadFile], File(description='Multiple files as UploadFile') + ], collection_label: str = Form(None), + urls: str = Form(None), project_user: ProjectUser = Depends(get_project_user), file_service: FileUploadService = Depends(get_file_upload_service), vector_service: VectorService = Depends(get_vector_service), view=Depends(get_page_template_for_logged_in_users), - projects: list[ProjectResponse] = Depends(list_projects_request), - project_service: ProjectService = Depends(get_project_service), collection_service: CollectionService = Depends(get_collection_service), ) -> list: - upload_path = file_service.save_files(project_user.project_id, files) - - upload_directory_name = upload_path.split('/')[-1] - await vector_service.create_collection(collection_name=upload_directory_name, - embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL) + list_of_files = [file for file in files if len(file.filename) > 0] + list_of_urls = [url for url in (urls or '').splitlines() if is_url(url)] + upload_path = file_service.save_files(project_user.project_id, list_of_files) + collection_name = upload_path.split('/')[-1] + + chunks = [ + { + 'document': element.text, + 'document_meta': { + key: value + for key, value in {**element.metadata.to_dict(), }.items() + if key in ['filename', 'url', 'page_number', 'page_name'] + } + } + for file_or_url in [ + *[file.path for file in file_service.get_file_infos(upload_path)], + *[url for url in list_of_urls] + ] + for element in ParserFactory.get_parser(file_or_url).parse(file_or_url) if element + ] + + if len(chunks) == 0: + return view( + c.FireEvent(event=e.GoToEvent(url='/documents/upload')), + _('submit_a_question', 'Create Question'), + ) + + await vector_service.create_collection( + collection_name=collection_name, + embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL + ) - parsed_files = file_service.parse_files(upload_path) await vector_service.add_documents_without_id_to_empty_collection( - collection_name=upload_directory_name, - documents=parsed_files, - embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL + collection_name=collection_name, + documents=[chunk['document'] for chunk in chunks], + embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL, + documents_metadata=[chunk['document_meta'] for chunk in chunks], ) await collection_service.create_collection_metadata( - collection_id=upload_directory_name or '', + collection_id=collection_name or '', label=collection_label or '', description='', embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL, + urls=[url for url in list_of_urls] ) return view( @@ -161,5 +193,5 @@ def parse_document( temp.flush() parser = ParserFactory.get_parser(temp.name) parsed = parser.parse(temp.name) - joined = "\n\n".join([p.text for p in parsed]) + joined = '\n\n'.join([p.text for p in parsed]) return joined diff --git a/fai-rag-app/fai-backend/fai_backend/files/file_parser.py b/fai-rag-app/fai-backend/fai_backend/files/file_parser.py index 93d3dacc..b544c7bd 100644 --- a/fai-rag-app/fai-backend/fai_backend/files/file_parser.py +++ b/fai-rag-app/fai-backend/fai_backend/files/file_parser.py @@ -1,12 +1,47 @@ -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import wraps +from typing import Any +from urllib.parse import urlparse import magic +import requests from unstructured.documents.elements import Element from unstructured.partition.docx import partition_docx +from unstructured.partition.html import partition_html from unstructured.partition.md import partition_md from unstructured.partition.pdf import partition_pdf from unstructured.partition.xlsx import partition_xlsx -from unstructured.partition.html import partition_html + +from fai_backend.config import settings +from fai_backend.logger.console import console + + +def error_handler(default_return: Any = None): + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + console.log(f'Error in {func.__name__}:\n{e}') + return default_return + + return wrapper + + return decorator + + +def is_url(string: str) -> bool: + parsed = urlparse(string) + return parsed.scheme in {'http', 'https'} + + +@error_handler(default_return='null') +def get_mime_type(file_path: str) -> str: + if is_url(file_path): + return magic.from_buffer(requests.get(file_path, verify=settings.ENV_MODE != 'development').content, mime=True) + return magic.from_file(file_path, mime=True) class AbstractDocumentParser(ABC): @@ -16,44 +51,95 @@ def parse(self, filename: str) -> list[Element]: class DocxParser(AbstractDocumentParser): - def parse(self, filename: str): - return partition_docx(filename, chunking_strategy="basic") + @error_handler(default_return=[]) + def parse(self, filename: str) -> list[Element]: + if is_url(filename): + return partition_docx( + url=filename, + ssl_verify=settings.ENV_MODE != 'development', + chunking_strategy='basic' + ) + return partition_docx(filename, chunking_strategy='basic') class PDFParser(AbstractDocumentParser): - def parse(self, filename: str): - return partition_pdf(filename, chunking_strategy="basic") + @error_handler(default_return=[]) + def parse(self, filename: str) -> list[Element]: + if is_url(filename): + return partition_pdf( + url=filename, + ssl_verify=settings.ENV_MODE != 'development', + chunking_strategy='basic' + ) + return partition_pdf(filename, chunking_strategy='basic') class MarkdownParser(AbstractDocumentParser): - def parse(self, filename: str): - return partition_md(filename) + @error_handler(default_return=[]) + def parse(self, filename: str) -> list[Element]: + if is_url(filename): + return partition_md( + url=filename, + ssl_verify=settings.ENV_MODE != 'development', + chunking_strategy='basic' + ) + return partition_md(filename, chunking_strategy='basic') class ExcelParser(AbstractDocumentParser): - def parse(self, filename: str): + @error_handler(default_return=[]) + def parse(self, filename: str) -> list[Element]: + if is_url(filename): + return partition_xlsx( + url=filename, + ssl_verify=settings.ENV_MODE != 'development' + ) return partition_xlsx(filename) class HTMLParser(AbstractDocumentParser): - def parse(self, filename: str): - return partition_html(filename) + @staticmethod + def _parse_html(filename: str) -> list[Element]: + if is_url(filename): + return partition_html( + url=filename, + ssl_verify=settings.ENV_MODE != 'development', + chunking_strategy='basic' + ) + return partition_html(filename, chunking_strategy='basic') + + @error_handler(default_return=[]) + def parse(self, filename: str) -> list[Element]: + chunks = self._parse_html(filename) + title = chunks[0].metadata.orig_elements[0].text if len(chunks) > 0 else None + for chunk in chunks: + chunk.metadata.page_name = title + return chunks + + +class NullParser(AbstractDocumentParser): + def parse(self, filename: str) -> list[Element]: + return [] class ParserFactory: + MIME_TYPE_MAPPING: dict[str, type[AbstractDocumentParser]] = { + 'application/pdf': PDFParser, + 'text/plain': MarkdownParser, + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': DocxParser, + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': ExcelParser, + 'text/html': HTMLParser, + 'null': NullParser + } + @staticmethod def get_parser(file_path: str) -> AbstractDocumentParser: - mime_type = magic.from_file(file_path, mime=True) - - if mime_type == 'application/pdf': - return PDFParser() - if mime_type == 'text/plain': - return MarkdownParser() - if mime_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': - return DocxParser() - if mime_type == 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': - return ExcelParser() - if mime_type == 'text/html': - return HTMLParser() - - raise ValueError(f'Unsupported file type: {mime_type}') + console.log(f'Getting parser for file: {file_path}') + mime_type = get_mime_type(file_path) + parser_cls = ParserFactory.MIME_TYPE_MAPPING.get(mime_type) + + if parser_cls is None: + console.log(f'Unsupported file type: {mime_type}') + return NullParser() + + return parser_cls() diff --git a/fai-rag-app/fai-backend/fai_backend/files/service.py b/fai-rag-app/fai-backend/fai_backend/files/service.py index a57c1281..a65486f6 100644 --- a/fai-rag-app/fai-backend/fai_backend/files/service.py +++ b/fai-rag-app/fai-backend/fai_backend/files/service.py @@ -29,13 +29,16 @@ def _generate_upload_path(self, project_id: str) -> str: def save_files(self, project_id: str, files: list[UploadFile]) -> str: upload_path = self._generate_upload_path(project_id) + if not files: + return upload_path + for file in files: file_location = os.path.join(upload_path, file.filename) with open(file_location, 'wb+') as file_object: file_object.write(file.file.read()) return upload_path - def get_file_infos(self, directory_path, upload_date: datetime) -> list[FileInfo]: + def get_file_infos(self, directory_path) -> list[FileInfo]: file_infos = [] for file_name in os.listdir(directory_path): file_path = os.path.join(directory_path, file_name) @@ -49,7 +52,7 @@ def get_file_infos(self, directory_path, upload_date: datetime) -> list[FileInfo collection=file_path.split('/')[-2], # TODO: niceify mime_type=mime_type or 'application/octet-stream', last_modified=datetime.fromtimestamp(stat.st_mtime), - upload_date=upload_date, + upload_date=datetime.fromtimestamp(os.path.getctime(directory_path)), created_date=datetime.fromtimestamp(stat.st_ctime) )) @@ -64,7 +67,7 @@ def list_files(self, project_id: str) -> list[FileInfo]: full_paths = [os.path.join(self.upload_dir, path) for path in project_directories] all_files = [file for path in full_paths for file in - self.get_file_infos(path, datetime.fromtimestamp(os.path.getctime(path)))] + self.get_file_infos(path)] return sorted(all_files, key=lambda x: x.upload_date, reverse=True) @@ -81,8 +84,7 @@ def get_latest_upload_path(self, project_id: str) -> str | None: def parse_files(self, src_directory_path: str) -> list[str]: parsed_files = [] - upload_date = datetime.fromtimestamp(os.path.getctime(src_directory_path)) - files = self.get_file_infos(src_directory_path, upload_date) + files = self.get_file_infos(src_directory_path) for file in files: parser = ParserFactory.get_parser(file.path) diff --git a/fai-rag-app/fai-backend/fai_backend/vector/service.py b/fai-rag-app/fai-backend/fai_backend/vector/service.py index 044da0bd..355feacd 100644 --- a/fai-rag-app/fai-backend/fai_backend/vector/service.py +++ b/fai-rag-app/fai-backend/fai_backend/vector/service.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Optional from fai_backend.collection.service import CollectionService @@ -26,12 +27,15 @@ async def add_to_collection( collection_name: str, ids: OneOrMany[str], documents: Optional[OneOrMany[str]], - embedding_model: str | None = None + embedding_model: str | None = None, + documents_metadata: Mapping[str, str | int | float | bool] | list[ + Mapping[str, str | int | float | bool]] | None = None, ) -> None: await self.vector_db.add( collection_name=collection_name, ids=ids, documents=documents, + metadatas=documents_metadata, embedding_function=await EmbeddingFnFactory.create(embedding_model) ) @@ -39,7 +43,9 @@ async def add_documents_without_id_to_empty_collection( self, collection_name: str, documents: list[str], - embedding_model: str | None = None + embedding_model: str | None = None, + documents_metadata: Mapping[str, str | int | float | bool] | list[ + Mapping[str, str | int | float | bool]] | None = None, ) -> None: """ Add documents to a collection without specifying ID's @@ -47,11 +53,13 @@ async def add_documents_without_id_to_empty_collection( The collection should be empty before calling this method to avoid ID conflicts. """ ids = [str(i) for i in range(len(documents))] + await self.add_to_collection( collection_name=collection_name, ids=ids, documents=documents, - embedding_model=embedding_model + embedding_model=embedding_model, + documents_metadata=documents_metadata, ) async def query_from_collection( diff --git a/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte b/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte index 3503606c..061853ab 100644 --- a/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte +++ b/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte @@ -1,17 +1,18 @@ - + diff --git a/fai-rag-app/fai-frontend/src/lib/components/markdown/Link.svelte b/fai-rag-app/fai-frontend/src/lib/components/markdown/Link.svelte new file mode 100644 index 00000000..4199cc53 --- /dev/null +++ b/fai-rag-app/fai-frontend/src/lib/components/markdown/Link.svelte @@ -0,0 +1,17 @@ + + +