Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/codebase search improvements #199

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Changed

- Performance improvements and cleaner use of compression retrievers on `CodebaseSearchAgent`.

### Added

- Codebase search now allows to configure how many results are returned by the search.

### Fixed

- Tantivy index volume was not pointing to the correct folder on `docker-compose.yml`.
- `ReviewAddressorManager` was committing duplicated file changes, leading to errors calling the repo client API.

### Chore

- Removed unused code on codebase search. This was left after the latest major refactor.

## [0.1.0-alpha.20] - 2025-01-19

### Upgrade Guide
Expand Down
40 changes: 7 additions & 33 deletions daiv/automation/agents/codebase_search/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import logging
from typing import TYPE_CHECKING

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMListwiseRerank
from langchain_core.documents import Document
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.runnables import Runnable

from automation.agents import BaseAgent
from automation.conf import settings
from automation.retrievers import MultiQueryRephraseRetriever

if TYPE_CHECKING:
from collections.abc import Sequence

from langchain_core.retrievers import BaseRetriever

logger = logging.getLogger("daiv.agents")
Expand All @@ -37,34 +36,9 @@ def compile(self) -> Runnable:
Returns:
Runnable: The compiled agent
"""
return {
"query": RunnablePassthrough(),
"documents": MultiQueryRephraseRetriever.from_llm(self.retriever, llm=self.model),
} | RunnableLambda(
lambda inputs: self._compress_documents(inputs["documents"], inputs["query"]), name="compress_documents"
return ContextualCompressionRetriever(
base_compressor=LLMListwiseRerank.from_llm(
llm=self.get_model(temperature=0), top_n=settings.CODEBASE_SEARCH_TOP_N
),
base_retriever=MultiQueryRephraseRetriever.from_llm(self.retriever, llm=self.get_model(temperature=0.3)),
)

def get_model_kwargs(self) -> dict:
"""
Get the model kwargs with a redefined temperature to make the model more creative.

Returns:
dict: The model kwargs
"""
kwargs = super().get_model_kwargs()
kwargs["temperature"] = 0.5
return kwargs

def _compress_documents(self, documents: list[Document], query: str) -> Sequence[Document]:
"""
Compress the documents using a listwise reranker.

Args:
documents (Sequence[Document]): The documents to compress
query (str): The search query string

Returns:
Sequence[Document]: The compressed documents
"""
reranker = LLMListwiseRerank.from_llm(llm=self.model, top_n=5)
return reranker.compress_documents(documents, query)
51 changes: 0 additions & 51 deletions daiv/automation/agents/codebase_search/prompts.py

This file was deleted.

19 changes: 0 additions & 19 deletions daiv/automation/agents/codebase_search/schemas.py

This file was deleted.

33 changes: 0 additions & 33 deletions daiv/automation/agents/codebase_search/state.py

This file was deleted.

4 changes: 2 additions & 2 deletions daiv/automation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class AutomationSettings(BaseSettings):
),
)
# Codebase search settings
CODEBASE_SEARCH_MAX_TRANSFORMATIONS: int = Field(
default=2, description="Maximum number of transformations to apply to the query."
CODEBASE_SEARCH_TOP_N: int = Field(
default=10, description="Maximum number of documents to return from the codebase search."
)


Expand Down
65 changes: 49 additions & 16 deletions daiv/automation/retrievers.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
import textwrap
from itertools import chain
from operator import itemgetter
from typing import override

from langchain.retrievers.multi_query import DEFAULT_QUERY_PROMPT, LineListOutputParser, MultiQueryRetriever
from langchain.retrievers.multi_query import LineListOutputParser, MultiQueryRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.messages import SystemMessage
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnableLambda, RunnableParallel

OUTPUT_FORMAT_PROMPT = (
"\nThe output should be a list of queries, separated by newlines, with no numbering or additional formatting."
DEFAULT_QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template=textwrap.dedent(
"""\
You are an AI language model assistant. Your task is to generate 3 different versions of the given user question to retrieve relevant documents from a vector database.
By generating multiple perspectives on the user question, your goal is to help the user overcome some of the limitations of distance-based similarity search. Provide these alternative
questions separated by newlines.
The output should be a list of queries, separated by newlines, with no numbering or additional formatting.
Original question: {question}
""" # NOQA: E501
),
)

REPHRASE_QUERY_PROMPT = PromptTemplate.from_template(
"""\
You are an assistant tasked with taking 3 natural language queries from a user and converting them into 3 queries for a vectorstore.
In this process, you strip out information that is not relevant for the retrieval task. Here are the user queries: {% for query in queries %}
- {{ query }}
{% endfor %}
""" # noqa: E501
+ OUTPUT_FORMAT_PROMPT,
template_format="jinja2",
REPHRASE_SYSTEM = SystemMessage(
textwrap.dedent(
"""\
You are an assistant tasked with taking 3 coding-related queries in natural language from a user and converting them into 3 queries optimized for a semantiv search on a vector database.
In this process, you strip out information that is not relevant to improve semantic matching.
The output should be a list of queries, separated by newlines, with no numbering or additional formatting.
""" # NOQA: E501
)
)

REPHRASE_HUMAN = HumanMessagePromptTemplate.from_template(
"{% for query in queries %}{{ query }}\n{% endfor %}", template_format="jinja2"
)


Expand Down Expand Up @@ -47,10 +64,10 @@ def from_llm(
Returns:
MultiQueryRephraseRetriever
"""
prompt.template += OUTPUT_FORMAT_PROMPT
rephrase_prompt = ChatPromptTemplate.from_messages([REPHRASE_SYSTEM, REPHRASE_HUMAN])

