Skip to content

Commit

Permalink
E2E Hybrid retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored May 28, 2024
1 parent bf1eb85 commit 608bdf0
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 107 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.27.0]

### Added
- Added a keyword-based search similarity to RAGTools to serve both for baseline evaluation and for advanced performance (by having a hybrid index with both embeddings and BM25). Works only for retrieval for now (no E2E `airag` support yet) See `?RT.KeywordsIndexer` and `?RT.BM25Similarity` for more information, to build use `build_index(KeywordsIndexer(), texts)` or convert an existing embeddings-based index `ChunkKeywordsIndex(index)`.
- Added a keyword-based search similarity to RAGTools to serve both for baseline evaluation and for advanced performance (by having a hybrid index with both embeddings and BM25). See `?RT.KeywordsIndexer` and `?RT.BM25Similarity` for more information, to build use `build_index(KeywordsIndexer(), texts)` or convert an existing embeddings-based index `ChunkKeywordsIndex(index)`.

### Updated
- For naming consistency, `ChunkIndex` in RAGTools has been renamed to `ChunkEmbeddingsIndex` (with an alias `ChunkIndex` for backwards compatibility). There are now two main index types: `ChunkEmbeddingsIndex` and `ChunkKeywordsIndex` (=BM25), which can be combined into a `MultiIndex` to serve as a hybrid index.
Expand Down
4 changes: 3 additions & 1 deletion docs/src/examples/building_RAG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ Let's build a Retrieval-Augmented Generation (RAG) chatbot, tailored to navigate

If you're not familiar with "RAG", start with this [article](https://towardsdatascience.com/add-your-own-data-to-an-llm-using-retrieval-augmented-generation-rag-b1958bf56a5a).

Note: You must first import `LinearAlgebra`, `SparseArrays`, and `Unicode` to use this example!


````julia
using LinearAlgebra, SparseArrays
using LinearAlgebra, SparseArrays, Unicode
using PromptingTools
using PromptingTools.Experimental.RAGTools
## Note: RAGTools module is still experimental and will change in the future. Ideally, they will be cleaned up and moved to a dedicated package
Expand Down
2 changes: 1 addition & 1 deletion docs/src/extra_tools/rag_tools_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Import the module as follows:

```julia
# required dependencies to load the necessary extensions!!!
using LinearAlgebra, SparseArrays
using LinearAlgebra, SparseArrays, Unicode
using PromptingTools.Experimental.RAGTools
# to access unexported functionality
const RT = PromptingTools.Experimental.RAGTools
Expand Down
4 changes: 3 additions & 1 deletion examples/building_RAG.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

# If you're not familiar with "RAG", start with this [article](https://towardsdatascience.com/add-your-own-data-to-an-llm-using-retrieval-augmented-generation-rag-b1958bf56a5a).

# Note: You must first import `LinearAlgebra`, `SparseArrays`, and `Unicode` to use this example!

## Imports
using LinearAlgebra, SparseArrays
using LinearAlgebra, SparseArrays, Unicode
using PromptingTools
## Note: RAGTools is still experimental and will change in the future. Ideally, they will be cleaned up and moved to a dedicated package
using PromptingTools.Experimental.RAGTools
Expand Down
2 changes: 1 addition & 1 deletion src/Experimental/RAGTools/RAGTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Provides Retrieval-Augmented Generation (RAG) functionality.
Requires: LinearAlgebra, SparseArrays, PromptingTools for proper functionality.
Requires: LinearAlgebra, SparseArrays, Unicode, PromptingTools for proper functionality.
This module is experimental and may change at any time. It is intended to be moved to a separate package in the future.
"""
Expand Down
97 changes: 64 additions & 33 deletions src/Experimental/RAGTools/generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ struct ContextEnumerator <: AbstractContextBuilder end

"""
build_context(contexter::ContextEnumerator,
index::AbstractChunkIndex, candidates::CandidateChunks;
index::AbstractDocumentIndex, candidates::AbstractCandidateChunks;
verbose::Bool = true,
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)
build_context!(contexter::ContextEnumerator,
index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...)
index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
Build context strings for each position in `candidates` considering a window margin around each position.
If mutating version is used (`build_context!`), it will use `result.reranked_candidates` to update the `result.context` field.
# Arguments
- `contexter::ContextEnumerator`: The method to use for building the context. Enumerates the snippets.
- `index::ChunkIndex`: The index containing chunks and sources.
- `candidates::CandidateChunks`: Candidate chunks which contain positions to extract context from.
- `index::AbstractDocumentIndex`: The index containing chunks and sources.
- `candidates::AbstractCandidateChunks`: Candidate chunks which contain positions to extract context from.
- `verbose::Bool`: If `true`, enables verbose logging.
- `chunks_window_margin::Tuple{Int, Int}`: A tuple indicating the margin (before, after) around each position to include in the context.
Defaults to `(1,1)`, which means 1 preceding and 1 suceeding chunk will be included. With `(0,0)`, only the matching chunks will be included.
Expand All @@ -37,32 +37,40 @@ context = build_context(ContextEnumerator(), index, candidates; chunks_window_ma
```
"""
function build_context(contexter::ContextEnumerator,
index::AbstractChunkIndex, candidates::CandidateChunks;
index::AbstractDocumentIndex, candidates::AbstractCandidateChunks;
verbose::Bool = true,
chunks_window_margin::Tuple{Int, Int} = (1, 1), kwargs...)
## Checks
@assert chunks_window_margin[1] >= 0&&chunks_window_margin[2] >= 0 "Both `chunks_window_margin` values must be non-negative"

context = String[]
for (i, position) in enumerate(candidates.positions)
chunks_ = chunks(index)[max(1, position - chunks_window_margin[1]):min(end,
## select the right index
id = candidates isa MultiCandidateChunks ? candidates.index_ids[i] :
candidates.index_id
index_ = index isa AbstractChunkIndex ? index : index[id]
isnothing(index_) && continue
##
chunks_ = chunks(index_)[
max(1, position - chunks_window_margin[1]):min(end,
position + chunks_window_margin[2])]
## Check if surrounding chunks are from the same source
is_same_source = sources(index)[max(1, position - chunks_window_margin[1]):min(end,
position + chunks_window_margin[2])] .== sources(index)[position]
is_same_source = sources(index_)[
max(1, position - chunks_window_margin[1]):min(end,
position + chunks_window_margin[2])] .== sources(index_)[position]
push!(context, "$(i). $(join(chunks_[is_same_source], "\n"))")
end
return context
end

function build_context!(contexter::AbstractContextBuilder,
index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...)
index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
throw(ArgumentError("Contexter $(typeof(contexter)) not implemented"))
end

# Mutating version that dispatches on the result to the underlying implementation
function build_context!(contexter::ContextEnumerator,
index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...)
index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
result.context = build_context(contexter, index, result.reranked_candidates; kwargs...)
return result
end
Expand All @@ -77,14 +85,14 @@ Default method for `answer!` method. Generates an answer using the `aigenerate`
struct SimpleAnswerer <: AbstractAnswerer end

function answer!(
answerer::AbstractAnswerer, index::AbstractChunkIndex, result::AbstractRAGResult;
answerer::AbstractAnswerer, index::AbstractDocumentIndex, result::AbstractRAGResult;
kwargs...)
throw(ArgumentError("Answerer $(typeof(answerer)) not implemented"))
end

"""
answer!(
answerer::SimpleAnswerer, index::AbstractChunkIndex, result::AbstractRAGResult;
answerer::SimpleAnswerer, index::AbstractDocumentIndex, result::AbstractRAGResult;
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
template::Symbol = :RAGAnswerFromContext,
cost_tracker = Threads.Atomic{Float64}(0.0),
Expand All @@ -97,7 +105,7 @@ Generates an answer using the `aigenerate` function with the provided `result.co
# Arguments
- `answerer::SimpleAnswerer`: The method to use for generating the answer. Uses `aigenerate`.
- `index::AbstractChunkIndex`: The index containing chunks and sources.
- `index::AbstractDocumentIndex`: The index containing chunks and sources.
- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for.
- `model::AbstractString`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`.
- `verbose::Bool`: If `true`, enables verbose logging.
Expand All @@ -106,7 +114,7 @@ Generates an answer using the `aigenerate` function with the provided `result.co
"""
function answer!(
answerer::SimpleAnswerer, index::AbstractChunkIndex, result::AbstractRAGResult;
answerer::SimpleAnswerer, index::AbstractDocumentIndex, result::AbstractRAGResult;
model::AbstractString = PT.MODEL_CHAT, verbose::Bool = true,
template::Symbol = :RAGAnswerFromContext,
cost_tracker = Threads.Atomic{Float64}(0.0),
Expand Down Expand Up @@ -154,7 +162,7 @@ Refines the answer by executing a web search using the Tavily API. This method a
struct TavilySearchRefiner <: AbstractRefiner end

function refine!(
refiner::AbstractRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
refiner::AbstractRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
kwargs...)
throw(ArgumentError("Refiner $(typeof(refiner)) not implemented"))
end
Expand All @@ -167,7 +175,7 @@ end
Simple no-op function for `refine`. It simply copies the `result.answer` and `result.conversations[:answer]` without any changes.
"""
function refine!(
refiner::NoRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
refiner::NoRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
kwargs...)
result.final_answer = result.answer
if haskey(result.conversations, :answer)
Expand All @@ -178,7 +186,7 @@ end

"""
refine!(
refiner::SimpleRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
template::Symbol = :RAGAnswerRefiner,
Expand All @@ -194,15 +202,15 @@ This method uses the same context as the original answer, however, it can be mod
# Arguments
- `refiner::SimpleRefiner`: The method to use for refining the answer. Uses `aigenerate`.
- `index::AbstractChunkIndex`: The index containing chunks and sources.
- `index::AbstractDocumentIndex`: The index containing chunks and sources.
- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for.
- `model::AbstractString`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`.
- `verbose::Bool`: If `true`, enables verbose logging.
- `template::Symbol`: The template to use for the `aigenerate` function. Defaults to `:RAGAnswerRefiner`.
- `cost_tracker`: An atomic counter to track the cost of the operation.
"""
function refine!(
refiner::SimpleRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
refiner::SimpleRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
template::Symbol = :RAGAnswerRefiner,
Expand Down Expand Up @@ -232,7 +240,7 @@ end

"""
refine!(
refiner::TavilySearchRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
include_answer::Bool = true,
Expand All @@ -253,7 +261,7 @@ Note: The web results and web answer (if requested) will be added to the context
# Arguments
- `refiner::TavilySearchRefiner`: The method to use for refining the answer. Uses `aigenerate` with a web search template.
- `index::AbstractChunkIndex`: The index containing chunks and sources.
- `index::AbstractDocumentIndex`: The index containing chunks and sources.
- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for.
- `model::AbstractString`: The model to use for generating the answer. Defaults to `PT.MODEL_CHAT`.
- `include_answer::Bool`: If `true`, includes the answer from Tavily in the web search.
Expand All @@ -280,7 +288,7 @@ pprint(result)
```
"""
function refine!(
refiner::TavilySearchRefiner, index::AbstractChunkIndex, result::AbstractRAGResult;
refiner::TavilySearchRefiner, index::AbstractDocumentIndex, result::AbstractRAGResult;
verbose::Bool = true,
model::AbstractString = PT.MODEL_CHAT,
include_answer::Bool = true,
Expand Down Expand Up @@ -345,13 +353,13 @@ Overload this method to add custom postprocessing steps, eg, logging, saving con
"""
struct NoPostprocessor <: AbstractPostprocessor end

function postprocess!(postprocessor::AbstractPostprocessor, index::AbstractChunkIndex,
function postprocess!(postprocessor::AbstractPostprocessor, index::AbstractDocumentIndex,
result::AbstractRAGResult; kwargs...)
throw(ArgumentError("Postprocessor $(typeof(postprocessor)) not implemented"))
end

function postprocess!(
::NoPostprocessor, index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...)
::NoPostprocessor, index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
return result
end

Expand Down Expand Up @@ -386,7 +394,7 @@ end

"""
generate!(
generator::AbstractGenerator, index::AbstractChunkIndex, result::AbstractRAGResult;
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
verbose::Integer = 1,
api_kwargs::NamedTuple = NamedTuple(),
contexter::AbstractContextBuilder = generator.contexter,
Expand Down Expand Up @@ -416,7 +424,7 @@ Returns the mutated `result` with the `result.final_answer` and the full convers
# Arguments
- `generator::AbstractGenerator`: The `generator` to use for generating the answer. Can be `SimpleGenerator` or `AdvancedGenerator`.
- `index::AbstractChunkIndex`: The index containing chunks and sources.
- `index::AbstractDocumentIndex`: The index containing chunks and sources.
- `result::AbstractRAGResult`: The result containing the context and question to generate the answer for.
- `verbose::Integer`: If >0, enables verbose logging.
- `api_kwargs::NamedTuple`: API parameters that will be forwarded to ALL of the API calls (`aiembed`, `aigenerate`, and `aiextract`).
Expand Down Expand Up @@ -451,7 +459,7 @@ result = generate!(index, result)
```
"""
function generate!(
generator::AbstractGenerator, index::AbstractChunkIndex, result::AbstractRAGResult;
generator::AbstractGenerator, index::AbstractDocumentIndex, result::AbstractRAGResult;
verbose::Integer = 1,
api_kwargs::NamedTuple = NamedTuple(),
contexter::AbstractContextBuilder = generator.contexter,
Expand Down Expand Up @@ -494,7 +502,7 @@ end

# Set default behavior
DEFAULT_GENERATOR = SimpleGenerator()
function generate!(index::AbstractChunkIndex, result::AbstractRAGResult; kwargs...)
function generate!(index::AbstractDocumentIndex, result::AbstractRAGResult; kwargs...)
return generate!(DEFAULT_GENERATOR, index, result; kwargs...)
end

Expand All @@ -514,7 +522,7 @@ To customize the components, replace corresponding fields for each step of the R
end

"""
airag(cfg::AbstractRAGConfig, index::AbstractChunkIndex;
airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
question::AbstractString,
verbose::Integer = 1, return_all::Bool = false,
api_kwargs::NamedTuple = NamedTuple(),
Expand All @@ -533,7 +541,7 @@ Eg, use `subtypes(AbstractRetriever)` to find the available options.
# Arguments
- `cfg::AbstractRAGConfig`: The configuration for the RAG pipeline. Defaults to `RAGConfig()`, where you can swap sub-types to customize the pipeline.
- `index::AbstractChunkIndex`: The chunk index to search for relevant text.
- `index::AbstractDocumentIndex`: The chunk index to search for relevant text.
- `question::AbstractString`: The question to be answered.
- `return_all::Bool`: If `true`, returns the details used for RAG along with the response.
- `verbose::Integer`: If `>0`, enables verbose logging. The higher the number, the more nested functions will log.
Expand All @@ -559,7 +567,7 @@ Eg, use `subtypes(AbstractRetriever)` to find the available options.
- If `return_all` is `false`, returns the generated message (`msg`).
- If `return_all` is `true`, returns the detail of the full pipeline in `RAGResult` (see the docs).
See also `build_index`, `retrieve`, `generate!`, `RAGResult`, `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`.
See also `build_index`, `retrieve`, `generate!`, `RAGResult`, `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`, `ChunkKeywordsIndex`.
# Examples
Expand Down Expand Up @@ -612,9 +620,32 @@ kwargs = (
result = airag(cfg, index, question; kwargs...)
```
If you want to use hybrid retrieval (embeddings + BM25), you can easily create an additional index based on keywords
and pass them both into a `MultiIndex`.
You need to provide an explicit config, so the pipeline knows how to handle each index in the search similarity phase (`finder`).
```julia
index = # your existing index
# create the multi-index with the keywords index
index_keywords = ChunkKeywordsIndex(index)
multi_index = MultiIndex([index, index_keywords])
# define the similarity measures for the indices that you have (same order)
finder = RT.MultiFinder([RT.CosineSimilarity(), RT.BM25Similarity()])
cfg = RAGConfig(; retriever=AdvancedRetriever(; processor=RT.KeywordsProcessor(), finder))
# Run the pipeline with the new hybrid retrieval (return the `RAGResult` to see the details)
result = airag(cfg, multi_index; question, return_all=true)
# Pretty-print the result
PT.pprint(result)
```
For easier manipulation of nested kwargs, see utilities `getpropertynested`, `setpropertynested`, `merge_kwargs_nested`.
"""
function airag(cfg::AbstractRAGConfig, index::AbstractChunkIndex;
function airag(cfg::AbstractRAGConfig, index::AbstractDocumentIndex;
question::AbstractString,
verbose::Integer = 1, return_all::Bool = false,
api_kwargs::NamedTuple = NamedTuple(),
Expand Down Expand Up @@ -656,7 +687,7 @@ end

# Default behavior
const DEFAULT_RAG_CONFIG = RAGConfig()
function airag(index::AbstractChunkIndex; question::AbstractString, kwargs...)
function airag(index::AbstractDocumentIndex; question::AbstractString, kwargs...)
return airag(DEFAULT_RAG_CONFIG, index; question, kwargs...)
end

Expand Down
Loading

0 comments on commit 608bdf0

Please sign in to comment.