-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Major improvements on chat completion. (#198)
* Major improvements on chat completion. * Updated changelog. * Fixed missing migrations. Updated packages.
- Loading branch information
Showing
18 changed files
with
611 additions
and
582 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,4 @@ | |
|
||
|
||
class OverallState(MessagesState): | ||
pass | ||
context: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,148 +1,70 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import cast | ||
from typing import TYPE_CHECKING | ||
|
||
from langgraph.constants import Send | ||
from langgraph.graph import END, START, StateGraph | ||
from langgraph.graph.state import CompiledStateGraph | ||
from langchain.retrievers.document_compressors import LLMListwiseRerank | ||
from langchain_core.documents import Document | ||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough | ||
|
||
from automation.agents import BaseAgent | ||
from automation.agents.codebase_search.schemas import GradeDocumentsOutput, ImprovedQueryOutput | ||
from automation.conf import settings | ||
from codebase.indexes import CodebaseIndex | ||
from automation.retrievers import MultiQueryRephraseRetriever | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Sequence | ||
|
||
from .prompts import grade_human, grade_system, re_write_human, re_write_system | ||
from .state import GradeDocumentState, OverallState | ||
from langchain_core.retrievers import BaseRetriever | ||
|
||
logger = logging.getLogger("daiv.agents") | ||
|
||
|
||
class CodebaseSearchAgent(BaseAgent[CompiledStateGraph]): | ||
class CodebaseSearchAgent(BaseAgent[Runnable[str, list[Document]]]): | ||
""" | ||
Agent to search for code snippets in the codebase. | ||
""" | ||
|
||
def __init__(self, index: CodebaseIndex, source_repo_id: str | None = None, source_ref: str | None = None): | ||
super().__init__() | ||
self.index = index | ||
self.source_repo_id = source_repo_id | ||
self.source_ref = source_ref | ||
|
||
def compile(self) -> CompiledStateGraph: | ||
workflow = StateGraph(OverallState) | ||
|
||
# Add nodes | ||
workflow.add_node("retrieve", self.retrieve) | ||
workflow.add_node("grade_document", self.grade_document) | ||
workflow.add_node("post_grade_document", self.post_grade_document) | ||
workflow.add_node("transform_query", self.transform_query) | ||
|
||
# Add edges | ||
workflow.add_edge(START, "retrieve") | ||
workflow.add_conditional_edges( | ||
"retrieve", self.should_grade_documents, ["transform_query", "grade_document", END] | ||
) | ||
workflow.add_edge("grade_document", "post_grade_document") | ||
workflow.add_conditional_edges("post_grade_document", self.should_transform_query, ["transform_query", END]) | ||
workflow.add_edge("transform_query", "retrieve") | ||
model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME | ||
|
||
return workflow.compile() | ||
def __init__(self, retriever: BaseRetriever, *args, **kwargs): | ||
self.retriever = retriever | ||
super().__init__(*args, **kwargs) | ||
|
||
def retrieve(self, state: OverallState): | ||
def compile(self) -> Runnable: | ||
""" | ||
Retrieve documents from the codebase index. | ||
Compile the agent into a Runnable. | ||
Args: | ||
state (GraphState): The current state of the graph. | ||
Returns: | ||
Runnable: The compiled agent | ||
""" | ||
if self.source_repo_id and self.source_ref: | ||
# we need to update the index before retrieving the documents | ||
# because the codebase search agent needs to search for the codebase changes | ||
# and we need to make sure the index is updated before the agent starts retrieving the documents | ||
self.index.update(self.source_repo_id, self.source_ref) | ||
|
||
if self.source_repo_id and self.source_ref: | ||
documents = self.index.search(self.source_repo_id, self.source_ref, state["query"]) | ||
else: | ||
documents = self.index.search_all(state["query"]) | ||
|
||
return {"documents": documents, "iterations": state.get("iterations", 0) + 1} | ||
return { | ||
"query": RunnablePassthrough(), | ||
"documents": MultiQueryRephraseRetriever.from_llm(self.retriever, llm=self.model), | ||
} | RunnableLambda( | ||
lambda inputs: self._compress_documents(inputs["documents"], inputs["query"]), name="compress_documents" | ||
) | ||
|
||
def grade_document(self, state: GradeDocumentState): | ||
def get_model_kwargs(self) -> dict: | ||
""" | ||
Grade the relevance of the retrieved document to the query. | ||
Get the model kwargs with a redefined temperature to make the model more creative. | ||
Args: | ||
state (GraphState): The current state of the graph. | ||
""" | ||
grader_agent = self.model.with_structured_output(GradeDocumentsOutput, method="json_schema") | ||
|
||
messages = [ | ||
grade_system, | ||
grade_human.format( | ||
query=state["query"], query_intent=state["query_intent"], document=state["document"].page_content | ||
), | ||
] | ||
response = cast("GradeDocumentsOutput", grader_agent.invoke(messages)) | ||
|
||
if response.binary_score: | ||
logger.info("[grade_document] Document '%s' is relevant to the query", state["document"].metadata["source"]) | ||
return {"documents": []} | ||
return {"documents": [state["document"]]} | ||
|
||
def post_grade_document(self, state: OverallState): | ||
""" | ||
Post-process the grade of the document. | ||
Returns: | ||
dict: The model kwargs | ||
""" | ||
return {"documents": []} | ||
kwargs = super().get_model_kwargs() | ||
kwargs["temperature"] = 0.5 | ||
return kwargs | ||
|
||
def transform_query(self, state: OverallState): | ||
def _compress_documents(self, documents: list[Document], query: str) -> Sequence[Document]: | ||
""" | ||
Transform the query to improve retrieval. | ||
Compress the documents using a listwise reranker. | ||
Args: | ||
state (GraphState): The current state of the graph. | ||
""" | ||
messages = [re_write_system, re_write_human.format(query=state["query"], query_intent=state["query_intent"])] | ||
|
||
query_rewriter = self.model.with_structured_output(ImprovedQueryOutput, method="json_schema") | ||
response = cast( | ||
"ImprovedQueryOutput", query_rewriter.invoke(messages, config={"configurable": {"temperature": 0.7}}) | ||
) | ||
|
||
logger.info("[transform_query] Query '%s' improved to '%s'", state["query"], response.query) | ||
|
||
return {"query": response.query} | ||
documents (Sequence[Document]): The documents to compress | ||
query (str): The search query string | ||
def should_grade_documents(self, state: OverallState): | ||
Returns: | ||
Sequence[Document]: The compressed documents | ||
""" | ||
Check if we should transform the query. | ||
""" | ||
if not state["documents"]: | ||
if state["iterations"] < settings.CODEBASE_SEARCH_MAX_TRANSFORMATIONS: | ||
logger.info("[should_grade_documents] No documents retrieved. Moving to transform_query state.") | ||
return "transform_query" | ||
else: | ||
logger.info("[should_grade_documents] No documents retrieved. Ending the process.") | ||
return END | ||
|
||
logger.info("[should_grade_documents] Documents retrieved. Moving to grade_documents state.") | ||
return [ | ||
Send( | ||
"grade_document", {"document": document, "query": state["query"], "query_intent": state["query_intent"]} | ||
) | ||
for document in state["documents"] | ||
] | ||
|
||
def should_transform_query(self, state: OverallState): | ||
""" | ||
Check if we should transform the query. | ||
""" | ||
logger.info( | ||
"[should_transform_query] %d documents found (iteration: %d)", len(state["documents"]), state["iterations"] | ||
) | ||
if not state["documents"] and state["iterations"] < settings.CODEBASE_SEARCH_MAX_TRANSFORMATIONS: | ||
logger.info("[should_transform_query] No relevant documents found. Moving to transform_query state.") | ||
return "transform_query" | ||
if not state["documents"]: | ||
logger.info("[should_transform_query] No relevant documents found. Ending the process.") | ||
return END | ||
reranker = LLMListwiseRerank.from_llm(llm=self.model, top_n=5) | ||
return reranker.compress_documents(documents, query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.