Skip to content

Commit

Permalink
Refactor/rag references (#203)
Browse files Browse the repository at this point in the history
* fix(get_mime_type): verify unless dev env

* force unique id per element

* parameterize document ids

* add more-itertools

* refactor(upload_and_vectorize_handler): yield and generate embeddings in chunks

* fix: poetry lock

---------

Co-authored-by: Nikolas Ramstedt <[email protected]>
  • Loading branch information
zilaei and nRamstedt authored Dec 5, 2024
1 parent 59b8d47 commit b969502
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 46 deletions.
58 changes: 30 additions & 28 deletions fai-rag-app/fai-backend/fai_backend/documents/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Annotated

from fastapi import APIRouter, Depends, File, Form, Security, UploadFile
from more_itertools import chunked

from fai_backend.collection.dependencies import get_collection_service
from fai_backend.collection.service import CollectionService
Expand Down Expand Up @@ -131,44 +132,45 @@ async def upload_and_vectorize_handler(
view=Depends(get_page_template_for_logged_in_users),
collection_service: CollectionService = Depends(get_collection_service),
) -> list:
def generate_chunks(file_paths_or_urls: list[str]):
for file_or_url in file_paths_or_urls:
for element in ParserFactory.get_parser(file_or_url).parse(file_or_url):
if len(element.text):
yield {
'document': element.text,
'document_meta': {
key: value
for key, value in element.metadata.to_dict().items()
if key in ['filename', 'url', 'page_number', 'page_name']
},
'document_id': element.id
}

list_of_files = [file for file in files if len(file.filename) > 0]
list_of_urls = [url for url in (urls or '').splitlines() if is_url(url)]

upload_path = file_service.save_files(project_user.project_id, list_of_files)
collection_name = upload_path.split('/')[-1]

chunks = [
{
'document': element.text,
'document_meta': {
key: value
for key, value in {**element.metadata.to_dict(), }.items()
if key in ['filename', 'url', 'page_number', 'page_name']
}
}
for file_or_url in [
*[file.path for file in file_service.get_file_infos(upload_path)],
*[url for url in list_of_urls]
]
for element in ParserFactory.get_parser(file_or_url).parse(file_or_url) if element
]

if len(chunks) == 0:
return view(
c.FireEvent(event=e.GoToEvent(url='/documents/upload')),
_('submit_a_question', 'Create Question'),
)

await vector_service.create_collection(
collection_name=collection_name,
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL
)

await vector_service.add_documents_without_id_to_empty_collection(
collection_name=collection_name,
documents=[chunk['document'] for chunk in chunks],
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL,
documents_metadata=[chunk['document_meta'] for chunk in chunks],
)
for batch in chunked(
generate_chunks([
*[file.path for file in file_service.get_file_infos(upload_path)],
*list_of_urls
]),
100
):
await vector_service.add_documents_without_id_to_empty_collection(
collection_name=collection_name,
documents=[chunk['document'] for chunk in batch],
embedding_model=settings.APP_VECTOR_DB_EMBEDDING_MODEL,
documents_metadata=[chunk['document_meta'] for chunk in batch],
document_ids=[chunk['document_id'] for chunk in batch]
)

await collection_service.create_collection_metadata(
collection_id=collection_name or '',
Expand Down
27 changes: 16 additions & 11 deletions fai-rag-app/fai-backend/fai_backend/files/file_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def is_url(string: str) -> bool:
@error_handler(default_return='null')
def get_mime_type(file_path: str) -> str:
if is_url(file_path):
return magic.from_buffer(requests.get(file_path, verify=settings.ENV_MODE != 'development').content, mime=True)
return magic.from_buffer(requests.get(file_path, verify=settings.ENV_MODE == 'development').content, mime=True)
return magic.from_file(file_path, mime=True)


Expand All @@ -57,9 +57,10 @@ def parse(self, filename: str) -> list[Element]:
return partition_docx(
url=filename,
ssl_verify=settings.ENV_MODE != 'development',
chunking_strategy='basic'
chunking_strategy='basic',
unique_element_ids=True
)
return partition_docx(filename, chunking_strategy='basic')
return partition_docx(filename, chunking_strategy='basic', unique_element_ids=True)


class PDFParser(AbstractDocumentParser):
Expand All @@ -69,9 +70,10 @@ def parse(self, filename: str) -> list[Element]:
return partition_pdf(
url=filename,
ssl_verify=settings.ENV_MODE != 'development',
chunking_strategy='basic'
chunking_strategy='basic',
unique_element_ids=True
)
return partition_pdf(filename, chunking_strategy='basic')
return partition_pdf(filename, chunking_strategy='basic', unique_element_ids=True)


class MarkdownParser(AbstractDocumentParser):
Expand All @@ -81,9 +83,10 @@ def parse(self, filename: str) -> list[Element]:
return partition_md(
url=filename,
ssl_verify=settings.ENV_MODE != 'development',
chunking_strategy='basic'
chunking_strategy='basic',
unique_element_ids=True
)
return partition_md(filename, chunking_strategy='basic')
return partition_md(filename, chunking_strategy='basic', unique_element_ids=True)


class ExcelParser(AbstractDocumentParser):
Expand All @@ -92,9 +95,10 @@ def parse(self, filename: str) -> list[Element]:
if is_url(filename):
return partition_xlsx(
url=filename,
ssl_verify=settings.ENV_MODE != 'development'
ssl_verify=settings.ENV_MODE != 'development',
unique_element_ids=True
)
return partition_xlsx(filename)
return partition_xlsx(filename, unique_element_ids=True)


class HTMLParser(AbstractDocumentParser):
Expand All @@ -104,9 +108,10 @@ def _parse_html(filename: str) -> list[Element]:
return partition_html(
url=filename,
ssl_verify=settings.ENV_MODE != 'development',
chunking_strategy='basic'
chunking_strategy='basic',
unique_element_ids=True
)
return partition_html(filename, chunking_strategy='basic')
return partition_html(filename, chunking_strategy='basic', unique_element_ids=True)

@error_handler(default_return=[])
def parse(self, filename: str) -> list[Element]:
Expand Down
3 changes: 2 additions & 1 deletion fai-rag-app/fai-backend/fai_backend/vector/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ async def add_documents_without_id_to_empty_collection(
embedding_model: str | None = None,
documents_metadata: Mapping[str, str | int | float | bool] | list[
Mapping[str, str | int | float | bool]] | None = None,
document_ids: list[str] | None = None
) -> None:
"""
Add documents to a collection without specifying ID's
The collection should be empty before calling this method to avoid ID conflicts.
"""
ids = [str(i) for i in range(len(documents))]
ids = [str(i) for i in range(len(documents))] if document_ids is None else document_ids

await self.add_to_collection(
collection_name=collection_name,
Expand Down
18 changes: 12 additions & 6 deletions fai-rag-app/fai-backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions fai-rag-app/fai-backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ openpyxl = "~3.1.5"
tiktoken = "^0.8.0"
pyjwt = "~2.9.0"
cryptography = "^43.0.3"
more-itertools = "^10.5.0"

[tool.poetry.group.unstructured.dependencies]
unstructured = { extras = ["md", "pdf", "docx"], version = "0.13.7" }
Expand Down

0 comments on commit b969502

Please sign in to comment.