Skip to content

Commit

Permalink
Feat/rag references (#192)
Browse files Browse the repository at this point in the history
* fix(FileUploadService): simplify get_file_infos............

* chore(rag_scoring): format

* fix(FileUploadService): simplify

* feat(VectorService): add optional document metadata arg

* feat(upload_and_vectorize_handler): extract metadata per chunk and save it to vector

* feat(RagScoringPipeline): propagate and expose document metadata to LLM context (references etc)

* fix(upload_and_vectorize_handler): clean up

* chore(file_parser.py): refactor

* feat(file_parser.py): allow parsing remote files (by url)

* feat(CollectionMetadataModel): add urls

* feat(upload_and_vectorize_handler): aggregate urls & files

* feat(upload_view): add textarea for urls

* fix(upload_and_vectorize_handler): include deps

* fix(upload_and_vectorize_handler): make files optional

* add(HTMLParser): extract and include page title as metadata

* feat(upload_view): add no-wrap css and increase textarea rows

* add(file_parser): error handler & null parser

* fix(upload_view): move up urls

* chore(upload_and_vectorize_handler): simplify

* fix(HTMLParser): rm self from args

* fix(upload_and_vectorize_handler): only create collection when we have chunks

* fix(after rebase): save meta data

* feat(SvelteMarkdownWrapper): add custom link renderer (in order to target blank)

* output context as json

* fix(upload_and_vectorize_handler): allow urls or files (or a combination of both)
  • Loading branch information
nRamstedt authored Dec 3, 2024
1 parent 8890bcd commit fec142d
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 87 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
from typing import Any, AsyncGenerator
from collections.abc import AsyncGenerator
from typing import Any

from langstream import Stream
from langstream import Stream, as_async_generator

from fai_backend.assistant.protocol import IAssistantPipelineStrategy, IAssistantContextStore
from fai_backend.assistant.protocol import IAssistantContextStore, IAssistantPipelineStrategy
from fai_backend.collection.dependencies import get_collection_service
from fai_backend.llm.service import query_vector
from fai_backend.projects.dependencies import get_project_service
from fai_backend.vector.factory import vector_db
from fai_backend.vector.service import VectorService
Expand All @@ -19,25 +19,22 @@ async def create_pipeline(
async def run_rag_stream(query: list[str]):
collection_id = context_store.get_mutable().files_collection_id
vector_service = VectorService(vector_db=vector_db, collection_meta_service=get_collection_service())
vector_db_query_result = await query_vector(
vector_service=vector_service,

result = await vector_service.query_from_collection(
collection_name=collection_id,
query=query[0],
query_texts=[query[0]],
n_results=10,
)

documents: [str] = []

def store_and_return_document(document: str):
documents.append(document)
return document
documents, documents_metadata = result['documents'][0], result['metadatas'][0]

def append_score_to_documents(scores):
z = zip(documents, [s[0] for s in scores])
z = zip(documents, documents_metadata, [s[0] for s in scores])
return z

def sort_and_slice_documents(scored_documents, slice_size: int):
first_element = list(scored_documents)[0]
sorted_scores = sorted(first_element, key=lambda x: x[1], reverse=True)
sorted_scores = sorted(first_element, key=lambda x: x[2], reverse=True)
return sorted_scores[:slice_size]

projects = await get_project_service().read_projects()
Expand All @@ -51,7 +48,7 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]:
stream = await assistant.create_stream()
scoring_context_store.get_mutable().rag_document = document

full = ""
full = ''
async for o in stream(query[0]):
if o.final:
full += o.data
Expand All @@ -60,21 +57,27 @@ async def scoring_stream(document: str) -> AsyncGenerator[str, None]:
yield score

full_stream = (
vector_db_query_result
.map(store_and_return_document)
Stream[None, str](
'QueryVectorStream',
lambda _: as_async_generator(*documents)
)
.map(scoring_stream)
.gather()
.and_then(append_score_to_documents)
.and_then(lambda scored_documents: sort_and_slice_documents(scored_documents, 6))
.and_then(lambda results: {"query": query, "results": results[0]})
.and_then(lambda results: {'query': query, 'results': results[0]})
)

async for r in full_stream(query[0]):
yield r

def rag_postprocess(in_data: Any):
results: list[str] = in_data[0]['results']
concatenated = "\n".join([s for (s, _) in results])
concatenated = '\n\n'.join([json.dumps({
**{'text': s},
**m
}) for (s, m, _) in results])

context_store.get_mutable().rag_output = concatenated
return concatenated

Expand Down
1 change: 1 addition & 0 deletions fai-rag-app/fai-backend/fai_backend/collection/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class CollectionMetadataModel(Document):
label: str = ''
description: str = ''
embedding_model: str | None = ''
urls: list[str] | None = None

class Settings:
name = 'collections'
Expand Down
2 changes: 2 additions & 0 deletions fai-rag-app/fai-backend/fai_backend/collection/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ async def create_collection_metadata(
label: str,
description: str = '',
embedding_model: str | None = None,
urls: list[str] | None = None

):
collection_metadata = CollectionMetadataModel(
collection_id=collection_id,
label=label,
description=description,
embedding_model=embedding_model,
urls=urls
)

return await self.repo.create(collection_metadata)
Expand Down
90 changes: 61 additions & 29 deletions fai-rag-app/fai-backend/fai_backend/documents/routes.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
from tempfile import NamedTemporaryFile
from typing import Annotated

from fastapi import APIRouter, Depends, Form, Security, UploadFile
from fastapi import APIRouter, Depends, File, Form, Security, UploadFile

from fai_backend.collection.dependencies import get_collection_service
from fai_backend.collection.service import CollectionService
from fai_backend.config import settings
from fai_backend.dependencies import get_authenticated_user, get_page_template_for_logged_in_users, get_project_user
from fai_backend.files.dependecies import get_file_upload_service
from fai_backend.files.file_parser import ParserFactory
from fai_backend.files.file_parser import ParserFactory, is_url
from fai_backend.files.service import FileUploadService
from fai_backend.framework import components as c
from fai_backend.framework import events as e
from fai_backend.logger.route_class import APIRouter as LoggingAPIRouter
from fai_backend.phrase import phrase as _
from fai_backend.projects.dependencies import get_project_service, list_projects_request
from fai_backend.projects.schema import ProjectResponse
from fai_backend.projects.service import ProjectService
from fai_backend.schema import ProjectUser
from fai_backend.vector.dependencies import get_vector_service
from fai_backend.vector.service import VectorService

router_base = APIRouter()

router = APIRouter(
prefix='/api',
tags=['Documents'],
route_class=LoggingAPIRouter,
)


Expand Down Expand Up @@ -88,17 +86,23 @@ def upload_view(view=Depends(get_page_template_for_logged_in_users)) -> list:
components=[
c.InputField(
name='collection_label',
title=_('input_fileupload_collection_label',
'Collection label (optional)'),
placeholder=_('input_fileupload_collection_placeholder',
'Collection label (optional)'),
label=_('Collection label (optional)'),
placeholder=_('Collection label (optional)'),
required=False,
html_type='text',
),
c.Textarea(
name='urls',
placeholder=_('urls', 'URLs'),
label=_('urls', 'URLs'),
required=False,
class_name='whitespace-nowrap',
rows=6
),
c.FileInput(
name='files',
title=_('file', 'File'),
required=True,
label=_('file', 'File'),
required=False,
multiple=True,
file_size_limit=settings.FILE_SIZE_LIMIT,
),
Expand All @@ -108,42 +112,70 @@ def upload_view(view=Depends(get_page_template_for_logged_in_users)) -> list:
class_name='btn btn-primary',
),
],
class_name='card bg-base-100 w-full max-w-6xl',
class_name='card-body',
),
]),
], class_name='card bg-base-100 w-full max-w-xl'),
])], _('upload_documents', 'Upload documents'))


