Skip to content

Commit

Permalink
Fixed max_token on replacer agent.
Browse files Browse the repository at this point in the history
  • Loading branch information
srtab committed Nov 1, 2024
1 parent fc648d0 commit 838f4c6
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 140 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ cython_debug/
#.idea/

config.secrets.env
state-snapshots.txt
33 changes: 31 additions & 2 deletions daiv/automation/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from abc import ABC, abstractmethod
from decimal import Decimal
from enum import Enum
from enum import StrEnum
from typing import TYPE_CHECKING, Generic, TypeVar, cast

from langchain.chat_models.base import BaseChatModel, _attempt_infer_model_provider, init_chat_model
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_community.callbacks import OpenAICallbackHandler
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_openai.chat_models import ChatOpenAI
from pydantic import BaseModel
Expand All @@ -22,7 +24,7 @@
GENERIC_COST_EFFICIENT_MODEL_NAME = "gpt-4o-mini-2024-07-18"


class ModelProvider(Enum):
class ModelProvider(StrEnum):
ANTHROPIC = "anthropic"
OPENAI = "openai"

Expand Down Expand Up @@ -107,13 +109,40 @@ def draw_mermaid(self):
"""
return self.agent.get_graph().draw_mermaid()

def get_num_tokens(self, text: str) -> int:
"""
Get the number of tokens in a text.
Args:
text (str): The text
Returns:
int: The number of tokens
"""
if _attempt_infer_model_provider(self.model_name) == ModelProvider.ANTHROPIC:
return cast(ChatAnthropic, self.model)._client.count_tokens(text)
return self.model.get_num_tokens(text)

def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""
Get the number of tokens from a list of messages.
Args:
messages (list[BaseMessage]): The messages
Returns:
int: The number of tokens
"""
return sum(self.get_num_tokens(get_buffer_string([m])) for m in messages)

def get_max_token_value(self) -> int:
"""
Get the maximum token value for the model.
Returns:
int: The maximum token value
"""

match _attempt_infer_model_provider(self.model_name):
case ModelProvider.ANTHROPIC:
return 8192 if self.model_name.startswith("claude-3-5-sonnet") else 4096
Expand Down
31 changes: 18 additions & 13 deletions daiv/automation/agents/snippet_replacer/agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from functools import cached_property
from typing import TypedDict

from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda
from pydantic import BaseModel

from automation.agents import CODING_COST_EFFICIENT_MODEL_NAME, BaseAgent

Expand All @@ -12,14 +12,14 @@
from .utils import find_original_snippet


class SnippetReplacerInput(BaseModel):
class SnippetReplacerInput(TypedDict):
"""
Input for the SnippetReplacerAgent.
"""

original_snippet: str = ""
replacement_snippet: str = ""
content: str = ""
original_snippet: str
replacement_snippet: str
content: str


class SnippetReplacerAgent(BaseAgent[Runnable[SnippetReplacerInput, SnippetReplacerOutput | str]]):
Expand All @@ -36,26 +36,31 @@ def compile(self) -> Runnable:
Returns:
CompiledStateGraph | Runnable: The compiled agent.
"""
return self._prompt | RunnableLambda(self._route)
return RunnableLambda(self._route) | RunnableLambda(self._post_process)

def _route(self, input_data: SnippetReplacerInput) -> Runnable:
if self.validate_max_token_not_exceeded(input_data):
return self.model.with_structured_output(SnippetReplacerOutput, method="json_schema")
return self._prompt | self.model.with_structured_output(SnippetReplacerOutput, method="json_schema")
return RunnableLambda(self._replace_content_snippet)

def _replace_content_snippet(self, input_data: SnippetReplacerInput) -> SnippetReplacerOutput | str:
original_snippet_found = find_original_snippet(
input_data.original_snippet, input_data.content, initial_line_threshold=1
input_data["original_snippet"], input_data["content"], initial_line_threshold=1
)
if not original_snippet_found:
return "error: Original snippet not found."

replaced_content = input_data.content.replace(original_snippet_found, input_data.replacement_snippet)
replaced_content = input_data["content"].replace(original_snippet_found, input_data["replacement_snippet"])
if not replaced_content:
return "error: Snippet replacement failed."

return SnippetReplacerOutput(content=replaced_content)

def _post_process(self, output: SnippetReplacerOutput | str) -> SnippetReplacerOutput | str:
if isinstance(output, SnippetReplacerOutput) and not output.content.endswith("\n"):
output.content += "\n"
return output

def validate_max_token_not_exceeded(self, input_data: SnippetReplacerInput) -> bool: # noqa: A002
"""
Validate that the messages does not exceed the maximum token value of the model.
Expand All @@ -67,12 +72,12 @@ def validate_max_token_not_exceeded(self, input_data: SnippetReplacerInput) -> b
bool: True if the text does not exceed the maximum token value, False otherwise
"""
prompt = self._prompt
filled_messages = prompt.invoke(input_data.model_dump()).to_messages()
empty_messages = prompt.invoke(SnippetReplacerInput().model_dump()).to_messages()
filled_messages = prompt.invoke(input_data).to_messages()
empty_messages = prompt.invoke({"original_snippet": "", "replacement_snippet": "", "content": ""}).to_messages()
# get the number of tokens used in the messages
used_tokens = self.model.get_num_tokens_from_messages(filled_messages)
used_tokens = self.get_num_tokens_from_messages(filled_messages)
# try to anticipate the number of tokens needed for the output
estimated_needed_tokens = used_tokens - self.model.get_num_tokens_from_messages(empty_messages)
estimated_needed_tokens = used_tokens - self.get_num_tokens_from_messages(empty_messages)
return estimated_needed_tokens <= self.get_max_token_value() - used_tokens

@cached_property
Expand Down
70 changes: 27 additions & 43 deletions daiv/automation/tools/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import logging
import textwrap
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

from langchain_core.prompts.string import jinja2_formatter
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field

from automation.agents.codebase_search import CodebaseSearchAgent
from automation.agents.snippet_replacer.agent import SnippetReplacerAgent
from automation.agents.snippet_replacer.schemas import SnippetReplacerOutput
from codebase.base import FileChange, FileChangeAction
from codebase.clients import RepoClient
from codebase.indexes import CodebaseIndex
Expand Down Expand Up @@ -51,11 +50,14 @@ class SearchCodeSnippetsTool(BaseTool):
""" # noqa: E501
).format(retrieve_file_content_name=RETRIEVE_FILE_CONTENT_NAME)
args_schema: type[BaseModel] = SearchCodeSnippetsInput
handle_validation_error: bool = True

