Skip to content

Commit

Permalink
Improved resilience of agents. Fixed bugs on executions. (#219)
Browse files Browse the repository at this point in the history
* Improved resilience of agents. Fixed bugs on executions.

* Fixed linting.

* Fixed unittests.
  • Loading branch information
srtab authored Jan 30, 2025
1 parent 4dc9955 commit 5cd4e68
Show file tree
Hide file tree
Showing 19 changed files with 347 additions and 292 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added custom Django checks to ensure the API keys for the models configured are set.
- Added `REASONING` setting to the `chat` app to enable/disable reasoning tags on the chat completion. This will allow tools like `OpenWebUI` to show loader indicating the agent is thinking.
- Added `external_link` to the `CodeSnippet` tool to allow linking to the codebase file on the Repository UI.
- Added fallback models support to `BaseAgent` to ease and streamline the fallback logic on all agents.

### Changed

- Improved `README.md` to include more information about the project and how to run it locally.
- Improved `CodebaseQAAgent` logic to consider history of the conversation and have a more agentic behavior by following the ReAct pattern. The answer now will include a `References` section with links to the codebase files that were used to answer the question.
- Moved DeepSeek integration to langchain official: `langchain-deepseek-official`.
- Now `CodebaseQAAgent` will use the fallback model if the main model is not available.

### Fixed

- Filled placeholder on the `LICENSE` file with correct copyright information.
- Agents model instantiation was not overriding the model name as expected, leading to wrong model provider handling.
- `ReActAgent` was not calling the structued output on some cases, which was not the expected behavior. Now it will always call the structured output if it's defined, even if the agent is not calling the tool.
- `FileChange` is not hashable, leading to errors when storing it in a set on `ReviewAddressorManager`. Created structure to store file changes without repeating the same file changes.
- Assessment on `ReviewAddressorAgent` was not calling the structured output on some cases, leading to errors. Now it will always be called.

## [0.1.0-alpha.21] - 2025-01-24

Expand Down
2 changes: 1 addition & 1 deletion daiv/accounts/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
T = TypeVar("T", bound="APIKey")


class APIKeyManager(models.Manager, Generic[T]):
class APIKeyManager(models.Manager, Generic[T]): # noqa: UP046
"""
Manager for the APIKey model.
"""
Expand Down
29 changes: 14 additions & 15 deletions daiv/automation/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from decimal import Decimal
from enum import StrEnum
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING, Generic, TypeVar, cast

from langchain.chat_models.base import _attempt_infer_model_provider, init_chat_model
from langchain_community.callbacks import OpenAICallbackHandler
Expand All @@ -29,28 +29,32 @@ class ModelProvider(StrEnum):
T = TypeVar("T", bound=Runnable)


class BaseAgent[T: Runnable](ABC):
class BaseAgent(ABC, Generic[T]): # noqa: UP046
"""
Base agent class for creating agents that interact with a model.
"""

agent: T

model_name: str = settings.GENERIC_COST_EFFICIENT_MODEL_NAME
fallback_model_name: str | None = None

def __init__(
self,
*,
run_name: str | None = None,
model_name: str | None = None,
fallback_model_name: str | None = None,
usage_handler: OpenAICallbackHandler | None = None,
checkpointer: PostgresSaver | None = None,
):
self.run_name = run_name or self.__class__.__name__
self.model_name = model_name or self.model_name
self.fallback_model_name = fallback_model_name or self.fallback_model_name
self.usage_handler = usage_handler or OpenAICallbackHandler()
self.checkpointer = checkpointer
self.model = self.get_model()
self.model = self.get_model(model=self.model_name)
self.fallback_model = self.get_model(model=self.fallback_model_name) if self.fallback_model_name else None
self.agent = self.compile().with_config(self.get_config())

@abstractmethod
Expand Down Expand Up @@ -91,15 +95,10 @@ def get_model_kwargs(self, **kwargs) -> dict:
# If needed, we can increase this value using the configurable field.
_kwargs["max_tokens"] = "2048"
_kwargs["model_kwargs"]["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
elif model_provider == ModelProvider.DEEPSEEK:
assert settings.DEEPSEEK_API_KEY is not None, "DEEPSEEK_API_KEY is not set"

_kwargs["model_provider"] = "openai"
_kwargs["base_url"] = settings.DEEPSEEK_API_BASE
_kwargs["api_key"] = settings.DEEPSEEK_API_KEY
elif model_provider == ModelProvider.GOOGLE_GENAI:
# otherwise it will be inferred as google_vertexai
_kwargs["model_provider"] = "google_genai"
elif model_provider in [ModelProvider.DEEPSEEK, ModelProvider.GOOGLE_GENAI]:
# otherwise google_genai will be inferred as google_vertexai
# deepseek is not yet being inferred yet by langchain
_kwargs["model_provider"] = model_provider
return _kwargs

def get_config(self) -> RunnableConfig:
Expand Down Expand Up @@ -132,15 +131,15 @@ def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""
return self.model.get_num_tokens_from_messages(messages)

def get_max_token_value(self) -> int:
def get_max_token_value(self, model_name: str) -> int:
"""
Get the maximum token value for the model.
Returns:
int: The maximum token value
"""

match BaseAgent.get_model_provider(self.model_name):
match BaseAgent.get_model_provider(model_name):
case ModelProvider.ANTHROPIC:
return 8192

Expand All @@ -157,7 +156,7 @@ def get_max_token_value(self) -> int:
return 8192

case _:
raise ValueError(f"Unknown provider for model {self.model_name}")
raise ValueError(f"Unknown provider for model {model_name}")

@staticmethod
def get_model_provider(model_name: str) -> ModelProvider:
Expand Down
23 changes: 16 additions & 7 deletions daiv/automation/agents/codebase_qa/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import logging
from typing import Any
from typing import TYPE_CHECKING, Any, cast

from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import Runnable, RunnableLambda
Expand All @@ -9,10 +11,14 @@
from automation.agents import BaseAgent
from automation.conf import settings
from automation.tools.repository import SearchCodeSnippetsTool
from codebase.indexes import CodebaseIndex

from .prompts import data_collection_system

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel

from codebase.indexes import CodebaseIndex

logger = logging.getLogger("daiv.agents")


Expand Down Expand Up @@ -43,19 +49,22 @@ class CodebaseQAAgent(BaseAgent[Runnable[dict[str, Any], FinalAnswer]]):
"""

model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME
fallback_model_name = settings.GENERIC_COST_EFFICIENT_MODEL_NAME

def __init__(self, index: CodebaseIndex):
def __init__(self, *args, index: CodebaseIndex, **kwargs):
self.index = index
super().__init__()
super().__init__(*args, **kwargs)

def compile(self) -> Runnable:
return RunnableLambda(self._execute_react_agent) | self.model.with_structured_output(FinalAnswer)
return RunnableLambda(self._execute_react_agent) | self.model.with_structured_output(
FinalAnswer
).with_fallbacks([cast("BaseChatModel", self.fallback_model).with_structured_output(FinalAnswer)])

def _execute_react_agent(self, inputs):
react_agent = create_react_agent(
self.model,
self.model.with_fallbacks([cast("BaseChatModel", self.fallback_model)]),
tools=[SearchCodeSnippetsTool(api_wrapper=self.index)],
state_modifier=ChatPromptTemplate.from_messages([data_collection_system, MessagesPlaceholder("messages")]),
prompt=ChatPromptTemplate.from_messages([data_collection_system, MessagesPlaceholder("messages")]),
)
result = react_agent.invoke(inputs)

Expand Down
12 changes: 9 additions & 3 deletions daiv/automation/agents/codebase_search/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMListwiseRerank
Expand All @@ -13,8 +13,10 @@
from automation.retrievers import MultiQueryRephraseRetriever

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.retrievers import BaseRetriever


logger = logging.getLogger("daiv.agents")


Expand All @@ -24,6 +26,7 @@ class CodebaseSearchAgent(BaseAgent[Runnable[str, list[Document]]]):
"""

model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME
fallback_model_name = settings.GENERIC_COST_EFFICIENT_MODEL_NAME

def __init__(self, retriever: BaseRetriever, rephrase: bool = True, *args, **kwargs):
self.retriever = retriever
Expand All @@ -38,13 +41,16 @@ def compile(self) -> Runnable:
Runnable: The compiled agent
"""
if self.rephrase:
base_retriever = MultiQueryRephraseRetriever.from_llm(self.retriever, llm=self.get_model(temperature=0.3))
base_retriever = MultiQueryRephraseRetriever.from_llm(
self.retriever, llm=self.model.with_fallbacks([cast("BaseChatModel", self.fallback_model)])
)
else:
base_retriever = self.retriever

return ContextualCompressionRetriever(
base_compressor=LLMListwiseRerank.from_llm(
llm=self.get_model(temperature=0), top_n=settings.CODEBASE_SEARCH_TOP_N
llm=self.model.with_fallbacks([cast("BaseChatModel", self.fallback_model)]),
top_n=settings.CODEBASE_SEARCH_TOP_N,
),
base_retriever=base_retriever,
)
16 changes: 11 additions & 5 deletions daiv/automation/agents/pipeline_fixer/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import logging
from typing import Literal, cast
from typing import TYPE_CHECKING, Literal, cast

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import END, START, CompiledStateGraph, StateGraph
from langgraph.store.base import BaseStore
from langgraph.store.base import BaseStore # noqa: TC002
from langgraph.store.memory import InMemoryStore

from automation.agents import BaseAgent
Expand All @@ -14,8 +15,6 @@
from automation.tools.sandbox import RunSandboxCommandsTool
from automation.tools.toolkits import ReadRepositoryToolkit, SandboxToolkit, WebSearchToolkit, WriteRepositoryToolkit
from automation.utils import file_changes_namespace
from codebase.base import FileChange
from codebase.clients import AllRepoClient
from codebase.indexes import CodebaseIndex
from core.config import RepositoryConfig

Expand All @@ -29,6 +28,13 @@
from .schemas import ActionPlanOutput, PipelineLogClassifierOutput
from .state import OverallState

if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig

from codebase.base import FileChange
from codebase.clients import AllRepoClient


logger = logging.getLogger("daiv.agents")


Expand Down
15 changes: 10 additions & 5 deletions daiv/automation/agents/pr_describer/agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from typing import NotRequired, TypedDict
from __future__ import annotations

from typing import TYPE_CHECKING, NotRequired, TypedDict, cast

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable

from automation.agents import BaseAgent
from automation.conf import settings
from codebase.base import FileChange

from .prompts import human, system
from .schemas import PullRequestDescriberOutput

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel

from codebase.base import FileChange


class PullRequestDescriberInput(TypedDict):
changes: list[FileChange]
Expand All @@ -23,13 +29,12 @@ class PullRequestDescriberAgent(BaseAgent[Runnable[PullRequestDescriberInput, Pu
"""

model_name = settings.GENERIC_COST_EFFICIENT_MODEL_NAME
fallback_model_name = settings.CODING_COST_EFFICIENT_MODEL_NAME

def compile(self) -> Runnable:
prompt = ChatPromptTemplate.from_messages([system, human]).partial(
branch_name_convention=None, extra_details={}
)
return prompt | self.model.with_structured_output(PullRequestDescriberOutput).with_fallbacks([
self.get_model(model=settings.CODING_COST_EFFICIENT_MODEL_NAME).with_structured_output(
PullRequestDescriberOutput
)
cast("BaseChatModel", self.fallback_model).with_structured_output(PullRequestDescriberOutput)
])
30 changes: 14 additions & 16 deletions daiv/automation/agents/prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.chat_agent_executor import AgentState
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.store.base import BaseStore # noqa: TC002
from openai import InternalServerError as OpenAIInternalServerError
from pydantic import BaseModel, ValidationError # noqa: TCH002

Expand All @@ -23,7 +24,6 @@

from langchain_core.prompts.chat import MessageLikeRepresentation
from langchain_core.tools.base import BaseTool
from langgraph.store.base import BaseStore

from codebase.clients import AllRepoClient

Expand Down Expand Up @@ -52,15 +52,13 @@ def __init__(
*args,
with_structured_output: type[BaseModel] | None = None,
store: BaseStore | None = None,
fallback_model_name: str | None = None,
**kwargs,
):
self.tool_classes = tools
self.with_structured_output = with_structured_output
self.store = store
self.structured_tool_name = None
self.state_class: type[AgentState] = AgentState
self.fallback_model_name = fallback_model_name
if self.with_structured_output:
self.tool_classes.append(self.with_structured_output)
self.structured_tool_name = self.with_structured_output.model_json_schema()["title"]
Expand Down Expand Up @@ -117,9 +115,7 @@ def call_model(self, state: AgentState):
self.model_name,
self.fallback_model_name,
)
llm_with_tools = self.get_model(model=self.fallback_model_name).bind_tools(
self.tool_classes, **tools_kwargs
)
llm_with_tools = self.fallback_model.bind_tools(self.tool_classes, **tools_kwargs)
response = llm_with_tools.invoke(state["messages"])
else:
raise e
Expand Down Expand Up @@ -148,18 +144,20 @@ def respond(self, state: AgentState):
response = None

try:
if not last_message.tool_calls:
# this can happen if the agent don't use any tool, which is an edge case, but we need to handle it
# the expected behavior is the agent calling the tool defined by the with_structured_output
raise ValidationError("No tool calls found in the last message.")

response = self.with_structured_output(**last_message.tool_calls[0]["args"])
except ValidationError:
logger.warning("[ReAcT] Error structuring output with tool args. Fallback to llm with_structured_output.")

llm_with_structured_output = self.model.with_structured_output(self.with_structured_output)
response = cast(
"BaseModel",
llm_with_structured_output.invoke(
[HumanMessage(last_message.pretty_repr())],
config={"configurable": {"model": settings.GENERIC_COST_EFFICIENT_MODEL_NAME}},
),
)
llm_with_structured_output = self.get_model(
model=settings.GENERIC_COST_EFFICIENT_MODEL_NAME
).with_structured_output(self.with_structured_output)

response = cast("BaseModel", llm_with_structured_output.invoke([HumanMessage(last_message.pretty_repr())]))

return {"response": response}

Expand All @@ -175,10 +173,10 @@ def should_continue(self, state: AgentState) -> Literal["respond", "continue", "
"""
last_message = cast("AIMessage", state["messages"][-1])

if (
if self.structured_tool_name and (
last_message.tool_calls
and self.structured_tool_name
and last_message.tool_calls[0]["name"] == self.structured_tool_name
or not last_message.tool_calls
):
return "respond"
elif not last_message.tool_calls:
Expand Down
Loading

0 comments on commit 5cd4e68

Please sign in to comment.