Skip to content

Commit

Permalink
Major improvements on chat completion. (#198)
Browse files Browse the repository at this point in the history
* Major improvements on chat completion.

* Updated changelog.

* Fixed missing migrations. Updated packages.
  • Loading branch information
srtab authored Jan 20, 2025
1 parent 845276b commit 463a662
Show file tree
Hide file tree
Showing 18 changed files with 611 additions and 582 deletions.
38 changes: 33 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,52 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
## [0.1.0-alpha.20] - 2025-01-19

### Upgrade Guide

- There were substancial changes on the codebase search engines, as it was rewritten to improve the quality of the search results. You will need to run command `update_index` with `reset_all=True` to update the indexes.

### Added

- Show comment on the issue when the agent is replanning the tasks.
- Repositories file paths are now being indexed too, allowing the agent to search for file paths. **You must run command `update_index` with `reset_all=True` to update the indexes to include file paths.**
- Repositories file paths are now being indexed too, allowing the agent to search for file paths.
- New option `reset_all` added to the `update_index` command to allow reset indexes from all branches, not only the default branch.

### Changed

- Improved `CodebaseQAAgent` response, even when tool calls are not being called. Added web search tool to the agent to allow it to search for answers when the codebase doesn't have the information.
- Changed models url paths on Chat API from `api/v1/chat/models` -> `api/v1/models` to be more consistent with OpenAI API. **This is a breaking change, as it will affect all clients using the Chat API.**
- Rewriten lexical search engine to only have one index, allowing the agent to search for answers in multiple repositories at once.
- Rewritten `CodebaseSearchAgent` to improve the quality of search results by using techniques such as: generating multiple queries, rephrasing those queries and compressing documents with listwise reranking.
- Simplified `CodebaseQAAgent` to use `CodebaseSearchAgent` to search for answers instead of binding tools to the agent.
- Changed models url paths on Chat API from `api/v1/chat/models` -> `api/v1/models` to be more consistent with OpenAI API.
- Default embedding model changed to `text-embedding-3-large` to improve the quality of the search results. The dimension of stored vectors was increased from 1536 to 2000.

### Fixed

- Issues with images were not being processed correctly, leading to errors interpreting the image from the description.
- Chat API was not returning the correct response structure when not on streaming mode.
- Codebase retriever used for non-scoped indexes on `CodebaseQAAgent` was returning duplicate documents, from different branches. Now it's filtering always by repository default branch.q
- Codebase retriever used for non-scoped indexes on `CodebaseSearchAgent` was returning duplicate documents, from different refs. Now it's filtering always by repository default branch.

### Chore

- Updated dependencies:
- `django` from 5.1.4 to 5.1.5
- `duckduckgo-search` from 7.2.0 to 7.2.1
- `httpx` from 0.27.2 to 0.28.1
- `langchain-anthropic` from 0.3.2 to 0.3.3
- `langchain-openai` from 0.3.0 to 0.3.1
- `langchain-text-splitters` from 0.3.4 to 0.3.5
- `langgraph` from 0.2.61 to 0.2.64
- `langgraph-checkpoint-postgres` from 2.0.9 to 2.0.13
- `langsmith` from 0.2.10 to 0.2.11
- `psycopg` from 3.2.3 to 3.2.4
- `pyopenssl` from 24.3 to 25
- `pydantic` from 2.10.4 to 2.10.5
- `pytest-asyncio` from 0.25.1 to 0.25.2
- `python-gitlab` from 5.3.0 to 5.3.1
- `ruff` from 0.8.7 to 0.9.2
- `sentry-sdk` from 2.19.2 to 2.20
- `watchfiles` from 1.0.3 to 1.0.4

## [0.1.0-alpha.19] - 2025-01-10

Expand Down
3 changes: 2 additions & 1 deletion daiv/automation/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from enum import StrEnum
from typing import TYPE_CHECKING, Generic, TypeVar, cast

from langchain.chat_models.base import BaseChatModel, _attempt_infer_model_provider, init_chat_model
from langchain.chat_models.base import _attempt_infer_model_provider, init_chat_model
from langchain_community.callbacks import OpenAICallbackHandler
from langchain_core.runnables import Runnable, RunnableConfig
from pydantic import BaseModel

from automation.conf import settings

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_openai.chat_models import ChatOpenAI
from langgraph.checkpoint.postgres import PostgresSaver
Expand Down
73 changes: 32 additions & 41 deletions daiv/automation/agents/codebase_qa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from automation.agents import BaseAgent
from automation.conf import settings
from automation.tools.repository import SearchCodeSnippetsTool
from automation.tools.toolkits import WebSearchToolkit
from automation.tools.web_search import WebSearchTool
from codebase.indexes import CodebaseIndex

from .prompts import system, system_query_or_respond
from .prompts import system
from .state import OverallState

logger = logging.getLogger("daiv.agents")
Expand All @@ -21,64 +20,56 @@ class CodebaseQAAgent(BaseAgent[CompiledStateGraph]):
Agent to answer questions about the codebase.
"""

model_name = settings.GENERIC_COST_EFFICIENT_MODEL_NAME
model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME

def __init__(self, index: CodebaseIndex):
self.index = index
self.tools = [SearchCodeSnippetsTool(api_wrapper=index)] + WebSearchToolkit.create_instance().get_tools()
super().__init__()

def get_model_kwargs(self) -> dict:
kwargs = super().get_model_kwargs()
kwargs["temperature"] = 0.3
return kwargs

def compile(self) -> CompiledStateGraph:
workflow = StateGraph(OverallState)

# Add nodes
workflow.add_node("query_or_respond", self.query_or_respond)
workflow.add_node("tools", ToolNode(self.tools))
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("generate", self.generate)

# Add edges
workflow.add_edge(START, "query_or_respond")
workflow.add_conditional_edges("query_or_respond", tools_condition, {END: END, "tools": "tools"})
workflow.add_edge("tools", "generate")
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

return workflow.compile()

def query_or_respond(self, state: OverallState):
def retrieve(self, state: OverallState):
"""
Generate tool call for retrieval or respond.
Retrieve the context.
"""
llm_with_tools = self.model.bind_tools(self.tools)
response = llm_with_tools.invoke([system_query_or_respond] + state["messages"])
return {"messages": [response]}
context = SearchCodeSnippetsTool(api_wrapper=self.index).invoke({
"query": state["messages"][-1].content,
"intent": "Searching the codebase.",
})
if not context:
context = WebSearchTool().invoke({
"query": state["messages"][-1].content,
"intent": "No code snippets found, searching the web.",
})
return {"context": context}

def generate(self, state: OverallState):
"""
Generate answer.
"""
tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
tool_messages.append(message)

if tool_messages:
docs_content = "\n\n".join(doc.content for doc in tool_messages[::-1])

conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls)
]
prompt = [
system.format(
context=docs_content,
codebase_client=self.index.repo_client.client_slug,
codebase_url=self.index.repo_client.codebase_url,
)
] + conversation_messages

response = self.model.invoke(prompt)
else:
response = state["messages"][-1]

return {"messages": [response]}
prompt = [
system.format(
context=state["context"],
codebase_client=self.index.repo_client.client_slug,
codebase_url=self.index.repo_client.codebase_url,
)
] + state["messages"]

return {"messages": [self.model.invoke(prompt)]}
2 changes: 1 addition & 1 deletion daiv/automation/agents/codebase_qa/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


class OverallState(MessagesState):
pass
context: str
160 changes: 41 additions & 119 deletions daiv/automation/agents/codebase_search/agent.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,70 @@
from __future__ import annotations

import logging
from typing import cast
from typing import TYPE_CHECKING

from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langchain.retrievers.document_compressors import LLMListwiseRerank
from langchain_core.documents import Document
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough

from automation.agents import BaseAgent
from automation.agents.codebase_search.schemas import GradeDocumentsOutput, ImprovedQueryOutput
from automation.conf import settings
from codebase.indexes import CodebaseIndex
from automation.retrievers import MultiQueryRephraseRetriever

if TYPE_CHECKING:
from collections.abc import Sequence

from .prompts import grade_human, grade_system, re_write_human, re_write_system
from .state import GradeDocumentState, OverallState
from langchain_core.retrievers import BaseRetriever

logger = logging.getLogger("daiv.agents")


class CodebaseSearchAgent(BaseAgent[CompiledStateGraph]):
class CodebaseSearchAgent(BaseAgent[Runnable[str, list[Document]]]):
"""
Agent to search for code snippets in the codebase.
"""

def __init__(self, index: CodebaseIndex, source_repo_id: str | None = None, source_ref: str | None = None):
super().__init__()
self.index = index
self.source_repo_id = source_repo_id
self.source_ref = source_ref

def compile(self) -> CompiledStateGraph:
workflow = StateGraph(OverallState)

# Add nodes
workflow.add_node("retrieve", self.retrieve)
workflow.add_node("grade_document", self.grade_document)
workflow.add_node("post_grade_document", self.post_grade_document)
workflow.add_node("transform_query", self.transform_query)

# Add edges
workflow.add_edge(START, "retrieve")
workflow.add_conditional_edges(
"retrieve", self.should_grade_documents, ["transform_query", "grade_document", END]
)
workflow.add_edge("grade_document", "post_grade_document")
workflow.add_conditional_edges("post_grade_document", self.should_transform_query, ["transform_query", END])
workflow.add_edge("transform_query", "retrieve")
model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME

return workflow.compile()
def __init__(self, retriever: BaseRetriever, *args, **kwargs):
self.retriever = retriever
super().__init__(*args, **kwargs)

def retrieve(self, state: OverallState):
def compile(self) -> Runnable:
"""
Retrieve documents from the codebase index.
Compile the agent into a Runnable.
Args:
state (GraphState): The current state of the graph.
Returns:
Runnable: The compiled agent
"""
if self.source_repo_id and self.source_ref:
# we need to update the index before retrieving the documents
# because the codebase search agent needs to search for the codebase changes
# and we need to make sure the index is updated before the agent starts retrieving the documents
self.index.update(self.source_repo_id, self.source_ref)

if self.source_repo_id and self.source_ref:
documents = self.index.search(self.source_repo_id, self.source_ref, state["query"])
else:
documents = self.index.search_all(state["query"])

return {"documents": documents, "iterations": state.get("iterations", 0) + 1}
return {
"query": RunnablePassthrough(),
"documents": MultiQueryRephraseRetriever.from_llm(self.retriever, llm=self.model),
} | RunnableLambda(
lambda inputs: self._compress_documents(inputs["documents"], inputs["query"]), name="compress_documents"
)

def grade_document(self, state: GradeDocumentState):
def get_model_kwargs(self) -> dict:
"""
Grade the relevance of the retrieved document to the query.
Get the model kwargs with a redefined temperature to make the model more creative.
Args:
state (GraphState): The current state of the graph.
"""
grader_agent = self.model.with_structured_output(GradeDocumentsOutput, method="json_schema")

messages = [
grade_system,
grade_human.format(
query=state["query"], query_intent=state["query_intent"], document=state["document"].page_content
),
]
response = cast("GradeDocumentsOutput", grader_agent.invoke(messages))

if response.binary_score:
logger.info("[grade_document] Document '%s' is relevant to the query", state["document"].metadata["source"])
return {"documents": []}
return {"documents": [state["document"]]}

def post_grade_document(self, state: OverallState):
"""
Post-process the grade of the document.
Returns:
dict: The model kwargs
"""
return {"documents": []}
kwargs = super().get_model_kwargs()
kwargs["temperature"] = 0.5
return kwargs

def transform_query(self, state: OverallState):
def _compress_documents(self, documents: list[Document], query: str) -> Sequence[Document]:
"""
Transform the query to improve retrieval.
Compress the documents using a listwise reranker.
Args:
state (GraphState): The current state of the graph.
"""
messages = [re_write_system, re_write_human.format(query=state["query"], query_intent=state["query_intent"])]

query_rewriter = self.model.with_structured_output(ImprovedQueryOutput, method="json_schema")
response = cast(
"ImprovedQueryOutput", query_rewriter.invoke(messages, config={"configurable": {"temperature": 0.7}})
)

logger.info("[transform_query] Query '%s' improved to '%s'", state["query"], response.query)

return {"query": response.query}
documents (Sequence[Document]): The documents to compress
query (str): The search query string
def should_grade_documents(self, state: OverallState):
Returns:
Sequence[Document]: The compressed documents
"""
Check if we should transform the query.
"""
if not state["documents"]:
if state["iterations"] < settings.CODEBASE_SEARCH_MAX_TRANSFORMATIONS:
logger.info("[should_grade_documents] No documents retrieved. Moving to transform_query state.")
return "transform_query"
else:
logger.info("[should_grade_documents] No documents retrieved. Ending the process.")
return END

logger.info("[should_grade_documents] Documents retrieved. Moving to grade_documents state.")
return [
Send(
"grade_document", {"document": document, "query": state["query"], "query_intent": state["query_intent"]}
)
for document in state["documents"]
]

def should_transform_query(self, state: OverallState):
"""
Check if we should transform the query.
"""
logger.info(
"[should_transform_query] %d documents found (iteration: %d)", len(state["documents"]), state["iterations"]
)
if not state["documents"] and state["iterations"] < settings.CODEBASE_SEARCH_MAX_TRANSFORMATIONS:
logger.info("[should_transform_query] No relevant documents found. Moving to transform_query state.")
return "transform_query"
if not state["documents"]:
logger.info("[should_transform_query] No relevant documents found. Ending the process.")
return END
reranker = LLMListwiseRerank.from_llm(llm=self.model, top_n=5)
return reranker.compress_documents(documents, query)
2 changes: 1 addition & 1 deletion daiv/automation/agents/issue_addressor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.store.base import BaseStore # noqa: TC002
from langgraph.store.memory import InMemoryStore

from automation.agents import BaseAgent
Expand Down Expand Up @@ -33,7 +34,6 @@

if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
from langgraph.store.base import BaseStore

from codebase.base import FileChange
from codebase.clients import AllRepoClient
Expand Down
Loading

0 comments on commit 463a662

Please sign in to comment.