source_repo_id: str = Field(description="The repository ID to search in.")
source_ref: str = Field(description="The branch or commit to search in.")
source_repo_id: str = Field(..., description="The repository ID to search in.")
source_ref: str = Field(..., description="The branch or commit to search in.")

api_wrapper: CodebaseIndex = Field(default_factory=lambda: CodebaseIndex(repo_client=RepoClient.create_instance()))
api_wrapper: CodebaseIndex = Field(
..., default_factory=lambda: CodebaseIndex(repo_client=RepoClient.create_instance())
)

def _run(self, query: str, intent: str, **kwargs) -> str:
"""
Expand Down Expand Up @@ -100,10 +102,12 @@ class BaseRepositoryTool(BaseTool):
Base class for repository interaction tools.
"""

source_repo_id: str = Field(description="The repository ID to search in.")
source_ref: str = Field(description="The branch or commit to search in.")
handle_validation_error: bool = True

source_repo_id: str = Field(..., description="The repository ID to search in.")
source_ref: str = Field(..., description="The branch or commit to search in.")

api_wrapper: RepoClient = Field(default_factory=RepoClient.create_instance)
api_wrapper: RepoClient = Field(..., default_factory=RepoClient.create_instance)

def _get_file_content(self, file_path: str, store: BaseStore) -> str | None:
"""
Expand Down Expand Up @@ -246,58 +250,38 @@ def _run(
if not (repo_file_content := self._get_file_content(file_path, store)):
return f"error: File {file_path} not found."

replaced_content = self._replace_content(original_snippet, replacement_snippet, repo_file_content)
replacer = SnippetReplacerAgent()
result = replacer.agent.invoke({
"original_snippet": original_snippet,
"replacement_snippet": replacement_snippet,
"content": repo_file_content,
})

if isinstance(result, str):
# It means, and error occurred during the replacement.
return result

if file_change:
file_change.content = replaced_content
file_change.content = result.content
file_change.commit_messages.append(commit_message)
else:
file_change = FileChange(
action=FileChangeAction.UPDATE,
file_path=file_path,
content=replaced_content,
content=result.content,
commit_messages=[commit_message],
)

store.put(("file_changes", self.source_repo_id, self.source_ref), file_path, {"data": file_change})

return "success: Snippet replaced."

def _replace_content(self, original_snippet: str, replacement_snippet: str, content: str) -> str:
"""
Replaces a snippet in a file with the provided replacement.
Args:
original_snippet: The original snippet to replace.
replacement_snippet: The replacement snippet.
content: The content of the file to replace the snippet in.
Returns:
The content of the file with the snippet replaced.
"""
replacer = SnippetReplacerAgent()

result = cast(
SnippetReplacerOutput,
replacer.agent.invoke({
"original_snippet": original_snippet,
"replacement_snippet": replacement_snippet,
"content": content,
}),
)

# Add a trailing snippet to the new snippet to match the original snippet if there isn't already one.
if not result.content.endswith("\n"):
result.content += "\n"

return result.content


class CreateNewRepositoryFileTool(BaseRepositoryTool):
name: str = CREATE_NEW_REPOSITORY_FILE_NAME
description: str = textwrap.dedent(
"""\
Create a new file within the repository with the provided content. Use this tool only to create files that do not already exist in the repository. Do not use this tool to overwrite or modify existing files. Ensure that the file path does not point to an existing file in the repository. Necessary directories should already exist in the repository; this tool does not create directories.
Create a new file within the repository with the provided file content. Use this tool only to create files that do not already exist in the repository. Do not use this tool to overwrite or modify existing files. Ensure that the file path does not point to an existing file in the repository. Necessary directories should already exist in the repository; this tool does not create directories.
""" # noqa: E501
)

Expand All @@ -306,7 +290,7 @@ class CreateNewRepositoryFileTool(BaseRepositoryTool):
def _run(
self,
file_path: str,
content: str,
file_content: str,
commit_message: str,
store: BaseStore,
run_manager: CallbackManagerForToolRun | None = None,
Expand Down Expand Up @@ -336,7 +320,7 @@ def _run(
"data": FileChange(
action=FileChangeAction.CREATE,
file_path=file_path,
content=content,
content=file_content,
commit_messages=[commit_message],
)
},
Expand Down
Loading

0 comments on commit 838f4c6

Please sign in to comment.