Skip to content

Commit

Permalink
feat: let LLM choose whether to retrieve context (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Dec 15, 2024
1 parent b02c5a0 commit 574e407
Show file tree
Hide file tree
Showing 12 changed files with 451 additions and 155 deletions.
42 changes: 33 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,35 @@ insert_document(Path("Special Relativity.pdf"), config=my_config)

### 3. Searching and Retrieval-Augmented Generation (RAG)

#### 3.1 Simple RAG pipeline
#### 3.1 Dynamically routed RAG

Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:
Now you can run a dynamically routed RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking a retrieval tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans (each of which is a list of consecutive chunks). The retrieval results are sent to the `on_retrieval` callback and are also appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the assistant response:

```python
from raglite import rag

# Create a user message:
messages = [] # Or start with an existing message history.
messages.append({
"role": "user",
"content": "How is intelligence measured?"
})

# Let the LLM decide whether to search the database by providing a retrieval tool to the LLM.
# If requested, RAGLite then uses hybrid search and reranking to append RAG context to the message history.
# Finally, assistant response is streamed and appended to the message history.
chunk_spans = []
stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=my_config)
for update in stream:
print(update, end="")

# Access the documents referenced in the RAG context:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

#### 3.2 Programmable RAG

If you need manual control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:

```python
from raglite import create_rag_instruction, rag, retrieve_rag_context
Expand All @@ -174,21 +200,19 @@ chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_co
messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))

# Stream the RAG response:
# Stream the RAG response and append it to the message history:
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# Access the documents cited in the RAG response:
# Access the documents referenced in the RAG context:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

#### 3.2 Advanced RAG pipeline

> [!TIP]
> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks.
In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps:
RAGLite also offers more advanced control over the individual steps of a full RAG pipeline:

1. Searching for relevant chunks with keyword, vector, or hybrid search
2. Retrieving the chunks from the database
Expand Down Expand Up @@ -229,14 +253,14 @@ from raglite import create_rag_instruction
messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))

# Stream the RAG response:
# Stream the RAG response and append it to the message history:
from raglite import rag

stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# Access the documents cited in the RAG response:
# Access the documents referenced in the RAG context:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ scipy = ">=1.5.0"
spacy = ">=3.7.0,<3.8.0"
# Large Language Models:
huggingface-hub = ">=0.22.0"
litellm = ">=1.47.1"
litellm = ">=1.48.4"
llama-cpp-python = ">=0.3.2"
pydantic = ">=2.7.0"
# Approximate Nearest Neighbors:
Expand Down
55 changes: 20 additions & 35 deletions src/raglite/_chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,11 @@
import chainlit as cl
from chainlit.input_widget import Switch, TextInput

from raglite import (
RAGLiteConfig,
async_rag,
create_rag_instruction,
hybrid_search,
insert_document,
rerank_chunks,
retrieve_chunk_spans,
retrieve_chunks,
)
from raglite import RAGLiteConfig, async_rag, hybrid_search, insert_document, rerank_chunks
from raglite._markdown import document_to_markdown

async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
async_retrieve_chunks = cl.make_async(retrieve_chunks)
async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
async_rerank_chunks = cl.make_async(rerank_chunks)


Expand Down Expand Up @@ -93,31 +82,27 @@ async def handle_message(user_message: cl.Message) -> None:
for i, attachment in enumerate(inline_attachments)
)
+ f"\n\n{user_message.content}"
)
# Search for relevant contexts for RAG.
async with cl.Step(name="search", type="retrieval") as step:
step.input = user_message.content
chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
step.output = chunks
step.elements = [ # Show the top chunks inline.
cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
]
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
# Rerank the chunks and group them into chunk spans.
async with cl.Step(name="rerank", type="rerank") as step:
step.input = chunks
chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
step.output = chunk_spans
step.elements = [ # Show the top chunk spans inline.
cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
]
await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
).strip()
# Stream the LLM response.
assistant_message = cl.Message(content="")
chunk_spans = []
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
async for token in async_rag(messages, config=config):
messages.append({"role": "user", "content": user_prompt})
async for token in async_rag(
messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config
):
await assistant_message.stream_token(token)
# Append RAG sources, if any.
if chunk_spans:
rag_sources: dict[str, list[str]] = {}
for chunk_span in chunk_spans:
rag_sources.setdefault(chunk_span.document.id, [])
rag_sources[chunk_span.document.id].append(str(chunk_span))
assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks.
f"[{i + 1}]" for i in range(len(rag_sources))
)
assistant_message.elements = [ # Markdown content is rendered in sidebar.
cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc]
for i, (_, content) in enumerate(rag_sources.items())
]
await assistant_message.update() # type: ignore[no-untyped-call]
31 changes: 27 additions & 4 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,41 @@ def to_xml(self, index: int | None = None) -> str:
if not self.chunks:
return ""
index_attribute = f' index="{index}"' if index is not None else ""
xml = "\n".join(
xml_document = "\n".join(
[
f'<document{index_attribute} id="{self.document.id}">',
f"<source>{self.document.url if self.document.url else self.document.filename}</source>",
f'<span from_chunk_id="{self.chunks[0].id}" to_chunk_id="{self.chunks[0].id}">',
f"<heading>\n{escape(self.chunks[0].headings.strip())}\n</heading>",
f'<span from_chunk_id="{self.chunks[0].id}" to_chunk_id="{self.chunks[-1].id}">',
f"<headings>\n{escape(self.chunks[0].headings.strip())}\n</headings>",
f"<content>\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n</content>",
"</span>",
"</document>",
]
)
return xml
return xml_document

def to_json(self, index: int | None = None) -> str:
"""Convert this chunk span to a JSON representation.
The JSON representation follows Anthropic's best practices [1].
[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
"""
if not self.chunks:
return "{}"
index_attribute = {"index": index} if index is not None else {}
json_document = {
**index_attribute,
"id": self.document.id,
"source": self.document.url if self.document.url else self.document.filename,
"span": {
"from_chunk_id": self.chunks[0].id,
"to_chunk_id": self.chunks[-1].id,
"headings": self.chunks[0].headings.strip(),
"content": "".join(chunk.body for chunk in self.chunks).strip(),
},
}
return json.dumps(json_document)

@property
def content(self) -> str:
Expand Down
5 changes: 2 additions & 3 deletions src/raglite/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ class MyNameResponse(BaseModel):
# Load the default config if not provided.
config = config or RAGLiteConfig()
# Check if the LLM supports the response format.
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
llm_supports_response_format = "response_format" in (
get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []
get_supported_openai_params(model=config.llm) or []
)
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = getattr(return_type, "system_prompt", "").strip()
if not llm_supports_response_format or llm_provider == "llama-cpp-python":
if not llm_supports_response_format or config.llm.startswith("llama-cpp-python"):
system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}"
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
# is disabled by default because it only supports a subset of JSON schema features [2].
Expand Down
Loading

0 comments on commit 574e407

Please sign in to comment.