output_parser = LineListOutputParser()
llm_chain = prompt | llm | output_parser | REPHRASE_QUERY_PROMPT | llm | output_parser
llm_chain = prompt | llm | output_parser | rephrase_prompt | llm | output_parser
return cls(retriever=retriever, llm_chain=llm_chain, include_original=include_original)

@override
Expand All @@ -66,3 +83,19 @@ def unique_union(self, documents: list[Document]) -> list[Document]:
"""
unique_docs: dict[str | None, Document] = {doc.metadata.get("id"): doc for doc in documents}
return list(unique_docs.values())

@override
def retrieve_documents(self, queries: list[str], run_manager: CallbackManagerForRetrieverRun) -> list[Document]:
"""
Run all LLM generated queries in parallel and return the results as a list of documents.

Args:
queries: query list

Returns:
List of retrieved Documents
"""
runnable = RunnableParallel({
f"query_{i}": itemgetter(i) | self.retriever for i, _q in enumerate(queries)
}) | RunnableLambda(lambda inputs: list(chain(*inputs.values())))
return runnable.invoke(queries, config={"callbacks": run_manager.get_child()})
7 changes: 4 additions & 3 deletions daiv/codebase/managers/review_addressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _process_review(self):
Each iteration of dicussions resolution will be processed with the changes from the previous iterations,
ensuring that the file changes are processed correctly.
"""
file_changes: list[FileChange] = []
file_changes: set[FileChange] = set()
resolved_discussions: list[str] = []

merge_request_patches = self._extract_merge_request_diffs()
Expand Down Expand Up @@ -221,15 +221,16 @@ def _process_review(self):

if not state_after_run.tasks:
if files_to_commit := reviewer_addressor.get_files_to_commit():
file_changes.extend(files_to_commit)
# Use set and update method to avoid duplicates
file_changes.update(files_to_commit)

if result and ("response" not in result or not result["response"]):
# If the response is not in the result or is empty, it means the discussion was resolved,
# no further action would be needed.
resolved_discussions.append(context.discussion.id)

if file_changes:
self._commit_changes(file_changes=file_changes, thread_id=thread_id)
self._commit_changes(file_changes=list(file_changes), thread_id=thread_id)

if resolved_discussions:
for discussion_id in resolved_discussions:
Expand Down
2 changes: 1 addition & 1 deletion data/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
!.gitignore
!media/
!static/
!tantivy_index/
!tantivy_index_v1/
File renamed without changes.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ x-app-defaults: &x_app_default
volumes:
- ./data/static:/home/app/data/static
- ./data/media:/home/app/data/media
- ./data/tantivy_index:/home/app/data/tantivy_index
- ./data/tantivy_index_v1:/home/app/data/tantivy_index_v1
- ./data/models:/home/app/.cache/huggingface/hub/
- .:/home/app/src
depends_on:
Expand Down
Loading