Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Embeddings pipeline improvements #33

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/delete.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import APIRouter

from models.delete import RequestPayload, ResponsePayload
from service.embedding import get_encoder
from service.vector_database import VectorService, get_vector_service
from service.embedding import get_encoder

Expand Down
24 changes: 22 additions & 2 deletions dev/embedding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,27 @@
"metadata": {},
"outputs": [],
"source": [
"docs = await embedding_service.generate_chunks()"
"elements = await embedding_service._download_and_extract_elements(file, strategy=\"auto\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for element in elements:\n",
" print(type(element))\n",
" # print(f\"Text: {element.text}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"docs = await embedding_service.generate_chunks(strategy=\"auto\")"
]
},
{
Expand Down Expand Up @@ -68,7 +88,7 @@
" print(colored_text)\n",
" concatenated_document += chunk + \" \"\n",
"\n",
"print(\"\\nConcatenated Document:\\n\", concatenated_document)"
"# print(\"\\nConcatenated Document:\\n\", concatenated_document)"
]
},
{
Expand Down
10 changes: 10 additions & 0 deletions dev/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = query_response.json().get('data', [])\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
33 changes: 9 additions & 24 deletions service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests
from tqdm import tqdm
from unstructured.chunking.title import chunk_by_title
from unstructured.documents.elements import Element
from unstructured.partition.auto import partition

import encoders
Expand Down Expand Up @@ -41,8 +42,8 @@ def _get_datasource_suffix(self, type: str) -> str:
raise ValueError("Unsupported datasource type")

async def _download_and_extract_elements(
self, file, strategy="hi_res"
) -> List[Any]:
self, file, strategy: Optional[str] = "hi_res"
) -> List[Element]:
"""
Downloads the file and extracts elements using the partition function.
Returns a list of unstructured elements.
Expand Down Expand Up @@ -84,33 +85,17 @@ async def generate_document(
except Exception as e:
logger.error(f"Error loading document {file.url}: {e}")

async def generate_summary_document(
self, documents: List[BaseDocument]
) -> List[BaseDocument]:
pbar = tqdm(total=len(documents), desc="Summarizing documents")
pages = {}
for document in documents:
page_number = document.metadata.get("page_number")
if page_number not in pages:
doc = copy.deepcopy(document)
doc.text = await completion(document=doc)
pages[page_number] = doc
else:
pages[page_number].text += document.text
pbar.update()
pbar.close()
summary_documents = list(pages.values())
return summary_documents

async def generate_chunks(self) -> List[BaseDocumentChunk]:
async def generate_chunks(self, strategy: Optional[str]) -> List[BaseDocumentChunk]:
doc_chunks = []
for file in tqdm(self.files, desc="Generating chunks"):
try:
elements = await self._download_and_extract_elements(file)
elements = await self._download_and_extract_elements(file, strategy)
document = await self.generate_document(file, elements)
if not document:
continue
chunks = chunk_by_title(elements)
chunks = chunk_by_title(
elements, max_characters=500, combine_text_under_n_chars=0
)
for chunk in chunks:
# Ensure all metadata values are of a type acceptable to Pinecone
sanitized_metadata = {
Expand Down Expand Up @@ -225,4 +210,4 @@ def get_encoder(*, encoder_type: EncoderEnum) -> encoders.BaseEncoder:
encoder_class = encoder_mapping.get(encoder_type)
if encoder_class is None:
raise ValueError(f"Unsupported encoder: {encoder_type}")
return encoder_class()
return encoder_class()
14 changes: 14 additions & 0 deletions service/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]:
# async def convert_to_rerank_format():
# pass

# TODO: make it default method instead of abstract
# async def convert_to_rerank_format(self, chunks: List[BaseDocumentChunk]):
# docs = [
# {
# "content": chunk.text,
# "page_label": (
# chunk.metadata.get("page_number", "") if chunk.metadata else ""
# ),
# "file_url": chunk.doc_url,
# }
# for chunk in chunks
# ]
# return docs

@abstractmethod
async def delete(self, file_url: str):
pass
Expand Down
Loading