Skip to content

Commit

Permalink
Add show_definition tool to ACR; Add tool use tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
waleko committed Jun 21, 2024
1 parent b4d6ed1 commit 91e92e3
Show file tree
Hide file tree
Showing 24 changed files with 804 additions and 558 deletions.
13 changes: 9 additions & 4 deletions code_editing/agents/agent_codeeditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain_core.runnables import RunnableConfig, RunnableLambda

from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.utils.checkout_extractor import CheckoutExtractor
from code_editing.agents.utils.tool_factory import ToolFactory
from code_editing.code_editor import CEInput, CEOutput, CodeEditor
Expand Down Expand Up @@ -46,15 +47,19 @@ def generate_diff(self, req: CEInput, root_span) -> CEOutput:
# Context providers that help the agent to search for the code
context_providers = {k: instantiate(v, **generation_kwargs) for k, v in self.context_providers_cfg.items()}

run_overview_manager = RunOverviewManager(
**generation_kwargs,
context_providers=context_providers,
)

# Tools available to the agent
tools = self.tool_factory.build(
**generation_kwargs,
**context_providers,
run_overview_manager=run_overview_manager,
root_span=root_span, # W&B root span
)

# Build the graph runnable
app = self.graph_factory.tools(tools).build(**context_providers)
app = self.graph_factory.tools(tools).build(run_overview_manager=run_overview_manager)

# Diff collection
def to_ceoutput(state):
Expand All @@ -64,7 +69,7 @@ def to_ceoutput(state):
if viewed_lines is None:
logging.warning("No viewed lines found in the graph output")
viewed_lines = {}
return {"prediction": diff, "viewed_lines": viewed_lines}
return {"prediction": diff, "viewed_lines": viewed_lines, "run": run_overview_manager.get_run_summary()}

# Invoke the graph
return (app | RunnableLambda(to_ceoutput, name="Collect Diff")).invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from langchain_core.runnables import RunnableLambda

from code_editing.agents.collect_edit.collect_edit import CollectEditState
from code_editing.agents.context_providers.retrieval.retrieval_helper import RetrievalHelper
from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager
from code_editing.utils.tokenization_utils import TokenizationUtils


Expand All @@ -18,9 +20,9 @@ def __init__(self, k: Optional[int] = 10, total_context: Optional[int] = None, *
if (k is None) == (total_context is None):
raise ValueError("Either k or total_context should be provided")

def build(self, *args, retrieval_helper=None, **kwargs):
if retrieval_helper is None:
raise ValueError("Retrieval helper is not set")
def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):
# noinspection PyTypeChecker
retrieval_helper: RetrievalHelper = run_overview_manager.get_ctx_provider("retrieval_helper")

def search(state: CollectEditState, config) -> CollectEditState:
if self.k is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from code_editing.agents.context_providers.acr_search.search_manage import SearchManager
from code_editing.agents.context_providers.acr_search.search_utils import to_relative_path
from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager

