diff --git a/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py b/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py
index cf9bc729..2b8c0c93 100644
--- a/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py
+++ b/fai-rag-app/fai-backend/fai_backend/assistant/pipeline/rag_scoring.py
@@ -1,11 +1,11 @@
import json
-from typing import Any, AsyncGenerator
+from collections.abc import AsyncGenerator
+from typing import Any
-from langstream import Stream
+from langstream import Stream, as_async_generator
-from fai_backend.assistant.protocol import IAssistantPipelineStrategy, IAssistantContextStore
+from fai_backend.assistant.protocol import IAssistantContextStore, IAssistantPipelineStrategy
from fai_backend.collection.dependencies import get_collection_service
-from fai_backend.llm.service import query_vector
from fai_backend.projects.dependencies import get_project_service
from fai_backend.vector.factory import vector_db
from fai_backend.vector.service import VectorService
@@ -19,25 +19,22 @@ async def create_pipeline(
async def run_rag_stream(query: list[str]):
collection_id = context_store.get_mutable().files_collection_id
vector_service = VectorService(vector_db=vector_db, collection_meta_service=get_collection_service())
- vector_db_query_result = await query_vector(
- vector_service=vector_service,
+
+ result = await vector_service.query_from_collection(
collection_name=collection_id,
- query=query[0],
+ query_texts=[query[0]],
+ n_results=10,
)
- documents: [str] = []
-
- def store_and_return_document(document: str):
- documents.append(document)
- return document
+ documents, documents_metadata = result['documents'][0], result['metadatas'][0]
def append_score_to_documents(scores):
- z = zip(documents, [s[0] for s in scores])
+ z = zip(documents, documents_metadata, [s[0] for s in scores])
return z
def sort_and_slice_documents(scored_documents, slice_size: int):
first_element = list(scored_documents)[0]
- sorted_scores = sorted(first_element, key=lambda x: x[1], reverse=True)
+ sorted_scores = sorted(first_element, key=lambda x: x[2], reverse=True)
return sorted_scores[:slice_size]
projects = await get_project_service().read_projects()
@@ -51,7 +48,7 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]:
stream = await assistant.create_stream()
scoring_context_store.get_mutable().rag_document = document
- full = ""
+ full = ''
async for o in stream(query[0]):
if o.final:
full += o.data
@@ -60,13 +57,15 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]:
yield score
full_stream = (
- vector_db_query_result
- .map(store_and_return_document)
+ Stream[None, str](
+ 'QueryVectorStream',
+ lambda _: as_async_generator(*documents)
+ )
.map(scoring_stream)
.gather()
.and_then(append_score_to_documents)
.and_then(lambda scored_documents: sort_and_slice_documents(scored_documents, 6))
- .and_then(lambda results: {"query": query, "results": results[0]})
+ .and_then(lambda results: {'query': query, 'results': results[0]})
)
async for r in full_stream(query[0]):
@@ -74,7 +73,11 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]:
def rag_postprocess(in_data: Any):
results: list[str] = in_data[0]['results']
- concatenated = "\n".join([s for (s, _) in results])
+ concatenated = '\n\n'.join([json.dumps({
+ **{'text': s},
+ **m
+ }) for (s, m, _) in results])
+
context_store.get_mutable().rag_output = concatenated
return concatenated
diff --git a/fai-rag-app/fai-backend/fai_backend/collection/models.py b/fai-rag-app/fai-backend/fai_backend/collection/models.py
index 488c8c83..3c86cfed 100644
--- a/fai-rag-app/fai-backend/fai_backend/collection/models.py
+++ b/fai-rag-app/fai-backend/fai_backend/collection/models.py
@@ -6,6 +6,7 @@ class CollectionMetadataModel(Document):
label: str = ''
description: str = ''
embedding_model: str | None = ''
+ urls: list[str] | None = None
class Settings:
name = 'collections'
diff --git a/fai-rag-app/fai-backend/fai_backend/collection/service.py b/fai-rag-app/fai-backend/fai_backend/collection/service.py
index 9f728647..b4c868a0 100644
--- a/fai-rag-app/fai-backend/fai_backend/collection/service.py
+++ b/fai-rag-app/fai-backend/fai_backend/collection/service.py
@@ -15,6 +15,7 @@ async def create_collection_metadata(
label: str,
description: str = '',
embedding_model: str | None = None,
+ urls: list[str] | None = None
):
collection_metadata = CollectionMetadataModel(
@@ -22,6 +23,7 @@ async def create_collection_metadata(
label=label,
description=description,
embedding_model=embedding_model,
+ urls=urls
)
return await self.repo.create(collection_metadata)
diff --git a/fai-rag-app/fai-backend/fai_backend/documents/routes.py b/fai-rag-app/fai-backend/fai_backend/documents/routes.py
index e4c249b9..d9d8049b 100644
--- a/fai-rag-app/fai-backend/fai_backend/documents/routes.py
+++ b/fai-rag-app/fai-backend/fai_backend/documents/routes.py
@@ -1,29 +1,27 @@
from tempfile import NamedTemporaryFile
+from typing import Annotated
-from fastapi import APIRouter, Depends, Form, Security, UploadFile
+from fastapi import APIRouter, Depends, File, Form, Security, UploadFile
from fai_backend.collection.dependencies import get_collection_service
from fai_backend.collection.service import CollectionService
from fai_backend.config import settings
from fai_backend.dependencies import get_authenticated_user, get_page_template_for_logged_in_users, get_project_user
from fai_backend.files.dependecies import get_file_upload_service
-from fai_backend.files.file_parser import ParserFactory
+from fai_backend.files.file_parser import ParserFactory, is_url
from fai_backend.files.service import FileUploadService
from fai_backend.framework import components as c
from fai_backend.framework import events as e
-from fai_backend.logger.route_class import APIRouter as LoggingAPIRouter
from fai_backend.phrase import phrase as _
-from fai_backend.projects.dependencies import get_project_service, list_projects_request
-from fai_backend.projects.schema import ProjectResponse
-from fai_backend.projects.service import ProjectService
from fai_backend.schema import ProjectUser
from fai_backend.vector.dependencies import get_vector_service
from fai_backend.vector.service import VectorService
+router_base = APIRouter()
+
router = APIRouter(
prefix='/api',
tags=['Documents'],
- route_class=LoggingAPIRouter,
)
@@ -88,17 +86,23 @@ def upload_view(view=Depends(get_page_template_for_logged_in_users)) -> list:
components=[
c.InputField(
name='collection_label',
- title=_('input_fileupload_collection_label',
- 'Collection label (optional)'),
- placeholder=_('input_fileupload_collection_placeholder',
- 'Collection label (optional)'),
+ label=_('Collection label (optional)'),
+ placeholder=_('Collection label (optional)'),
required=False,
html_type='text',
),
+ c.Textarea(
+ name='urls',
+ placeholder=_('urls', 'URLs'),
+ label=_('urls', 'URLs'),
+ required=False,
+ class_name='whitespace-nowrap',
+ rows=6
+ ),
c.FileInput(
name='files',
- title=_('file', 'File'),
- required=True,
+ label=_('file', 'File'),
+ required=False,
multiple=True,
file_size_limit=settings.FILE_SIZE_LIMIT,
),
@@ -108,42 +112,70 @@ def upload_view(view=Depends(get_page_template_for_logged_in_users)) -> list:
class_name='btn btn-primary',
),
],
- class_name='card bg-base-100 w-full max-w-6xl',
+ class_name='card-body',
),
- ]),
+ ], class_name='card bg-base-100 w-full max-w-xl'),
])], _('upload_documents', 'Upload documents'))
@router.post('/documents/upload_and_vectorize', response_model=list, response_model_exclude_none=True)
async def upload_and_vectorize_handler(
- files: list[UploadFile],
+ files: Annotated[
+ list[UploadFile], File(description='Multiple files as UploadFile')
+ ],
collection_label: str = Form(None),
+ urls: str = Form(None),
project_user: ProjectUser = Depends(get_project_user),
file_service: FileUploadService = Depends(get_file_upload_service),
vector_service: VectorService = Depends(get_vector_service),
view=Depends(get_page_template_for_logged_in_users),
- projects: list[ProjectResponse] = Depends(list_projects_request),
- project_service: ProjectService = Depends(get_project_service),
collection_service: CollectionService = Depends(get_collection_service),
) -> list:
- upload_path = file_service.save_files(project_user.project_id, files)
-
- upload_directory_name = upload_path.split('/')[-1]
- await vector_service.create_collection(collection_name=upload_directory_name,
- embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL)
+ list_of_files = [file for file in files if len(file.filename) > 0]
+ list_of_urls = [url for url in (urls or '').splitlines() if is_url(url)]
+ upload_path = file_service.save_files(project_user.project_id, list_of_files)
+ collection_name = upload_path.split('/')[-1]
+
+ chunks = [
+ {
+ 'document': element.text,
+ 'document_meta': {
+ key: value
+ for key, value in {**element.metadata.to_dict(), }.items()
+ if key in ['filename', 'url', 'page_number', 'page_name']
+ }
+ }
+ for file_or_url in [
+ *[file.path for file in file_service.get_file_infos(upload_path)],
+ *[url for url in list_of_urls]
+ ]
+ for element in ParserFactory.get_parser(file_or_url).parse(file_or_url) if element
+ ]
+
+ if len(chunks) == 0:
+ return view(
+ c.FireEvent(event=e.GoToEvent(url='/documents/upload')),
+ _('submit_a_question', 'Create Question'),
+ )
+
+ await vector_service.create_collection(
+ collection_name=collection_name,
+ embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL
+ )
- parsed_files = file_service.parse_files(upload_path)
await vector_service.add_documents_without_id_to_empty_collection(
- collection_name=upload_directory_name,
- documents=parsed_files,
- embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL
+ collection_name=collection_name,
+ documents=[chunk['document'] for chunk in chunks],
+ embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL,
+ documents_metadata=[chunk['document_meta'] for chunk in chunks],
)
await collection_service.create_collection_metadata(
- collection_id=upload_directory_name or '',
+ collection_id=collection_name or '',
label=collection_label or '',
description='',
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL,
+ urls=[url for url in list_of_urls]
)
return view(
@@ -161,5 +193,5 @@ def parse_document(
temp.flush()
parser = ParserFactory.get_parser(temp.name)
parsed = parser.parse(temp.name)
- joined = "\n\n".join([p.text for p in parsed])
+ joined = '\n\n'.join([p.text for p in parsed])
return joined
diff --git a/fai-rag-app/fai-backend/fai_backend/files/file_parser.py b/fai-rag-app/fai-backend/fai_backend/files/file_parser.py
index 93d3dacc..b544c7bd 100644
--- a/fai-rag-app/fai-backend/fai_backend/files/file_parser.py
+++ b/fai-rag-app/fai-backend/fai_backend/files/file_parser.py
@@ -1,12 +1,47 @@
-from abc import abstractmethod, ABC
+from abc import ABC, abstractmethod
+from collections.abc import Callable
+from functools import wraps
+from typing import Any
+from urllib.parse import urlparse
import magic
+import requests
from unstructured.documents.elements import Element
from unstructured.partition.docx import partition_docx
+from unstructured.partition.html import partition_html
from unstructured.partition.md import partition_md
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.xlsx import partition_xlsx
-from unstructured.partition.html import partition_html
+
+from fai_backend.config import settings
+from fai_backend.logger.console import console
+
+
+def error_handler(default_return: Any = None):
+ def decorator(func: Callable):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ console.log(f'Error in {func.__name__}:\n{e}')
+ return default_return
+
+ return wrapper
+
+ return decorator
+
+
+def is_url(string: str) -> bool:
+ parsed = urlparse(string)
+ return parsed.scheme in {'http', 'https'}
+
+
+@error_handler(default_return='null')
+def get_mime_type(file_path: str) -> str:
+ if is_url(file_path):
+ return magic.from_buffer(requests.get(file_path, verify=settings.ENV_MODE != 'development').content, mime=True)
+ return magic.from_file(file_path, mime=True)
class AbstractDocumentParser(ABC):
@@ -16,44 +51,95 @@ def parse(self, filename: str) -> list[Element]:
class DocxParser(AbstractDocumentParser):
- def parse(self, filename: str):
- return partition_docx(filename, chunking_strategy="basic")
+ @error_handler(default_return=[])
+ def parse(self, filename: str) -> list[Element]:
+ if is_url(filename):
+ return partition_docx(
+ url=filename,
+ ssl_verify=settings.ENV_MODE != 'development',
+ chunking_strategy='basic'
+ )
+ return partition_docx(filename, chunking_strategy='basic')
class PDFParser(AbstractDocumentParser):
- def parse(self, filename: str):
- return partition_pdf(filename, chunking_strategy="basic")
+ @error_handler(default_return=[])
+ def parse(self, filename: str) -> list[Element]:
+ if is_url(filename):
+ return partition_pdf(
+ url=filename,
+ ssl_verify=settings.ENV_MODE != 'development',
+ chunking_strategy='basic'
+ )
+ return partition_pdf(filename, chunking_strategy='basic')
class MarkdownParser(AbstractDocumentParser):
- def parse(self, filename: str):
- return partition_md(filename)
+ @error_handler(default_return=[])
+ def parse(self, filename: str) -> list[Element]:
+ if is_url(filename):
+ return partition_md(
+ url=filename,
+ ssl_verify=settings.ENV_MODE != 'development',
+ chunking_strategy='basic'
+ )
+ return partition_md(filename, chunking_strategy='basic')
class ExcelParser(AbstractDocumentParser):
- def parse(self, filename: str):
+ @error_handler(default_return=[])
+ def parse(self, filename: str) -> list[Element]:
+ if is_url(filename):
+ return partition_xlsx(
+ url=filename,
+ ssl_verify=settings.ENV_MODE != 'development'
+ )
return partition_xlsx(filename)
class HTMLParser(AbstractDocumentParser):
- def parse(self, filename: str):
- return partition_html(filename)
+ @staticmethod
+ def _parse_html(filename: str) -> list[Element]:
+ if is_url(filename):
+ return partition_html(
+ url=filename,
+ ssl_verify=settings.ENV_MODE != 'development',
+ chunking_strategy='basic'
+ )
+ return partition_html(filename, chunking_strategy='basic')
+
+ @error_handler(default_return=[])
+ def parse(self, filename: str) -> list[Element]:
+ chunks = self._parse_html(filename)
+ title = chunks[0].metadata.orig_elements[0].text if len(chunks) > 0 else None
+ for chunk in chunks:
+ chunk.metadata.page_name = title
+ return chunks
+
+
+class NullParser(AbstractDocumentParser):
+ def parse(self, filename: str) -> list[Element]:
+ return []
class ParserFactory:
+ MIME_TYPE_MAPPING: dict[str, type[AbstractDocumentParser]] = {
+ 'application/pdf': PDFParser,
+ 'text/plain': MarkdownParser,
+ 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': DocxParser,
+ 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': ExcelParser,
+ 'text/html': HTMLParser,
+ 'null': NullParser
+ }
+
@staticmethod
def get_parser(file_path: str) -> AbstractDocumentParser:
- mime_type = magic.from_file(file_path, mime=True)
-
- if mime_type == 'application/pdf':
- return PDFParser()
- if mime_type == 'text/plain':
- return MarkdownParser()
- if mime_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
- return DocxParser()
- if mime_type == 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet':
- return ExcelParser()
- if mime_type == 'text/html':
- return HTMLParser()
-
- raise ValueError(f'Unsupported file type: {mime_type}')
+ console.log(f'Getting parser for file: {file_path}')
+ mime_type = get_mime_type(file_path)
+ parser_cls = ParserFactory.MIME_TYPE_MAPPING.get(mime_type)
+
+ if parser_cls is None:
+ console.log(f'Unsupported file type: {mime_type}')
+ return NullParser()
+
+ return parser_cls()
diff --git a/fai-rag-app/fai-backend/fai_backend/files/service.py b/fai-rag-app/fai-backend/fai_backend/files/service.py
index a57c1281..a65486f6 100644
--- a/fai-rag-app/fai-backend/fai_backend/files/service.py
+++ b/fai-rag-app/fai-backend/fai_backend/files/service.py
@@ -29,13 +29,16 @@ def _generate_upload_path(self, project_id: str) -> str:
def save_files(self, project_id: str, files: list[UploadFile]) -> str:
upload_path = self._generate_upload_path(project_id)
+ if not files:
+ return upload_path
+
for file in files:
file_location = os.path.join(upload_path, file.filename)
with open(file_location, 'wb+') as file_object:
file_object.write(file.file.read())
return upload_path
- def get_file_infos(self, directory_path, upload_date: datetime) -> list[FileInfo]:
+ def get_file_infos(self, directory_path) -> list[FileInfo]:
file_infos = []
for file_name in os.listdir(directory_path):
file_path = os.path.join(directory_path, file_name)
@@ -49,7 +52,7 @@ def get_file_infos(self, directory_path, upload_date: datetime) -> list[FileInfo
collection=file_path.split('/')[-2], # TODO: niceify
mime_type=mime_type or 'application/octet-stream',
last_modified=datetime.fromtimestamp(stat.st_mtime),
- upload_date=upload_date,
+ upload_date=datetime.fromtimestamp(os.path.getctime(directory_path)),
created_date=datetime.fromtimestamp(stat.st_ctime)
))
@@ -64,7 +67,7 @@ def list_files(self, project_id: str) -> list[FileInfo]:
full_paths = [os.path.join(self.upload_dir, path) for path in project_directories]
all_files = [file for path in full_paths for file in
- self.get_file_infos(path, datetime.fromtimestamp(os.path.getctime(path)))]
+ self.get_file_infos(path)]
return sorted(all_files, key=lambda x: x.upload_date, reverse=True)
@@ -81,8 +84,7 @@ def get_latest_upload_path(self, project_id: str) -> str | None:
def parse_files(self, src_directory_path: str) -> list[str]:
parsed_files = []
- upload_date = datetime.fromtimestamp(os.path.getctime(src_directory_path))
- files = self.get_file_infos(src_directory_path, upload_date)
+ files = self.get_file_infos(src_directory_path)
for file in files:
parser = ParserFactory.get_parser(file.path)
diff --git a/fai-rag-app/fai-backend/fai_backend/vector/service.py b/fai-rag-app/fai-backend/fai_backend/vector/service.py
index 044da0bd..355feacd 100644
--- a/fai-rag-app/fai-backend/fai_backend/vector/service.py
+++ b/fai-rag-app/fai-backend/fai_backend/vector/service.py
@@ -1,3 +1,4 @@
+from collections.abc import Mapping
from typing import Optional
from fai_backend.collection.service import CollectionService
@@ -26,12 +27,15 @@ async def add_to_collection(
collection_name: str,
ids: OneOrMany[str],
documents: Optional[OneOrMany[str]],
- embedding_model: str | None = None
+ embedding_model: str | None = None,
+ documents_metadata: Mapping[str, str | int | float | bool] | list[
+ Mapping[str, str | int | float | bool]] | None = None,
) -> None:
await self.vector_db.add(
collection_name=collection_name,
ids=ids,
documents=documents,
+ metadatas=documents_metadata,
embedding_function=await EmbeddingFnFactory.create(embedding_model)
)
@@ -39,7 +43,9 @@ async def add_documents_without_id_to_empty_collection(
self,
collection_name: str,
documents: list[str],
- embedding_model: str | None = None
+ embedding_model: str | None = None,
+ documents_metadata: Mapping[str, str | int | float | bool] | list[
+ Mapping[str, str | int | float | bool]] | None = None,
) -> None:
"""
Add documents to a collection without specifying ID's
@@ -47,11 +53,13 @@ async def add_documents_without_id_to_empty_collection(
The collection should be empty before calling this method to avoid ID conflicts.
"""
ids = [str(i) for i in range(len(documents))]
+
await self.add_to_collection(
collection_name=collection_name,
ids=ids,
documents=documents,
- embedding_model=embedding_model
+ embedding_model=embedding_model,
+ documents_metadata=documents_metadata,
)
async def query_from_collection(
diff --git a/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte b/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte
index 3503606c..061853ab 100644
--- a/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte
+++ b/fai-rag-app/fai-frontend/src/lib/components/SvelteMarkdownWrapper.svelte
@@ -1,17 +1,18 @@
-
+
diff --git a/fai-rag-app/fai-frontend/src/lib/components/markdown/Link.svelte b/fai-rag-app/fai-frontend/src/lib/components/markdown/Link.svelte
new file mode 100644
index 00000000..4199cc53
--- /dev/null
+++ b/fai-rag-app/fai-frontend/src/lib/components/markdown/Link.svelte
@@ -0,0 +1,17 @@
+
+
+