@router.post('/documents/upload_and_vectorize', response_model=list, response_model_exclude_none=True)
async def upload_and_vectorize_handler(
files: list[UploadFile],
files: Annotated[
list[UploadFile], File(description='Multiple files as UploadFile')
],
collection_label: str = Form(None),
urls: str = Form(None),
project_user: ProjectUser = Depends(get_project_user),
file_service: FileUploadService = Depends(get_file_upload_service),
vector_service: VectorService = Depends(get_vector_service),
view=Depends(get_page_template_for_logged_in_users),
projects: list[ProjectResponse] = Depends(list_projects_request),
project_service: ProjectService = Depends(get_project_service),
collection_service: CollectionService = Depends(get_collection_service),
) -> list:
upload_path = file_service.save_files(project_user.project_id, files)

upload_directory_name = upload_path.split('/')[-1]
await vector_service.create_collection(collection_name=upload_directory_name,
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL)
list_of_files = [file for file in files if len(file.filename) > 0]
list_of_urls = [url for url in (urls or '').splitlines() if is_url(url)]
upload_path = file_service.save_files(project_user.project_id, list_of_files)
collection_name = upload_path.split('/')[-1]

chunks = [
{
'document': element.text,
'document_meta': {
key: value
for key, value in {**element.metadata.to_dict(), }.items()
if key in ['filename', 'url', 'page_number', 'page_name']
}
}
for file_or_url in [
*[file.path for file in file_service.get_file_infos(upload_path)],
*[url for url in list_of_urls]
]
for element in ParserFactory.get_parser(file_or_url).parse(file_or_url) if element
]

if len(chunks) == 0:
return view(
c.FireEvent(event=e.GoToEvent(url='/documents/upload')),
_('submit_a_question', 'Create Question'),
)

await vector_service.create_collection(
collection_name=collection_name,
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL
)

parsed_files = file_service.parse_files(upload_path)
await vector_service.add_documents_without_id_to_empty_collection(
collection_name=upload_directory_name,
documents=parsed_files,
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL
collection_name=collection_name,
documents=[chunk['document'] for chunk in chunks],
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL,
documents_metadata=[chunk['document_meta'] for chunk in chunks],
)

await collection_service.create_collection_metadata(
collection_id=upload_directory_name or '',
collection_id=collection_name or '',
label=collection_label or '',
description='',
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL,
urls=[url for url in list_of_urls]
)

return view(
Expand All @@ -161,5 +193,5 @@ def parse_document(
temp.flush()
parser = ParserFactory.get_parser(temp.name)
parsed = parser.parse(temp.name)
joined = "\n\n".join([p.text for p in parsed])
joined = '\n\n'.join([p.text for p in parsed])
return joined
Loading

0 comments on commit fec142d

Please sign in to comment.