SYSTEM_PROMPT = """You are a software developer maintaining a large project.
You are working on an issue submitted to your project.
Expand All @@ -28,6 +29,7 @@
"\n- search_method(method_name: str): Search for a method in the entire codebase"
"\n- search_code(code_str: str): Search for a code snippet in the entire codebase"
"\n- search_code_in_file(code_str: str, file_path: str): Search for a code snippet in a given file file"
"\n- show_definition(symbol: str, line_no: int, file_path: str): Show the definition of a symbol in a given file"
"\n\nNote that you can use multiple search APIs in one round."
"\n\nNow analyze the issue and select necessary APIs to get more context of the project. Each API call must have concrete arguments as inputs."
)
Expand All @@ -46,6 +48,7 @@
search_class(class_name: str)
search_code_in_file(code_str: str, file_path: str)
search_code(code_str: str)
show_definition(symbol: str, line_no: int, file_path: str)
Provide your answer in JSON structure like this, you should ignore the argument placeholders in api calls.
For example, search_code(code_str="str") should be search_code("str")
Expand Down Expand Up @@ -146,16 +149,28 @@ def prepare_issue_prompt(problem_stmt: str) -> str:
return result


def remove_unwanted_lines(text: str, word: str) -> str:
return "\n".join([e for e in text.split("\n") if word not in e])


class ACRRetrieval(GraphFactory):
name = "acr_retrieval"

def __init__(self, *args, max_tries: int = 5, **kwargs):
def __init__(self, *args, max_tries: int = 5, use_show_definition: bool = False, **kwargs):
# super().__init__(*args, **kwargs)
super().__init__()
self.max_tries = max_tries
self.prompt = prompt
self.proxy_prompt = PROXY_PROMPT
if not use_show_definition:
# remove corresponding line
self.prompt = remove_unwanted_lines(self.prompt, "show_definition")
self.proxy_prompt = remove_unwanted_lines(self.proxy_prompt, "show_definition")
print(self.prompt)
print(self.proxy_prompt)

def proxy_run(self, text: str) -> Optional[dict]:
messages = [SystemMessage(PROXY_PROMPT)]
messages = [SystemMessage(self.proxy_prompt)]
messages.append(HumanMessage(text))
llm: BaseChatModel = self._llm
parser = JsonOutputParser()
Expand All @@ -174,15 +189,13 @@ def proxy_run(self, text: str) -> Optional[dict]:
logging.warning("Failed to get a valid response after max tries.")
return None

def build(self, *args, retrieval_helper=None, search_manager=None, **kwargs):
if retrieval_helper is None:
raise ValueError("Retrieval helper is not set")
if search_manager is None:
raise ValueError("Search manager is not set")
def build(self, *args, run_overview_manager: RunOverviewManager, **kwargs):
# noinspection PyTypeChecker
search_manager: SearchManager = run_overview_manager.get_ctx_provider("search_manager")

workflow = StateGraph(dict)
llm: BaseChatModel = self._llm
search_text = prompt
search_text = self.prompt

iters = 0

Expand Down Expand Up @@ -239,15 +252,22 @@ def search(state):
return state

def do_search(state):
nonlocal messages, llm
nonlocal messages, llm, run_overview_manager
api_calls = state["api_calls"]
if api_calls:
tool_output = ""
for api_call in api_calls:
try:
func_name, func_args = parse_function_invocation(api_call)
function = getattr(search_manager, func_name)
res, summary, _ = function(*func_args)
try:
run_overview_manager.add_tool_use(func_name)
res, summary, ok = function(*func_args)
if not ok:
run_overview_manager.add_tool_failure(func_name)
except Exception:
run_overview_manager.add_tool_error(func_name)
raise
tool_output += f"Result of {func_name}({', '.join(func_args)}):\n{res}\n"
except Exception as e:
tool_output += f"Error in {api_call}: {e}\n"
Expand Down Expand Up @@ -279,7 +299,7 @@ def collect_context(state):
segments = search_manager.viewed_lines
ctx = {}
for file_name, st, end in segments:
fname = to_relative_path(file_name, retrieval_helper.repo_path).replace("\\", "/")
fname = to_relative_path(file_name, run_overview_manager.repo_path).replace("\\", "/")
ctx.setdefault(fname, set()).update(range(st, end + 1))
return {"collected_context": ctx}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langgraph.graph import END, StateGraph

from code_editing.agents.collect_edit.context_collectors.llm_retrieval import LLMRetrieval
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.utils import PromptWrapper


Expand All @@ -17,9 +18,8 @@ def __init__(self, review_prompt: PromptWrapper, *args, max_tries: int = 5, **kw
self.review_prompt = review_prompt
self.max_tries = max_tries

def build(self, *args, retrieval_helper=None, **kwargs):
if retrieval_helper is None:
raise ValueError("Retrieval helper is not set")
def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):
retrieval_helper = run_overview_manager.get_ctx_provider("retrieval_helper")

agent_executor = self._agent_executor(
tools=self.get_llm_retrieval_tools(retrieval_helper),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from langgraph.graph import END, StateGraph

from code_editing.agents.collect_edit.context_collectors.llm_retrieval import LLMRetrieval
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.tools.common import parse_file, read_file_full
from code_editing.utils.tokenization_utils import TokenizationUtils

Expand All @@ -15,9 +16,8 @@ def __init__(self, *args, total_context: int = 10000, max_searches=10, **kwargs)
self.max_searches = max_searches
self.tok_utils = TokenizationUtils("gpt-3.5-turbo-16k")

def build(self, *args, retrieval_helper=None, **kwargs):
if retrieval_helper is None:
raise ValueError("Retrieval helper is not set")
def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):
retrieval_helper = run_overview_manager.get_ctx_provider("retrieval_helper")

agent_executor = self._agent_executor(
tools=self.get_llm_retrieval_tools(retrieval_helper),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from langchain_core.tools import ToolException, tool

from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.tools.common import parse_file, read_file_full
from code_editing.agents.utils import PromptWrapper

Expand All @@ -13,9 +14,8 @@ def __init__(self, search_prompt: PromptWrapper, do_review: bool = True, **kwarg
self.search_prompt = search_prompt
self.do_review = do_review

def build(self, *args, retrieval_helper=None, **kwargs):
if retrieval_helper is None:
raise ValueError("Retrieval helper is not set")
def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):
retrieval_helper = run_overview_manager.get_ctx_provider("retrieval_helper")

return (
self.search_prompt.as_runnable(to_dict=True)
Expand Down
6 changes: 3 additions & 3 deletions code_editing/agents/collect_edit/context_collectors/my_acr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from code_editing.agents.context_providers.acr_search.search_manage import SearchManager
from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.tools.common import lines_format_document

SYSTEM_PROMPT = """You are a software developer maintaining a large project.
Expand Down Expand Up @@ -171,9 +172,8 @@ def proxy_run(self, text: str) -> Optional[dict]:
logging.warning("Failed to get a valid response after max tries.")
return None

def build(self, *args, retrieval_helper=None, **kwargs):
if retrieval_helper is None:
raise ValueError("Retrieval helper is not set")
def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):
retrieval_helper = run_overview_manager.get_ctx_provider("retrieval_helper")

search_manager = RetrievalSearchManager(retrieval_helper)
workflow = StateGraph(dict)
Expand Down
5 changes: 3 additions & 2 deletions code_editing/agents/collect_edit/editors/simple_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from code_editing.agents.collect_edit.collect_edit import CollectEditState
from code_editing.agents.collect_edit.editors.util import MarkdownOutputParser, process_edit
from code_editing.agents.graph_factory import GraphFactory
from code_editing.agents.run import RunOverviewManager
from code_editing.agents.tools.common import parse_file, write_file_full
from code_editing.agents.utils import PromptWrapper

Expand All @@ -24,7 +25,7 @@ def __init__(self, edit_prompt: PromptWrapper):
super().__init__()
self.edit_prompt = edit_prompt

def build(self, *args, retrieval_helper=None, **kwargs):
def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs):
workflow = StateGraph(EditorState)

agent_executor = self._agent_executor(
Expand Down Expand Up @@ -65,7 +66,7 @@ def edit_lambda(_: str, snippet: str) -> str:
pass
raise OutputParserException("Failed to edit the code")

file = parse_file(file_name, retrieval_helper.repo_path)
file = parse_file(file_name, run_overview_manager.repo_path)
new_code = process_edit(file, lines, edit_lambda)

# Check linter
Expand Down
54 changes: 54 additions & 0 deletions code_editing/agents/context_providers/acr_search/search_manage.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Original: https://github.com/nus-apr/auto-code-rover/blob/main/app/search/search_manage.py
import os.path
from collections import defaultdict, namedtuple
from collections.abc import MutableMapping
from typing import List, Tuple

import jedi

from code_editing.agents.context_providers.acr_search import search_utils
from code_editing.agents.context_providers.acr_search.search_utils import SearchResult

Expand Down Expand Up @@ -39,6 +42,8 @@ def __init__(self, repo_path: str, show_lineno: bool = False, **kwargs):
self.is_tracking = False
self.show_lineno = show_lineno

self.jedi_project = jedi.Project(self.project_path)

def _build_index(self):
"""
With all source code of the project, build two indexes:
Expand Down Expand Up @@ -454,7 +459,56 @@ def search_code_in_file(self, code_str: str, file_name: str) -> tuple[str, str,
tool_output += f"- Search result {idx + 1}:\n```\n{res_str}\n```\n"
return tool_output, summary, True

def show_definition(self, symbol: str, line_number: int, file_path: str) -> tuple[str, str, bool]:
"""
Show the definition of the symbol at the given line number in the file.
"""
line_number = int(line_number)
full_file_path = os.path.join(self.project_path, file_path)
# check whether this line is inside a class or function
line = search_utils.get_code_snippets(full_file_path, line_number, line_number, show_lineno=False)
if symbol not in line:
tool_output = f"The symbol `{symbol}` does not appear in line {line_number} of file {file_path}: {line}."
summary = tool_output
return tool_output, summary, False

col_offset = line.index(symbol)
jedi_script = jedi.Script(path=full_file_path, project=self.jedi_project)
definitions = jedi_script.infer(line_number, col_offset)
if not definitions:
tool_output = f"Could not find definition of symbol `{symbol}` in line {line_number} of file {file_path}."
summary = tool_output
return tool_output, summary, False
# get the first definition
definition = definitions[0]
full_path, start_no, end_no = (
definition.module_path,
definition.get_definition_start_position(),
definition.get_definition_end_position(),
)
if not is_subfolder(self.project_path, full_path):
tool_output = f"Type of symbol `{symbol}` is {definition.full_name}"
summary = tool_output
return tool_output, summary, True
rel_path = search_utils.to_relative_path(full_path, self.project_path)
code = self.retrieve_code_snippet(full_path, start_no[0], end_no[0])

tool_output = f"Found definition of symbol `{symbol}` in line {line_number} of file {file_path}:\n\n"
tool_output += f"File: {rel_path}\n"
tool_output += f"Code snippet:\n```\n{code}\n```"
summary = tool_output
return tool_output, summary, True

def retrieve_code_snippet(self, file_path: str, start_line: int, end_line: int) -> str:
if self.is_tracking:
self.viewed_lines.append((file_path, start_line, end_line))
return search_utils.get_code_snippets(file_path, start_line, end_line, show_lineno=self.show_lineno)


def is_subfolder(folder, potential_subfolder):
# Get absolute paths
folder = os.path.abspath(folder)
potential_subfolder = os.path.abspath(potential_subfolder)

# Check if potential_subfolder starts with folder
return os.path.commonpath([folder]) == os.path.commonpath([folder, potential_subfolder])
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,13 @@ def get_code_snippets(file_full_path: str, start: int, end: int, show_lineno: bo
end (int): End line number. (1-based)
"""
with open(file_full_path) as f:
file_content = f.readlines()
file_content = f.read().split("\n")
snippet = ""
for i in range(start - 1, end):
for i in range(start - 1, min(end, len(file_content))):
if show_lineno:
snippet += f"{i+1} {file_content[i]}"
snippet += f"{i+1} {file_content[i]}\n"
else:
snippet += file_content[i]
snippet += file_content[i] + "\n"
return snippet


Expand Down
3 changes: 2 additions & 1 deletion code_editing/agents/graph_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain_core.tools import BaseTool
from typing_extensions import Self

from code_editing.agents.run import RunOverviewManager
from code_editing.agents.tools.common import dummy


Expand All @@ -32,7 +33,7 @@ def __init__(self):
self._llm: Optional[BaseChatModel] = None

@abstractmethod
def build(self, *args, **kwargs) -> Runnable[AgentInput, Any]:
def build(self, run_overview_manager: RunOverviewManager, *args, **kwargs) -> Runnable[AgentInput, Any]:
pass

# Utility functions for the derived classes
Expand Down
Loading

0 comments on commit 91e92e3

Please sign in to comment.